mojtaba rayati
mojtaba rayati
خواندن ۷ دقیقه·۳ سال پیش

روش S2SD

روش S2SD برای بهبود عملکرد deep metric learning (DML) توسعه داده شده است.

روش DML چیست و چه کاربردی دارد؟ برای فهم بهتر مثال زیر را در نظر بگیرید. کاربرد معمولی که از شبکه‌های عصبی می شناسیم مسئله classification است. مثلاً وقتی در خیابان قدم می­زنیم تشخیص می دهیم که این شی درخت است، این شی خودرو هست، این شی پرنده هست و ... . و این توانایی طی آموزش طولانی (داده برداری زیاد) برای ما به وجود آمده است.

اما گاه فردی را می بینیم که او را نمی شناسیم(کلاسی مشخصی برای او نداریم) اما می دانیم این فرد را در مکانی دیگر دیده بودیم. در مسئله دوم آموزش زیادی (داده برداری زیاد) برای چهره فرد برای ما انجام نشده، اما شناسایی انجام شده است. در واقع ما آموزش دیده ایم که ویژگی‌های چهره افراد را در ذهن بسپاریم و اگر آن ویژگی‌ها را مجدد ببینیم این تشخیص برای ما اتفاق می­افتد که این همان فرد قبلی است.

DML روشی است که می‌تواند برای شبکه های عصبی عملکرد مشابه را ایجاد کند. یعنی شبکهبتواند تشخیص دهد آیا دو چهره مربوط به یک نفر هستند یا خیر. مدل کردن ویژگی ها عموما با یک بردار انجام می شود و Metric تابعیست که نزدیکی دو بردار ویژگی را مدل می کند. وظیفه شبکه­ی عصبی نیز استخراج آن ویژگی­هاست.

شکل 1. استخراج ویژگی توسط شبکه برای DML
شکل 1. استخراج ویژگی توسط شبکه برای DML

loss شبکه نیز تابعی از Metric بردارهاست. برای مثال طبق تعریف تابع متریک، متریک دو بردار مشابه صفر است. به همین دلیل اگر دو چهره از فرد یکسانی باشند و در عین حال شبکه متریک غیر صفری برگرداند خطایی برای شبکه در نظر گرفته می­شود.

در عمل چه نکاتی باید رعایت شود تا شبکه عملکرد خوبی در DML داشته باشد؟ در تحقیقات دیده شده که هرچه ابعاد خروجی شبکه گسترش پیدا می کند قدرت تعمیم DML نیز افزایش پیدا می­کند. قدرت تعمیم به چه معناست؟ یعنی شبکه­ای که بر روی تعداد محدودی چهره آموزش دیده چقدر می­تواند چهره های جدید را بهتر تطبیق (یا معادلا تفکیک) دهد.

قدرت تعمیم یکی از مهمترین شاخص­هاییست که از شبکه انتظار داریم. گفته شد که مطالعات نشان داده­اند که افزایش بعد منجر به افزایش قدرت تعمیم نیز می شود. اما در کاربرد های DML در بسیاری از مواقع ما نیاز به ایجاد پایگاه داده­هایی بزرگ داریم و همزمان لازم است تطبیق تصویر دو چهره با سرعت کافی انجام پذیرد. در نتیجه اگر ابعاد شبکه یا خروجی نهایی بزرگ باشد سرعت کارکرد کاهش پیدا می­کند. راه حل چیست؟ روش­های مبتنی بر knowledge distillation روشیهاییند که می توانند اطلاعات یک شبکه بزرگتر به یک شبکه کوچکتر را بفهمانند به طوری که شبکه کوچکتر رفتاری مشابه با شبکه بزرگتر از خود نشان دهد.

ایده اصلی S2SD اینست این است که برای مسئله DML با استفاده از knowledge distillation پیشنهادی خود، شبکه­ای طراحی کند که در این حالی که به قدر کافی کوچک است، رفتاری مشابه شبکه­های بزرگتر را تقلید کند تا بتواند قدرت تعمیم بیشتری از خود نشان دهد. حال به جزئیات این ایده می پردازیم.

معرفی metric و loss:

فضای ویژگی را تصور کنید که توسط شبکه­های بزرگ استخراج می­شوند، شبکه­هایی نظیر resnet و یا inception. این فضای ویژگی Φ نامیده شده و این فضای ویژگی متناسب با مسئله classification به دست آمده است. هدف اینست که تابع fی را پیدا کنیم که این ویژگی­ها را برای تطبیق دو چهره بهینه کند. این تابع همان شبکه عصبی­ای است که نیاز داریم آموزش دهیم. خروجی این شبکه (تابع) را Ψ نامیده شده است. متریک به صورت زیر را تعریف شده است.

یعنی نزدیکی دو چهره معادل است با فاصله­ای اقلدیسی دو بردار ویژگی به دست آمده از شبکه f. حال loss شبکه را نیز از جنس triplet تعریف شده است. triplet loss تلاش می­کند بین متریک دو تصویر یکسان تفاوتی با متریک دو تصویر غیر کسان ایجاد کند بدون اینکه اندازه­ای دلخواهی برای این متریک قائل باشد. یعنی هدف اصلی آن تفاوتگذاریست.

در معادله بالا i و j به معنای تصویر مختلف از یک فرد، i و k نیز به معنای دو تصویر مختلف از دو فرد هستند. مقدار m نیز عددی مثبت است یک آستانه­ی حداقلی برای این تفاوت ایجاد کند.

روش knowledge distillation:

گفته شد که برای knowledge distillation باید از شبکه های بزرگتر برای آموزش شبکه کوچکتر استفاده کرد. شبکه بزرگتر با تابع g را در نظر بگیرید که همزمان با f بر روی مجموعه Φ عمل می کند. و خروجی آن Ψg نامیده شده و مشابه با f آموزش می بیند. در این حالت شبکه fهم تلاش می­کند که Ltriplet را کمینه سازد و هم سعی می کند رفتار شبکه g را نیز تقلید کند. تقلید گفته شده با مشابه­سازی cosine similarity matrix (CSM) انجام می­شود.

ابتدا را بردارها Ψ نرمال می کنیم. CSM برای ست Ψ را Dمی­نامیم (Dij=ψiTψj). در این صورت هر درایه اندازه کسینوس زاویه بین هر دو بردار را نمایش می دهد. در طول آموزش تلاش می‌شود که f CSM مشابهی با g برای هر batch از داده ها را به وجود ‌آورد. اما معیار شباهت بین دو CSM چیست؟ آیا باید اندازه درایه های آن مشابه اندازه درایه دومی شود. مثال فرضی زیر را در نظر بگیرید اولین سطر مرجع بوده و می­خواهیم دومی و سومی را از نظر رفتاری با آن مقایسه کنیم.

در سطح اول (D1) می­بینیم که بین داده اول بیشترین شباهت با داده چهارم وجود دارد. برای مجموعه داده دوم (D2) بیشترین شباهت داده اول با داده های سوم و پنجم بوده همچنین در سطر سوم (D3) نیز شباهت داده اول و چهارم بیشتر است. پس می­بینیم اگرچه اندازه­ی درایه های سطر دوم با سطر اول مشابه­تر است، اما از لحاظ رفتاری سطر سوم و اول سنخیت بیشتری با یکدیگر دارند. برای کمی کردن این تشابه می­توان از معیار پراکندگی Kullback-Leibler استفاده کرد در واقع KL میزان شباهت دو توزیع را نشان می دهد و با رابطه­ی زیر به دست می ­آید.

در اینجا نیز از همین منطق برای ایجاد خطا استفاده شد یعنی هرچه ماتریس CSM تابع f، از نظر پراکندگی KL، تفاوت بیشتری با g داشته باشد، خطای بیشتری برای آن در نظر گرفته می شود. مدل ریاضی آن به شرح زیر است.

در اینجا σ تابع softmax بوده و T متغیر میزان temperature است (یعنی اگر به سمت بی نهایت برود خروجی softmax برای همه یکسان می شود). دقت شود خطای distillationتنها برای آموزش f استفاده خواهد شد و مفروض این است که رفتار g نیازی به تقلید ندارد. علامت † نشان دهنده همین است که g با این خطا آموزش نخواهد یافت. با تعریف این خطا، خطای کلی برای شبکه را به صورت زیر در نظر گرفته می­شود.

در اینجا L_DML همان Ltriplet است. نمودار بلوک این ساختار در زیر آمده است.

شکل 2. پیاده ­سازی knowledge distillition از شبکه g به f با استفاده از CSM
شکل 2. پیاده ­سازی knowledge distillition از شبکه g به f با استفاده از CSM


باز هم Knowledge distillation!

در مرحله قبل این احتمال داده شد که اگر اطلاعات از یک شبکه بزرگتر به شبکه کوچکتر انتقال یابد قدرت تعمیم شبکه کوچک بالا می رود. محتمل است که شبکه های دیگری نیز وجود داشته باشند که برای مجموعه ای از داده ها قدرت تعمیم بهتری را فراهم آورد. و به این دلیل شاید بهتر باشد به جای استفاده تنها از gاز مجموعه ای از شبکه ها با ابعاد مختلف استفاده شود و هر کدام از آنها اطلاعات خود را به f بفرستند. به این صورت که مجموعه ای از شبکه ها از g1 تا gm که ابعاد g_(i+1)>g_i است همزمان بر روی داده­ ها عمل می کنند و اطلاعات خود را به f بفرستند. تابع هزینه در این حالت به شکل زیر خواهد شد.

مسئله بعدی این است که هنگامی که ابعاد در میانه شبکه کاهش پیدا میکند شبکه دچار dimensionality bottleneck می­شود. در اینجا نیز از φ به f یک bottleneck ایجاد می شود که ممکن است باعث شود قدرت تعمیم f کاهش یابد. همچنین دیده شده φ به دلیل ابعاد بالا می‌تواند برای مسئله DML مورد استفاده قرار گیرد به همین دلیل knowledge distillation از φ به fنیز انجام شده تا اگر اطلاعاتی از دست رفت بازیابی شود. در نتیجه loss نهایی پیشنهادی S2SD به فرم زیر خواهد بود.

نمودار بلوکی کل شبکه به صورت زیر شده است.

شکل 3. نمودار بلوکی S2SD
شکل 3. نمودار بلوکی S2SD


نتایج:

با این تغییرات نویسندگان مقاله S2SD ادعا کرده­اند که توانستند ۷ درصد بهبود در معیار Recal@1 بگیرند. در فرمول­ها loss متغیر γ که واضحا میزان تاثیر knowledge distillationرا در کل فرایند یادگیری کنترل می­کرد. شکل زیر نشان دهنده ­ی تغییرات معیار Recall@1 بر حسب این متغیر برای سه دیتاست مختلف از داده هست.

شکل 4.  تغییرات Recall@1 برحسب میزان تاثیر knowledge distillation (با متغیر γ)
شکل 4. تغییرات Recall@1 برحسب میزان تاثیر knowledge distillation (با متغیر γ)

واضحا دیده­ می­شود که معیار با knowledge distillationبهبود یافته اما در عین حال مقدار بهینه نیز برای آن وجود دارد.

همچنین اثر توسعه lossهایی که در مراحل مختلف بیان شد بررسی شده و گزارش شده که آخرین lossپیشنهادی در کل بهترین عملکرد را از خود نشان داده است. همچنین اثر افزایش بعد بر Recall@1 نشان داده شده است.

شکل 5. تاثیر ابعاد و lossهای مختلف بر recall@1
شکل 5. تاثیر ابعاد و lossهای مختلف بر recall@1


همچینین در مقاله بیان شده بهتر است که شبکه مورد استفاده برای f و giها یک MLP دو لایه باشد. شکل زیر اثر شبکه­هایی مختلف را در معیار سنجش نشان می­دهد.

شکل 6. تاثیر شبکه­های مختلف بر Recall@1
شکل 6. تاثیر شبکه­های مختلف بر Recall@1




مرجع:

[1] Roth, Karsten, et al. "Simultaneous Similarity-based Self-Distillation for Deep Metric Learning." International Conference on Machine Learning. PMLR, 2021.

شبکه عصبیشبکه
شاید از این پست‌ها خوشتان بیاید