سلام !
اگر سری آموزش قبلی تحت عنوان " شروع یک مسئله ماشین لرنینگ با داده های کم " رو دنبال کرده باشید به یاد دارید که این متد را به عنوان یک روش برای حل مسئله ی few shot learning یا همان fsl معرفی کردیم.
در سری آموزش ذکر شده یک طبقه بندی اصلی برای این موضوع مطرح کردیم : دیتا ، مدل و الگوریتم.
متد Proto Net در طبقه بندی Model و در دسته ی Embedding models و هم چنین از نوع Task-invariant Embedding Model بود که در این دسته بندی ، Matching Net ,Relation Net و ... هم قرار دارند که در آموزش های دیگر در مورد آنها هم بحث خواهد شد.
- تعریف ساده و کلی :
فرض کنید یک مسئله ی فیو شات به صورت 3way-5shot داریم، یعنی سه کلاس و ۵ شات در هر کلاس .
فرض کنید که داده های ما ( دایره های سبز و آبی و نارنجی ) داده های با ابعاد بالا بوده اند که به وسیله ی یک embedding function در این فضای دکارتی دوبعدی نمایش داده شده اند. ما این داده ها را Support set می نامیم.
برای هر دسته که با یک رنگ متفاوت نشان داده شده ، چنانچه میانگین داده ها را بگیریم به یک نقطه خواهیم رسید که نماینده ی کلاس مربوطه است و آنرا نقطه ی prototype گوییم. در تصویر بالا C1,C2,C3 هرکدام پروتوتایپ کلاس خود هستند .
حالا یک داده ی بدون لیبل که میخواهیم لیبل آن را پیش بینی کنیم به نام X با همان embedding function ذکر شده در قبل ، به مجموعه اضافه میکنیم. این داده را جزو دسته ی query set مینامیم. کاری که ما باید انجام دهیم محاسبه ی فاصله ی اقلیدسی این داده با هر یک از پروتوتایپ های بالا و انتخاب لیبلی است که داده ی X نزدیک ترین فاصله را با آنها داشته است . به این منظور فاصله ها را با فانکشن d پیدا میکنیم و سپس روی این فاصله ها یک softmax پیاده میکنیم تا احتمال تعلق این داده به هر یک از دسته ها را به ما بدهد.( در اینجا خروجی softmax برای دسته ی c2 بیشترین میزان است .
آنچه در بالا توضیح داده شد ،کلیت کاری بود که انجام میشود . حالا سوالاتی که ممکن است در ذهن شما پیش بیاید این ها هستند :
- نحوه ی پیدا کردن prototype چگونه است ؟
برای پیدا کردن پروتوتایپ از فرمول زیر استفاده میکنیم :
نترسید ! Ck همان پروتایپ ها هستند. k نشان دهنده ی شماره کلاس هست ( مثلا اگر سه کلاس داشته باشیم k=1,2,3 است ) . Sk به مجموعه ی همه ی سمپل ها ی support set گفته میشود که به صورت جفتی یعنی خود سمپل و لیبل آن درونش قرار گرفته اند .
تابع fφ هم اصلا چیز ترسناکی نیست !همون فانکشن embedding هست که میاد داده های با ابعاد بالاتر رو به ابعاد پایین تر تبدیل میکنه.
خب پس Ck به این صورت به دست میاد که جمع همه ی شات های ما ( البته پس از embedded شدن ) که با (fφ(Xi نمایش میدیم را در یک کلاس Sk ( مثلا کلاس با ۵ تا شات ) به دست میاره و تقسیم بر اندازه ی Sk میکنه که اندازه ی Sk هم همون تعداد شات ها هست. به این صورت میانگین گیری میشه و پروتوتایپ Ck بدست میاد .
- نحوه ی انجام embedding چگونه است ؟ embedding function چگونه کار میکند ؟
برای انجام عمل embedding الگورتیم های مختلفی وجود دارد که توضیح آن خارج از بحث فعلی است.دو منبع زیر برای مطالعه ی بیشتر پیشنهاد میشود :
- نحوه ی پیدا کردن فاصله ی بین داده X و هر کدام از پروتوتایپ ها چیست ؟
گفتیم که وقتی فاصله ها را با تابع d محاسبه کردیم ، با استفاده از یک softmax یک توزیع بین ۰ تا ۱ به دست می آوریم . فرمول softmax به صورت زیر بود :
حالا برای فاصله هایی که ما داریم ، همین فرمول را به صورت زیر بازنویسی میکنیم :
این فرمول به این معنا است که احتمال برابر بودن لیبل داده ی X با k را به این صورت میابیم. تابع d محاسبه کننده فاصله و مقادیر درون آن یعنی فاصله ی بین (fφ(X و Ck همان فاصله ی بین نقطه ی X که embedded شده و پروتوتایپ ها است. علامت منفی هم به دلیل این قرار داده میشود که هرچه d بزرگتر باشد کسر هم بزرگتر میشه و یعنی فاصله بیشتر برابر با احتمال تعلق بیشتر ! اما ما میخواهیم هرچه فاصله بیشتر بود احتمال کوچکتر باشد. برای همین از علامت منفی استفاده میکنیم.
- مقدار Loss برای بهبود دادن و یادگیری مدل چگونه بدست می آید ؟
تابع زیر را در نظر بگیرید :
مقدار Loss را با J نشان میدهیم و این تابع برابر است با مقدار فرمول بالا. این J میبایست minimize شود.
عبارت جلوی Log همان pφ است که در فرمول 4 آمده . چنانچه جایگذاری کنیم داریم :
همینطوری که در تصویر بالا میبینید ، J برای هر episode بدست می آید . یک episode شامل یک زیرمجموعه ای از training set ما است. اگر به توضیح هر خط در سمت راست نگاه کنید در چهار خط اول ، مقدمه کار یعنی انتخاب اندیس برای اپیزود ، انتخاب مجموعه ی support se و query set و نهایتا ساخت پروتوتایپ ها انجام شده است. در ادامه ابتدا مقدار J را برابر ۰ قرار میدهد ( مقدار Loss اولیه ) و در یک حلقه ی او در تو مقدار Loss آپدیت میشود. Nc تعداد کلاس های هر اپیزود و Ns تعداد سمپل های هر کلاس است.
نتایج تست این متد روی دو دیتاست مشهور mini image net و omniglot را در زیر مشاهده میکنید.
خلاصه :
ابتدا یک توضیح خیلی ساده از جایگاه این متد در دسته بندی مسائل فیوشات و نحوه کارکرد این متد ارائه کردیم. سپس به بررسی عمیق تر در قالب پرسش و پاسخ هایی که به ذهن شما خطور میکرد پرداختیم و نهایتا آنچه گفته شد را در قالب کد بیان کردیم. در آخر نیز دو جدول از نتایج دقت روی دیتاست های مشهور را ارائه دادیم.
برای این مطلب از paper اصلی این متد استفاده کردم.
از مطالعه این مطلب تا انتها بی نهایت سپاسگذارم و منتظر انتقادات ، پیشنهادات و سوالات شما هستم.
ارادت مند شما ، امید عرب خرزوقی .