scikit learn decision tree export graphviz - wrong class names in the decision tree

4k views Asked by At

I get wrong class names in the decision tree from "scikit learn/decision tree/export graphviz". The program is shown as follows:

import matplotlib.pyplot as plt
import matplotlib.image as img
import pydot
from sklearn import tree

digital_table = [[0, 0], [0, 1], [1, 0], [1, 1]]
digital_label = ['zero', 'one', 'two', 'three']
digital_name = ['idx-1', 'idx-2']

digital_tree = tree.DecisionTreeClassifier()
digital_tree.fit(digital_table, digital_label)

with open("digital.dot", 'w') as f:
    f = tree.export_graphviz(digital_tree, 
                            feature_names=digital_name,
                            class_names=digital_label,
                            filled=True, rounded=True,
                            out_file=f)
(graph,) = pydot.graph_from_dot_file("digital.dot")
graph.write_png("digital.png")

plt.imshow(img.imread('digital.png'))
plt.show()

The output is as follows:

the decision tree

The problem is about the class names shown in the leafs. For example, the green box should be labeled as 'three' if both idx-1 as 1 and idx-2 as 1. But, the image shows the label as 'one'. Can anyone give your comments?

2

There are 2 answers

0
noobie2023 On

When you use DecisionTreeClassifier, you are supposed to change to the class label into numbers like 0,1,2

Then use:

classe_names = decision_tree_classifier.classes_

It will give you the label of the class in ascending order. Then specify your class_label in the same order. It can be strings.

1
Vicente Valencia Navech On

Try sorting the class labels alphabetically before passing them to export_graphviz