Meysam.A
Meysam.A
خواندن ۲ دقیقه·۳ سال پیش

ایجاد Transformer کاستوم در Pipelineهای Scikit-learn


اگه با کتابخونه‌ی Scikit-learn آشنایی دارید، حتما پیش اومده که حین train مدل‌هاتون از pipeline استفاده کردین. بعنوان مثال فرض کنید تمام داده‌های دیتافریم ما عددی هستند و در pipeline زیر ما تعیین کردیم که قبل از فیت شدن لاجستیک رگرسیون روی داده‌هامون، داده‌ها ابتدا scale بشن:


حالا فرض کنید میخوایم از یک Transformer خاص و مد نظر خودمون به نام multiplierTransformer در Pipeline مون استفاده کنیم که یه ستون جدید (بنام multiple) به دیتافریم مون اضافه کنه که شامل حاصلضرب داده‌های دو ستون دلخواه a و b باشه:


برای این منظور، کلاس زیر رو تعریف میکنیم که دارای دو متد fit و transform بوده، و در سازنده یا constructor اون، ستون‌هایی که قراره بعنوان ورودی بدیم ست میکنیم:


سپس میتونیم از Transformer جدیدمون استفاده کنیم و پایپلاینی که ایجاد کردیم رو روی داده‌هامون فیت کنیم و بعد روی داده‌های تست، predict انجام بدیم :


و به این ترتیب با عبور داده‌ها از step اول pipelineمون، در هر سطر از دیتافریم، یه ستون جدید بنام 'multiple' از حاصلضرب دو ستون a و b ایجاد میشه. و میتونیم ترنسفورمری که ساختیم رو به تنهایی(میخوایم فقط خروجی ترنسفورمر رو چک کنیم) روی یه دیتافریم ساختگی امتحان کنیم:


* در مثال بالا ما متد fit_transform رو فراخوانی کردیم تا خروجی رو ببینیم و در پایپلاینمون هیچ estimator ای نبود(مثل LogisticRegression)، چون estimatorها متد transform ندارن(predict دارن).




نکته 1 : درواقع وقتی متد fit رو در pipeline فراخوانی می‌کنیم، متد fit مربوط به Transformerهای حاضر در pipeline فراخوانی میشه. و با فراخوانی متد predict مربوط به pipeline، تنها متد transform مربوط به Transformerهای حاضر در pipeline فراخوانی میشه.

نکته 2 : متد fit اونجاییه که فاز "یادگیری (learning)" اتفاق میوفته و پارامترهای مدلمون و همچنین Transformerهای حاضر در pipeline ، در متد fit ایجاد میشن.



حالا بعنوان مثال فرض کنید بخوایم در ترسنفورمری که ساختیم، به اون حاصلضربمون(ضرب دو ستون a و b) میانگین داده‌های دو ستون رو اضافه کنیم و اینبار حاصل رو در ستون "multipleAndMeans" قرار بدیم. برای این منظور باید عملیات بدست آوردن میانگین‌ها رو در متد "fit" انجام بدیم تا وقتی متد transform فراخوانی شد، این مقادیر(میانگین اعداد دو ستون) رو داشته باشیم:



میتونین ازین لینک به کدها دسترسی داشته باشین.

custom transformerscikit learnscikit learn pipeline
every day is a chance to learn more
شاید از این پست‌ها خوشتان بیاید