/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.GainRatio;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.logging.Logger;

public class C45
extends Classifier
implements Boostable,
Serializable {
    private static Logger logger = MalletLogger.getLogger(C45.class.getName());
    Node m_root;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public C45(Pipe instancePipe, Node root) {
        super(instancePipe);
        this.m_root = root;
    }

    public Node getRoot() {
        return this.m_root;
    }

    private Node getLeaf(Node node, FeatureVector fv) {
        if (node.getLeftChild() == null && node.getRightChild() == null) {
            return node;
        }
        if (fv.value(node.getGainRatio().getMaxValuedIndex()) <= node.getGainRatio().getMaxValuedThreshold()) {
            return this.getLeaf(node.getLeftChild(), fv);
        }
        return this.getLeaf(node.getRightChild(), fv);
    }

    @Override
    public Classification classify(Instance instance) {
        FeatureVector fv = (FeatureVector)instance.getData();
        assert (this.instancePipe == null || fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        Node leaf = this.getLeaf(this.m_root, fv);
        return new Classification(instance, this, leaf.getGainRatio().getBaseLabelDistribution());
    }

    public void prune() {
        this.getRoot().computeCostAndPrune();
    }

    public int getSize() {
        Node root = this.getRoot();
        if (root == null) {
            return 0;
        }
        return 1 + root.getNumDescendants();
    }

    @Override
    public void print() {
        if (this.getRoot() != null) {
            this.getRoot().print();
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeObject(this.getInstancePipe());
        out.writeObject(this.m_root);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != 1) {
            throw new ClassNotFoundException("Mismatched C45 versions: wanted 1, got " + version);
        }
        this.instancePipe = (Pipe)in.readObject();
        this.m_root = (Node)in.readObject();
    }

    public static class Node
    implements Serializable {
        private static final long serialVersionUID = 1L;
        GainRatio m_gainRatio;
        InstanceList m_ilist;
        int[] m_instIndices;
        Alphabet m_dataDict;
        int m_minNumInsts;
        Node m_parent;
        Node m_leftChild;
        Node m_rightChild;

        public Node(InstanceList ilist, Node parent, int minNumInsts) {
            this(ilist, parent, minNumInsts, null);
        }

        public Node(InstanceList ilist, Node parent, int minNumInsts, int[] instIndices) {
            if (instIndices == null) {
                instIndices = new int[ilist.size()];
                int ii = 0;
                while (ii < instIndices.length) {
                    instIndices[ii] = ii;
                    ++ii;
                }
            }
            this.m_gainRatio = GainRatio.createGainRatio(ilist, instIndices, minNumInsts);
            this.m_ilist = ilist;
            this.m_instIndices = instIndices;
            this.m_dataDict = this.m_ilist.getDataAlphabet();
            this.m_minNumInsts = minNumInsts;
            this.m_parent = parent;
            this.m_rightChild = null;
            this.m_leftChild = null;
        }

        public int depth() {
            int depth = 0;
            Node p = this.m_parent;
            while (p != null) {
                p = p.m_parent;
                ++depth;
            }
            return depth;
        }

        public int getSize() {
            return this.m_instIndices.length;
        }

        public boolean isLeaf() {
            return this.m_leftChild == null && this.m_rightChild == null;
        }

        public boolean isRoot() {
            return this.m_parent == null;
        }

        public Node getParent() {
            return this.m_parent;
        }

        public Node getLeftChild() {
            return this.m_leftChild;
        }

        public Node getRightChild() {
            return this.m_rightChild;
        }

        public GainRatio getGainRatio() {
            return this.m_gainRatio;
        }

        public Object getSplitFeature() {
            return this.m_dataDict.lookupObject(this.m_gainRatio.getMaxValuedIndex());
        }

        public InstanceList getInstances() {
            InstanceList ret = new InstanceList(this.m_ilist.getPipe());
            int ii = 0;
            while (ii < this.m_instIndices.length) {
                ret.add((Instance)this.m_ilist.get(this.m_instIndices[ii]));
                ++ii;
            }
            return ret;
        }

        public int getNumDescendants() {
            if (this.isLeaf()) {
                return 0;
            }
            int count = 0;
            if (!this.getLeftChild().isLeaf()) {
                count += 1 + this.getLeftChild().getNumDescendants();
            }
            if (!this.getRightChild().isLeaf()) {
                count += 1 + this.getRightChild().getNumDescendants();
            }
            return count;
        }

        public void split() {
            if (this.m_ilist == null) {
                throw new IllegalStateException("Frozen.  Cannot split.");
            }
            int numLeftChildren = 0;
            boolean[] toLeftChild = new boolean[this.m_instIndices.length];
            int i = 0;
            while (i < this.m_instIndices.length) {
                Instance instance = (Instance)this.m_ilist.get(this.m_instIndices[i]);
                FeatureVector fv = (FeatureVector)instance.getData();
                if (fv.value(this.m_gainRatio.getMaxValuedIndex()) <= this.m_gainRatio.getMaxValuedThreshold()) {
                    toLeftChild[i] = true;
                    ++numLeftChildren;
                } else {
                    toLeftChild[i] = false;
                }
                ++i;
            }
            logger.info("leftChild.size=" + numLeftChildren + " rightChild.size=" + (this.m_instIndices.length - numLeftChildren));
            int[] leftIndices = new int[numLeftChildren];
            int[] rightIndices = new int[this.m_instIndices.length - numLeftChildren];
            int li = 0;
            int ri = 0;
            int i2 = 0;
            while (i2 < this.m_instIndices.length) {
                if (toLeftChild[i2]) {
                    leftIndices[li++] = this.m_instIndices[i2];
                } else {
                    rightIndices[ri++] = this.m_instIndices[i2];
                }
                ++i2;
            }
            this.m_leftChild = new Node(this.m_ilist, this, this.m_minNumInsts, leftIndices);
            this.m_rightChild = new Node(this.m_ilist, this, this.m_minNumInsts, rightIndices);
        }

        public double computeCostAndPrune() {
            double costS = this.getMDL();
            if (this.isLeaf()) {
                return costS + 1.0;
            }
            double minCost1 = this.getLeftChild().computeCostAndPrune();
            double minCost2 = this.getRightChild().computeCostAndPrune();
            double costSplit = Math.log(this.m_gainRatio.getNumSplitPointsForBestFeature()) / GainRatio.log2;
            double minCostN = Math.min(costS + 1.0, costSplit + 1.0 + minCost1 + minCost2);
            if (Maths.almostEquals(minCostN, costS + 1.0)) {
                this.m_rightChild = null;
                this.m_leftChild = null;
            }
            return minCostN;
        }

        public double getMDL() {
            int numClasses = this.m_ilist.getTargetAlphabet().size();
            double mdl = (double)this.getSize() * this.getGainRatio().getBaseEntropy();
            mdl += (double)(numClasses - 1) * Math.log((double)this.getSize() / 2.0) / (2.0 * GainRatio.log2);
            double piPow = Math.pow(Math.PI, (double)numClasses / 2.0);
            double gammaVal = Maths.gamma((double)numClasses / 2.0);
            return mdl += Math.log(piPow / gammaVal) / GainRatio.log2;
        }

        public void stopGrowth() {
            if (this.m_leftChild != null) {
                this.m_leftChild.stopGrowth();
            }
            if (this.m_rightChild != null) {
                this.m_rightChild.stopGrowth();
            }
            this.m_ilist = null;
        }

        public String getName() {
            return this.getStringBufferName().toString();
        }

        public StringBuffer getStringBufferName() {
            StringBuffer sb = new StringBuffer();
            if (this.m_parent == null) {
                return sb.append("root");
            }
            if (this.m_parent.getParent() == null) {
                sb.append("(\"");
                sb.append(this.m_dataDict.lookupObject(this.m_parent.getGainRatio().getMaxValuedIndex()).toString());
                sb.append("\"");
                if (this.m_parent.getLeftChild() == this) {
                    sb.append(" <= ");
                } else {
                    sb.append(" > ");
                }
                sb.append(this.m_parent.getGainRatio().getMaxValuedThreshold());
                return sb.append(")");
            }
            sb.append(this.m_parent.getStringBufferName());
            sb.append(" && (\"");
            sb.append(this.m_dataDict.lookupObject(this.m_parent.getGainRatio().getMaxValuedIndex()).toString());
            sb.append("\"");
            if (this.m_parent.getLeftChild() == this) {
                sb.append(" <= ");
            } else {
                sb.append(" > ");
            }
            sb.append(this.m_parent.getGainRatio().getMaxValuedThreshold());
            return sb.append(")");
        }

        public void print() {
            this.print("");
        }

        public void print(String prefix) {
            if (this.isLeaf()) {
                int bestLabelIndex = this.getGainRatio().getBaseLabelDistribution().getBestIndex();
                int numMajorityLabel = (int)(this.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * (double)this.getSize());
                System.out.println("root:" + this.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + this.getSize());
            } else {
                int numMajorityLabel;
                int bestLabelIndex;
                String featName = this.m_dataDict.lookupObject(this.getGainRatio().getMaxValuedIndex()).toString();
                double threshold = this.getGainRatio().getMaxValuedThreshold();
                System.out.print(String.valueOf(prefix) + "\"" + featName + "\" <= " + threshold + ":");
                if (this.m_leftChild.isLeaf()) {
                    bestLabelIndex = this.m_leftChild.getGainRatio().getBaseLabelDistribution().getBestIndex();
                    numMajorityLabel = (int)(this.m_leftChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * (double)this.m_leftChild.getSize());
                    System.out.println(this.m_leftChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + this.m_leftChild.getSize());
                } else {
                    System.out.println();
                    this.m_leftChild.print(String.valueOf(prefix) + "|    ");
                }
                System.out.print(String.valueOf(prefix) + "\"" + featName + "\" > " + threshold + ":");
                if (this.m_rightChild.isLeaf()) {
                    bestLabelIndex = this.m_rightChild.getGainRatio().getBaseLabelDistribution().getBestIndex();
                    numMajorityLabel = (int)(this.m_rightChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * (double)this.m_rightChild.getSize());
                    System.out.println(this.m_rightChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + this.m_rightChild.getSize());
                } else {
                    System.out.println();
                    this.m_rightChild.print(String.valueOf(prefix) + "|    ");
                }
            }
        }
    }
}

