package org.apache.mahout.clustering.dirichlet;

import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.Vector;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.2.jar:org/apache/mahout/clustering/dirichlet/DirichletState.class */
public class DirichletState<O> {
    private int numClusters;
    private ModelDistribution<O> modelFactory;
    private List<DirichletCluster<O>> clusters;
    private Vector mixture;
    private double offset;

    public DirichletState(ModelDistribution<O> modelDistribution, int i, double d, int i2, int i3) {
        this.numClusters = i;
        this.modelFactory = modelDistribution;
        this.offset = d / i;
        this.clusters = new ArrayList();
        for (Model<O> model : modelDistribution.sampleFromPrior(i)) {
            this.clusters.add(new DirichletCluster<>(model, this.offset));
        }
        this.mixture = UncommonDistributions.rDirichlet(totalCounts());
    }

    public DirichletState() {
    }

    public int getNumClusters() {
        return this.numClusters;
    }

    public void setNumClusters(int i) {
        this.numClusters = i;
    }

    public ModelDistribution<O> getModelFactory() {
        return this.modelFactory;
    }

    public void setModelFactory(ModelDistribution<O> modelDistribution) {
        this.modelFactory = modelDistribution;
    }

    public List<DirichletCluster<O>> getClusters() {
        return this.clusters;
    }

    public void setClusters(List<DirichletCluster<O>> list) {
        this.clusters = list;
    }

    public Vector getMixture() {
        return this.mixture;
    }

    public void setMixture(Vector vector) {
        this.mixture = vector;
    }

    public double getOffset() {
        return this.offset;
    }

    public void setOffset(double d) {
        this.offset = d;
    }

    public Vector totalCounts() {
        DenseVector denseVector = new DenseVector(this.numClusters);
        for (int i = 0; i < this.numClusters; i++) {
            denseVector.set(i, this.clusters.get(i).getTotalCount());
        }
        return denseVector;
    }

    public void update(Model<O>[] modelArr) {
        for (int i = 0; i < modelArr.length; i++) {
            modelArr[i].computeParameters();
            this.clusters.get(i).setModel(modelArr[i]);
        }
        this.mixture = UncommonDistributions.rDirichlet(totalCounts());
    }

    public double adjustedProbability(O o, int i) {
        return this.mixture.get(i) * this.clusters.get(i).getModel().pdf(o);
    }

    public Model<O>[] getModels() {
        Model<O>[] modelArr = new Model[this.numClusters];
        for (int i = 0; i < this.numClusters; i++) {
            modelArr[i] = this.clusters.get(i).getModel();
        }
        return modelArr;
    }
}
