package hex.word2vec;

import hex.word2vec.Word2Vec;
import hex.word2vec.Word2VecModel;
import java.util.Random;
import water.H2O;
import water.MRTask;
import water.fvec.CStrChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashMap;
import water.parser.BufferedString;
import water.util.Log;

/* loaded from: input_file:hex/word2vec/WordVectorTrainer.class */
public class WordVectorTrainer extends MRTask<WordVectorTrainer> {
    static final int MAX_SENTENCE_LEN = 1000;
    static final int MIN_SENTENCE_LEN = 10;
    static final int EXP_TABLE_SIZE = 1000;
    static final int MAX_EXP = 6;
    private Word2VecModel.Word2VecModelInfo _input;
    Word2VecModel.Word2VecModelInfo _output;
    Frame _vocab;
    static NonBlockingHashMap<BufferedString, Integer> _vocabHM;
    final Word2Vec.WordModel _wordModel;
    final Word2Vec.NormModel _normModel;
    final int _vocabSize;
    final int _wordVecSize;
    final int _windowSize;
    final int _epochs;
    final int _negExCnt;
    final float _initLearningRate;
    final float _sentSampleRate;
    static float[] _syn0;
    static float[] _syn1;
    static float[] _expTable;
    final int[] _unigramTable;
    final int[][] _HBWTCode;
    final int[][] _HBWTPoint;
    int _chunkNodeCount;
    transient float _curLearningRate;
    transient int _chkIdx;
    transient Random _rand;
    static transient long _seed;
    static long _lastWarn;
    static long _warnCount;
    static final /* synthetic */ boolean $assertionsDisabled;

    public WordVectorTrainer(Word2VecModel.Word2VecModelInfo word2VecModelInfo) {
        super((H2O.H2OCountedCompleter) null);
        this._chunkNodeCount = 1;
        this._chkIdx = 0;
        this._input = word2VecModelInfo;
        this._wordModel = word2VecModelInfo.getParams()._wordModel;
        this._normModel = word2VecModelInfo.getParams()._normModel;
        this._vocab = word2VecModelInfo.getParams()._vocabKey.get();
        this._vocabSize = (int) this._vocab.numRows();
        this._wordVecSize = word2VecModelInfo.getParams()._vecSize;
        this._windowSize = word2VecModelInfo.getParams()._windowSize;
        _syn0 = word2VecModelInfo._syn0;
        _syn1 = word2VecModelInfo._syn1;
        this._initLearningRate = word2VecModelInfo.getParams()._initLearningRate;
        this._sentSampleRate = word2VecModelInfo.getParams()._sentSampleRate;
        this._epochs = word2VecModelInfo.getParams()._epochs;
        _seed = System.nanoTime();
        if (!$assertionsDisabled && this._output != null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this._vocab.numRows() <= 0) {
            throw new AssertionError();
        }
        if (word2VecModelInfo.getParams()._normModel == Word2Vec.NormModel.NegSampling) {
            this._negExCnt = word2VecModelInfo.getParams()._negSampleCnt;
            this._unigramTable = word2VecModelInfo._uniTable;
            this._HBWTCode = (int[][]) null;
            this._HBWTPoint = (int[][]) null;
            return;
        }
        this._negExCnt = 0;
        this._unigramTable = null;
        this._HBWTCode = word2VecModelInfo._HBWTCode;
        this._HBWTPoint = word2VecModelInfo._HBWTPoint;
    }

    public final Word2VecModel.Word2VecModelInfo getModelInfo() {
        return this._output;
    }

    protected void setupLocal() {
        _syn0 = this._input._syn0;
        _syn1 = this._input._syn1;
        this._output = this._input;
        this._input = null;
        this._rand = new Random();
        initExpTable();
        buildVocabHashMap();
        this._curLearningRate = this._output._curLearningRate;
        this._output.setLocallyProcessed(0);
    }

    private void buildVocabHashMap() {
        Vec vec = this._vocab.vec(0);
        _vocabHM = new NonBlockingHashMap<>((int) this._vocab.numRows());
        for (int i = 0; i < this._vocab.numRows(); i++) {
            _vocabHM.put(vec.atStr(new BufferedString(), i), Integer.valueOf(i));
        }
    }

    private void updateAlpha(int i) {
        this._curLearningRate = this._initLearningRate * (1.0f - (((float) (this._output.getGloballyProcessed() + i)) / ((float) ((this._epochs * this._output._trainFrameSize) + 1))));
        if (this._curLearningRate < this._initLearningRate * 1.0E-4f) {
            this._curLearningRate = this._initLearningRate * 1.0E-4f;
        }
    }

    private int getSentence(int[] iArr, CStrChunk cStrChunk) {
        Vec vec = this._vocab.vec(1);
        BufferedString bufferedString = new BufferedString();
        int i = 0;
        int i2 = (cStrChunk._len - 1) - this._chkIdx;
        if (i2 >= 1000) {
            i2 = 1000;
        } else if (i2 < MIN_SENTENCE_LEN) {
            return 0;
        }
        while (this._chkIdx < cStrChunk._len) {
            cStrChunk.atStr(bufferedString, this._chkIdx);
            if (_vocabHM.containsKey(bufferedString)) {
                int intValue = ((Integer) _vocabHM.get(bufferedString)).intValue();
                if (this._sentSampleRate <= 0.0f || ((((float) Math.sqrt(((float) vec.at8(intValue)) / (this._sentSampleRate * ((float) this._output._trainFrameSize)))) + 1.0f) * (this._sentSampleRate * ((float) this._output._trainFrameSize))) / ((float) vec.at8(intValue)) >= this._rand.nextFloat()) {
                    int i3 = i;
                    i++;
                    iArr[i3] = intValue;
                    if (i >= i2) {
                        break;
                    }
                }
            }
            this._chkIdx++;
        }
        return i2;
    }

    private void initExpTable() {
        _expTable = new float[1000];
        for (int i = 0; i < 1000; i++) {
            _expTable[i] = (float) Math.exp((((i / 1000.0f) * 2.0f) - 1.0f) * 6.0f);
            _expTable[i] = _expTable[i] / (_expTable[i] + 1.0f);
        }
    }

    public void map(Chunk[] chunkArr) {
        int i;
        int i2;
        int i3 = 0;
        int i4 = 0;
        int i5 = this._windowSize;
        int i6 = this._wordVecSize;
        float[] fArr = new float[i6];
        float[] fArr2 = new float[i6];
        int[] iArr = new int[1000];
        for (Chunk chunk : chunkArr) {
            i = chunk instanceof CStrChunk ? 0 : i + 1;
            while (true) {
                int sentence = getSentence(iArr, (CStrChunk) chunk);
                if (sentence > 0) {
                    for (int i7 = 0; i7 < sentence; i7++) {
                        if (i3 % 10000 == 0) {
                            updateAlpha(i3);
                        }
                        int i8 = iArr[i7];
                        i3++;
                        if (this._wordModel == Word2Vec.WordModel.CBOW) {
                            for (int i9 = 0; i9 < i6; i9++) {
                                fArr[i9] = 0.0f;
                            }
                            for (int i10 = 0; i10 < i6; i10++) {
                                fArr2[i10] = 0.0f;
                            }
                            i4 = 0;
                        }
                        int cheapRandInt = cheapRandInt(i5);
                        for (int i11 = cheapRandInt; i11 < ((i5 * 2) + 1) - cheapRandInt; i11++) {
                            if (i11 != i5 && (i2 = (i7 - i5) + i11) >= 0 && i2 < sentence) {
                                int i12 = iArr[i2];
                                if (this._wordModel == Word2Vec.WordModel.SkipGram) {
                                    skipGram(i8, i12, fArr2);
                                } else {
                                    for (int i13 = 0; i13 < i6; i13++) {
                                        int i14 = i13;
                                        fArr[i14] = fArr[i14] + _syn0[i13 + (i12 * i6)];
                                    }
                                    i4++;
                                }
                            }
                        }
                        if (this._wordModel == Word2Vec.WordModel.CBOW && i4 > 0) {
                            CBOW(i8, iArr, i7, sentence, cheapRandInt, i4, fArr, fArr2);
                        }
                    }
                }
            }
        }
        this._output.addLocallyProcessed(i3);
    }

    public void reduce(WordVectorTrainer wordVectorTrainer) {
        if (wordVectorTrainer._output.getLocallyProcessed() <= 0 || wordVectorTrainer._output == this._output) {
            return;
        }
        if (this._output.getLocallyProcessed() == 0) {
            this._output = wordVectorTrainer._output;
            this._chunkNodeCount = wordVectorTrainer._chunkNodeCount;
        } else {
            this._output.add(wordVectorTrainer._output);
            this._chunkNodeCount += wordVectorTrainer._chunkNodeCount;
        }
    }

    protected void closeLocal() {
        this._vocab = null;
    }

    protected void postGlobal() {
        if (H2O.CLOUD.size() > 1) {
            long currentTimeMillis = System.currentTimeMillis();
            if (this._chunkNodeCount < H2O.CLOUD.size() && currentTimeMillis - _lastWarn > 5000 && _warnCount < 3) {
                Log.warn(new Object[]{(H2O.CLOUD.size() - this._chunkNodeCount) + " node(s) (out of " + H2O.CLOUD.size() + ") are not contributing to model updates. Consider setting replicate_training_data to true or using a larger training dataset (or fewer H2O nodes)."});
                _lastWarn = currentTimeMillis;
                _warnCount++;
            }
        }
        this._output.div(this._chunkNodeCount);
        this._output.addGloballyProcessed(this._output.getLocallyProcessed());
        this._output.setLocallyProcessed(0);
        if (!$assertionsDisabled && this._input != null) {
            throw new AssertionError();
        }
    }

    private void skipGram(int i, int i2, float[] fArr) {
        int i3 = this._wordVecSize;
        int i4 = i2 * i3;
        for (int i5 = 0; i5 < i3; i5++) {
            fArr[i5] = 0.0f;
        }
        if (this._normModel == Word2Vec.NormModel.NegSampling) {
            negSamplingSG(i, i4, fArr);
        } else {
            hierarchicalSoftmaxSG(i, i4, fArr);
        }
        for (int i6 = 0; i6 < i3; i6++) {
            float[] fArr2 = _syn0;
            int i7 = i6 + i4;
            fArr2[i7] = fArr2[i7] + fArr[i6];
        }
    }

    private void CBOW(int i, int[] iArr, int i2, int i3, int i4, int i5, float[] fArr, float[] fArr2) {
        int i6;
        int i7 = this._wordVecSize;
        int i8 = this._windowSize;
        int i9 = ((this._windowSize * 2) + 1) - i8;
        for (int i10 = 0; i10 < i7; i10++) {
            int i11 = i10;
            fArr[i11] = fArr[i11] / i5;
        }
        if (this._normModel == Word2Vec.NormModel.NegSampling) {
            negSamplingCBOW(i, fArr, fArr2);
        } else {
            hierarchicalSoftmaxCBOW(i, fArr, fArr2);
        }
        for (int i12 = i4; i12 < i9; i12++) {
            if (i12 != i8 && (i6 = (i2 - i8) + i12) >= 0 && i6 < i3) {
                int i13 = iArr[i6];
                for (int i14 = 0; i14 < i7; i14++) {
                    float[] fArr3 = _syn0;
                    int i15 = i14 + (i13 * i7);
                    fArr3[i15] = fArr3[i15] + fArr2[i14];
                }
            }
        }
    }

    private void negSamplingCBOW(int i, float[] fArr, float[] fArr2) {
        int i2 = this._wordVecSize;
        int i3 = this._negExCnt;
        int length = this._unigramTable.length;
        float f = this._curLearningRate;
        float f2 = 0.0f;
        int i4 = i * i2;
        for (int i5 = 0; i5 < i2; i5++) {
            f2 += fArr[i5] * _syn1[i5 + i4];
        }
        float f3 = f2 > 6.0f ? 0.0f : f2 < -6.0f ? f : (1.0f - _expTable[(int) ((f2 + 6.0f) * 83.0f)]) * f;
        for (int i6 = 0; i6 < i2; i6++) {
            int i7 = i6;
            fArr2[i7] = fArr2[i7] + (f3 * _syn1[i6 + i4]);
        }
        for (int i8 = 0; i8 < i2; i8++) {
            float[] fArr3 = _syn1;
            int i9 = i8 + i4;
            fArr3[i9] = fArr3[i9] + (f3 * fArr[i8]);
        }
        for (int i10 = 1; i10 < i3 + 1; i10++) {
            float f4 = 0.0f;
            int i11 = this._unigramTable[cheapRandInt(length)];
            if (i11 != i) {
                int i12 = i11 * i2;
                for (int i13 = 0; i13 < i2; i13++) {
                    f4 += fArr[i13] * _syn1[i13 + i12];
                }
                float f5 = f4 > 6.0f ? -f : f4 < -6.0f ? 0.0f : (-_expTable[(int) ((f4 + 6.0f) * 83.0f)]) * f;
                for (int i14 = 0; i14 < i2; i14++) {
                    int i15 = i14;
                    fArr2[i15] = fArr2[i15] + (f5 * _syn1[i14 + i12]);
                }
                for (int i16 = 0; i16 < i2; i16++) {
                    float[] fArr4 = _syn1;
                    int i17 = i16 + i12;
                    fArr4[i17] = fArr4[i17] + (f5 * fArr[i16]);
                }
            }
        }
    }

    private void negSamplingSG(int i, int i2, float[] fArr) {
        int i3 = this._wordVecSize;
        int i4 = this._negExCnt;
        int length = this._unigramTable.length;
        float f = this._curLearningRate;
        float f2 = 0.0f;
        int i5 = i * i3;
        for (int i6 = 0; i6 < i3; i6++) {
            f2 += _syn0[i6 + i2] * _syn1[i6 + i5];
        }
        float f3 = f2 > 6.0f ? 0.0f : f2 < -6.0f ? f : (1.0f - _expTable[(int) ((f2 + 6.0f) * 83.0f)]) * f;
        for (int i7 = 0; i7 < i3; i7++) {
            int i8 = i7;
            fArr[i8] = fArr[i8] + (f3 * _syn1[i7 + i5]);
        }
        for (int i9 = 0; i9 < i3; i9++) {
            float[] fArr2 = _syn1;
            int i10 = i9 + i5;
            fArr2[i10] = fArr2[i10] + (f3 * _syn0[i9 + i2]);
        }
        for (int i11 = 1; i11 < i4 + 1; i11++) {
            float f4 = 0.0f;
            int i12 = this._unigramTable[cheapRandInt(length)];
            if (i12 != i) {
                int i13 = i12 * i3;
                for (int i14 = 0; i14 < i3; i14++) {
                    f4 += _syn0[i14 + i2] * _syn1[i14 + i13];
                }
                float f5 = f4 > 6.0f ? -f : f4 < -6.0f ? 0.0f : (-_expTable[(int) ((f4 + 6.0f) * 83.0f)]) * f;
                for (int i15 = 0; i15 < i3; i15++) {
                    int i16 = i15;
                    fArr[i16] = fArr[i16] + (f5 * _syn1[i15 + i13]);
                }
                for (int i17 = 0; i17 < i3; i17++) {
                    float[] fArr3 = _syn1;
                    int i18 = i17 + i13;
                    fArr3[i18] = fArr3[i18] + (f5 * _syn0[i17 + i2]);
                }
            }
        }
    }

    private int cheapRandInt(int i) {
        _seed ^= _seed << 21;
        _seed ^= _seed >>> 35;
        _seed ^= _seed << 4;
        int i2 = ((int) _seed) % i;
        return i2 > 0 ? i2 : -i2;
    }

    private void hierarchicalSoftmaxCBOW(int i, float[] fArr, float[] fArr2) {
        int i2 = this._wordVecSize;
        int length = this._HBWTCode[i].length;
        float f = this._curLearningRate;
        float f2 = 0.0f;
        int i3 = 0;
        while (i3 < length) {
            int i4 = this._HBWTPoint[i][i3] * i2;
            for (int i5 = 0; i5 < i2; i5++) {
                f2 += fArr[i5] * _syn1[i5 + i4];
            }
            if (f2 > -6.0f && f2 < 6.0f) {
                float f3 = ((1 - this._HBWTCode[i][i3]) - _expTable[(int) ((f2 + 6.0f) * 83.0f)]) * f;
                for (int i6 = 0; i6 < i2; i6++) {
                    int i7 = i6;
                    fArr2[i7] = fArr2[i7] + (f3 * _syn1[i6 + i4]);
                }
                for (int i8 = 0; i8 < i2; i8++) {
                    float[] fArr3 = _syn1;
                    int i9 = i8 + i4;
                    fArr3[i9] = fArr3[i9] + (f3 * fArr[i8]);
                }
            }
            i3++;
            f2 = 0.0f;
        }
    }

    private void hierarchicalSoftmaxSG(int i, int i2, float[] fArr) {
        int i3 = this._wordVecSize;
        int length = this._HBWTCode[i].length;
        float f = this._curLearningRate;
        float f2 = 0.0f;
        int i4 = 0;
        while (i4 < length) {
            int i5 = this._HBWTPoint[i][i4] * i3;
            for (int i6 = 0; i6 < i3; i6++) {
                f2 += _syn0[i6 + i2] * _syn1[i6 + i5];
            }
            if (f2 > -6.0f && f2 < 6.0f) {
                float f3 = ((1 - this._HBWTCode[i][i4]) - _expTable[(int) ((f2 + 6.0f) * 83.0f)]) * f;
                for (int i7 = 0; i7 < i3; i7++) {
                    int i8 = i7;
                    fArr[i8] = fArr[i8] + (f3 * _syn1[i7 + i5]);
                }
                for (int i9 = 0; i9 < i3; i9++) {
                    float[] fArr2 = _syn1;
                    int i10 = i9 + i5;
                    fArr2[i10] = fArr2[i10] + (f3 * _syn0[i9 + i2]);
                }
            }
            i4++;
            f2 = 0.0f;
        }
    }

    static {
        $assertionsDisabled = !WordVectorTrainer.class.desiredAssertionStatus();
    }
}
