package ai.h2o.automl;

import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.dummy.DummyModel;
import ai.h2o.automl.events.EventLog;
import ai.h2o.automl.leaderboard.Leaderboard;
import hex.Model;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.runner.CloudSize;
import water.runner.H2ORunner;
import water.util.ArrayUtils;

@CloudSize(1)
@RunWith(H2ORunner.class)
/* loaded from: input_file:ai/h2o/automl/ModelSelectionStrategiesTest.class */
public class ModelSelectionStrategiesTest {
    private List<Keyed> toDelete = new ArrayList();
    private Frame fr;
    private double[] perfectPreds;
    private Model[] oldModels;
    private Model[] newModels;
    private Supplier<ModelSelectionStrategies.LeaderboardHolder> leaderboardSupplier;

    /* loaded from: input_file:ai/h2o/automl/ModelSelectionStrategiesTest$DummyScoreModel.class */
    static class DummyScoreModel extends DummyModel {
        private double[] _perfectPreds;
        private int _goodPredsCount;

        public DummyScoreModel(String str, double[] dArr, int i) {
            super(str);
            this._perfectPreds = dArr;
            this._goodPredsCount = i;
        }

        @Override // ai.h2o.automl.dummy.DummyModel
        protected double[] score0(double[] dArr, double[] dArr2) {
            int i = (int) dArr[0];
            double d = i < this._goodPredsCount ? this._perfectPreds[i - 1] : 1.0d - this._perfectPreds[i - 1];
            return TestUtil.ard(new double[]{d, d, 1.0d - d});
        }
    }

    @Before
    public void prepareModels() {
        this.fr = new Frame(Key.make("dummy_fr"), new String[]{"A", "B", "target"}, new Vec[]{TestUtil.ivec(new int[]{1, 2, 3, 4, 5}), TestUtil.ivec(new int[]{1, 2, 3, 4, 5}), TestUtil.cvec(new String[]{"foo", "foo", "foo", "bar", "bar"})});
        this.perfectPreds = TestUtil.ard(new double[]{1.0d, 1.0d, 1.0d, 0.0d, 0.0d});
        DKV.put(this.fr);
        this.toDelete.add(this.fr);
        this.leaderboardSupplier = () -> {
            final EventLog orMake = EventLog.getOrMake(Key.make("selection_lb"));
            final Leaderboard orMake2 = Leaderboard.getOrMake("selection_lb", orMake, this.fr, "logloss");
            return new ModelSelectionStrategies.LeaderboardHolder() { // from class: ai.h2o.automl.ModelSelectionStrategiesTest.1
                public Leaderboard get() {
                    return orMake2;
                }

                public void cleanup() {
                    orMake2.remove();
                    orMake.remove();
                }
            };
        };
        Model[] modelArr = (Model[]) Arrays.asList(Double.valueOf(1.0d), Double.valueOf(2.0d), Double.valueOf(3.0d), Double.valueOf(4.0d), Double.valueOf(1.1d), Double.valueOf(2.2d), Double.valueOf(5.5d)).stream().map(d -> {
            int floor = (int) Math.floor(d.doubleValue());
            Keyed dummyScoreModel = new DummyScoreModel("dummy_" + d, this.perfectPreds, floor);
            ((DummyModel.DummyModelParameters) ((DummyModel) dummyScoreModel)._parms)._moreParams.put(floor % 2 == 0 ? "even" : "odd", String.valueOf(floor));
            ((DummyModel.DummyModelOutput) ((DummyModel) dummyScoreModel)._output)._names = this.fr.names();
            ((DummyModel.DummyModelOutput) ((DummyModel) dummyScoreModel)._output)._domains = TestUtil.ar((String[][]) new String[]{0, 0, this.fr.vec("target").domain()});
            DKV.put(dummyScoreModel);
            this.toDelete.add(dummyScoreModel);
            return dummyScoreModel;
        }).toArray(i -> {
            return new Model[i];
        });
        this.oldModels = (Model[]) ArrayUtils.subarray(modelArr, 0, 4);
        this.newModels = (Model[]) ArrayUtils.subarray(modelArr, 4, 3);
    }

    @After
    public void cleanup() {
        this.toDelete.forEach((v0) -> {
            v0.remove();
        });
    }

    @Test
    public void testKeepBestN() {
        try {
            Scope.enter();
            ModelSelectionStrategy.Selection select = new ModelSelectionStrategies.KeepBestN(2, this.leaderboardSupplier).select((Key[]) Arrays.stream(this.oldModels).map(model -> {
                return model._key;
            }).toArray(i -> {
                return new Key[i];
            }), (Key[]) Arrays.stream(this.newModels).map(model2 -> {
                return model2._key;
            }).toArray(i2 -> {
                return new Key[i2];
            }));
            Assert.assertNotNull(select);
            Assert.assertEquals(1L, select._add.length);
            Assert.assertArrayEquals(TestUtil.ar(new String[]{"dummy_5.5"}), Arrays.stream(select._add).map((v0) -> {
                return v0.toString();
            }).toArray(i3 -> {
                return new String[i3];
            }));
            Assert.assertEquals(3L, select._remove.length);
            Assert.assertArrayEquals(TestUtil.ar(new String[]{"dummy_1.0", "dummy_2.0", "dummy_3.0"}), Arrays.stream(select._remove).map((v0) -> {
                return v0.toString();
            }).toArray(i4 -> {
                return new String[i4];
            }));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testKeepBestN_NoAdd() {
        try {
            Scope.enter();
            ModelSelectionStrategy.Selection select = new ModelSelectionStrategies.KeepBestN(1, this.leaderboardSupplier).select((Key[]) Arrays.stream(this.newModels).map(model -> {
                return model._key;
            }).toArray(i -> {
                return new Key[i];
            }), (Key[]) Arrays.stream(this.oldModels).map(model2 -> {
                return model2._key;
            }).toArray(i2 -> {
                return new Key[i2];
            }));
            Assert.assertNotNull(select);
            Assert.assertEquals(0L, select._add.length);
            Assert.assertEquals(2L, select._remove.length);
            Assert.assertArrayEquals(TestUtil.ar(new String[]{"dummy_1.1", "dummy_2.2"}), Arrays.stream(select._remove).map((v0) -> {
                return v0.toString();
            }).toArray(i3 -> {
                return new String[i3];
            }));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testKeepBestN_NoRemove() {
        try {
            Scope.enter();
            ModelSelectionStrategy.Selection select = new ModelSelectionStrategies.KeepBestN(10, this.leaderboardSupplier).select((Key[]) Arrays.stream(this.oldModels).map(model -> {
                return model._key;
            }).toArray(i -> {
                return new Key[i];
            }), (Key[]) Arrays.stream(this.newModels).map(model2 -> {
                return model2._key;
            }).toArray(i2 -> {
                return new Key[i2];
            }));
            Assert.assertNotNull(select);
            Assert.assertEquals(3L, select._add.length);
            Assert.assertArrayEquals(TestUtil.ar(new String[]{"dummy_5.5", "dummy_2.2", "dummy_1.1"}), Arrays.stream(select._add).map((v0) -> {
                return v0.toString();
            }).toArray(i3 -> {
                return new String[i3];
            }));
            Assert.assertEquals(0L, select._remove.length);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testKeepBestConstantSize() {
        try {
            Scope.enter();
            ModelSelectionStrategy.Selection select = new ModelSelectionStrategies.KeepBestConstantSize(this.leaderboardSupplier).select((Key[]) Arrays.stream(this.oldModels).map(model -> {
                return model._key;
            }).toArray(i -> {
                return new Key[i];
            }), (Key[]) Arrays.stream(this.newModels).map(model2 -> {
                return model2._key;
            }).toArray(i2 -> {
                return new Key[i2];
            }));
            Assert.assertNotNull(select);
            Assert.assertEquals(1L, select._add.length);
            Assert.assertArrayEquals(TestUtil.ar(new String[]{"dummy_5.5"}), Arrays.stream(select._add).map((v0) -> {
                return v0.toString();
            }).toArray(i3 -> {
                return new String[i3];
            }));
            Assert.assertEquals(1L, select._remove.length);
            Assert.assertArrayEquals(TestUtil.ar(new String[]{"dummy_1.0"}), Arrays.stream(select._remove).map((v0) -> {
                return v0.toString();
            }).toArray(i4 -> {
                return new String[i4];
            }));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testKeepBestNFromSubgroup() {
        try {
            Scope.enter();
            ModelSelectionStrategy.Selection select = new ModelSelectionStrategies.KeepBestNFromSubgroup(1, key -> {
                return ((DummyModel.DummyModelParameters) key.get()._parms)._moreParams.containsKey("odd");
            }, this.leaderboardSupplier).select((Key[]) Arrays.stream(this.oldModels).map(model -> {
                return model._key;
            }).toArray(i -> {
                return new Key[i];
            }), (Key[]) Arrays.stream(this.newModels).map(model2 -> {
                return model2._key;
            }).toArray(i2 -> {
                return new Key[i2];
            }));
            Assert.assertNotNull(select);
            Assert.assertEquals(1L, select._add.length);
            Assert.assertArrayEquals(TestUtil.ar(new String[]{"dummy_5.5"}), Arrays.stream(select._add).map((v0) -> {
                return v0.toString();
            }).toArray(i3 -> {
                return new String[i3];
            }));
            Assert.assertEquals(2L, select._remove.length);
            Assert.assertArrayEquals(TestUtil.ar(new String[]{"dummy_1.0", "dummy_3.0"}), Arrays.stream(select._remove).map((v0) -> {
                return v0.toString();
            }).toArray(i4 -> {
                return new String[i4];
            }));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
