package hex;

import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.MRTask;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.ast.prims.advmath.AstKFold;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/XValPredictionsCheck.class */
public class XValPredictionsCheck extends TestUtil {
    static final /* synthetic */ boolean $assertionsDisabled;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testXValPredictions() {
        Frame frame = null;
        try {
            frame = parse_test_file("smalldata/iris/iris_wheader.csv");
            Frame frame2 = new Frame(new String[]{"foldId"}, new Vec[]{AstKFold.kfoldColumn(frame.vec("class").makeZero(), 3, 543216789L)});
            frame.add(frame2);
            DKV.put(frame);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._response_column = "class";
            gBMParameters._ntrees = 1;
            gBMParameters._max_depth = 1;
            gBMParameters._fold_column = "foldId";
            gBMParameters._distribution = DistributionFamily.multinomial;
            gBMParameters._keep_cross_validation_predictions = true;
            checkModel(new GBM(gBMParameters).trainModel().get(), frame2.anyVec(), 3);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._response_column = "class";
            dRFParameters._ntrees = 1;
            dRFParameters._max_depth = 1;
            dRFParameters._fold_column = "foldId";
            dRFParameters._distribution = DistributionFamily.multinomial;
            dRFParameters._keep_cross_validation_predictions = true;
            checkModel(new DRF(dRFParameters).trainModel().get(), frame2.anyVec(), 3);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._train = frame._key;
            gLMParameters._response_column = "sepal_len";
            gLMParameters._fold_column = "foldId";
            gLMParameters._keep_cross_validation_predictions = true;
            checkModel(new GLM(gLMParameters).trainModel().get(), frame2.anyVec(), 1);
            DeepLearningModel.DeepLearningParameters deepLearningParameters = new DeepLearningModel.DeepLearningParameters();
            deepLearningParameters._train = frame._key;
            deepLearningParameters._response_column = "class";
            deepLearningParameters._hidden = new int[]{1};
            deepLearningParameters._epochs = 1.0d;
            deepLearningParameters._fold_column = "foldId";
            deepLearningParameters._keep_cross_validation_predictions = true;
            checkModel(new DeepLearning(deepLearningParameters).trainModel().get(), frame2.anyVec(), 3);
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v31, types: [hex.XValPredictionsCheck$1] */
    void checkModel(Model model, Vec vec, int i) {
        if (!(model instanceof DRFModel)) {
            Assert.assertEquals(model._output._training_metrics._nobs, model._output._cross_validation_metrics._nobs);
        }
        model.delete();
        model.deleteCrossValidationModels();
        Key[] keyArr = model._output._cross_validation_predictions;
        Key key = model._output._cross_validation_holdout_predictions_frame_id;
        final int[] iArr = new int[1];
        for (Key key2 : keyArr) {
            Frame get = DKV.getGet(key2);
            if (!$assertionsDisabled && get.numRows() != vec.length()) {
                throw new AssertionError();
            }
            Vec[] vecArr = new Vec[i + 1];
            vecArr[0] = vec;
            if (i == 1) {
                vecArr[1] = get.anyVec();
            } else {
                System.arraycopy(get.vecs(ArrayUtils.range(1, i)), 0, vecArr, 1, i);
            }
            new MRTask() { // from class: hex.XValPredictionsCheck.1
                static final /* synthetic */ boolean $assertionsDisabled;

                public void map(Chunk[] chunkArr) {
                    Chunk chunk = chunkArr[0];
                    for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                        if (chunk.at8(i2) != iArr[0]) {
                            for (int i3 = 1; i3 < chunkArr.length; i3++) {
                                if (!$assertionsDisabled && chunkArr[i3].atd(i2) != 0.0d) {
                                    throw new AssertionError();
                                }
                            }
                        }
                    }
                }

                static {
                    $assertionsDisabled = !XValPredictionsCheck.class.desiredAssertionStatus();
                }
            }.doAll(vecArr);
            iArr[0] = iArr[0] + 1;
            get.delete();
        }
        key.remove();
    }

    static {
        $assertionsDisabled = !XValPredictionsCheck.class.desiredAssertionStatus();
    }
}
