Print out confusion matrix with JAVA

3.2k views Asked by At

This might sound stupid! But, I count find any simple example to reference! Can someone please give an example of printing confusion matrix using java?

something like this (the output):

p\a   Head    Tail
Head    1      4
Tail    4      1

Assuming the data stored in HashMap like this

HashMap<String,Integer>
String = "Head, Tail"
Integer = 4

update (sample code):

public static void main(String[] args) {

        HashMap<String,Integer> cmatrix = new HashMap<String,Integer>();

        //the string part can hold more the 2 values, all separated with comma
        cmatrix.put("tail, head", 1);
        cmatrix.put("head ,tail", 4);
        cmatrix.put("tail, tail", 1);
        cmatrix.put("head, head", 4);

        for (Map.Entry entry : cmatrix.entrySet()) {
            System.out.println(entry.getKey() +" : "+entry.getValue());
        }
    }

thanks!

2

There are 2 answers

3
Jason On BEST ANSWER

To simplify the code, let's assume no spaces in the source data:

cmatrix.put("tail,head", 1);
cmatrix.put("head,tail", 4);
cmatrix.put("tail,tail", 1);
cmatrix.put("head,head", 4);

First, we need to gather the names of the classes:

Set<String> classNames = new HashSet<String>();
for(String key : cmatrix.keySet()) {
    String[] classes = key.split(",");
    if(classes != null && classes.length > 0) {
        classNames.addAll(Arrays.asList(classes));
    }
}

Next, sort the class names:

List<String> sortedClassNames = new ArrayList<String>();
sortedClassNames.addAll(classNames);
Collections.sort(sortedClassNames);

Then print out the header row.

System.out.print("p/a");
for(String predictedClassName : sortedClassNames) {
    System.out.print("\t" + predictedClassName);
}
System.out.println();

Then print out each line:

for(String actualClassName : sortedClassNames) {
    System.out.print(actualClassName);
    for(String predictedClassName : sortedClassNames) {
        Integer value = cmatrix.get(actualClassName + "," + predictedClassName);
        System.out.print("\t");
        if(value != null) {
            System.out.print(value);
        }
    }
    System.out.println();
}

I'll leave the 'prettying up' of the output as an exercise for the reader.

0
demongolem On

Confusion matrices are just complicated enough that it is worth looking into an open source solution. One that easily be integrated into code without adding tons of unnecessary extras is here. Other data science / nlp packages also have implementations as part of their distribution and even if using the entire library is too much the source files can give guidance as to how to do it.

The benefit of using one of these is that they also give you some metrics for free as well as being a placeholder matrix, such as Cohen's Kappa measure and the more basic precision / recall / and F-measure scores.