/*
 * Decompiled with CFR 0.152.
 */
package hex.util;

import hex.CreateFrame;
import hex.DataInfo;
import hex.Model;
import hex.aggregator.Aggregator;
import hex.aggregator.AggregatorModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.H2O;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.parser.ParseDataset;
import water.parser.ParserTest;
import water.util.Log;

public class AggregatorTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        AggregatorTest.stall_till_cloudsize((int)1);
    }

    @Test
    public void testAggregator100() {
        this.testAggregator(100);
    }

    @Test
    public void testAggregator1k() {
        this.testAggregator(1000);
    }

    @Test
    public void testAggregator13() {
        this.testAggregator(13);
    }

    @Test
    public void testAggregator10k() {
        this.testAggregator(10000);
    }

    public void testAggregator(int max) {
        CreateFrame cf = new CreateFrame();
        cf.rows = 100000L;
        cf.cols = 2;
        cf.categorical_fraction = 0.1;
        cf.integer_fraction = 0.3;
        cf.real_range = 100L;
        cf.integer_range = 100L;
        cf.seed = 1234L;
        Frame frame = (Frame)cf.execImpl().get();
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = max;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        System.out.println(output.toTwoDimTable(0L, 10));
        frame.delete();
        this.checkNumExemplars(agg);
        output.remove();
        agg.remove();
    }

    @Test
    public void testAggregatorEigen() {
        CreateFrame cf = new CreateFrame();
        cf.rows = 1000L;
        cf.cols = 10;
        cf.categorical_fraction = 0.6;
        cf.integer_fraction = 0.0;
        cf.binary_fraction = 0.0;
        cf.real_range = 100L;
        cf.integer_range = 100L;
        cf.missing_fraction = 0.0;
        cf.factors = 5;
        cf.seed = 1234L;
        Frame frame = (Frame)cf.execImpl().get();
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Eigen;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        System.out.println(output.toTwoDimTable(0L, 10));
        Log.info((Object[])new Object[]{"Number of exemplars: " + agg._exemplars.length});
        output.remove();
        frame.remove();
        agg.remove();
    }

    @Test
    public void testAggregatorEigenHighCardinality() {
        CreateFrame cf = new CreateFrame();
        cf.rows = 10000L;
        cf.cols = 10;
        cf.categorical_fraction = 0.6;
        cf.integer_fraction = 0.0;
        cf.binary_fraction = 0.0;
        cf.real_range = 100L;
        cf.integer_range = 100L;
        cf.missing_fraction = 0.0;
        cf.factors = 1000;
        cf.seed = 1234L;
        Frame frame = (Frame)cf.execImpl().get();
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Eigen;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        System.out.println(output.toTwoDimTable(0L, 10));
        Log.info((Object[])new Object[]{"Number of exemplars: " + agg._exemplars.length});
        output.remove();
        frame.remove();
        agg.remove();
    }

    @Test
    public void testAggregatorEigenHighCardinalityEnum() {
        CreateFrame cf = new CreateFrame();
        cf.rows = 10000L;
        cf.cols = 10;
        cf.categorical_fraction = 0.6;
        cf.integer_fraction = 0.0;
        cf.binary_fraction = 0.0;
        cf.real_range = 100L;
        cf.integer_range = 100L;
        cf.missing_fraction = 0.0;
        cf.factors = 1000;
        cf.seed = 1234L;
        Frame frame = (Frame)cf.execImpl().get();
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Enum;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        System.out.println(output.toTwoDimTable(0L, (int)output.numRows()));
        Log.info((Object[])new Object[]{"Number of exemplars: " + agg._exemplars.length});
        output.remove();
        frame.remove();
        agg.remove();
    }

    @Test
    public void testAggregatorEigenLowCardinalityEnum() {
        String[] data = new String[]{"1|2|A|A", "1|2|A|A", "1|2|A|A", "1|2|A|A", "1|2|A|A", "2|2|A|B", "2|2|A|A", "1|4|A|A", "1|2|B|A", "1|2|B|A", "1|2|A|A", "1|2|A|A", "4|5|C|A", "4|5|D|A", "2|5|D|A", "3|5|E|A", "4|5|F|A", "4|5|G|A", "4|5|H|A", "4|5|I|A", "4|5|J|A", "4|5|K|A", "4|5|L|A", "4|5|M|A", "4|5|N|A", "4|5|O|A", "4|5|P|A"};
        StringBuilder sb1 = new StringBuilder();
        for (String ds : data) {
            sb1.append(ds).append("\n");
        }
        Key k1 = ParserTest.makeByteVec((String[])new String[]{sb1.toString()});
        Key r1 = Key.make((String)"r1");
        Frame frame = ParseDataset.parse((Key)r1, (Key[])new Key[]{k1});
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = 5;
        parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Enum;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        System.out.println(output.toTwoDimTable(0L, (int)output.numRows()));
        Log.info((Object[])new Object[]{"Number of exemplars: " + agg._exemplars.length});
        Assert.assertTrue((agg._exemplars.length == 17 ? 1 : 0) != 0);
        output.remove();
        frame.remove();
        agg.remove();
    }

    @Test
    public void testAggregatorEigenLowCardinalityEnumLimited() {
        String[] data = new String[]{"1|2|A|A", "1|2|A|A", "1|2|A|A", "1|2|A|A", "1|2|A|A", "2|2|A|B", "2|2|A|A", "1|4|A|A", "1|2|B|A", "1|2|B|A", "1|2|A|A", "1|2|A|A", "4|5|C|A", "4|5|D|A", "2|5|D|A", "3|5|E|A", "4|5|F|A", "4|5|G|A", "4|5|H|A", "4|5|I|A", "4|5|J|A", "4|5|K|A", "4|5|L|A", "4|5|M|A", "4|5|N|A", "4|5|O|A", "4|5|P|A"};
        StringBuilder sb1 = new StringBuilder();
        for (String ds : data) {
            sb1.append(ds).append("\n");
        }
        Key k1 = ParserTest.makeByteVec((String[])new String[]{sb1.toString()});
        Key r1 = Key.make((String)"r1");
        Frame frame = ParseDataset.parse((Key)r1, (Key[])new Key[]{k1});
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = 5;
        parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.EnumLimited;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        System.out.println(output.toTwoDimTable(0L, (int)output.numRows()));
        Log.info((Object[])new Object[]{"Number of exemplars: " + agg._exemplars.length});
        Assert.assertTrue((agg._exemplars.length == 7 ? 1 : 0) != 0);
        output.remove();
        frame.remove();
        agg.remove();
    }

    @Test
    public void testAggregatorBinary() {
        CreateFrame cf = new CreateFrame();
        cf.rows = 1000L;
        cf.cols = 10;
        cf.categorical_fraction = 0.6;
        cf.integer_fraction = 0.0;
        cf.binary_fraction = 0.0;
        cf.real_range = 100L;
        cf.integer_range = 100L;
        cf.missing_fraction = 0.1;
        cf.factors = 5;
        cf.seed = 1234L;
        Frame frame = (Frame)cf.execImpl().get();
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._transform = DataInfo.TransformType.NORMALIZE;
        parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Binary;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        System.out.println(output.toTwoDimTable(0L, 10));
        Log.info((Object[])new Object[]{"Number of exemplars: " + agg._exemplars.length});
        Assert.assertTrue((agg._exemplars.length == 1000 ? 1 : 0) != 0);
        output.remove();
        frame.remove();
        agg.remove();
    }

    @Test
    public void testAggregatorOneHot() {
        Scope.enter();
        CreateFrame cf = new CreateFrame();
        cf.rows = 1000L;
        cf.cols = 10;
        cf.categorical_fraction = 0.6;
        cf.integer_fraction = 0.0;
        cf.binary_fraction = 0.0;
        cf.real_range = 100L;
        cf.integer_range = 100L;
        cf.missing_fraction = 0.1;
        cf.factors = 5;
        cf.seed = 1234L;
        Frame frame = (Frame)cf.execImpl().get();
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = 278;
        parms._transform = DataInfo.TransformType.NORMALIZE;
        parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.OneHotExplicit;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        System.out.println(output.toTwoDimTable(0L, 10));
        this.checkNumExemplars(agg);
        output.remove();
        frame.remove();
        agg.remove();
        Scope.exit((Key[])new Key[0]);
    }

    @Ignore
    @Test
    public void testAirlines() {
        Frame frame = AggregatorTest.parse_test_file((String)"smalldata/airlines/allyears2k_headers.zip");
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = 500;
        parms._rel_tol_num_exemplars = 0.05;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        frame.delete();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        output.remove();
        this.checkNumExemplars(agg);
        agg.remove();
    }

    @Test
    public void testCovtype() {
        Frame frame = AggregatorTest.parse_test_file((String)"smalldata/covtype/covtype.20k.data");
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = 500;
        parms._rel_tol_num_exemplars = 0.05;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        frame.delete();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        Log.info((Object[])new Object[]{"Exemplars: " + output.toString()});
        output.remove();
        this.checkNumExemplars(agg);
        agg.remove();
    }

    public void checkNumExemplars(AggregatorModel m) {
        Log.info((Object[])new Object[]{"Number of exemplars: " + m._exemplars.length});
        Assert.assertTrue(((double)m._exemplars.length >= (1.0 - ((AggregatorModel.AggregatorParameters)m._parms)._rel_tol_num_exemplars) * (double)((AggregatorModel.AggregatorParameters)m._parms)._target_num_exemplars ? 1 : 0) != 0);
        Assert.assertTrue(((double)m._exemplars.length <= (1.0 + ((AggregatorModel.AggregatorParameters)m._parms)._rel_tol_num_exemplars) * (double)((AggregatorModel.AggregatorParameters)m._parms)._target_num_exemplars ? 1 : 0) != 0);
    }

    @Test
    public void testChunks() {
        Frame frame = AggregatorTest.parse_test_file((String)"smalldata/covtype/covtype.20k.data");
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = 137;
        parms._rel_tol_num_exemplars = 0.05;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        this.checkNumExemplars(agg);
        output.remove();
        agg.remove();
        for (int i : new int[]{1, 2, 5, 10, 50, 100}) {
            Key key = Key.make();
            RebalanceDataSet rb = new RebalanceDataSet(frame, key, i);
            H2O.submitTask((H2O.H2OCountedCompleter)rb);
            rb.join();
            Frame rebalanced = (Frame)DKV.get((Key)key).get();
            parms = new AggregatorModel.AggregatorParameters();
            parms._train = frame._key;
            parms._target_num_exemplars = 137;
            parms._rel_tol_num_exemplars = 0.05;
            start = System.currentTimeMillis();
            AggregatorModel agg2 = (AggregatorModel)new Aggregator(parms).trainModel().get();
            System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
            agg2.checkConsistency();
            Log.info((Object[])new Object[]{"Number of exemplars for " + i + " chunks: " + agg2._exemplars.length});
            rebalanced.delete();
            Assert.assertTrue((Math.abs(agg._exemplars.length - agg2._exemplars.length) == 0 ? 1 : 0) != 0);
            output = (Frame)((AggregatorModel.AggregatorOutput)agg2._output)._output_frame.get();
            output.remove();
            this.checkNumExemplars(agg);
            agg2.remove();
        }
        frame.delete();
    }

    @Ignore
    @Test
    public void testCovtypeMemberIndices() {
        Frame frame = AggregatorTest.parse_test_file((String)"smalldata/covtype/covtype.20k.data");
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = 117;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        Log.info((Object[])new Object[]{"Number of exemplars: " + agg._exemplars.length});
        Key memberKey = Key.make();
        for (int i = 0; i < agg._exemplars.length; ++i) {
            Frame members = agg.scoreExemplarMembers(memberKey, i);
            assert (members.numRows() == agg._counts[i]);
            members.delete();
        }
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        output.remove();
        this.checkNumExemplars(agg);
        frame.delete();
        agg.remove();
    }

    @Test
    public void testDomains() {
        Frame frame = AggregatorTest.parse_test_file((String)"smalldata/junit/weather.csv");
        for (String s : new String[]{"MaxWindSpeed", "RelHumid9am", "Cloud9am"}) {
            Vec v = frame.vec(s);
            Vec newV = v.toCategoricalVec();
            frame.remove(s);
            frame.add(s, newV);
            v.remove();
        }
        DKV.put((Keyed)frame);
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        parms._target_num_exemplars = 17;
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        Assert.assertTrue((output.numRows() <= 17L ? 1 : 0) != 0);
        boolean same = true;
        for (int i = 0; i < frame.numCols(); ++i) {
            if (!frame.vec(i).isCategorical()) continue;
            boolean bl = same = frame.domains()[i].length == output.domains()[i].length;
            if (!same) break;
        }
        frame.remove();
        output.remove();
        agg.remove();
        Assert.assertFalse((boolean)same);
    }

    @Ignore
    @Test
    public void testMNIST() {
        Frame frame = AggregatorTest.parse_test_file((String)"bigdata/laptop/mnist/train.csv.gz");
        AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
        parms._train = frame._key;
        long start = System.currentTimeMillis();
        AggregatorModel agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
        System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
        agg.checkConsistency();
        frame.delete();
        Frame output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
        output.remove();
        Log.info((Object[])new Object[]{"Number of exemplars: " + agg._exemplars.length});
        this.checkNumExemplars(agg);
        agg.remove();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCovtypeMapping() {
        Frame frame = null;
        Frame output = null;
        Frame mapping = null;
        AggregatorModel agg = null;
        try {
            frame = AggregatorTest.parse_test_file((String)"smalldata/covtype/covtype.20k.data");
            AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
            parms._train = frame._key;
            parms._target_num_exemplars = 500;
            parms._rel_tol_num_exemplars = 0.05;
            parms._save_mapping_frame = true;
            long start = System.currentTimeMillis();
            agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
            System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
            agg.checkConsistency();
            output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
            Log.info((Object[])new Object[]{"Exemplars: " + output.toString()});
            mapping = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._mapping_frame.get();
            this.checkNumExemplars(agg);
        }
        finally {
            frame.delete();
            output.remove();
            mapping.remove();
            agg.remove();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testAirlinesUuidNA() {
        Frame frame = null;
        Frame output = null;
        Frame mapping = null;
        AggregatorModel agg = null;
        try {
            frame = AggregatorTest.parse_test_file((String)"smalldata/airlines/uuid_airline.csv");
            AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
            parms._train = frame._key;
            parms._target_num_exemplars = 10;
            parms._rel_tol_num_exemplars = 0.5;
            parms._save_mapping_frame = true;
            long start = System.currentTimeMillis();
            agg = (AggregatorModel)new Aggregator(parms).trainModel().get();
            System.out.println("AggregatorModel finished in: " + (double)(System.currentTimeMillis() - start) / 1000.0 + " seconds");
            agg.checkConsistency();
            output = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._output_frame.get();
            Log.info((Object[])new Object[]{"Exemplars: " + output.toString()});
            mapping = (Frame)((AggregatorModel.AggregatorOutput)agg._output)._mapping_frame.get();
            this.checkNumExemplars(agg);
        }
        finally {
            frame.delete();
            output.remove();
            mapping.remove();
            agg.remove();
        }
    }
}

