機械学習で利用する分類器の一つ

決定木

決定木(decision tree)は、データを複数のクラスに分類する教師あり学習のアルゴリズムの一つである。主に木構造を利用した分類アルゴリズムである。学習しやすく、また学習結果の解釈も容易である。しかし、過学習を起こしやすい。例えば、決定木を利用した脊椎動物の分類は、以下のようなモデルを利用することができる。

決定木による脊椎動物の分類

scikit-learn を利用した決定木の作成と予測

決定木モデルの作成と予測

scikit-learn ライブラリー中の tree.DecisionTreeClassifier クラスを利用することで、決定木モデルの作成と予測を行うことができる。

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

if __name__ == '__main__':
    iris = datasets.load_iris()
    print(iris.data)
    ## [[ 5.1  3.5  1.4  0.2]
    ##  [ 4.9  3.   1.4  0.2]
    ##  [ 4.7  3.2  1.3  0.2]
    ##  [ 4.6  3.1  1.5  0.2]
    ## ..
    ##  [ 6.2  3.4  5.4  2.3]
    ##  [ 5.9  3.   5.1  1.8]]
   
    print(iris.target)
    ## [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    ##  0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
    ##  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
    ##  2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
    ##  2 2]

    clf = DecisionTreeClassifier(max_depth = 3)
    clf = clf.fit(iris.data, iris.target)

    predicted_target = clf.predict(iris.data)
    print(predicted_target)
    ## [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    ##  0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1
    ##  1 1 1 2 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2
    ##  2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
    ##  2 2]

    score = accuracy_score(iris.target, predicted_target)

    print(score)
    ## 0.973333333333

可視化

決定木モデルを可視化するには、pydotplus ライブラリーなどを利用して行う。次のスクリプトを動かすには、Python の pydotplus ライブラリーをインストールする必要があるほか、システムに GraphViz もインストールする必要がある。

import pydotplus
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO

if __name__ == '__main__':

    iris = datasets.load_iris()
    clf = DecisionTreeClassifier(max_depth = 3)
    clf = clf.fit(iris.data, iris.target)

    dot = StringIO()
    export_graphviz(clf, out_file = dot)
    graph = pydotplus.graph_from_dot_data(dot.getvalue())
    graph.write_pdf("graph.pdf")

生成された決定木の構造は次のように PDF として出力される。

決定木モデルの可視化

交差検証

決定木モデルを作成する時、階層を指定する必要がある。この階層を、交差検証により決定していく。

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score

if __name__ == '__main__':

    iris = datasets.load_iris()

    max_depth_list = [2, 3, 4, 5, 6, 7, 8, 9]

    for max_depth in max_depth_list:
        clf = DecisionTreeClassifier(max_depth = max_depth)
        score = cross_val_score(estimator = clf, X = iris.data, y = iris.target, cv = 10)
        print([max_depth, score.mean()])
        ## [2, 0.94666666666666666]
        ## [3, 0.95999999999999996]
        ## [4, 0.95333333333333337]
        ## [5, 0.95333333333333337]
        ## [6, 0.95333333333333337]
        ## [7, 0.95333333333333337]
        ## [8, 0.95333333333333337]
        ## [9, 0.95999999999999996]

階層の他にハイパーパラメーターが複数ある場合は、次のように GridSearchCV メソッドを利用すると便利である。このとき、ハイパーパラメーターの名前をディクショナリーのキーとして、ハイパーパラメーターを値として保存して、GridSearchCV に与える。

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

if __name__ == '__main__':

    iris = datasets.load_iris()

    params = {'max_depth': [2, 3, 4, 5, 6, 7, 8, 9],
              'criterion': ['gini', 'entropy']}
    
    clf = GridSearchCV(DecisionTreeClassifier(), params, cv = 10)
    clf.fit(X = iris.data, y = iris.target)

    print(clf.best_estimator_)
    ## DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    ##             max_features=None, max_leaf_nodes=None,
    ##             min_impurity_decrease=0.0, min_impurity_split=None,
    ##             min_samples_leaf=1, min_samples_split=2,
    ##             min_weight_fraction_leaf=0.0, presort=False, random_state=None,
    ##             splitter='best')

    print(clf.best_score_)
    ##0.96

    print(clf.best_params_)
    ## {'criterion': 'gini', 'max_depth': 3}

References