/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.AdaBoostM2;
import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Logger;

public class AdaBoostM2Trainer
extends ClassifierTrainer<AdaBoostM2> {
    private static Logger logger = MalletLogger.getLogger(AdaBoostM2Trainer.class.getName());
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    ClassifierTrainer weakLearner;
    int numRounds;
    AdaBoostM2 classifier;

    @Override
    public AdaBoostM2 getClassifier() {
        return this.classifier;
    }

    public AdaBoostM2Trainer(ClassifierTrainer weakLearner, int numRounds) {
        if (!(weakLearner instanceof Boostable)) {
            throw new IllegalArgumentException("weak learner not boostable");
        }
        if (numRounds <= 0) {
            throw new IllegalArgumentException("number of rounds must be positive");
        }
        this.weakLearner = weakLearner;
        this.numRounds = numRounds;
    }

    public AdaBoostM2Trainer(ClassifierTrainer weakLearner) {
        this(weakLearner, 100);
    }

    @Override
    public AdaBoostM2 train(InstanceList trainingList) {
        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
        if (selectedFeatures != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        int numClasses = trainingList.getTargetAlphabet().size();
        int numInstances = trainingList.size();
        InstanceList trainingInsts = new InstanceList(trainingList.getPipe());
        double[] weights = new double[numInstances * (numClasses - 1)];
        double w = 1.0 / (double)weights.length;
        Arrays.fill(weights, w);
        int[] classIndices = new int[weights.length];
        int numAdded = 0;
        int i = 0;
        while (i < numInstances) {
            Instance inst = (Instance)trainingList.get(i);
            int trueClassIndex = inst.getLabeling().getBestIndex();
            int j = 0;
            while (j < numClasses) {
                if (j != trueClassIndex) {
                    trainingInsts.add(inst, 1.0);
                    classIndices[numAdded] = j;
                    ++numAdded;
                }
                ++j;
            }
            ++i;
        }
        Random random = new Random();
        Classifier[] weakLearners = new Classifier[this.numRounds];
        double[] classifierWeights = new double[this.numRounds];
        double[] exponents = new double[weights.length];
        int[] instIndices = new int[weights.length];
        int i2 = 0;
        while (i2 < instIndices.length) {
            instIndices[i2] = i2;
            ++i2;
        }
        int round = 0;
        while (round < this.numRounds) {
            double epsilon;
            logger.info("===========  AdaBoostM2Trainer round " + (round + 1) + " begin");
            InstanceList roundTrainingInsts = new InstanceList(trainingInsts.getPipe());
            int resamplingIterations = 0;
            do {
                Instance inst;
                epsilon = 0.0;
                int[] sampleIndices = this.sampleWithWeights(instIndices, weights, random);
                roundTrainingInsts = new InstanceList(trainingInsts.getPipe(), sampleIndices.length);
                int i3 = 0;
                while (i3 < sampleIndices.length) {
                    inst = (Instance)trainingInsts.get(sampleIndices[i3]);
                    roundTrainingInsts.add(inst, 1.0);
                    ++i3;
                }
                weakLearners[round] = this.weakLearner.train(roundTrainingInsts);
                i3 = 0;
                while (i3 < trainingInsts.size()) {
                    inst = (Instance)trainingInsts.get(i3);
                    Classification c = weakLearners[round].classify(inst);
                    double htCorrect = c.valueOfCorrectLabel();
                    double htWrong = c.getLabeling().value(classIndices[i3]);
                    epsilon += weights[i3] * (1.0 - htCorrect + htWrong);
                    exponents[i3] = 1.0 + htCorrect - htWrong;
                    ++i3;
                }
            } while (Maths.almostEquals(epsilon *= 0.5, 0.0) && ++resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS);
            if (Maths.almostEquals(epsilon, 0.0)) {
                int numClassifiersToUse;
                logger.info("AdaBoostM2Trainer stopped at " + (round + 1) + " / " + this.numRounds + " pseudo-loss=" + epsilon);
                int n = numClassifiersToUse = round == 0 ? 1 : round;
                if (round == 0) {
                    classifierWeights[0] = 1.0;
                }
                double[] classifierWeights2 = new double[numClassifiersToUse];
                Classifier[] weakLearners2 = new Classifier[numClassifiersToUse];
                System.arraycopy(classifierWeights, 0, classifierWeights2, 0, numClassifiersToUse);
                System.arraycopy(weakLearners, 0, weakLearners2, 0, numClassifiersToUse);
                int i4 = 0;
                while (i4 < classifierWeights2.length) {
                    logger.info("AdaBoostM2Trainer weight[weakLearner[" + i4 + "]]=" + classifierWeights2[i4]);
                    ++i4;
                }
                return new AdaBoostM2(trainingInsts.getPipe(), weakLearners2, classifierWeights2);
            }
            double beta = epsilon / (1.0 - epsilon);
            classifierWeights[round] = Math.log(1.0 / beta);
            double sum = 0.0;
            int i5 = 0;
            while (i5 < weights.length) {
                int n = i5;
                weights[n] = weights[n] * Math.pow(beta, 0.5 * exponents[i5]);
                sum += weights[i5];
                ++i5;
            }
            MatrixOps.timesEquals(weights, 1.0 / sum);
            logger.info("===========  AdaBoostM2Trainer round " + (round + 1) + " finished, pseudo-loss = " + epsilon);
            ++round;
        }
        i2 = 0;
        while (i2 < classifierWeights.length) {
            logger.info("AdaBoostM2Trainer weight[weakLearner[" + i2 + "]]=" + classifierWeights[i2]);
            ++i2;
        }
        this.classifier = new AdaBoostM2(trainingInsts.getPipe(), weakLearners, classifierWeights);
        return this.classifier;
    }

    private int[] sampleWithWeights(int[] data, double[] weights, Random random) {
        if (weights.length != data.length) {
            throw new IllegalArgumentException("length of weight vector must equal number of data points");
        }
        double sumOfWeights = 0.0;
        int i = 0;
        while (i < data.length) {
            if (weights[i] < 0.0) {
                throw new IllegalArgumentException("weight vector must be non-negative");
            }
            sumOfWeights += weights[i];
            ++i;
        }
        if (sumOfWeights <= 0.0) {
            throw new IllegalArgumentException("weights must sum to positive value");
        }
        int[] sample = new int[data.length];
        double[] probabilities = new double[data.length];
        double sumProbs = 0.0;
        int i2 = 0;
        while (i2 < data.length) {
            probabilities[i2] = sumProbs += random.nextDouble();
            ++i2;
        }
        MatrixOps.timesEquals(probabilities, sumOfWeights / sumProbs);
        probabilities[data.length - 1] = sumOfWeights;
        int a = 0;
        int b = 0;
        sumProbs = 0.0;
        while (a < data.length && b < data.length) {
            sumProbs += weights[b];
            while (a < data.length && probabilities[a] <= sumProbs) {
                sample[a] = data[b];
                ++a;
            }
            ++b;
        }
        return sample;
    }
}

