package hex.grid;

import hex.Model;
import hex.Model.Parameters;
import hex.ModelParametersBuilderFactory;
import hex.ScoreKeeper;
import hex.ScoringInfo;
import hex.grid.HyperSpaceSearchCriteria;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.Set;
import water.exceptions.H2OIllegalArgumentException;
import water.util.PojoUtils;

/* loaded from: input_file:hex/grid/HyperSpaceWalker.class */
public interface HyperSpaceWalker<MP extends Model.Parameters, C extends HyperSpaceSearchCriteria> {

    /* loaded from: input_file:hex/grid/HyperSpaceWalker$BaseWalker.class */
    public static abstract class BaseWalker<MP extends Model.Parameters, C extends HyperSpaceSearchCriteria> implements HyperSpaceWalker<MP, C> {
        protected final C _search_criteria;
        final transient ModelParametersBuilderFactory<MP> _paramsBuilderFactory;
        final MP _params;
        protected final Map<String, Object[]> _hyperParams;
        protected final String[] _hyperParamNames;
        protected final long _maxHyperSpaceSize = computeMaxSizeOfHyperSpace();

        /* loaded from: input_file:hex/grid/HyperSpaceWalker$BaseWalker$WalkerFactory.class */
        public static class WalkerFactory<MP extends Model.Parameters, C extends HyperSpaceSearchCriteria> {
            public static <MP extends Model.Parameters, C extends HyperSpaceSearchCriteria> HyperSpaceWalker create(MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, C c) {
                HyperSpaceSearchCriteria.Strategy strategy = c.strategy();
                if (strategy == HyperSpaceSearchCriteria.Strategy.Cartesian) {
                    return new CartesianWalker(mp, map, modelParametersBuilderFactory, (HyperSpaceSearchCriteria.CartesianSearchCriteria) c);
                }
                if (strategy == HyperSpaceSearchCriteria.Strategy.RandomDiscrete) {
                    return new RandomDiscreteValueWalker(mp, map, modelParametersBuilderFactory, (HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria) c);
                }
                throw new H2OIllegalArgumentException("strategy", "GridSearch", strategy);
            }
        }

        @Override // hex.grid.HyperSpaceWalker
        public C search_criteria() {
            return this._search_criteria;
        }

        @Override // hex.grid.HyperSpaceWalker
        public boolean stopEarly(Model model, ScoringInfo[] scoringInfoArr) {
            return false;
        }

        public BaseWalker(MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, C c) {
            this._params = mp;
            this._hyperParams = map;
            this._paramsBuilderFactory = modelParametersBuilderFactory;
            this._hyperParamNames = (String[]) map.keySet().toArray(new String[0]);
            this._search_criteria = c;
            try {
                Model.Parameters parameters = (Model.Parameters) mp.getClass().newInstance();
                for (String str : map.keySet()) {
                    if (0 == map.get(str).length) {
                        throw new H2OIllegalArgumentException("Grid search hyperparameter value list is empty for hyperparameter: " + str);
                    }
                    if (!"seed".equals(str) && !"_seed".equals(str)) {
                        String str2 = str.startsWith("_") ? "" : "_";
                        Object fieldValue = PojoUtils.getFieldValue(parameters, str2 + str, PojoUtils.FieldNaming.CONSISTENT);
                        Object fieldValue2 = PojoUtils.getFieldValue(mp, str2 + str, PojoUtils.FieldNaming.CONSISTENT);
                        if (fieldValue != null && fieldValue2 != null) {
                            if (fieldValue.getClass().isArray() && !PojoUtils.arraysEquals(fieldValue, fieldValue2)) {
                                throw new H2OIllegalArgumentException("Grid search model parameter '" + str + "' is set in both the model parameters and in the hyperparameters map.  This is ambiguous; set it in one place or the other, not both.");
                            }
                            if (!fieldValue.getClass().isArray() && !fieldValue.equals(fieldValue2)) {
                                throw new H2OIllegalArgumentException("Grid search model parameter '" + str + "' is set in both the model parameters and in the hyperparameters map.  This is ambiguous; set it in one place or the other, not both.");
                            }
                        }
                        if (fieldValue == null && fieldValue2 != null) {
                            throw new H2OIllegalArgumentException("Grid search model parameter '" + str + "' is set in both the model parameters and in the hyperparameters map.  This is ambiguous; set it in one place or the other, not both.");
                        }
                    }
                }
            } catch (Exception e) {
                throw new H2OIllegalArgumentException("Failed to instantiate a new Model.Parameters object to get the default values.");
            }
        }

        @Override // hex.grid.HyperSpaceWalker
        public String[] getHyperParamNames() {
            return this._hyperParamNames;
        }

        @Override // hex.grid.HyperSpaceWalker
        public long getMaxHyperSpaceSize() {
            return this._maxHyperSpaceSize;
        }

        @Override // hex.grid.HyperSpaceWalker
        public MP getParams() {
            return this._params;
        }

        @Override // hex.grid.HyperSpaceWalker
        public ModelParametersBuilderFactory<MP> getParametersBuilderFactory() {
            return this._paramsBuilderFactory;
        }

        protected MP getModelParams(MP mp, Object[] objArr) {
            ModelParametersBuilderFactory.ModelParametersBuilder<MP> modelParametersBuilder = this._paramsBuilderFactory.get(mp);
            for (int i = 0; i < this._hyperParamNames.length; i++) {
                String str = this._hyperParamNames[i];
                Object obj = objArr[i];
                if (str.equals("valid")) {
                    str = "validation_frame";
                }
                modelParametersBuilder.set(str, obj);
            }
            return modelParametersBuilder.build();
        }

        protected long computeMaxSizeOfHyperSpace() {
            long j = 1;
            Iterator<Map.Entry<String, Object[]>> it = this._hyperParams.entrySet().iterator();
            while (it.hasNext()) {
                if (it.next().getValue() != null) {
                    j *= r0.getValue().length;
                }
            }
            return j;
        }

        protected Object[] hypers(int[] iArr, Object[] objArr) {
            for (int i = 0; i < iArr.length; i++) {
                objArr[i] = this._hyperParams.get(this._hyperParamNames[i])[iArr[i]];
            }
            return objArr;
        }

        protected int integerHash(int[] iArr) {
            Integer[] numArr = new Integer[iArr.length];
            for (int i = 0; i < iArr.length; i++) {
                numArr[i] = Integer.valueOf(iArr[i] * this._hyperParams.get(this._hyperParamNames[i]).length);
            }
            return Arrays.deepHashCode(numArr);
        }
    }

    /* loaded from: input_file:hex/grid/HyperSpaceWalker$CartesianWalker.class */
    public static class CartesianWalker<MP extends Model.Parameters> extends BaseWalker<MP, HyperSpaceSearchCriteria.CartesianSearchCriteria> {
        public CartesianWalker(MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, HyperSpaceSearchCriteria.CartesianSearchCriteria cartesianSearchCriteria) {
            super(mp, map, modelParametersBuilderFactory, cartesianSearchCriteria);
        }

        @Override // hex.grid.HyperSpaceWalker
        public HyperSpaceIterator<MP> iterator() {
            return (HyperSpaceIterator<MP>) new HyperSpaceIterator<MP>() { // from class: hex.grid.HyperSpaceWalker.CartesianWalker.1
                private int[] _currentHyperparamIndices = null;

                /* JADX WARN: Multi-variable type inference failed */
                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public MP nextModelParameters(Model model) {
                    this._currentHyperparamIndices = this._currentHyperparamIndices != null ? CartesianWalker.this.nextModelIndices(this._currentHyperparamIndices) : new int[CartesianWalker.this._hyperParamNames.length];
                    if (this._currentHyperparamIndices == null) {
                        throw new NoSuchElementException("No more elements to explore in hyper-space!");
                    }
                    MP mp = (MP) CartesianWalker.this.getModelParams((Model.Parameters) CartesianWalker.this._params.m76clone(), CartesianWalker.this.hypers(this._currentHyperparamIndices, new Object[CartesianWalker.this._hyperParamNames.length]));
                    if (CartesianWalker.this._search_criteria != null && ((HyperSpaceSearchCriteria.CartesianSearchCriteria) CartesianWalker.this._search_criteria).strategy() == HyperSpaceSearchCriteria.Strategy.RandomDiscrete) {
                        if (time_remaining_secs() > 0.0d) {
                            if (mp._max_runtime_secs > 0.0d) {
                                mp._max_runtime_secs = (long) StrictMath.floor(StrictMath.min(mp._max_runtime_secs, r0));
                            } else {
                                mp._max_runtime_secs = (long) StrictMath.floor(r0);
                            }
                        }
                    }
                    return mp;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public boolean hasNext(Model model) {
                    if (this._currentHyperparamIndices == null) {
                        return true;
                    }
                    int[] iArr = this._currentHyperparamIndices;
                    for (int i = 0; i < iArr.length; i++) {
                        if (iArr[i] + 1 < CartesianWalker.this._hyperParams.get(CartesianWalker.this._hyperParamNames[i]).length) {
                            return true;
                        }
                    }
                    return false;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public void reset() {
                    this._currentHyperparamIndices = null;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public double time_remaining_secs() {
                    return Double.MAX_VALUE;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public double max_runtime_secs() {
                    return Double.MAX_VALUE;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public int max_models() {
                    if (CartesianWalker.this._maxHyperSpaceSize > 2147483647L) {
                        return Integer.MAX_VALUE;
                    }
                    return (int) CartesianWalker.this._maxHyperSpaceSize;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public void modelFailed(Model model) {
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public Object[] getCurrentRawParameters() {
                    return CartesianWalker.this.hypers(this._currentHyperparamIndices, new Object[CartesianWalker.this._hyperParamNames.length]);
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int[] nextModelIndices(int[] iArr) {
            int i = 0;
            while (i < iArr.length && iArr[i] + 1 >= this._hyperParams.get(this._hyperParamNames[i]).length) {
                i++;
            }
            if (i == iArr.length) {
                return null;
            }
            for (int i2 = 0; i2 < i; i2++) {
                iArr[i2] = 0;
            }
            int i3 = i;
            iArr[i3] = iArr[i3] + 1;
            return iArr;
        }
    }

    /* loaded from: input_file:hex/grid/HyperSpaceWalker$HyperSpaceIterator.class */
    public interface HyperSpaceIterator<MP extends Model.Parameters> {
        MP nextModelParameters(Model model);

        boolean hasNext(Model model);

        void reset();

        double max_runtime_secs();

        int max_models();

        double time_remaining_secs();

        void modelFailed(Model model);

        Object[] getCurrentRawParameters();
    }

    /* loaded from: input_file:hex/grid/HyperSpaceWalker$RandomDiscreteValueWalker.class */
    public static class RandomDiscreteValueWalker<MP extends Model.Parameters> extends BaseWalker<MP, HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria> {
        Random random;
        private List<int[]> _visitedPermutations;
        private Set<Integer> _visitedPermutationHashes;

        public RandomDiscreteValueWalker(MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria randomDiscreteValueSearchCriteria) {
            super(mp, map, modelParametersBuilderFactory, randomDiscreteValueSearchCriteria);
            this._visitedPermutations = new ArrayList();
            this._visitedPermutationHashes = new LinkedHashSet();
            if (-1 == randomDiscreteValueSearchCriteria.seed()) {
                this.random = new Random();
            } else {
                this.random = new Random(randomDiscreteValueSearchCriteria.seed());
            }
        }

        @Override // hex.grid.HyperSpaceWalker.BaseWalker, hex.grid.HyperSpaceWalker
        public boolean stopEarly(Model model, ScoringInfo[] scoringInfoArr) {
            return ScoreKeeper.stopEarly(ScoringInfo.scoreKeepers(scoringInfoArr), search_criteria().stopping_rounds(), model._output.isClassifier(), search_criteria().stopping_metric(), search_criteria().stopping_tolerance(), "grid's best", true);
        }

        @Override // hex.grid.HyperSpaceWalker
        public HyperSpaceIterator<MP> iterator() {
            return (HyperSpaceIterator<MP>) new HyperSpaceIterator<MP>() { // from class: hex.grid.HyperSpaceWalker.RandomDiscreteValueWalker.1
                private int[] _currentHyperparamIndices = null;
                private int _currentPermutationNum = 0;
                private long _start_time = System.currentTimeMillis();

                /* JADX WARN: Multi-variable type inference failed */
                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public MP nextModelParameters(Model model) {
                    this._currentHyperparamIndices = RandomDiscreteValueWalker.this.nextModelIndices();
                    if (this._currentHyperparamIndices == null) {
                        throw new NoSuchElementException("No more elements to explore in hyper-space!");
                    }
                    RandomDiscreteValueWalker.this._visitedPermutations.add(this._currentHyperparamIndices);
                    RandomDiscreteValueWalker.this._visitedPermutationHashes.add(Integer.valueOf(RandomDiscreteValueWalker.this.integerHash(this._currentHyperparamIndices)));
                    this._currentPermutationNum++;
                    return (MP) RandomDiscreteValueWalker.this.getModelParams((Model.Parameters) RandomDiscreteValueWalker.this._params.m76clone(), RandomDiscreteValueWalker.this.hypers(this._currentHyperparamIndices, new Object[RandomDiscreteValueWalker.this._hyperParamNames.length]));
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public boolean hasNext(Model model) {
                    return ((long) RandomDiscreteValueWalker.this._visitedPermutationHashes.size()) < RandomDiscreteValueWalker.this._maxHyperSpaceSize && (RandomDiscreteValueWalker.this.search_criteria().max_models() == 0 || this._currentPermutationNum < RandomDiscreteValueWalker.this.search_criteria().max_models());
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public void reset() {
                    this._start_time = System.currentTimeMillis();
                    this._currentPermutationNum = 0;
                    this._currentHyperparamIndices = null;
                    RandomDiscreteValueWalker.this._visitedPermutations.clear();
                    RandomDiscreteValueWalker.this._visitedPermutationHashes.clear();
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public double max_runtime_secs() {
                    return RandomDiscreteValueWalker.this.search_criteria().max_runtime_secs();
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public int max_models() {
                    return RandomDiscreteValueWalker.this.search_criteria().max_models();
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public double time_remaining_secs() {
                    return RandomDiscreteValueWalker.this.search_criteria().max_runtime_secs() - ((System.currentTimeMillis() - this._start_time) / 1000.0d);
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public void modelFailed(Model model) {
                    this._currentPermutationNum--;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public Object[] getCurrentRawParameters() {
                    return RandomDiscreteValueWalker.this.hypers(this._currentHyperparamIndices, new Object[RandomDiscreteValueWalker.this._hyperParamNames.length]);
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int[] nextModelIndices() {
            int[] iArr = new int[this._hyperParamNames.length];
            do {
                for (int i = 0; i < this._hyperParamNames.length; i++) {
                    iArr[i] = this.random.nextInt(this._hyperParams.get(this._hyperParamNames[i]).length);
                }
            } while (this._visitedPermutationHashes.contains(Integer.valueOf(integerHash(iArr))));
            return iArr;
        }
    }

    C search_criteria();

    boolean stopEarly(Model model, ScoringInfo[] scoringInfoArr);

    HyperSpaceIterator<MP> iterator();

    String[] getHyperParamNames();

    long getMaxHyperSpaceSize();

    MP getParams();

    ModelParametersBuilderFactory<MP> getParametersBuilderFactory();
}
