package water;

import hex.Model;
import hex.ModelMetrics;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.CompressedTree;
import hex.tree.SharedTreeModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import hex.tree.isofor.IsolationForest;
import hex.tree.isofor.IsolationForestModel;
import java.io.File;
import java.io.IOException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.fvec.Frame;
import water.util.Log;

/* loaded from: input_file:water/ModelSerializationTest.class */
public class ModelSerializationTest extends TestUtil {
    private static String[] ESA = new String[0];

    /* loaded from: input_file:water/ModelSerializationTest$BlahModel.class */
    static class BlahModel extends Model<BlahModel, BlahParameters, BlahOutput> {

        /* loaded from: input_file:water/ModelSerializationTest$BlahModel$BlahOutput.class */
        static class BlahOutput extends Model.Output {
            public BlahOutput(boolean z, boolean z2, boolean z3) {
                super(z, z2, z3);
            }
        }

        /* loaded from: input_file:water/ModelSerializationTest$BlahModel$BlahParameters.class */
        static class BlahParameters extends Model.Parameters {
            BlahParameters() {
            }

            public String algoName() {
                return "Blah";
            }

            public String fullName() {
                return "Blah";
            }

            public String javaName() {
                return BlahModel.class.getName();
            }

            public long progressUnits() {
                return 0L;
            }
        }

        public BlahModel(Key key, BlahParameters blahParameters, BlahOutput blahOutput) {
            super(key, blahParameters, blahOutput);
        }

        public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
            return null;
        }

        protected double[] score0(double[] dArr, double[] dArr2) {
            return new double[0];
        }
    }

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testSimpleModel() throws IOException {
        BlahModel blahModel = new BlahModel(Key.make("BLAHModel"), new BlahModel.BlahParameters(), new BlahModel.BlahOutput(false, false, false));
        DKV.put(((Model) blahModel)._key, blahModel);
        Model model = null;
        try {
            model = saveAndLoad(blahModel);
            assertModelBinaryEquals(blahModel, model);
            if (model != null) {
                model.delete();
            }
        } catch (Throwable th) {
            if (model != null) {
                model.delete();
            }
            throw th;
        }
    }

    @Test
    public void testGBMModelMultinomial() throws IOException {
        Model model = null;
        try {
            GBMModel prepareGBMModel = prepareGBMModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
            CompressedTree[][] trees = getTrees(prepareGBMModel);
            model = (GBMModel) saveAndLoad(prepareGBMModel);
            assertModelBinaryEquals(prepareGBMModel, model);
            assertTreeEquals("Trees have to be binary same", trees, getTrees(model));
            if (model != null) {
                model.delete();
            }
        } catch (Throwable th) {
            if (model != null) {
                model.delete();
            }
            throw th;
        }
    }

    @Test
    public void testGBMModelBinomial() throws IOException {
        Model model = null;
        try {
            GBMModel prepareGBMModel = prepareGBMModel("smalldata/logreg/prostate.csv", ar(new String[]{"ID"}), "CAPSULE", true, 5);
            CompressedTree[][] trees = getTrees(prepareGBMModel);
            model = (GBMModel) saveAndLoad(prepareGBMModel);
            assertModelBinaryEquals(prepareGBMModel, model);
            assertTreeEquals("Trees have to be binary same", trees, getTrees(model));
            if (model != null) {
                model.delete();
            }
        } catch (Throwable th) {
            if (model != null) {
                model.delete();
            }
            throw th;
        }
    }

    @Test
    public void testDRFModelMultinomial() throws IOException {
        Model model = null;
        try {
            DRFModel prepareDRFModel = prepareDRFModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
            CompressedTree[][] trees = getTrees(prepareDRFModel);
            model = (DRFModel) saveAndLoad(prepareDRFModel);
            assertModelBinaryEquals(prepareDRFModel, model);
            assertTreeEquals("Trees have to be binary same", trees, getTrees(model));
            if (model != null) {
                model.delete();
            }
        } catch (Throwable th) {
            if (model != null) {
                model.delete();
            }
            throw th;
        }
    }

    @Test
    public void testDRFModelBinomial() throws IOException {
        SharedTreeModel sharedTreeModel = null;
        Model model = null;
        try {
            sharedTreeModel = prepareDRFModel("smalldata/logreg/prostate.csv", ar(new String[]{"ID"}), "CAPSULE", true, 5);
            CompressedTree[][] trees = getTrees(sharedTreeModel);
            model = (DRFModel) saveAndLoad(sharedTreeModel);
            assertModelBinaryEquals(sharedTreeModel, model);
            assertTreeEquals("Trees have to be binary same", trees, getTrees(model));
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (model != null) {
                model.delete();
            }
        } catch (Throwable th) {
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (model != null) {
                model.delete();
            }
            throw th;
        }
    }

    @Test
    public void testIsolationForestModel() throws IOException {
        SharedTreeModel sharedTreeModel = null;
        Model model = null;
        try {
            sharedTreeModel = prepareIsoForModel("smalldata/logreg/prostate.csv", ar(new String[]{"ID", "CAPSULE"}), 5);
            CompressedTree[][] trees = getTrees(sharedTreeModel);
            model = (IsolationForestModel) saveAndLoad(sharedTreeModel);
            assertModelBinaryEquals(sharedTreeModel, model);
            assertTreeEquals("Trees have to be binary same", trees, getTrees(model));
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (model != null) {
                model.delete();
            }
        } catch (Throwable th) {
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (model != null) {
                model.delete();
            }
            throw th;
        }
    }

    @Test
    public void testGLMModel() throws IOException {
        Model model = null;
        try {
            GLMModel prepareGLMModel = prepareGLMModel("smalldata/junit/cars.csv", ESA, "power (hp)", GLMModel.GLMParameters.Family.poisson);
            model = (GLMModel) saveAndLoad(prepareGLMModel);
            assertModelBinaryEquals(prepareGLMModel, model);
            if (model != null) {
                model.delete();
            }
        } catch (Throwable th) {
            if (model != null) {
                model.delete();
            }
            throw th;
        }
    }

    private GBMModel prepareGBMModel(String str, String[] strArr, String str2, boolean z, int i) {
        Frame parse_test_file = parse_test_file(str);
        if (z) {
            try {
                if (!parse_test_file.vec(str2).isCategorical()) {
                    parse_test_file.replace(parse_test_file.find(str2), parse_test_file.vec(str2).toCategoricalVec()).remove();
                    DKV.put(parse_test_file._key, parse_test_file);
                }
            } catch (Throwable th) {
                if (parse_test_file != null) {
                    parse_test_file.delete();
                }
                throw th;
            }
        }
        GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
        gBMParameters._train = parse_test_file._key;
        gBMParameters._ignored_columns = strArr;
        gBMParameters._response_column = str2;
        gBMParameters._ntrees = i;
        gBMParameters._score_each_iteration = true;
        GBMModel gBMModel = new GBM(gBMParameters).trainModel().get();
        if (parse_test_file != null) {
            parse_test_file.delete();
        }
        return gBMModel;
    }

    private DRFModel prepareDRFModel(String str, String[] strArr, String str2, boolean z, int i) {
        Frame parse_test_file = parse_test_file(str);
        if (z) {
            try {
                if (!parse_test_file.vec(str2).isCategorical()) {
                    parse_test_file.replace(parse_test_file.find(str2), parse_test_file.vec(str2).toCategoricalVec()).remove();
                    DKV.put(parse_test_file._key, parse_test_file);
                }
            } catch (Throwable th) {
                if (parse_test_file != null) {
                    parse_test_file.delete();
                }
                throw th;
            }
        }
        DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
        dRFParameters._train = parse_test_file._key;
        dRFParameters._ignored_columns = strArr;
        dRFParameters._response_column = str2;
        dRFParameters._ntrees = i;
        dRFParameters._score_each_iteration = true;
        DRFModel dRFModel = new DRF(dRFParameters).trainModel().get();
        if (parse_test_file != null) {
            parse_test_file.delete();
        }
        return dRFModel;
    }

    private IsolationForestModel prepareIsoForModel(String str, String[] strArr, int i) {
        Frame parse_test_file = parse_test_file(str);
        try {
            IsolationForestModel.IsolationForestParameters isolationForestParameters = new IsolationForestModel.IsolationForestParameters();
            isolationForestParameters._train = parse_test_file._key;
            isolationForestParameters._ignored_columns = strArr;
            isolationForestParameters._ntrees = i;
            isolationForestParameters._score_each_iteration = true;
            IsolationForestModel isolationForestModel = new IsolationForest(isolationForestParameters).trainModel().get();
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            return isolationForestModel;
        } catch (Throwable th) {
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            throw th;
        }
    }

    private GLMModel prepareGLMModel(String str, String[] strArr, String str2, GLMModel.GLMParameters.Family family) {
        Frame parse_test_file = parse_test_file(str);
        try {
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._train = parse_test_file._key;
            gLMParameters._ignored_columns = strArr;
            gLMParameters._response_column = str2;
            gLMParameters._family = family;
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            return gLMModel;
        } catch (Throwable th) {
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            throw th;
        }
    }

    private <M extends Model> M saveAndLoad(M m) throws IOException {
        return (M) saveAndLoad(m, true);
    }

    private <M extends Model<?, ?, ?>> M saveAndLoad(M m, boolean z) throws IOException {
        File createTempFile = File.createTempFile(m.getClass().getSimpleName(), null);
        try {
            String absolutePath = createTempFile.getAbsolutePath();
            m.exportBinaryModel(absolutePath, true);
            if (z) {
                m.delete();
            }
            M m2 = (M) Model.importBinaryModel(absolutePath);
            if (!createTempFile.delete()) {
                Log.err(new Object[]{"Temporary file " + createTempFile + " was not deleted."});
            }
            return m2;
        } catch (Throwable th) {
            if (!createTempFile.delete()) {
                Log.err(new Object[]{"Temporary file " + createTempFile + " was not deleted."});
            }
            throw th;
        }
    }

    public static void assertModelBinaryEquals(Model model, Model model2) {
        Assert.assertArrayEquals("The serialized models are not binary same!", model.write(new AutoBuffer()).buf(), model2.write(new AutoBuffer()).buf());
    }

    public static void assertIcedBinaryEquals(String str, Iced iced, Iced iced2) {
        if (iced == null) {
            Assert.assertEquals(str, (Object) null, iced2);
        } else {
            Assert.assertArrayEquals(str, iced.write(new AutoBuffer()).buf(), iced2.write(new AutoBuffer()).buf());
        }
    }

    public static void assertTreeEquals(String str, CompressedTree[][] compressedTreeArr, CompressedTree[][] compressedTreeArr2) {
        assertTreeEquals(str, compressedTreeArr, compressedTreeArr2, false);
    }

    public static void assertTreeEquals(String str, CompressedTree[][] compressedTreeArr, CompressedTree[][] compressedTreeArr2, boolean z) {
        Assert.assertEquals("Number of trees has to match", compressedTreeArr.length, compressedTreeArr2.length);
        for (int i = 0; i < compressedTreeArr.length; i++) {
            Assert.assertEquals("Number of trees per tree has to match", compressedTreeArr[i].length, compressedTreeArr2[i].length);
            for (int i2 = 0; i2 < compressedTreeArr[i].length; i2++) {
                Key key = null;
                Key key2 = null;
                if (z) {
                    if (compressedTreeArr[i][i2] != null) {
                        key = compressedTreeArr[i][i2]._key;
                        compressedTreeArr[i][i2]._key = null;
                    }
                    if (compressedTreeArr2[i][i2] != null) {
                        key2 = compressedTreeArr2[i][i2]._key;
                        compressedTreeArr2[i][i2]._key = null;
                    }
                }
                assertIcedBinaryEquals(str, compressedTreeArr[i][i2], compressedTreeArr2[i][i2]);
                if (z) {
                    if (compressedTreeArr[i][i2] != null) {
                        compressedTreeArr[i][i2]._key = key;
                    }
                    if (compressedTreeArr2[i][i2] != null) {
                        compressedTreeArr2[i][i2]._key = key2;
                    }
                }
            }
        }
    }

    public static CompressedTree[][] getTrees(SharedTreeModel sharedTreeModel) {
        SharedTreeModel.SharedTreeOutput sharedTreeOutput = sharedTreeModel._output;
        int i = sharedTreeOutput._ntrees;
        int nclasses = sharedTreeOutput.nclasses();
        CompressedTree[][] compressedTreeArr = new CompressedTree[i][nclasses];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < nclasses; i3++) {
                if (sharedTreeOutput._treeKeys[i2][i3] != null) {
                    compressedTreeArr[i2][i3] = sharedTreeOutput.ctree(i2, i3);
                }
            }
        }
        return compressedTreeArr;
    }
}
