I get wrong class names in the decision tree from "scikit learn/decision tree/export graphviz". The program is shown as follows:
import matplotlib.pyplot as plt
import matplotlib.image as img
import pydot
from sklearn import tree
digital_table = [[0, 0], [0, 1], [1, 0], [1, 1]]
digital_label = ['zero', 'one', 'two', 'three']
digital_name = ['idx-1', 'idx-2']
digital_tree = tree.DecisionTreeClassifier()
digital_tree.fit(digital_table, digital_label)
with open("digital.dot", 'w') as f:
f = tree.export_graphviz(digital_tree,
feature_names=digital_name,
class_names=digital_label,
filled=True, rounded=True,
out_file=f)
(graph,) = pydot.graph_from_dot_file("digital.dot")
graph.write_png("digital.png")
plt.imshow(img.imread('digital.png'))
plt.show()
The output is as follows:
The problem is about the class names shown in the leafs. For example, the green box should be labeled as 'three' if both idx-1 as 1 and idx-2 as 1. But, the image shows the label as 'one'. Can anyone give your comments?
When you use DecisionTreeClassifier, you are supposed to change to the class label into numbers like 0,1,2
Then use:
It will give you the label of the class in ascending order. Then specify your class_label in the same order. It can be strings.