package hex.word2vec;

import hex.word2vec.Word2Vec;
import hex.word2vec.Word2VecModel;
import java.util.Iterator;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.IcedHashMapGeneric;
import water.util.IcedLong;

/* loaded from: input_file:hex/word2vec/WordVectorTrainer.class */
public class WordVectorTrainer extends MRTask<WordVectorTrainer> {
    private static final int MAX_SENTENCE_LEN = 1000;
    private static final int EXP_TABLE_SIZE = 1000;
    private static final int MAX_EXP = 6;
    private static final float[] _expTable = calcExpTable();
    private static final float LEARNING_RATE_MIN_FACTOR = 1.0E-4f;
    private final Job<Word2VecModel> _job;
    private final Word2Vec.WordModel _wordModel;
    private final int _wordVecSize;
    private final int _windowSize;
    private final int _epochs;
    private final float _initLearningRate;
    private final float _sentSampleRate;
    private final long _vocabWordCount;
    private final Key<Word2VecModel.Vocabulary> _vocabKey;
    private final Key<Word2VecModel.WordCounts> _wordCountsKey;
    private final Key<HBWTree> _treeKey;
    private final long _prevTotalProcessedWords;
    float[] _syn0;
    float[] _syn1;
    long _processedWords;
    IcedLong _nodeProcessedWords;
    private transient IcedHashMapGeneric<BufferedString, Integer> _vocab;
    private transient IcedHashMap<BufferedString, IcedLong> _wordCounts;
    private transient int[][] _HBWTCode;
    private transient int[][] _HBWTPoint;
    private float _curLearningRate;
    private long _seed;

    /* loaded from: input_file:hex/word2vec/WordVectorTrainer$ChunkSentenceIterator.class */
    private class ChunkSentenceIterator implements Iterator<int[]> {
        private Chunk _chk;
        private int _pos;
        private int _len;
        private int[] _sent;

        private ChunkSentenceIterator(Chunk chunk) {
            this._pos = 0;
            this._len = -1;
            this._sent = new int[1001];
            this._chk = chunk;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return nextLength() >= 0;
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* JADX WARN: Multi-variable type inference failed */
        public int nextLength() {
            if (this._len >= 0) {
                return this._len;
            }
            if (this._pos >= this._chk._len) {
                return -1;
            }
            this._len = 0;
            BufferedString bufferedString = new BufferedString();
            while (this._pos < this._chk._len && !this._chk.isNA(this._pos) && this._len < 1000) {
                BufferedString atStr = this._chk.atStr(bufferedString, this._pos);
                if (WordVectorTrainer.this._vocab.containsKey(atStr)) {
                    if (WordVectorTrainer.this._sentSampleRate > 0.0f) {
                        if (((float) (((Math.sqrt(((float) r0) / (WordVectorTrainer.this._sentSampleRate * ((float) WordVectorTrainer.this._vocabWordCount))) + 1.0d) * (WordVectorTrainer.this._sentSampleRate * ((float) WordVectorTrainer.this._vocabWordCount))) / ((IcedLong) WordVectorTrainer.this._wordCounts.get(atStr))._val)) * 65536.0f < WordVectorTrainer.this.cheapRandInt(65535)) {
                        }
                    }
                    int[] iArr = this._sent;
                    int i = this._len;
                    this._len = i + 1;
                    iArr[i] = ((Integer) WordVectorTrainer.this._vocab.get(bufferedString)).intValue();
                }
                this._pos++;
            }
            this._sent[this._len] = -1;
            this._pos++;
            return this._len;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public int[] next() {
            if (!hasNext()) {
                return null;
            }
            this._len = -1;
            return this._sent;
        }

        @Override // java.util.Iterator
        public void remove() {
            throw new UnsupportedOperationException("Remove is not supported");
        }
    }

    public WordVectorTrainer(Job<Word2VecModel> job, Word2VecModel.Word2VecModelInfo word2VecModelInfo) {
        super((H2O.H2OCountedCompleter) null);
        this._processedWords = 0L;
        this._seed = System.nanoTime();
        this._job = job;
        this._treeKey = word2VecModelInfo._treeKey;
        this._vocabKey = word2VecModelInfo._vocabKey;
        this._wordCountsKey = word2VecModelInfo._wordCountsKey;
        this._wordModel = word2VecModelInfo.getParams()._word_model;
        this._wordVecSize = word2VecModelInfo.getParams()._vec_size;
        this._windowSize = word2VecModelInfo.getParams()._window_size;
        this._sentSampleRate = word2VecModelInfo.getParams()._sent_sample_rate;
        this._epochs = word2VecModelInfo.getParams()._epochs;
        this._initLearningRate = word2VecModelInfo.getParams()._init_learning_rate;
        this._vocabWordCount = word2VecModelInfo._vocabWordCount;
        this._prevTotalProcessedWords = word2VecModelInfo._totalProcessedWords;
        this._syn0 = word2VecModelInfo._syn0;
        this._syn1 = word2VecModelInfo._syn1;
        this._curLearningRate = calcLearningRate(this._initLearningRate, this._epochs, this._prevTotalProcessedWords, this._vocabWordCount);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // water.MRTask
    public void setupLocal() {
        this._vocab = ((Word2VecModel.Vocabulary) DKV.getGet(this._vocabKey))._data;
        this._wordCounts = ((Word2VecModel.WordCounts) DKV.getGet(this._wordCountsKey))._data;
        HBWTree hBWTree = (HBWTree) DKV.getGet(this._treeKey);
        this._HBWTCode = hBWTree._code;
        this._HBWTPoint = hBWTree._point;
        this._nodeProcessedWords = new IcedLong(0L);
    }

    private static float[] calcExpTable() {
        float[] fArr = new float[1000];
        for (int i = 0; i < 1000; i++) {
            fArr[i] = (float) Math.exp((((i / 1000.0f) * 2.0f) - 1.0f) * 6.0f);
            fArr[i] = fArr[i] / (fArr[i] + 1.0f);
        }
        return fArr;
    }

    @Override // water.MRTask
    public void map(Chunk chunk) {
        int i;
        int i2 = this._windowSize;
        int i3 = this._wordVecSize;
        float[] fArr = new float[i3];
        float[] fArr2 = new float[i3];
        ChunkSentenceIterator chunkSentenceIterator = new ChunkSentenceIterator(chunk);
        int i4 = 0;
        while (chunkSentenceIterator.hasNext()) {
            int nextLength = chunkSentenceIterator.nextLength();
            int[] next = chunkSentenceIterator.next();
            for (int i5 = 0; i5 < nextLength; i5++) {
                int i6 = next[i5];
                int i7 = 0;
                if (this._wordModel == Word2Vec.WordModel.CBOW) {
                    for (int i8 = 0; i8 < i3; i8++) {
                        fArr[i8] = 0.0f;
                    }
                    for (int i9 = 0; i9 < i3; i9++) {
                        fArr2[i9] = 0.0f;
                    }
                }
                int cheapRandInt = cheapRandInt(i2);
                for (int i10 = cheapRandInt; i10 < ((i2 * 2) + 1) - cheapRandInt; i10++) {
                    if (i10 != i2 && (i = (i5 - i2) + i10) >= 0 && i < nextLength) {
                        int i11 = next[i];
                        if (this._wordModel == Word2Vec.WordModel.SkipGram) {
                            skipGram(i6, i11, fArr2);
                        } else {
                            for (int i12 = 0; i12 < i3; i12++) {
                                int i13 = i12;
                                fArr[i13] = fArr[i13] + this._syn0[i12 + (i11 * i3)];
                            }
                            i7++;
                        }
                    }
                }
                if (this._wordModel == Word2Vec.WordModel.CBOW && i7 > 0) {
                    CBOW(i6, next, i5, nextLength, cheapRandInt, i7, fArr, fArr2);
                }
                i4++;
                if (i4 % 10000 == 0) {
                    this._nodeProcessedWords._val += 10000;
                    this._curLearningRate = calcLearningRate(this._initLearningRate, this._epochs, this._prevTotalProcessedWords + this._nodeProcessedWords._val, this._vocabWordCount);
                }
            }
        }
        this._processedWords = i4;
        this._nodeProcessedWords._val += i4 % 10000;
        this._job.update(1L);
    }

    @Override // water.MRTask
    public void reduce(WordVectorTrainer wordVectorTrainer) {
        this._processedWords += wordVectorTrainer._processedWords;
        if (this._syn0 != wordVectorTrainer._syn0) {
            float f = ((float) wordVectorTrainer._processedWords) / ((float) this._processedWords);
            ArrayUtils.add(1.0f - f, this._syn0, f, wordVectorTrainer._syn0);
            ArrayUtils.add(1.0f - f, this._syn1, f, wordVectorTrainer._syn1);
            this._nodeProcessedWords._val += wordVectorTrainer._nodeProcessedWords._val;
        }
    }

    private void skipGram(int i, int i2, float[] fArr) {
        int i3 = this._wordVecSize;
        int i4 = i2 * i3;
        for (int i5 = 0; i5 < i3; i5++) {
            fArr[i5] = 0.0f;
        }
        hierarchicalSoftmaxSG(i, i4, fArr);
        for (int i6 = 0; i6 < i3; i6++) {
            float[] fArr2 = this._syn0;
            int i7 = i6 + i4;
            fArr2[i7] = fArr2[i7] + fArr[i6];
        }
    }

    private void hierarchicalSoftmaxSG(int i, int i2, float[] fArr) {
        int i3 = this._wordVecSize;
        int length = this._HBWTCode[i].length;
        float f = this._curLearningRate;
        for (int i4 = 0; i4 < length; i4++) {
            int i5 = this._HBWTPoint[i][i4] * i3;
            float f2 = 0.0f;
            for (int i6 = 0; i6 < i3; i6++) {
                f2 += this._syn0[i6 + i2] * this._syn1[i6 + i5];
            }
            if (f2 > -6.0f && f2 < 6.0f) {
                float f3 = ((1 - this._HBWTCode[i][i4]) - _expTable[(int) ((f2 + 6.0f) * 83.0f)]) * f;
                for (int i7 = 0; i7 < i3; i7++) {
                    int i8 = i7;
                    fArr[i8] = fArr[i8] + (f3 * this._syn1[i7 + i5]);
                }
                for (int i9 = 0; i9 < i3; i9++) {
                    float[] fArr2 = this._syn1;
                    int i10 = i9 + i5;
                    fArr2[i10] = fArr2[i10] + (f3 * this._syn0[i9 + i2]);
                }
            }
        }
    }

    private void CBOW(int i, int[] iArr, int i2, int i3, int i4, int i5, float[] fArr, float[] fArr2) {
        int i6;
        int i7 = this._wordVecSize;
        int i8 = this._windowSize;
        int i9 = ((i8 * 2) + 1) - i8;
        for (int i10 = 0; i10 < i7; i10++) {
            int i11 = i10;
            fArr[i11] = fArr[i11] / i5;
        }
        hierarchicalSoftmaxCBOW(i, fArr, fArr2);
        for (int i12 = i4; i12 < i9; i12++) {
            if (i12 != i8 && (i6 = (i2 - i8) + i12) >= 0 && i6 < i3) {
                int i13 = iArr[i6];
                for (int i14 = 0; i14 < i7; i14++) {
                    float[] fArr3 = this._syn0;
                    int i15 = i14 + (i13 * i7);
                    fArr3[i15] = fArr3[i15] + fArr2[i14];
                }
            }
        }
    }

    private void hierarchicalSoftmaxCBOW(int i, float[] fArr, float[] fArr2) {
        int i2 = this._wordVecSize;
        int length = this._HBWTCode[i].length;
        float f = this._curLearningRate;
        float f2 = 0.0f;
        int i3 = 0;
        while (i3 < length) {
            int i4 = this._HBWTPoint[i][i3] * i2;
            for (int i5 = 0; i5 < i2; i5++) {
                f2 += fArr[i5] * this._syn1[i5 + i4];
            }
            if (f2 > -6.0f && f2 < 6.0f) {
                float f3 = ((1 - this._HBWTCode[i][i3]) - _expTable[(int) ((f2 + 6.0f) * 83.0f)]) * f;
                for (int i6 = 0; i6 < i2; i6++) {
                    int i7 = i6;
                    fArr2[i7] = fArr2[i7] + (f3 * this._syn1[i6 + i4]);
                }
                for (int i8 = 0; i8 < i2; i8++) {
                    float[] fArr3 = this._syn1;
                    int i9 = i8 + i4;
                    fArr3[i9] = fArr3[i9] + (f3 * fArr[i8]);
                }
            }
            i3++;
            f2 = 0.0f;
        }
    }

    private static float calcLearningRate(float f, int i, long j, long j2) {
        float f2 = f * (1.0f - (((float) j) / ((float) ((i * j2) + 1))));
        if (f2 < f * LEARNING_RATE_MIN_FACTOR) {
            f2 = f * LEARNING_RATE_MIN_FACTOR;
        }
        return f2;
    }

    public void updateModelInfo(Word2VecModel.Word2VecModelInfo word2VecModelInfo) {
        word2VecModelInfo._syn0 = this._syn0;
        word2VecModelInfo._syn1 = this._syn1;
        word2VecModelInfo._totalProcessedWords += this._processedWords;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int cheapRandInt(int i) {
        this._seed ^= this._seed << 21;
        this._seed ^= this._seed >>> 35;
        this._seed ^= this._seed << 4;
        int i2 = ((int) this._seed) % i;
        return i2 > 0 ? i2 : -i2;
    }
}
