- We start with a few imports, most notably of scikit-learn:
import gzipimport pickleimport numpy as npimport graphvizfrom sklearn import tree
- Let's load the data and split it into inputs and outputs:
balanced_fit = np.load(gzip.open('balanced_fit.npy.gz', 'rb'))ordered_features = pickle.load(open('ordered_features', 'rb'))train_X = balanced_fit[:,:-2] #POS and errorstrain_Y = balanced_fit[:,-1]
- We call the decision tree algorithm:
estimator = tree.DecisionTreeClassifier(max_depth=4)tree_fit = estimator.fit(train_X, train_Y)print(tree_fit.feature_importances_)
We also print the feature_importances_, which is suggesting that QUAL is the most important.
- Let's draw a tree:
graphviz_representation = tree.export_graphviz(tree_fit, ...