متد gather در pytorch

نیاز مادر اختراع است

اگر با پایتورچ آشنا هستید (که انشالله هستید چون لازمه فهمیدن این پست دانستن pytorch است ) احتمالا با دستورهای مربوط به indexing و slicing هم آشنایی دارید به فرض مثال تنسور زیر رو در نظر بگیریم:

فرض کنید از شما می‌پرسند چطوری میشه محتویات باکس‌های قرمز و سبز رو بدون استفاده از for به دست آورید. احتمالا شما هم به سادگی می‌گویید x[1,:] و x[:, 7] که البته کاملا هم درست هستند. اما حالا اگر شما در صورت مساله ای قرار گرفتید که بایستی از تنسور زیر مقادیر زیر رو بیرون بکشید و تحت یک تنسور خروجی بدید چه کار میکنید؟ (اگر در جوابتون for وجود داره این نکته رو بایستی متذکر شوم که استفاده از for در کدهای یادگیری عمیق یک چیزی تو مایه‌های استفاده از goto در زبان‌های برنامه نویسی است، هم فلسفه یادگیری عمیق که موازی کاری هست رو زیر سوال می‌برید هم کلی فحش برای خودتون به رایگان میخرید!)

چگونه از gather استفاده کنیم ۱۰۱

در صورتی که دستور gather رو نمی‌شناسید و تمایلی هم ندارید که با کثیف کاری کار رو پیش ببرید و میل به یادگیری چیز جدیدی دارید با ما در ادامه همراه باشید!

مساله بالا رو میشه به این صورت فرموله کرد که ما یک تنسور ورودی داریم که میخوایم روی عناصر بعد صفرم این تنسور به نوعی iterate کنیم و از هر کدوم این عناصر، عنصرهای در جای مشخصی رو جمع‌‌آوری کنیم. متناظر با این مساله، دستور gather در پایتورچ هم سه ورودی می‌گیره، input یا تنسور ورودی، dim یا همون بعدی که میخوایم در امتداد اون داده‌های مورد نظرمون رو جمع آوری و جدا کنیم و در نهایت index که تنسوری است که وظیفه‌اش مشخص کردن اینه که دقیقا کدوم عناصر رو میخوایم برداریم. نکته مهمی که بایستی توجه بشه اینه که تعداد ابعاد تنسور input با index بایستی یکسان باشند. از طرفی در صورتی که مثلا ابعاد تنسور ورودی شما 30*20*10 باشه و مثلا dim هم برابر با صفر باشه، ابعاد تنسور index باید به فرم N*20*30 باشه.

فرض کنید که برای مثال این تنسورهای ما دو بعدی (مثل قسمت قبل) باشند، در این صورت خروجی متد gather رو اگر بخوایم خیلی فرموله شده و ریاضی نشون بدیم این شکلی میشه:

out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]]  # if dim == 1

اما این که نوشتیم یعنی چه؟ فرض کنیم که dim برابر با یک باشه. در این صورت حاصل gather این میشه که مشخص میکنه از هر سطری مقدارهای توی چه ستون‌هاییش باید انتخاب شوند و بالعکس اگر dim برابر با صفر باشه gather مشخص میکنه که از هر ستونی مقدارهای توی چه سطریش باید انتخاب بشوند. برگردیم به مثال قسمت قبل خودمون و بخواهیم حلش کنیم، بایستی همچین کدی رو بنویسیم (من دیگه از همون بیخ ابتدای قضیه کدش رو نوشتم که راحت کپی پیست کنید)

x = torch.arange(50).reshape(5,10)
torch.gather(x, 1, torch.tensor([[2, 6], [1, 3], [3, 5], [1, 6], [8, 9]]))

با اجرای کدهای بالا نتیجه زیر حاصل میشه:

به همین سادگی به همین خوشمزگی!

که عشق آسان نمود اول ولی افتاد مشکل‌ها

این چیزی که در قسمت قبل دیدیم یک نمونه دو بعدی بود. اما اگر مساله ما سه بعدی بود چی؟ مثالش میشه وقتی که مثلا شما در یک کیس پردازش زبانی یک ماتریس سه بعدی از کلمات دارید که اندازه بعد‌ها به ترتیب سایز بچ، اندازه طول جمله و اندازه بازنمایی یک هر یک از کلمات جمله هستند. حالا شما در موقعیتی قرار دارید که بایستی از هر جمله کلمات خاصی رو با بردار بازنمایی شون بیرون بکشید. انجام دادن gather روی تنسور سه بعدی کمی سخت و تریکی هستش اما نگران نباشید انشالله که موشکافانه می‌فهمیمش. برای فهم این مساله سعی میکنیم از طریق مثال عملی روشنش کنیم. اول از همه بیایید و یک تنسور ورودی مثالی بسازیم (خرده مگیرید چرا for استفاده کرده‌ایم، از قصد استفاده کردیم که بتونیم تنسوری بسازیم که جلوتر نتیجه gather روی این تنسور قابل دنبال کردن و سنجیدن باشه)

batch_size = 4
max_seq_len = 9
hidden_size = 8
x = torch.empty(batch_size, max_seq_len, hidden_size)
    for i in range(batch_size):
        for j in range(max_seq_len):
            for k in range(hidden_size):
                x[i,j,k] = i + j*10 + k*100

ما یک تنسور فرضی رو ساختیم که سایز بچ‌اش چهار هست و هر جمله اون نه کلمه داره و اندازه بردار بازنمایی هر کلمه هم هشت هست. مقادیر این تنسور رو هم از قصد جوری مقداردهی کردیم که بشه بعدا چک کرد آیا درست برداشتیم یا نه. حالا فرض کنید بخوایم از جمله صفر کلمات یک و پنج، از جمله یک کلمات دو و چهار، از جمله دو کلمات یک و هفت و از جمله چهارم هم کلمات شش و هشت رو برداریم. از اونجایی که از هر جمله دو کلمه رو انتخاب کردیم بنا به قاعده قسمت قبل بایستی اندازه تنسور index اینجا برابر با 8*2*4 باشه و dim رو هم برابر با یک قرار بدیم (چون میخوایم از بعد کلمات انتخاب کنیم). خب اگر مساله ما دو بعدی بود ما ‌میتونستیم تنسور ایندکس زیر رو تشکیل بدیم و به راحتی عملیات gather رو انجام بدیم:

token_indexes = torch.LongTensor([[1,5],
                                   [2,4],
                                   [1,7],
                                   [6,8]])

درد اما اینجاست که این تنسور به درد ما نمیخوره. چرا؟ چون که ما باید عوض این تنسور، یک تنسور سه بعدی داشته باشیم که به ازای تک تک عناصر بعد سوم یعنی بعد بردار بازنمایی هم مشخص کرده باشه که کدوم کلمات رو برداریم. از اونجایی که ما برای یک کلمه تمامی ابعاد بردار بازنمایی‌اش رو برمیداریم یا اگر از اون وری فکر کنیم برای تمام ابعاد بردار بازنمایی یک کلمه مشخص رو برمیداریم (این جوری نیست که چهارتای اول بردار بازنمایی رو از کلمه اول برداریم چهارتای بعدی رو از یک کلمه دیگه). بنا به این خاصیت پس می‌تونیم این تنسور index مون رو برایش یک بعد سومی بسازیم و محتوای این تنسور را در بعد سوم repeat کنیم (این قدر repeat کنیم تا تمامی عناصر بردار بازنمایی مشخص بشوند و سایز تنسور index بشود 8*2*4) بنابراین کد زیر رو اجرا میکنیم.

indices = token_indexes.repeat(1,8).reshape(4, -1,2).transpose(2,1)
print(indices.shape) ## torch.Size([4, 2, 8])

گیج شدید؟ برای این که کمی شفاف سازی داشته باشیم قبل از رفتن به گام قبل، یک لحظه این indices رو با هم چک کنیم:

print(indices[0,:,0]) # tensor([1, 5])
print(indices[0,0,:]) # tensor([1, 1, 1, 1, 1, 1, 1, 1])
print(indices[:,0,0]) # tensor([1, 2, 1, 6])

خط اول در واقع بیان میکنه که عنصر صفر بردار بازنمایی چه کلماتی از جمله صفر رو میخواد جمع آوری کنه که جواب کلمات یک و پنج است. خط دوم میگه که عناصر بردار بازنمایی کلمه جمع آوری شده صفر جمله صفر دقیقا کدوم کلمه رو از اون جمله برمیدارند که جواب همگی یک است. و خط آخر هم میگه عنصر اول بردار بازنمایی کدوم کلمات جمع آوری شده صفر جملات رو جمع آوری میکنند که جواب یک و دو و یک و شش هست (شاید بهتر باشه این تیکه رو خودتون بیشتر تست کنید که بهتر متوجه بشید)

در نهایت می‌ریم که ببینم جواب نهایی عملیات ما چه شکلی شد:

torch.gather(x,1,indices)

همین طور که مشاهده می‌کنید تنسور آماده صرفه. نوش جان :)

پ.ن.۱: بار دیگر تاکید میشه که تمامی این دردسرها و متدهای عجیب و غریب برای اینه که از به کار بردن حلقه در کدهای یادگیری عمیق به شدت پرهیز بشه. اینجا مثال‌ها اندازه‌هاشون کوچک بودند ولی در کاربرد واقعی وقتی مثلا سایز بردارنهان ۷۶۸ باشه یا تعداد کلماتتون به ۵۱۲ برسه انجام for در مقایسه با توابعی مثل gather مثل مقایسه دویدن لاکپشت و یوزپلنگ است. پس اگر قصد خبره شدن در یادگیری عمیق رو داریم به هر for ای که به کار میبریم بایستی به چشم عنود و عداوت نگریسته بشه و در صورت ممکن سعی بر پاکسازی کد از حلقه‌ها بشه.

پ.ن.۲: به کانال ما سر بزنید!