Meysam.A
Meysam.A
خواندن ۵ دقیقه·۳ سال پیش

آشنایی با Cross Validation در Scikit-Learn


فرآیند cross validation به این شکله که دیتاست به k بخش(fold) تقسیم میشه و k بار از یه بخش جدید برای test و از بقیه(k-1) برای train استفاده میشه، به این ترتیب زیرمجموعه‌های مختلفی از کل دیتایی که داریم رو، برای تست استفاده میکنیم. حالا هدف ازین کار یا پیدا کردن بهترین hyperparameterهای مدلامون هست، یا ارزیابی مدل‌ها، که در هر صورت دید منطقی تری نسبت به generalization مدل پیدا میکنیم، یعنی میفهمیم مدلمون تا چه حد عملکرد خوبی روی دیتاهایی که تجربه نکرده(unseen data) خواهد داشت. به مثال زیر توجه کنید:

اینجا هر بار دیتاستمون رو با یه نسبت دلخواه تقسیم کردیم و مدلمون رو با یه بخش از دیتا train و با یه بخش دیگه test کردیم
اینجا هر بار دیتاستمون رو با یه نسبت دلخواه تقسیم کردیم و مدلمون رو با یه بخش از دیتا train و با یه بخش دیگه test کردیم

میبینیم که طی هر split، مدلمون عملکرد خیلی متفاوتی داشته و دلیلش هم اینه که هر بار با نمونه‌های متفاوتی test شده. اما ما نمیتونیم به test کردن مدلمون به شکل بالا و با فقط یه بخش از دیتاست بسنده کنیم، چون گاهی میخوایم چندین مدل رو باهم مقایسه میکنیم، و بهتره سعی کنیم از تمام دیتایی که داریم بعنوان test-set استفاده کنیم، همچنین این کارو باید طوری انجام بدیم که قبلش مدلمون این test-set رو تجربه نکرده باشه، تا بفهمیم در حالت کلی کدوم یکی از مدلامون عملکرد بهتری توی predict کردن داده‌هایی که تجربش نکردن(unseen data) دارن، و اینجاست که باید cross validation انجام بدیم(البته اگه بخوایم همزمان هم مدلامون رو tune کنیم، و هم ارزیابیشون کنیم، باید از nested cross validation استفاده کنیم، که تو این پست راجبش توضیح داده شده)


استراتژی‌های مختلفی برای cross validation بر اساس ماهیت داده‌ها وجود داره که در کتابخونه scikit-learn براشون کلاس‌هایی بنام Cross validation Splitters هست که بهشون cross validator هم گفته میشه(در حالی که صرفا splitter هستن فقط)، مثل KFoldStratified ، LeaveOneOut ، ShuffleSplit ، KFold و ... که میان به ما ایندکس داده‌های split شده در هر مرحله رو برمیگردونن و هر کدوم در موقعیت خاصی(با توجه به داده‌هامون) استفاده میشن که تو یه مطلب دیگه دقیق تر بهشون اشاره میکنیم. قطعه کدای زیر مثالی از ساده‌ترین splitter هستن بنام KFold، که تو مواقعی که دیتاهامون هیچ ترتیب زمانی یا هیچ ارتباط گروهی‌ای باهم ندارن(تاثیر وجود این ارتباط رو هم بعدا بررسی میکنیم)، استفاده میشه و خروجیشو میبینید:

اینجا داده‌هارو به 3 قسمت تقسیم کردیم و از shuffle استفاده نکردیم، که یعنی داده‌های test با همون ترتیبی که داشتن جدا شدن(یه بار 3 تای اول، یبار 3 تای دوم و ...)
اینجا داده‌هارو به 3 قسمت تقسیم کردیم و از shuffle استفاده نکردیم، که یعنی داده‌های test با همون ترتیبی که داشتن جدا شدن(یه بار 3 تای اول، یبار 3 تای دوم و ...)
اما اینجا با shuffle=True ،
اما اینجا با shuffle=True ،

خب، حالا فرض کنید با این splitterها میخوایم روی دیتاستمون cross validation انجام بدیم و درواقع مدلمون رو با زیر مجموعه‌های مختلفی از feature‌ها (X) و label‌ها (Y) مون ، train و test کنیم. بدون shuffle دیتامون(هم X و هم Y) به این شکل تقسیم میشه:

به ترتیب هر بار یک چهارم بعدی ِدیتا split شده(ایندکسِ داده‌ها مرتب هستش)
به ترتیب هر بار یک چهارم بعدی ِدیتا split شده(ایندکسِ داده‌ها مرتب هستش)


و همچنین با shuffle کردن دیتا، ایندکس داده‌ها بُر میخورن و تو هر split ، اون یک سوم داده‌ها (test set) از بخشهای رندومی انتخاب میشن:


این splitter ها میتونن بعنوان ورودی"cv" در توابع cross_validate ، cross_val_score و GridSearchCV و امثال اینا پاس داده بشن:

از یکی از splitter هایی که ساخته بودیم بعنوان ورودی
از یکی از splitter هایی که ساخته بودیم بعنوان ورودی


خب راست به چپ هم که نمیشه بنویسیم تو ویرگول? پس مجبوریم اول هر عنوان انگلیسی یه چیز فارسیم بذاریم ?

  • فانکشن GridSearchCV : برای Hyperparameter tuning استفادش میکنیم و کارکردش به این شکله که طی فرآیند cross validation میاد یه دیکشنری از ما میگیره (به ورودی param_grid پاسش میدیم) که شامل یه سری پارامترهای مخصوص مدلمون هست و ما میخوایم بفهمیم مدلمون با کدوم ترکیب ازین پارامترها، روی دیتاستی که داریم، بهتر عمل میکنه و بعد از fit کردن روی دیتامون با مشخصه best_estimator ، مدلی که با بهترین پارامترا fit شده رو میتونیم داشته باشیم:
توجه کنید که اینجا اصلا ورودی
توجه کنید که اینجا اصلا ورودی
  • فانکشن cross_validate : صرفا خوده عمل cross validation رو انجام میده و خروجیش یه دیکشنری شامل زمان fit شدن، زمان evaluate کردن و scoreهای مربوط به split‌های مختلف هستش:
اینجا ورودی
اینجا ورودی

تو عکس بالا به test_score مربوط به split چهارم توجه کنید، میبینیم که از همه split ها کمتره.دلیل این overfitting چی میتونه باشه؟ جواب: ازونجایی که "cv" رو یه آبجکت KFold دادیم، تو این حالت splitter ما داده‌ها رو بُر نمیزنه و عملا تو split آخر، 25% آخر دیتاست قرار گرفتن. پس باید دیتاستمون رو هم دقیق‌تر بررسی کنیم تا ببینیم دلیل چی میتونه باشه، و آیا دیتاهامون تو اون تیکه شکل خاصی دارن؟ البته این حالت با یه بُر زدن رفع میشه و حتی ممکنه تو همین مثالم علتش چیزه خاصی نباشه، ولی نکته اینجاست که بعضی وقتا(هر چند خیلی کم) ممکنه حین cross validation به همچین مواردی بر بخوریم که دید بهتری نسبت به داده‌هامون پیدا میکنیم.

  • فانکشن cross_val_score : مستقیما test_score ها رو برمیگردونه:
ورودی
ورودی

یه سری دیگه هم ازین توابع هستن که خودتون در صورت نیاز بررسیشون کنین، مثل learning_curve ، RandomizedSearchCv و ...

اغلب تو مواردی که برای آشنایی یا تمرین روی دیتاست‌ها ازین فانکشن‌ها استفاده میشه، پارامتر "cv" رو یه عددی مثل k و یا حتی "none" در نظر میگیریم، که در این حالت، این فانکشن‌ها بصورت پیشفرض از (shuffle=False)StratifiedKFold بعنوان آبجکت "cv" استفاده میکنن که مشابه KFold هستش، منتهی میاد علاوه بر اینکه داده‌هامون رو هر بار به k دسته تقسیم میکنه، نسبت داده‌ها رو هم رعایت میکنه، یعنی اگه فرضا سه کلاس B ، A و C داشته باشیم که به ترتیب 40% ، 30% و 30% از داده‌هامون رو تشکیل بدن، تو هر split ، هم بخش test و هم بخش train، شامل 40% از کلاس A و 30% از کلاس B و 30% از کلاس C خواهند بود.

نکته: StratifiedKFold در صورتی که estimator مون یه classifier باشه(مثل SVM)، و داده‌هامون 2 دسته و یا بیشتر باشن بعنوان splitter پیشفرض هست، در غیر این صورت همون (shuffle=False)KFold فراخوانی میشه.

حالا که تا اینجا اومدین، بد نیست نحوه split کردن خوده همین StratifiedKFold رو هم ببینید:

میبینیم که نسبت کلاس‌ها در هر split حفظ شده (30% از هر کلاس برای train و 30% برای test)
میبینیم که نسبت کلاس‌ها در هر split حفظ شده (30% از هر کلاس برای train و 30% برای test)

ضمنا شکل مربوط به (shuffle=True)StratifiedKFold هم شبیه (shuffle=True)Kfold میشد و از لحاظ بصری نمیشه تشخیص داد که نسبت کلاس‌ها حفظ شده واسه همین دیگه شکلشو نذاشتم.


خب فعلا کافیه، ببخشید اگه طولانی شد. اینجا فقط سعی شد تا کلیت cross validation گفته بشه و یکم به اهمیت این کار اشاره کنیم و انشاالله تو مطلب مفصل تری به انواع روشهای cross validation (در واقع انواع استراتژی‌های splitting) و علت استفاده از هر نوع میپردازیم.

ضمنا کدهای مربوط به این مطلب ازین آدرس قابل دسترسی هستن. موفق باشید?

منبع

cross validationkfoldStratifiedKFoldgridsearchcvScikit_Learn
every day is a chance to learn more
شاید از این پست‌ها خوشتان بیاید