package org.apache.mahout.clustering.dirichlet.models;

import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.SquareRootFunction;

/* loaded from: input_file:org/apache/mahout/clustering/dirichlet/models/NormalModel.class */
public class NormalModel implements Model<VectorWritable> {
    private Vector mean;
    private double stdDev;
    private int s0;
    private Vector s1;
    private Vector s2;
    private static final double sqrt2pi = Math.sqrt(6.283185307179586d);
    private static final Type modelType = new TypeToken<Model<Vector>>() { // from class: org.apache.mahout.clustering.dirichlet.models.NormalModel.1
    }.getType();

    public NormalModel() {
        this.s0 = 0;
    }

    public NormalModel(Vector vector, double d) {
        this.s0 = 0;
        this.mean = vector;
        this.stdDev = d;
        this.s0 = 0;
        this.s1 = vector.like();
        this.s2 = vector.like();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getS0() {
        return this.s0;
    }

    public Vector getMean() {
        return this.mean;
    }

    public double getStdDev() {
        return this.stdDev;
    }

    public NormalModel sample() {
        return new NormalModel(this.mean, this.stdDev);
    }

    @Override // org.apache.mahout.clustering.dirichlet.models.Model
    public void observe(VectorWritable vectorWritable) {
        this.s0++;
        Vector vector = vectorWritable.get();
        if (this.s1 == null) {
            this.s1 = vector.clone();
        } else {
            this.s1 = this.s1.plus(vector);
        }
        if (this.s2 == null) {
            this.s2 = vector.times(vector);
        } else {
            this.s2 = this.s2.plus(vector.times(vector));
        }
    }

    @Override // org.apache.mahout.clustering.dirichlet.models.Model
    public void computeParameters() {
        if (this.s0 == 0) {
            return;
        }
        this.mean = this.s1.divide(this.s0);
        if (this.s0 <= 1) {
            this.stdDev = Double.MIN_VALUE;
        } else {
            this.stdDev = this.s2.times(this.s0).minus(this.s1.times(this.s1)).assign(new SquareRootFunction()).divide(this.s0).zSum() / r0.size();
        }
    }

    @Override // org.apache.mahout.clustering.dirichlet.models.Model
    public double pdf(VectorWritable vectorWritable) {
        Vector vector = vectorWritable.get();
        return Math.exp((-((vector.dot(vector) - (2.0d * vector.dot(this.mean))) + this.mean.dot(this.mean))) / (2.0d * (this.stdDev * this.stdDev))) / (this.stdDev * sqrt2pi);
    }

    @Override // org.apache.mahout.clustering.dirichlet.models.Model
    public int count() {
        return this.s0;
    }

    public String toString() {
        return asFormatString(null);
    }

    @Override // org.apache.mahout.clustering.Printable
    public String asFormatString(String[] strArr) {
        StringBuilder sb = new StringBuilder();
        sb.append("nm{n=").append(this.s0).append(" m=");
        if (this.mean != null) {
            sb.append(ClusterBase.formatVector(this.mean, strArr));
        }
        sb.append(" sd=").append(String.format("%.2f", Double.valueOf(this.stdDev))).append('}');
        return sb.toString();
    }

    public void readFields(DataInput dataInput) throws IOException {
        VectorWritable vectorWritable = new VectorWritable();
        vectorWritable.readFields(dataInput);
        this.mean = vectorWritable.get();
        this.stdDev = dataInput.readDouble();
        this.s0 = dataInput.readInt();
        vectorWritable.readFields(dataInput);
        this.s1 = vectorWritable.get();
        vectorWritable.readFields(dataInput);
        this.s2 = vectorWritable.get();
    }

    public void write(DataOutput dataOutput) throws IOException {
        VectorWritable.writeVector(dataOutput, this.mean);
        dataOutput.writeDouble(this.stdDev);
        dataOutput.writeInt(this.s0);
        VectorWritable.writeVector(dataOutput, this.s1);
        VectorWritable.writeVector(dataOutput, this.s2);
    }

    @Override // org.apache.mahout.clustering.Printable
    public String asJsonString() {
        GsonBuilder gsonBuilder = new GsonBuilder();
        gsonBuilder.registerTypeAdapter(Model.class, new JsonModelAdapter());
        return gsonBuilder.create().toJson(this, modelType);
    }
}
