package hex.word2vec;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.word2vec.Word2Vec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.MemoryManager;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.IcedHashMap;
import water.util.IcedHashMapGeneric;
import water.util.IcedLong;
import water.util.RandomBase;
import water.util.RandomUtils;

/* loaded from: input_file:hex/word2vec/Word2VecModel.class */
public class Word2VecModel extends Model<Word2VecModel, Word2VecParameters, Word2VecOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/word2vec/Word2VecModel$AggregateMethod.class */
    public enum AggregateMethod {
        NONE,
        AVERAGE
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$ConvertToFrameTask.class */
    private static class ConvertToFrameTask extends MRTask<ConvertToFrameTask> {
        private Key<Word2VecModel> _modelKey;
        private transient Word2VecModel _model;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ConvertToFrameTask(Word2VecModel word2VecModel) {
            this._modelKey = word2VecModel._key;
        }

        protected void setupLocal() {
            this._model = DKV.getGet(this._modelKey);
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            if (!$assertionsDisabled && chunkArr.length != 1) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && newChunkArr.length != ((Word2VecOutput) this._model._output)._vecSize + 1) {
                throw new AssertionError();
            }
            Chunk chunk = chunkArr[0];
            int start = (int) chunk.start();
            int i = ((Word2VecOutput) this._model._output)._vecSize * start;
            for (int i2 = 0; i2 < chunk._len; i2++) {
                newChunkArr[0].addStr(((Word2VecOutput) this._model._output)._words[start + i2]);
                for (int i3 = 1; i3 < newChunkArr.length; i3++) {
                    int i4 = i;
                    i++;
                    newChunkArr[i3].addNum(((Word2VecOutput) this._model._output)._vecs[i4]);
                }
            }
        }

        static {
            $assertionsDisabled = !Word2VecModel.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Vocabulary.class */
    public static class Vocabulary extends Keyed<Vocabulary> {
        IcedHashMapGeneric<BufferedString, Integer> _data;

        Vocabulary(IcedHashMapGeneric<BufferedString, Integer> icedHashMapGeneric) {
            super(Key.make());
            this._data = icedHashMapGeneric;
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Word2VecAggregateTask.class */
    private static class Word2VecAggregateTask extends MRTask<Word2VecAggregateTask> {
        private Word2VecModel _model;
        static final /* synthetic */ boolean $assertionsDisabled;

        public Word2VecAggregateTask(Word2VecModel word2VecModel) {
            this._model = word2VecModel;
        }

        /* JADX WARN: Code restructure failed: missing block: B:41:0x00c2, code lost:
        
            r10 = 0;
            r0 = r9.nextChunk();
            r9 = r0;
         */
        /*
            Code decompiled incorrectly, please refer to instructions dump.
            To view partially-correct add '--show-bad-code' argument
        */
        public void map(water.fvec.Chunk[] r7, water.fvec.NewChunk[] r8) {
            /*
                Method dump skipped, instructions count: 221
                To view this dump add '--comments-level debug' option
            */
            throw new UnsupportedOperationException("Method not decompiled: hex.word2vec.Word2VecModel.Word2VecAggregateTask.map(water.fvec.Chunk[], water.fvec.NewChunk[]):void");
        }

        private void writeAggregate(int i, float[] fArr, NewChunk[] newChunkArr) {
            if (i != 0) {
                for (int i2 = 0; i2 < newChunkArr.length; i2++) {
                    newChunkArr[i2].addNum(fArr[i2] / i);
                }
                return;
            }
            for (NewChunk newChunk : newChunkArr) {
                newChunk.addNA();
            }
        }

        private int findNA(Chunk chunk) {
            for (int i = 0; i < chunk._len; i++) {
                if (chunk.isNA(i)) {
                    return i;
                }
            }
            return -1;
        }

        static {
            $assertionsDisabled = !Word2VecModel.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Word2VecModelInfo.class */
    public static class Word2VecModelInfo extends Iced {
        long _vocabWordCount;
        long _totalProcessedWords = 0;
        float[] _syn0;
        float[] _syn1;
        Key<HBWTree> _treeKey;
        Key<Vocabulary> _vocabKey;
        Key<WordCounts> _wordCountsKey;
        private Word2VecParameters _parameters;

        public final Word2VecParameters getParams() {
            return this._parameters;
        }

        public Word2VecModelInfo() {
        }

        private Word2VecModelInfo(Word2VecParameters word2VecParameters, WordCounts wordCounts) {
            this._parameters = word2VecParameters;
            long j = 0;
            ArrayList<Map.Entry> arrayList = new ArrayList(wordCounts._data.size());
            for (Map.Entry entry : wordCounts._data.entrySet()) {
                if (((IcedLong) entry.getValue())._val >= this._parameters._min_word_freq) {
                    arrayList.add(entry);
                    j += ((IcedLong) entry.getValue())._val;
                }
            }
            Collections.sort(arrayList, new Comparator<Map.Entry<BufferedString, IcedLong>>() { // from class: hex.word2vec.Word2VecModel.Word2VecModelInfo.1
                @Override // java.util.Comparator
                public int compare(Map.Entry<BufferedString, IcedLong> entry2, Map.Entry<BufferedString, IcedLong> entry3) {
                    long j2 = entry2.getValue()._val;
                    long j3 = entry3.getValue()._val;
                    if (j2 < j3) {
                        return -1;
                    }
                    return j2 == j3 ? 0 : 1;
                }
            });
            int size = arrayList.size();
            long[] jArr = new long[size];
            Vocabulary vocabulary = new Vocabulary(new IcedHashMapGeneric());
            int i = 0;
            for (Map.Entry entry2 : arrayList) {
                jArr[i] = ((IcedLong) entry2.getValue())._val;
                int i2 = i;
                i++;
                vocabulary._data.put(entry2.getKey(), Integer.valueOf(i2));
            }
            HBWTree buildHuffmanBinaryWordTree = HBWTree.buildHuffmanBinaryWordTree(jArr);
            this._vocabWordCount = j;
            this._treeKey = publish(buildHuffmanBinaryWordTree);
            this._vocabKey = publish(vocabulary);
            this._wordCountsKey = publish(wordCounts);
            RandomBase rng = RandomUtils.getRNG(new long[]{912559, 55930});
            this._syn1 = MemoryManager.malloc4f(this._parameters._vec_size * size);
            this._syn0 = MemoryManager.malloc4f(this._parameters._vec_size * size);
            for (int i3 = 0; i3 < this._parameters._vec_size * size; i3++) {
                this._syn0[i3] = (rng.nextFloat() - 0.5f) / this._parameters._vec_size;
            }
        }

        public static Word2VecModelInfo createInitialModelInfo(Word2VecParameters word2VecParameters) {
            return new Word2VecModelInfo(word2VecParameters, new WordCounts(((WordCountTask) new WordCountTask().doAll(new Vec[]{word2VecParameters.trainVec()}))._counts));
        }

        private static <T extends Keyed<T>> Key<T> publish(T t) {
            Scope.track_generic(t);
            DKV.put(t);
            return ((Keyed) t)._key;
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Word2VecOutput.class */
    public static class Word2VecOutput extends Model.Output {
        public int _vecSize;
        public int _epochs;
        public BufferedString[] _words;
        public float[] _vecs;
        public IcedHashMapGeneric<BufferedString, Integer> _vocab;

        public Word2VecOutput(Word2Vec word2Vec) {
            super(word2Vec);
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.WordEmbedding;
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Word2VecParameters.class */
    public static class Word2VecParameters extends Model.Parameters {
        static final int MAX_VEC_SIZE = 10000;
        public Word2Vec.WordModel _word_model = Word2Vec.WordModel.SkipGram;
        public Word2Vec.NormModel _norm_model = Word2Vec.NormModel.HSM;
        public int _min_word_freq = 5;
        public int _vec_size = 100;
        public int _window_size = 5;
        public int _epochs = 5;
        public float _init_learning_rate = 0.025f;
        public float _sent_sample_rate = 0.001f;
        public Key<Frame> _pre_trained;

        public String algoName() {
            return "Word2Vec";
        }

        public String fullName() {
            return "Word2Vec";
        }

        public String javaName() {
            return Word2VecModel.class.getName();
        }

        public long progressUnits() {
            return isPreTrained() ? this._pre_trained.get().anyVec().nChunks() : train().vec(0).nChunks() * this._epochs;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public boolean isPreTrained() {
            return this._pre_trained != null;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Vec trainVec() {
            return train().vec(0);
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Word2VecTransformTask.class */
    private static class Word2VecTransformTask extends MRTask<Word2VecTransformTask> {
        private Word2VecModel _model;
        static final /* synthetic */ boolean $assertionsDisabled;

        public Word2VecTransformTask(Word2VecModel word2VecModel) {
            this._model = word2VecModel;
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            if (!$assertionsDisabled && chunkArr.length != 1) {
                throw new AssertionError();
            }
            Chunk chunk = chunkArr[0];
            BufferedString bufferedString = new BufferedString();
            for (int i = 0; i < chunk._len; i++) {
                if (chunk.isNA(i)) {
                    for (NewChunk newChunk : newChunkArr) {
                        newChunk.addNA();
                    }
                } else {
                    if (this._model.transform(chunk.atStr(bufferedString, i)) == null) {
                        for (NewChunk newChunk2 : newChunkArr) {
                            newChunk2.addNA();
                        }
                    } else {
                        for (int i2 = 0; i2 < newChunkArr.length; i2++) {
                            newChunkArr[i2].addNum(r0[i2]);
                        }
                    }
                }
            }
        }

        static {
            $assertionsDisabled = !Word2VecModel.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$WordCounts.class */
    public static class WordCounts extends Keyed<WordCounts> {
        IcedHashMap<BufferedString, IcedLong> _data;

        WordCounts(IcedHashMap<BufferedString, IcedLong> icedHashMap) {
            super(Key.make());
            this._data = icedHashMap;
        }
    }

    public Word2VecModel(Key<Word2VecModel> key, Word2VecParameters word2VecParameters, Word2VecOutput word2VecOutput) {
        super(key, word2VecParameters, word2VecOutput);
        if (!$assertionsDisabled && !Arrays.equals(this._key._kb, key._kb)) {
            throw new AssertionError();
        }
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        throw H2O.unimpl("No Model Metrics for Word2Vec.");
    }

    public double[] score0(Chunk[] chunkArr, int i, double[] dArr, double[] dArr2) {
        throw H2O.unimpl();
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        throw H2O.unimpl();
    }

    /* renamed from: getMojo, reason: merged with bridge method [inline-methods] */
    public Word2VecMojoWriter m409getMojo() {
        return new Word2VecMojoWriter(this);
    }

    public Frame toFrame() {
        Vec vec = null;
        try {
            vec = Vec.makeZero(((Word2VecOutput) this._output)._words.length);
            byte[] bArr = new byte[1 + ((Word2VecOutput) this._output)._vecSize];
            Arrays.fill(bArr, (byte) 3);
            bArr[0] = 2;
            String[] strArr = new String[bArr.length];
            strArr[0] = "Word";
            for (int i = 1; i < strArr.length; i++) {
                strArr[i] = "V" + i;
            }
            Frame outputFrame = ((ConvertToFrameTask) new ConvertToFrameTask(this).doAll(bArr, new Vec[]{vec})).outputFrame(strArr, (String[][]) null);
            if (vec != null) {
                vec.remove();
            }
            return outputFrame;
        } catch (Throwable th) {
            if (vec != null) {
                vec.remove();
            }
            throw th;
        }
    }

    public float[] transform(String str) {
        return transform(new BufferedString(str));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public float[] transform(BufferedString bufferedString) {
        if (!((Word2VecOutput) this._output)._vocab.containsKey(bufferedString)) {
            return null;
        }
        int intValue = ((Integer) ((Word2VecOutput) this._output)._vocab.get(bufferedString)).intValue();
        return Arrays.copyOfRange(((Word2VecOutput) this._output)._vecs, intValue * ((Word2VecOutput) this._output)._vecSize, (intValue + 1) * ((Word2VecOutput) this._output)._vecSize);
    }

    public Frame transform(Vec vec, AggregateMethod aggregateMethod) {
        if (vec.get_type() != 2) {
            throw new IllegalArgumentException("Expected a string vector, got " + vec.get_type_str() + " vector.");
        }
        byte[] bArr = new byte[((Word2VecOutput) this._output)._vecSize];
        Arrays.fill(bArr, (byte) 3);
        return (aggregateMethod == AggregateMethod.AVERAGE ? new Word2VecAggregateTask(this) : new Word2VecTransformTask(this)).doAll(bArr, new Vec[]{vec}).outputFrame(Key.make(), (String[]) null, (String[][]) null);
    }

    public Map<String, Float> findSynonyms(String str, int i) {
        float[] transform = transform(str);
        if (transform == null || i == 0) {
            return Collections.emptyMap();
        }
        int[] iArr = new int[i];
        float[] fArr = new float[i];
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            iArr[i3] = i3;
            fArr[i3] = cosineSimilarity(transform, i3 * transform.length, ((Word2VecOutput) this._output)._vecs);
            if (fArr[i3] < fArr[i2]) {
                i2 = i3;
            }
        }
        int size = ((Word2VecOutput) this._output)._vocab.size();
        for (int i4 = i; i4 < size; i4++) {
            float cosineSimilarity = cosineSimilarity(transform, i4 * transform.length, ((Word2VecOutput) this._output)._vecs);
            if (cosineSimilarity > fArr[i2] && cosineSimilarity < 0.999999d) {
                iArr[i2] = i4;
                fArr[i2] = cosineSimilarity;
                i2 = 0;
                for (int i5 = 1; i5 < i; i5++) {
                    if (fArr[i5] < fArr[i2]) {
                        i2 = i5;
                    }
                }
            }
        }
        HashMap hashMap = new HashMap(i);
        for (int i6 = 0; i6 < i; i6++) {
            hashMap.put(((Word2VecOutput) this._output)._words[iArr[i6]].toString(), Float.valueOf(fArr[i6]));
        }
        return hashMap;
    }

    private float cosineSimilarity(float[] fArr, int i, float[] fArr2) {
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            f += fArr[i2] * fArr2[i + i2];
            f2 = (float) (f2 + Math.pow(fArr[i2], 2.0d));
            f3 = (float) (f3 + Math.pow(fArr2[i + i2], 2.0d));
        }
        return (float) (f / (Math.sqrt(f2) * Math.sqrt(f3)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void buildModelOutput(Word2VecModelInfo word2VecModelInfo) {
        IcedHashMapGeneric<BufferedString, Integer> icedHashMapGeneric = DKV.getGet(word2VecModelInfo._vocabKey)._data;
        BufferedString[] bufferedStringArr = new BufferedString[icedHashMapGeneric.size()];
        for (BufferedString bufferedString : icedHashMapGeneric.keySet()) {
            bufferedStringArr[((Integer) icedHashMapGeneric.get(bufferedString)).intValue()] = bufferedString;
        }
        ((Word2VecOutput) this._output)._vecSize = ((Word2VecParameters) this._parms)._vec_size;
        ((Word2VecOutput) this._output)._vecs = word2VecModelInfo._syn0;
        ((Word2VecOutput) this._output)._words = bufferedStringArr;
        ((Word2VecOutput) this._output)._vocab = icedHashMapGeneric;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void buildModelOutput(BufferedString[] bufferedStringArr, float[] fArr) {
        IcedHashMapGeneric<BufferedString, Integer> icedHashMapGeneric = new IcedHashMapGeneric<>();
        for (int i = 0; i < bufferedStringArr.length; i++) {
            icedHashMapGeneric.put(bufferedStringArr[i], Integer.valueOf(i));
        }
        ((Word2VecOutput) this._output)._vecSize = ((Word2VecParameters) this._parms)._vec_size;
        ((Word2VecOutput) this._output)._vecs = fArr;
        ((Word2VecOutput) this._output)._words = bufferedStringArr;
        ((Word2VecOutput) this._output)._vocab = icedHashMapGeneric;
    }

    static {
        $assertionsDisabled = !Word2VecModel.class.desiredAssertionStatus();
    }
}
