package water.rapids;

import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import water.Keyed;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MRUtils;
import water.util.TwoDimTable;
import water.util.VecUtils;

/* loaded from: input_file:water/rapids/PermutationVarImp.class */
public class PermutationVarImp {
    private final Model _model;
    private final Frame _inputFrame;
    static final /* synthetic */ boolean $assertionsDisabled;

    public PermutationVarImp(Model model, Frame frame) {
        if (frame.numRows() < 2) {
            throw new IllegalArgumentException("Frame must contain more than 1 rows to be used in permutation variable importance!");
        }
        if (!ArrayUtils.contains(frame.names(), model._parms._response_column)) {
            throw new IllegalArgumentException("Frame must contain the response column for the use in permutation variable importance!");
        }
        this._model = model;
        this._inputFrame = frame;
    }

    private static double getMetric(ModelMetrics modelMetrics, String str) {
        if (!$assertionsDisabled && modelMetrics == null) {
            throw new AssertionError();
        }
        double metricFromModelMetric = ModelMetrics.getMetricFromModelMetric(modelMetrics, str);
        if (Double.isNaN(metricFromModelMetric)) {
            throw new IllegalArgumentException("Model doesn't support the metric following metric " + str);
        }
        return metricFromModelMetric;
    }

    private String inferAndValidateMetric(String str) {
        Set<String> allowedMetrics = ModelMetrics.getAllowedMetrics(this._model._key);
        String lowerCase = str.toLowerCase();
        if (lowerCase.equals("auto")) {
            if (this._model._output._training_metrics instanceof ModelMetricsBinomial) {
                lowerCase = "auc";
            } else if (this._model._output._training_metrics instanceof ModelMetricsRegression) {
                lowerCase = "rmse";
            } else {
                if (!(this._model._output._training_metrics instanceof ModelMetricsMultinomial)) {
                    throw new IllegalArgumentException("Unable to infer metric. Please specify metric for permutation variable importance.");
                }
                lowerCase = "logloss";
            }
        }
        if (allowedMetrics.contains(lowerCase)) {
            return lowerCase;
        }
        throw new IllegalArgumentException("Permutation Variable Importance doesn't support " + lowerCase + " for model " + this._model._key);
    }

    private Future<Vec> precomputeShuffledVec(ExecutorService executorService, Frame frame, HashSet<String> hashSet, String[] strArr, int i, long j) {
        for (int i2 = i + 1; i2 < frame.numCols(); i2++) {
            if (hashSet.contains(strArr[i2])) {
                int i3 = i2;
                return executorService.submit(() -> {
                    return VecUtils.shuffleVec(frame.vec(i3), j);
                });
            }
        }
        return null;
    }

    /* JADX WARN: Finally extract failed */
    Map<String, Double> calculatePermutationVarImp(String str, long j, String[] strArr, long j2) {
        if (-1 == j2) {
            j2 = new Random().nextLong();
        }
        if (j == 1) {
            throw new IllegalArgumentException("Unable to permute one row. Please set n_samples to higher value or to -1 to use the whole dataset.");
        }
        String[] names = this._inputFrame.names();
        HashSet<String> hashSet = new HashSet<>(Arrays.asList((null == strArr || strArr.length <= 0) ? names : strArr));
        hashSet.removeAll(Arrays.asList(this._model._parms.getNonPredictors()));
        if (this._model._parms._ignored_columns != null) {
            hashSet.removeAll(Arrays.asList(this._model._parms._ignored_columns));
        }
        Frame sampleFrame = j > 1 ? (j > 1000 || this._model._parms._weights_column != null) ? MRUtils.sampleFrame(this._inputFrame, j, this._model._parms._weights_column, j2) : MRUtils.sampleFrameSmall(this._inputFrame, (int) j, j2) : this._inputFrame;
        this._model.score(sampleFrame).remove();
        double metric = getMetric(ModelMetrics.getFromDKV(this._model, sampleFrame), str);
        ExecutorService newSingleThreadExecutor = Executors.newSingleThreadExecutor();
        Keyed keyed = null;
        Future<Vec> precomputeShuffledVec = precomputeShuffledVec(newSingleThreadExecutor, sampleFrame, hashSet, names, -1, j2);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < sampleFrame.numCols(); i++) {
            try {
                try {
                    if (hashSet.contains(names[i])) {
                        if (!$assertionsDisabled && precomputeShuffledVec == null) {
                            throw new AssertionError();
                        }
                        Vec vec = precomputeShuffledVec.get();
                        precomputeShuffledVec = precomputeShuffledVec(newSingleThreadExecutor, sampleFrame, hashSet, names, i, j2);
                        Vec replace = sampleFrame.replace(i, vec);
                        this._model.score(sampleFrame).remove();
                        hashMap.put(names[i], Double.valueOf(Math.abs(getMetric(ModelMetrics.getFromDKV(this._model, sampleFrame), str) - metric)));
                        sampleFrame.replace(i, replace);
                        vec.remove();
                        keyed = null;
                    }
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException("Unable to calculate the permutation variable importance.", e);
                }
            } catch (Throwable th) {
                if (null != sampleFrame && sampleFrame != this._inputFrame) {
                    sampleFrame.remove();
                }
                if (null != keyed) {
                    keyed.remove();
                }
                if (null != precomputeShuffledVec) {
                    precomputeShuffledVec.cancel(true);
                }
                newSingleThreadExecutor.shutdownNow();
                throw th;
            }
        }
        if (null != sampleFrame && sampleFrame != this._inputFrame) {
            sampleFrame.remove();
        }
        if (null != keyed) {
            keyed.remove();
        }
        if (null != precomputeShuffledVec) {
            precomputeShuffledVec.cancel(true);
        }
        newSingleThreadExecutor.shutdownNow();
        return hashMap;
    }

    public TwoDimTable getPermutationVarImp(String str, long j, String[] strArr, long j2) {
        Map<String, Double> calculatePermutationVarImp = calculatePermutationVarImp(inferAndValidateMetric(str), j, strArr, j2);
        String[] strArr2 = new String[calculatePermutationVarImp.size()];
        double[] dArr = new double[calculatePermutationVarImp.size()];
        int i = 0;
        for (Map.Entry<String, Double> entry : calculatePermutationVarImp.entrySet()) {
            strArr2[i] = entry.getKey();
            int i2 = i;
            i++;
            dArr[i2] = entry.getValue().doubleValue();
        }
        return ModelMetrics.calcVarImp(dArr, strArr2);
    }

    /* JADX WARN: Type inference failed for: r9v2, types: [java.lang.String[], java.lang.String[][]] */
    public TwoDimTable getRepeatedPermutationVarImp(String str, long j, int i, String[] strArr, long j2) {
        String inferAndValidateMetric = inferAndValidateMetric(str);
        Map[] mapArr = new HashMap[i];
        for (int i2 = 0; i2 < i; i2++) {
            mapArr[i2] = calculatePermutationVarImp(inferAndValidateMetric, j, strArr, j2 == -1 ? -1L : j2 + i2);
        }
        String[] strArr2 = new String[mapArr[0].size()];
        double[][] dArr = new double[mapArr[0].size()][i];
        ArrayList<Map.Entry> arrayList = new ArrayList(mapArr[0].entrySet());
        arrayList.sort(Map.Entry.comparingByValue(Collections.reverseOrder()));
        int i3 = 0;
        for (Map.Entry entry : arrayList) {
            strArr2[i3] = (String) entry.getKey();
            for (int i4 = 0; i4 < i; i4++) {
                dArr[i3][i4] = ((Double) mapArr[i4].get(entry.getKey())).doubleValue();
            }
            i3++;
        }
        return new TwoDimTable("Repeated Permutation Variable Importance", null, strArr2, (String[]) IntStream.range(0, i).mapToObj(i5 -> {
            return "Run " + (i5 + 1);
        }).toArray(i6 -> {
            return new String[i6];
        }), (String[]) IntStream.range(0, i).mapToObj(i7 -> {
            return "double";
        }).toArray(i8 -> {
            return new String[i8];
        }), null, "Variable", new String[strArr2.length], dArr);
    }

    public TwoDimTable getPermutationVarImp(String str) {
        return getPermutationVarImp(str, -1L, null, -1L);
    }

    static {
        $assertionsDisabled = !PermutationVarImp.class.desiredAssertionStatus();
    }
}
