package hex;

import au.com.bytecode.opencsv.CSVWriter;
import hex.ModelMetricsSupervised;
import hex.genmodel.GenModel;
import java.util.Arrays;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MathUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/ModelMetricsOrdinal.class */
public class ModelMetricsOrdinal extends ModelMetricsSupervised {
    public final float[] _hit_ratios;
    public final ConfusionMatrix _cm;
    public final double _logloss;
    public final double _mean_per_class_error;

    /* loaded from: input_file:hex/ModelMetricsOrdinal$MetricBuilderOrdinal.class */
    public static class MetricBuilderOrdinal<T extends MetricBuilderOrdinal<T>> extends ModelMetricsSupervised.MetricBuilderSupervised<T> {
        double[][] _cm;
        double[] _hits;
        int _K;
        double _logloss;
        public transient double[] _priorDistribution;
        static final /* synthetic */ boolean $assertionsDisabled;

        public MetricBuilderOrdinal(int i, String[] strArr) {
            super(i, strArr);
            this._cm = strArr.length > ConfusionMatrix.maxClasses() ? (double[][]) null : new double[strArr.length][strArr.length];
            this._K = Math.min(10, this._nclasses);
            this._hits = new double[this._K];
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public double[] perRow(double[] dArr, float[] fArr, Model model) {
            return perRow(dArr, fArr, 1.0d, CMAESOptimizer.DEFAULT_STOPFITNESS, model);
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public double[] perRow(double[] dArr, float[] fArr, double d, double d2, Model model) {
            if (this._cm != null && !Float.isNaN(fArr[0]) && !ArrayUtils.hasNaNs(dArr)) {
                if (d == CMAESOptimizer.DEFAULT_STOPFITNESS || Double.isNaN(d)) {
                    return dArr;
                }
                int i = (int) fArr[0];
                this._count++;
                this._wcount += d;
                this._wY += d * i;
                this._wYY += d * i * i;
                double d3 = i + 1 < dArr.length ? 1.0d - dArr[i + 1] : 1.0d;
                this._sumsqe += d * d3 * d3;
                if (!$assertionsDisabled && Double.isNaN(this._sumsqe)) {
                    throw new AssertionError();
                }
                double[] dArr2 = this._cm[i];
                int i2 = (int) dArr[0];
                dArr2[i2] = dArr2[i2] + 1.0d;
                if (this._K > 0 && i < dArr.length - 1) {
                    ModelMetricsOrdinal.updateHits(d, i, dArr, this._hits, model != null ? model._output._priorClassDist : this._priorDistribution);
                }
                this._logloss += d * MathUtils.logloss(d3);
                return dArr;
            }
            return dArr;
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public void reduce(T t) {
            if (this._cm == null) {
                return;
            }
            super.reduce((MetricBuilderOrdinal<T>) t);
            if (!$assertionsDisabled && t._K != this._K) {
                throw new AssertionError();
            }
            ArrayUtils.add(this._cm, t._cm);
            this._hits = ArrayUtils.add(this._hits, t._hits);
            this._logloss += t._logloss;
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public ModelMetrics makeModelMetrics(Model model, Frame frame, Frame frame2, Frame frame3) {
            double d = Double.NaN;
            double d2 = Double.NaN;
            float[] fArr = new float[this._K];
            ConfusionMatrix confusionMatrix = new ConfusionMatrix(this._cm, this._domain);
            double weightedSigma = weightedSigma();
            if (this._wcount > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                if (this._hits != null) {
                    for (int i = 0; i < fArr.length; i++) {
                        fArr[i] = (float) (this._hits[i] / this._wcount);
                    }
                    for (int i2 = 1; i2 < fArr.length; i2++) {
                        int i3 = i2;
                        fArr[i3] = fArr[i3] + fArr[i2 - 1];
                    }
                }
                d = this._sumsqe / this._wcount;
                d2 = this._logloss / this._wcount;
            }
            ModelMetricsOrdinal modelMetricsOrdinal = new ModelMetricsOrdinal(model, frame, this._count, d, this._domain, weightedSigma, confusionMatrix, fArr, d2, this._customMetric);
            if (model != null) {
                model.addModelMetrics(modelMetricsOrdinal);
            }
            return modelMetricsOrdinal;
        }

        static {
            $assertionsDisabled = !ModelMetricsOrdinal.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ModelMetricsOrdinal$OrdinalMetrics.class */
    public static class OrdinalMetrics extends MRTask<OrdinalMetrics> {
        String[] domain;
        private MetricBuilderOrdinal _mb;

        public OrdinalMetrics(String[] strArr) {
            this.domain = strArr;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            this._mb = new MetricBuilderOrdinal(this.domain.length, this.domain);
            Chunk chunk = chunkArr[chunkArr.length - 1];
            double[] dArr = new double[chunkArr.length];
            for (int i = 0; i < chunkArr[0]._len; i++) {
                for (int i2 = 1; i2 < chunkArr.length; i2++) {
                    dArr[i2] = chunkArr[i2 - 1].atd(i);
                }
                dArr[0] = GenModel.getPrediction(dArr, null, dArr, 0.5d);
                this._mb.perRow(dArr, new float[]{(float) chunk.at8(i)}, null);
            }
        }

        @Override // water.MRTask
        public void reduce(OrdinalMetrics ordinalMetrics) {
            this._mb.reduce(ordinalMetrics._mb);
        }
    }

    public ModelMetricsOrdinal(Model model, Frame frame, long j, double d, String[] strArr, double d2, ConfusionMatrix confusionMatrix, float[] fArr, double d3, CustomMetric customMetric) {
        super(model, frame, j, d, strArr, d2, customMetric);
        this._cm = confusionMatrix;
        this._hit_ratios = fArr;
        this._logloss = d3;
        this._mean_per_class_error = (confusionMatrix == null || confusionMatrix.tooLarge()) ? Double.NaN : confusionMatrix.mean_per_class_error();
    }

    @Override // hex.ModelMetricsSupervised, hex.ModelMetrics
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" logloss: " + ((float) this._logloss) + CSVWriter.DEFAULT_LINE_END);
        sb.append(" mean_per_class_error: " + ((float) this._mean_per_class_error) + CSVWriter.DEFAULT_LINE_END);
        sb.append(" hit ratios: " + Arrays.toString(this._hit_ratios) + CSVWriter.DEFAULT_LINE_END);
        if (cm() != null) {
            if (cm().nclasses() <= 20) {
                sb.append(" CM: " + cm().toASCII());
            } else {
                sb.append(" CM: too large to print.\n");
            }
        }
        return sb.toString();
    }

    public double logloss() {
        return this._logloss;
    }

    public double mean_per_class_error() {
        return this._mean_per_class_error;
    }

    @Override // hex.ModelMetrics
    public ConfusionMatrix cm() {
        return this._cm;
    }

    @Override // hex.ModelMetrics
    public float[] hr() {
        return this._hit_ratios;
    }

    public static ModelMetricsOrdinal getFromDKV(Model model, Frame frame) {
        ModelMetrics fromDKV = ModelMetrics.getFromDKV(model, frame);
        if (fromDKV instanceof ModelMetricsOrdinal) {
            return (ModelMetricsOrdinal) fromDKV;
        }
        throw new H2OIllegalArgumentException("Expected to find a Multinomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsMultinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + fromDKV.getClass());
    }

    public static void updateHits(double d, int i, double[] dArr, double[] dArr2) {
        updateHits(d, i, dArr, dArr2, null);
    }

    public static void updateHits(double d, int i, double[] dArr, double[] dArr2, double[] dArr3) {
        if (i == dArr[0]) {
            dArr2[0] = dArr2[0] + 1.0d;
            return;
        }
        double sum = ArrayUtils.sum(dArr2);
        double[] copyOf = Arrays.copyOf(dArr, dArr.length);
        copyOf[1 + ((int) dArr[0])] = 0.0d;
        int i2 = 1;
        while (true) {
            if (i2 >= dArr2.length) {
                break;
            }
            int prediction = GenModel.getPrediction(copyOf, dArr3, dArr, 0.5d);
            copyOf[1 + prediction] = 0.0d;
            if (prediction == i) {
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + d;
                break;
            }
            i2++;
        }
        if (dArr2.length == dArr.length - 1 && ArrayUtils.sum(dArr2) == sum) {
            int length = dArr2.length - 1;
            dArr2[length] = dArr2[length] + d;
        }
    }

    public static TwoDimTable getHitRatioTable(float[] fArr) {
        String str = "Top-" + fArr.length + " Hit Ratios";
        String[] strArr = new String[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            strArr[i] = Integer.toString(i + 1);
        }
        TwoDimTable twoDimTable = new TwoDimTable(str, null, strArr, new String[]{"Hit Ratio"}, new String[]{"float"}, new String[]{"%f"}, "K");
        for (int i2 = 0; i2 < fArr.length; i2++) {
            twoDimTable.set(i2, 0, Float.valueOf(fArr[i2]));
        }
        return twoDimTable;
    }

    public static ModelMetricsOrdinal make(Frame frame, Vec vec) {
        String[] names = frame.names();
        String[] domain = vec.domain();
        if (ArrayUtils.union(names, domain, true).length == names.length + domain.length) {
            throw new IllegalArgumentException("Column names of per-class-probabilities and categorical domain of actual labels have no common values!");
        }
        return make(frame, vec, frame.names());
    }

    public static ModelMetricsOrdinal make(Frame frame, Vec vec, String[] strArr) {
        Scope.enter();
        Vec categoricalVec = vec.toCategoricalVec();
        if (categoricalVec == null || frame == null) {
            throw new IllegalArgumentException("Missing actualLabels or predictedProbs for multinomial metrics!");
        }
        if (categoricalVec.length() != frame.numRows()) {
            throw new IllegalArgumentException("Both arguments must have the same length for multinomial metrics (" + categoricalVec.length() + "!=" + frame.numRows() + ")!");
        }
        for (Vec vec2 : frame.vecs()) {
            if (!vec2.isNumeric()) {
                throw new IllegalArgumentException("Predicted probabilities must be numeric per-class probabilities for multinomial metrics.");
            }
            if (vec2.min() < CMAESOptimizer.DEFAULT_STOPFITNESS || vec2.max() > 1.0d) {
                throw new IllegalArgumentException("Predicted probabilities must be between 0 and 1 for multinomial metrics.");
            }
        }
        int numCols = frame.numCols();
        if (strArr.length != numCols) {
            throw new IllegalArgumentException("Given domain has " + strArr.length + " classes, but predictions have " + numCols + " columns (per-class probabilities) for multinomial metrics.");
        }
        Vec adaptTo = categoricalVec.adaptTo(strArr);
        Frame frame2 = new Frame(frame);
        frame2.add("labels", adaptTo);
        MetricBuilderOrdinal metricBuilderOrdinal = new OrdinalMetrics(adaptTo.domain()).doAll(frame2)._mb;
        adaptTo.remove();
        ModelMetricsOrdinal modelMetricsOrdinal = (ModelMetricsOrdinal) metricBuilderOrdinal.makeModelMetrics(null, frame2, null, null);
        modelMetricsOrdinal._description = "Computed on user-given predictions and labels.";
        Scope.exit(new Key[0]);
        return modelMetricsOrdinal;
    }
}
