package org.apache.mahout.clustering.dirichlet;

import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.TimesFunction;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.3.jar:org/apache/mahout/clustering/dirichlet/DirichletClusterer.class */
public class DirichletClusterer<O> {
    private final List<O> sampleData;
    private final ModelDistribution<O> modelFactory;
    private final DirichletState<O> state;
    private final int thin;
    private final int burnin;
    private final int numClusters;
    private final List<Model<O>[]> clusterSamples = new ArrayList();

    public DirichletClusterer(List<O> list, ModelDistribution<O> modelDistribution, double d, int i, int i2, int i3) {
        this.sampleData = list;
        this.modelFactory = modelDistribution;
        this.thin = i2;
        this.burnin = i3;
        this.numClusters = i;
        this.state = new DirichletState<>(modelDistribution, i, d);
    }

    public List<Model<O>[]> cluster(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            iterate(i2, this.state);
        }
        return this.clusterSamples;
    }

    private void iterate(int i, DirichletState<O> dirichletState) {
        Model<O>[] sampleFromPosterior = this.modelFactory.sampleFromPosterior(dirichletState.getModels());
        for (O o : this.sampleData) {
            sampleFromPosterior[UncommonDistributions.rMultinom(normalizedProbabilities(dirichletState, o))].observe(o);
        }
        if (i >= this.burnin && i % this.thin == 0) {
            this.clusterSamples.add(sampleFromPosterior);
        }
        dirichletState.update(sampleFromPosterior);
    }

    private Vector normalizedProbabilities(DirichletState<O> dirichletState, O o) {
        DenseVector denseVector = new DenseVector(this.numClusters);
        double d = 0.0d;
        for (int i = 0; i < this.numClusters; i++) {
            double adjustedProbability = dirichletState.adjustedProbability(o, i);
            denseVector.set(i, adjustedProbability);
            if (d < adjustedProbability) {
                d = adjustedProbability;
            }
        }
        denseVector.assign(new TimesFunction(), 1.0d / d);
        return denseVector;
    }

    public static List<Model<Vector>[]> clusterPoints(List<Vector> list, ModelDistribution<Vector> modelDistribution, double d, int i, int i2, int i3, int i4) {
        return new DirichletClusterer(list, modelDistribution, d, i, i2, i3).cluster(i4);
    }
}
