こんにちは、sue124です。
前回は以下の書籍の写経を終えた感想を書きましたが、その中で本書の良くない点として挙げた「特に決定木のモデルの可視化が省略されている」ことに関して、自分でやり方を調べたので、書いていきたいと思います。
今回は巷で良く使われているアヤメのデータから品種を特定する決定木のモデルを作って、それを可視化していきます。
import pandas as pd #対応する名前に変換する関数 def name(num): if num == 0: return 'Setosa' elif num == 1: return 'Veriscolour' else: return 'Virginica' #アヤメのデータを格納する(説明変数をdataX, 目的変数をdataYとする) from sklearn import datasets data = datasets.load_iris() iris_dataX_df = pd.DataFrame(data=data.data, columns=data.feature_names) iris_dataY_df = pd.DataFrame(data=data.target) iris_dataY_df = iris_dataY_df.rename(columns={0: "Species"}) iris_dataY_df["Species"] = iris_dataY_df["Species"].apply(name)
iris_dataX_df, iris_dataY_df の中身はそれぞれこんな感じです。
iris_dataX_df.head()
iris_dataY_df.head()
このデータから決定木のモデルを作ります。
#データを分割する(訓練用:評価用 = 7:3) from sklearn.model_selection import train_test_split X_train, X_test, Y_train, Y_test = train_test_split(iris_dataX_df, iris_dataY_df, test_size=0.3) #決定木モデル作成 from sklearn.tree import DecisionTreeClassifier model = DecisionTreeClassifier() model.fit(X_train, Y_train)
上記で作ったモデルの可視化するには、以下のようにします。
#決定木モデルの可視化 import matplotlib.pyplot as plt from sklearn.tree import plot_tree %matplotlib inline fig = plt.figure(figsize=(20, 10)) plot_tree(model, fontsize=14);
たったこれだけで可視化できます。数行でかけるのだから、書籍の中で省略してほしくなかったなーと思います。
決定木って作ったモデルを可視化してナンボですし。
あと図中には明示されていませんが、左に分岐すると枠内に書いている条件が「True」、右に分岐すると「False」です。
ところで図を良く見ると「gini」の文字が。よくよく確認すると、このモジュールでは決定木のモデル生成に「エントロピー」でなく「ジニ不純度」を使っているそうです。公式ドキュメントを見てても良く「impurity」という単語が出てくるなぁと思ったら、そういうことだったらしいです。
###2020/06/07 追記###
エントロピーとジニ不純度の違いについてまとめてみました。
sue124.hatenablog.com