package org.apache.mahout.df.builder;

import java.util.Arrays;
import java.util.Random;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.conditions.Condition;
import org.apache.mahout.df.node.CategoricalNode;
import org.apache.mahout.df.node.Leaf;
import org.apache.mahout.df.node.Node;
import org.apache.mahout.df.node.NumericalNode;
import org.apache.mahout.df.split.IgSplit;
import org.apache.mahout.df.split.OptIgSplit;
import org.apache.mahout.df.split.Split;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.2.jar:org/apache/mahout/df/builder/DefaultTreeBuilder.class */
public class DefaultTreeBuilder implements TreeBuilder {
    private int m = 1;
    private IgSplit igSplit = new OptIgSplit();

    public void setM(int i) {
        this.m = i;
    }

    public void setIgSplit(IgSplit igSplit) {
        this.igSplit = igSplit;
    }

    @Override // org.apache.mahout.df.builder.TreeBuilder
    public Node build(Random random, Data data) {
        if (data.isEmpty()) {
            return new Leaf(-1);
        }
        if (data.isIdentical()) {
            return new Leaf(data.majorityLabel(random));
        }
        if (data.identicalLabel()) {
            return new Leaf(data.get(0).label);
        }
        Split split = null;
        for (int i : randomAttributes(data.getDataset(), random, this.m)) {
            Split computeSplit = this.igSplit.computeSplit(data, i);
            if (split == null || split.ig < computeSplit.ig) {
                split = computeSplit;
            }
        }
        if (data.getDataset().isNumerical(split.attr)) {
            return new NumericalNode(split.attr, split.split, build(random, data.subset(Condition.lesser(split.attr, split.split))), build(random, data.subset(Condition.greaterOrEquals(split.attr, split.split))));
        }
        double[] values = data.values(split.attr);
        Node[] nodeArr = new Node[values.length];
        for (int i2 = 0; i2 < values.length; i2++) {
            nodeArr[i2] = build(random, data.subset(Condition.equals(split.attr, values[i2])));
        }
        return new CategoricalNode(split.attr, values, nodeArr);
    }

    protected static int[] randomAttributes(Dataset dataset, Random random, int i) {
        int nextInt;
        if (i > dataset.nbAttributes()) {
            throw new IllegalArgumentException("m > num attributes");
        }
        int[] iArr = new int[i];
        Arrays.fill(iArr, -1);
        for (int i2 = 0; i2 < i; i2++) {
            do {
                nextInt = random.nextInt(dataset.nbAttributes());
            } while (ArrayUtils.contains(iArr, nextInt));
            iArr[i2] = nextInt;
        }
        return iArr;
    }
}
