package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.functions.nearestCentroid.ICentroidFinder;
import weka.classifiers.functions.nearestCentroid.prototypeFinders.MeanCentroidFinder;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.NormalizableDistance;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:weka/classifiers/functions/NearestCentroidClassifier.class */
public class NearestCentroidClassifier extends AbstractClassifier implements WeightedInstancesHandler {
    private static final long serialVersionUID = 8462836067571523903L;
    protected DistanceFunction distFun;
    protected ICentroidFinder centFinder;

    public NearestCentroidClassifier() {
        this.distFun = null;
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        euclideanDistance.setDontNormalize(true);
        this.distFun = euclideanDistance;
        this.centFinder = new MeanCentroidFinder();
    }

    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.distFun.setInstances(instances);
        this.centFinder.findCentroids(instances);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        int centroidNum = this.centFinder.getCentroidNum();
        double[] dArr = new double[centroidNum];
        double d = 0.0d;
        int i = 0;
        double d2 = -1.7976931348623157E308d;
        for (int i2 = 0; i2 < centroidNum; i2++) {
            double distance = this.distFun.distance(this.centFinder.getCentroid(i2), instance);
            if (distance > d2) {
                d2 = distance;
                i = i2;
            }
            double exp = Math.exp((-2.0d) * distance);
            if (this.centFinder.isCentroidActive(i2)) {
                dArr[i2] = exp;
                d += exp;
            }
        }
        boolean z = Utils.eq(d, 0.0d);
        if (!z) {
            int i3 = 0;
            while (true) {
                if (i3 >= centroidNum) {
                    break;
                }
                int i4 = i3;
                dArr[i4] = dArr[i4] / d;
                if (Utils.isMissingValue(dArr[i3])) {
                    z = true;
                    break;
                }
                i3++;
            }
        }
        if (z) {
            dArr = new double[centroidNum];
            dArr[i] = 1.0d;
        }
        return dArr;
    }

    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(1);
        vector.addElement(new Option("\tThe distance function to use (default: weka.core.EuclideanDistance).\n", "A", 0, "-A"));
        vector.addElement(new Option("\tThe centroid calculator object to use (default: weka.classifiers.functions.nearestCentroid.prototypeFinders.MeanCentroidFinder).\n", "CF", 0, "-CF"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-A");
        vector.add(String.valueOf(this.distFun.getClass().getName()) + " " + Utils.joinOptions(this.distFun.getOptions()));
        vector.add("-CF");
        vector.add(String.valueOf(this.centFinder.getClass().getName()) + " " + Utils.joinOptions(this.centFinder.getOptions()));
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[0]);
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('A', strArr);
        if (option.length() != 0) {
            String[] splitOptions = Utils.splitOptions(option);
            if (splitOptions.length == 0) {
                throw new Exception("Invalid Distance function specification string.");
            }
            String str = splitOptions[0];
            splitOptions[0] = "";
            setDistFun((DistanceFunction) Utils.forName(NormalizableDistance.class, str, splitOptions));
        } else {
            setDistFun(new EuclideanDistance());
        }
        String option2 = Utils.getOption("CF", strArr);
        if (option2.length() != 0) {
            String[] splitOptions2 = Utils.splitOptions(option2);
            if (splitOptions2.length == 0) {
                throw new Exception("Invalid Distance function specification string.");
            }
            String str2 = splitOptions2[0];
            splitOptions2[0] = "";
            setCentFinder((ICentroidFinder) Utils.forName(ICentroidFinder.class, str2, splitOptions2));
        } else {
            setCentFinder(new MeanCentroidFinder());
        }
        super.setOptions(strArr);
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.BINARY_ATTRIBUTES);
        capabilities.setMinimumNumberInstances(2);
        return capabilities;
    }

    public String toString() {
        if (this.centFinder.isModelBuilt()) {
            return "The model has not been built yet.";
        }
        StringBuffer stringBuffer = new StringBuffer();
        try {
            stringBuffer.append("Nearest Centroid Classifier: \n\nCentroids:\n");
            Attribute classAttribute = this.centFinder.getCentroid(0).classAttribute();
            int centroidNum = this.centFinder.getCentroidNum();
            for (int i = 0; i < centroidNum; i++) {
                stringBuffer.append("Class " + classAttribute.value(i) + ":" + this.centFinder.getCentroid(i) + "\n");
            }
        } catch (Exception e) {
        }
        return stringBuffer.toString();
    }

    public Instance[] getCentroids() {
        if (!this.centFinder.isModelBuilt()) {
            return null;
        }
        int centroidNum = this.centFinder.getCentroidNum();
        Instance[] instanceArr = new Instance[centroidNum];
        for (int i = 0; i < centroidNum; i++) {
            try {
                instanceArr[i] = this.centFinder.getCentroid(i);
            } catch (Exception e) {
            }
        }
        return instanceArr;
    }

    public DistanceFunction getDistFun() {
        return this.distFun;
    }

    public void setDistFun(DistanceFunction distanceFunction) {
        this.distFun = distanceFunction;
    }

    public String distFunTipText() {
        return "Distance function to use with the classifier";
    }

    public ICentroidFinder getCentFinder() {
        return this.centFinder;
    }

    public void setCentFinder(ICentroidFinder iCentroidFinder) {
        this.centFinder = iCentroidFinder;
    }

    public String centFinderTipText() {
        return "Centroid finder to use with the classifier";
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 2 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new NearestCentroidClassifier(), strArr);
    }

    public String globalInfo() {
        return "Performs the Nearest Centroid classification";
    }
}
