/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.tree;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.classification.tree.AbstractTreeNode;
import de.jungblut.classification.tree.FeatureType;
import de.jungblut.classification.tree.LeafNode;
import de.jungblut.classification.tree.NominalNode;
import de.jungblut.classification.tree.NumericalNode;
import de.jungblut.classification.tree.Split;
import de.jungblut.classification.tree.TreeCompiler;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleVector;
import de.jungblut.math.tuple.Tuple;
import gnu.trove.TIntCollection;
import gnu.trove.iterator.TIntObjectIterator;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TDoubleHashSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.math3.util.FastMath;
import org.apache.hadoop.io.WritableUtils;

public final class DecisionTree
extends AbstractClassifier {
    private static final double LOG2 = FastMath.log((double)2.0);
    private AbstractTreeNode rootNode;
    private FeatureType[] featureTypes;
    private int numRandomFeaturesToChoose;
    private int maxHeight = 25;
    private long seed = System.currentTimeMillis();
    private boolean binaryClassification = true;
    private boolean compile = false;
    private String compiledName = null;
    private byte[] compiledClass = null;
    private int outcomeDimension;
    private int numFeatures;

    private DecisionTree() {
    }

    private DecisionTree(AbstractTreeNode rootNode, FeatureType[] featureTypes, boolean binaryClassification, int numFeatures, int outcomeDimension) {
        this.binaryClassification = binaryClassification;
        this.rootNode = rootNode;
        this.featureTypes = featureTypes;
        this.numFeatures = numFeatures;
        this.outcomeDimension = outcomeDimension;
        this.compile = true;
    }

    @Override
    public void train(DoubleVector[] features, DoubleVector[] outcome) {
        Preconditions.checkArgument((features.length == outcome.length ? 1 : 0) != 0, (Object)"Number of examples and outcomes must match!");
        if (this.featureTypes == null) {
            this.featureTypes = new FeatureType[features[0].getDimension()];
            Arrays.fill((Object[])this.featureTypes, (Object)FeatureType.NOMINAL);
        }
        Preconditions.checkArgument((this.featureTypes.length == features[0].getDimension() ? 1 : 0) != 0, (Object)"FeatureType length must match the dimension of the features!");
        this.binaryClassification = outcome[0].getDimension() == 1;
        this.outcomeDimension = this.binaryClassification ? 2 : outcome[0].getDimension();
        this.numFeatures = features[0].getDimension();
        TIntHashSet possibleFeatureIndices = this.getPossibleFeatures();
        this.rootNode = this.build(Lists.newArrayList((Object[])features), Lists.newArrayList((Object[])outcome), possibleFeatureIndices, 0);
        if (this.compile) {
            try {
                this.compileTree();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    @Override
    public DoubleVector predict(DoubleVector features) {
        int clz = this.rootNode.predict(features);
        if (clz < 0) {
            clz = 0;
        }
        if (this.binaryClassification) {
            return new DenseDoubleVector(new double[]{clz});
        }
        SparseDoubleVector vec = this.outcomeDimension > 10 ? new SparseDoubleVector(this.outcomeDimension) : new DenseDoubleVector(this.outcomeDimension);
        vec.set(clz, 1.0);
        return vec;
    }

    public void compileTree() throws Exception {
        if (this.compiledClass == null) {
            this.compiledName = TreeCompiler.generateClassName();
            this.compiledClass = TreeCompiler.compileNode(this.compiledName, this.rootNode);
            this.rootNode = TreeCompiler.load(this.compiledName, this.compiledClass);
        }
    }

    TIntHashSet chooseRandomFeatures(TIntHashSet possibleFeatureIndices) {
        if (this.numRandomFeaturesToChoose > 0 && this.numRandomFeaturesToChoose < this.numFeatures && possibleFeatureIndices.size() > this.numRandomFeaturesToChoose) {
            TIntHashSet set = new TIntHashSet();
            int[] arr = possibleFeatureIndices.toArray();
            Random rnd = new Random(this.seed);
            while (set.size() < this.numRandomFeaturesToChoose) {
                set.add(arr[rnd.nextInt(arr.length)]);
            }
            return set;
        }
        return possibleFeatureIndices;
    }

    private AbstractTreeNode build(List<DoubleVector> features, List<DoubleVector> outcome, TIntHashSet possibleFeatureIndices, int level) {
        possibleFeatureIndices = this.chooseRandomFeatures(possibleFeatureIndices);
        int[] countOutcomeClasses = this.getPossibleClasses(outcome);
        TIntHashSet notZeroClasses = new TIntHashSet();
        for (int i = 0; i < countOutcomeClasses.length; ++i) {
            if (countOutcomeClasses[i] == 0) continue;
            notZeroClasses.add(i);
        }
        if (notZeroClasses.size() == 1) {
            return new LeafNode(notZeroClasses.iterator().next());
        }
        if (possibleFeatureIndices.isEmpty() || level >= this.maxHeight) {
            return new LeafNode(ArrayUtils.maxIndex(countOutcomeClasses));
        }
        double targetEntropy = DecisionTree.getEntropy(countOutcomeClasses);
        Split[] infoGain = new Split[this.numFeatures];
        for (int featureIndex : possibleFeatureIndices.toArray()) {
            infoGain[featureIndex] = this.computeSplit(targetEntropy, featureIndex, countOutcomeClasses, features, outcome);
        }
        int maxIndex = 0;
        double maxGain = infoGain[maxIndex] != null ? infoGain[maxIndex].getInformationGain() : -2.147483648E9;
        for (int i = 1; i < infoGain.length; ++i) {
            if (infoGain[i] == null || !(infoGain[i].getInformationGain() > maxGain)) continue;
            maxGain = infoGain[i].getInformationGain();
            maxIndex = i;
        }
        Split bestSplit = infoGain[maxIndex];
        int bestSplitIndex = bestSplit.getSplitAttributeIndex();
        if (this.featureTypes[bestSplitIndex].isNominal()) {
            TIntHashSet uniqueFeatures = this.getNominalValues(bestSplitIndex, features);
            NominalNode node = new NominalNode(bestSplitIndex, uniqueFeatures.size());
            int cIndex = 0;
            int[] nArray = uniqueFeatures.toArray();
            int n = nArray.length;
            for (int i = 0; i < n; ++i) {
                int nominalValue;
                node.nominalSplitValues[cIndex] = nominalValue = nArray[i];
                Tuple<List<DoubleVector>, List<DoubleVector>> filtered = this.filterNominal(features, outcome, bestSplitIndex, nominalValue);
                TIntHashSet newPossibleFeatures = new TIntHashSet((TIntCollection)possibleFeatureIndices);
                newPossibleFeatures.remove(bestSplitIndex);
                node.children[cIndex] = this.build((List)filtered.getFirst(), (List)filtered.getSecond(), newPossibleFeatures, level + 1);
                ++cIndex;
            }
            node.sortInternal();
            return node;
        }
        TIntHashSet newPossibleFeatures = new TIntHashSet((TIntCollection)possibleFeatureIndices);
        Tuple<List<DoubleVector>, List<DoubleVector>> filterNumeric = this.filterNumeric(features, outcome, bestSplitIndex, bestSplit.getNumericalSplitValue(), true);
        Tuple<List<DoubleVector>, List<DoubleVector>> filterNumericHigher = this.filterNumeric(features, outcome, bestSplitIndex, bestSplit.getNumericalSplitValue(), false);
        if (((List)filterNumeric.getFirst()).isEmpty() || ((List)filterNumericHigher.getFirst()).isEmpty()) {
            newPossibleFeatures.remove(bestSplitIndex);
        } else {
            for (int i = 0; i < this.featureTypes.length; ++i) {
                if (!this.featureTypes[i].isNumerical()) continue;
                newPossibleFeatures.add(i);
            }
        }
        AbstractTreeNode lower = this.build((List)filterNumeric.getFirst(), (List)filterNumeric.getSecond(), new TIntHashSet((TIntCollection)newPossibleFeatures), level + 1);
        AbstractTreeNode higher = this.build((List)filterNumericHigher.getFirst(), (List)filterNumericHigher.getSecond(), new TIntHashSet((TIntCollection)newPossibleFeatures), level + 1);
        return new NumericalNode(bestSplitIndex, bestSplit.getNumericalSplitValue(), lower, higher);
    }

    private Tuple<List<DoubleVector>, List<DoubleVector>> filterNominal(List<DoubleVector> features, List<DoubleVector> outcome, int bestSplitIndex, int nominalValue) {
        ArrayList newFeatures = Lists.newArrayList();
        ArrayList newOutcomes = Lists.newArrayList();
        Iterator<DoubleVector> featureIterator = features.iterator();
        Iterator<DoubleVector> outcomeIterator = outcome.iterator();
        while (featureIterator.hasNext()) {
            DoubleVector feature = featureIterator.next();
            DoubleVector out = outcomeIterator.next();
            if ((int)feature.get(bestSplitIndex) != nominalValue) continue;
            newFeatures.add(feature);
            newOutcomes.add(out);
        }
        return new Tuple((Object)newFeatures, (Object)newOutcomes);
    }

    private Tuple<List<DoubleVector>, List<DoubleVector>> filterNumeric(List<DoubleVector> features, List<DoubleVector> outcome, int bestSplitIndex, double splitValue, boolean lower) {
        ArrayList newFeatures = Lists.newArrayList();
        ArrayList newOutcomes = Lists.newArrayList();
        Iterator<DoubleVector> featureIterator = features.iterator();
        Iterator<DoubleVector> outcomeIterator = outcome.iterator();
        while (featureIterator.hasNext()) {
            DoubleVector feature = featureIterator.next();
            DoubleVector out = outcomeIterator.next();
            if (lower) {
                if (!(feature.get(bestSplitIndex) <= splitValue)) continue;
                newFeatures.add(feature);
                newOutcomes.add(out);
                continue;
            }
            if (!(feature.get(bestSplitIndex) > splitValue)) continue;
            newFeatures.add(feature);
            newOutcomes.add(out);
        }
        return new Tuple((Object)newFeatures, (Object)newOutcomes);
    }

    private Split computeSplit(double overallEntropy, int featureIndex, int[] countOutcomeClasses, List<DoubleVector> features, List<DoubleVector> outcome) {
        if (this.featureTypes[featureIndex].isNominal()) {
            TIntObjectHashMap featureValueOutcomeCount = new TIntObjectHashMap();
            TIntIntHashMap rowSums = new TIntIntHashMap();
            int numFeatures = 0;
            Iterator<DoubleVector> featureIterator = features.iterator();
            Iterator<DoubleVector> outcomeIterator = outcome.iterator();
            while (featureIterator.hasNext()) {
                DoubleVector feature = featureIterator.next();
                DoubleVector out = outcomeIterator.next();
                int classIndex = this.getOutcomeClassIndex(out);
                int nominalFeatureValue = (int)feature.get(featureIndex);
                int[] is = (int[])featureValueOutcomeCount.get(nominalFeatureValue);
                if (is == null) {
                    is = new int[this.outcomeDimension];
                    featureValueOutcomeCount.put(nominalFeatureValue, (Object)is);
                }
                int n = classIndex;
                is[n] = is[n] + 1;
                rowSums.put(nominalFeatureValue, rowSums.get(nominalFeatureValue) + 1);
                ++numFeatures;
            }
            double entropySum = 0.0;
            TIntObjectIterator iterator = featureValueOutcomeCount.iterator();
            while (iterator.hasNext()) {
                iterator.advance();
                int[] outcomeCounts = (int[])iterator.value();
                double condEntropy = (double)rowSums.get(iterator.key()) / (double)numFeatures * DecisionTree.getEntropy(outcomeCounts);
                entropySum += condEntropy;
            }
            return new Split(featureIndex, overallEntropy - entropySum);
        }
        Iterator<DoubleVector> featureIterator = features.iterator();
        TDoubleHashSet possibleFeatureValues = new TDoubleHashSet();
        while (featureIterator.hasNext()) {
            DoubleVector feature = featureIterator.next();
            possibleFeatureValues.add(feature.get(featureIndex));
        }
        double bestInfogain = -1.0;
        double bestSplit = 0.0;
        for (double value : possibleFeatureValues) {
            double ig = this.computeNumericalInfogain(features, outcome, overallEntropy, featureIndex, value);
            if (!(ig > bestInfogain)) continue;
            bestInfogain = ig;
            bestSplit = value;
        }
        return new Split(featureIndex, bestInfogain, bestSplit);
    }

    private double computeNumericalInfogain(List<DoubleVector> features, List<DoubleVector> outcome, double overallEntropy, int featureIndex, double value) {
        double invDatasize = 1.0 / (double)features.size();
        int[][] counts = new int[2][this.outcomeDimension];
        int lowCount = 0;
        int highCount = 0;
        Arrays.fill((Object[])counts, new int[this.outcomeDimension]);
        Iterator<DoubleVector> featureIterator = features.iterator();
        Iterator<DoubleVector> outcomeIterator = outcome.iterator();
        while (featureIterator.hasNext()) {
            DoubleVector feature = featureIterator.next();
            DoubleVector out = outcomeIterator.next();
            int idx = this.getOutcomeClassIndex(out);
            if (feature.get(featureIndex) > value) {
                int[] nArray = counts[1];
                int n = idx;
                nArray[n] = nArray[n] + 1;
                ++highCount;
                continue;
            }
            int[] nArray = counts[0];
            int n = idx;
            nArray[n] = nArray[n] + 1;
            ++lowCount;
        }
        overallEntropy -= (double)lowCount * invDatasize * DecisionTree.getEntropy(counts[0]);
        return overallEntropy -= (double)highCount * invDatasize * DecisionTree.getEntropy(counts[1]);
    }

    private int getOutcomeClassIndex(DoubleVector out) {
        int classIndex = 0;
        classIndex = this.binaryClassification ? (int)out.get(0) : out.maxIndex();
        return classIndex;
    }

    private TIntHashSet getNominalValues(int featureIndex, List<DoubleVector> features) {
        TIntHashSet uniqueFeatures = new TIntHashSet();
        for (DoubleVector vec : features) {
            int featureValue = (int)vec.get(featureIndex);
            uniqueFeatures.add(featureValue);
        }
        return uniqueFeatures;
    }

    private int[] getPossibleClasses(List<DoubleVector> outcome) {
        int[] clzs = new int[this.outcomeDimension];
        for (DoubleVector out : outcome) {
            if (this.binaryClassification) {
                int n = (int)out.get(0);
                clzs[n] = clzs[n] + 1;
                continue;
            }
            int n = out.maxIndex();
            clzs[n] = clzs[n] + 1;
        }
        return clzs;
    }

    public DecisionTree setFeatureTypes(FeatureType[] featureTypes) {
        this.featureTypes = featureTypes;
        return this;
    }

    public DecisionTree setNumRandomFeaturesToChoose(int numRandomFeaturesToChoose) {
        this.numRandomFeaturesToChoose = numRandomFeaturesToChoose;
        return this;
    }

    public DecisionTree setCompiled(boolean compiled) {
        this.compile = compiled;
        return this;
    }

    public DecisionTree setMaxHeight(int max) {
        this.maxHeight = max;
        return this;
    }

    public DecisionTree setSeed(long seed) {
        this.seed = seed;
        return this;
    }

    void setNumFeatures(int numFeatures) {
        this.numFeatures = numFeatures;
    }

    TIntHashSet getPossibleFeatures() {
        TIntHashSet possibleFeatureIndices = new TIntHashSet();
        for (int i = 0; i < this.numFeatures; ++i) {
            possibleFeatureIndices.add(i);
        }
        return possibleFeatureIndices;
    }

    public static void serialize(DecisionTree tree, DataOutput out) throws IOException {
        try {
            out.writeBoolean(tree.binaryClassification);
            WritableUtils.writeVInt((DataOutput)out, (int)tree.outcomeDimension);
            WritableUtils.writeVInt((DataOutput)out, (int)tree.numFeatures);
            for (int i = 0; i < tree.featureTypes.length; ++i) {
                WritableUtils.writeVInt((DataOutput)out, (int)tree.featureTypes[i].ordinal());
            }
            if (tree.compiledClass == null) {
                out.writeBoolean(false);
                tree.rootNode.write(out);
            } else {
                out.writeBoolean(true);
                out.writeUTF(tree.compiledName);
                WritableUtils.writeCompressedByteArray((DataOutput)out, (byte[])tree.compiledClass);
            }
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    public static DecisionTree deserialize(DataInput in) throws IOException {
        boolean binary = in.readBoolean();
        int outcomeDimension = WritableUtils.readVInt((DataInput)in);
        int numFeatures = WritableUtils.readVInt((DataInput)in);
        FeatureType[] arr = new FeatureType[numFeatures];
        for (int i = 0; i < numFeatures; ++i) {
            arr[i] = FeatureType.values()[WritableUtils.readVInt((DataInput)in)];
        }
        if (in.readBoolean()) {
            String name = in.readUTF();
            byte[] compiled = WritableUtils.readCompressedByteArray((DataInput)in);
            try {
                AbstractTreeNode loadedRoot = TreeCompiler.load(name, compiled);
                return new DecisionTree(loadedRoot, arr, binary, numFeatures, outcomeDimension);
            }
            catch (Exception e) {
                throw new IOException(e);
            }
        }
        AbstractTreeNode root = AbstractTreeNode.read(in);
        return new DecisionTree(root, arr, binary, numFeatures, outcomeDimension);
    }

    public static DecisionTree create() {
        return new DecisionTree();
    }

    public static DecisionTree create(FeatureType[] featureTypes) {
        return new DecisionTree().setFeatureTypes(featureTypes);
    }

    public static DecisionTree createCompiledTree() {
        return new DecisionTree().setCompiled(true);
    }

    public static DecisionTree createCompiledTree(FeatureType[] featureTypes) {
        return new DecisionTree().setFeatureTypes(featureTypes).setCompiled(true);
    }

    static double getEntropy(int[] outcomeCounter) {
        double entropySum = 0.0;
        double sum = 0.0;
        for (int x : outcomeCounter) {
            sum += (double)x;
        }
        for (int x : outcomeCounter) {
            if (x == 0) {
                return 0.0;
            }
            double conditionalProbability = (double)x / sum;
            entropySum -= conditionalProbability * DecisionTree.log2(conditionalProbability);
        }
        return entropySum;
    }

    private static double log2(double num) {
        return FastMath.log((double)num) / LOG2;
    }
}

