/*
 * Decompiled with CFR 0.152.
 */
package hex.psvm;

import hex.ConfusionMatrix;
import hex.ConfusionMatrixTest;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsSupervised;
import hex.genmodel.algos.psvm.KernelParameters;
import hex.genmodel.algos.psvm.KernelType;
import hex.genmodel.algos.psvm.ScorerFactory;
import hex.genmodel.algos.psvm.SupportVectorScorer;
import hex.psvm.BulkScorerFactory;
import hex.psvm.BulkSupportVectorScorer;
import hex.psvm.PSVM;
import hex.psvm.PSVMModel;
import hex.splitframe.ShuffleSplitFrame;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.H2O;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.Scope;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.TransformWrappedVec;
import water.fvec.Vec;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.prims.math.AstSgn;
import water.util.FrameUtils;

public class PSVMTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        PSVMTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSplice() {
        try {
            Scope.enter();
            Frame fr = PSVMTest.parse_test_file((String)"./smalldata/splice/splice.svm");
            Scope.track((Frame[])new Frame[]{fr});
            PSVMModel.PSVMParameters parms = new PSVMModel.PSVMParameters();
            parms._gamma = 0.01;
            parms._rank_ratio = 0.1;
            parms._train = fr._key;
            parms._response_column = "C1";
            PSVMModel model = new SVMTrainer(parms).train();
            Assert.assertNotNull((Object)model);
            Scope.track_generic((Keyed)model);
            Assert.assertEquals((double)2.38873807, (double)((PSVMModel.PSVMModelOutput)model._output)._rho, (double)1.0E-6);
            Assert.assertEquals((long)662L, (long)((PSVMModel.PSVMModelOutput)model._output)._svs_count);
            Assert.assertEquals((long)612L, (long)((PSVMModel.PSVMModelOutput)model._output)._bsv_count);
            Assert.assertNotNull((Object)((PSVMModel.PSVMModelOutput)model._output)._compressed_svs);
            Assert.assertNotEquals((long)0L, (long)((PSVMModel.PSVMModelOutput)model._output)._compressed_svs.length);
            Frame expected = PSVMTest.parse_test_file((String)"./smalldata/splice/splice_icf100_preds.csv");
            Scope.track((Frame[])new Frame[]{expected});
            expected.replace(expected.find("predict"), Scope.track((Vec)Scope.track((Vec)new TransformWrappedVec(expected.vec("score"), (AstPrimitive)new AstSgn()).toStringVec()).toCategoricalVec()));
            Frame predicted = model.score(fr);
            Scope.track((Frame[])new Frame[]{predicted});
            ModelMetricsSupervised mm = (ModelMetricsSupervised)ModelMetrics.getFromDKV((Model)model, (Frame)fr);
            Assert.assertNotNull((Object)mm);
            Scope.track_generic((Keyed)mm);
            System.out.println(predicted.toTwoDimTable().toString());
            PSVMTest.assertVecEquals((Vec)expected.vec("predict"), (Vec)predicted.vec("predict"), (double)0.0);
            PSVMTest.checkCM(model, fr, fr.vec("C1"), predicted.vec(0));
            PSVMTest.checkScorers(model, fr, expected.vec("score"));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testProstate() {
        try {
            Scope.enter();
            Frame train = PSVMTest.parse_test_file((String)"./smalldata/logreg/prostate_train.csv").toCategoricalCol("CAPSULE");
            Scope.track((Frame[])new Frame[]{train});
            Frame test = PSVMTest.parse_test_file((String)"./smalldata/logreg/prostate_test.csv").toCategoricalCol("CAPSULE");
            Scope.track((Frame[])new Frame[]{test});
            PSVMModel.PSVMParameters parms = new PSVMModel.PSVMParameters();
            parms._train = train._key;
            parms._response_column = "CAPSULE";
            parms._gamma = 0.1;
            parms._hyper_param = 2.0;
            PSVM svm = new PSVM(parms);
            PSVMModel model = (PSVMModel)svm.trainModel().get();
            Assert.assertNotNull((Object)model);
            Scope.track_generic((Keyed)model);
            Frame predsTrain = model.score(train);
            Scope.track((Frame[])new Frame[]{predsTrain});
            ModelMetricsBinomial mmbTrain = (ModelMetricsBinomial)ModelMetrics.getFromDKV((Model)model, (Frame)train);
            Assert.assertNotNull((Object)mmbTrain);
            Frame predsTest = model.score(test);
            Scope.track((Frame[])new Frame[]{predsTest});
            ModelMetricsBinomial mmbTest = (ModelMetricsBinomial)ModelMetrics.getFromDKV((Model)model, (Frame)test);
            Assert.assertNotNull((Object)mmbTest);
            PSVMTest.checkCM(model, test, test.vec(parms._response_column), predsTest.vec(0));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    @Test
    public void testProstateWithCategoricals() {
        PSVMModel.PSVMModelOutput regular = this.trainOnProstate(false);
        PSVMModel.PSVMModelOutput encoded = this.trainOnProstate(true);
        Assert.assertEquals((double)encoded._training_metrics._MSE, (double)regular._training_metrics._MSE, (double)0.0);
        Assert.assertEquals((double)encoded._validation_metrics._MSE, (double)regular._validation_metrics._MSE, (double)0.0);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private PSVMModel.PSVMModelOutput trainOnProstate(boolean encode) {
        try {
            Scope.enter();
            Frame fr = PSVMTest.parse_test_file((String)"./smalldata/logreg/prostate.csv").toCategoricalCol("CAPSULE").toCategoricalCol("RACE");
            Scope.track((Frame[])new Frame[]{fr});
            if (encode) {
                fr.insertVec(0, "RACE", fr.remove("RACE"));
                Frame encoded = (Frame)new FrameUtils.CategoricalOneHotEncoder(fr, new String[]{"CAPSULE"}).exec().get();
                Scope.track((Frame[])new Frame[]{encoded});
                fr = encoded;
            }
            Frame[] fs = PSVMTest.splitFrameTrainValid(fr, 0.8, -889275714L);
            Frame train = Scope.track((Frame[])new Frame[]{fs[0]});
            Frame valid = Scope.track((Frame[])new Frame[]{fs[1]});
            PSVMModel.PSVMParameters parms = new PSVMModel.PSVMParameters();
            parms._train = train._key;
            parms._valid = valid._key;
            parms._response_column = "CAPSULE";
            parms._ignored_columns = new String[]{"ID"};
            parms._gamma = 0.4;
            parms._hyper_param = 2.0;
            parms._disable_training_metrics = false;
            PSVM svm = new PSVM(parms);
            PSVMModel model = (PSVMModel)svm.trainModel().get();
            Assert.assertNotNull((Object)model);
            Scope.track_generic((Keyed)model);
            PSVMModel.PSVMModelOutput pSVMModelOutput = (PSVMModel.PSVMModelOutput)model._output;
            return pSVMModelOutput;
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSVMGuide1() {
        try {
            Scope.enter();
            Frame fr = PSVMTest.parse_test_file((String)"./smalldata/svm_test/svmguide1.svm").toCategoricalCol("C1");
            Scope.track((Frame[])new Frame[]{fr});
            PSVMModel.PSVMParameters parms = new PSVMModel.PSVMParameters();
            parms._train = fr._key;
            parms._response_column = "C1";
            parms._gamma = 0.1;
            PSVM svm = new PSVM(parms);
            PSVMModel model = (PSVMModel)svm.trainModel().get();
            Assert.assertNotNull((Object)model);
            Scope.track_generic((Keyed)model);
            Frame test = PSVMTest.parse_test_file((String)"./smalldata/svm_test/svmguide1_test.svm").toCategoricalCol("C1");
            Scope.track((Frame[])new Frame[]{test});
            Frame testPreds = model.score(test);
            Scope.track((Frame[])new Frame[]{testPreds});
            ModelMetricsBinomial mmb = (ModelMetricsBinomial)ModelMetrics.getFromDKV((Model)model, (Frame)test);
            Assert.assertNotNull((Object)mmb);
            Assert.assertEquals((double)0.1, (double)mmb.mse(), (double)0.05);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSVMGuide3() {
        try {
            Scope.enter();
            Frame fr = PSVMTest.parse_test_file((String)"./smalldata/svm_test/svmguide3scale.svm");
            Scope.track((Frame[])new Frame[]{fr});
            PSVMModel.PSVMParameters parms = new PSVMModel.PSVMParameters();
            parms._train = fr._key;
            parms._response_column = "C1";
            parms._gamma = 0.125;
            parms._hyper_param = 1.0;
            PSVM svm = new PSVM(parms);
            PSVMModel model = (PSVMModel)svm.trainModel().get();
            Assert.assertNotNull((Object)model);
            Scope.track_generic((Keyed)model);
            Frame test = PSVMTest.parse_test_file((String)"./smalldata/svm_test/svmguide3scale_test.svm");
            Scope.track((Frame[])new Frame[]{test});
            Frame testPreds = model.score(test);
            Scope.track((Frame[])new Frame[]{testPreds});
            Assert.assertEquals((double)1.0, (double)((double)testPreds.vec(0).nzCnt() / (double)testPreds.numRows()), (double)0.15);
            PSVMTest.checkCM(model, test, test.vec(parms._response_column), testPreds.vec(0));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    private static Frame[] splitFrameTrainValid(Frame fr, double ratio, long seed) {
        return ShuffleSplitFrame.shuffleSplitFrame((Frame)fr, (Key[])new Key[]{Key.make((String)(fr._key + "_train")), Key.make((String)(fr._key + "_valid"))}, (double[])new double[]{ratio, 1.0 - ratio}, (long)seed);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void checkCM(PSVMModel model, Frame frame, Vec actuals, Vec predicted) {
        String[] domain = ((PSVMModel.PSVMModelOutput)model._output)._domains[((PSVMModel.PSVMModelOutput)model._output).responseIdx()];
        Scope.enter();
        try {
            if (!actuals.isCategorical()) {
                if ("1".equals((actuals = Scope.track((Vec)actuals.toCategoricalVec())).domain()[actuals.domain().length - 1])) {
                    actuals.domain()[actuals.domain().length - 1] = "+1";
                }
                actuals = Scope.track((Vec)actuals.adaptTo(domain));
            }
            if (!predicted.isCategorical()) {
                predicted = Scope.track((Vec)predicted.toCategoricalVec());
            }
            ConfusionMatrix expectedCM = ConfusionMatrixTest.buildCM((Vec)actuals, (Vec)predicted);
            ConfusionMatrix actualCM = ModelMetricsBinomial.getFromDKV((Model)model, (Frame)frame).cm();
            System.out.println(actualCM.table().toString());
            ConfusionMatrixTest.assertCMEqual((String[])domain, (double[][])expectedCM._cm, (ConfusionMatrix)actualCM);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    private static void checkScorers(PSVMModel model, Frame f, Vec expected) {
        Assert.assertEquals((Object)((PSVMModel.PSVMParameters)model._parms)._response_column, (Object)f.name(0));
        Frame adapted = new Frame(f);
        adapted.remove(((PSVMModel.PSVMParameters)model._parms)._response_column);
        Frame scores = new CheckScorersTask((Key<PSVMModel>)model._key).doAll(3, (byte)3, adapted).outputFrame();
        Scope.track((Frame[])new Frame[]{scores});
        PSVMTest.assertVecEquals((Vec)expected, (Vec)scores.vec(0), (double)1.0E-6);
        PSVMTest.assertVecEquals((Vec)expected, (Vec)scores.vec(1), (double)1.0E-6);
        PSVMTest.assertVecEquals((Vec)expected, (Vec)scores.vec(2), (double)1.0E-6);
    }

    private static class CheckScorersTask
    extends MRTask {
        private final Key<PSVMModel> _model_key;
        private transient PSVMModel _model;

        CheckScorersTask(Key<PSVMModel> modelKey) {
            this._model_key = modelKey;
        }

        protected void setupLocal() {
            this._model = (PSVMModel)this._model_key.get();
        }

        public void map(Chunk[] cs, NewChunk[] ncs) {
            double[] scoresPojo;
            double[] scoresRaw;
            double rho = ((PSVMModel.PSVMModelOutput)this._model._output)._rho;
            SupportVectorScorer scorer = ScorerFactory.makeScorer((KernelType)((PSVMModel.PSVMParameters)this._model._parms)._kernel_type, (KernelParameters)((PSVMModel.PSVMParameters)this._model._parms).kernelParms(), (byte[])((PSVMModel.PSVMModelOutput)this._model._output)._compressed_svs);
            double[] row = new double[cs.length];
            for (int i = 0; i < cs[0]._len; ++i) {
                for (int j = 0; j < cs.length; ++j) {
                    row[j] = cs[j].atd(i);
                }
                double s = scorer.score0(row);
                ncs[0].addNum(s + rho);
            }
            BulkSupportVectorScorer rawBulkScorer = BulkScorerFactory.makeScorer((KernelType)((PSVMModel.PSVMParameters)this._model._parms)._kernel_type, (KernelParameters)((PSVMModel.PSVMParameters)this._model._parms).kernelParms(), (byte[])((PSVMModel.PSVMModelOutput)this._model._output)._compressed_svs, (int)((int)((PSVMModel.PSVMModelOutput)this._model._output)._svs_count), (boolean)true);
            for (double s : scoresRaw = rawBulkScorer.bulkScore0(cs)) {
                ncs[1].addNum(s + rho);
            }
            BulkSupportVectorScorer pojoBulkScorer = BulkScorerFactory.makeScorer((KernelType)((PSVMModel.PSVMParameters)this._model._parms)._kernel_type, (KernelParameters)((PSVMModel.PSVMParameters)this._model._parms).kernelParms(), (byte[])((PSVMModel.PSVMModelOutput)this._model._output)._compressed_svs, (int)((int)((PSVMModel.PSVMModelOutput)this._model._output)._svs_count), (boolean)true);
            for (double s : scoresPojo = pojoBulkScorer.bulkScore0(cs)) {
                ncs[2].addNum(s + rho);
            }
        }
    }

    private static class SVMTrainer
    extends H2O.RemoteRunnable<SVMTrainer> {
        private final PSVMModel.PSVMParameters _parms;
        private PSVMModel _model;

        private SVMTrainer(PSVMModel.PSVMParameters parms) {
            this._parms = parms;
        }

        public void run() {
            this._model = (PSVMModel)new PSVM(this._parms).trainModel().get();
        }

        private PSVMModel train() {
            return ((SVMTrainer)H2O.runOnLeaderNode((H2O.RemoteRunnable)this))._model;
        }
    }
}

