/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.BalancedWinnow;
import cc.mallet.classify.Boostable;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import java.io.Serializable;
import java.util.Arrays;

public class BalancedWinnowTrainer
extends ClassifierTrainer<BalancedWinnow>
implements Boostable,
Serializable {
    private static final long serialVersionUID = 1L;
    public static final double DEFAULT_EPSILON = 0.5;
    public static final double DEFAULT_DELTA = 0.1;
    public static final int DEFAULT_MAX_ITERATIONS = 30;
    public static final double DEFAULT_COOLING_RATE = 0.5;
    double m_epsilon;
    double m_delta;
    int m_maxIterations;
    double m_coolingRate;
    double[][] m_weights;
    BalancedWinnow classifier;

    @Override
    public BalancedWinnow getClassifier() {
        return this.classifier;
    }

    public BalancedWinnowTrainer() {
        this(0.5, 0.1, 30, 0.5);
    }

    public BalancedWinnowTrainer(double epsilon, double delta, int maxIterations, double coolingRate) {
        this.m_epsilon = epsilon;
        this.m_delta = delta;
        this.m_maxIterations = maxIterations;
        this.m_coolingRate = coolingRate;
    }

    @Override
    public BalancedWinnow train(InstanceList trainingList) {
        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
        if (selectedFeatures != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        double epsilon = this.m_epsilon;
        Alphabet dict = trainingList.getDataAlphabet();
        int numLabels = trainingList.getTargetAlphabet().size();
        int numFeats = dict.size();
        this.m_weights = new double[numLabels][numFeats + 1];
        int i = 0;
        while (i < numLabels) {
            Arrays.fill(this.m_weights[i], 1.0);
            ++i;
        }
        double[] results = new double[numLabels];
        int iter = 0;
        while (iter < this.m_maxIterations) {
            int ii = 0;
            while (ii < trainingList.size()) {
                int fi;
                int fvi;
                Instance inst = (Instance)trainingList.get(ii);
                Labeling labeling = inst.getLabeling();
                FeatureVector fv = (FeatureVector)inst.getData();
                int fvisize = fv.numLocations();
                int correctIndex = labeling.getBestIndex();
                Arrays.fill(results, 0.0);
                int lpos = 0;
                while (lpos < numLabels) {
                    int fvi2 = 0;
                    while (fvi2 < fvisize) {
                        int fi2 = fv.indexAtLocation(fvi2);
                        double vi = fv.valueAtLocation(fvi2);
                        int n = lpos;
                        results[n] = results[n] + vi * this.m_weights[lpos][fi2];
                        ++fvi2;
                    }
                    int n = lpos;
                    results[n] = results[n] + this.m_weights[lpos][numFeats];
                    ++lpos;
                }
                int predictedIndex = 0;
                int secondHighestIndex = 0;
                double max = Double.MIN_VALUE;
                double secondMax = Double.MIN_VALUE;
                int i2 = 0;
                while (i2 < numLabels) {
                    if (results[i2] > max) {
                        secondMax = max;
                        max = results[i2];
                        secondHighestIndex = predictedIndex;
                        predictedIndex = i2;
                    } else if (results[i2] > secondMax) {
                        secondMax = results[i2];
                        secondHighestIndex = i2;
                    }
                    ++i2;
                }
                if (predictedIndex != correctIndex) {
                    fvi = 0;
                    while (fvi < fvisize) {
                        fi = fv.indexAtLocation(fvi);
                        double[] dArray = this.m_weights[predictedIndex];
                        int n = fi;
                        dArray[n] = dArray[n] * (1.0 - epsilon);
                        double[] dArray2 = this.m_weights[correctIndex];
                        int n2 = fi;
                        dArray2[n2] = dArray2[n2] * (1.0 + epsilon);
                        ++fvi;
                    }
                    double[] dArray = this.m_weights[predictedIndex];
                    int n = numFeats;
                    dArray[n] = dArray[n] * (1.0 - epsilon);
                    double[] dArray3 = this.m_weights[correctIndex];
                    int n3 = numFeats;
                    dArray3[n3] = dArray3[n3] * (1.0 + epsilon);
                } else if (max / secondMax - 1.0 < this.m_delta) {
                    fvi = 0;
                    while (fvi < fvisize) {
                        fi = fv.indexAtLocation(fvi);
                        double[] dArray = this.m_weights[secondHighestIndex];
                        int n = fi;
                        dArray[n] = dArray[n] * (1.0 - epsilon);
                        double[] dArray4 = this.m_weights[correctIndex];
                        int n4 = fi;
                        dArray4[n4] = dArray4[n4] * (1.0 + epsilon);
                        ++fvi;
                    }
                    double[] dArray = this.m_weights[secondHighestIndex];
                    int n = numFeats;
                    dArray[n] = dArray[n] * (1.0 - epsilon);
                    double[] dArray5 = this.m_weights[correctIndex];
                    int n5 = numFeats;
                    dArray5[n5] = dArray5[n5] * (1.0 + epsilon);
                }
                ++ii;
            }
            epsilon *= 1.0 - this.m_coolingRate;
            ++iter;
        }
        this.classifier = new BalancedWinnow(trainingList.getPipe(), this.m_weights);
        return this.classifier;
    }
}

