رگراسیون خطی به زبان ساده

سلام (= راستش تصمیم گرفتم هماهنگ با مطالعه های خودم توی حوزه ی هوش مصنوعی و یادگیری ماشین، یه سری مفاهیمی که ممکنه کمی پیچیده باشن رو به زبون ساده توضیح بدم. این پست درباره ی «Linear Regression» یا «رگراسیون خطی» هست که یکی از متدهای یادگیری ماشینه. قبل از اینکه برم سراغ مفاهیم و تعاریف، بهتره بگم اصن قضیه چیه و میخوایم چیکار کنیم! ما در واقع میخوایم به سیستم، یه دیتاست بدیم و سیستم اون رو بررسی کنه و الگویی رو ازش کشف کنه. حالا اگه یه داده ای که توی دیتای اولیه نیست رو از سیستم بخوایم، سیستم باید بتونه با توجه به الگوریتم هاش حدس مناسبی رو برای ما انجام بده، که با بقیه دیتاست همخوانی داشته باشه. به نمودار پایین نگاه کنین. این نمودار فرضی میزان فروش یه بستنی فروشی رو در دماهای مختلف در طول سال نشون میده:

به طور کلی میتونیم حدس بزنیم که وقتی هوا سردتره، فروش بستنی کمتره و وقتی هوا گرمتر میشه، فروش هم بیشتر میشه. اما برای این قاعده فرمول دقیقی وجود نداره! از طرفی، برای میزان فروش در دماهایی مثل 21 یا 24 درجه سانتیگراد، اطلاعاتی در نمودار نیست و باید حدسشون بزنیم! این حدس رو میسپاریم به یادگیری ماشین! و اینجاست که رگراسیون خطی وارد میشه. قبل از شروع بذارید این ترکیب کلمات رو تعریف کنم. رگراسیون (Regression) یعنی پیشبینی مقداری که پیوستگی داره. حالا این ینی چی؟ بذارین نقطه مقابلش هم تعریف کنم: «طبقه بندی» یا «Classification» که یعنی پیشبینی مقداری که پیوستگی نداره و گسسته اس. رگراسیون با پیشبینی عدد سروکار داره و یه مقدار کمّی خروجی میده، درحالی که طبقه بندی بین دو یا چند چیز(مثل Yes و No، یا مرد و زن) که کیفی هستن تصمیم گیری میکنه. به این مثال دقت کنین:

این یه نمونه از طبقه بندی هست که با توجه به میزان قد و وزن، جنسیت هر نفر رو حدس میزنه. نقاط مثلثی مرد و نقاط دایره ای زن هستن. پس سیستم بین این دو تصمیم میگیره. ولی ما اینجا با رگراسیون کار داریم، که قراره یه عدد برامون خروجی بده! مثلا دمای هوا رو به عنوان ورودی بگیره و میزان فروش رو حدس بزنه. و یا برعکسش! خب، حالا بریم سراغ رگراسیون خطی. به نمودار پایین دقت کنین:

این همون نمودار فروش بستنی بر حسب دماییه که بالاتر داشتیم. اما اینجا یه خط به نمودار اضافه کردیم و این خط میتونه مقادیر مجهولمون(مثل فروش برای دمای 21 و 24 درجه سانتیگراد) رو حدس بزنه! اما این خط از کجا اومده و چرا این خط؟! ما به کمک رگراسیون خطی و یه سری الگوریتم میتونیم چنین خطی رو برای نمودارمون درنظر بگیریم که بتونه به ازای هر ورودی معتبر، حدس معقولی رو به ما تحویل بده. حالا این خط چه خصوصیاتی باید داشته باشه؟ منطقاً اینکه تا حد امکان به نقاط روی نمودار نزدیک باشه تا جواب عجیب و غریب بهمون تحویل نده! برای مثال، مشخصا خط های نمودار زیر انتخاب خوبی نیستن:

برای اینکه خط مورد نظر تا حد امکان به نقاط نمودار نزدیک باشه و ازشون عبور کنه، یه تعریف داریم: مجموع فواصل نقاط تا خط باید حداقل باشد. که منطقی هم به نظر میاد. هرچی فاصله خط از نقاط کمتر، حدس دقیقتر! حالا چجوری این تعریف رو به زبان ریاضی بنویسیم؟ به این فرمول نگاه کنین:

قبل از اینکه از قیافه ترسناک این فرمول ناامید شین و صفحه رو ببندین، اجازه بدین توضیحش بدم. قرار بود فاصله هر نقطه از خط رو حساب کنیم، جمعشون کنیم و حداقل این مجموع ها رو بین تمامی خطوط ممکن پیدا کنیم. عبارت توان دار توی فرمول هم فاصله هر نقطه ی y ⁽ⁱ⁾ روی نمودار رو از نقطه ای که خط براش پیشبینی کرده محاسبه میکنه. سیگما مجموع این فواصل رو به اندازه ی m (که تعداد داده ها یا همون نقطه ها باشه) رو محاسبه میکنه و ترم 1/2m هم این مقدار رو کوچیکتر و استانداردتر میکنه. به این تابع J( θ₀ , θ₁ ) میگیم تابع هزینه یا کاست (Cost Function). ریاضی به ما میگه که مقدار این تابع عجیب به ازای θ₀ و θ₁ باید حداقل باشه. حالا این θ₀ و θ₁ چی هستن؟! تتا یک و تتا صفر ضرایبی هستن که به ترتیب شیب خط و مکان این خط در صفحه مختصات رو مشخص میکنن. و ما میتونیم معادله ی خط مورد نظرمون رو به کمک این دو عدد بنویسیم. یعنی به صورت مقابل : y = θ₀ + θ₁x . این معادله ی خطی که نوشتم، همون عبارت hθ توی فرمول بالاست. صرفا بدونین به این تابع که تابع خطمون هست میگیم فرض یا Hypothesis. خب حالا کافیه برای این تابع هزینه مون، یه مینیموم پیدا کنیم. الگوریتم هایی هست برای اینکه این مینیموم پیدا بشه؛ مثلا اینکه یه خط فرضی درنظر بگیریم، به یه جهت بچرخونیمش(یا درواقع θ₁ رو تغییر بدیم) و تابع هزینه رو حساب کنیم و تاجایی چرخش رو ادامه بدیم که به حداقل تابع هزینه برسیم. و حالا باید مکان این خط رو مشخص کنیم یا همون θ₀. به اندازه ای θ₀ رو تغییر بدیم تا تابع هزینه باز هم حداقل بشه. و درنهایت ما مینیمم تابع هزینه رو پیدا کردیم. روشهای خوشگلتری مثل کاهش گرادیانی (Gradient Descent) هم هست که خب، بهتره خودتون دربارش بخونید (=

توی این پست هم پیاده سازی این الگوریتم رو در پایتون، مفصّصّصّل توضیح دادم (=