package org.apache.mahout.classifier.sgd;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collection;
import java.util.HashSet;
import java.util.Random;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/GradientMachine.class */
public class GradientMachine extends AbstractVectorClassifier implements OnlineLearner, Writable {
    public static final int WRITABLE_VERSION = 1;
    private double learningRate = 0.1d;
    private double regularization = 0.1d;
    private double sparsity = 0.1d;
    private double sparsityLearningRate = 0.1d;
    private int numFeatures;
    private int numHidden;
    private int numOutput;
    private Vector[] hiddenWeights;
    private Vector[] outputWeights;
    private Vector hiddenBias;
    private Vector outputBias;
    private final Random rnd;

    public GradientMachine(int i, int i2, int i3) {
        this.numFeatures = 10;
        this.numHidden = 100;
        this.numOutput = 2;
        this.numFeatures = i;
        this.numHidden = i2;
        this.numOutput = i3;
        this.hiddenWeights = new DenseVector[i2];
        for (int i4 = 0; i4 < i2; i4++) {
            this.hiddenWeights[i4] = new DenseVector(i);
            this.hiddenWeights[i4].assign(0.0d);
        }
        this.hiddenBias = new DenseVector(i2);
        this.hiddenBias.assign(0.0d);
        this.outputWeights = new DenseVector[i3];
        for (int i5 = 0; i5 < i3; i5++) {
            this.outputWeights[i5] = new DenseVector(i2);
            this.outputWeights[i5].assign(0.0d);
        }
        this.outputBias = new DenseVector(i3);
        this.outputBias.assign(0.0d);
        this.rnd = RandomUtils.getRandom();
    }

    public void initWeights(Random random) {
        double sqrt = 1.0d / Math.sqrt(this.numFeatures);
        for (int i = 0; i < this.numHidden; i++) {
            for (int i2 = 0; i2 < this.numFeatures; i2++) {
                this.hiddenWeights[i].setQuick(i2, ((2.0d * random.nextDouble()) - 1.0d) * sqrt);
            }
        }
        double sqrt2 = 1.0d / Math.sqrt(this.numHidden);
        for (int i3 = 0; i3 < this.numOutput; i3++) {
            for (int i4 = 0; i4 < this.numHidden; i4++) {
                this.outputWeights[i3].setQuick(i4, ((2.0d * random.nextDouble()) - 1.0d) * sqrt2);
            }
        }
    }

    public GradientMachine learningRate(double d) {
        this.learningRate = d;
        return this;
    }

    public GradientMachine regularization(double d) {
        this.regularization = d;
        return this;
    }

    public GradientMachine sparsity(double d) {
        this.sparsity = d;
        return this;
    }

    public GradientMachine sparsityLearningRate(double d) {
        this.sparsityLearningRate = d;
        return this;
    }

    public void copyFrom(GradientMachine gradientMachine) {
        this.numFeatures = gradientMachine.numFeatures;
        this.numHidden = gradientMachine.numHidden;
        this.numOutput = gradientMachine.numOutput;
        this.learningRate = gradientMachine.learningRate;
        this.regularization = gradientMachine.regularization;
        this.sparsity = gradientMachine.sparsity;
        this.sparsityLearningRate = gradientMachine.sparsityLearningRate;
        this.hiddenWeights = new DenseVector[this.numHidden];
        for (int i = 0; i < this.numHidden; i++) {
            this.hiddenWeights[i] = gradientMachine.hiddenWeights[i].mo2352clone();
        }
        this.hiddenBias = gradientMachine.hiddenBias.mo2352clone();
        this.outputWeights = new DenseVector[this.numOutput];
        for (int i2 = 0; i2 < this.numOutput; i2++) {
            this.outputWeights[i2] = gradientMachine.outputWeights[i2].mo2352clone();
        }
        this.outputBias = gradientMachine.outputBias.mo2352clone();
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public int numCategories() {
        return this.numOutput;
    }

    public int numFeatures() {
        return this.numFeatures;
    }

    public int numHidden() {
        return this.numHidden;
    }

    public DenseVector inputToHidden(Vector vector) {
        DenseVector denseVector = new DenseVector(this.numHidden);
        for (int i = 0; i < this.numHidden; i++) {
            denseVector.setQuick(i, this.hiddenWeights[i].dot(vector));
        }
        denseVector.assign(this.hiddenBias, Functions.PLUS);
        denseVector.assign(Functions.min(40.0d)).assign(Functions.max(-40.0d));
        denseVector.assign(Functions.SIGMOID);
        return denseVector;
    }

    public DenseVector hiddenToOutput(Vector vector) {
        DenseVector denseVector = new DenseVector(this.numOutput);
        for (int i = 0; i < this.numOutput; i++) {
            denseVector.setQuick(i, this.outputWeights[i].dot(vector));
        }
        denseVector.assign(this.outputBias, Functions.PLUS);
        return denseVector;
    }

    public void updateRanking(Vector vector, Collection<Integer> collection, int i, Random random) {
        int i2;
        if (collection.size() >= this.numOutput) {
            return;
        }
        for (Integer num : collection) {
            double dot = this.outputWeights[num.intValue()].dot(vector);
            int i3 = -1;
            double d = Double.NEGATIVE_INFINITY;
            for (int i4 = 0; i4 < i; i4++) {
                int nextInt = random.nextInt(this.numOutput);
                while (true) {
                    i2 = nextInt;
                    if (!collection.contains(Integer.valueOf(i2))) {
                        break;
                    } else {
                        nextInt = random.nextInt(this.numOutput);
                    }
                }
                double dot2 = this.outputWeights[i2].dot(vector);
                if (dot2 > d) {
                    d = dot2;
                    i3 = i2;
                }
            }
            int i5 = i3;
            if ((1.0d - dot) + d >= 0.0d) {
                Vector mo2352clone = this.outputWeights[num.intValue()].mo2352clone();
                mo2352clone.assign(Functions.NEGATE);
                Vector mo2352clone2 = mo2352clone.mo2352clone();
                Vector mo2352clone3 = this.outputWeights[i5].mo2352clone();
                mo2352clone2.assign(mo2352clone3, Functions.PLUS);
                mo2352clone.assign(Functions.mult((-this.learningRate) * (1.0d - this.regularization)));
                this.outputWeights[num.intValue()].assign(mo2352clone, Functions.PLUS);
                mo2352clone3.assign(Functions.mult((-this.learningRate) * (1.0d + this.regularization)));
                this.outputWeights[i5].assign(mo2352clone3, Functions.PLUS);
                this.outputBias.setQuick(num.intValue(), this.outputBias.get(num.intValue()) + this.learningRate);
                this.outputBias.setQuick(i5, this.outputBias.get(i5) - this.learningRate);
                Vector mo2352clone4 = vector.mo2352clone();
                mo2352clone4.assign(Functions.SIGMOIDGRADIENT);
                for (int i6 = 0; i6 < this.numHidden; i6++) {
                    mo2352clone4.setQuick(i6, mo2352clone4.get(i6) * mo2352clone2.get(i6));
                }
                for (int i7 = 0; i7 < this.numHidden; i7++) {
                    for (int i8 = 0; i8 < this.numFeatures; i8++) {
                        double d2 = this.hiddenWeights[i7].get(i8);
                        this.hiddenWeights[i7].setQuick(i8, d2 - (this.learningRate * (mo2352clone4.get(i7) + (this.regularization * d2))));
                    }
                }
            }
        }
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classify(Vector vector) {
        Vector classifyNoLink = classifyNoLink(vector);
        int maxValueIndex = classifyNoLink.maxValueIndex();
        classifyNoLink.assign(0.0d);
        classifyNoLink.setQuick(maxValueIndex, 1.0d);
        return classifyNoLink.viewPart(1, classifyNoLink.size() - 1);
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classifyNoLink(Vector vector) {
        return hiddenToOutput(inputToHidden(vector));
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public double classifyScalar(Vector vector) {
        Vector classifyNoLink = classifyNoLink(vector);
        return classifyNoLink.get(0) > classifyNoLink.get(1) ? 0.0d : 1.0d;
    }

    public GradientMachine copy() {
        close();
        GradientMachine gradientMachine = new GradientMachine(numFeatures(), numHidden(), numCategories());
        gradientMachine.copyFrom(this);
        return gradientMachine;
    }

    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(1);
        dataOutput.writeDouble(this.learningRate);
        dataOutput.writeDouble(this.regularization);
        dataOutput.writeDouble(this.sparsity);
        dataOutput.writeDouble(this.sparsityLearningRate);
        dataOutput.writeInt(this.numFeatures);
        dataOutput.writeInt(this.numHidden);
        dataOutput.writeInt(this.numOutput);
        VectorWritable.writeVector(dataOutput, this.hiddenBias);
        for (int i = 0; i < this.numHidden; i++) {
            VectorWritable.writeVector(dataOutput, this.hiddenWeights[i]);
        }
        VectorWritable.writeVector(dataOutput, this.outputBias);
        for (int i2 = 0; i2 < this.numOutput; i2++) {
            VectorWritable.writeVector(dataOutput, this.outputWeights[i2]);
        }
    }

    public void readFields(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        if (readInt != 1) {
            throw new IOException("Incorrect object version, wanted 1 got " + readInt);
        }
        this.learningRate = dataInput.readDouble();
        this.regularization = dataInput.readDouble();
        this.sparsity = dataInput.readDouble();
        this.sparsityLearningRate = dataInput.readDouble();
        this.numFeatures = dataInput.readInt();
        this.numHidden = dataInput.readInt();
        this.numOutput = dataInput.readInt();
        this.hiddenWeights = new DenseVector[this.numHidden];
        this.hiddenBias = VectorWritable.readVector(dataInput);
        for (int i = 0; i < this.numHidden; i++) {
            this.hiddenWeights[i] = VectorWritable.readVector(dataInput);
        }
        this.outputWeights = new DenseVector[this.numOutput];
        this.outputBias = VectorWritable.readVector(dataInput);
        for (int i2 = 0; i2 < this.numOutput; i2++) {
            this.outputWeights[i2] = VectorWritable.readVector(dataInput);
        }
    }

    @Override // org.apache.mahout.classifier.OnlineLearner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, String str, int i, Vector vector) {
        DenseVector inputToHidden = inputToHidden(vector);
        hiddenToOutput(inputToHidden);
        HashSet hashSet = new HashSet();
        hashSet.add(Integer.valueOf(i));
        updateRanking(inputToHidden, hashSet, 2, this.rnd);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, int i, Vector vector) {
        train(j, null, i, vector);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(int i, Vector vector) {
        train(0L, null, i, vector);
    }
}
