package weka.classifiers.functions;

import com.microsoft.ml.lightgbm.PredictionType;
import io.github.metarank.lightgbm4j.LGBMBooster;
import io.github.metarank.lightgbm4j.LGBMDataset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

/* loaded from: input_file:weka/classifiers/functions/LightGBM.class */
public class LightGBM extends RandomizableClassifier implements TechnicalInformationHandler, AutoCloseable {
    private static final long serialVersionUID = -6138516902729782286L;
    public static final String VERSION = "3.3.2";
    public static final String PARAMETERS_URL = "https://lightgbm.readthedocs.io/en/v3.3.2/Parameters.html";
    public static final int OBJECTIVE_REGRESSION = 0;
    public static final int OBJECTIVE_REGRESSION_L1 = 1;
    public static final int OBJECTIVE_HUBER = 2;
    public static final int OBJECTIVE_FAIR = 3;
    public static final int OBJECTIVE_POISSON = 4;
    public static final int OBJECTIVE_QUANTILE = 5;
    public static final int OBJECTIVE_MAPE = 6;
    public static final int OBJECTIVE_GAMMA = 7;
    public static final int OBJECTIVE_TWEEDIE = 8;
    public static final int OBJECTIVE_BINARY = 9;
    public static final int OBJECTIVE_MULTICLASS = 10;
    public static final int OBJECTIVE_MULTICLASSOVA = 11;
    public static final int OBJECTIVE_CROSS_ENTROPY = 12;
    public static final int OBJECTIVE_CROSS_ENTROPY_LAMBDA = 13;
    public static final int OBJECTIVE_LAMBDARANK = 14;
    public static final int OBJECTIVE_RANK_XENDCG = 15;
    public static final Tag[] TAGS_OBJECTIVE = {new Tag(0, "REGRESSION", "Regression"), new Tag(1, "REGRESSION_L1", "Regression L1"), new Tag(2, "HUBER", "Huber loss"), new Tag(3, "FAIR", "Fair loss"), new Tag(4, "POISSON", "Poisson regression"), new Tag(5, "QUANTILE", "Quantile regression"), new Tag(6, "MAPE", "MAPE loss"), new Tag(7, "GAMMA", "Gamma regression with log-link"), new Tag(8, "TWEEDIE", "Tweedie regression with log-link"), new Tag(9, "BINARY", "Binary log loss classification"), new Tag(10, "MULTICLASS", "Multi-class (softmax)"), new Tag(11, "MULTICLASSOVA", "Multi-class (one-vs-all)"), new Tag(12, "CROSSENTROPY", "Cross-entropy"), new Tag(13, "CROSSENTROPY_LAMBDA", "Cross-entropy Lambda"), new Tag(14, "LAMBDA_RANK", "Lambda rank"), new Tag(15, "RANK_XENDCG", "Rank Xendcg")};
    protected String m_ActualParameters;
    protected boolean m_NumericClass;
    protected int m_Objective = 0;
    protected String m_Parameters = "";
    protected int m_NumIterations = 1000;
    protected double m_ValidationPercentage = 0.0d;
    protected boolean m_RandomizeBeforeSplit = false;
    protected transient LGBMBooster m_Booster = null;
    protected String m_Model = null;

    public String globalInfo() {
        return "LightGBM (https://github.com/microsoft/LightGBM) is a gradient boosting framework that uses tree based learning algorithms. It is designed to be distributed and efficient.\n\nInformation on parameters:\nhttps://lightgbm.readthedocs.io/en/v3.3.2/Parameters.html\nThe following parameters get filled in automatically:\n- objective\n- categorical_features\n\nFor more information see:\n\n" + getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Ke, Guolin and Meng, Qi and Finley, Thomas and Wang, Taifeng and Chen, Wei and Ma, Weidong and Ye, Qiwei and Liu, Tie-Yan");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2017");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "LightGBM: A Highly Efficient Gradient Boosting Decision Tree");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Advances in Neural Information Processing Systems");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "Curran Associates, Inc.");
        technicalInformation.setValue(TechnicalInformation.Field.EDITOR, "I. Guyon and U. Von Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "3149-3157");
        technicalInformation.setValue(TechnicalInformation.Field.URL, "https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf");
        return technicalInformation;
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        String str = "";
        for (int i = 0; i < TAGS_OBJECTIVE.length; i++) {
            SelectedTag selectedTag = new SelectedTag(TAGS_OBJECTIVE[i].getID(), TAGS_OBJECTIVE);
            str = str + "\t" + selectedTag.getSelectedTag().getIDStr() + " = " + selectedTag.getSelectedTag().getReadable() + "\n";
        }
        vector.addElement(new Option("\tThe type of booster to use:\n" + str + "\t(default: " + new SelectedTag(0, TAGS_OBJECTIVE) + ")", "O", 1, "-O " + Tag.toOptionList(TAGS_OBJECTIVE)));
        vector.addElement(new Option("\tThe parameters for the booster (blank-separated key=value pairs).\n\tSee: https://lightgbm.readthedocs.io/en/v3.3.2/Parameters.html\n\t(default: none)\n", "P", 1, "-P <parameters>"));
        vector.addElement(new Option("\tThe number of iterations to train for.\n\t(default: 1000)\n", "I", 1, "-I <iterations>"));
        vector.addElement(new Option("\tThe size of the validation set to split off from the training set.\n\t(default: 0.0)\n", "V", 1, "-V <0-100>"));
        vector.addElement(new Option("\tTurns on randomization before splitting off the validation set.\n\t(default: off)\n", "R", 0, "-R"));
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('O', strArr);
        if (option.length() != 0) {
            setObjective(new SelectedTag(option, TAGS_OBJECTIVE));
        } else {
            setObjective(new SelectedTag(0, TAGS_OBJECTIVE));
        }
        setParameters(Utils.getOption('P', strArr));
        String option2 = Utils.getOption('I', strArr);
        if (option2.length() != 0) {
            setNumIterations(Integer.parseInt(option2));
        } else {
            setNumIterations(1000);
        }
        String option3 = Utils.getOption('V', strArr);
        if (option3.length() != 0) {
            setValidationPercentage(Double.parseDouble(option3));
        } else {
            setValidationPercentage(0.0d);
        }
        setRandomizeBeforeSplit(Utils.getFlag('R', strArr));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList(Arrays.asList(super.getOptions()));
        arrayList.add("-O");
        arrayList.add("" + getObjective());
        if (!getParameters().trim().isEmpty()) {
            arrayList.add("-P");
            arrayList.add("" + getParameters());
        }
        arrayList.add("-I");
        arrayList.add("" + getNumIterations());
        if (getValidationPercentage() > 0.0d) {
            arrayList.add("-V");
            arrayList.add("" + getValidationPercentage());
        }
        if (getRandomizeBeforeSplit()) {
            arrayList.add("-R");
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    public void setObjective(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_OBJECTIVE) {
            this.m_Objective = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getObjective() {
        return new SelectedTag(this.m_Objective, TAGS_OBJECTIVE);
    }

    public String objectiveTipText() {
        return "Sets the type of booster to use.";
    }

    public void setParameters(String str) {
        this.m_Parameters = str;
    }

    public String getParameters() {
        return this.m_Parameters;
    }

    public String parametersTipText() {
        return "Sets the parameters to use (blank-separated key=value pairs), see: https://lightgbm.readthedocs.io/en/v3.3.2/Parameters.html";
    }

    public void setNumIterations(int i) {
        if (i > 0) {
            this.m_NumIterations = i;
        }
    }

    public int getNumIterations() {
        return this.m_NumIterations;
    }

    public String numIterationsTipText() {
        return "Sets the number of iterations to train for.";
    }

    public void setValidationPercentage(double d) {
        if (d < 0.0d || d >= 100.0d) {
            return;
        }
        this.m_ValidationPercentage = d;
    }

    public double getValidationPercentage() {
        return this.m_ValidationPercentage;
    }

    public String validationPercentageTipText() {
        return "Sets the percentage to split off the training set for using as validation set during training (0 <= x < 100).";
    }

    public void setRandomizeBeforeSplit(boolean z) {
        this.m_RandomizeBeforeSplit = z;
    }

    public boolean getRandomizeBeforeSplit() {
        return this.m_RandomizeBeforeSplit;
    }

    public String randomizeBeforeSplitTipText() {
        return "If enabled, the data gets randomized before splitting off the validation set.";
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = new Capabilities(this);
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        switch (this.m_Objective) {
            case OBJECTIVE_BINARY /* 9 */:
                capabilities.enable(Capabilities.Capability.BINARY_CLASS);
                capabilities.disable(Capabilities.Capability.UNARY_CLASS);
                break;
            case OBJECTIVE_MULTICLASS /* 10 */:
            case OBJECTIVE_MULTICLASSOVA /* 11 */:
                capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
                capabilities.disable(Capabilities.Capability.BINARY_CLASS);
                capabilities.disable(Capabilities.Capability.UNARY_CLASS);
                break;
            default:
                capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
                break;
        }
        capabilities.setMinimumNumberInstances(1);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_NumericClass = instances2.classAttribute().isNumeric();
        Instances instances3 = instances2;
        Instances instances4 = null;
        if (this.m_ValidationPercentage > 0.0d) {
            if (this.m_RandomizeBeforeSplit) {
                instances2.randomize(new Random(this.m_Seed));
            }
            int round = (int) Math.round((instances2.size() * this.m_ValidationPercentage) / 100.0d);
            instances3 = new Instances(instances2, instances2.numInstances() - round);
            instances4 = new Instances(instances2, round);
            for (int i = 0; i < instances2.numInstances(); i++) {
                if (i < instances2.numInstances() - round) {
                    instances3.add((Instance) instances2.instance(i).copy());
                } else {
                    instances4.add((Instance) instances2.instance(i).copy());
                }
            }
            if (getDebug()) {
                System.out.println("train size: " + instances3.numInstances() + ", validation size: " + instances4.numInstances());
            }
        }
        StringBuilder sb = new StringBuilder();
        for (int i2 = 0; i2 < instances2.numAttributes(); i2++) {
            if (i2 != instances2.classIndex() && instances2.attribute(i2).isNominal()) {
                if (sb.length() > 0) {
                    sb.append(",");
                }
                sb.append(i2);
            }
        }
        LGBMDataset fromInstances = LightGBMUtils.fromInstances(instances3);
        LGBMDataset fromInstances2 = instances4 != null ? LightGBMUtils.fromInstances(instances4, fromInstances) : null;
        this.m_ActualParameters = "objective=" + getObjective().getSelectedTag().getIDStr().toLowerCase() + " label=name:" + instances2.classAttribute().name();
        if (sb.length() > 0) {
            this.m_ActualParameters += " categorical_features=" + sb.toString();
        }
        if (!this.m_Parameters.isEmpty()) {
            this.m_ActualParameters += " " + this.m_Parameters;
        }
        if (getDebug()) {
            System.out.println("Actual parameters: " + this.m_ActualParameters);
        }
        try {
            try {
                this.m_Booster = LGBMBooster.create(fromInstances, this.m_ActualParameters);
                if (fromInstances2 != null) {
                    this.m_Booster.addValidData(fromInstances2);
                }
                int i3 = 0;
                while (true) {
                    if (i3 >= this.m_NumIterations) {
                        break;
                    }
                    if (this.m_Booster.updateOneIter()) {
                        System.out.println("No more splits possible, stopping training at iteration " + (i3 + 1) + " out of " + this.m_NumIterations);
                        break;
                    }
                    i3++;
                }
                this.m_Model = this.m_Booster.saveModelToString(0, 0, LGBMBooster.FeatureImportanceType.GAIN);
                fromInstances.close();
                if (fromInstances2 != null) {
                    fromInstances2.close();
                }
            } catch (Exception e) {
                if (this.m_Booster != null) {
                    this.m_Booster.close();
                }
                fromInstances.close();
                if (fromInstances2 != null) {
                    fromInstances2.close();
                }
            }
        } catch (Throwable th) {
            fromInstances.close();
            if (fromInstances2 != null) {
                fromInstances2.close();
            }
            throw th;
        }
    }

    protected void initBooster() throws Exception {
        if (this.m_Booster == null) {
            if (this.m_Model == null) {
                throw new IllegalStateException("No model trained?");
            }
            this.m_Booster = LGBMBooster.loadModelFromString(this.m_Model);
        }
    }

    public double classifyInstance(Instance instance) throws Exception {
        initBooster();
        double predictForMatSingleRow = this.m_Booster.predictForMatSingleRow(LightGBMUtils.fromInstance(instance), PredictionType.C_API_PREDICT_NORMAL);
        if (!this.m_NumericClass) {
            predictForMatSingleRow = Math.round(predictForMatSingleRow);
        }
        return predictForMatSingleRow;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.m_Model == null) {
            sb.append("No model built yet.");
        } else {
            sb.append("LightGBM\n");
            sb.append("========\n\n");
            sb.append("Actual parameters: ").append(this.m_ActualParameters).append("\n");
            sb.append("Model:\n");
            sb.append(this.m_Model);
        }
        return sb.toString();
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        if (this.m_Booster != null) {
            this.m_Booster.close();
            this.m_Booster = null;
        }
    }

    public static void main(String[] strArr) {
        runClassifier(new LightGBM(), strArr);
    }
}
