package org.apache.mahout.clustering.dirichlet;

import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.Vector;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.1.jar:org/apache/mahout/clustering/dirichlet/DirichletState.class */
public class DirichletState<Observation> {
    public int numClusters;
    public ModelDistribution<Observation> modelFactory;
    public List<DirichletCluster<Observation>> clusters;
    public Vector mixture;
    public double offset;

    public DirichletState(ModelDistribution<Observation> modelDistribution, int i, double d, int i2, int i3) {
        this.numClusters = i;
        this.modelFactory = modelDistribution;
        this.offset = d / i;
        this.clusters = new ArrayList();
        for (Model<Observation> model : modelDistribution.sampleFromPrior(i)) {
            this.clusters.add(new DirichletCluster<>(model, this.offset));
        }
        this.mixture = UncommonDistributions.rDirichlet(totalCounts());
    }

    public DirichletState() {
    }

    public Vector totalCounts() {
        DenseVector denseVector = new DenseVector(this.numClusters);
        for (int i = 0; i < this.numClusters; i++) {
            denseVector.set(i, this.clusters.get(i).totalCount);
        }
        return denseVector;
    }

    public void update(Model<Observation>[] modelArr) {
        for (int i = 0; i < modelArr.length; i++) {
            modelArr[i].computeParameters();
            this.clusters.get(i).setModel(modelArr[i]);
        }
        this.mixture = UncommonDistributions.rDirichlet(totalCounts());
    }

    public double adjustedProbability(Observation observation, int i) {
        return this.mixture.get(i) * this.clusters.get(i).model.pdf(observation);
    }

    public Model<Observation>[] getModels() {
        Model<Observation>[] modelArr = new Model[this.numClusters];
        for (int i = 0; i < this.numClusters; i++) {
            modelArr[i] = this.clusters.get(i).model;
        }
        return modelArr;
    }
}
