Example Java solution for assignment 9 (Decision Trees)

Disclaimer: Dieser Thread wurde aus dem alten Forum importiert. Daher werden eventuell nicht alle Formatierungen richtig angezeigt. Der ursprüngliche Thread beginnt im zweiten Post dieses Threads.

Example Java solution for assignment 9 (Decision Trees)
Hello everyone,

as promised in the tutorial today: Here is my Java solution for assignment 9.
Please let me know if there are any errors (or if you have suggestions for improvements).

The solution is supposed to be short and readable - and not efficient at all :wink:

I used Java Streams again, which might be a bit more challenging to read if you’re not used to it. Feel free to ask if you have any questions!

package info.kwarc.teaching.AI.DecisionTrees;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;


public class MyDecisionTable extends DecisionTable {
    /* Inefficient, but simple. Obvious optimization: Don't iterate over the data that often ;) */

    MyDecisionTable(List<TableRow> rows) {
        super(rows);
    }

    private Stream<TableRow> getMatchingRows(List<AttVal> assumes) {
        // get the rows that satisfy the assumptions
        return rows.stream().filter(
                row -> assumes.stream().allMatch(attVal -> row.get(attVal.attr).equals(attVal.value))
        );
    }

    private double getProbability(List<AttVal> assumes, Attribute target, String value) {
        // calculates the probability that `target` has value `value` given the assumptions
        long remainingRows = getMatchingRows(assumes).count();
        if (remainingRows == 0) throw new RuntimeException("No rows left");
        long countedRows = getMatchingRows(assumes).filter(row -> row.get(target).equals(value)).count();
        return countedRows/(double)remainingRows;
    }

    private double entropy(List<AttVal> assumes, Attribute target) {
        // calculate entropy for target
        double entropy = 0.0;
        for (String value : target.range) {
            double p = getProbability(assumes, target, value);
            if (p < 1e-10) continue;        // probability 0
            entropy += -p * Math.log(p) / Math.log(2);
        }
        return entropy;
    }

    @Override
    protected double entropyGain(Attribute attr, List<AttVal> assumes, Attribute target) {
        double entropyGain = entropy(assumes, target);
        for (String val : attr.range) {
            double p = getProbability(assumes, attr, val);
            if (p < 1e-10) continue;        // probability 0
            assumes.add(new AttVal(attr, val));
            entropyGain -= p * entropy(assumes, target);
            assumes.remove(assumes.size() - 1);
        }
        return entropyGain;
    }

    private String getMostCommonValue(List<AttVal> assumes, Attribute target) {
        // returns the most common value for `target` in the rows that match `assumes`
        Map<String, Integer> counts = target.range.stream().collect(Collectors.toMap(val -> val, val -> 0));
        getMatchingRows(assumes).forEach(row -> counts.replace(row.get(target), counts.get(row.get(target)) + 1));
        return counts.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey();
    }

    private Tree findTree(List<AttVal> assumes, Attribute target, String defaultVal) {
        if (getMatchingRows(assumes).count() == 0) {
            // no examples (= matching rows) left
            return new Leaf(defaultVal);
        }

        String firstVal = getMatchingRows(assumes).findFirst().get().get(target);
        if (getMatchingRows(assumes).allMatch(row -> row.get(target).equals(firstVal))) {
            // all remaining rows have same value (firstVal) for the target
            return new Leaf(firstVal);
        }

        // create map of choices (attributes that haven't been used) and their information gain
        Map<Attribute, Double> choices = attributes.stream()
                .filter(attr -> attr != target && assumes.stream().noneMatch(attVal -> attVal.attr == attr))
                .collect(Collectors.toMap(attr -> attr, attr -> entropyGain(attr, assumes, target)));

        if (choices.isEmpty()) {
            // we've already checked every available attribute
            return new Leaf(getMostCommonValue(assumes, target));
        }

        // create subtree using the attribute with the maximal gain
        Attribute attribute = choices.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey();
        defaultVal = getMostCommonValue(assumes, target);
        List<Evaluation> accepts = new ArrayList<>();
        for (String value : attribute.range) {
            assumes.add(new AttVal(attribute, value));
            accepts.add(new Evaluation(value, findTree(assumes, target, defaultVal)));
            assumes.remove(assumes.size()-1);
        }

        return new SubTree(attribute, accepts);
    }

    @Override
    public Tree computeTree(Attribute target) {
        return findTree(new ArrayList<>(), target, getMostCommonValue(new ArrayList<>(), target));
    }
}

Attachment:
MyDecisionTable.java: https://fsi.cs.fau.de/unb-attachments/post_156794/MyDecisionTable.java

2 „Gefällt mir“