دانشجوی دکتری مهندسی کامپیوتر دانشگاه تهران، علاقهمند به حوزه علوم شناختی و اعصاب شناختی، فلسفه، هوش مصنوعی، علوم داده، بصری سازی داده، روانشناسی، اقتصاد و مالی
چگونگی کار با الگوریتم درخت تصمیم به کمک کتابخانه Scikit-Learn
در این پست میخواهیم یک مثال واقعی از یک مساله طبقهبندی را به کمک الگوریتم درخت تصمیم(Decision Tree) حل کنیم و برای این کار از کتابخانه Sci kit-Learn در پایتون استفاده میکنیم. تمامی کدهای نوشتن برنامه نیز در اینجا قرار میگیرد و هر پردازش و یا تغییری که در هر قسمت از فرآیند اجرای آن انجام دادهام نیز در اینجا ذکر میشود. تا جایی که امکان داشته باشد سعی میکنم مطالب را بسیار ساده و روان توضیح بدهم که در فهم و درک آن هیچ مشکلی وجود نداشته باشد. بنابراین تا پایان این پست با خیال آسوده مطالب را دنبال کنید.
در این مساله میخواهیم بر اساس اطلاعاتی که از یک تعدادی از دانشآموزان یک مدرسه داریم، مدلی بسازیم که بتواند پیشبینی کند که یک دانشآموز بر اساس نمرات قبلی امتحانات، آیا در امتحان نهایی قبول میشود یا رد میشود؟ در اصل مساله طبقهبندی دانشآموزان در دو دسته قبول و رد میباشد. مجموعه دادهای که برای این مساله در نظر گرفته شده است را میتوانید از لینک زیر دانلود کنید.
https://archive.ics.uci.edu/ml/datasets/student+performance
این مجموعه داده شامل اطلاعات ۶۴۹ دانشآموز میباشد. برای هر دانشآموز ۳۰ دادهی مختلف ذخیره شده است که این دادهها شامل جنسیت و سن و تحصیلات والدین و محل و زندگی و ... میباشد. در قسمت زیر این اطلاعات به علاوه جنس دادهی آنها را میتوانیم مشاهده کنیم.
همانطور که مشاهده میکنیم، جنس دادهها متفاوت میباشد. تعدادی از دادهها عددی میباشند. تعدادی دیگر طبقهبندی شده هستند و به صورت رشته حرفی هستند. بعضی از دادهها باینری هستند و این تفاوت جنس دادهها کمی کار را دشوار میکند. معمولا زمانی که در هنگام استفاده از الگوریتم درخت تصمیم با این مدل دادهها سر و کار داریم باید ابتدا دادهها را پردازش اولیه بکنیم و مجموعه دادهی جدیدی بسازیم که در تحلیل کار آسان تر شود. برای همین ابتدا سعی میکنیم با روشهایی که توضیح میدهیم تمامی اطلاعات مربوط به دانشآموزان را به صورت عددی در بیاوریم. از طرفی وقتی کمی مجموعه دادهها را بررسی کنید متوجه میشوید که تعدادی از این دادهها اطلاعات دقیقتری را برای پیشبینی در اختیار ما قرار میدهند. برای مثال تعداد غیبتهای دانشآموز یا تعداد امتحانات رد شده قبلی هر دانشآموز تاثیر بیشتری در دقت پیشبینی ما خواهد داشت. البته قبل از تست واقعی این گفتهها، در حد حدس و گمان ولی با احتمال بالا میباشد. الگوریتم درخت تصمیم هم به ترتیب به دنبال ویژگیهایی میگردد که بیشترین میزان پیشبینی پذیری را برای ما ایجاد میکند.
برای اینکه بتوانیم در هر مرحله خروجی را مشاهده کنیم، از ابزار Jupyter برای برنامهنویسی استفاده میکنیم، ولی شما میتوانید از هر IDE پایتونی برای نوشتن کدها و اجرای آن استفاده کنید. ابتدا با دستور زیر دادهها را که در یک فایل csv ذخیره شده است به کمک کتابخانه pandas لود کرده و برای اینکه اطمینان داشته باشیم که همه دادهها لود شده است، طول دادههای لود شده را چک میکنیم.
همانطور که مشاهده میکنیم تعداد دادههای لود شده ۶۴۹ است که درست میباشد. در مجموعه دادهی اصلی سه نمره وجود دارد که مجموع نمرات این سه درس معیاری برای موفقیت و یا عدم موفقیت دانشاموز میباشد. اگر مجموعه این سه نمره از ۳۵ بیشتر باشد دانشآموز موفق و اگر کمتر باشد دانشآموز ناموفق بوده است. بنابراین ما از این سه نمره استفاده کرده و یک ستون به عنوان لیبل میسازیم که در آن مقدار ۱ به معنای موفقیت دانشآموز و مقدار ۰ به معنای عدم موفقیت میباشد. این کار را با تکه کد زیر انجام میدهیم. در اصل ما در تکه کد زیر مجموع سه نمره دانشآموز را محاسبه کرده و بررسی میکنیم که اگر مجموع این نمرات بیشتر از ۳۵ باشد، دانشآموز قبول و در غیر این صورت دانشآموز مردود شده است. نتیجه را در یک ستون ذخیره میکنیم و آن را به عنوان لیبل استفاده میکنیم.
برای اجرای آنچه که گفتیم باید این محاسبه مجموع سه نمره و بررسی مقدار آن را برای هر سطر از داده انجام دهیم. برای این کار از تابع apply که یکی از توابع کتابخانه pandas میباشد استفاده میکنیم. axis=1 به معنی این است که این تابع را برای هر سطر انجام بده. بعد از اینکه ستون جدید به نام pass را با دادهی محاسبه شده برای هر سطر پر کردیم میتوانیم به کمک تابع drop سه ستون مربوط به سه تمره را از دادهها حذف کنیم.
همانطور که مشاهده میکنیم مجموعه دادهی ما دارای ۳۱ ستون است که یکی از آنها با عنوان pass نشاندهنده لیبل و مقادیر ۰ و ۱ دارد. یکی از کارهای مهمی که در رابطه با هر مجموعه دادهای باید انجام دهید، این است که میزان بایاس یا بالانس بودن دادهها را بررسی کنید. بالانس بودن داده به این معنی است که برای مثال اگر میخواهید یک مساله طبقهبندی دو کلاسه را انجام دهید، باید در مجموعه دادههای آموزش از هر کلاس به تعداد کافی و تقریبا برابر داده داشته باشید. در مجموعه دادهای که ما با آن کار میکنیم بعد از لیبل زدن به دادهها میتوان این موضوع را بررسی کرد. در تکه کد زیر ما تعداد دادههایی که دارای لیبل ۱ به معنی قبول و تعداد دادههایی با لیبل ۰ به معنی مردود را محاسبه کرده و نسبت آنها را به تعداد کل دادهها محاسبه کردیم.
همانطور که مشاهده میشود تقریبا ۵۰ درصد دادهها دارای لیبل ۱ و ۵۰ درصد دیگر دادهها دارای لیبل صفر هستند. بنابراین دادههای ما بالانس میباشد و میتوانیم فرآیند یادگیری را اجرا کنیم. زمانی که دادهها بالانس نیستند و اصطلاحا بایاس به یک طبقه میشوند یادگیری دقیق نمیباشد. بایاس به یک طبقه به این معنی است که در مجموعه دادهی ما اکثر دادهها دارای یک لیبل خاص هستند و مابقی لیبلها تعداد کمی داده دارند.
در قسمت دوم فرآیند آمادهسازی داده باید ویژگیهایی که دارای مقدار عددی نیستند را به مقدار عددی تبدیل کنیم. یکی از روشهای بسیار محبوب و مناسب برای اینکار استفاده از روش one-hot میباشد. این روش به این صورت است که برای هر ستون که دارای مقادیر غیر عددی است ابتدا تمام حالات ممکن مقدار را پیدا میکند و به ازاری هر مقدار یک ستون به دادهها اضافه میکند. مقدار این ستون جدید صفر یا یک میباشد. برای درک بهتر با یک مثال در مساله خودمان توضیح میدهیم. ویژگی Mjob نشاندهنده شغل مادر برای دانشآموز بوده است. با کمک تکه کد زیر میتوانیم مقادیر منحصر به فرد این ویژگی را مشاهده کنیم.
اگر بخواهیم به کمک روش one-hot این ستون را به مقادیر عددی تبدیل کنیم، به جای ستون Mjob پنج ستون به داده اضافه میکنیم و آنها را Mjob_at_home، Mjob_health، Mjob_other، Mjob_services و Mjob_teacher مینامیم. سپس برای هر دانشآموز فقط مقدار یکی از ستونها برابر با یک و مابقی برابر با صفر خواهد بود. برای مثال برای دانشآموزی که مقدار ستون Mjob او health میباشد، بعد از one-hot کردن مقدار ستون Mjob_health برابر با یک و بقیه ستونها برابر با صفر خواهند بود. برای انجام این کار از تابع get_dummies کتابخانه pandas استفاده میکنیم و لیست ستونهایی که میخواهیم عملیات one-hot روی آنها انجام شود را به عنوان ورودی میدهیم. خروجی یک مجموعه دادهی جدید خواهد بود که تمامی ستونهای خواسته شده به کمک روش one-hot به شکل جدید در آمدهاند. تکه کد زیر این کار را نشان میدهد و ما ۵ سطر اول را بعد از اجرای کد نشان دادهایم.
همانطور که مشاهده میکنیم بعد از عملیات one-hot تعداد ستونها از ۳۱ به ۵۷ تا رسیده است. حالا بعد از آمادهسازی مجموعه داده و عددی کردن مقادیر باید یک بار دادهها را shuffle کرده و دادههای آموزش و تست را جدا کنیم. کتابخانه Scikit learn تابع آماده برای shuffle کردن داده دارد ولی ما در اینجا از توابع کتابخانه pandas برای این کار استفاده میکنیم. در تکه کد زیر ما ابتدا به کمک تابع sample دادهها را shuffle میکنیم. مقدار frac=1 نشاندهنده این است که چه درصدی از دادهها را به عنوان خروجی برگرداند که ما با مقدار ۱ مشخص میکنیم که کل دادهها را میخواهیم.
سپس ۵۰۰ دادهی اول را به عنوان دادهی آموزش و مابقی که ۱۴۹ تا میباشد را به عنوان دادهی تست مشخص میکنیم. بعد از آن ستون pass را که به عنوان لیبل میباشد را از ویژگیها جدا کرده و این کار را هم برای دادههای آموزش و هم برای دادههای تست انجام میدهیم.
بعد از آماده شدن دادههای آموزش و تست زمان آن است که مدل درخت تصمیم را به کمک کتابخانه scikit learn بسازیم. برای این کار از تابع DecisionTreeClassifier استفاده میکنیم. برای بدست آوردن information gain از entropy استفاده میکنیم. و مدل را طوری تنظیم میکنیم که درخت حداکثر تا عمق ۵ لایه پایین برود. سپس مدل را بر روی دادهی آموزش fit میکنیم. تکه کد زیر این کار را انجام میدهد.
بعد از اینکه مدل روی دادههای آموزش، آموزش دید باید آن را بر روی دادههای تست، تست کنیم تا میزان دقت نهایی را بدست بیاوریم. تکه کد زیر این کار را انجام میدهد.
همانطور که مشاهده میکنیم دقت مدل بر روی دادهی تست برابر با ۶۹ درصد میباشد. اولین کاری که بعد از تست انجام میدهیم cross validation میباشد تا اطمینان پیدا کنیم که دادهها بایاس نمیباشد. cross validation به این صورت میباشد که هر بار یک تکهی متفاوتی از داده را به عنوان دادهی تست انتخاب میکند مدل را آموزش میدهد و دقت را اندازهگیری میکند و در نهایت میانگین دقتها را گزارش میکند. با این کار اطمینان پیدا میکنیم که مدل بر روی دادههای خیلی بد و یا خیلی خوب آموزش ندیده است و اصطلاحا بایاس نشده است. این کار را به کمک تابع cross_val_score انجام میدهیم. مدل، کل داده و لیبل را به این تابع میدهیم و مشخص میکنیم که cross validation چندتایی میخواهیم. ما در کد این مقدار را ۵ تنظیم کردهایم. عدد ۵ نشان میدهد که تابع دادهها را به ۵ قسمت مساوی تقسیم میکند. هر بار یک قسمت را به عنوان دادهی تست و مابقی را به عنوان دادهی آموزش در نظر میگیرد و دقت مدل را محاسبه میکند. تکه کد زیر این کار را انجام میدهد. (دقت داشته باشید که کل داده که ویژگیها و لیبل آن جدا شده است به عنوان ورودی به تابع داده شده است).
تابع cross validation هم یک میانگین برای دقت و هم یک انحراف معیار برای دقت محاسبه میکند که ما آنها را در خروجی کد بالا چاپ کردهایم. دقت میانگین برابر با ۶۹ درصد و انجراف معیار مثبت و منفی ۶ درصد میباشد.
یکی از پارامترهایی که برای ساخت مدل استفاده کردیم حداکثر عمق درخت بوده که آن را ۵ در نظر گرفته بودیم. سوال این است که آیا برای بهبود مدل این مقدار تاثیر گذار است؟ میتوانیم جواب این سوال را تست کنیم. ما مدلهایی میسازیم که دارای عمق ۱ تا عمق ۲۰ متغییر باسند و دقت هر کدام را به وسیبه cross validation اندازهگیری میکنیم. سپس بهتریم مقدار را مشخص میکنیم. تکه کد زیر این کار را انجام میدهد.
همانطور که مشاهده میشود با افزایش عمق درخت دقت کمتر میشود. احتمالا به این دلیل است که افزایش عمق درخت باعث ایجاد over fit میباشد. over fit به زبان ساده به معنی حفظ کردن داده میباشد نه یادگیری آن. به همین دلیل زمانی که داده را حفظ کنیم اگر دادهی جدیدی به عنوان ورودی به ما داده شود دقت پایین میآید.
برای نمایش میزان دقت بر اساس عمق درخت میتوانیم از ابزار بصریسازی پایتون استفاده کنیم. در تکه کد زیر ابتدا یک مجموعه داده جدید درست میکنیم و مقادیر اندازهی عمق درخت، میانگین دقت و انجراف معیار را در آن وارد میکنیم.
سپس به کمک توابع matplotlib که ابزار بصریسازی پایتون میباشد یک errorBar رسم میکنیم. تکه کد زیر این کار را انجام میدهد.
توضیح پیادهسازی مدل درخت تصمیم در این جا به اتمام رسیده است. اگر بخواهیم مراحل کار را به صورت خلاصه مرور کنیم به این صورت است که ما ابتدا مجموعه دادهی ورودی را گرفته و آن را بررسی اولیه میکنیم. دادهها و نوع آنها را مورد بررسی قرار میدهیم. سپس شروع به آمادهسازی مجموعه داده برای مناسبسازی آموزش میکنیم. در نهایت مدل را با دادههای آموزش، آموزش داده و دقت مدل را با دادههای تست، تست میکنیم. بعد از آن هم برای بهبود دقت مدل سعی در تغییر پارامترهای تاثیرگذار در مدل و فرآیند آموزش میکنیم.
مطلبی دیگر از این انتشارات
رگرسیون منطقی (logistic regression)
مطلبی دیگر از این انتشارات
پرکاربردترین تکنیکهای آماری در علمداده
مطلبی دیگر از این انتشارات
پیچیدگی کد