package hex.Infogram;

import hex.Infogram.InfogramModel;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderHelper;
import hex.SplitFrame;
import hex.deeplearning.DeepLearningModel;
import hex.glm.GLMModel;
import hex.schemas.DRFV3;
import hex.schemas.DeepLearningV3;
import hex.schemas.GBMV3;
import hex.schemas.GLMV3;
import hex.schemas.XGBoostV3;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBMModel;
import hex.tree.xgboost.XGBoostModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.DoubleStream;
import water.DKV;
import water.Key;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/Infogram/InfogramUtils.class */
public class InfogramUtils {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.Infogram.InfogramUtils$1, reason: invalid class name */
    /* loaded from: input_file:hex/Infogram/InfogramUtils$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm = new int[InfogramModel.InfogramParameters.Algorithm.values().length];

        static {
            try {
                $SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm[InfogramModel.InfogramParameters.Algorithm.AUTO.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm[InfogramModel.InfogramParameters.Algorithm.gbm.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm[InfogramModel.InfogramParameters.Algorithm.glm.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm[InfogramModel.InfogramParameters.Algorithm.deeplearning.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm[InfogramModel.InfogramParameters.Algorithm.drf.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm[InfogramModel.InfogramParameters.Algorithm.xgboost.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    public static String[] extractPredictors(InfogramModel.InfogramParameters infogramParameters, Frame frame, String str) {
        ArrayList arrayList = new ArrayList(Arrays.asList(frame.names()));
        for (String str2 : infogramParameters.getNonPredictors()) {
            arrayList.remove(str2);
        }
        if (infogramParameters._protected_columns != null) {
            for (String str3 : infogramParameters._protected_columns) {
                arrayList.remove(str3);
            }
        }
        if (str != null) {
            arrayList.remove(str);
        }
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    public static String[] extractTopKPredictors(InfogramModel.InfogramParameters infogramParameters, Frame frame, String[] strArr, Key<Frame>[] keyArr) {
        if (infogramParameters._top_n_features >= strArr.length) {
            return strArr;
        }
        Frame extractTrainingFrame = extractTrainingFrame(infogramParameters, strArr, 1.0d, frame);
        keyArr[0] = extractTrainingFrame._key;
        infogramParameters._infogram_algorithm_parameters._train = extractTrainingFrame._key;
        Model model = ModelBuilderHelper.trainModelsParallel(buildModelBuilders(buildModelParameters(new Frame[]{extractTrainingFrame}, infogramParameters._infogram_algorithm_parameters, 1, infogramParameters._algorithm)), 1)[0].get();
        Scope.track_generic(model);
        TwoDimTable extractVarImp = extractVarImp(infogramParameters._algorithm, model);
        String[] strArr2 = new String[infogramParameters._top_n_features];
        System.arraycopy(extractVarImp.getRowHeaders(), 0, strArr2, 0, infogramParameters._top_n_features);
        return strArr2;
    }

    public static int findstart(Key<Frame>[] keyArr) {
        int length = keyArr.length;
        for (int i = 0; i < length; i++) {
            if (keyArr[i] == null) {
                return i;
            }
        }
        return -1;
    }

    public static TwoDimTable extractVarImp(InfogramModel.InfogramParameters.Algorithm algorithm, Model model) {
        switch (AnonymousClass1.$SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm[algorithm.ordinal()]) {
            case InfogramModel.InfogramModelOutput._ADMISSIBLE_PREDICTOR_INDEX /* 1 */:
            case 2:
                return ((GBMModel) model)._output._variable_importances;
            case InfogramModel.InfogramModelOutput._RELEVANCE_INDEX /* 3 */:
                return ((GLMModel) model)._output._variable_importances;
            case InfogramModel.InfogramModelOutput._CMI_INDEX /* 4 */:
                return ((DeepLearningModel) model)._output._variable_importances;
            case InfogramModel.InfogramModelOutput._CMI_RAW_INDEX /* 5 */:
                return ((DRFModel) model)._output._variable_importances;
            case 6:
                return ((XGBoostModel) model)._output._variable_importances;
            default:
                return null;
        }
    }

    public static Frame extractTrainingFrame(InfogramModel.InfogramParameters infogramParameters, String[] strArr, double d, Frame frame) {
        if (d < 1.0d) {
            SplitFrame splitFrame = new SplitFrame(frame, new double[]{infogramParameters._data_fraction, 1.0d - infogramParameters._data_fraction}, new Key[]{Key.make("ig_train_" + frame._key), Key.make("ig_discard" + frame._key)});
            splitFrame.exec().get();
            Key[] keyArr = splitFrame._destination_frames;
            frame = (Frame) DKV.get(keyArr[0]).get();
            DKV.remove(keyArr[1]);
        }
        Frame frame2 = new Frame(Key.make());
        if (strArr != null) {
            for (String str : strArr) {
                frame2.add(str, frame.vec(str));
            }
        }
        String[] nonPredictors = infogramParameters.getNonPredictors();
        List asList = Arrays.asList(frame.names());
        boolean z = infogramParameters._weights_column != null && asList.contains("__internal_cv_weights__") && (infogramParameters._weights_column.equals("__internal_cv_weights__") || infogramParameters._weights_column.equals("infogram_internal_cv_weights_"));
        for (String str2 : nonPredictors) {
            if (("__internal_cv_weights__".equals(str2) || "infogram_internal_cv_weights_".equals(str2)) && asList.contains("__internal_cv_weights__")) {
                frame2.add("infogram_internal_cv_weights_", frame.vec("__internal_cv_weights__"));
                infogramParameters._weights_column = "infogram_internal_cv_weights_";
            } else if (str2.equals(infogramParameters._fold_column) && asList.contains(infogramParameters._fold_column) && !z) {
                frame2.add(str2, frame.vec(str2));
            } else if (!str2.equals(infogramParameters._fold_column) && asList.contains(str2)) {
                frame2.add(str2, frame.vec(str2));
            }
        }
        if (infogramParameters._fold_column == null || !asList.contains(infogramParameters._fold_column) || z) {
            infogramParameters._fold_column = null;
        }
        DKV.put(frame2);
        return frame2;
    }

    public static String[] generateModelDescription(String[] strArr, String[] strArr2) {
        String[] strArr3 = new String[strArr.length + 1];
        int length = strArr.length - 1;
        if (strArr2 == null) {
            for (int i = 0; i < length; i++) {
                strArr3[i] = "Model built missing predictor " + strArr[i];
            }
            strArr3[length] = "Full model built with all predictors";
        } else {
            for (int i2 = 0; i2 < length; i2++) {
                strArr3[i2] = "Model built with sensitive_features and predictor " + strArr[i2];
            }
            strArr3[length] = "Model built with sensitive_features only";
        }
        return strArr3;
    }

    public static Model.Parameters[] buildModelParameters(Frame[] frameArr, Model.Parameters parameters, int i, InfogramModel.InfogramParameters.Algorithm algorithm) {
        GLMV3.GLMParametersV3 xGBoostParametersV3;
        switch (AnonymousClass1.$SwitchMap$hex$Infogram$InfogramModel$InfogramParameters$Algorithm[algorithm.ordinal()]) {
            case InfogramModel.InfogramModelOutput._ADMISSIBLE_PREDICTOR_INDEX /* 1 */:
            case 2:
                xGBoostParametersV3 = new GBMV3.GBMParametersV3();
                break;
            case InfogramModel.InfogramModelOutput._RELEVANCE_INDEX /* 3 */:
                xGBoostParametersV3 = new GLMV3.GLMParametersV3();
                break;
            case InfogramModel.InfogramModelOutput._CMI_INDEX /* 4 */:
                xGBoostParametersV3 = new DeepLearningV3.DeepLearningParametersV3();
                break;
            case InfogramModel.InfogramModelOutput._CMI_RAW_INDEX /* 5 */:
                xGBoostParametersV3 = new DRFV3.DRFParametersV3();
                break;
            case 6:
                xGBoostParametersV3 = new XGBoostV3.XGBoostParametersV3();
                break;
            default:
                throw new UnsupportedOperationException("Unknown algo: " + algorithm);
        }
        Model.Parameters[] parametersArr = new Model.Parameters[i];
        for (int i2 = 0; i2 < i; i2++) {
            parametersArr[i2] = (Model.Parameters) xGBoostParametersV3.fillFromImpl(parameters).createAndFillImpl();
            parametersArr[i2]._ignored_columns = null;
            parametersArr[i2]._train = frameArr[i2]._key;
        }
        return parametersArr;
    }

    public static ModelBuilder[] buildModelBuilders(Model.Parameters[] parametersArr) {
        int length = parametersArr.length;
        ModelBuilder[] modelBuilderArr = new ModelBuilder[length];
        for (int i = 0; i < length; i++) {
            modelBuilderArr[i] = ModelBuilder.make(parametersArr[i]);
        }
        return modelBuilderArr;
    }

    public static Frame generateCMIRelevance(String[] strArr, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, boolean z) {
        Vec.VectorGroup vectorGroup = Vec.VectorGroup.VG_LEN1;
        Frame frame = new Frame(Key.make(), z ? new String[]{"column", "admissible", "admissible_index", "total_information", "net_information", "cmi_raw"} : new String[]{"column", "admissible", "admissible_index", "relevance_index", "safety_index", "cmi_raw"}, new Vec[]{Vec.makeVec(strArr, vectorGroup.addVec()), Vec.makeVec(dArr, vectorGroup.addVec()), Vec.makeVec(dArr2, vectorGroup.addVec()), Vec.makeVec(dArr3, vectorGroup.addVec()), Vec.makeVec(dArr4, vectorGroup.addVec()), Vec.makeVec(dArr5, vectorGroup.addVec())});
        DKV.put(frame);
        return frame;
    }

    public static void removeFromDKV(Key<Frame>[] keyArr) {
        Key<Frame> key;
        int length = keyArr.length;
        for (int i = 0; i < length && null != (key = keyArr[i]); i++) {
            DKV.remove(key);
        }
    }

    public static double[] calculateFinalCMI(double[] dArr, boolean z) {
        int length = dArr.length - 1;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            if (z) {
                dArr[i] = Math.max(0.0d, dArr[length] - dArr[i]);
            } else {
                dArr[i] = Math.max(0.0d, dArr[i] - dArr[length]);
            }
            if (dArr[i] > d) {
                d = dArr[i];
            }
        }
        double d2 = d == 0.0d ? 0.0d : 1.0d / d;
        double[] dArr2 = new double[length];
        System.arraycopy(DoubleStream.of(dArr).map(d3 -> {
            return d3 * d2;
        }).toArray(), 0, dArr2, 0, length);
        return dArr2;
    }

    public static Frame subtractAdd2Frame(Frame frame, Frame frame2, String[] strArr, String[] strArr2) {
        Frame frame3 = new Frame(frame);
        if (strArr != null) {
            for (String str : strArr) {
                frame3.remove(str);
            }
        }
        for (String str2 : strArr2) {
            frame3.add(str2, frame2.vec(str2));
        }
        DKV.put(frame3);
        return frame3;
    }

    public static void extractInfogramInfo(InfogramModel infogramModel, double[][] dArr, List<List<String>> list, int i) {
        Frame get = DKV.getGet(((InfogramModel.InfogramModelOutput) infogramModel._output)._admissible_score_key_valid);
        dArr[i] = vec2array(get.vec(5));
        list.add(new ArrayList(Arrays.asList(strVec2array(get.vec(0)))));
        get.remove();
    }

    static double[] vec2array(Vec vec) {
        if (!$assertionsDisabled && vec.length() >= 2147483647L) {
            throw new AssertionError();
        }
        int length = (int) vec.length();
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = vec.at(i);
        }
        return dArr;
    }

    static String[] strVec2array(Vec vec) {
        if (!$assertionsDisabled && vec.length() >= 2147483647L) {
            throw new AssertionError();
        }
        int length = (int) vec.length();
        BufferedString bufferedString = new BufferedString();
        String[] strArr = new String[length];
        for (int i = 0; i < length; i++) {
            BufferedString atStr = vec.atStr(bufferedString, i);
            if (atStr != null) {
                strArr[i] = atStr.toString();
            }
        }
        return strArr;
    }

    static {
        $assertionsDisabled = !InfogramUtils.class.desiredAssertionStatus();
    }
}
