/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.tree;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.classification.Classifier;
import de.jungblut.classification.ClassifierFactory;
import de.jungblut.classification.meta.Voter;
import de.jungblut.classification.tree.DecisionTree;
import de.jungblut.classification.tree.FeatureType;
import de.jungblut.math.DoubleVector;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;

public final class RandomForest
extends AbstractClassifier {
    private final int numTrees;
    private FeatureType[] featureTypes;
    private int numThreads = 1;
    private int numRandomFeaturesToChoose = 0;
    private int maxHeight = Integer.MAX_VALUE;
    private boolean verbose;
    private boolean compile = false;
    private Voter<DecisionTree> trees;

    private RandomForest(int numTrees) {
        this.numTrees = numTrees;
    }

    private RandomForest(int numTrees, Voter<DecisionTree> voter) {
        this(numTrees);
        this.trees = voter;
    }

    @Override
    public void train(DoubleVector[] features, DoubleVector[] outcome) {
        Preconditions.checkArgument((features.length == outcome.length ? 1 : 0) != 0, (Object)"Number of examples and outcomes must match!");
        Preconditions.checkArgument((this.numTrees > 1 ? 1 : 0) != 0, (Object)"There must be at least two trees to make up a forest!");
        if (this.featureTypes == null) {
            this.featureTypes = new FeatureType[features[0].getDimension()];
            Arrays.fill((Object[])this.featureTypes, (Object)FeatureType.NOMINAL);
        }
        int numFeatures = features[0].getDimension();
        if (this.numRandomFeaturesToChoose <= 0) {
            this.numRandomFeaturesToChoose = (int)Math.sqrt(numFeatures);
        }
        Preconditions.checkArgument((this.featureTypes.length == numFeatures ? 1 : 0) != 0, (Object)("FeatureType length must match the dimension of the features! Given: " + numFeatures + ", but expected: " + this.featureTypes.length));
        Preconditions.checkArgument((this.numRandomFeaturesToChoose < numFeatures ? 1 : 0) != 0, (Object)"Number of random features to choose must be lower or equal than the number of features!");
        this.trees = Voter.create(this.numTrees, Voter.CombiningType.MAJORITY, new DecisionTreeFactory()).selectionType(Voter.SelectionType.BAGGING).numThreads(this.numThreads).verbose(this.verbose);
        this.trees.train(features, outcome);
    }

    @Override
    public DoubleVector predict(DoubleVector features) {
        this.trees.setCombiningType(Voter.CombiningType.MAJORITY);
        return this.trees.predict(features);
    }

    @Override
    public DoubleVector predictProbability(DoubleVector features) {
        this.trees.setCombiningType(Voter.CombiningType.PROBABILITY);
        return this.trees.predict(features);
    }

    public RandomForest compile() {
        this.compile = true;
        return this;
    }

    public RandomForest verbose() {
        return this.verbose(true);
    }

    public RandomForest verbose(boolean verb) {
        this.verbose = verb;
        return this;
    }

    public RandomForest setMaxHeight(int max) {
        this.maxHeight = max;
        return this;
    }

    public RandomForest numThreads(int numThreads) {
        this.numThreads = numThreads;
        return this;
    }

    public RandomForest setNumRandomFeaturesToChoose(int numRandomFeaturesToChoose) {
        this.numRandomFeaturesToChoose = numRandomFeaturesToChoose;
        return this;
    }

    public RandomForest setFeatureTypes(FeatureType[] types) {
        this.featureTypes = types;
        return this;
    }

    public static RandomForest create(int numTrees) {
        return new RandomForest(numTrees);
    }

    public static RandomForest create(int numTrees, FeatureType[] types) {
        return new RandomForest(numTrees).setFeatureTypes(types);
    }

    public static void serialize(RandomForest tree, DataOutput out) throws IOException {
        out.writeInt(tree.numTrees);
        for (Classifier c : tree.trees.getClassifier()) {
            DecisionTree.serialize((DecisionTree)c, out);
        }
    }

    public static RandomForest deserialize(DataInput in) throws IOException {
        int numTrees = in.readInt();
        Voter<DecisionTree> voter = Voter.create(numTrees, Voter.CombiningType.MAJORITY, new ClassifierFactory<DecisionTree>(){

            @Override
            public DecisionTree newInstance() {
                return null;
            }
        });
        for (int i = 0; i < numTrees; ++i) {
            voter.getClassifier()[i] = DecisionTree.deserialize(in);
        }
        return new RandomForest(numTrees, voter);
    }

    private final class DecisionTreeFactory
    implements ClassifierFactory<DecisionTree> {
        private DecisionTreeFactory() {
        }

        @Override
        public DecisionTree newInstance() {
            if (RandomForest.this.compile) {
                return DecisionTree.createCompiledTree(RandomForest.this.featureTypes).setNumRandomFeaturesToChoose(RandomForest.this.numRandomFeaturesToChoose).setMaxHeight(RandomForest.this.maxHeight);
            }
            return DecisionTree.create(RandomForest.this.featureTypes).setNumRandomFeaturesToChoose(RandomForest.this.numRandomFeaturesToChoose).setMaxHeight(RandomForest.this.maxHeight);
        }
    }
}

