package org.apache.mahout.math.stats;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.Random;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/* loaded from: input_file:org/apache/mahout/math/stats/GlobalOnlineAuc.class */
public class GlobalOnlineAuc implements OnlineAuc {
    public static final int HISTORY = 10;
    private int windowSize = Integer.MAX_VALUE;
    private ReplacementPolicy policy = ReplacementPolicy.FIFO;
    private final Random random = RandomUtils.getRandom();
    private Matrix scores = new DenseMatrix(2, 10);
    private Vector averages;
    private Vector samples;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/mahout/math/stats/GlobalOnlineAuc$ReplacementPolicy.class */
    public enum ReplacementPolicy {
        FIFO,
        FAIR,
        RANDOM
    }

    public GlobalOnlineAuc() {
        this.scores.assign(Double.NaN);
        this.averages = new DenseVector(2);
        this.averages.assign(0.5d);
        this.samples = new DenseVector(2);
    }

    @Override // org.apache.mahout.math.stats.OnlineAuc
    public double addSample(int i, String str, double d) {
        return addSample(i, d);
    }

    @Override // org.apache.mahout.math.stats.OnlineAuc
    public double addSample(int i, double d) {
        int i2 = (int) this.samples.get(i);
        if (i2 < 10) {
            this.scores.set(i, i2, d);
        } else {
            switch (this.policy) {
                case FIFO:
                    this.scores.set(i, i2 % 10, d);
                    break;
                case FAIR:
                    int nextInt = this.random.nextInt(i2 + 1);
                    if (nextInt < 10) {
                        this.scores.set(i, nextInt, d);
                        break;
                    }
                    break;
                case RANDOM:
                    this.scores.set(i, this.random.nextInt(10), d);
                    break;
                default:
                    throw new IllegalStateException("Unknown policy: " + this.policy);
            }
        }
        this.samples.set(i, i2 + 1);
        if (this.samples.minValue() >= 1.0d) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            Iterator<Vector.Element> it = this.scores.viewRow(1 - i).all().iterator();
            while (it.hasNext()) {
                double d4 = it.next().get();
                if (!Double.isNaN(d4)) {
                    d3 += 1.0d;
                    if (d > d4) {
                        d2 += 1.0d;
                    } else if (d == d4) {
                        d2 += 0.5d;
                    }
                }
            }
            this.averages.set(i, this.averages.get(i) + (((d2 / d3) - this.averages.get(i)) / Math.min(this.windowSize, this.samples.get(i))));
        }
        return auc();
    }

    @Override // org.apache.mahout.math.stats.OnlineAuc
    public double auc() {
        return ((1.0d - this.averages.get(0)) + this.averages.get(1)) / 2.0d;
    }

    public double value() {
        return auc();
    }

    @Override // org.apache.mahout.math.stats.OnlineAuc
    public void setPolicy(ReplacementPolicy replacementPolicy) {
        this.policy = replacementPolicy;
    }

    @Override // org.apache.mahout.math.stats.OnlineAuc
    public void setWindowSize(int i) {
        this.windowSize = i;
    }

    @Override // org.apache.hadoop.io.Writable
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.windowSize);
        dataOutput.writeInt(this.policy.ordinal());
        MatrixWritable.writeMatrix(dataOutput, this.scores);
        VectorWritable.writeVector(dataOutput, this.averages);
        VectorWritable.writeVector(dataOutput, this.samples);
    }

    @Override // org.apache.hadoop.io.Writable
    public void readFields(DataInput dataInput) throws IOException {
        this.windowSize = dataInput.readInt();
        this.policy = ReplacementPolicy.values()[dataInput.readInt()];
        this.scores = MatrixWritable.readMatrix(dataInput);
        this.averages = VectorWritable.readVector(dataInput);
        this.samples = VectorWritable.readVector(dataInput);
    }
}
