package hex.psvm;

import hex.ConfusionMatrix;
import hex.ConfusionMatrixTest;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsSupervised;
import hex.genmodel.algos.psvm.ScorerFactory;
import hex.genmodel.algos.psvm.SupportVectorScorer;
import hex.psvm.PSVMModel;
import hex.splitframe.ShuffleSplitFrame;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.H2O;
import water.Key;
import water.MRTask;
import water.Scope;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.TransformWrappedVec;
import water.fvec.Vec;
import water.rapids.ast.prims.math.AstSgn;
import water.util.FrameUtils;

/* loaded from: input_file:hex/psvm/PSVMTest.class */
public class PSVMTest extends TestUtil {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/psvm/PSVMTest$CheckScorersTask.class */
    public static class CheckScorersTask extends MRTask {
        private final Key<PSVMModel> _model_key;
        private transient PSVMModel _model;

        CheckScorersTask(Key<PSVMModel> key) {
            this._model_key = key;
        }

        protected void setupLocal() {
            this._model = this._model_key.get();
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            double d = this._model._output._rho;
            SupportVectorScorer makeScorer = ScorerFactory.makeScorer(this._model._parms._kernel_type, this._model._parms.kernelParms(), this._model._output._compressed_svs);
            double[] dArr = new double[chunkArr.length];
            for (int i = 0; i < chunkArr[0]._len; i++) {
                for (int i2 = 0; i2 < chunkArr.length; i2++) {
                    dArr[i2] = chunkArr[i2].atd(i);
                }
                newChunkArr[0].addNum(makeScorer.score0(dArr) + d);
            }
            for (double d2 : BulkScorerFactory.makeScorer(this._model._parms._kernel_type, this._model._parms.kernelParms(), this._model._output._compressed_svs, (int) this._model._output._svs_count, true).bulkScore0(chunkArr)) {
                newChunkArr[1].addNum(d2 + d);
            }
            for (double d3 : BulkScorerFactory.makeScorer(this._model._parms._kernel_type, this._model._parms.kernelParms(), this._model._output._compressed_svs, (int) this._model._output._svs_count, true).bulkScore0(chunkArr)) {
                newChunkArr[2].addNum(d3 + d);
            }
        }
    }

    /* loaded from: input_file:hex/psvm/PSVMTest$SVMTrainer.class */
    private static class SVMTrainer extends H2O.RemoteRunnable<SVMTrainer> {
        private final PSVMModel.PSVMParameters _parms;
        private PSVMModel _model;

        private SVMTrainer(PSVMModel.PSVMParameters pSVMParameters) {
            this._parms = pSVMParameters;
        }

        public void run() {
            this._model = new PSVM(this._parms).trainModel().get();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public PSVMModel train() {
            return ((SVMTrainer) H2O.runOnLeaderNode(this))._model;
        }
    }

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testSplice() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/splice/splice.svm");
            Scope.track(new Frame[]{parse_test_file});
            PSVMModel.PSVMParameters pSVMParameters = new PSVMModel.PSVMParameters();
            pSVMParameters._gamma = 0.01d;
            pSVMParameters._rank_ratio = 0.1d;
            pSVMParameters._train = parse_test_file._key;
            pSVMParameters._response_column = "C1";
            PSVMModel train = new SVMTrainer(pSVMParameters).train();
            Assert.assertNotNull(train);
            Scope.track_generic(train);
            Assert.assertEquals(2.38873807d, train._output._rho, 1.0E-6d);
            Assert.assertEquals(662L, train._output._svs_count);
            Assert.assertEquals(612L, train._output._bsv_count);
            Assert.assertNotNull(train._output._compressed_svs);
            Assert.assertNotEquals(0L, train._output._compressed_svs.length);
            Frame parse_test_file2 = parse_test_file("./smalldata/splice/splice_icf100_preds.csv");
            Scope.track(new Frame[]{parse_test_file2});
            parse_test_file2.replace(parse_test_file2.find("predict"), Scope.track(Scope.track(new TransformWrappedVec(parse_test_file2.vec("score"), new AstSgn()).toStringVec()).toCategoricalVec()));
            Frame score = train.score(parse_test_file);
            Scope.track(new Frame[]{score});
            ModelMetricsSupervised fromDKV = ModelMetrics.getFromDKV(train, parse_test_file);
            Assert.assertNotNull(fromDKV);
            Scope.track_generic(fromDKV);
            System.out.println(score.toTwoDimTable().toString());
            assertVecEquals(parse_test_file2.vec("predict"), score.vec("predict"), 0.0d);
            checkCM(train, parse_test_file, parse_test_file.vec("C1"), score.vec(0));
            checkScorers(train, parse_test_file, parse_test_file2.vec("score"));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testProstate() {
        try {
            Scope.enter();
            Frame categoricalCol = parse_test_file("./smalldata/logreg/prostate_train.csv").toCategoricalCol("CAPSULE");
            Scope.track(new Frame[]{categoricalCol});
            Frame categoricalCol2 = parse_test_file("./smalldata/logreg/prostate_test.csv").toCategoricalCol("CAPSULE");
            Scope.track(new Frame[]{categoricalCol2});
            PSVMModel.PSVMParameters pSVMParameters = new PSVMModel.PSVMParameters();
            pSVMParameters._train = categoricalCol._key;
            pSVMParameters._response_column = "CAPSULE";
            pSVMParameters._gamma = 0.1d;
            pSVMParameters._hyper_param = 2.0d;
            PSVMModel pSVMModel = new PSVM(pSVMParameters).trainModel().get();
            Assert.assertNotNull(pSVMModel);
            Scope.track_generic(pSVMModel);
            Scope.track(new Frame[]{pSVMModel.score(categoricalCol)});
            Assert.assertNotNull(ModelMetrics.getFromDKV(pSVMModel, categoricalCol));
            Frame score = pSVMModel.score(categoricalCol2);
            Scope.track(new Frame[]{score});
            Assert.assertNotNull(ModelMetrics.getFromDKV(pSVMModel, categoricalCol2));
            checkCM(pSVMModel, categoricalCol2, categoricalCol2.vec(pSVMParameters._response_column), score.vec(0));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testProstateWithCategoricals() {
        PSVMModel.PSVMModelOutput trainOnProstate = trainOnProstate(false);
        PSVMModel.PSVMModelOutput trainOnProstate2 = trainOnProstate(true);
        Assert.assertEquals(trainOnProstate2._training_metrics._MSE, trainOnProstate._training_metrics._MSE, 0.0d);
        Assert.assertEquals(trainOnProstate2._validation_metrics._MSE, trainOnProstate._validation_metrics._MSE, 0.0d);
    }

    private PSVMModel.PSVMModelOutput trainOnProstate(boolean z) {
        try {
            Scope.enter();
            Frame categoricalCol = parse_test_file("./smalldata/logreg/prostate.csv").toCategoricalCol("CAPSULE").toCategoricalCol("RACE");
            Scope.track(new Frame[]{categoricalCol});
            if (z) {
                categoricalCol.insertVec(0, "RACE", categoricalCol.remove("RACE"));
                Frame frame = (Frame) new FrameUtils.CategoricalOneHotEncoder(categoricalCol, new String[]{"CAPSULE"}).exec().get();
                Scope.track(new Frame[]{frame});
                categoricalCol = frame;
            }
            Frame[] splitFrameTrainValid = splitFrameTrainValid(categoricalCol, 0.8d, -889275714L);
            Frame track = Scope.track(new Frame[]{splitFrameTrainValid[0]});
            Frame track2 = Scope.track(new Frame[]{splitFrameTrainValid[1]});
            PSVMModel.PSVMParameters pSVMParameters = new PSVMModel.PSVMParameters();
            pSVMParameters._train = track._key;
            pSVMParameters._valid = track2._key;
            pSVMParameters._response_column = "CAPSULE";
            pSVMParameters._ignored_columns = new String[]{"ID"};
            pSVMParameters._gamma = 0.4d;
            pSVMParameters._hyper_param = 2.0d;
            pSVMParameters._disable_training_metrics = false;
            PSVMModel pSVMModel = new PSVM(pSVMParameters).trainModel().get();
            Assert.assertNotNull(pSVMModel);
            Scope.track_generic(pSVMModel);
            PSVMModel.PSVMModelOutput pSVMModelOutput = pSVMModel._output;
            Scope.exit(new Key[0]);
            return pSVMModelOutput;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testSVMGuide1() {
        try {
            Scope.enter();
            Frame categoricalCol = parse_test_file("./smalldata/svm_test/svmguide1.svm").toCategoricalCol("C1");
            Scope.track(new Frame[]{categoricalCol});
            PSVMModel.PSVMParameters pSVMParameters = new PSVMModel.PSVMParameters();
            pSVMParameters._train = categoricalCol._key;
            pSVMParameters._response_column = "C1";
            pSVMParameters._gamma = 0.1d;
            PSVMModel pSVMModel = new PSVM(pSVMParameters).trainModel().get();
            Assert.assertNotNull(pSVMModel);
            Scope.track_generic(pSVMModel);
            Frame categoricalCol2 = parse_test_file("./smalldata/svm_test/svmguide1_test.svm").toCategoricalCol("C1");
            Scope.track(new Frame[]{categoricalCol2});
            Scope.track(new Frame[]{pSVMModel.score(categoricalCol2)});
            ModelMetricsBinomial fromDKV = ModelMetrics.getFromDKV(pSVMModel, categoricalCol2);
            Assert.assertNotNull(fromDKV);
            Assert.assertEquals(0.1d, fromDKV.mse(), 0.05d);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testSVMGuide3() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/svm_test/svmguide3scale.svm");
            Scope.track(new Frame[]{parse_test_file});
            PSVMModel.PSVMParameters pSVMParameters = new PSVMModel.PSVMParameters();
            pSVMParameters._train = parse_test_file._key;
            pSVMParameters._response_column = "C1";
            pSVMParameters._gamma = 0.125d;
            pSVMParameters._hyper_param = 1.0d;
            PSVMModel pSVMModel = new PSVM(pSVMParameters).trainModel().get();
            Assert.assertNotNull(pSVMModel);
            Scope.track_generic(pSVMModel);
            Frame parse_test_file2 = parse_test_file("./smalldata/svm_test/svmguide3scale_test.svm");
            Scope.track(new Frame[]{parse_test_file2});
            Frame score = pSVMModel.score(parse_test_file2);
            Scope.track(new Frame[]{score});
            Assert.assertEquals(1.0d, score.vec(0).nzCnt() / score.numRows(), 0.15d);
            checkCM(pSVMModel, parse_test_file2, parse_test_file2.vec(pSVMParameters._response_column), score.vec(0));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private static Frame[] splitFrameTrainValid(Frame frame, double d, long j) {
        return ShuffleSplitFrame.shuffleSplitFrame(frame, new Key[]{Key.make(frame._key + "_train"), Key.make(frame._key + "_valid")}, new double[]{d, 1.0d - d}, j);
    }

    private static void checkCM(PSVMModel pSVMModel, Frame frame, Vec vec, Vec vec2) {
        String[] strArr = pSVMModel._output._domains[pSVMModel._output.responseIdx()];
        Scope.enter();
        try {
            if (!vec.isCategorical()) {
                Vec track = Scope.track(vec.toCategoricalVec());
                if ("1".equals(track.domain()[track.domain().length - 1])) {
                    track.domain()[track.domain().length - 1] = "+1";
                }
                vec = Scope.track(track.adaptTo(strArr));
            }
            if (!vec2.isCategorical()) {
                vec2 = Scope.track(vec2.toCategoricalVec());
            }
            ConfusionMatrix buildCM = ConfusionMatrixTest.buildCM(vec, vec2);
            ConfusionMatrix cm = ModelMetricsBinomial.getFromDKV(pSVMModel, frame).cm();
            System.out.println(cm.table().toString());
            ConfusionMatrixTest.assertCMEqual(strArr, buildCM._cm, cm);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private static void checkScorers(PSVMModel pSVMModel, Frame frame, Vec vec) {
        Assert.assertEquals(pSVMModel._parms._response_column, frame.name(0));
        Frame frame2 = new Frame(frame);
        frame2.remove(pSVMModel._parms._response_column);
        Frame outputFrame = new CheckScorersTask(pSVMModel._key).doAll(3, (byte) 3, frame2).outputFrame();
        Scope.track(new Frame[]{outputFrame});
        assertVecEquals(vec, outputFrame.vec(0), 1.0E-6d);
        assertVecEquals(vec, outputFrame.vec(1), 1.0E-6d);
        assertVecEquals(vec, outputFrame.vec(2), 1.0E-6d);
    }
}
