package hex.word2vec;

import hex.word2vec.Word2Vec;
import hex.word2vec.Word2VecModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.runner.CloudSize;
import water.runner.H2ORunner;
import water.util.ArrayUtils;
import water.util.Log;

@CloudSize(1)
@RunWith(H2ORunner.class)
/* loaded from: input_file:hex/word2vec/Word2VecTest.class */
public class Word2VecTest extends TestUtil {

    @Rule
    public ExpectedException ee = ExpectedException.none();

    @Test
    public void testW2V_SG_HSM_small() {
        String[] strArr = new String[220];
        for (int i = 0; i < 200; i += 2) {
            strArr[i] = "a";
            strArr[i + 1] = "b";
        }
        for (int i2 = 200; i2 < 220; i2 += 2) {
            strArr[i2] = "a";
            strArr[i2 + 1] = "c";
        }
        Scope.enter();
        try {
            Frame track = Scope.track(new Frame[]{new Frame(Key.make(), new String[]{"Words"}, new Vec[]{Scope.track(svec(strArr))})});
            DKV.put(track);
            Word2VecModel.Word2VecParameters word2VecParameters = new Word2VecModel.Word2VecParameters();
            word2VecParameters._train = track._key;
            word2VecParameters._min_word_freq = 5;
            word2VecParameters._word_model = Word2Vec.WordModel.SkipGram;
            word2VecParameters._norm_model = Word2Vec.NormModel.HSM;
            word2VecParameters._vec_size = 10;
            word2VecParameters._window_size = 5;
            word2VecParameters._sent_sample_rate = 0.001f;
            word2VecParameters._init_learning_rate = 0.025f;
            word2VecParameters._epochs = 1;
            Word2VecModel track_generic = Scope.track_generic(new Word2Vec(word2VecParameters).trainModel().get());
            Map<String, Float> findSynonyms = track_generic.findSynonyms("a", 2);
            logResults(findSynonyms);
            Assert.assertEquals(new HashSet(Arrays.asList("b", "c")), findSynonyms.keySet());
            Frame track2 = Scope.track(new Frame[]{track_generic.transform(Scope.track(svec(new String[]{"a", "b", "c", "Unseen", null})), Word2VecModel.AggregateMethod.NONE)});
            Assert.assertEquals(10L, track2.numCols());
            for (int i3 = 0; i3 < 10; i3++) {
                for (int i4 = 0; i4 < 3; i4++) {
                    Assert.assertFalse(track2.vec(i3).isNA(i4));
                }
                for (int i5 = 3; i5 < 5; i5++) {
                    Assert.assertTrue(track2.vec(i3).isNA(i5));
                }
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testW2V_pretrained() {
        String[] strArr = new String[1000];
        double[] dArr = new double[strArr.length];
        double[] dArr2 = new double[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = "word" + i;
            dArr[i] = i / strArr.length;
            dArr2[i] = 1.0d - dArr[i];
        }
        Scope.enter();
        Frame build = new TestFrameBuilder().withName("w2v-pretrained").withColNames(new String[]{"Word", "V1", "V2"}).withVecTypes(new byte[]{2, 3, 3}).withDataForCol(0, strArr).withDataForCol(1, dArr).withDataForCol(2, dArr2).withChunkLayout(new long[]{100, 100, 20, 80, 100, 100, 100, 100, 100, 100, 100}).build();
        Scope.track(new Frame[]{build});
        try {
            Word2VecModel word2VecModel = Word2Vec.fromPretrainedModel(build).get();
            Scope.track_generic(word2VecModel);
            for (int i2 = 0; i2 < strArr.length; i2++) {
                Assert.assertArrayEquals("wordvec " + i2, new float[]{(float) dArr[i2], (float) dArr2[i2]}, word2VecModel.transform(strArr[i2]), 1.0E-4f);
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testImportPretrained_invalid() {
        try {
            Scope.enter();
            Frame build = new TestFrameBuilder().withName("w2v-pretrained").withColNames(new String[]{"Word", "V1", "V2", "V3"}).withVecTypes(new byte[]{2, 5, 3, 4}).withDataForCol(0, ar(new String[]{"a"})).withDataForCol(1, ar(new long[]{System.currentTimeMillis()})).withDataForCol(2, ard(new double[]{3.141592653589793d})).withDataForCol(3, ar(new String[]{"C1"})).build();
            this.ee.expectMessage("All components of word2vec mapping are expected to be numeric. Invalid columns: V1 (type Time), V3 (type Enum)");
            Word2Vec.fromPretrainedModel(build);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testW2V_toFrame() {
        Random random = new Random();
        String[] strArr = new String[1000];
        double[] dArr = new double[strArr.length];
        double[] dArr2 = new double[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = "word" + i;
            dArr[i] = random.nextDouble();
            dArr2[i] = random.nextDouble();
        }
        try {
            Scope.enter();
            Frame build = new TestFrameBuilder().withName("w2v").withColNames(new String[]{"Word", "V1", "V2"}).withVecTypes(new byte[]{2, 3, 3}).withDataForCol(0, strArr).withDataForCol(1, dArr).withDataForCol(2, dArr2).withChunkLayout(new long[]{100, 900}).build();
            Scope.track(new Frame[]{build});
            Word2VecModel.Word2VecParameters word2VecParameters = new Word2VecModel.Word2VecParameters();
            word2VecParameters._vec_size = 2;
            word2VecParameters._pre_trained = build._key;
            Frame track = Scope.track(new Frame[]{Scope.track_generic(new Word2Vec(word2VecParameters).trainModel().get()).toFrame()});
            Assert.assertArrayEquals(build._names, track._names);
            assertStringVecEquals(build.vec(0), track.vec(0));
            assertVecEquals(build.vec(1), track.vec(1), 1.0E-4d);
            assertVecEquals(build.vec(2), track.vec(2), 1.0E-4d);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testW2V_SG_HSM() {
        Assume.assumeThat("word2vec test enabled", System.getProperty("testW2V"), CoreMatchers.is(CoreMatchers.notNullValue()));
        Frame parse_test_file = parse_test_file("bigdata/laptop/text8.gz", "NA", 0, new byte[]{2});
        Word2VecModel word2VecModel = null;
        try {
            Word2VecModel.Word2VecParameters word2VecParameters = new Word2VecModel.Word2VecParameters();
            word2VecParameters._train = parse_test_file._key;
            word2VecParameters._min_word_freq = 5;
            word2VecParameters._word_model = Word2Vec.WordModel.SkipGram;
            word2VecParameters._norm_model = Word2Vec.NormModel.HSM;
            word2VecParameters._vec_size = 100;
            word2VecParameters._window_size = 4;
            word2VecParameters._sent_sample_rate = 0.001f;
            word2VecParameters._init_learning_rate = 0.025f;
            word2VecParameters._epochs = 10;
            word2VecModel = (Word2VecModel) new Word2Vec(word2VecParameters).trainModel().get();
            Map<String, Float> findSynonyms = word2VecModel.findSynonyms("dog", 20);
            logResults(findSynonyms);
            Assert.assertTrue(findSynonyms.containsKey("cat") || findSynonyms.containsKey("dogs") || findSynonyms.containsKey("hound"));
            parse_test_file.remove();
            if (word2VecModel != null) {
                word2VecModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.remove();
            if (word2VecModel != null) {
                word2VecModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testW2V_CBOW_HSM() {
        Assume.assumeThat("word2vec test enabled", System.getProperty("testW2V"), CoreMatchers.is(CoreMatchers.notNullValue()));
        Frame parse_test_file = parse_test_file("bigdata/laptop/text8.gz", "NA", 0, new byte[]{2});
        Word2VecModel word2VecModel = null;
        try {
            Word2VecModel.Word2VecParameters word2VecParameters = new Word2VecModel.Word2VecParameters();
            word2VecParameters._train = parse_test_file._key;
            word2VecParameters._min_word_freq = 20;
            word2VecParameters._word_model = Word2Vec.WordModel.CBOW;
            word2VecParameters._norm_model = Word2Vec.NormModel.HSM;
            word2VecParameters._vec_size = 100;
            word2VecParameters._window_size = 4;
            word2VecParameters._sent_sample_rate = 0.01f;
            word2VecParameters._init_learning_rate = 0.05f;
            word2VecParameters._epochs = 10;
            word2VecModel = (Word2VecModel) new Word2Vec(word2VecParameters).trainModel().get();
            Map<String, Float> findSynonyms = word2VecModel.findSynonyms("dog", 10);
            logResults(findSynonyms);
            Assert.assertTrue(findSynonyms.containsKey("dogs"));
            Assert.assertTrue(findSynonyms.containsKey("cat") || findSynonyms.containsKey("dogs") || findSynonyms.containsKey("hound"));
            parse_test_file.remove();
            if (word2VecModel != null) {
                word2VecModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.remove();
            if (word2VecModel != null) {
                word2VecModel.delete();
            }
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testTransformAggregate() {
        Scope.enter();
        try {
            Frame track = Scope.track(new Frame[]{new Frame(Key.make(), new String[]{"Words"}, new Vec[]{Scope.track(svec(new String[]{"a", "b"}))})});
            DKV.put(track);
            Word2VecModel.Word2VecParameters word2VecParameters = new Word2VecModel.Word2VecParameters();
            word2VecParameters._train = track._key;
            word2VecParameters._min_word_freq = 0;
            word2VecParameters._epochs = 1;
            word2VecParameters._vec_size = 2;
            Word2VecModel track_generic = Scope.track_generic(new Word2Vec(word2VecParameters).trainModel().get());
            track_generic._output._vecs = new float[]{1.0f, 0.0f, 0.0f, 1.0f};
            DKV.put(track_generic);
            String[] strArr = {new String[]{"a", "b", null, "a", "c", null, "c", null, "a", "a"}, new String[]{"a", "b", null}, new String[]{null, null}, new String[]{"b", "b", "a"}, new String[]{"b"}};
            long[] jArr = new long[strArr.length];
            String[] strArr2 = new String[0];
            for (int i = 0; i < strArr.length; i++) {
                strArr2 = ArrayUtils.append(strArr2, strArr[i]);
                jArr[i] = strArr[i].length;
            }
            Frame track2 = Scope.track(new Frame[]{track_generic.transform(new TestFrameBuilder().withName("data").withColNames(new String[]{"Sentences"}).withVecTypes(new byte[]{2}).withDataForCol(0, strArr2).withChunkLayout(jArr).build().vec(0), Word2VecModel.AggregateMethod.AVERAGE)});
            Vec track3 = Scope.track(dvec(new double[]{0.5d, 1.0d, Double.NaN, 0.75d, Double.NaN, Double.NaN, 0.25d}));
            Vec track4 = Scope.track(dvec(new double[]{0.5d, 0.0d, Double.NaN, 0.25d, Double.NaN, Double.NaN, 0.75d}));
            assertVecEquals(track3, track2.vec(((Integer) track_generic._output._vocab.get(new BufferedString("a"))).intValue()), 1.0E-4d);
            assertVecEquals(track4, track2.vec(((Integer) track_generic._output._vocab.get(new BufferedString("b"))).intValue()), 1.0E-4d);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private void logResults(Map<String, Float> map) {
        ArrayList<Map.Entry> arrayList = new ArrayList(map.entrySet());
        Collections.sort(arrayList, new Comparator<Map.Entry<String, Float>>() { // from class: hex.word2vec.Word2VecTest.1
            @Override // java.util.Comparator
            public int compare(Map.Entry<String, Float> entry, Map.Entry<String, Float> entry2) {
                return entry2.getValue().compareTo(entry.getValue());
            }
        });
        int i = 0;
        for (Map.Entry entry : arrayList) {
            int i2 = i;
            i++;
            Log.info(new Object[]{i2 + ". " + entry.getKey() + ", " + entry.getValue()});
        }
    }
}
