داستان ترنسفورمرها (۶): داستان XLNet

نبردی تاریخی میان مدل‌های autoregressive و autoencoder. برگرفته از فیلم The Great Battle
نبردی تاریخی میان مدل‌های autoregressive و autoencoder. برگرفته از فیلم The Great Battle


اگر در دنیای یادگیری عمیق به خصوص در حوزه پردازش زبان طبیعی سیر و سلوک داشتید قطعا اسم برت به گوش‌تون خورده. ظهور برت یک انقلاب عظیم در حوزه پردازش زبان طبیعی بود و ما هم در قسمت چهارم از سریال داستان ترنسفورمر‌ها اون رو براتون روایت کردیم. پس از گذر از پیچ‌های تاریخی در حوزه یادگیری مدل زبانی حالا در این قسمت نوبت به XLNet رسیده. مدلی که از برت الهام گرفته و با نگاهی دیگه تونسته عملکرد برت رو تا حد خوبی بهبود بده. در واقع، مدل‌هایی مثل برت که از توکن‌های masked استفاده می‌کنند به صورت ضمنی ارتباط بین این توکن‌ها رو در نظر نمی‌گیرند و همین نکته می‌تونه باعث ضعف عملکرد مدل‌هایی مثل برت بشه. اگر نمی‌دونید سر این قصه از کجا شروع شده، به جهت آشنایی با مشکلات RNNها، مکانیزم توجه و معماری ترنسفورمرها می‌تونید قسمت‌های قبلی این سریال رو بخونید.

انواع مدل‌های زبانی؛ نبرد بین autoregressive و autoencoder

مدل‌های زبانی رو بر اساس نحوه pre-train می‌تونیم به دو دسته Autoregressive و Autoencoder تقسیم کنیم. مدل‌های زبانی که به صورت autoregressive پیش‌آموزش شبکه رو انجام می‌دهند، در واقع توزیع یک کلمه به ازای کلمات قبلی و یا کلمات بعدی رو حدس می‌زنند. همین یه طرفه به قاضی رفتن باعث ضعیف‌تر شدن مدل میشه. یعنی مدل‌های AR یک نگاه یه‌طرفه به مدل زبانی دارند در حالیکه مدل‌های مبتنی بر autoencoder سعی کردند به نوعی نگاه دوطرفه به مدل زبانی ایجاد کنند. مثلا مدل BERT که معرف حضور همه هست در دسته autoencoderها قرار می‌گیره. در BERT با استفاده از توکن MASK و جایگزینی اون با برخی توکن‌ها در فاز پیش‌آموزش و اجبار مدل به حدس‌زدن کلمات masked شده، در واقع به نوعی اثر نگاه دوطرفه در مدل ایجاد میشه. عملا در تسک‌های downstream در پردازش زبان طبیعی به این نگاه دوطرفه نیاز داریم. مثلا انسان‌ها در ترجمه زبان وقتی می‌خوان یه جمله رو ترجمه کنند به تمامی جمله توجه می‌کنند و ممکنه برخی از کلمات انتهای جمله زبان مبدا رو در ابتدای جمله زبان مقصد بیارند. همین نگاه باعث میشه مدل‌های autoregressive به مشکل بخورند و در ظاهر مدلی مثل BERT پیروز میدان قلمداد بشه!

شکست BERT؛ XLNet برمی‌خیزد

اما مدل‌هایی مثل BERT هم خالی از اشکال نیستند. به عنوان ایراد اول، وقتی برخی از توکن‌ها با توکن MASK جایگزین می‌شود و مدل سعی می‌کنه این توکن‌های masked رو تشخیص بده، به صورت ضمنی ارتباط بین این توکن‌های maskشده در نظر گرفته نمیشه. در واقع مدل برت نمی‌تونه joint probability بین این توکن‌ها رو در نظر بگیره چرا که در هر لحظه به دنبال حدس‌زدن یکی از این کلمات maskشده است و عملا ممکنه در اون لحظه هنوز بخش عمده‌ای از این کلمات masked باقی مونده باشند و به صورت ضمنی فرض می‌کنه که توکن‌های maskشده از هم مستقل هستند. برای مثال جمله "محمدرضا [MASK] در فیلم [MASK] بازی کرده است" را در نظر بگیرید. این که مدل انتخاب کنه برای ماسک اول فروتن، گلزار یا شریفی نیا رو انتخاب کنه بر روی انتخاب هاش برای ماسک دوم نیز باید اثر بذاره.

همچنین به عنوان ایراد دوم می‌تونیم به موضوع ناسازگاری بین فاز pretraining و finetuning اشاره کنیم. چرا که عملا توکنی به نام MASK در دیتای واقعی وجود نداره. به‌خاطر وجود توکن MASK عملا معماری برت پیچیده‌تر شده که هم بتونه تسک pretrain رو به‌خوبی هندل کنه و هم تسک fine tune رو. به این علت که توکن mask در دیتای فاین‌تیون وجود نداره ممکنه به خاطر همین maskکردن‌ها بخشی از اطلاعات از دست بره یا ممکنه نویز وارد فرآیند یادگیری بشه. برای حل این مشکلات XLNet وارد میدون میشه. مدلی که ذاتا autoregressive است اما با اتخاذ تکنیک permutation سعی کرده اثر نگاه یک‌طرفه در این نوع مدل‌ها رو از بین ببره از طرفی به خاطر ذات autoregressive بودنش دیگه نگران ناسازگاری بین فاز‌های pre-training و fine tuning نیست. همچنین برای اینکه بتونه متن‌های با طول بلندتر رو هم ساپورت کنه از Transformr-XL الهام گرفته و در نتیجه نشون داده که تونسته BERT رو شکست بده. در تصویر زیر می‌تونید تفاوت XLNet و BERT رو مشاهده کنید.

در این مثال فرض شده که جمله ورودی [New, York, is, a, city] است. همچنین توکن‌های هدف هم [New, york] هستند. با توجه به این مثال، مدل BERT به هر یک از کلمات New و York به صورت جداگانه نگاه می‌کنه در حالیکه مدل XLNet می‌تونه ارتباط بین New و York رو در نظر بگیره. بنابراین می‌تونه نتایج بهتری رو با در نظر گرفتن ارتبط بین دو توکن new و york تولید کنه.

مکانیزم Permutation

مدل XLNet برای فاز pre-training از Permutation Language Modeling استفاده می‌کنه که همون‌طور که از اسمش پیداست ایده اصلی اون درباره جایگشت‌های متفاوت از جمله ورودیه. فرض کنید که یک جمله ورودی به شکل [x1, x2, x3, x4] داریم. در این صورت تعداد کل جایگشت‌ها برابر با ۴! یا همون ۲۴ است. فرض کنید توکنی که باید مدل اون رو حدس بزنه همون x3 باشه. تابع هدف به صورت کلی در این نوع مدل زبانی به صورت زیره:

در این فرمول مقدار Z_T برابر با مجموعه کل جایگشت‌های دنباله به طول T است. همچنین x_z_t بیانگر توکن tام از جایگشت z است و x_z<t برابر با تمامی توکن‌های قبل از اندیس t است. با توجه به فرمول بالا لگاریتم احتمال رخداد توکنی که مدل باید حدس بزنه به شرط توکن‌های قبلی حساب میشه و امیدریاضی بر روی تمامی جایگشت‌ها حساب میشه. اما در عمل به دلیل اینکه تعداد جایگشت‌ها می‌تونه خیلی زیاد باشه فقط بعضی از اون‌ها به صورت رندم انتخاب می‌شوند، به صورتی که توکن هدف حتما در مجموعه جایگشت‌های انتخابی، در تمامی ایندکس‌های اول تا آخر جمله ظاهر شده باشه.

جایگشت‌های مورد قبول برای حدس زدن توکن x3
جایگشت‌های مورد قبول برای حدس زدن توکن x3

مکانیزم Attention Mask

با توضیحات بالا احتمالا متوجه شده باشید که پیاده‌سازی تابع هدف معرفی‌شده با استفاده از transformer یک چالش اصلی داره. فرض کنید که جمله ورودی برابر با x=[This, is, a, sentence] باشد. در این صورت مدل احتمال Pr(This|is) را همسان با Pr(This|a) می‌تونه ببینه در حالیکه می‌دونیم پوزیشن کلمات is و a متفاوته و همین موضوع می‌تونه بر احتمال محاسبه شده تاثیر بذاره. در واقع نیاز داریم تا اطلاعات پوزیشن کلمات context (یعنی کلماتی که به شرط وجود اون‌ها احتمال کلمه هدف رو مشخص می‌کنیم) رو داشته باشیم. خوشبختانه معماری پایه ترنسفومر به صورت پیش‌فرض این مشکل رو حل می‌کنه. در این معماری اطلاعات پوزیشن توکن‌ها با بازنمایی هر توکن ترکیب می‌شه و اینطوری اثر پوزیشن هر توکن context رو می‌تونیم در محاسبه احتمال‌ها ببینیم. اما از اونجایی که در XLNet مکانیزم جایگشتی داریم پس ترتیب کلمات در جمله می‌تونه بهم بخوره. از اینجاست که دیگه مکانیزم position embedding به‌تنهایی به‌درد نمیخوره. XLNet برای پیاده‌سازی مکانیزم جایگشتی از مفهوم attention mask استفاده می‌کنه. به عبارت دیگه مدل همیشه ترتیب اصلی کلمات رو حفظ می‌کنه و فقط کلماتی رو که نباید در نظر بگیره در محاسبه بردارهای attention ازشون چشم‌پوشی می‌کنه. مثلا فرض کنید بخوایم جایگشت [a, is, sentence, This] رو به تابع هدف مدل بدیم. در این‌صورت برای حدس زدن اولین کلمه از این جایگشت (یعنی a) هیچ کانتکستی رو لازم نداریم بنابراین بردار attention mask برابر با [0,0,0,0] خواهد بود. اگر بخوایم کلمه is رو حدس بزنیم، در جایگشت موردنظر فقط کلمه a قبل از اون اومده پس بردار attention mask برابر با [0,0,1,0] میشه. چرا که در جمله اصلی کلمه a در جایگاه سوم قرار داره. به همین شکل ماتریس attention mask ساخته می‌شه که به‌صورت زیر خواهد بود.

از نگاه دیگه تابع هدف می‌تونه شامل موارد زیر باشه. در این تصویر کلماتی که در نظر گرفته نمیشه با زیرخط مشخص شده:

مکانیزم Two-Stream Self-Attention

اما با توجه به توضیحات داده شده هنوز یه چالش دیگه باقی‌مونده. در واقع ما علاوه بر اینکه نیاز داریم تا اندیس توکن‌های context رو بدونیم، نیاز داریم تا اندیس توکن هدف رو هم داشته باشیم. به عبارت دیگه دنبال Pr(This|1, is+2) هستیم. اما همون‌طور که می‌دونید اینجا دیگه معماری ترنسفومر نمی‌تونه به‌ما کمکی بکنه چرا که پوزیشن کلمه This رو همراه با بازنمایی خود کلمه This ایجاد می‌کنه. به عبارت دیگه ترنسفومر اصلی مقدار Pr(This|This+1, is+2) رو برمی‌گردونه که طبیعتا ما نمی‌تونیم بازنمایی This رو داشته باشیم چرا که کلمه‌ای هست که می‌خوایم اون رو حدس بزنیم. برای این مشکل، از two-stream self-attention استفاده می‌کنیم. هر توکن در هر لایه self-attention دو تا وکتور داره. یک وکتور h که مربوط به content stream میشه و یک بردار g که مربوط به query stream است. بردارهای content stream همانند بردارهای attention عادی در شبکه ترنسفورمر هستند و با بازنمایی هر توکن به علاوه بازنمایی پوزیشن هر توکن مقداردهی اولیه می‌شوند. اما بردارهای query stream با بازنمایی یک کلمه مشخص مانند w به علاوه بازنمایی پوزیشن هر توکن مقداردهی اولیه می‌شوند. در واقع چون در بردارهای query stream صرفا از بازنمایی یک کلمه ثابت استفاده شده، اثر استفاده از توکن mask رو در برت داره. مکانیزم به‌روزرسانی هر یک از این stream ها به صورت جداگانه انجام میشه. به‌طوری که برای به‌روزرسانی h_i وکتور‌های content stream که unmasked هستند به علاوه خود h_i استفاده می‌شوند. مثلا برای به‌روزرسانی کلمه a با توجه به جایگشتی که در قسمت قبل مثال زدیم، بردار mask به صورت [0, 0, 1, 0] استفاده می‌شود. چرا که attention mask این کلمه برابر با [0, 0, 0, 0] بود و اندیس اصلی خود این کلمه نیز ۳ است. یا مثلا برای کلمه is بردار mask به صورت [0, 1, 1, 0] خواهد بود. این پروسه از بردارهای content stream به عنوان query, key, value استفاده می‌کند. از طرف دیگر، هر g_i با استفاده از بردارهای content vector بر اساس attention mask و خود g_i به‌روزرسانی میشه. مثلا برای به‌روزرسانی بردار g_4 که مربوط به کلمه sentence است باید content vector دو کلمه is و a به همراه خود g_4 استفاده شود چرا که attention mask آن برابر با [0, 1, 1, 0] است. در تصویر زیر نحوه به‌روزرسانی content vector ها و query vector به ازای کلمه چهارم یعنی sentence را مشاهده می‌کنید.

با توجه به توضیحات بالا، تابع هدف شامل مقادیر زیر می‌شود که در این تصویر علامت * به معنای این است که اطلاعات پوزیشن آن توکن در نظر گرفته شده است.

برای اینکه این موفقیت عظیم رو بیشتر درک کنید، در زیر می‌تونید جدول مقایسه نتایج XLNet رو با BERT ببینید.

یه نکته مهم اینه که علی‌رغم بهبود همه‌جانبه از سمت XLNet در تسک‌های معروف، اما قدرت این مدل نسبت به برت در تسک‌هایی که به‌نوعی generative هستند بیشتر مشهوده. درواقع شبکه برت در تسک‌های تولید زبان ضعف بیشتری از خودش نشون میده و دلیلش هم تقریبا مشخصه. چون وقتی می‌خوایم متنی رو تولید کنیم طبیعتا کلمات ماقبل رو فقط دیدیم و همین فرض باعث میشه که AR language modelها قدرت بیشتری در تسک‌های تولید زبانی داشته باشند.

جمع‌بندی

مدل XLNet یه روش برای پیش‌آموزش مدل زبانیه که با استفاده از permutation language modeling سعی داره هم از خوبی‌های AR language modelingها استفاده کنه و هم از AE language modelها. معماری این مدل طوری طراحی شده که با استفاده از مکانیزم Two-Stream Self-Attention بتونه بر کمبود‌های معماری استاندارد ترنسفورمر برای پیاده‌سازی این تابع هدف غلبه کنه و تونسته نتایج بهتری نسبت به BERT بر روی خیلی از تسک‌های استاندارد پردازش زبان به‌دست بیاره.

منابع

[1]: XLNet: Generalized Autoregressive Pretraining for Language Understanding

[2]: Understanding XLNet

[3]: What is XLNet and why it outperformes BERT