package com.linkedin.dagli.fasttext.anonymized;

import com.linkedin.dagli.fasttext.anonymized.Args;
import com.linkedin.dagli.math.distribution.AliasSampler;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:com/linkedin/dagli/fasttext/anonymized/Model.class */
public class Model {
    static final int SIGMOID_TABLE_SIZE = 512;
    static final int MAX_SIGMOID = 8;
    static final int LOG_TABLE_SIZE = 512;
    static final int NEGATIVE_TABLE_SIZE = 10000000;
    private Matrix wi_;
    private Matrix wo_;
    private Args args_;
    private Vector hidden_;
    private Vector output_;
    private Vector grad_;
    private int hsz_;
    private int isz_;
    private int osz_;
    private float[] t_sigmoid;
    private float[] t_log;
    private AliasSampler<?> _randomNegativeSampler;
    private List<List<Integer>> paths;
    private List<List<Boolean>> codes;
    private List<Node> tree;
    public transient Random rng;
    private Comparator<Pair<Float, Integer>> comparePairs = new Comparator<Pair<Float, Integer>>() { // from class: com.linkedin.dagli.fasttext.anonymized.Model.1
        @Override // java.util.Comparator
        public int compare(Pair<Float, Integer> pair, Pair<Float, Integer> pair2) {
            return pair2.getKey().compareTo(pair.getKey());
        }
    };
    private int negpos = 0;
    private float loss_ = 0.0f;
    private long nexamples_ = 1;

    /* loaded from: input_file:com/linkedin/dagli/fasttext/anonymized/Model$Node.class */
    public class Node {
        int parent;
        int left;
        int right;
        long count;
        boolean binary;

        public Node() {
        }
    }

    public Model(Matrix matrix, Matrix matrix2, Args args, int i) {
        this.hidden_ = new Vector(args.dim);
        this.output_ = new Vector(matrix2.m_);
        this.grad_ = new Vector(args.dim);
        this.rng = new Random(i);
        this.wi_ = matrix;
        this.wo_ = matrix2;
        this.args_ = args;
        this.isz_ = matrix.m_;
        this.osz_ = matrix2.m_;
        this.hsz_ = args.dim;
        initSigmoid();
        initLog();
    }

    public float[][] getLabelEmbeddings() {
        return this.wo_.data_;
    }

    public float[] getLabelEmbedding(int i) {
        return this.wo_.data_[i];
    }

    public float[] getInputEmbedding(int i) {
        return this.wi_.data_[i];
    }

    public float[][] getInputEmbeddingsStartingAtRow(int i) {
        return (float[][]) Arrays.copyOfRange(this.wi_.data_, i, this.wi_.data_.length);
    }

    public float binaryLogistic(int i, boolean z, float f) {
        float sigmoid = sigmoid(this.wo_.dotRow(this.hidden_, i));
        float f2 = f * ((z ? 1.0f : 0.0f) - sigmoid);
        this.grad_.addRow(this.wo_, i, f2);
        this.wo_.addRow(this.hidden_, i, f2);
        return z ? -log(sigmoid) : -log(1.0f - sigmoid);
    }

    public float negativeSampling(int i, float f) {
        float f2;
        float binaryLogistic;
        float f3 = 0.0f;
        this.grad_.zero();
        for (int i2 = 0; i2 <= this.args_.neg; i2++) {
            if (i2 == 0) {
                f2 = f3;
                binaryLogistic = binaryLogistic(i, true, f);
            } else {
                f2 = f3;
                binaryLogistic = binaryLogistic(getNegative(i), false, f);
            }
            f3 = f2 + binaryLogistic;
        }
        return f3;
    }

    public float hierarchicalSoftmax(int i, float f) {
        float f2 = 0.0f;
        this.grad_.zero();
        List<Boolean> list = this.codes.get(i);
        List<Integer> list2 = this.paths.get(i);
        for (int i2 = 0; i2 < list2.size(); i2++) {
            f2 += binaryLogistic(list2.get(i2).intValue(), list.get(i2).booleanValue(), f);
        }
        return f2;
    }

    public void computeOutputSoftmax(Vector vector, Vector vector2, float[] fArr) {
        vector2.mul(this.wo_, vector);
        if (fArr != null) {
            for (int i = 0; i < fArr.length; i++) {
                fArr[i] = calculateSigmoid(vector2.get(i));
            }
        }
        float f = vector2.get(0);
        float f2 = 0.0f;
        for (int i2 = 1; i2 < this.osz_; i2++) {
            f = Math.max(vector2.get(i2), f);
        }
        for (int i3 = 0; i3 < this.osz_; i3++) {
            vector2.set(i3, (float) Math.exp(vector2.get(i3) - f));
            f2 += vector2.get(i3);
        }
        for (int i4 = 0; i4 < this.osz_; i4++) {
            vector2.set(i4, vector2.get(i4) / f2);
        }
    }

    public void computeOutputSoftmax() {
        computeOutputSoftmax(this.hidden_, this.output_, null);
    }

    public float softmax(int i, float f) {
        this.grad_.zero();
        computeOutputSoftmax();
        int i2 = 0;
        while (i2 < this.osz_) {
            float f2 = f * ((i2 == i ? 1.0f : 0.0f) - this.output_.get(i2));
            this.grad_.addRow(this.wo_, i2, f2);
            this.wo_.addRow(this.hidden_, i2, f2);
            i2++;
        }
        return -log(this.output_.get(i));
    }

    public void computeHidden(IntArrayList intArrayList, IntArrayList intArrayList2, Vector vector) {
        Utils.checkArgument(vector.size() == this.hsz_);
        vector.zero();
        intArrayList.forEach(i -> {
            vector.addRow(this.wi_, i);
        });
        intArrayList2.forEach(i2 -> {
            vector.addRow(this.wi_, i2);
        });
        vector.mul(1.0f / (intArrayList.size() + intArrayList2.size()));
    }

    public void update(IntArrayList intArrayList, IntArrayList intArrayList2, int i, float f) {
        Utils.checkArgument(i >= 0);
        Utils.checkArgument(i < this.osz_);
        if (intArrayList.isEmpty()) {
            return;
        }
        computeHidden(intArrayList, intArrayList2, this.hidden_);
        if (this.args_.loss == Args.loss_name.ns) {
            this.loss_ += negativeSampling(i, f);
        } else if (this.args_.loss == Args.loss_name.hs) {
            this.loss_ += hierarchicalSoftmax(i, f);
        } else {
            this.loss_ += softmax(i, f);
        }
        this.nexamples_++;
        if (this.args_.model == Args.model_name.sup) {
            this.grad_.mul(1.0f / (intArrayList.size() + intArrayList2.size()));
        }
        intArrayList.forEach(i2 -> {
            this.wi_.addRow(this.grad_, i2, 1.0f);
        });
        intArrayList2.forEach(i3 -> {
            this.wi_.addRow(this.grad_, i3, 1.0f);
        });
    }

    public void setTargetCounts(long[] jArr) {
        Utils.checkArgument(jArr.length == this.osz_);
        if (this.args_.loss == Args.loss_name.ns) {
            initTableNegatives(jArr);
        }
        if (this.args_.loss == Args.loss_name.hs) {
            buildTree(jArr);
        }
    }

    public void initTableNegatives(long[] jArr) {
        double[] dArr = new double[jArr.length];
        float f = 0.0f;
        for (long j : jArr) {
            f += (float) Math.pow(j, 0.5d);
        }
        for (int i = 0; i < jArr.length; i++) {
            dArr[i] = (((float) Math.pow(jArr[i], 0.5d)) * 1.0E7f) / f;
        }
        this._randomNegativeSampler = new AliasSampler<>(new Object[dArr.length], dArr);
    }

    public int getNegative(int i) {
        int sampleIndex;
        do {
            sampleIndex = this._randomNegativeSampler.sampleIndex(this.rng.nextDouble());
        } while (i == sampleIndex);
        return sampleIndex;
    }

    public void buildTree(long[] jArr) {
        this.paths = new ArrayList(this.osz_);
        this.codes = new ArrayList(this.osz_);
        this.tree = new ArrayList((2 * this.osz_) - 1);
        for (int i = 0; i < (2 * this.osz_) - 1; i++) {
            Node node = new Node();
            node.parent = -1;
            node.left = -1;
            node.right = -1;
            node.count = 1000000000000000L;
            node.binary = false;
            this.tree.add(i, node);
        }
        for (int i2 = 0; i2 < this.osz_; i2++) {
            this.tree.get(i2).count = jArr[i2];
        }
        int i3 = this.osz_ - 1;
        int i4 = this.osz_;
        for (int i5 = this.osz_; i5 < (2 * this.osz_) - 1; i5++) {
            int[] iArr = new int[2];
            for (int i6 = 0; i6 < 2; i6++) {
                if (i3 < 0 || this.tree.get(i3).count >= this.tree.get(i4).count) {
                    int i7 = i4;
                    i4++;
                    iArr[i6] = i7;
                } else {
                    int i8 = i3;
                    i3--;
                    iArr[i6] = i8;
                }
            }
            this.tree.get(i5).left = iArr[0];
            this.tree.get(i5).right = iArr[1];
            this.tree.get(i5).count = this.tree.get(iArr[0]).count + this.tree.get(iArr[1]).count;
            this.tree.get(iArr[0]).parent = i5;
            this.tree.get(iArr[1]).parent = i5;
            this.tree.get(iArr[1]).binary = true;
        }
        for (int i9 = 0; i9 < this.osz_; i9++) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            int i10 = i9;
            while (true) {
                int i11 = i10;
                if (this.tree.get(i11).parent != -1) {
                    arrayList.add(Integer.valueOf(this.tree.get(i11).parent - this.osz_));
                    arrayList2.add(Boolean.valueOf(this.tree.get(i11).binary));
                    i10 = this.tree.get(i11).parent;
                }
            }
            this.paths.add(arrayList);
            this.codes.add(arrayList2);
        }
    }

    public float getLoss() {
        return this.loss_ / ((float) this.nexamples_);
    }

    private void initSigmoid() {
        this.t_sigmoid = new float[513];
        for (int i = 0; i < 513; i++) {
            this.t_sigmoid[i] = calculateSigmoid((((i * 2) * MAX_SIGMOID) / 512.0f) - 8.0f);
        }
    }

    public static float calculateSigmoid(float f) {
        return (float) (1.0d / (1.0d + Math.exp(-f)));
    }

    private void initLog() {
        this.t_log = new float[513];
        for (int i = 0; i < 513; i++) {
            this.t_log[i] = (float) Math.log((i + 1.0E-5f) / 512.0f);
        }
    }

    public float log(float f) {
        if (f > 1.0f) {
            return 0.0f;
        }
        return this.t_log[(int) (f * 512.0f)];
    }

    public float sigmoid(float f) {
        if (f < -8.0f) {
            return 0.0f;
        }
        if (f > 8.0f) {
            return 1.0f;
        }
        return this.t_sigmoid[(int) ((((f + 8.0f) * 512.0f) / 8.0f) / 2.0f)];
    }
}
