package hex.grid;

import hex.Model;
import hex.Model.Parameters;
import hex.ModelCategory;
import hex.ModelContainer;
import hex.ModelExportOption;
import hex.ModelMetrics;
import hex.ScoringInfo;
import hex.faulttolerance.Recoverable;
import hex.faulttolerance.Recovery;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import water.AutoBuffer;
import water.DKV;
import water.Freezable;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;
import water.fvec.persist.PersistUtils;
import water.persist.Persist;
import water.util.ArrayUtils;
import water.util.FileUtils;
import water.util.IcedHashMap;
import water.util.IcedLong;
import water.util.Log;
import water.util.PojoUtils;
import water.util.StringUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/grid/Grid.class */
public class Grid<MP extends Model.Parameters> extends Lockable<Grid<MP>> implements ModelContainer<Model>, Recoverable<Grid<MP>> {
    public static final Grid GRID_PROTO;
    private final IcedHashMap<IcedLong, Key<Model>> _models;
    private final IcedHashMap<Key<Model>, SearchFailure> _failures;
    private final MP _params;
    private final String[] _hyper_names;
    private HyperParameters _hyper_params;
    private int _parallelism;
    private HyperSpaceSearchCriteria _search_criteria;
    private final PojoUtils.FieldNaming _field_naming_strategy;
    private ScoringInfo[] _scoring_infos;
    private static final Key<Model> NO_MODEL_FAILURES_KEY;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/grid/Grid$SearchFailure.class */
    public static final class SearchFailure<MP extends Model.Parameters> extends Iced<SearchFailure> {
        private MP[] _failed_params;
        private String[] _failure_details;
        private String[] _failure_stack_traces;
        private String[][] _failed_raw_params;
        private String[] _warning_details;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX WARN: Type inference failed for: r1v6, types: [java.lang.String[], java.lang.String[][]] */
        private SearchFailure(Class<MP> cls) {
            this._failed_params = cls != null ? (MP[]) ((Model.Parameters[]) Array.newInstance((Class<?>) cls, 0)) : null;
            this._failure_details = new String[0];
            this._failed_raw_params = new String[0];
            this._failure_stack_traces = new String[0];
            this._warning_details = new String[0];
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void appendFailedModelParameters(MP mp, String[] strArr, String str, String str2) {
            if (!$assertionsDisabled && strArr == null) {
                throw new AssertionError("API has to always pass rawParams");
            }
            MP[] mpArr = this._failed_params;
            MP[] mpArr2 = (MP[]) ((Model.Parameters[]) Arrays.copyOf(mpArr, mpArr.length + 1));
            mpArr2[mpArr.length] = mp;
            this._failed_params = mpArr2;
            String[] strArr2 = this._failure_details;
            String[] strArr3 = (String[]) Arrays.copyOf(strArr2, strArr2.length + 1);
            strArr3[strArr2.length] = str;
            this._failure_details = strArr3;
            String[][] strArr4 = this._failed_raw_params;
            String[][] strArr5 = (String[][]) Arrays.copyOf(strArr4, strArr4.length + 1);
            strArr5[strArr4.length] = strArr;
            this._failed_raw_params = strArr5;
            String[] strArr6 = this._failure_stack_traces;
            String[] strArr7 = (String[]) Arrays.copyOf(strArr6, strArr6.length + 1);
            strArr7[strArr6.length] = str2;
            this._failure_stack_traces = strArr7;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void appendWarningMessage(String[] strArr, String str) {
            if (strArr == null || !Arrays.asList(strArr).contains(str)) {
                return;
            }
            String str2 = null;
            if ("alpha".equals(str)) {
                str2 = "Adding alpha array to hyperparameter runs slower with gridsearch. This is due to the fact that the algo has to run initialization for every alpha value. Setting the alpha array as a model parameter will skip the initialization and run faster overall.";
            }
            if (str2 != null) {
                Log.warn(str2);
                String[] strArr2 = this._warning_details;
                String[] strArr3 = (String[]) Arrays.copyOf(strArr2, strArr2.length + 1);
                strArr3[strArr2.length] = str2;
                this._warning_details = strArr3;
            }
        }

        public void appendFailedModelParameters(MP[] mpArr, String[][] strArr, String[] strArr2, String[] strArr3) {
            if (!$assertionsDisabled && strArr == null) {
                throw new AssertionError("API has to always pass rawParams");
            }
            this._failed_params = (MP[]) ((Model.Parameters[]) ArrayUtils.append(this._failed_params, mpArr));
            this._failed_raw_params = (String[][]) ArrayUtils.append(this._failed_raw_params, strArr);
            this._failure_details = ArrayUtils.append(this._failure_details, strArr2);
            this._failure_stack_traces = ArrayUtils.append(this._failure_stack_traces, strArr3);
        }

        void appendFailedModelParameters(Object[] objArr, Exception exc) {
            if (!$assertionsDisabled && objArr == null) {
                throw new AssertionError("Raw parameters should be always != null !");
            }
            appendFailedModelParameters((SearchFailure<MP>) null, ArrayUtils.toString(objArr), exc.getMessage(), StringUtils.toString(exc));
        }

        public Model.Parameters[] getFailedParameters() {
            return this._failed_params;
        }

        public String[] getFailureDetails() {
            return this._failure_details;
        }

        public String[] getWarningDetails() {
            return this._warning_details;
        }

        public String[] getFailureStackTraces() {
            return this._failure_stack_traces;
        }

        public String[][] getFailedRawParameters() {
            return this._failed_raw_params;
        }

        public int getFailureCount() {
            return this._failed_params.length;
        }

        static {
            $assertionsDisabled = !Grid.class.desiredAssertionStatus();
        }
    }

    protected Grid(Key key, MP mp, String[] strArr, Map<String, Object[]> map, HyperSpaceSearchCriteria hyperSpaceSearchCriteria, PojoUtils.FieldNaming fieldNaming, int i) {
        super(key);
        this._models = new IcedHashMap<>();
        this._scoring_infos = null;
        this._params = mp != null ? (MP) mp.m2123clone() : null;
        this._hyper_names = strArr;
        this._failures = new IcedHashMap<>();
        this._field_naming_strategy = fieldNaming;
        update(map, hyperSpaceSearchCriteria, i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Type inference failed for: r5v1, types: [hex.grid.HyperSpaceSearchCriteria] */
    public Grid(Key key, HyperSpaceWalker<MP, ?> hyperSpaceWalker, int i) {
        this(key, hyperSpaceWalker.getParams(), hyperSpaceWalker.getAllHyperParamNames(), hyperSpaceWalker.getHyperParams(), hyperSpaceWalker.search_criteria(), hyperSpaceWalker.getParametersBuilderFactory().getFieldNamingStrategy(), i);
    }

    public void update(Map<String, Object[]> map, HyperSpaceSearchCriteria hyperSpaceSearchCriteria, int i) {
        this._hyper_params = new HyperParameters(map);
        this._search_criteria = hyperSpaceSearchCriteria;
        this._parallelism = i;
    }

    public Map<String, Object[]> getHyperParams() {
        return this._hyper_params.getValues();
    }

    public HyperSpaceSearchCriteria getSearchCriteria() {
        return this._search_criteria;
    }

    public int getParallelism() {
        return this._parallelism;
    }

    public String getModelName() {
        return this._params.algoName();
    }

    public ScoringInfo[] getScoringInfos() {
        return this._scoring_infos;
    }

    public void setScoringInfos(ScoringInfo[] scoringInfoArr) {
        this._scoring_infos = scoringInfoArr;
    }

    public Frame getTrainingFrame() {
        return this._params.train();
    }

    public Model getModel(MP mp) {
        Key<Model> modelKey = getModelKey((Grid<MP>) mp);
        if (modelKey != null) {
            return modelKey.get();
        }
        return null;
    }

    public Key<Model> getModelKey(MP mp) {
        return getModelKey(mp.checksum(GridSearch.IGNORED_FIELDS_PARAM_HASH));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Key<Model> getModelKey(long j) {
        return this._models.get(IcedLong.valueOf(j));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized Key<Model> putModel(long j, Key<Model> key) {
        return this._models.put(IcedLong.valueOf(j), key);
    }

    private void appendFailedModelParameters(Key<Model> key, MP mp, String[] strArr, Throwable th) {
        String message = isJobCanceled(th) ? "Job Canceled" : th.getMessage();
        String stringUtils = StringUtils.toString(th);
        Key<Model> key2 = key != null ? key : NO_MODEL_FAILURES_KEY;
        SearchFailure searchFailure = this._failures.get(key2);
        if (searchFailure == null) {
            searchFailure = new SearchFailure(this._params.getClass());
            this._failures.put(key2, searchFailure);
        }
        searchFailure.appendFailedModelParameters((SearchFailure) mp, strArr, message, stringUtils);
        searchFailure.appendWarningMessage(this._hyper_names, "alpha");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean isJobCanceled(Throwable th) {
        Throwable th2 = th;
        while (true) {
            Throwable th3 = th2;
            if (th3 == null) {
                return false;
            }
            if (th3 instanceof Job.JobCancelledException) {
                return true;
            }
            th2 = th3.getCause();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void appendFailedModelParameters(Key<Model> key, MP mp, Throwable th) {
        if (!$assertionsDisabled && mp == null) {
            throw new AssertionError("Model parameters should be always != null !");
        }
        appendFailedModelParameters(key, mp, ArrayUtils.toString(getHyperValues(mp)), th);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void appendFailedModelParameters(Key<Model> key, Object[] objArr, Exception exc) {
        if (!$assertionsDisabled && objArr == null) {
            throw new AssertionError("Raw parameters should be always != null !");
        }
        appendFailedModelParameters(key, null, ArrayUtils.toString(objArr), exc);
    }

    @Override // hex.ModelContainer
    public Key<Model>[] getModelKeys() {
        Key<Model>[] keyArr = (Key[]) this._models.values().toArray(new Key[this._models.size()]);
        Arrays.sort(keyArr);
        return keyArr;
    }

    @Override // hex.ModelContainer
    public Model[] getModels() {
        Collection<Key<Model>> values = this._models.values();
        Model[] modelArr = new Model[values.size()];
        int i = 0;
        Iterator<Key<Model>> it = values.iterator();
        while (it.hasNext()) {
            Key<Model> next = it.next();
            modelArr[i] = next != null ? next.get() : null;
            i++;
        }
        return modelArr;
    }

    @Override // hex.ModelContainer
    public int getModelCount() {
        return this._models.size();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SearchFailure getFailures() {
        Collection<SearchFailure> values = this._failures.values();
        SearchFailure searchFailure = new SearchFailure(this._params != null ? this._params.getClass() : null);
        for (SearchFailure searchFailure2 : values) {
            searchFailure.appendFailedModelParameters(searchFailure2._failed_params, searchFailure2._failed_raw_params, searchFailure2._failure_details, searchFailure2._failure_stack_traces);
        }
        searchFailure.appendWarningMessage(this._hyper_names, "alpha");
        return searchFailure;
    }

    public int countTotalFailures() {
        return this._failures.values().stream().mapToInt((v0) -> {
            return v0.getFailureCount();
        }).sum();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void clearNonRelatedFailures() {
        this._failures.remove(NO_MODEL_FAILURES_KEY);
    }

    public Object[] getHyperValues(MP mp) {
        Object[] objArr = new Object[this._hyper_names.length];
        for (int i = 0; i < this._hyper_names.length; i++) {
            objArr[i] = PojoUtils.getFieldValue(mp, this._hyper_names[i], this._field_naming_strategy);
        }
        return objArr;
    }

    public String[] getHyperNames() {
        return this._hyper_names;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // water.Keyed
    public Futures remove_impl(Futures futures, boolean z) {
        if (z) {
            Iterator<Key<Model>> it = this._models.values().iterator();
            while (it.hasNext()) {
                Keyed.remove(it.next(), futures, true);
            }
        }
        this._models.clear();
        return super.remove_impl(futures, z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // water.Keyed
    public AutoBuffer writeAll_impl(AutoBuffer autoBuffer) {
        Iterator<Key<Model>> it = this._models.values().iterator();
        while (it.hasNext()) {
            autoBuffer.putKey(it.next());
        }
        return super.writeAll_impl(autoBuffer);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // water.Keyed
    public Keyed readAll_impl(AutoBuffer autoBuffer, Futures futures) {
        throw H2O.unimpl();
    }

    @Override // water.Keyed
    protected long checksum_impl() {
        throw H2O.unimpl();
    }

    @Override // water.Keyed
    public Class<KeyV3.GridKeyV3> makeSchema() {
        return KeyV3.GridKeyV3.class;
    }

    public TwoDimTable createSummaryTable(Key<Model>[] keyArr, String str, boolean z) {
        if (this._hyper_names == null || keyArr == null || keyArr.length == 0) {
            return null;
        }
        int i = str != null ? 2 : 1;
        String[] strArr = new String[this._hyper_names.length + i];
        String[] strArr2 = new String[this._hyper_names.length + i];
        Arrays.fill(strArr, "string");
        Arrays.fill(strArr2, "%s");
        for (int i2 = 0; i2 < this._hyper_names.length; i2++) {
            Object[] objArr = this._hyper_params.getValues().get(this._hyper_names[i2]);
            if (objArr != null && objArr.length > 0) {
                Object obj = objArr[0];
                if ((obj instanceof Double) || (obj instanceof Float)) {
                    strArr[i2] = "double";
                    strArr2[i2] = "%.5f";
                } else if ((obj instanceof Integer) || (obj instanceof Long)) {
                    strArr[i2] = "long";
                    strArr2[i2] = "%d";
                }
            }
        }
        if (str != null) {
            strArr[strArr.length - 1] = "double";
            strArr2[strArr2.length - 1] = "%.5f";
        }
        String[] strArr3 = (String[]) Arrays.copyOf(this._hyper_names, this._hyper_names.length + i);
        strArr3[this._hyper_names.length] = "model_ids";
        if (str != null) {
            strArr3[this._hyper_names.length + 1] = str;
        }
        TwoDimTable twoDimTable = new TwoDimTable("Hyper-Parameter Search Summary", str != null ? "ordered by " + (z ? "decreasing " : "increasing ") + str : null, new String[this._models.size()], strArr3, strArr, strArr2, "");
        int i3 = 0;
        for (Key<Model> key : keyArr) {
            P p = ((Model) DKV.getGet(key))._parms;
            int i4 = 0;
            while (i4 < this._hyper_names.length) {
                Object fieldValue = PojoUtils.getFieldValue(p, this._hyper_names[i4], this._field_naming_strategy);
                if (fieldValue.getClass().isArray()) {
                    if ((fieldValue instanceof float[]) && ((float[]) fieldValue).length == 1) {
                        fieldValue = Float.valueOf(((float[]) fieldValue)[0]);
                    } else if ((fieldValue instanceof double[]) && ((double[]) fieldValue).length == 1) {
                        fieldValue = Double.valueOf(((double[]) fieldValue)[0]);
                    } else if ((fieldValue instanceof int[]) && ((int[]) fieldValue).length == 1) {
                        fieldValue = Integer.valueOf(((int[]) fieldValue)[0]);
                    } else if ((fieldValue instanceof long[]) && ((long[]) fieldValue).length == 1) {
                        fieldValue = Long.valueOf(((long[]) fieldValue)[0]);
                    } else if ((fieldValue instanceof Object[]) && ((Object[]) fieldValue).length == 1) {
                        fieldValue = ((Object[]) fieldValue)[0];
                    }
                }
                twoDimTable.set(i3, i4, fieldValue);
                i4++;
            }
            twoDimTable.set(i3, i4, key.toString());
            if (str != null) {
                twoDimTable.set(i3, i4 + 1, Double.valueOf(ModelMetrics.getMetricFromModel(key, str)));
            }
            i3++;
        }
        Log.info(twoDimTable);
        return twoDimTable;
    }

    public TwoDimTable createScoringHistoryTable() {
        if (0 == this._models.values().size()) {
            return ScoringInfo.createScoringHistoryTable(this._scoring_infos, false, false, ModelCategory.Binomial, false);
        }
        Key<Model> key = null;
        Iterator<Key<Model>> it = this._models.values().iterator();
        if (it.hasNext()) {
            key = it.next();
        }
        Model model = key.get();
        if (null == model) {
            Log.warn("Cannot create grid scoring history table; Model has been removed: " + key);
            return ScoringInfo.createScoringHistoryTable(this._scoring_infos, false, false, ModelCategory.Binomial, false);
        }
        ScoringInfo scoringInfo = (this._scoring_infos == null || this._scoring_infos.length <= 0) ? null : this._scoring_infos[0];
        return ScoringInfo.createScoringHistoryTable(this._scoring_infos, scoringInfo != null ? scoringInfo.validation : false, scoringInfo != null ? scoringInfo.cross_validation : false, model._output.getModelCategory(), scoringInfo != null ? scoringInfo.is_autoencoder : false);
    }

    @Override // hex.faulttolerance.Recoverable
    public List<String> exportBinary(String str, boolean z, ModelExportOption... modelExportOptionArr) {
        Objects.requireNonNull(str);
        if (!$assertionsDisabled && this._key == null) {
            throw new AssertionError();
        }
        String str2 = str + "/" + this._key;
        PersistUtils.write(FileUtils.getURI(str2), autoBuffer -> {
            autoBuffer.put(this);
        });
        ArrayList arrayList = new ArrayList();
        arrayList.add(str2);
        if (z) {
            exportModelsBinary(arrayList, str, modelExportOptionArr);
        }
        return arrayList;
    }

    private void exportModelsBinary(List<String> list, String str, ModelExportOption... modelExportOptionArr) {
        Objects.requireNonNull(str);
        for (Model model : getModels()) {
            try {
                String str2 = str + "/" + model._key.toString();
                list.add(str2);
                model.exportBinaryModel(str2, true, modelExportOptionArr);
            } catch (IOException e) {
                throw new RuntimeException("Failed to write grid model " + model._key.toString(), e);
            }
        }
    }

    public static Grid importBinary(String str, boolean z) {
        URI uri = FileUtils.getURI(str);
        if (!PersistUtils.exists(uri)) {
            throw new IllegalArgumentException("Grid file not found " + uri);
        }
        Persist persistForURI = H2O.getPM().getPersistForURI(uri);
        String parent = persistForURI.getParent(uri.toString());
        Grid readGridBinary = readGridBinary(uri, persistForURI);
        Recovery recovery = new Recovery(parent);
        URI uri2 = FileUtils.getURI(recovery.referencesMetaFile(readGridBinary));
        if (z && !PersistUtils.exists(uri2)) {
            throw new IllegalArgumentException("Requested to load with references, but the grid was saved without references.");
        }
        readGridBinary.importModelsBinary(parent);
        if (z) {
            recovery.loadReferences(readGridBinary);
        }
        DKV.put(readGridBinary);
        return readGridBinary;
    }

    /* JADX WARN: Failed to calculate best type for var: r10v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Failed to calculate best type for var: r11v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 10, insn: 0x006e: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r10 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:28:0x006e */
    /* JADX WARN: Not initialized variable reg: 11, insn: 0x0072: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r11 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:30:0x0072 */
    /* JADX WARN: Type inference failed for: r10v0, types: [java.io.InputStream] */
    /* JADX WARN: Type inference failed for: r11v0, types: [java.lang.Throwable] */
    private static Grid readGridBinary(URI uri, Persist persist) {
        try {
            try {
                InputStream open = persist.open(uri.toString());
                Throwable th = null;
                Freezable freezable = new AutoBuffer(open).get();
                if (!(freezable instanceof Grid)) {
                    throw new IllegalArgumentException(String.format("Given file '%s' is not a Grid", uri.toString()));
                }
                Grid grid = (Grid) freezable;
                if (open != null) {
                    if (0 != 0) {
                        try {
                            open.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        open.close();
                    }
                }
                return grid;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalStateException("Failed to open grid file.", e);
        }
    }

    private void importModelsBinary(String str) {
        Iterator<Key<Model>> it = this._models.values().iterator();
        while (it.hasNext()) {
            String str2 = str + "/" + it.next().toString();
            try {
                Model importBinaryModel = Model.importBinaryModel(str2);
                if (!$assertionsDisabled && importBinaryModel == null) {
                    throw new AssertionError();
                }
            } catch (IOException e) {
                throw new IllegalStateException("Unable to load model from " + str2, e);
            }
        }
    }

    @Override // hex.faulttolerance.Recoverable
    public Set<Key<?>> getDependentKeys() {
        return this._params.getDependentKeys();
    }

    public MP getParams() {
        return this._params;
    }

    static {
        $assertionsDisabled = !Grid.class.desiredAssertionStatus();
        GRID_PROTO = new Grid(null, null, null, new HashMap(), null, null, 0);
        NO_MODEL_FAILURES_KEY = Key.makeUserHidden("GridSearchFailureEmptyModelKey");
    }
}
