علیرضا محمدی
علیرضا محمدی
خواندن ۱۰ دقیقه·۲ سال پیش

Hierarchical Triplet Loss

در تسک‌های metric learning یا similarity learning به دنبال یادگیری یک تابع فاصله هستیم که این تابع به عنوان معیاری برای مقایسه تصاویر مختلف و بررسی شباهت آن‌ها مورد استفاده قرار می‌گیرد.

بر این اساس وقتی تابع فاصله‌ی مناسبی آموزش دیده شود، تصاویر ورودی با محتوای مشابه به ناحیه مشخصی در فضای embedded تصویر می‌شوند که از تصاویر با محتوای متفاوت فاصله خواهند داشت.


Deep Metric Learning

با رشد و توسعه شبکه‌های عصبی عمیق، پیشرفت زیادی نیز در زمینه‌ی metric learning رخ داده‌است. استفاده از شبکه‌های عصبی عمیق این امکان را فراهم می‌سازد تا بتوان فضای embedded را به نحوی شکل داد که با استفاده از توابع فاصله‌ی ساده مثل فاصله اقلیدسی و فاصله کسینوسی امکان تمایز بین تصاویر با محتوای متفاوت وجود داشته باشد.
بر این اساس ورودی با استفاده از یک شبکه عصبی عمیق، به فضای embedded تصویر می‌شود که در این فضا می‌توان با استفاده از یک تابع فاصله اقلیدسی، فاصله بین دو نقطه را تعریف کرد؛ تصاویر با محتوای مشابه به به نواحی مشابهی در فضای ثانویه تصویر خواهند شد و دو تصویری که محتوای متفاوتی دارند، در این فضای ثانویه نیز با یکدیگر فاصله خواهند داشت.

توابع خطا

در یک شبکه‌ی عصبی عمیق به منظور طبقه‌بندی تصاویر، از یک تابع هدف مثل softmax استفاده می‌شود که احتمال هر یک از کلاس‌ها را به ازای هر تصویر ورودی محاسبه می‌کند. در این شبکه به منظور یادگیری وزن‌ها، از روش گرادیان نزولی استفاده می‌شود؛ گرادیان به ازای هر یک از تک تصاویر ورودی محاسبه شده و سپس وزن‌های شبکه بروزرسانی خواهد شد.
در توسعه‌های اخیر، توابع loss مختلفی مورد بررسی و استفاده قرار گرفته‌اند، مثل:

  • contrastive loss
  • triplet loss
  • quadruplet loss

مقدار این توابع بر اساس چند نمونه‌ی همبسته محاسبه می‌شود با این منظور که نمونه‌های مشابه و مربوط به یک کلاس، در فضای ثانویه به یکدیگر نزدیک شوند و نمونه‌های متعلق به کلاس‌های متفاوت از یکدگیر فاصله بگیرند.

از همین رو انتخاب نمونه‌های مناسب به منظور آموزش شبکه با چالش‌های زیادی همراه است، زیرا ورودی شبکه به صورت یک دوتایی یا چندتایی از تصاویر است و ایجاد چنین نمونه‌هایی در یک دیتاست با تعداد تصاویر و کلاس‌های زیاد، حالت‌های زیادی را ایجاد می‌کند. برای مثال در یک دیتاست با N تصویر، تعداد زوج‌های ممکن برای استفاده از تابع contrastive loss از مرتبه O(N^2) و تعداد سه‌تایی‌ها برای استفاده از تابع triplet loss از مرتبه O(N^3) است که عملا بررسی تمامی این حالات، مخصوصا در یک دیتاست بزرگ امکان پذیر نخواهد بود.

در یک شبکه که به منظور طبقه‌بندی تصاویر مورد استفاده قرار می‌گیرد وقتی از mini batch استفاده می‌کنیم، گرادیان‌های محاسبه شده به منظور بهینه‌سازی پارامتر‌های شبکه، محلی خواهند بود زیرا به دلیل محدودیت محاسباتی و فضای ذخیره‌سازی، امکان استفاده از تمامی دیتاست در یک batch وجود ندارد و شبکه تنها بر روی توزیع محلی دیتاهای یک mini batch تمرکز می‌کند.

بنابراین نمی‌توان توزیع فراگیر دیتا را به سادگی مورد بررسی قرار داد به همین دلیل یادگیری شبکه کند شده و امکان گرفتار شدن در کمینه‌های محلی نیز زیاد می‌شود.

این مساله در شبکه‌های deep metric نیز به صورت شدیدتری بروز پیدا می‌کند، زیرا اندازه‌ی فضای نمونه‌ها همانطور که گفته شده از مرتبه O(N^2) تا O(N^4) افزایش خواهد یافت که در این فضا بحث redundancy نیز مساله‌ساز خواهد شد و اکثر نمونه‌ها نیز اطلاعات کافی ارائه نمی‌کنند بنابراین نمونه‌برداری تصاد فی و استفاده از این نمونه‌های redundant باعث کاهش شدید عملکرد شبکه و همگرایی کند آن خواهد شد.

به همین دلیل اساس بیشتر تحقیقات و توسعه‌های انجام شده در زمینه‌ی metric learning ، ارائه روشی به منظور مواجهه با این redundancy و استراتژی مناسب به منظور انتخاب نمونه‌هایی است که اطلاعات کافی را در اختیار شبکه قرار دهند و روند آموزش آن را تسریع کنند.

Triplet Loss

همانطور که در بخش قبل اشاره شد، یکی از توابع loss مورد استفاده در شبکه‌های metric learning تابع triplet است که مقدار آن بر اساس دسته‌ی سه‌تایی از تصاویر محاسبه می‌شود. در این دسته‌ی سه‌تایی یک تصویر به عنوان مرجع، یک تصویر از کلاس مشابه با تصویر مرجع و یک تصویر از کلاس متفاوتی انتخاب می‌شود

تاثیر کاهش تابع خطای triplet
تاثیر کاهش تابع خطای triplet


برای مثال (x_i,y_i) نمونه‌ی i ام در دیتاست آموزشی

است و

بردار ویژگی متناظر با ورودی x_i در فضای ثانویه است که معمولا این بردار ویژگی نرمالیزه می‌شود تا فرآیند آموزش شبکه پایدارتر شود. در طی فرآیند آموزش شبکه، سه نمونه از دیتاست آموزشی انتخاب می‌شود و سه‌تایی

تشکیل داده می‌شود که در آن x_a تصویر مرجع، x_p تصویر با کلاس مشابه و x_n تصویر از کلاس متفاوت است. هدف کلی این تابع loss این است که نمونه‌های مرتبط با یک کلاس را در فضای ثانویه به یکدیگر نزدیک کرده و نمونه‌های متفاوت را دور کند، بر این اساس تابع خطا به صورت زیر تعریف می‌شود:

با توجه به این تابع خطا، در صورتی که فاصله بین تصویر مرجع و نمونه مثبت کم‌تر از فاصله تا نمونه منفی باشد شبکه به هدف خود رسیده‌است و نتیجتا مقدار تابع خطا و گرادیان آن نیز صفر خواهد بود و وزن‌های شبکه بروزرسانی نمی‌شود.

اما در صورتی که فاصله تصویر مرجع تا نمونه منفی کم‌تر باشد، گرادیان غیر صفر بوده و وزن‌های شبکه بروزرسانی می‌شود. پارامتر alpha به نحوی حاشیه‌ای را برای آموزش شبکه در نظر می‌گیرد تا فاصله نمونه‌های مربوط به کلاس متفاوت به اندازه‌ی مشخصی زیاد شود.

این پارامتر نقش کلیدی در انتخاب نمونه‌ها در آموزش شبکه‌ی deep metric با استفاده از تابع خطای triplet بازی می‌کند.

Hierarchical Triplet Loss

در مدل‌های متداول کنونی metric learning که از توابع خطای contrastive، triplet و quadruplet استفاده می‌کنند با چالش‌های جدی روبرو هستیم:

  • همانطور که در بخش قبل هم به آن اشاره شد، به دلیل محدودیت‌های حافظه در پردازنده‌های گرافیکی نمی‌توان تمامی نمونه‌ها را در یک batch به شبکه تحویل داد و از همین رو نمی‌توان به سادگی، توزیع فراگیر دیتا را مورد بررسی قرار داد که این موضوع می‌تواند موجب گرفتار شدن مدل در کمینه‌های محلی شود.

برای حل این مشکل راهکار استفاده از ساختار درختی در [1] ارائه شده‌است. به این صورت که در هر مرحله درختی را تشکیل می‌دهیم که برگ‌های آن در سطح صفر، کلاس‌های موجود در دیتاست هستند و در هر سطح بر اساس مدل کنونی و یک مقدار آستانه، کلاس‌هایی با محتوای مشابه و نزدیک به هم، تلفیق می‌شوند. بر اساس این ساختار درختی می‌توانیم روابط موجود در دیتاست را بررسی کنیم.

  • مشکل دیگری که برای مثال در یک مدل با تابع خطای triplet با آن روبرو هستیم، این است که تعداد زیادی از triplet های ایجاد شده، به دلیل در نظر گرفتن یک مقدار ثابت برای violate margin ، گرادیانی را ایجاد نمی‌کنند و از این رو اطلاعات جدیدی را به مدل اضافه نخواهند کرد، در حالی که سرعت همگرایی و آموزش شبکه را به شدت کاهش می‌دهند. به عبارتی با تمامی دیتاها به صورت یکسان برخورد می‌شود. برای حل این مساله نیز ایده‌ی استفاده از dynamic margin ارائه شده‌است. به این صورت که این مقدار حاشیه را به نحوی انتخاب می‌کنیم تا تمرکز مدل بر روی نمونه‌های دارای اطلاعات باشد.
  • مورد نهایی نیز بحث انتخاب tripletهایی است که اطلاعات مناسبی را در هنگام آموزش مدل ارائه کنند. آنچه در [1] بررسی شده‌است این است که به جای دور کردن صرفا نمونه‌هایی با کلاس‌های متفاوت، تلاش کنیم تا کلاس‌هایی با محتوای مشابه را نیز از یکدیگر دور کنیم و از این رو مدلی با قابلیت جداسازی قوی‌تر ایجاد کنیم.

الگوریتم HTL ارائه شده در [1] دارای دو بخش اصلی است:

  • ایجاد ساختار درختی از کلاس‌های موجود در دیتاست بر اساس مدل کنونی که سطح اول این درخت متشکل از تمامی کلاس‌هاست و در سطح آخر آن، تمامی نود‌ها با یکدیگر تلفیق شده‌اند.
  • تشکیل تابع خطای جدید بر مبنای خطای triplet که پارامتر آستانه در آن به صورت متغیر در نظر گرفته می‌شود.

روند ایجاد ساختار درختی

با استفاده از شبکه عصبی

که وزن‌های آن با استفاده از خطای triplet مرسوم آموزش دیده‌است، ساختار درختی فراگیر را ایجاد می‌کنیم. با توجه به اینکه

بیان نمونه‌ی x_i در فضای ثانویه است، ماتریس فاصله‌ی C بین تمامی کلاس‌ها را بر روی دیتاست آموزشی D محاسبه می‌کنیم:

که در آنd(p, q) فاصله‌ی بین دو کلاس می‌باشد. مقدار بردار r_i به دلیل پایداری فرآیند یادگیری مدل، به مقدار واحد نرمالیزه می‌شود. در صورتی که در معادله از نرم مرتبه اول استفاده شود، مقدار این فاصله‌ در بازه‌ی [0,4] و در صورتی که از نرم اقلیدسی استفاده شود در بازه‌ی [0,2] خواهد بود.

برای ساخت درخت از کلاس‌های موجود، ابتدا با برگ‌هایی متشکل از تمامی کلاس‌ها در سطح صفر شروع می‌کنیم و در سطوح بعدی به تدریج این نود‌ها را بر اساس ماتریس فاصله محاسبه شده، تلفیق می کنیم تا اینکه در مرحله آخر تنها یک نود باقی می‌ماند.

در سطح صفر، حد آستانه برای تلفیق نود‌ها به صورت میانگین فواصل درونی هر کلاس محاسبه می‌شود:

که در آن n_c تعداد نمونه‌های موجود در کلاس c ام می‌باشد. در سطح l ام از ساختار درختی، نود‌هایی که فاصله‌ی آن‌ها کم‌تر از حد آستانه‌ی

باشد، با یکدیگر تلفیق می‌شوند.

تابع خطا و حد آستانه

پس از ایجاد ساختار درختی مشابه با آنچه در بخش قبل به آن اشاره شد، به صورت تصادفی 'l نود از سطح صفر درخت H را انتخاب می‌کنیم که هر نود بیانگر یک کلاس می‌باشد، این کار به منظور حفظ تنوع در نمونه‌های آموزشی انجام می‌شود.

سپس m- 1 نزدیک‌ترین کلاس را به هر یک از این 'l نود در سطح صفر انتخاب می‌کنیم. این کار با هدف افزایش قدرت مدل در جدایی‌سازی بین کلاس‌هایی با محتوای تقریبا مشابه انجام می‌شود.

در نهایت t تصویر را از هر کلاس انتخاب می‌کنیم که نتیجتا به n = l' * m * t تصویر در mini batch M می‌رسیم. به منظور ایجاد تمامی triplet های ممکن در این mini batch ، ابتدا باید از بین l' m کلاس، دو کلاس مثبت و منفی را انتخاب کنیم که تعداد حالات ممکن برای آن برابر با

است. سپس دو نمونه‌ی مثبت و مرجع را از بین دیتاهای کلاس مثبت انتخاب می‌کنیم و از بین نمونه‌های کلاس منفی انتخاب شده‌ نیز نمونه‌ی negetive را انتخاب می‌کنیم پس بنابراین تعداد triplet های موجود در M برابر با

خواهد بود.پس از ایجاد triplet های ممکن در mini batch ، مقدار تابع خطا را محاسبه می‌کنیم:

یکی از مهمترین وجه تمایز‌های تابع خطای بررسی شده در [1] با تابع خطای مرسوم triplet استفاده از حد آستانه‌ی متغیر است. مقدار این حد آستانه بر اساس ساختار درختی وطبق ارتباط بین کلاس مرجع و کلاس منفی محاسبه می‌شود. برای مثال برای سه‌تایی T مقدار آستانه برابر است با:

این مقدار آستانه از سه‌ترم تشکیل شده‌است. ترم beta برای اطمینان از این موضوع است که کلاس‌های تصاویر در مرحله کنونی آموزش نسبت به مرحله قبلی از یکدیگر دورتر شوند. ترم d حد آستانه‌ای که به خاطر آن دو کلاس مرجع و کلاس منفی در ساختار درختی با یکدیگر تلفیق شده‌اند را در نظر می‌گیرد. با توجه به آنکه اگر در سه‌تایی انتخاب شده، فاصله بین کلاس مرجع و کلاس منفی، بیشتر از کلاس مرجع و کلاس مثبت باشد، در حالت عادی ممکن است به خاطر انتخاب حد آستانه ثابت، گرادیان برابر با صفر باشد و این سه‌تایی تاثیری در آموزش نداشته باشد،‌ اما اگر این ترم را در حد آستانه متغیر در نظر بگیریم، چنین سه‌تایی‌هایی نیز می‌توانند در ایجاد گرادیان نقش داشته باشند و موجب دورتر شدن کلاس‌هایی که با یکدیگر confusion زیادی دارند بشویم.

ترم d برابر با حد آستانه‌ی تلفیق کلاس مرجع و کلاس منفی در سطحی از درخت H است که در سطح بعدی این دو کلاس تلفیق شده‌باشند.

در ترم سوم نیز فاصله‌ی متوسط بین نمونه‌های کلاس مرجع در نظر گرفته شده‌است:

در این تابع خطای ارائه شده، نمونه‌ی مرجع در triplet تلاش می‌کند تا نمونه‌هایی با مفاهیم متفاوت را از خودش دور کند، از طرفی استفاده از حد آستانه‌ی متغیر باعث می‌شود که triplet هایی که در آن‌ها فاصله بین نمونه مرجع و نمونه منفی خیلی زیاد است نیز در ایجاد گرادیان نقش داشته باشند.

پیاده‌سازی

به منظور پیاده‌سازی HTL از فریم‌ورک Caffe بر روی پردازنده گرافیکی NVIDIA TITAN X استفاده شده‌است که دارای 12 گیگابایت فضای ذخیره‌سازی می‌باشد. برای ساختار شبکه عصبی از معماری GoogLeNet با تکنیک batch-norm بهره گرفته شده‌است که وزن‌های آن به صورت از پیش آموزش دیده شده بر روی دیتاست ImageNet می‌باشد با این تفاوت که لایه‌های fully connected نهایی این شبکه حذف شده و یک لایه جدید d
بعدی اضافه شده‌است که بیان هر تصویر در فضای ثانویه را ایجاد می‌کند.

وزن‌های این لایه بر اساس نویز تصادفی و با استفاده از فیلتر Xavier مقداردهی شده‌اند.

در [1] به منظور استفاده از 650 تصویر در هر mini batch تغییراتی در مدیریت حافظه فریم‌ورک مربوطه داده‌شده‌است. هر تصویر ورودی این شبکه دارای ابعاد 224*224 بوده که مقدار میانگین پیکسل‌ها به عنوان پیش‌پردازش، از آن‌ها کاسته شده‌است.

نتایج

الگوریتم HTL ارائه شده در [1] بر روی دیتاست‌های مختلفی مورد بررسی قرار گرفته‌است. یکی از این دیتاست‌ها Cars-196 و Stanford Online Products است. برای بررسی عملکرد این الگوریتم، عمق ساختار درختی برابر با 16 و مقدار beta = 0.2 انتخاب شده‌است. کل فرآیند آموزش بر روی این دیتاست به 30 epoch احتیاج دارد و اندازه batch برابر با 50 می‌باشد. در هر 10 epoch نرخ یادگیری یک‌دهم می‌شود.

نتایج مقایسه این روش با چند روش دیگر در جدول زیر و به ازای Recall@x های مختلف نمایش داده‌شده‌است:

مقایسه نتایج
مقایسه نتایج


این نتایج نشان می‌دهد که روش HTL ارائه شده در [1] عملکرد triplet loss را به صورت قابل توجهی افزایش داده و قابلیت تعمیم مدل نیز افزایش داشته‌است.



[1] Ge, W., Huang, W., Dong, D., Scott, M.R. (2018). Deep Metric Learning with Hierarchical Triplet Loss. ECCV.

triplet losshierarchical triplet lossmetric learningdeep metric learningطراحی و تحلیل شبکه‌های عصبی عمیق
شاید از این پست‌ها خوشتان بیاید