package weka.classifiers.meta;

import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.meta.multisearch.AbstractEvaluationFactory;
import weka.classifiers.meta.multisearch.AbstractEvaluationMetrics;
import weka.classifiers.meta.multisearch.AbstractSearch;
import weka.classifiers.meta.multisearch.DefaultEvaluationFactory;
import weka.classifiers.meta.multisearch.DefaultSearch;
import weka.classifiers.meta.multisearch.Performance;
import weka.classifiers.meta.multisearch.PerformanceComparator;
import weka.classifiers.meta.multisearch.TraceableOptimizer;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Debug;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.SetupGenerator;
import weka.core.Summarizable;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.setupgenerator.AbstractParameter;
import weka.core.setupgenerator.MathParameter;
import weka.core.setupgenerator.ParameterGroup;
import weka.core.setupgenerator.Point;
import weka.core.setupgenerator.Space;

/* loaded from: input_file:weka/classifiers/meta/MultiSearch.class */
public class MultiSearch extends RandomizableSingleClassifierEnhancer implements AdditionalMeasureProducer, Summarizable, TraceableOptimizer {
    private static final long serialVersionUID = -5129316523575906233L;
    protected AbstractSearch.SearchResult m_BestClassifier;
    protected AbstractParameter[] m_DefaultParameters;
    protected AbstractParameter[] m_Parameters;
    protected AbstractSearch m_Algorithm;
    protected SetupGenerator m_Generator;
    protected List<Map.Entry<Integer, Performance>> m_Trace;
    protected File m_LogFile = new File(System.getProperty("user.dir"));
    protected AbstractEvaluationFactory m_Factory = newFactory();
    protected AbstractEvaluationMetrics m_Metrics = this.m_Factory.newMetrics();
    protected int m_Evaluation = this.m_Metrics.getDefaultMetric();

    public MultiSearch() {
        this.m_Classifier = defaultClassifier();
        this.m_DefaultParameters = defaultSearchParameters();
        this.m_Parameters = defaultSearchParameters();
        this.m_Algorithm = defaultAlgorithm();
        this.m_Trace = new ArrayList();
        try {
            this.m_BestClassifier = new AbstractSearch.SearchResult();
            this.m_BestClassifier.classifier = AbstractClassifier.makeCopy(this.m_Classifier);
        } catch (Exception e) {
            System.err.println("Failed to create copy of default classifier!");
            e.printStackTrace();
        }
    }

    public String globalInfo() {
        return "Performs a search of an arbitrary number of parameters of a classifier and chooses the best pair found for the actual filtering and training.\nThe default MultiSearch is using the following Classifier setup:\n  LinearRegression, searching for the \"Ridge\"\nThe properties being explored are totally up to the user.\n\nE.g., if you have a FilteredClassifier selected as base classifier, sporting a PLSFilter and you want to explore the number of PLS components, then your property will be made up of the following components:\n - filter: referring to the FilteredClassifier's property (= PLSFilter)\n - numComponents: the actual property of the PLSFilter that we want to modify\nAnd assembled, the property looks like this:\n  filter.numComponents\n\n\nThe best classifier setup can be accessed after the buildClassifier call via the getBestClassifier method.\n\nThe trace of setups evaluated can be accessed after the buildClassifier call as well, using the following methods:\n- getTrace()\n- getTraceSize()\n- getTraceValue(int)\n- getTraceFolds(int)\n- getTraceClassifierAsCli(int)\nUsing the " + ParameterGroup.class.getName() + " parameter, it is possible to group dependent parameters. In this case, all top-level parameters must be of type " + ParameterGroup.class.getName() + ".";
    }

    protected String defaultClassifierString() {
        return defaultClassifier().getClass().getName();
    }

    protected Classifier defaultClassifier() {
        LinearRegression linearRegression = new LinearRegression();
        linearRegression.setAttributeSelectionMethod(new SelectedTag(1, LinearRegression.TAGS_SELECTION));
        linearRegression.setEliminateColinearAttributes(false);
        return linearRegression;
    }

    protected AbstractParameter[] defaultSearchParameters() {
        AbstractParameter[] abstractParameterArr;
        MathParameter mathParameter = new MathParameter();
        mathParameter.setProperty("ridge");
        mathParameter.setMin(-10.0d);
        mathParameter.setMax(5.0d);
        mathParameter.setStep(1.0d);
        mathParameter.setBase(10.0d);
        mathParameter.setExpression("pow(BASE,I)");
        try {
            abstractParameterArr = (AbstractParameter[]) new SerializedObject(new AbstractParameter[]{mathParameter}).getObject();
        } catch (Exception e) {
            abstractParameterArr = new AbstractParameter[0];
            System.err.println("Failed to create copy of default parameters!");
            e.printStackTrace();
        }
        return abstractParameterArr;
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        String str = "";
        for (int i = 0; i < this.m_Metrics.getTags().length; i++) {
            SelectedTag selectedTag = new SelectedTag(this.m_Metrics.getTags()[i].getID(), this.m_Metrics.getTags());
            str = str + "\t" + selectedTag.getSelectedTag().getIDStr() + " = " + selectedTag.getSelectedTag().getReadable() + "\n";
        }
        vector.addElement(new Option("\tDetermines the parameter used for evaluation:\n" + str + "\t(default: " + new SelectedTag(this.m_Metrics.getDefaultMetric(), this.m_Metrics.getTags()) + ")", "E", 1, "-E " + Tag.toOptionList(this.m_Metrics.getTags())));
        vector.addElement(new Option("\tA property search setup.\n", "search", 1, "-search \"<classname options>\""));
        vector.addElement(new Option("\tA search algorithm.\n", "algorithm", 1, "-algorithm \"<classname options>\""));
        vector.addElement(new Option("\tThe log file to log the messages to.\n\t(default: none)", "log-file", 1, "-log-file <filename>"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-E");
        vector.add("" + getEvaluation());
        for (int i = 0; i < getSearchParameters().length; i++) {
            vector.add("-search");
            vector.add(getCommandline(getSearchParameters()[i]));
        }
        vector.add("-algorithm");
        vector.add(getCommandline(this.m_Algorithm));
        vector.add("-log-file");
        vector.add("" + getLogFile());
        for (String str : super.getOptions()) {
            vector.add(str);
        }
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public void setOptions(String[] strArr) throws Exception {
        String option;
        String option2 = Utils.getOption('E', strArr);
        if (option2.length() != 0) {
            setEvaluation(new SelectedTag(option2, this.m_Metrics.getTags()));
        } else {
            setEvaluation(new SelectedTag(this.m_Metrics.getDefaultMetric(), this.m_Metrics.getTags()));
        }
        Vector vector = new Vector();
        do {
            option = Utils.getOption("search", strArr);
            if (option.length() > 0) {
                vector.add(option);
            }
        } while (option.length() > 0);
        if (vector.size() == 0) {
            for (int i = 0; i < this.m_DefaultParameters.length; i++) {
                vector.add(getCommandline(this.m_DefaultParameters[i]));
            }
        }
        AbstractParameter[] abstractParameterArr = new AbstractParameter[vector.size()];
        for (int i2 = 0; i2 < vector.size(); i2++) {
            String[] splitOptions = Utils.splitOptions((String) vector.get(i2));
            String str = splitOptions[0];
            splitOptions[0] = "";
            abstractParameterArr[i2] = (AbstractParameter) Utils.forName(AbstractParameter.class, str, splitOptions);
        }
        setSearchParameters(abstractParameterArr);
        String option3 = Utils.getOption("algorithm", strArr);
        if (option3.isEmpty()) {
            setAlgorithm(new DefaultSearch());
        } else {
            String[] splitOptions2 = Utils.splitOptions(option3);
            String str2 = splitOptions2[0];
            splitOptions2[0] = "";
            setAlgorithm((AbstractSearch) Utils.forName(AbstractSearch.class, str2, splitOptions2));
        }
        String option4 = Utils.getOption("log-file", strArr);
        if (option4.length() != 0) {
            setLogFile(new File(option4));
        } else {
            setLogFile(new File(System.getProperty("user.dir")));
        }
        super.setOptions(strArr);
    }

    public void setClassifier(Classifier classifier) {
        super.setClassifier(classifier);
        try {
            this.m_BestClassifier.classifier = AbstractClassifier.makeCopy(this.m_Classifier);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public String searchParametersTipText() {
        return "Defines the search parameters.";
    }

    public void setSearchParameters(AbstractParameter[] abstractParameterArr) {
        this.m_Parameters = abstractParameterArr;
    }

    public AbstractParameter[] getSearchParameters() {
        return this.m_Parameters;
    }

    public String algorithmTipText() {
        return "Defines the search algorithm.";
    }

    public void setAlgorithm(AbstractSearch abstractSearch) {
        this.m_Algorithm = abstractSearch;
    }

    public AbstractSearch getAlgorithm() {
        return this.m_Algorithm;
    }

    public AbstractSearch defaultAlgorithm() {
        return new DefaultSearch();
    }

    public String evaluationTipText() {
        return "Sets the criterion for evaluating the classifier performance and choosing the best one.";
    }

    public Tag[] getMetricsTags() {
        return this.m_Metrics.getTags();
    }

    public void setEvaluation(SelectedTag selectedTag) {
        if (selectedTag.getTags() == this.m_Metrics.getTags()) {
            this.m_Evaluation = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getEvaluation() {
        return new SelectedTag(this.m_Evaluation, this.m_Metrics.getTags());
    }

    public String logFileTipText() {
        return "The log file to log the messages to.";
    }

    public File getLogFile() {
        return this.m_LogFile;
    }

    public void setLogFile(File file) {
        this.m_LogFile = file;
    }

    public Classifier getBestClassifier() {
        return this.m_BestClassifier.classifier;
    }

    public SetupGenerator getGenerator() {
        return this.m_Generator;
    }

    public Enumeration enumerateMeasures() {
        Vector vector = new Vector();
        if (getBestValues() != null) {
            for (int i = 0; i < getBestValues().dimensions(); i++) {
                if (getBestValues().getValue(i) instanceof Double) {
                    vector.add("measure-" + i);
                }
            }
        }
        return vector.elements();
    }

    public double getMeasure(String str) {
        if (str.startsWith("measure-")) {
            return ((Double) getBestValues().getValue(Integer.parseInt(str.replace("measure-", "")))).doubleValue();
        }
        throw new IllegalArgumentException("Measure '" + str + "' not supported!");
    }

    protected AbstractEvaluationFactory newFactory() {
        return new DefaultEvaluationFactory();
    }

    public AbstractEvaluationFactory getFactory() {
        return this.m_Factory;
    }

    public AbstractEvaluationMetrics getMetrics() {
        return this.m_Metrics;
    }

    public Point<Object> getBestValues() {
        return this.m_BestClassifier.values;
    }

    public Point<Object> getBestCoordinates() {
        return this.m_BestClassifier.performance.getValues();
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        Iterator capabilities2 = capabilities.getClassCapabilities().capabilities();
        while (capabilities2.hasNext()) {
            Capabilities.Capability capability = (Capabilities.Capability) capabilities2.next();
            if (capability != Capabilities.Capability.BINARY_CLASS && capability != Capabilities.Capability.NOMINAL_CLASS && capability != Capabilities.Capability.NUMERIC_CLASS && capability != Capabilities.Capability.DATE_CLASS) {
                capabilities.disable(capability);
            }
        }
        for (Capabilities.Capability capability2 : Capabilities.Capability.values()) {
            capabilities.enableDependency(capability2);
        }
        if (capabilities.getMinimumNumberInstances() < 1) {
            capabilities.setMinimumNumberInstances(1);
        }
        capabilities.setOwner(this);
        return capabilities;
    }

    public String getCommandline(Object obj) {
        String name = obj.getClass().getName();
        if (obj instanceof OptionHandler) {
            name = name + " " + Utils.joinOptions(((OptionHandler) obj).getOptions());
        }
        return name.trim();
    }

    public void log(String str) {
        log(str, false);
    }

    public void log(String str, boolean z) {
        if (getDebug() && !z) {
            System.out.println(str);
        }
        if (getLogFile().isDirectory()) {
            return;
        }
        Debug.writeToFile(getLogFile().getAbsolutePath(), str, true);
    }

    public String logPerformances(Space space, Vector<Performance> vector, Tag tag) {
        StringBuffer stringBuffer = new StringBuffer(tag.getReadable() + ":\n");
        stringBuffer.append(space.toString());
        stringBuffer.append("\n");
        for (int i = 0; i < vector.size(); i++) {
            stringBuffer.append(vector.get(i).getPerformance(tag.getID()));
            stringBuffer.append("\n");
        }
        stringBuffer.append("\n");
        return stringBuffer.toString();
    }

    public void logPerformances(Space space, Vector<Performance> vector) {
        for (int i = 0; i < this.m_Metrics.getTags().length; i++) {
            log("\n" + logPerformances(space, vector, this.m_Metrics.getTags()[i]), true);
        }
    }

    @Override // weka.classifiers.meta.multisearch.TraceableOptimizer
    public int getTraceSize() {
        return this.m_Trace.size();
    }

    @Override // weka.classifiers.meta.multisearch.TraceableOptimizer
    public String getTraceClassifierAsCli(int i) {
        return getCommandline(this.m_Trace.get(i).getValue().getClassifier());
    }

    @Override // weka.classifiers.meta.multisearch.TraceableOptimizer
    public Double getTraceValue(int i) {
        return Double.valueOf(this.m_Trace.get(i).getValue().getPerformance());
    }

    @Override // weka.classifiers.meta.multisearch.TraceableOptimizer
    public Integer getTraceFolds(int i) {
        return this.m_Trace.get(i).getKey();
    }

    @Override // weka.classifiers.meta.multisearch.TraceableOptimizer
    public List<Map.Entry<Integer, Performance>> getTrace() {
        return this.m_Trace;
    }

    protected List<AbstractParameter[]> groupParameters() {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 0; i2 < this.m_Parameters.length; i2++) {
            if (this.m_Parameters[i2] instanceof ParameterGroup) {
                i++;
            }
        }
        if (i > 0 && this.m_Parameters.length != i) {
            throw new IllegalStateException("Cannot mix " + ParameterGroup.class.getName() + " with other parameter types!");
        }
        if (i > 0) {
            for (int i3 = 0; i3 < this.m_Parameters.length; i3++) {
                arrayList.add(((ParameterGroup) this.m_Parameters[i3]).getParameters());
            }
        } else {
            arrayList.add(this.m_Parameters);
        }
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_Trace.clear();
        List<AbstractParameter[]> groupParameters = groupParameters();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < groupParameters.size(); i++) {
            if (groupParameters.size() > 1) {
                log("\n---> group #" + (i + 1));
            }
            this.m_Generator = new SetupGenerator();
            this.m_Generator.setBaseObject(this);
            this.m_Generator.setParameters((AbstractParameter[]) groupParameters.get(i).clone());
            this.m_Generator.setBaseObject((Serializable) getClassifier());
            this.m_Algorithm.setOwner(this);
            arrayList.add(this.m_Algorithm.search(instances2));
            this.m_Trace.addAll(this.m_Algorithm.getTrace());
        }
        AbstractSearch.SearchResult searchResult = (AbstractSearch.SearchResult) arrayList.get(0);
        if (arrayList.size() > 1) {
            PerformanceComparator performanceComparator = new PerformanceComparator(getEvaluation().getSelectedTag().getID(), getMetrics());
            for (int i2 = 1; i2 < arrayList.size(); i2++) {
                if (performanceComparator.compare(((AbstractSearch.SearchResult) arrayList.get(i2)).performance, searchResult.performance) < 0) {
                    searchResult = (AbstractSearch.SearchResult) arrayList.get(i2);
                }
            }
        }
        this.m_BestClassifier = searchResult;
        log("\n---> train best - start");
        log(Utils.toCommandLine(this.m_BestClassifier));
        this.m_Classifier = AbstractClassifier.makeCopy(this.m_BestClassifier.classifier);
        this.m_Classifier.buildClassifier(instances2);
        log("\n---> train best - end");
        if (this.m_Debug) {
            log("\n---> Trace (format: #. folds/performance - setup)");
            for (int i3 = 0; i3 < getTraceSize(); i3++) {
                log((i3 + 1) + ". " + getTraceFolds(i3) + "/" + getTraceValue(i3) + " - " + getTraceClassifierAsCli(i3));
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.m_Classifier.distributionForInstance(instance);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (getBestValues() == null) {
            sb.append("No search performed yet.");
        } else {
            sb.append(getClass().getName() + ":\nClassifier: " + getCommandline(getBestClassifier()) + "\n\n");
            for (int i = 0; i < this.m_Parameters.length; i++) {
                sb.append((i + 1) + ". parameter: " + this.m_Parameters[i] + "\n");
            }
            sb.append("Evaluation: " + getEvaluation().getSelectedTag().getReadable() + "\nCoordinates: " + getBestCoordinates() + "\n");
            sb.append("Values: " + getBestValues() + "\n\n" + this.m_Classifier.toString());
            if (this.m_Debug) {
                sb.append("\n\nTrace (format: #. folds/performance - setup):\n");
                for (int i2 = 0; i2 < getTraceSize(); i2++) {
                    sb.append("\n" + (i2 + 1) + ". " + getTraceFolds(i2) + "/" + getTraceValue(i2) + " - " + getTraceClassifierAsCli(i2));
                }
            }
        }
        return sb.toString();
    }

    public String toSummaryString() {
        return "Best classifier: " + getCommandline(getBestClassifier());
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 4521 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new MultiSearch(), strArr);
    }
}
