package opennlp.tools.ml.perceptron;

import com.fasterxml.jackson.core.util.MinimalPrettyPrinter;
import java.io.IOException;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.MutableContext;

/* loaded from: input_file:opennlp/tools/ml/perceptron/PerceptronTrainer.class */
public class PerceptronTrainer extends AbstractEventTrainer {
    public static final String PERCEPTRON_VALUE = "PERCEPTRON";
    public static final double TOLERANCE_DEFAULT = 1.0E-5d;
    private int numUniqueEvents;
    private int numEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private String[] outcomeLabels;
    private String[] predLabels;
    private boolean printMessages = true;
    private double tolerance = 1.0E-5d;
    private Double stepSizeDecrease;
    private boolean useSkippedlAveraging;

    @Override // opennlp.tools.ml.AbstractEventTrainer, opennlp.tools.ml.AbstractTrainer
    public boolean isValid() {
        String algorithm = getAlgorithm();
        return algorithm == null || PERCEPTRON_VALUE.equals(algorithm);
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer
    public boolean isSortAndMerge() {
        return false;
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer
    public AbstractModel doTrain(DataIndexer dataIndexer) throws IOException {
        if (!isValid()) {
            throw new IllegalArgumentException("trainParams are not valid!");
        }
        int iterations = getIterations();
        int cutoff = getCutoff();
        boolean booleanParam = getBooleanParam("UseAverage", true);
        boolean booleanParam2 = getBooleanParam("UseSkippedAveraging", false);
        if (booleanParam2) {
            booleanParam = true;
        }
        double doubleParam = getDoubleParam("StepSizeDecrease", 0.0d);
        double doubleParam2 = getDoubleParam("Tolerance", 1.0E-5d);
        setSkippedAveraging(booleanParam2);
        if (doubleParam > 0.0d) {
            setStepSizeDecrease(doubleParam);
        }
        setTolerance(doubleParam2);
        return trainModel(iterations, dataIndexer, cutoff, booleanParam);
    }

    public void setTolerance(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("tolerance must be a positive number but is " + d + "!");
        }
        this.tolerance = d;
    }

    public void setStepSizeDecrease(double d) {
        if (d < 0.0d || d > 100.0d) {
            throw new IllegalArgumentException("decrease must be between 0 and 100 but is " + d + "!");
        }
        this.stepSizeDecrease = Double.valueOf(d);
    }

    public void setSkippedAveraging(boolean z) {
        this.useSkippedlAveraging = z;
    }

    public AbstractModel trainModel(int i, DataIndexer dataIndexer, int i2) {
        return trainModel(i, dataIndexer, i2, true);
    }

    public AbstractModel trainModel(int i, DataIndexer dataIndexer, int i2, boolean z) {
        display("Incorporating indexed data for training...  \n");
        this.contexts = dataIndexer.getContexts();
        this.values = dataIndexer.getValues();
        this.numTimesEventsSeen = dataIndexer.getNumTimesEventsSeen();
        this.numEvents = dataIndexer.getNumEvents();
        this.numUniqueEvents = this.contexts.length;
        this.outcomeLabels = dataIndexer.getOutcomeLabels();
        this.outcomeList = dataIndexer.getOutcomeList();
        this.predLabels = dataIndexer.getPredLabels();
        this.numPreds = this.predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        display("done.\n");
        display("\tNumber of Event Tokens: " + this.numUniqueEvents + "\n");
        display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        display("\t  Number of Predicates: " + this.numPreds + "\n");
        display("Computing model parameters...\n");
        MutableContext[] findParameters = findParameters(i, z);
        display("...done.\n");
        return new PerceptronModel(findParameters, this.predLabels, this.outcomeLabels);
    }

    private MutableContext[] findParameters(int i, boolean z) {
        display("Performing " + i + " iterations.\n");
        int[] iArr = new int[this.numOutcomes];
        for (int i2 = 0; i2 < this.numOutcomes; i2++) {
            iArr[i2] = i2;
        }
        MutableContext[] mutableContextArr = new MutableContext[this.numPreds];
        for (int i3 = 0; i3 < this.numPreds; i3++) {
            mutableContextArr[i3] = new MutableContext(iArr, new double[this.numOutcomes]);
            for (int i4 = 0; i4 < this.numOutcomes; i4++) {
                mutableContextArr[i3].setParameter(i4, 0.0d);
            }
        }
        EvalParameters evalParameters = new EvalParameters(mutableContextArr, this.numOutcomes);
        MutableContext[] mutableContextArr2 = new MutableContext[this.numPreds];
        if (z) {
            for (int i5 = 0; i5 < this.numPreds; i5++) {
                mutableContextArr2[i5] = new MutableContext(iArr, new double[this.numOutcomes]);
                for (int i6 = 0; i6 < this.numOutcomes; i6++) {
                    mutableContextArr2[i5].setParameter(i6, 0.0d);
                }
            }
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        int i7 = 0;
        double d4 = 1.0d;
        int i8 = 1;
        while (true) {
            if (i8 > i) {
                break;
            }
            if (this.stepSizeDecrease != null) {
                d4 *= 1.0d - this.stepSizeDecrease.doubleValue();
            }
            displayIteration(i8);
            int i9 = 0;
            for (int i10 = 0; i10 < this.numUniqueEvents; i10++) {
                int i11 = this.outcomeList[i10];
                for (int i12 = 0; i12 < this.numTimesEventsSeen[i10]; i12++) {
                    double[] dArr = new double[this.numOutcomes];
                    if (this.values != null) {
                        PerceptronModel.eval(this.contexts[i10], this.values[i10], dArr, evalParameters, false);
                    } else {
                        PerceptronModel.eval(this.contexts[i10], null, dArr, evalParameters, false);
                    }
                    int maxIndex = maxIndex(dArr);
                    if (maxIndex != i11) {
                        for (int i13 = 0; i13 < this.contexts[i10].length; i13++) {
                            int i14 = this.contexts[i10][i13];
                            if (this.values == null) {
                                mutableContextArr[i14].updateParameter(i11, d4);
                                mutableContextArr[i14].updateParameter(maxIndex, -d4);
                            } else {
                                mutableContextArr[i14].updateParameter(i11, d4 * this.values[i10][i13]);
                                mutableContextArr[i14].updateParameter(maxIndex, (-d4) * this.values[i10][i13]);
                            }
                        }
                    }
                    if (maxIndex == i11) {
                        i9++;
                    }
                }
            }
            double d5 = i9 / this.numEvents;
            if (i8 < 10 || i8 % 10 == 0) {
                display(". (" + i9 + "/" + this.numEvents + ") " + d5 + "\n");
            }
            if ((z && this.useSkippedlAveraging && (i8 < 20 || isPerfectSquare(i8))) || z) {
                i7++;
                for (int i15 = 0; i15 < this.numPreds; i15++) {
                    for (int i16 = 0; i16 < this.numOutcomes; i16++) {
                        mutableContextArr2[i15].updateParameter(i16, mutableContextArr[i15].getParameters()[i16]);
                    }
                }
            }
            if (Math.abs(d - d5) < this.tolerance && Math.abs(d2 - d5) < this.tolerance && Math.abs(d3 - d5) < this.tolerance) {
                display("Stopping: change in training set accuracy less than " + this.tolerance + "\n");
                break;
            }
            d = d2;
            d2 = d3;
            d3 = d5;
            i8++;
        }
        trainingStats(evalParameters);
        if (!z) {
            return mutableContextArr;
        }
        for (int i17 = 0; i17 < this.numPreds; i17++) {
            for (int i18 = 0; i18 < this.numOutcomes; i18++) {
                mutableContextArr2[i17].setParameter(i18, mutableContextArr2[i17].getParameters()[i18] / i7);
            }
        }
        return mutableContextArr2;
    }

    private double trainingStats(EvalParameters evalParameters) {
        int i = 0;
        for (int i2 = 0; i2 < this.numUniqueEvents; i2++) {
            for (int i3 = 0; i3 < this.numTimesEventsSeen[i2]; i3++) {
                double[] dArr = new double[this.numOutcomes];
                if (this.values != null) {
                    PerceptronModel.eval(this.contexts[i2], this.values[i2], dArr, evalParameters, false);
                } else {
                    PerceptronModel.eval(this.contexts[i2], null, dArr, evalParameters, false);
                }
                if (maxIndex(dArr) == this.outcomeList[i2]) {
                    i++;
                }
            }
        }
        double d = i / this.numEvents;
        display("Stats: (" + i + "/" + this.numEvents + ") " + d + "\n");
        return d;
    }

    private int maxIndex(double[] dArr) {
        int i = 0;
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i2] > dArr[i]) {
                i = i2;
            }
        }
        return i;
    }

    private void display(String str) {
        if (this.printMessages) {
            System.out.print(str);
        }
    }

    private void displayIteration(int i) {
        if (i <= 10 || i % 10 == 0) {
            if (i < 10) {
                display("  " + i + ":  ");
            } else if (i < 100) {
                display(MinimalPrettyPrinter.DEFAULT_ROOT_VALUE_SEPARATOR + i + ":  ");
            } else {
                display(i + ":  ");
            }
        }
    }

    private static boolean isPerfectSquare(int i) {
        int sqrt = (int) Math.sqrt(i);
        return sqrt * sqrt == i;
    }
}
