package org.apache.mahout.clustering.iterator;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.mahout.clustering.classify.ClusterClassifier;
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/* loaded from: input_file:org/apache/mahout/clustering/iterator/DirichletClusteringPolicy.class */
public class DirichletClusteringPolicy extends AbstractClusteringPolicy {
    private Vector mixture;
    private double alpha0;

    public DirichletClusteringPolicy() {
    }

    public DirichletClusteringPolicy(int i, double d) {
        this.alpha0 = d;
        this.mixture = UncommonDistributions.rDirichlet(new DenseVector(i), d);
    }

    @Override // org.apache.mahout.clustering.iterator.AbstractClusteringPolicy, org.apache.mahout.clustering.iterator.ClusteringPolicy
    public Vector select(Vector vector) {
        int rMultinom = UncommonDistributions.rMultinom(vector.times(this.mixture));
        SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(vector.size());
        sequentialAccessSparseVector.set(rMultinom, 1.0d);
        return sequentialAccessSparseVector;
    }

    @Override // org.apache.mahout.clustering.iterator.AbstractClusteringPolicy, org.apache.mahout.clustering.iterator.ClusteringPolicy
    public void update(ClusterClassifier clusterClassifier) {
        DenseVector denseVector = new DenseVector(clusterClassifier.getModels().size());
        for (int i = 0; i < clusterClassifier.getModels().size(); i++) {
            denseVector.set(i, clusterClassifier.getModels().get(i).getTotalObservations());
        }
        this.mixture = UncommonDistributions.rDirichlet(denseVector, this.alpha0);
    }

    @Override // org.apache.mahout.clustering.iterator.AbstractClusteringPolicy
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeDouble(this.alpha0);
        VectorWritable.writeVector(dataOutput, this.mixture);
    }

    @Override // org.apache.mahout.clustering.iterator.AbstractClusteringPolicy
    public void readFields(DataInput dataInput) throws IOException {
        this.alpha0 = dataInput.readDouble();
        this.mixture = VectorWritable.readVector(dataInput);
    }
}
