/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.stats.Statistics;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Math;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Formatter;
import java.util.IllegalFormatException;
import java.util.Locale;

public class LogisticRegression
implements Compilable,
Serializable {
    static final long serialVersionUID = -8585743596322227589L;
    private final Vector[] mWeightVectors;

    public LogisticRegression(Vector[] weightVectors) {
        if (weightVectors.length < 1) {
            String msg = "Require at least one weight vector.";
            throw new IllegalArgumentException(msg);
        }
        int numDimensions = weightVectors[0].numDimensions();
        int k = 1;
        while (k < weightVectors.length) {
            if (numDimensions != weightVectors[k].numDimensions()) {
                String msg = "All weight vectors must be same dimensionality. Found weightVectors[0].numDimensions()=" + numDimensions + " weightVectors[" + k + "]=" + weightVectors[k].numDimensions();
                throw new IllegalArgumentException(msg);
            }
            ++k;
        }
        this.mWeightVectors = weightVectors;
    }

    public LogisticRegression(Vector weightVector) {
        this.mWeightVectors = new Vector[]{weightVector};
    }

    public int numInputDimensions() {
        return this.mWeightVectors[0].numDimensions();
    }

    public int numOutcomes() {
        return this.mWeightVectors.length + 1;
    }

    public Vector[] weightVectors() {
        Vector[] immutables = new Vector[this.mWeightVectors.length];
        int i = 0;
        while (i < immutables.length) {
            immutables[i] = Matrices.unmodifiableVector(this.mWeightVectors[i]);
            ++i;
        }
        return immutables;
    }

    public double[] classify(Vector x) {
        double[] ysHat = new double[this.numOutcomes()];
        this.classify(x, ysHat);
        return ysHat;
    }

    public void classify(Vector x, double[] ysHat) {
        if (this.numInputDimensions() != x.numDimensions()) {
            String msg = "Vector and classifer must be of same dimensionality. Regression model this.numInputDimensions()=" + this.numInputDimensions() + " Vector x.numDimensions()=" + x.numDimensions();
            throw new IllegalArgumentException(msg);
        }
        int numOutcomesMinus1 = ysHat.length - 1;
        ysHat[numOutcomesMinus1] = 0.0;
        double max = 0.0;
        int k = 0;
        while (k < numOutcomesMinus1) {
            ysHat[k] = x.dotProduct(this.mWeightVectors[k]);
            if (ysHat[k] > max) {
                max = ysHat[k];
            }
            ++k;
        }
        double z = 0.0;
        int k2 = 0;
        while (k2 < ysHat.length) {
            ysHat[k2] = java.lang.Math.exp(ysHat[k2] - max);
            z += ysHat[k2];
            ++k2;
        }
        k2 = 0;
        while (k2 < ysHat.length) {
            int n = k2++;
            ysHat[n] = ysHat[n] / z;
        }
    }

    @Override
    public void compileTo(ObjectOutput out) throws IOException {
        out.writeObject(new Externalizer(this));
    }

    Object writeReplace() {
        return new Externalizer(this);
    }

    public static LogisticRegression estimate(Vector[] xs, int[] cs, RegressionPrior prior, AnnealingSchedule annealingSchedule, Reporter reporter, double minImprovement, int minEpochs, int maxEpochs) {
        LogisticRegression hotStart = null;
        ObjectHandler<LogisticRegression> handler = null;
        int rollingAverageSize = 10;
        int blockSize = java.lang.Math.max(1, cs.length / 50);
        return LogisticRegression.estimate(xs, cs, prior, blockSize, hotStart, annealingSchedule, minImprovement, rollingAverageSize, minEpochs, maxEpochs, handler, reporter);
    }

    public static LogisticRegression estimate(Vector[] xs, int[] cs, RegressionPrior prior, int blockSize, LogisticRegression hotStart, AnnealingSchedule annealingSchedule, double minImprovement, int rollingAverageSize, int minEpochs, int maxEpochs, ObjectHandler<LogisticRegression> handler, Reporter reporter) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.info("Logistic Regression Estimation");
        boolean monitoringConvergence = !Double.isNaN(minImprovement);
        reporter.info("Monitoring convergence=" + monitoringConvergence);
        if (minImprovement < 0.0) {
            String msg = "Min improvement should be Double.NaN to turn off convergence or >= 0.0 otherwise. Found minImprovement=" + minImprovement;
            throw new IllegalArgumentException(msg);
        }
        if (xs.length < 1) {
            String msg = "Require at least one training instance.";
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        if (xs.length != cs.length) {
            String msg = "Require same number of training instances as outcomes. Found xs.length=" + xs.length + " cs.length=" + cs.length;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        int numTrainingInstances = xs.length;
        int numOutcomesMinus1 = Math.max(cs);
        int numOutcomes = numOutcomesMinus1 + 1;
        int numDimensions = xs[0].numDimensions();
        prior.verifyNumberOfDimensions(numDimensions);
        int i = 1;
        while (i < xs.length) {
            if (xs[i].numDimensions() != numDimensions) {
                String msg = "Number of dimensions must match for all input vectors. Found xs[0].numDimensions()=" + numDimensions + " xs[" + i + "].numDimensions()=" + xs[i].numDimensions();
                reporter.fatal(msg);
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
        Vector[] weightVectors = new DenseVector[numOutcomesMinus1];
        if (hotStart == null) {
            int k = 0;
            while (k < numOutcomesMinus1) {
                weightVectors[k] = new DenseVector(numDimensions);
                ++k;
            }
        } else {
            Vector[] hotStartWeightVectors = hotStart.weightVectors();
            int k = 0;
            while (k < weightVectors.length) {
                weightVectors[k] = new DenseVector(hotStartWeightVectors[k]);
                ++k;
            }
        }
        LogisticRegression regression = new LogisticRegression(weightVectors);
        boolean hasPrior = prior != null && !prior.isUniform();
        reporter.info("Number of dimensions=" + numDimensions);
        reporter.info("Number of Outcomes=" + numOutcomes);
        reporter.info("Number of Parameters=" + (long)(numOutcomes - 1) * (long)numDimensions);
        reporter.info("Number of Training Instances=" + cs.length);
        reporter.info("Prior=" + prior);
        reporter.info("Annealing Schedule=" + annealingSchedule);
        reporter.info("Minimum Epochs=" + minEpochs);
        reporter.info("Maximum Epochs=" + maxEpochs);
        reporter.info("Minimum Improvement Per Period=" + minImprovement);
        reporter.info("Has Informative Prior=" + hasPrior);
        double lastLog2LikelihoodAndPrior = -8.988465674311579E307;
        double[] rollingAbsDiffs = new double[rollingAverageSize];
        Arrays.fill(rollingAbsDiffs, Double.POSITIVE_INFINITY);
        int rollingAveragePosition = 0;
        double bestLog2LikelihoodAndPrior = Double.NEGATIVE_INFINITY;
        int partialBlockSize = numTrainingInstances % blockSize;
        int numFullBlocks = numTrainingInstances / blockSize;
        double[][] blockCondProbs = new double[blockSize][numOutcomes];
        int epoch = 0;
        while (epoch < maxEpochs) {
            Vector[] weightVectorCopies = annealingSchedule.allowsRejection() ? LogisticRegression.copy((DenseVector[])weightVectors) : weightVectors;
            double learningRate = annealingSchedule.learningRate(epoch);
            int b = 0;
            while (b < numFullBlocks) {
                LogisticRegression.adjustBlock(b * blockSize, (b + 1) * blockSize, xs, cs, (DenseVector[])weightVectors, learningRate, prior, blockCondProbs, regression);
                if (reporter.isDebugEnabled()) {
                    reporter.debug("          epoch " + epoch + " " + (int)(100.0 * (double)(b + 1) * (double)blockSize / (double)numTrainingInstances) + "% complete");
                }
                ++b;
            }
            if (partialBlockSize > 0) {
                LogisticRegression.adjustBlock(numFullBlocks * blockSize, numTrainingInstances, xs, cs, (DenseVector[])weightVectors, learningRate, prior, blockCondProbs, regression);
            }
            if (handler != null) {
                reporter.debug("handling regression for epoch");
                handler.handle(regression);
            }
            if (!monitoringConvergence) {
                reporter.info("Unmonitored Epoch=" + epoch);
            } else {
                boolean acceptUpdate;
                double log2LikelihoodAndPrior;
                block31: {
                    reporter.debug("computing log likelihood");
                    double log2Likelihood = LogisticRegression.log2Likelihood(xs, cs, regression);
                    double log2Prior = prior.log2Prior(weightVectors);
                    log2LikelihoodAndPrior = log2Likelihood + prior.log2Prior(weightVectors);
                    if (log2LikelihoodAndPrior > bestLog2LikelihoodAndPrior) {
                        bestLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
                    }
                    if (reporter.isInfoEnabled()) {
                        Formatter formatter = null;
                        try {
                            try {
                                formatter = new Formatter(Locale.ENGLISH);
                                formatter.format("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f", epoch, learningRate, log2Likelihood, log2Prior, log2LikelihoodAndPrior, bestLog2LikelihoodAndPrior);
                                reporter.info(formatter.toString());
                            }
                            catch (IllegalFormatException e) {
                                reporter.warn("Illegal format in Logistic Regression");
                                if (formatter != null) {
                                    formatter.close();
                                }
                                break block31;
                            }
                        }
                        catch (Throwable throwable) {
                            if (formatter != null) {
                                formatter.close();
                            }
                            throw throwable;
                        }
                        if (formatter != null) {
                            formatter.close();
                        }
                    }
                }
                if (!(acceptUpdate = annealingSchedule.receivedError(epoch, learningRate, -log2LikelihoodAndPrior))) {
                    reporter.info("Annealing rejected update at learningRate=" + learningRate + " error=" + -log2LikelihoodAndPrior);
                    weightVectors = weightVectorCopies;
                    regression = new LogisticRegression(weightVectors);
                } else {
                    double relativeAbsDiff;
                    rollingAbsDiffs[rollingAveragePosition] = relativeAbsDiff = Math.relativeAbsoluteDifference(lastLog2LikelihoodAndPrior, log2LikelihoodAndPrior);
                    if (++rollingAveragePosition == rollingAbsDiffs.length) {
                        rollingAveragePosition = 0;
                    }
                    double rollingAvgAbsDiff = Statistics.mean(rollingAbsDiffs);
                    reporter.debug("relativeAbsDiff=" + relativeAbsDiff + " rollingAvg=" + rollingAvgAbsDiff);
                    lastLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
                    if (rollingAvgAbsDiff < minImprovement) {
                        reporter.info("Converged with Rolling Average Absolute Difference=" + rollingAvgAbsDiff);
                        break;
                    }
                }
            }
            ++epoch;
        }
        return regression;
    }

    public static double log2Likelihood(Vector[] inputs, int[] cats, LogisticRegression regression) {
        if (inputs.length != cats.length) {
            String msg = "Inputs and categories must be same length. Found inputs.length=" + inputs.length + " cats.length=" + cats.length;
            throw new IllegalArgumentException(msg);
        }
        int numTrainingInstances = inputs.length;
        double log2Likelihood = 0.0;
        double[] conditionalProbs = new double[regression.numOutcomes()];
        int j = 0;
        while (j < numTrainingInstances) {
            regression.classify(inputs[j], conditionalProbs);
            log2Likelihood += Math.log2(conditionalProbs[cats[j]]);
            ++j;
        }
        return log2Likelihood;
    }

    private static void adjustBlock(int start, int end, Vector[] xs, int[] cs, DenseVector[] weightVectors, double learningRate, RegressionPrior prior, double[][] conditionalProbs, LogisticRegression regression) {
        int j = start;
        while (j < end) {
            regression.classify(xs[j], conditionalProbs[j - start]);
            ++j;
        }
        j = start;
        while (j < end) {
            int k = 0;
            while (k < weightVectors.length) {
                LogisticRegression.adjustWeightsWithConditionalProbs(weightVectors[k], conditionalProbs[j - start][k], learningRate, xs[j], k, cs[j]);
                ++k;
            }
            ++j;
        }
        if (prior != null && !prior.isUniform()) {
            LogisticRegression.adjustWeightsWithPrior(weightVectors, prior, learningRate * (double)(end - start) / (double)xs.length);
        }
    }

    private static void adjustWeightsWithPrior(DenseVector[] weightVectors, RegressionPrior prior, double learningRate) {
        int k = 0;
        while (k < weightVectors.length) {
            DenseVector weightVectorsK = weightVectors[k];
            int numDimensions = weightVectorsK.numDimensions();
            int i = 0;
            while (i < numDimensions) {
                double priorGradient;
                double delta;
                double priorMode;
                double weight_k_i = weightVectorsK.value(i);
                if (weight_k_i != (priorMode = prior.mode(i)) && (delta = (priorGradient = prior.gradient(weight_k_i, i)) * learningRate) != 0.0) {
                    double adjWeight_k_i = weight_k_i - delta;
                    double mode = prior.mode(i);
                    if (weight_k_i > mode) {
                        if (adjWeight_k_i < mode) {
                            adjWeight_k_i = mode;
                        }
                    } else if (adjWeight_k_i > mode) {
                        adjWeight_k_i = mode;
                    }
                    weightVectorsK.setValue(i, adjWeight_k_i);
                }
                ++i;
            }
            ++k;
        }
    }

    private static void adjustWeightsWithConditionalProbs(DenseVector weightVectorsK, double conditionalProb, double learningRate, Vector xsJ, int k, int csJ) {
        double conditionalProbMinusTruth;
        double d = conditionalProbMinusTruth = k == csJ ? conditionalProb - 1.0 : conditionalProb;
        if (conditionalProbMinusTruth == 0.0) {
            return;
        }
        weightVectorsK.increment(-learningRate * conditionalProbMinusTruth, xsJ);
    }

    private static DenseVector[] copy(DenseVector[] xs) {
        DenseVector[] result = new DenseVector[xs.length];
        int k = 0;
        while (k < xs.length) {
            result[k] = new DenseVector(xs[k]);
            ++k;
        }
        return result;
    }

    static class Externalizer
    extends AbstractExternalizable {
        static final long serialVersionUID = -2256261505231943102L;
        final LogisticRegression mRegression;

        public Externalizer() {
            this(null);
        }

        public Externalizer(LogisticRegression regression) {
            this.mRegression = regression;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            int numOutcomes = this.mRegression.mWeightVectors.length + 1;
            out.writeInt(numOutcomes);
            int numDimensions = this.mRegression.mWeightVectors[0].numDimensions();
            out.writeInt(numDimensions);
            int c = 0;
            while (c < numOutcomes - 1) {
                Vector vC = this.mRegression.mWeightVectors[c];
                int i = 0;
                while (i < numDimensions) {
                    out.writeDouble(vC.value(i));
                    ++i;
                }
                ++c;
            }
        }

        @Override
        public Object read(ObjectInput in) throws IOException {
            int numOutcomes = in.readInt();
            int numDimensions = in.readInt();
            Vector[] weightVectors = new Vector[numOutcomes - 1];
            int c = 0;
            while (c < weightVectors.length) {
                DenseVector weightVectorsC = new DenseVector(numDimensions);
                weightVectors[c] = weightVectorsC;
                int i = 0;
                while (i < numDimensions) {
                    weightVectorsC.setValue(i, in.readDouble());
                    ++i;
                }
                ++c;
            }
            return new LogisticRegression(weightVectors);
        }
    }
}

