ویرگول
ورودثبت نام
Melika Sadeghi
Melika Sadeghi
خواندن ۱۰ دقیقه·۲ سال پیش

Improved Deep Metric Learning with Multi-class N-pair Loss Objective

دیپ متریک لرنینگ در سال های اخیر شهرت و محبوبیت زیادی پیدا کرده است

اگرچه فریم های موجود دیپ لرنینگ بر اساس contrastive loss و triplet loss معمولا مشکل کند بودن سرعت convergence دارند , زیرا آن ها یک example منفی را به کار می گیرند و در حالی که در هر آپدیت بقیه example های منفی را نادیده می گیرند

در این مقاله قصد داریم که این مشکل را با یک روش دیپ متریک لرنینگ جدید به نام multi class N pair loss برطرف کنیم

از یک example منفی استفاده نکنیم و از N-1 مثال منفی استفاده کنیم

برای رسیدگی به این مشکل ما از (N+1 tuplet loss) استفاده می کنیم که مثال مثبت را از N-1 مثال منفی شناسایی کند در زمانی که N=2 است این برابر همان triplet loss هست

یک نگرانی در باره (N+1 tuplet loss) این است که با افزایش تعداد مثال محاسبه آن از مرتبه A = N^2 می شود

برای غلبه بر این موضوع یک روش بهینه batch construction معرفی می کنیم که برای ساختن N tuplet با اندازه N+1 , به جای N(N+1) مثال از 2N مثال استفاده کند

Preliminary: Distance Metric Learning

لازم به ذکر است , Contrastive loss یک جفت مثال را به عنوان ورودی می گیرد و تخمین می زند که این دو مثال متعلق به یک کلاس هستند یا نه و در واقع فرمول loss به شکل زیر است :


در واقع اگر دو مثال متعلق به یک کلاس باشند قسمت دوم رابطه جمع صفر شده و در قسمت اول کافی است فاصله بین fi و fj به کمترین مقدار ممکن برسد و اگر دو مثال از دو کلاس مختلف باشند , قسمت اول رابطه جمع صفر شده و در قسمت دوم کافی است که فاصله بین fi و fj از مارجین که برابر m هست بیشتر باشد

از طرفی در triplet loss , تنها اختلاف مثال مثبت و منفی با کوئری در نظرفته میشود که لازم است این تفاوت از یک مارجین m ای بزرگ تر باشد


با وجود استفاده گسترده از دو روش اما هر دوی triplet loss و contrastive loss مشکل کند بودن سرعت converge شدن دارند

Deep Metric Learning with Multiple Negative Examples

فلسفه اساسی پشت triple loss این هست که برای ورودی یا کوئری می خواهیم فاصله بردار های embedding مثال های مثبت با query را کم کنیم و فاصله بردار های embedding مثال های منفی را از کوئری زیاد کنیم

لازم به ذکر است triplet loss در حین یک آپدیت فقط ورودی را با یک مثال منفی مقایسه می کند و سایر مثال های منفی از کلاس های دیگر را نادیده می گیرد

اگر چه hard negative data mining با خروجی کلاس های زیاد , از نظر محاسباتی هزینه زیادی دارد ما به دنبال یک جایگزین هستیم , یک loss function که در هر آپدیت گروهی از مثال های منفی را به کار می گیرد و نیاز هست که تمام مثال های منفی در یک زمان قابل شناسایی باشد ,ما متود پشنهادی را به صورت زیر فرمول بندی می کنیم :


که در واقع (N+1 tuplet) داریم که {x,x+,x1,x2,….Xn-1} که x و x+ به ترتیب anchor و positive هستند و x1 و x+ به ترتیب anchor و positive هستند و x1 و.... Xn-1 مثال های منفی می باشند . اگر تعداد مثال های منفی هر کلاس را به یک عدد مثال محدود کنیم باز هم از نظر optimization سنگین است

فرمول loss بالا با N= 2 همانند triplet loss عمل می کند و هر دوی loss function ها معادل یکدیگرند

اما در مورد N> 2 در باره مزیتtuplet loss (n+1) بر triplet loss بحث خواهیم کرد


N-pair loss for efficient deep metric learning

فرض کنید که ما (N+1 tuplet loss) را برای دیپ متریک لرنینگ به کار گرفته ایم زمانی که بچ سایز SGD

برابر M می باشد , تعداد M * (N+1 ) مثال هست که در هر آپدیت باید پس داده شود

از انجا که تعداد مثال ها در هر بچ به صورت نمایی از مرتبه MN افزایش می یابد از نظر محاسباتی غیر ممکن می شود

N pair loss for efficient deep metric learning:

فرض کنید ما Tuplet loss (N+1) را بر فریم ورک های دیپ متریک لرنینگ اعمال کردیم , زمانی که اندازه بچ سایز SGD , برابر M باشد , تعداد M*(N+1) مثال وجود دارد که در یک اپدیت باید به تابع f داده شود از آنجا که تعداد مثال ها برای ارزیابی در هر بچ به صورت نمایی از M و N می باشد محاسبات برای شبکه دیپ غیر عملی خواهد شد

حال یک روش کارامد را برای جلوگیری از سربار محاسبات معرفی می کنیم.

فرض کنید,.{(XN,XN+),...(X1,X1+)} N جفت از N کلاس مختلف هستند به طوریکه yi≠yj به ازای هر

عضو i≠j, ما N تا tuplet می سازیم به صورت {Si}از N جفت به طوریکه si={xi,x1+,x2+,…Xn+} کهXi

کوئری برای si هست , xi+ مثال مثبت برای کوئری و xj+ به طوریکه j≠i مثال های منفی هستند


باتوجه به عكس فوق مشاهده مي شود كه N-pair batch construction در قسمت c از 2N , بردار embedding برای ساخت N تا (N+1 tuplet loss ) مجزا استفاده می کند و Triplet loss از 3N انتقال برای محاسبه وکتور های embedding استفاده می کند

و N+1 Tuplet loss به N*(N+1 ) انتقال نیاز دارد

و روش پیشنهادی ذکر شده به 2N انتقال نیاز دارد

همچنین محاسبه LOSS روش پیشنهادی به صورت زیر است


Hard negative class mining

محاسبه وکتورهای embedding برای مثال های متعدد برای تعداد کلاس های بالا از نظر محاسباتی

پیچیده است از طرفی (N pair loss) به صورت تئوری به N کلاس که دو به دو با یکدیگر منفی هستد نیاز دارد. برای غلبه بر این مشکل مفهوم negative class mining را معرفی می کنیم که مخالف negative instance mining است

به طور کلی negative class mining for N pair loss به صورت زیر محاسبه می شود :

به صورت رندوم تعداد c کلاس را انتخاب می کنیم برای هر کلاس تعداد یک یا دو مثال از آن را برای استخراج وکتور embedding انتخاب می کنیم . سپس از مرحله اول از بین c کلاس یک کلاس را به صورت رندوم انتخاب می کنیم

سپس یک کلاس دیگر را اضافه می کینم که بیشترین تخلف را نسبت به triplet loss دارد یعنی در فضای embed از بقیه کلاس ها نزدیک تر به کلاس انتخاب شده در مرحله قبل است اگر دو تا کلاس بود که فاصله آن ها یکی بود به صورت رندوم یکی از آن دو را انتخاب می کند

سپس از بین دو کلاس انتخاب شده در مراحل قبل (2 example) از هر کدام بیرون می کشیم

Fine-grained visual object recognition and verification

دو دیتاست ماشین و گل را به صورت زیر در نظر می گیریم :

ديتاست Car 333 از 164863 عکس تشکیل شده است که شامل 33 کتگوری مختلف ماشین است و این دیتاست شامل 157023 عکس برای داده اموزش و 7840 عکس برای داده تست است

ديتاست Flower 610 از 61771 عکس گل تشکیل شده است که 58721شامل 610 کتگوری مختلف گل است و این دیتاست شامل 58721 عکس برای داده آموزش و 3050 عکس برای داده تست است

ما شبکه را برای 40000 بار آموزش تکرار کرده که در هر بچ 144 مثال وجود دارد که به عبارتی 72 جفت در هر بچ برای N_ pair losses است

ما 5 old cross validation را روی داده train اعمال می کنیم و میانگین عملکرد را روی داده تست گزارش می کنیم

برای هر دو دیتاست ماشین و گل در هر دو بخش recognition و verification بهبود قابل ملاحظه ای از 72 pair loss نسبت به triplet loss مشاهده می شود

هم چنین اگر چه اعمال negative data mininig بر triplet بهبود قابل توجهی به ان بخشیده اما همچنان با 72 pair loss قابل رقابت نیست

در مقایسه با softmax , در قسمت verification , softmax علکرد ضعیف تری نسبت به 72 pair mc داشته است اما در در بخش recognition با 72 pair mc رقابت نزدیکی داشته


Distance metric learning for unseen object recognition

لازم به ذكر است Dictance metric learning را میتوان به عنوان متریکی برای یادگیری و تعمیم دهی مدل برای داده های دیده نشده استفاده کرد . سه دیتاست زیر را در نظر می گیریم :

که شامل 120053 عکس از 22634 کتگوری مختلف است که 59551 کتگوری اول را به داده آموزش اختصاص داده و ما بقی را به داده تست

که شامل 16185 عکس از 196 مدل کتگوری مختلف است که 98 کتگری اول را به داده آموزش اختصاص داده و مابقی کتگوری ها را به داده تست

که شامل 11788 عکس از 200 گونه پرنده است که 100 کتگوری اول برای اموزش و مابقی برای تست است

برخلاف سکشن قبلی کتگوری های آموزش و تست جداهستند که این حل مسئله را پرچالش تر می کند و و می تواند دچار overfitting روی داده آموزش شود و همچنین قدرت تعمیم دهی برای برای موارد دیده نشده سخت تر باشد

ما شبکه را برای 20000 بار آموزش تکرار کرده که در هر بچ 120 مثال وجود دارد که به عبارتی 60 جفت در هر بچ برای N_ pair losses

در این قسمت هم همانند سکشن قبلی روند یکسانی را ملاحظه می کنیم و Triplet loss بدترین عملکرد را در میان تمام loss function ها داشته است

و Negative data mining عملکرد triplet loss را با فرار کردن از نقاط بهینه محلی بهبود بخشیده

اما مدل N pair loss بدون نیاز به هزینه محاسبات بیشتر برای انجام negative data miming عملکرد بهتری نسبت به triplet-nm داشته

حال با اعمال negative data mining بر روی N pair loss حتی میتوان به دقت و عملکرد مناسب تری رسید


Face verification and identification

در نهایت ما متریک لرنینگ را برای تشخیص چهره به کار می بریم , مسئله پیدا کردن چهره ای که به چهره مد نظر ما بیشترین شباهت را دارد (متعلق به یک نفر هستند) از میان تعداد زیادی مثال منفی در گالری چهره ها

.بدین منظور شبکه را بر دیتاست webface اموزش می دهیم که شامل 494414 تا عکس هست از 10575 شخص مختلف و می خواهیم که کیفیت embedding شبکه را با متریک لرنینگ های مختلف ارزیابی کنیم

همه شبکه ها برای 240000 بار تکرار مختلف اموزش داده می شوند در حالیکه نرخ یادگیری از 0.0003 به 0.0001 و 0.00003 کاهش می یابد

نتیجه این است که مدل triplet loss حدود 95 درصد دقت برای تسک verification داشته است اما برای identification دقت این مدل کاهش یافته البته اگرچه negative data mining باعث بهبود دقت می شود اما میزان آن محدود بوده

در مقایسه با triplet loss , N pair mc loss دقت قابل ملاحظه ای داشته و افزون بر این با قرار دادن مقدار N = 320 به مقدار 98 درصد دقت برای تسک verification شده است

شایان ذکر است که N-pair-ovo(N pair one-vs-one ) نسبت به triplet loss اولیه عملکرد بهتری داشته اما نسبت به multi class N pair loss (N_pair loss mc ) عملکرد ضعیف تری داشته است

البته راه های دیگری برای افزایش دقت تسک face verification وجود دارد به عنوان مثال استفاده از triplet network که با میلیون ها مثال tarin شده دقت تست را به 99.63 درصد ارتقا می دهد

یا مثلا استفاده از شبکه های دیپپ نتورکی که از تلفیق هر دوی triplet loss و softmax loss استفاده می کند

Analysis on tuplet construction methods


در جدول بالا(M×N) بدین معناست که از M کلاس مختلف در هر بچ استفاده شده و N بیانگر تعداد مثال های مثبت هر کلاس است

با ثابت نگه داشتن تعداد مثال های هر بچ اگر به جای N کلاس مختلف از N/2 کلاس مختلف استفاده کنیم در نتیجه باید به جای 2 مثال مثبت در هر کلاس از 4 مثال مثبت استفاده کنیم اما از انجا که N pair loss

برای مثال های متعدد مثبت هندل نشده , مشاهده می کنیم که با کاهش تعداد کلاس ها درجه خاصی از افت عملکرد به وجود می آید

با این وجود, همه این نتایج از نتایج triplet loss بهتر است

این نتایج تایید کننده آموزش با چندین کلاس منفی است و پیشنهاد می شود که تا جایی که ممکن است از تعداد کلاس های منفی بیشتری استفاده کنید


deep metricmetric learningpair loss
شاید از این پست‌ها خوشتان بیاید