/*
 * Decompiled with CFR 0.152.
 */
package hex.word2vec;

import hex.word2vec.Word2Vec;
import hex.word2vec.Word2VecModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matcher;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.Log;

public class Word2VecTest
extends TestUtil {
    @Rule
    public ExpectedException ee = ExpectedException.none();

    @BeforeClass
    public static void setup() {
        Word2VecTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testW2V_SG_HSM_small() {
        int i;
        String[] words = new String[220];
        for (i = 0; i < 200; i += 2) {
            words[i] = "a";
            words[i + 1] = "b";
        }
        for (i = 200; i < 220; i += 2) {
            words[i] = "a";
            words[i + 1] = "c";
        }
        Scope.enter();
        try {
            Vec v = Scope.track((Vec)Word2VecTest.svec((String[])words));
            Frame fr = Scope.track((Frame[])new Frame[]{new Frame(Key.make(), new String[]{"Words"}, new Vec[]{v})});
            DKV.put((Keyed)fr);
            Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters();
            p._train = fr._key;
            p._min_word_freq = 5;
            p._word_model = Word2Vec.WordModel.SkipGram;
            p._norm_model = Word2Vec.NormModel.HSM;
            p._vec_size = 10;
            p._window_size = 5;
            p._sent_sample_rate = 0.001f;
            p._init_learning_rate = 0.025f;
            p._epochs = 1;
            Word2VecModel w2vm = (Word2VecModel)Scope.track_generic((Keyed)new Word2Vec(p).trainModel().get());
            Map hm = w2vm.findSynonyms("a", 2);
            this.logResults(hm);
            Assert.assertEquals(new HashSet<String>(Arrays.asList("b", "c")), hm.keySet());
            Vec testWordVec = Scope.track((Vec)Word2VecTest.svec((String[])new String[]{"a", "b", "c", "Unseen", null}));
            Frame wv = Scope.track((Frame[])new Frame[]{w2vm.transform(testWordVec, Word2VecModel.AggregateMethod.NONE)});
            Assert.assertEquals((long)10L, (long)wv.numCols());
            for (int i2 = 0; i2 < 10; ++i2) {
                int j;
                for (j = 0; j < 3; ++j) {
                    Assert.assertFalse((boolean)wv.vec(i2).isNA((long)j));
                }
                for (j = 3; j < 5; ++j) {
                    Assert.assertTrue((boolean)wv.vec(i2).isNA((long)j));
                }
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testW2V_pretrained() {
        String[] words = new String[1000];
        double[] v1 = new double[words.length];
        double[] v2 = new double[words.length];
        for (int i = 0; i < words.length; ++i) {
            words[i] = "word" + i;
            v1[i] = (float)i / (float)words.length;
            v2[i] = 1.0 - v1[i];
        }
        Scope.enter();
        Frame pretrained = new TestFrameBuilder().withName("w2v-pretrained").withColNames(new String[]{"Word", "V1", "V2"}).withVecTypes(new byte[]{2, 3, 3}).withDataForCol(0, words).withDataForCol(1, v1).withDataForCol(2, v2).withChunkLayout(new long[]{100L, 100L, 20L, 80L, 100L, 100L, 100L, 100L, 100L, 100L, 100L}).build();
        Scope.track((Frame[])new Frame[]{pretrained});
        try {
            Word2VecModel w2vm = (Word2VecModel)Word2Vec.fromPretrainedModel((Frame)pretrained).get();
            Scope.track_generic((Keyed)w2vm);
            for (int i = 0; i < words.length; ++i) {
                float[] wordVector = w2vm.transform(words[i]);
                Assert.assertArrayEquals((String)("wordvec " + i), (float[])new float[]{(float)v1[i], (float)v2[i]}, (float[])wordVector, (float)1.0E-4f);
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    @Test
    public void testImportPretrained_invalid() {
        try {
            Scope.enter();
            Frame pretrained = new TestFrameBuilder().withName("w2v-pretrained").withColNames(new String[]{"Word", "V1", "V2", "V3"}).withVecTypes(new byte[]{2, 5, 3, 4}).withDataForCol(0, Word2VecTest.ar((String[])new String[]{"a"})).withDataForCol(1, Word2VecTest.ar((long[])new long[]{System.currentTimeMillis()})).withDataForCol(2, Word2VecTest.ard((double[])new double[]{Math.PI})).withDataForCol(3, Word2VecTest.ar((String[])new String[]{"C1"})).build();
            this.ee.expectMessage("All components of word2vec mapping are expected to be numeric. Invalid columns: V1 (type Time), V3 (type Enum)");
            Word2Vec.fromPretrainedModel((Frame)pretrained);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testW2V_toFrame() {
        Random r = new Random();
        String[] words = new String[1000];
        double[] v1 = new double[words.length];
        double[] v2 = new double[words.length];
        for (int i = 0; i < words.length; ++i) {
            words[i] = "word" + i;
            v1[i] = r.nextDouble();
            v2[i] = r.nextDouble();
        }
        try {
            Scope.enter();
            Frame expected = new TestFrameBuilder().withName("w2v").withColNames(new String[]{"Word", "V1", "V2"}).withVecTypes(new byte[]{2, 3, 3}).withDataForCol(0, words).withDataForCol(1, v1).withDataForCol(2, v2).withChunkLayout(new long[]{100L, 900L}).build();
            Scope.track((Frame[])new Frame[]{expected});
            Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters();
            p._vec_size = 2;
            p._pre_trained = expected._key;
            Word2VecModel w2vm = (Word2VecModel)Scope.track_generic((Keyed)new Word2Vec(p).trainModel().get());
            Frame result = Scope.track((Frame[])new Frame[]{w2vm.toFrame()});
            Assert.assertArrayEquals((Object[])expected._names, (Object[])result._names);
            Word2VecTest.assertStringVecEquals((Vec)expected.vec(0), (Vec)result.vec(0));
            Word2VecTest.assertVecEquals((Vec)expected.vec(1), (Vec)result.vec(1), (double)1.0E-4);
            Word2VecTest.assertVecEquals((Vec)expected.vec(2), (Vec)result.vec(2), (double)1.0E-4);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testW2V_SG_HSM() {
        Assume.assumeThat((String)"word2vec test enabled", (Object)System.getProperty("testW2V"), (Matcher)CoreMatchers.is((Matcher)CoreMatchers.notNullValue()));
        Frame fr = this.parse_test_file("bigdata/laptop/text8.gz", "NA", 0, new byte[]{2});
        Word2VecModel w2vm = null;
        try {
            Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters();
            p._train = fr._key;
            p._min_word_freq = 5;
            p._word_model = Word2Vec.WordModel.SkipGram;
            p._norm_model = Word2Vec.NormModel.HSM;
            p._vec_size = 100;
            p._window_size = 4;
            p._sent_sample_rate = 0.001f;
            p._init_learning_rate = 0.025f;
            p._epochs = 10;
            w2vm = (Word2VecModel)new Word2Vec(p).trainModel().get();
            Map hm = w2vm.findSynonyms("dog", 20);
            this.logResults(hm);
            Assert.assertTrue((hm.containsKey("cat") || hm.containsKey("dogs") || hm.containsKey("hound") ? 1 : 0) != 0);
        }
        finally {
            fr.remove();
            if (w2vm != null) {
                w2vm.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testTransformAggregate() {
        Scope.enter();
        try {
            Vec v = Scope.track((Vec)Word2VecTest.svec((String[])new String[]{"a", "b"}));
            Frame fr = Scope.track((Frame[])new Frame[]{new Frame(Key.make(), new String[]{"Words"}, new Vec[]{v})});
            DKV.put((Keyed)fr);
            Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters();
            p._train = fr._key;
            p._min_word_freq = 0;
            p._epochs = 1;
            p._vec_size = 2;
            Word2VecModel w2vm = (Word2VecModel)Scope.track_generic((Keyed)new Word2Vec(p).trainModel().get());
            ((Word2VecModel.Word2VecOutput)w2vm._output)._vecs = new float[]{1.0f, 0.0f, 0.0f, 1.0f};
            DKV.put((Keyed)w2vm);
            String[][] chunks = new String[][]{{"a", "b", null, "a", "c", null, "c", null, "a", "a"}, {"a", "b", null}, {null, null}, {"b", "b", "a"}, {"b"}};
            long[] layout = new long[chunks.length];
            String[] sentences = new String[]{};
            for (int i = 0; i < chunks.length; ++i) {
                sentences = ArrayUtils.append((String[])sentences, (String[])chunks[i]);
                layout[i] = chunks[i].length;
            }
            Frame f = new TestFrameBuilder().withName("data").withColNames(new String[]{"Sentences"}).withVecTypes(new byte[]{2}).withDataForCol(0, sentences).withChunkLayout(layout).build();
            Frame result = Scope.track((Frame[])new Frame[]{w2vm.transform(f.vec(0), Word2VecModel.AggregateMethod.AVERAGE)});
            Vec expectedAs = Scope.track((Vec)Word2VecTest.dvec((double[])new double[]{0.5, 1.0, Double.NaN, 0.75, Double.NaN, Double.NaN, 0.25}));
            Vec expectedBs = Scope.track((Vec)Word2VecTest.dvec((double[])new double[]{0.5, 0.0, Double.NaN, 0.25, Double.NaN, Double.NaN, 0.75}));
            Word2VecTest.assertVecEquals((Vec)expectedAs, (Vec)result.vec(((Integer)((Word2VecModel.Word2VecOutput)w2vm._output)._vocab.get((Object)new BufferedString("a"))).intValue()), (double)1.0E-4);
            Word2VecTest.assertVecEquals((Vec)expectedBs, (Vec)result.vec(((Integer)((Word2VecModel.Word2VecOutput)w2vm._output)._vocab.get((Object)new BufferedString("b"))).intValue()), (double)1.0E-4);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    private void logResults(Map<String, Float> hm) {
        ArrayList<Map.Entry<String, Float>> result = new ArrayList<Map.Entry<String, Float>>(hm.entrySet());
        Collections.sort(result, new Comparator<Map.Entry<String, Float>>(){

            @Override
            public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
                return o2.getValue().compareTo(o1.getValue());
            }
        });
        int i = 0;
        for (Map.Entry entry : result) {
            Log.info((Object[])new Object[]{i++ + ". " + entry.getKey() + ", " + entry.getValue()});
        }
    }
}

