元生技データサイエンティストのメモ帳

勉強したことの備忘録とか雑記とか

Pythonで作った決定木のモデルを可視化する

こんにちは、sue124です。
前回は以下の書籍の写経を終えた感想を書きましたが、その中で本書の良くない点として挙げた「特に決定木のモデルの可視化が省略されている」ことに関して、自分でやり方を調べたので、書いていきたいと思います。

今回は巷で良く使われているアヤメのデータから品種を特定する決定木のモデルを作って、それを可視化していきます。

import pandas as pd
#対応する名前に変換する関数
def name(num):
    if num == 0:
        return 'Setosa'
    elif num == 1:
        return 'Veriscolour'
    else:
        return 'Virginica'

#アヤメのデータを格納する(説明変数をdataX, 目的変数をdataYとする)
from sklearn import datasets

data = datasets.load_iris()

iris_dataX_df = pd.DataFrame(data=data.data, columns=data.feature_names)

iris_dataY_df = pd.DataFrame(data=data.target)
iris_dataY_df = iris_dataY_df.rename(columns={0: "Species"})
iris_dataY_df["Species"] = iris_dataY_df["Species"].apply(name)

iris_dataX_df, iris_dataY_df の中身はそれぞれこんな感じです。

iris_dataX_df.head()

f:id:sue124:20200419205659p:plain

iris_dataY_df.head()

f:id:sue124:20200419205723p:plain

このデータから決定木のモデルを作ります。

#データを分割する(訓練用:評価用 = 7:3)
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(iris_dataX_df, iris_dataY_df, test_size=0.3)

#決定木モデル作成
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier()
model.fit(X_train, Y_train)

上記で作ったモデルの可視化するには、以下のようにします。

#決定木モデルの可視化
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
%matplotlib inline

fig = plt.figure(figsize=(20, 10))
plot_tree(model, fontsize=14);

f:id:sue124:20200419210050p:plain

たったこれだけで可視化できます。数行でかけるのだから、書籍の中で省略してほしくなかったなーと思います。
決定木って作ったモデルを可視化してナンボですし。
あと図中には明示されていませんが、左に分岐すると枠内に書いている条件が「True」、右に分岐すると「False」です。


ところで図を良く見ると「gini」の文字が。よくよく確認すると、このモジュールでは決定木のモデル生成に「エントロピー」でなく「ジニ不純度」を使っているそうです。公式ドキュメントを見てても良く「impurity」という単語が出てくるなぁと思ったら、そういうことだったらしいです。

###2020/06/07 追記###
エントロピーとジニ不純度の違いについてまとめてみました。
sue124.hatenablog.com