package hex;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OKeyNotFoundArgumentException;
import water.fvec.Frame;
import water.util.IcedHashMap;
import water.util.Log;
import water.util.PojoUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/ModelMetrics.class */
public class ModelMetrics extends Keyed<ModelMetrics> {
    public String _description;
    private Key _modelKey;
    private ModelCategory _model_category;
    private long _model_checksum;
    private Key _frameKey;
    private long _frame_checksum;
    public final long _scoring_time;
    public final CustomMetric _custom_metric;
    private transient Model _model;
    private transient Frame _frame;
    public final double _MSE;
    public final long _nobs;

    /* loaded from: input_file:hex/ModelMetrics$MetricBuilder.class */
    public static abstract class MetricBuilder<T extends MetricBuilder<T>> extends Iced<T> {
        public transient double[] _work;
        public double _sumsqe;
        public long _count;
        public double _wcount;
        public double _wY;
        public double _wYY;
        public CustomMetric _customMetric = null;
        static final /* synthetic */ boolean $assertionsDisabled;

        public double weightedSigma() {
            if (this._count <= 1) {
                return 0.0d;
            }
            return Math.sqrt(1.0d * ((this._wYY / this._wcount) - ((this._wY * this._wY) / (this._wcount * this._wcount))));
        }

        public abstract double[] perRow(double[] dArr, float[] fArr, Model model);

        public double[] perRow(double[] dArr, float[] fArr, double d, double d2, Model model) {
            if ($assertionsDisabled || (d == 1.0d && d2 == 0.0d)) {
                return perRow(dArr, fArr, model);
            }
            throw new AssertionError();
        }

        public void reduce(T t) {
            this._sumsqe += t._sumsqe;
            this._count += t._count;
            this._wcount += t._wcount;
            this._wY += t._wY;
            this._wYY += t._wYY;
        }

        public void postGlobal() {
            postGlobal(null);
        }

        public void postGlobal(CustomMetric customMetric) {
            this._customMetric = customMetric;
        }

        public abstract ModelMetrics makeModelMetrics(Model model, Frame frame, Frame frame2, Frame frame3);

        public void setCustomMetric(CustomMetric customMetric) {
            this._customMetric = customMetric;
        }

        static {
            $assertionsDisabled = !ModelMetrics.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/ModelMetrics$MetricsComparator.class */
    private static class MetricsComparator implements Comparator<Key<Model>> {
        String _sort_by;
        boolean decreasing;

        public MetricsComparator(String str, boolean z) {
            this._sort_by = null;
            this.decreasing = false;
            this._sort_by = str;
            this.decreasing = z;
        }

        @Override // java.util.Comparator
        public int compare(Key<Model> key, Key<Model> key2) {
            double metricFromModel = ModelMetrics.getMetricFromModel(key, this._sort_by);
            double metricFromModel2 = ModelMetrics.getMetricFromModel(key2, this._sort_by);
            return this.decreasing ? Double.compare(metricFromModel2, metricFromModel) : Double.compare(metricFromModel, metricFromModel2);
        }
    }

    /* loaded from: input_file:hex/ModelMetrics$MetricsComparatorForFrame.class */
    private static class MetricsComparatorForFrame implements Comparator<Key<Model>> {
        String _sort_by;
        boolean decreasing;
        Frame frame;
        IcedHashMap<Key<Model>, ModelMetrics> cachedMetrics = new IcedHashMap<>();

        public MetricsComparatorForFrame(Frame frame, String str, boolean z) {
            this._sort_by = null;
            this.decreasing = false;
            this.frame = null;
            this._sort_by = str;
            this.decreasing = z;
            this.frame = frame;
        }

        private final ModelMetrics findMetricsForModel(Key<Model> key) {
            ModelMetrics modelMetrics = this.cachedMetrics.get(key);
            if (null != modelMetrics) {
                return modelMetrics;
            }
            if (null == key.get()) {
                Log.warn("Tried to compare metrics for a model which was not found in the DKV: " + key);
                throw new H2OKeyNotFoundArgumentException(key.toString());
            }
            Model model = key.get();
            ModelMetrics fromDKV = ModelMetrics.getFromDKV(model, this.frame);
            if (null == fromDKV) {
                model.score(this.frame);
                fromDKV = ModelMetrics.getFromDKV(model, this.frame);
                if (null == fromDKV) {
                    Log.warn("Tried to compare metrics for a model/frame combination which was not found in the DKV: (" + key + ", " + this.frame._key.toString() + ")");
                    throw new H2OKeyNotFoundArgumentException(key.toString());
                }
            }
            this.cachedMetrics.put(key, fromDKV);
            return fromDKV;
        }

        @Override // java.util.Comparator
        public int compare(Key<Model> key, Key<Model> key2) {
            ModelMetrics findMetricsForModel = findMetricsForModel(key);
            ModelMetrics findMetricsForModel2 = findMetricsForModel(key2);
            double metricFromModelMetric = ModelMetrics.getMetricFromModelMetric(findMetricsForModel, this._sort_by);
            double metricFromModelMetric2 = ModelMetrics.getMetricFromModelMetric(findMetricsForModel2, this._sort_by);
            return this.decreasing ? Double.compare(metricFromModelMetric2, metricFromModelMetric) : Double.compare(metricFromModelMetric, metricFromModelMetric2);
        }
    }

    public ModelMetrics(Model model, Frame frame, long j, double d, String str, CustomMetric customMetric) {
        super(buildKey(model, frame));
        withModelAndFrame(model, frame);
        this._description = str;
        this._MSE = d;
        this._nobs = j;
        this._scoring_time = System.currentTimeMillis();
        this._custom_metric = customMetric;
    }

    private void setModelAndFrameFields(Model model, Frame frame) {
        PojoUtils.setField(this, "_modelKey", model == null ? null : model._key);
        PojoUtils.setField(this, "_frameKey", frame == null ? null : frame._key);
        PojoUtils.setField(this, "_model_category", model == null ? null : model._output.getModelCategory());
        PojoUtils.setField(this, "_model_checksum", Long.valueOf(model == null ? 0L : model.checksum()));
        try {
            PojoUtils.setField(this, "_frame_checksum", Long.valueOf(frame.checksum()));
        } catch (Throwable th) {
        }
    }

    public final ModelMetrics withModelAndFrame(Model model, Frame frame) {
        long checksum;
        this._modelKey = model == null ? null : model._key;
        this._model_category = model == null ? null : model._output.getModelCategory();
        this._model_checksum = model == null ? 0L : model.checksum();
        this._frameKey = frame == null ? null : frame._key;
        if (frame == null) {
            checksum = 0;
        } else {
            try {
                checksum = frame.checksum();
            } catch (Throwable th) {
            }
        }
        this._frame_checksum = checksum;
        this._key = buildKey(model, frame);
        return this;
    }

    public ModelMetrics withDescription(String str) {
        this._description = str;
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public ModelMetrics deepCloneWithDifferentModelAndFrame(Model model, Frame frame) {
        ModelMetrics modelMetrics = (ModelMetrics) m321clone();
        modelMetrics._key = buildKey(model, frame);
        modelMetrics.setModelAndFrameFields(model, frame);
        return modelMetrics;
    }

    public long residual_degrees_of_freedom() {
        throw new UnsupportedOperationException("residual degrees of freedom is not supported for this metric class");
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Model Metrics Type: " + getClass().getSimpleName().substring(12) + "\n");
        sb.append(" Description: " + (this._description == null ? "N/A" : this._description) + "\n");
        sb.append(" model id: " + this._modelKey + "\n");
        sb.append(" frame id: " + this._frameKey + "\n");
        sb.append(" MSE: " + ((float) this._MSE) + "\n");
        sb.append(" RMSE: " + ((float) rmse()) + "\n");
        return sb.toString();
    }

    public final Model model() {
        if (this._model != null) {
            return this._model;
        }
        Model model = (Model) DKV.getGet(this._modelKey);
        this._model = model;
        return model;
    }

    public final Frame frame() {
        if (this._frame != null) {
            return this._frame;
        }
        Frame frame = (Frame) DKV.getGet(this._frameKey);
        this._frame = frame;
        return frame;
    }

    public double mse() {
        return this._MSE;
    }

    public double rmse() {
        return Math.sqrt(this._MSE);
    }

    public ConfusionMatrix cm() {
        return null;
    }

    public float[] hr() {
        return null;
    }

    public AUC2 auc_obj() {
        return null;
    }

    public static double getMetricFromModel(Key<Model> key, String str) {
        Model model = (Model) DKV.getGet(key);
        if (null == model) {
            throw new H2OIllegalArgumentException("Cannot find model " + key);
        }
        return getMetricFromModelMetric(model._output._cross_validation_metrics != null ? model._output._cross_validation_metrics : model._output._validation_metrics != null ? model._output._validation_metrics : model._output._training_metrics, str);
    }

    public static double getMetricFromModelMetric(ModelMetrics modelMetrics, String str) {
        double doubleValue;
        if (null == str || str.equals("")) {
            throw new H2OIllegalArgumentException("Need a valid criterion, but got '" + str + "'.");
        }
        Method method = null;
        ConfusionMatrix cm = modelMetrics.cm();
        try {
            method = modelMetrics.getClass().getMethod(str.toLowerCase(), new Class[0]);
        } catch (Exception e) {
        }
        if (null == method && null != cm) {
            try {
                method = cm.getClass().getMethod(str.toLowerCase(), new Class[0]);
            } catch (Exception e2) {
            }
        }
        if (null == method) {
            throw new H2OIllegalArgumentException("Failed to find ModelMetrics for criterion: " + str);
        }
        try {
            doubleValue = ((Double) method.invoke(modelMetrics, new Object[0])).doubleValue();
        } catch (Exception e3) {
            try {
                doubleValue = ((Double) method.invoke(cm, new Object[0])).doubleValue();
            } catch (Exception e4) {
                throw new H2OIllegalArgumentException("Failed to get metric: " + str + " from ModelMetrics object: " + modelMetrics, "Failed to get metric: " + str + " from ModelMetrics object: " + modelMetrics + ", criterion: " + method + ", exception: " + e4);
            }
        }
        return doubleValue;
    }

    public static Set<String> getAllowedMetrics(Key<Model> key) {
        HashSet hashSet = new HashSet();
        Model model = (Model) DKV.getGet(key);
        if (null == model) {
            throw new H2OIllegalArgumentException("Cannot find model " + key);
        }
        ModelMetrics modelMetrics = model._output._cross_validation_metrics != null ? model._output._cross_validation_metrics : model._output._validation_metrics != null ? model._output._validation_metrics : model._output._training_metrics;
        ConfusionMatrix cm = modelMetrics.cm();
        HashSet hashSet2 = new HashSet();
        hashSet2.add("makeSchema");
        hashSet2.add("hr");
        hashSet2.add("cm");
        hashSet2.add("auc_obj");
        hashSet2.add("remove");
        hashSet2.add("nobs");
        if (modelMetrics != null) {
            for (Method method : modelMetrics.getClass().getMethods()) {
                if (!hashSet2.contains(method.getName())) {
                    try {
                        ((Double) method.invoke(modelMetrics, new Object[0])).doubleValue();
                        hashSet.add(method.getName().toLowerCase());
                    } catch (Exception e) {
                    }
                }
            }
        }
        if (cm != null) {
            for (Method method2 : cm.getClass().getMethods()) {
                if (!hashSet2.contains(method2.getName())) {
                    try {
                        ((Double) method2.invoke(cm, new Object[0])).doubleValue();
                        hashSet.add(method2.getName().toLowerCase());
                    } catch (Exception e2) {
                    }
                }
            }
        }
        return hashSet;
    }

    public static List<Key<Model>> sortModelsByMetric(String str, boolean z, List<Key<Model>> list) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(list);
        Collections.sort(arrayList, new MetricsComparator(str, z));
        return arrayList;
    }

    public static List<Key<Model>> sortModelsByMetric(Frame frame, String str, boolean z, List<Key<Model>> list) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(list);
        Collections.sort(arrayList, new MetricsComparatorForFrame(frame, str, z));
        return arrayList;
    }

    public static TwoDimTable calcVarImp(VarImp varImp) {
        if (varImp == null) {
            return null;
        }
        double[] dArr = new double[varImp._varimp.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = varImp._varimp[i];
        }
        return calcVarImp(dArr, varImp._names);
    }

    public static TwoDimTable calcVarImp(float[] fArr, String[] strArr) {
        double[] dArr = new double[fArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = fArr[i];
        }
        return calcVarImp(dArr, strArr);
    }

    public static TwoDimTable calcVarImp(double[] dArr, String[] strArr) {
        return calcVarImp(dArr, strArr, "Variable Importances", new String[]{"Relative Importance", "Scaled Importance", "Percentage"});
    }

    /* JADX WARN: Type inference failed for: r9v2, types: [java.lang.String[], java.lang.String[][]] */
    public static TwoDimTable calcVarImp(final double[] dArr, String[] strArr, String str, String[] strArr2) {
        if (dArr == null) {
            return null;
        }
        if (strArr == null) {
            strArr = new String[dArr.length];
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = "C" + String.valueOf(i + 1);
            }
        }
        Integer[] numArr = new Integer[dArr.length];
        for (int i2 = 0; i2 < numArr.length; i2++) {
            numArr[i2] = Integer.valueOf(i2);
        }
        Arrays.sort(numArr, new Comparator<Integer>() { // from class: hex.ModelMetrics.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                return Double.compare(-dArr[num.intValue()], -dArr[num2.intValue()]);
            }
        });
        double d = 0.0d;
        double d2 = dArr[numArr[0].intValue()];
        String[] strArr3 = new String[dArr.length];
        double[][] dArr2 = new double[dArr.length][3];
        int i3 = 0;
        for (Integer num : numArr) {
            int intValue = num.intValue();
            d += dArr[intValue];
            strArr3[i3] = strArr[intValue];
            dArr2[i3][0] = dArr[intValue];
            int i4 = i3;
            i3++;
            dArr2[i4][1] = dArr[intValue] / d2;
        }
        int i5 = 0;
        for (Integer num2 : numArr) {
            int i6 = i5;
            i5++;
            dArr2[i6][2] = dArr[num2.intValue()] / d;
        }
        String[] strArr4 = new String[3];
        String[] strArr5 = new String[3];
        Arrays.fill(strArr4, "double");
        Arrays.fill(strArr5, "%5f");
        return new TwoDimTable(str, null, strArr3, strArr2, strArr4, strArr5, "Variable", new String[dArr.length], dArr2);
    }

    public static Key<ModelMetrics> buildKey(Key key, long j, Key key2, long j2) {
        return Key.make("modelmetrics_" + key + "@" + j + "_on_" + key2 + "@" + j2);
    }

    public static Key<ModelMetrics> buildKey(Model model, Frame frame) {
        if (frame == null || model == null) {
            return null;
        }
        return buildKey(model._key, model.checksum(), frame._key, frame.checksum());
    }

    public boolean isForModel(Model model) {
        return this._model_checksum == model.checksum();
    }

    public boolean isForFrame(Frame frame) {
        return this._frame_checksum == frame.checksum();
    }

    public static ModelMetrics getFromDKV(Model model, Frame frame) {
        return (ModelMetrics) DKV.getGet(buildKey(model, frame));
    }

    @Override // water.Keyed
    protected long checksum_impl() {
        return (this._frame_checksum * 13) + (this._model_checksum * 17);
    }
}
