package weka.filters.unsupervised.attribute.missingvaluesimputation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.functions.Logistic;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

/* loaded from: input_file:weka/filters/unsupervised/attribute/missingvaluesimputation/IRMI.class */
public class IRMI extends AbstractImputation implements TechnicalInformationHandler {
    public static final String NOMINAL_CLASSIFIER = "nominal-classifier";
    public static final String NUMERIC_CLASSIFIER = "numeric-classifier";
    public static final String NUM_EPOCHS = "num-epochs";
    public static final String EPSILON = "epsilon";
    protected Classifier m_nominalClassifier = getDefaultNominalClassifier();
    protected Classifier m_numericClassifier = getDefaultNumericClassifier();
    protected int m_numEpochs = getDefaultNumEpochs();
    protected double m_epsilon = getDefaultEpsilon();
    protected Classifier[] m_classifiers;
    protected Instances m_Header;

    @Override // weka.filters.unsupervised.attribute.missingvaluesimputation.AbstractImputation, weka.filters.unsupervised.attribute.missingvaluesimputation.Imputation
    public String globalInfo() {
        return "Uses the IRMI algorithm as published by Templ et al in 'Iterative stepwise regression imputation using standard and robust methods'.\n\n" + getTechnicalInformation();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Matthias Templ and Alexander Kowarik and Peter Filzmoser");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Iterative stepwise regression imputation using standard and robust methods");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Computational Statistics & Data Analysis");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2011");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "55");
        technicalInformation.setValue(TechnicalInformation.Field.NUMBER, "10");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "2793-2806");
        technicalInformation.setValue(TechnicalInformation.Field.ISSN, "0167-9473");
        technicalInformation.setValue(TechnicalInformation.Field.HTTP, "http://www.statistik.tuwien.ac.at/public/filz/papers/CSDA11TKF.pdf");
        return technicalInformation;
    }

    @Override // weka.filters.unsupervised.attribute.missingvaluesimputation.AbstractImputation
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\t" + nominalClassifierTipText() + "\n\t(default: " + getDefaultNominalClassifier() + ")", NOMINAL_CLASSIFIER, 1, "-nominal-classifier <classname + options>"));
        vector.addElement(new Option("\t" + nominalClassifierTipText() + "\n\t(default: " + getDefaultNumericClassifier() + ")", NUMERIC_CLASSIFIER, 1, "-numeric-classifier <classname + options>"));
        vector.addElement(new Option("\t" + numEpochsTipText() + "\n\t(default: " + getDefaultNumEpochs() + ")", NUM_EPOCHS, 1, "-num-epochs <int>"));
        vector.addElement(new Option("\t" + epsilonTipText() + "\n\t(default: " + getDefaultEpsilon() + ")", EPSILON, 1, "-epsilon <double>"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    @Override // weka.filters.unsupervised.attribute.missingvaluesimputation.AbstractImputation
    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("-nominal-classifier");
        arrayList.add(Utils.toCommandLine(this.m_nominalClassifier));
        arrayList.add("-numeric-classifier");
        arrayList.add(Utils.toCommandLine(this.m_numericClassifier));
        arrayList.add("-num-epochs");
        arrayList.add("" + this.m_numEpochs);
        arrayList.add("-epsilon");
        arrayList.add("" + this.m_epsilon);
        Collections.addAll(arrayList, super.getOptions());
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    @Override // weka.filters.unsupervised.attribute.missingvaluesimputation.AbstractImputation
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption(NOMINAL_CLASSIFIER, strArr);
        if (option.isEmpty()) {
            setNominalClassifier(getDefaultNominalClassifier());
        } else {
            String[] splitOptions = Utils.splitOptions(option);
            String str = splitOptions[0];
            splitOptions[0] = "";
            setNominalClassifier((Classifier) Utils.forName(Classifier.class, str, splitOptions));
        }
        String option2 = Utils.getOption(NUMERIC_CLASSIFIER, strArr);
        if (option2.isEmpty()) {
            setNumericClassifier(getDefaultNumericClassifier());
        } else {
            String[] splitOptions2 = Utils.splitOptions(option2);
            String str2 = splitOptions2[0];
            splitOptions2[0] = "";
            setNumericClassifier((Classifier) Utils.forName(Classifier.class, str2, splitOptions2));
        }
        String option3 = Utils.getOption(NUM_EPOCHS, strArr);
        if (option3.isEmpty()) {
            setNumEpochs(getDefaultNumEpochs());
        } else {
            setNumEpochs(Integer.parseInt(option3));
        }
        String option4 = Utils.getOption(EPSILON, strArr);
        if (option4.isEmpty()) {
            setEpsilon(getDefaultEpsilon());
        } else {
            setEpsilon(Double.parseDouble(option4));
        }
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    protected Classifier getDefaultNominalClassifier() {
        return new Logistic();
    }

    public void setNominalClassifier(Classifier classifier) {
        this.m_nominalClassifier = classifier;
    }

    public Classifier getNominalClassifier() {
        return this.m_nominalClassifier;
    }

    public String nominalClassifierTipText() {
        return "Nominal classifier to use";
    }

    protected Classifier getDefaultNumericClassifier() {
        return new LinearRegression();
    }

    public void setNumericClassifier(Classifier classifier) {
        this.m_numericClassifier = classifier;
    }

    public Classifier getNumericClassifier() {
        return this.m_numericClassifier;
    }

    public String numericClassifierTipText() {
        return "Numeric classifier to use";
    }

    protected int getDefaultNumEpochs() {
        return 100;
    }

    public void setNumEpochs(int i) {
        this.m_numEpochs = i;
    }

    public int getNumEpochs() {
        return this.m_numEpochs;
    }

    public String numEpochsTipText() {
        return "Max number of epochs";
    }

    protected double getDefaultEpsilon() {
        return 5.0d;
    }

    public void setEpsilon(double d) {
        this.m_epsilon = d;
    }

    public double getEpsilon() {
        return this.m_epsilon;
    }

    public String epsilonTipText() {
        return "Epsilon for early termination";
    }

    @Override // weka.filters.unsupervised.attribute.missingvaluesimputation.AbstractImputation
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enableAllClasses();
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    protected double median(double[] dArr) {
        if (dArr.length == 0) {
            return 0.0d;
        }
        ArrayList arrayList = new ArrayList(dArr.length);
        for (double d : dArr) {
            if (!Utils.isMissingValue(d)) {
                arrayList.add(Double.valueOf(d));
            }
        }
        if (arrayList.size() == 0) {
            return 0.0d;
        }
        Collections.sort(arrayList);
        int size = arrayList.size() / 2;
        return arrayList.size() % 2 == 0 ? (((Double) arrayList.get(size)).doubleValue() + ((Double) arrayList.get(size - 1)).doubleValue()) / 2.0d : ((Double) arrayList.get(size)).doubleValue();
    }

    protected double mode(double[] dArr) {
        if (dArr.length == 0) {
            return 0.0d;
        }
        HashMap hashMap = new HashMap();
        for (double d : dArr) {
            if (!Utils.isMissingValue(d)) {
                if (hashMap.get(Double.valueOf(d)) == null) {
                    hashMap.put(Double.valueOf(d), 1);
                } else {
                    hashMap.put(Double.valueOf(d), Integer.valueOf(((Integer) hashMap.get(Double.valueOf(d))).intValue() + 1));
                }
            }
        }
        double d2 = 0.0d;
        double d3 = Double.NEGATIVE_INFINITY;
        Iterator it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            double doubleValue = ((Double) it.next()).doubleValue();
            if (((Integer) hashMap.get(Double.valueOf(doubleValue))).intValue() > d3) {
                d2 = doubleValue;
                d3 = ((Integer) hashMap.get(Double.valueOf(doubleValue))).intValue();
            }
        }
        return d2;
    }

    @Override // weka.filters.unsupervised.attribute.missingvaluesimputation.AbstractImputation
    protected Instances doBuildImputation(Instances instances) throws Exception {
        Instances instances2 = new Instances(instances);
        int classIndex = instances2.classIndex();
        for (int i = 0; i < instances2.numInstances(); i++) {
            instances2.get(i).setClassValue(Double.NaN);
        }
        Pair[] pairArr = new Pair[instances2.numAttributes()];
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < instances2.numAttributes(); i2++) {
            int i3 = 0;
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            for (int i4 = 0; i4 < instances2.numInstances(); i4++) {
                if (instances2.get(i4).isMissing(i2)) {
                    arrayList3.add(Integer.valueOf(i4));
                    i3++;
                } else {
                    arrayList4.add(Integer.valueOf(i4));
                }
            }
            arrayList.add(arrayList3);
            arrayList2.add(arrayList4);
            pairArr[i2] = new Pair(i3, i2);
        }
        Arrays.sort(pairArr);
        for (int i5 = 0; i5 < instances2.numAttributes(); i5++) {
            if (i5 != instances2.classIndex()) {
                double[] attributeToDoubleArray = instances2.attributeToDoubleArray(i5);
                double median = instances2.attribute(i5).isNumeric() ? median(attributeToDoubleArray) : mode(attributeToDoubleArray);
                for (int i6 = 0; i6 < attributeToDoubleArray.length; i6++) {
                    if (Double.isNaN(instances2.get(i6).value(i5)) || Double.isInfinite(instances2.get(i6).value(i5))) {
                        instances2.get(i6).setValue(i5, median);
                    }
                }
            }
        }
        boolean[] zArr = new boolean[instances2.numAttributes()];
        for (int i7 = 0; i7 < instances2.numAttributes(); i7++) {
            zArr[i7] = false;
        }
        this.m_classifiers = new Classifier[instances2.numAttributes()];
        for (int i8 = 0; i8 < getNumEpochs(); i8++) {
            for (Pair pair : pairArr) {
                int i9 = pair.index;
                if (pair.value != 0 && i9 != classIndex && !zArr[i9] && ((ArrayList) arrayList2.get(i9)).size() != 0) {
                    Instances instances3 = new Instances(instances, 0);
                    Iterator it = ((ArrayList) arrayList2.get(i9)).iterator();
                    while (it.hasNext()) {
                        instances3.add(instances2.get(((Integer) it.next()).intValue()));
                    }
                    instances3.setClassIndex(i9);
                    Classifier makeCopy = instances2.attribute(i9).isNominal() ? AbstractClassifier.makeCopy(this.m_nominalClassifier) : AbstractClassifier.makeCopy(this.m_numericClassifier);
                    makeCopy.buildClassifier(instances3);
                    this.m_classifiers[i9] = makeCopy;
                    double d = 0.0d;
                    instances2.setClassIndex(i9);
                    Iterator it2 = ((ArrayList) arrayList.get(i9)).iterator();
                    while (it2.hasNext()) {
                        int intValue = ((Integer) it2.next()).intValue();
                        double value = instances2.get(intValue).value(i9);
                        double classifyInstance = this.m_classifiers[i9].classifyInstance(instances2.get(intValue));
                        instances2.get(intValue).setValue(i9, classifyInstance);
                        d += Math.pow(value - classifyInstance, 2.0d);
                    }
                    if (d < this.m_epsilon) {
                        zArr[i9] = true;
                    }
                }
            }
            boolean z = true;
            int i10 = 0;
            while (true) {
                if (i10 >= zArr.length) {
                    break;
                }
                if (i10 != classIndex && !zArr[i10]) {
                    z = false;
                    break;
                }
                i10++;
            }
            if (z) {
                break;
            }
        }
        this.m_Header = new Instances(instances, 0);
        return new Instances(instances, 0);
    }

    @Override // weka.filters.unsupervised.attribute.missingvaluesimputation.AbstractImputation
    protected Instance doImpute(Instance instance) throws Exception {
        Instance instance2 = (Instance) instance.copy();
        instance2.setDataset(this.m_Header);
        for (int i = 0; i < instance2.numAttributes(); i++) {
            if (i != instance.classIndex() && instance2.isMissing(i) && this.m_classifiers[i] != null) {
                this.m_Header.setClassIndex(i);
                instance2.setValue(i, this.m_classifiers[i].classifyInstance(instance2));
            }
        }
        instance2.setDataset(this.m_OutputFormat);
        return instance2;
    }
}
