/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.targetencoding;

import ai.h2o.targetencoding.BroadcastJoinForTargetEncoder;
import ai.h2o.targetencoding.TargetEncoder;
import ai.h2o.targetencoding.TargetEncoderFrameHelper;
import com.pholser.junit.quickcheck.Mode;
import com.pholser.junit.quickcheck.Property;
import com.pholser.junit.quickcheck.generator.InRange;
import com.pholser.junit.quickcheck.runner.JUnitQuickcheck;
import java.util.Map;
import java.util.Random;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.rapids.Merge;
import water.util.IcedHashMapGeneric;

@RunWith(value=JUnitQuickcheck.class)
public class BroadcastJoinTest
extends TestUtil {
    private Frame fr = null;

    @BeforeClass
    public static void setup() {
        BroadcastJoinTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Property(trials=2, mode=Mode.EXHAUSTIVE)
    public void joinPerformsWithoutLoosingOriginalOrderTest(boolean isZeroBasedFoldValues) {
        Scope.enter();
        try {
            long[] lArray;
            long[] lArray2;
            TestFrameBuilder testFrameBuilder = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA", "fold"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, BroadcastJoinTest.ar((String[])new String[]{"a", "c", "b"}));
            if (isZeroBasedFoldValues) {
                long[] lArray3 = new long[3];
                lArray3[0] = 1L;
                lArray3[1] = 0L;
                lArray2 = lArray3;
                lArray3[2] = 1L;
            } else {
                long[] lArray4 = new long[3];
                lArray4[0] = 2L;
                lArray4[1] = 1L;
                lArray2 = lArray4;
                lArray4[2] = 2L;
            }
            Frame fr = testFrameBuilder.withDataForCol(1, lArray2).withChunkLayout(new long[]{1L, 1L, 1L}).build();
            TestFrameBuilder testFrameBuilder2 = new TestFrameBuilder().withName("testFrame2").withColNames(new String[]{"ColA", "fold", TargetEncoder.NUMERATOR_COL_NAME, TargetEncoder.DENOMINATOR_COL_NAME}).withVecTypes(new byte[]{4, 3, 3, 3}).withDataForCol(0, BroadcastJoinTest.ar((String[])new String[]{"a", "b", "c"}));
            if (isZeroBasedFoldValues) {
                long[] lArray5 = new long[3];
                lArray5[0] = 1L;
                lArray5[1] = 0L;
                lArray = lArray5;
                lArray5[2] = 0L;
            } else {
                long[] lArray6 = new long[3];
                lArray6[0] = 2L;
                lArray6[1] = 1L;
                lArray = lArray6;
                lArray6[2] = 1L;
            }
            Frame rightFr = testFrameBuilder2.withDataForCol(1, lArray).withDataForCol(2, BroadcastJoinTest.ar((long[])new long[]{22L, 33L, 42L})).withDataForCol(3, BroadcastJoinTest.ar((long[])new long[]{44L, 66L, 84L})).withChunkLayout(new long[]{1L, 1L, 1L}).build();
            Vec emptyNumerator = fr.anyVec().makeCon(0.0);
            fr.add(TargetEncoder.NUMERATOR_COL_NAME, emptyNumerator);
            Vec emptyDenominator = fr.anyVec().makeCon(0.0);
            fr.add(TargetEncoder.DENOMINATOR_COL_NAME, emptyDenominator);
            Scope.track((Vec)emptyNumerator);
            Scope.track((Vec)emptyDenominator);
            Frame joined = BroadcastJoinForTargetEncoder.join((Frame)fr, (int[])new int[]{0}, (int)1, (Frame)rightFr, (int[])new int[]{0}, (int)1, (int)2);
            BroadcastJoinTest.assertStringVecEquals((Vec)BroadcastJoinTest.cvec((String[])new String[]{"a", "c", "b"}), (Vec)joined.vec("ColA"));
            Assert.assertEquals((double)22.0, (double)joined.vec(TargetEncoder.NUMERATOR_COL_NAME).at(0L), (double)1.0E-5);
            Assert.assertEquals((double)44.0, (double)joined.vec(TargetEncoder.DENOMINATOR_COL_NAME).at(0L), (double)1.0E-5);
            Assert.assertEquals((double)42.0, (double)joined.vec(TargetEncoder.NUMERATOR_COL_NAME).at(1L), (double)1.0E-5);
            Assert.assertEquals((double)84.0, (double)joined.vec(TargetEncoder.DENOMINATOR_COL_NAME).at(1L), (double)1.0E-5);
            Assert.assertTrue((boolean)joined.vec(TargetEncoder.NUMERATOR_COL_NAME).isNA(2L));
            Assert.assertTrue((boolean)joined.vec(TargetEncoder.DENOMINATOR_COL_NAME).isNA(2L));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Property(trials=5)
    public void joinWorksWithoutLoosingOriginalOrderTest(@InRange(minInt=2, maxInt=10000) @InRange(minInt=2, maxInt=10000) int sizeOfLeftFrame, @InRange(minInt=1, maxInt=1000) @InRange(minInt=1, maxInt=1000) int numberOfFolds) {
        block3: {
            Scope.enter();
            long seed = 1234L;
            IcedHashMapGeneric encodingMap = null;
            try {
                String responseColumnName = "response";
                String[] randomArrOfStrings = this.randomArrOfStrings(sizeOfLeftFrame);
                String catColumnName = "ColA";
                String foldColumnName = "fold";
                Frame leftFr = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{catColumnName, responseColumnName}).withVecTypes(new byte[]{4, 4}).withDataForCol(0, randomArrOfStrings).withRandomBinaryDataForCol(1, sizeOfLeftFrame, seed).withChunkLayout(new long[]{sizeOfLeftFrame / 2, sizeOfLeftFrame - sizeOfLeftFrame / 2}).build();
                Assume.assumeTrue((leftFr.vec(responseColumnName).cardinality() == 2 ? 1 : 0) != 0);
                TargetEncoderFrameHelper.addKFoldColumn((Frame)leftFr, (String)foldColumnName, (int)numberOfFolds, (long)1234L);
                Assume.assumeTrue((((Vec)leftFr.vec(foldColumnName).clone()).toCategoricalVec().cardinality() == numberOfFolds ? 1 : 0) != 0);
                TargetEncoder tec = new TargetEncoder(new String[]{catColumnName});
                encodingMap = tec.prepareEncodingMap(leftFr, responseColumnName, foldColumnName);
                Frame encodingMapForColA = (Frame)encodingMap.get((Object)catColumnName);
                Vec emptyNumerator = leftFr.anyVec().makeCon(0.0);
                leftFr.add(TargetEncoder.NUMERATOR_COL_NAME, emptyNumerator);
                Vec emptyDenominator = leftFr.anyVec().makeCon(0.0);
                leftFr.add(TargetEncoder.DENOMINATOR_COL_NAME, emptyDenominator);
                Scope.track((Vec)emptyNumerator);
                Scope.track((Vec)emptyDenominator);
                Frame joined = BroadcastJoinForTargetEncoder.join((Frame)leftFr, (int[])new int[]{0}, (int)leftFr.find(foldColumnName), (Frame)encodingMapForColA, (int[])new int[]{0}, (int)encodingMapForColA.find(foldColumnName), (int)numberOfFolds);
                Scope.track((Frame[])new Frame[]{joined});
                BroadcastJoinTest.assertStringVecEquals((Vec)leftFr.vec(catColumnName), (Vec)joined.vec(catColumnName));
                int randomIdx = new Random(seed).nextInt(sizeOfLeftFrame);
                double randomColAValueFromLeftFr = leftFr.vec(catColumnName).at((long)randomIdx);
                double randomFoldValueFromLeftFr = leftFr.vec(foldColumnName).at((long)randomIdx);
                Frame filteredByColA = TargetEncoderFrameHelper.filterByValue((Frame)leftFr, (int)0, (double)randomColAValueFromLeftFr);
                Frame filteredByFoldAndColAColumns = TargetEncoderFrameHelper.filterByValue((Frame)filteredByColA, (int)filteredByColA.find(foldColumnName), (double)randomFoldValueFromLeftFr);
                Frame filteredByColAFromEM = TargetEncoderFrameHelper.filterByValue((Frame)encodingMapForColA, (int)0, (double)randomColAValueFromLeftFr);
                Frame filteredByFoldAndColAColumnsFromEM = TargetEncoderFrameHelper.filterByValue((Frame)filteredByColAFromEM, (int)filteredByColAFromEM.find(foldColumnName), (double)randomFoldValueFromLeftFr);
                Scope.track((Frame[])new Frame[]{filteredByColA, filteredByFoldAndColAColumns, filteredByColAFromEM, filteredByFoldAndColAColumnsFromEM});
                Assert.assertEquals((double)filteredByFoldAndColAColumns.vec(TargetEncoder.NUMERATOR_COL_NAME).at(0L), (double)filteredByFoldAndColAColumnsFromEM.vec(TargetEncoder.NUMERATOR_COL_NAME).at(0L), (double)1.0E-5);
                Assert.assertEquals((double)filteredByFoldAndColAColumns.vec(TargetEncoder.DENOMINATOR_COL_NAME).at(0L), (double)filteredByFoldAndColAColumnsFromEM.vec(TargetEncoder.DENOMINATOR_COL_NAME).at(0L), (double)1.0E-5);
                if (encodingMap == null) break block3;
            }
            catch (Throwable throwable) {
                if (encodingMap != null) {
                    TargetEncoderFrameHelper.encodingMapCleanUp(encodingMap);
                }
                Scope.exit((Key[])new Key[0]);
                throw throwable;
            }
            TargetEncoderFrameHelper.encodingMapCleanUp((Map)encodingMap);
        }
        Scope.exit((Key[])new Key[0]);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test(expected=AssertionError.class)
    public void mergeWillUseRightFramesOrderAndGroupByValues() {
        Scope.enter();
        Frame res = null;
        try {
            Frame fr = new TestFrameBuilder().withName("leftFrame").withColNames(new String[]{"ColA", "ColB"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, BroadcastJoinTest.ar((String[])new String[]{"a", "b", "c", "e", "a"})).withDataForCol(1, BroadcastJoinTest.ard((double[])new double[]{-1.0, 2.0, 3.0, 4.0, 7.0})).build();
            Frame holdoutEncodingMap = new TestFrameBuilder().withName("holdoutEncodingMap").withColNames(new String[]{"ColB", "ColC"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, BroadcastJoinTest.ar((String[])new String[]{"c", "a", "e", "b"})).withDataForCol(1, BroadcastJoinTest.ard((double[])new double[]{2.0, 3.0, 4.0, 5.0})).build();
            int[][] levelMaps = new int[][]{CategoricalWrappedVec.computeMap((String[])holdoutEncodingMap.vec(0).domain(), (String[])fr.vec(0).domain())};
            res = Merge.merge((Frame)holdoutEncodingMap, (Frame)fr, (int[])new int[]{0}, (int[])new int[]{0}, (boolean)false, (int[][])levelMaps);
            BroadcastJoinTest.assertStringVecEquals((Vec)BroadcastJoinTest.cvec((String[])new String[]{"a", "b", "c", "e", "a"}), (Vec)res.vec("ColB"));
        }
        finally {
            res.delete();
            Scope.exit((Key[])new Key[0]);
        }
    }

    @Test(expected=AssertionError.class)
    public void foldValuesThatAreBiggerThanIntegerWillCauseExceptionTest() {
        long biggerThanIntMax = Integer.MIN_VALUE;
        this.fr = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA", "fold", TargetEncoder.NUMERATOR_COL_NAME, TargetEncoder.DENOMINATOR_COL_NAME}).withVecTypes(new byte[]{4, 3, 3, 3}).withDataForCol(0, BroadcastJoinTest.ar((String[])new String[]{"a", "b", "c"})).withDataForCol(1, BroadcastJoinTest.ar((long[])new long[]{biggerThanIntMax, 33L, 42L})).withDataForCol(2, BroadcastJoinTest.ar((long[])new long[]{44L, 66L, 84L})).withDataForCol(3, BroadcastJoinTest.ar((long[])new long[]{88L, 132L, 168L})).withChunkLayout(new long[]{2L, 1L}).build();
        int cardinality = this.fr.vec("ColA").cardinality();
        int[][] encodingDataArray = ((BroadcastJoinForTargetEncoder.FrameWithEncodingDataToArray)new BroadcastJoinForTargetEncoder.FrameWithEncodingDataToArray(0, 1, 2, 3, cardinality, (int)biggerThanIntMax).doAll(this.fr)).getEncodingDataArray();
    }

    @Property(trials=100)
    public void foldValuesThatAreInRangeWouldNotCauseExceptionTest(@InRange(minInt=1, maxInt=1000) @InRange(minInt=1, maxInt=1000) int randomInt) {
        this.fr = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA", "fold", TargetEncoder.NUMERATOR_COL_NAME, TargetEncoder.DENOMINATOR_COL_NAME}).withVecTypes(new byte[]{4, 3, 3, 3}).withDataForCol(0, BroadcastJoinTest.ar((String[])new String[]{"a", "b", "c"})).withDataForCol(1, BroadcastJoinTest.ar((long[])new long[]{0L, 1L, 2L})).withDataForCol(2, BroadcastJoinTest.ar((long[])new long[]{randomInt, 66L, 84L})).withDataForCol(3, BroadcastJoinTest.ar((long[])new long[]{88L, 132L, randomInt})).withChunkLayout(new long[]{2L, 1L}).build();
        int cardinality = this.fr.vec("ColA").cardinality();
        ((BroadcastJoinForTargetEncoder.FrameWithEncodingDataToArray)new BroadcastJoinForTargetEncoder.FrameWithEncodingDataToArray(0, 1, 2, 3, cardinality, Math.max(randomInt, 42)).doAll(this.fr)).getEncodingDataArray();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void joinWithoutFoldColumnTest() {
        Frame rightFr = null;
        Vec emptyNumerator = null;
        Vec emptyDenominator = null;
        try {
            this.fr = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA"}).withVecTypes(new byte[]{4}).withDataForCol(0, BroadcastJoinTest.ar((String[])new String[]{"a", "c", "b"})).build();
            rightFr = new TestFrameBuilder().withName("testFrame2").withColNames(new String[]{"ColA", TargetEncoder.NUMERATOR_COL_NAME, TargetEncoder.DENOMINATOR_COL_NAME}).withVecTypes(new byte[]{4, 3, 3}).withDataForCol(0, BroadcastJoinTest.ar((String[])new String[]{"a", "b", "c"})).withDataForCol(1, BroadcastJoinTest.ar((long[])new long[]{22L, 33L, 42L})).withDataForCol(2, BroadcastJoinTest.ar((long[])new long[]{44L, 66L, 84L})).withChunkLayout(new long[]{1L, 1L, 1L}).build();
            emptyNumerator = Vec.makeZero((long)this.fr.numRows());
            this.fr.add(TargetEncoder.NUMERATOR_COL_NAME, emptyNumerator);
            emptyDenominator = Vec.makeZero((long)this.fr.numRows());
            this.fr.add(TargetEncoder.DENOMINATOR_COL_NAME, emptyDenominator);
            Frame joined = BroadcastJoinForTargetEncoder.join((Frame)this.fr, (int[])new int[]{0}, (int)-1, (Frame)rightFr, (int[])new int[]{0}, (int)-1, (int)0);
            Scope.enter();
            BroadcastJoinTest.assertStringVecEquals((Vec)BroadcastJoinTest.cvec((String[])new String[]{"a", "c", "b"}), (Vec)joined.vec("ColA"));
            BroadcastJoinTest.assertVecEquals((Vec)BroadcastJoinTest.vec((int[])new int[]{22, 42, 33}), (Vec)joined.vec(TargetEncoder.NUMERATOR_COL_NAME), (double)1.0E-5);
            BroadcastJoinTest.assertVecEquals((Vec)BroadcastJoinTest.vec((int[])new int[]{44, 84, 66}), (Vec)joined.vec(TargetEncoder.DENOMINATOR_COL_NAME), (double)1.0E-5);
            Scope.exit((Key[])new Key[0]);
        }
        finally {
            if (rightFr != null) {
                rightFr.delete();
            }
        }
    }

    private String[] randomArrOfStrings(int size) {
        String[] arr = new String[size];
        Random rg = new Random();
        int cardinality = size / 2;
        for (int a = 0; a < size; ++a) {
            arr[a] = Integer.toString(rg.nextInt(Math.max(1, cardinality)));
        }
        return arr;
    }

    @After
    public void afterEach() {
        if (this.fr != null) {
            this.fr.delete();
        }
    }
}

