حل یک مثال با الگوریتم Classification Tree در R

در این مقاله سعی داریم با استفاده از الگوریتم Classification Tree در محیط زبان R یک مثال حل کنیم ، بدون مقدمه اضافی سراغ مراحل حل مسئله می رویم . در ابتدا می بایست بسته tree را در محیط توسعه خود (که اینجا نرم افزار RStduio می باشد) با دستور ذیل نصب کنیم.

install.packages("tree")

پس از نصب این بسته به تابع tree موجود در این بسته با فرمت ذیل دسترسی خواهیم داشت

آرگومان های وروردی این تابع عبارتند از:

  • آرگومان z : دیتا فریم کلاس های مورد نظر برای دسته بندی می باشد
  • آرگومان data : دیتا ستی هست که می خواهم دسته بندی روی آن انجام شود
  • آرگومان split: مشخص کننده معیار تقسیم براساس انحراف یا gini می باشد

در ادامه گام به گام پیش می رویم

گام اول : بارگزاری بسته های مورد نظر در محیط توسعه

برای مسئله مورد نظر خود نیازمندیم داده های مربوط به یکسری خوردو که در دیتاست Vehicle موجود هست را به محیط R اضافه کنیم، این داده ها در بسته mlbench هستند و کافیست تا این بسته را نصب و در کد خود با تابع library فراخوانی کنیم.

install.packages("mlbench ")

library (mlbench)

data (Vehicle)

این مجموعه شامل داده‌های چهار نوع خودرو (اتوبوس دو طبقه، وانت شورولت، ساب 9000 و اوپل مانتا 400) است که شامل 846 مشاهده و 18 ویژگی عددی استخراج شده و همچنین یک متغیر اسمی کلاس اشیاء می باشد. که شرح هریک از متغییرها (=ویژگی ها) بصورت ذیل می باشد

با دستور summary می توانیم یکسری اطلاعات آماری برای هر ویژگی بدست آورد بطور نمونه برای متغییر Comp نتیجه بصورت زیر می باشد

summary(Vehicle [1])

گام دوم : آماده سازی داده ها و پارامترها

ما از 846 مشاهده ، 500 مشاهده را به صورت تصادفی برای ایجاد یک مجموعه آموزشی (train) انتخاب می کنیم، از این مشاهدات برای ساختن درخت طبقه بندی استفاده خواهیم کرد. برای اینکار از دستورات رایج ذیل استفاده می کنیم

set.seed (107)

N=nrow(Vehicle)

train <- sample(1:N, 500, FALSE)

گام سوم : تخمین با درخت تصمیم

اکنون آماده ساختن درخت تصمیم با استفاده از نمونه آموزشی هستیم برای این کار از دستور ذیل استفاده می نماییم

fit<- tree(Class ~., data = Vehicle[train ,], split ="deviance")

ما از انحراف (deviance) به عنوان معیار تقسیم استفاده می کنیم،ممکن است از اینکه R چقدر سریع درخت را می سازد شگفت زده شوید!

ذکر این نکته مهم است که به یاد داشته باشید که متغیر پاسخ ما در اینجا factor (یا به عبارتی categorical)است در صورتی که امکان دارد متغییر پاسخ در دیتا ستی که استفاده می کنیم عددی باشد که می بایست به با استفاده از تابع factor به نوع factor تبدیل کنیم، می‌برد. برای چک کردن نوع متغییر پاسخ می توانیم از این فرمان استفاده نماییم

class(Vehicle$Class)

که خروجی آن

[1] "factor"

نوع متغییر Class را factor بر می گرداند. حال کافیست از متغییر fit یک summary بگیریم

summary(fit)

نتیجه خروجی بصورت ذیل می باشد

مواردی که این خلاصه به ما نشان می دهد بصورت ذیل است

  • نوع الگوریتم درختی که استفاده شده، در این مورد درخت طبقه بندی؛
  • فرمول مورد استفاده برای fit کردن درخت؛
  • متغیرهای مورد استفاده برای fit کردن درخت که اینجا 10 متغییر است
  • تعداد گره های پایانی (برگ های درخت) در این مورد 16؛
  • انحراف میانگین باقیمانده - 0.9425؛
  • میزان خطای طبقه بندی 0.252 یا 25.2 درصد

همچنین برای اینکه بصورت گرافیکی درخت را نمایش دهیم می توانیم از دستور plot استفاده نماییم و برای نمایش مقادیر هر node می توان از دستور text استفاده کرد

plot(fit); text(fit)


نکته : ارتفاع خطوط عمودی در شکل فوق متناسب با کاهش انحراف است. هرچه این خط طولانی تر باشد کاهش بزرگتر است. این نحوه نمایش به شما امکان می دهد بخش های مهم را بلافاصله شناسایی کنید. اگر می خواهید مدل را با استفاده از طول های یکنواخت رسم کنید، می توانید از دستور زیر استفاده کنید

plot(fit,type="uniform") ; text(fit)

گام 4 : ارزیابی مدل

یکی از مشکلات رایج درختان طبقه بندی over fit شدن مدل است .یکی از روش‌های کاهش این ریسک، استفاده از اعتبارسنجی متقاطع (cross-validation) می باشد. برای هر نمونه ، مدل را fitمی‌کنیم و بررسی می‌کنیم که درخت در چه سطحی بهترین نتایج را می‌دهد (با استفاده از deviance یا نرخ طبقه‌بندی اشتباه). برای این منظور از تابع ()cv.tree استفاده می کنیم در اینجا ما اعتبارسنجی متقاطع leave-one-out را با استفاده از انحراف و نرخ طبقه‌بندی اشتباه را انتخاب نمودیم . پس از رسم نتایج کنار هم همانطور که در شکل زیر ترسیم شده است. خط ناهموار نشان می دهد که حداقل تفاوت بین انحراف و طبقه بندی اشتباه رخ داده است و از آنجایی که طبقه‌بندی اشتباه و انحراف cross validated شده هر دو به حداقل تعداد شاخه‌های درخت طبق بندی بدست آمده در مرحله قبل می رسند لذا هرس این درخت بیش از این نتیجه ای نخواهد داشت.برای اجرای موارد فوق از دستورات زیر استفاده می کنیم

fitM.cv <- cv.tree(fit ,K=346,FUN=prune.misclass)

fitP.cv <- cv.tree(fit ,K=346,FUN=prune.tree)

par(mfrow = c(1, 2))

plot(fitM.cv)

plot(fitP.cv)

گام پنجم : پیش بینی

نها یتا در این مرحله برای انجام پیشبینی و تشکیل ماتریس confusion می توانیم بصورت زیر عمل کنیم

pred <-predict(fit ,newdata=Vehicle[-train ,])

pred.class <- colnames(pred)[max.col(pred , ties.method = c("random"))]

table(Vehicle$Class[-train],pred.class ,dnn=c("Observed Class","Predicted Class" ))

که این نتیجه را در بر دارد

نهایتا برای محاسبه خطا مدل می توان به صورت ذیل عمل کرد که نشان دهنده 32.4% خطا برای این مدل است.

error_rate = (1-sum(pred.class== Vehicle$Class[-train])/346)

round(error_rate ,3)