I'm trying to write the ID3 algorithm that generates a Decision tree, but I get StackOverflowError when I run my code. When debugging I noticed that the looping begins when the attributes get down to 4 (from initially 9). The code for the tree generating is below. All the functions that I'm calling are working properly, they have been tested. However, the error code states that the problem is in another function that uses streams, but it's been tested separately and I know that it is working properly. Keep in mind that I'm working with random data so the function sometimes throws the error and sometimes doesn't. I post the error code beneath it, but the entropy function and informationGain work.
This is the TreeNode structure:
public class TreeNode {
List<Patient> samples;
List<TreeNode> children;
TreeNode parent;
Integer attribute;
String attributeValue;
String className;
public TreeNode(List<Patient> samples, List<TreeNode> children, TreeNode parent, Integer attribute,
String attributeValue, String className) {
this.samples = samples;
this.children = children;
this.parent = parent;
this.attribute = attribute;
this.attributeValue = attributeValue;
this.className = className;
}
}
And that's the code that throws the errors:
public TreeNode id3(List<Patient> patients, List<Integer> attributes, TreeNode root) {
boolean isLeaf = patients.stream().collect(Collectors.groupingBy(i -> i.className)).keySet().size() == 1;
if (isLeaf) {
root.setClassName(patients.get(0).className);
return root;
}
if (attributes.size() == 0) {
root.setClassName(mostCommonClass(patients));
return root;
}
int bestAttribute = maxInformationGainAttribute(patients, attributes);
Set<String> attributeValues = attributeValues(patients, bestAttribute);
for (String value : attributeValues) {
List<Patient> branch = patients.stream().filter(i -> i.patientData[bestAttribute].equals(value))
.collect(Collectors.toList());
TreeNode child = new TreeNode(branch, new ArrayList<>(), root, bestAttribute, value, null);
if (branch.isEmpty()) {
child.setClassName(mostCommonClass(patients));
root.addChild(new TreeNode(child));
} else {
List<Integer> newAttributes = new ArrayList<>();
newAttributes.addAll(attributes);
newAttributes.remove(new Integer(bestAttribute));
root.addChild(new TreeNode(id3(branch, newAttributes, child)));
}
}
return root;
}
Those are the other functions:
public static double entropy(List<Patient> patients) {
double entropy = 0.0;
double recurP = (double) patients.stream().filter(i -> i.className.equals("recurrence-events")).count()
/ (double) patients.size();
double noRecurP = (double) patients.stream().filter(i -> i.className.equals("no-recurrence-events")).count()
/ (double) patients.size();
entropy -= (recurP * (recurP > 0 ? Math.log(recurP) : 0 / Math.log(2))
+ noRecurP * (noRecurP > 0 ? Math.log(noRecurP) : 0 / Math.log(2)));
return entropy;
}
public static double informationGain(List<Patient> patients, int attribute) {
double informationGain = entropy(patients);
Map<String, List<Patient>> patientsGroupedByAttribute = patients.stream()
.collect(Collectors.groupingBy(i -> i.patientData[attribute]));
List<List<Patient>> subsets = new ArrayList<>();
for (String i : patientsGroupedByAttribute.keySet()) {
subsets.add(patientsGroupedByAttribute.get(i));
}
for (List<Patient> lp : subsets) {
informationGain -= proportion(lp, patients) * entropy(lp);
}
return informationGain;
}
private static int maxInformationGainAttribute(List<Patient> patients, List<Integer> attributes) {
int maxAttribute = 0;
double maxInformationGain = 0;
for (int i : attributes) {
if (informationGain(patients, i) > maxInformationGain) {
maxAttribute = i;
maxInformationGain = informationGain(patients, i);
}
}
return maxAttribute;
}
The exceptions:
Exception in thread "main" java.lang.StackOverflowError
at java.util.stream.ReferencePipeline$2$1.accept(Unknown Source)
at java.util.ArrayList$ArrayListSpliterator.forEachRemaining(Unknown Source)
at java.util.stream.AbstractPipeline.copyInto(Unknown Source)
at java.util.stream.AbstractPipeline.wrapAndCopyInto(Unknown Source)
at java.util.stream.ReduceOps$ReduceOp.evaluateSequential(Unknown Source)
at java.util.stream.AbstractPipeline.evaluate(Unknown Source)
at java.util.stream.LongPipeline.reduce(Unknown Source)
at java.util.stream.LongPipeline.sum(Unknown Source)
at java.util.stream.ReferencePipeline.count(Unknown Source)
at Patient.entropy(Patient.java:39)
at Patient.informationGain(Patient.java:67)
at Patient.maxInformationGainAttribute(Patient.java:85)
at Patient.id3(Patient.java:109)
The line:
root.addChild(new TreeNode(id3(branch, newAttributes, child)));
Is being called every time the method recurses, which leads to the stack overflow. That tells me that there is something wrong in your logic where none of the "base cases" that end the recursion, ie return root, are being reached. I don't know enough about the desired behavior or the starting data to pinpoint what is going wrong, but I would start by stepping through the code with a debugger and making sure that the logic within the method is behaving how you are expecting. I know that isn't a great answer but it's a starting point, hopefully that helps or someone else will chime in with a more specific solution.