package hex;

import hex.AUUC;
import hex.ModelMetricsSupervised;
import java.util.Arrays;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.C8DVolatileChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;

/* loaded from: input_file:hex/ModelMetricsBinomialUplift.class */
public class ModelMetricsBinomialUplift extends ModelMetricsSupervised {
    public final AUUC _auuc;
    public double _ate;
    public double _att;
    public double _atc;

    /* loaded from: input_file:hex/ModelMetricsBinomialUplift$MetricBuilderBinomialUplift.class */
    public static class MetricBuilderBinomialUplift extends ModelMetricsSupervised.MetricBuilderSupervised<MetricBuilderBinomialUplift> {
        protected AUUC.AUUCBuilder _auuc;
        public double _sumTE;
        public double _sumTETreatment;
        public long _treatmentCount;
        static final /* synthetic */ boolean $assertionsDisabled;

        public MetricBuilderBinomialUplift(String[] strArr, double[] dArr) {
            super(2, strArr);
            if (dArr != null) {
                this._auuc = new AUUC.AUUCBuilder(dArr);
            }
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public double[] perRow(double[] dArr, float[] fArr, Model model) {
            return perRow(dArr, fArr, 1.0d, CMAESOptimizer.DEFAULT_STOPFITNESS, model);
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public double[] perRow(double[] dArr, float[] fArr, double d, double d2, Model model) {
            if (!$assertionsDisabled && fArr.length != 2) {
                throw new AssertionError("Treatment must be included in `yact` when calculating AUUC");
            }
            if (Float.isNaN(fArr[0])) {
                return dArr;
            }
            if (d == CMAESOptimizer.DEFAULT_STOPFITNESS || Double.isNaN(d)) {
                return dArr;
            }
            int i = (int) fArr[0];
            if (i != 0 && i != 1) {
                return dArr;
            }
            this._wY += d * i;
            this._wYY += d * i * i;
            this._count++;
            this._wcount += d;
            int i2 = (int) fArr[1];
            double d3 = dArr[0] * d;
            this._sumTE += d3;
            this._sumTETreatment += i2 * d3;
            this._treatmentCount = (long) (this._treatmentCount + (i2 * d));
            if (this._auuc != null) {
                this._auuc.perRow(d3, d, i, i2);
            }
            return dArr;
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public void reduce(MetricBuilderBinomialUplift metricBuilderBinomialUplift) {
            super.reduce(metricBuilderBinomialUplift);
            if (this._auuc != null) {
                this._auuc.reduce(metricBuilderBinomialUplift._auuc);
            }
            this._sumTE += metricBuilderBinomialUplift._sumTE;
            this._sumTETreatment += metricBuilderBinomialUplift._sumTETreatment;
            this._treatmentCount += metricBuilderBinomialUplift._treatmentCount;
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public ModelMetrics makeModelMetrics(Model model, Frame frame, Frame frame2, Frame frame3) {
            Vec vec = null;
            Vec vec2 = null;
            AUUC.AUUCType aUUCType = model == null ? AUUC.AUUCType.AUTO : model._parms._auuc_type;
            if (frame3 != null) {
                if (frame2 == null) {
                    frame2 = frame;
                }
                vec = (model == null && frame2.vec(frame.numCols() - 1).isCategorical()) ? frame2.vec(frame.numCols() - 1) : frame2.vec(model._parms._response_column);
                if (model != null && model._parms._treatment_column != null) {
                    vec2 = frame2.vec(model._parms._treatment_column);
                }
            }
            return makeModelMetrics(model, frame, frame3, vec, vec2, aUUCType, (model == null || model._parms._auuc_nbins == -1) ? 1000 : model._parms._auuc_nbins);
        }

        private ModelMetrics makeModelMetrics(Model model, Frame frame, Frame frame2, Vec vec, Vec vec2, AUUC.AUUCType aUUCType, int i) {
            AUUC auuc = null;
            if (frame2 != null && vec != null) {
                auuc = new AUUC(frame2.vec(0), vec, vec2, aUUCType, i);
            }
            return makeModelMetrics(model, frame, auuc);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public ModelMetrics makeModelMetrics(Model model, Frame frame, AUUC.AUUCType aUUCType) {
            return makeModelMetrics(model, frame, new AUUC(this._auuc, aUUCType));
        }

        public ModelMetrics makeModelMetrics(Model model, Frame frame, AUUC auuc) {
            double d = Double.NaN;
            double d2 = Double.NaN;
            double d3 = Double.NaN;
            double d4 = Double.NaN;
            if (this._wcount > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                if (auuc == null) {
                    d = weightedSigma();
                    auuc = new AUUC(this._auuc, model._parms._auuc_type);
                }
                d2 = this._sumTE / this._wcount;
                d4 = this._sumTETreatment / this._treatmentCount;
                d3 = (this._sumTE - this._sumTETreatment) / (this._wcount - this._treatmentCount);
            } else {
                auuc = new AUUC();
            }
            ModelMetricsBinomialUplift modelMetricsBinomialUplift = new ModelMetricsBinomialUplift(model, frame, this._count, this._domain, d2, d4, d3, d, auuc, this._customMetric);
            if (model != null) {
                model.addModelMetrics(modelMetricsBinomialUplift);
            }
            return modelMetricsBinomialUplift;
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public Frame makePredictionCache(Model model, Vec vec) {
            return new Frame(vec.makeVolatileDoubles(1));
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public void cachePrediction(double[] dArr, Chunk[] chunkArr, int i, int i2, Model model) {
            if (!$assertionsDisabled && dArr.length != 3) {
                throw new AssertionError();
            }
            ((C8DVolatileChunk) chunkArr[i2]).getValues()[i] = dArr[0];
        }

        public String toString() {
            return "";
        }

        static {
            $assertionsDisabled = !ModelMetricsBinomialUplift.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/ModelMetricsBinomialUplift$UpliftBinomialMetrics.class */
    private static class UpliftBinomialMetrics extends MRTask<UpliftBinomialMetrics> {
        String[] domain;
        double[] thresholds;
        public MetricBuilderBinomialUplift _mb;

        public UpliftBinomialMetrics(String[] strArr, double[] dArr) {
            this.domain = strArr;
            this.thresholds = dArr;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            this._mb = new MetricBuilderBinomialUplift(this.domain, this.thresholds);
            Chunk chunk = chunkArr[0];
            Chunk chunk2 = chunkArr[1];
            Chunk chunk3 = chunkArr[2];
            double[] dArr = new double[1];
            float[] fArr = new float[2];
            for (int i = 0; i < chunkArr[0]._len; i++) {
                dArr[0] = chunk.atd(i);
                fArr[0] = (float) chunk2.atd(i);
                fArr[1] = (float) chunk3.atd(i);
                this._mb.perRow(dArr, fArr, 1.0d, CMAESOptimizer.DEFAULT_STOPFITNESS, null);
            }
        }

        @Override // water.MRTask
        public void reduce(UpliftBinomialMetrics upliftBinomialMetrics) {
            this._mb.reduce(upliftBinomialMetrics._mb);
        }
    }

    public ModelMetricsBinomialUplift(Model model, Frame frame, long j, String[] strArr, double d, double d2, double d3, double d4, AUUC auuc, CustomMetric customMetric) {
        super(model, frame, j, CMAESOptimizer.DEFAULT_STOPFITNESS, strArr, d4, customMetric);
        this._ate = d;
        this._att = d2;
        this._atc = d3;
        this._auuc = auuc;
    }

    public static ModelMetricsBinomialUplift getFromDKV(Model model, Frame frame) {
        ModelMetrics fromDKV = ModelMetrics.getFromDKV(model, frame);
        if (fromDKV instanceof ModelMetricsBinomialUplift) {
            return (ModelMetricsBinomialUplift) fromDKV;
        }
        throw new H2OIllegalArgumentException("Expected to find a Binomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsBinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + (fromDKV == null ? null : fromDKV.getClass()));
    }

    @Override // hex.ModelMetricsSupervised, hex.ModelMetrics
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append("ATE:").append((float) this._ate).append("\n");
        sb.append("ATT:").append((float) this._att).append("\n");
        sb.append("ATC:").append((float) this._atc).append("\n");
        if (this._auuc != null) {
            sb.append("Default AUUC: ").append((float) this._auuc.auuc()).append("\n");
            sb.append("Qini AUUC: ").append((float) this._auuc.auucByType(AUUC.AUUCType.qini)).append("\n");
            sb.append("Lift AUUC: ").append((float) this._auuc.auucByType(AUUC.AUUCType.lift)).append("\n");
            sb.append("Gain AUUC: ").append((float) this._auuc.auucByType(AUUC.AUUCType.gain)).append("\n");
            sb.append("Normalized Qini AUUC: ").append((float) this._auuc.auucNormalizedByType(AUUC.AUUCType.qini)).append("\n");
            sb.append("Normalized Lift AUUC: ").append((float) this._auuc.auucNormalizedByType(AUUC.AUUCType.lift)).append("\n");
            sb.append("Normalized Gain AUUC: ").append((float) this._auuc.auucNormalizedByType(AUUC.AUUCType.gain)).append("\n");
            sb.append("Qini: ").append((float) this._auuc.qini()).append("\n");
        }
        return sb.toString();
    }

    public double auuc() {
        return this._auuc.auuc();
    }

    public double qini() {
        return this._auuc.qini();
    }

    public double auucNormalized() {
        return this._auuc.auucNormalized();
    }

    public int nbins() {
        return this._auuc._nBins;
    }

    public double ate() {
        return this._ate;
    }

    public double att() {
        return this._att;
    }

    public double atc() {
        return this._atc;
    }

    @Override // hex.ModelMetrics
    protected StringBuilder appendToStringMetrics(StringBuilder sb) {
        return sb;
    }

    public static ModelMetricsBinomialUplift make(Vec vec, Vec vec2, Vec vec3, String[] strArr, AUUC.AUUCType aUUCType, int i, double[] dArr) {
        Scope.enter();
        try {
            Vec categoricalVec = vec2.toCategoricalVec();
            if (strArr == null) {
                strArr = categoricalVec.domain();
            }
            if (categoricalVec == null || vec == null || vec3 == null) {
                throw new IllegalArgumentException("Missing actualLabels or predicted probabilities or treatment values for uplift binomial metrics!");
            }
            if (!vec.isNumeric()) {
                throw new IllegalArgumentException("Predicted probabilities must be numeric per-class probabilities for uplift binomial metrics.");
            }
            if (strArr.length != 2) {
                throw new IllegalArgumentException("Domain must have 2 class labels, but is " + Arrays.toString(strArr) + " for uplift binomial metrics.");
            }
            Vec adaptTo = categoricalVec.adaptTo(strArr);
            if (adaptTo.cardinality() != 2) {
                throw new IllegalArgumentException("Adapted domain must have 2 class labels, but is " + Arrays.toString(adaptTo.domain()) + " for uplift binomial metrics.");
            }
            if (!vec3.isCategorical() || vec3.cardinality() != 2) {
                throw new IllegalArgumentException("Treatment values should be catecorical value and have 2 class " + Arrays.toString(vec3.domain()) + " for uplift binomial uplift metrics.");
            }
            long length = vec3.length();
            if (dArr != null) {
                if (dArr.length == 0) {
                    throw new IllegalArgumentException("Custom AUUC thresholds array should have size greater than 0.");
                }
                if (i != dArr.length) {
                    Log.info("Custom AUUC thresholds are specified, so number of AUUC bins will equal to thresholds size.");
                }
            }
            if (i < -1 || i == 0 || i > length) {
                throw new IllegalArgumentException("The number of bins to calculate AUUC need to be -1 (default value) or higher than zero, but less than data size.");
            }
            if (i == -1) {
                i = 1000 > length ? (int) length : 1000;
            }
            Frame frame = new Frame(vec);
            frame.add("labels", adaptTo);
            frame.add("treatment", vec3);
            MetricBuilderBinomialUplift metricBuilderBinomialUplift = dArr == null ? new UpliftBinomialMetrics(adaptTo.domain(), AUUC.calculateQuantileThresholds(i, vec)).doAll(frame)._mb : new UpliftBinomialMetrics(adaptTo.domain(), dArr).doAll(frame)._mb;
            adaptTo.remove();
            ModelMetricsBinomialUplift modelMetricsBinomialUplift = (ModelMetricsBinomialUplift) metricBuilderBinomialUplift.makeModelMetrics((Model) null, frame, aUUCType);
            modelMetricsBinomialUplift._description = "Computed on user-given predictions and labels.";
            Scope.exit(new Key[0]);
            return modelMetricsBinomialUplift;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
