/*
 * Decompiled with CFR 0.152.
 */
package net.algart.executors.modules.opencv.matrices.ml.training;

import java.util.Locale;
import net.algart.executors.modules.opencv.matrices.ml.AbstractMLTrain;
import net.algart.executors.modules.opencv.matrices.ml.MLKind;
import net.algart.executors.modules.opencv.matrices.ml.MLSamplesType;
import net.algart.executors.modules.opencv.matrices.ml.MLStatModelTrainer;
import net.algart.executors.modules.opencv.matrices.ml.MLTrainer;
import net.algart.executors.modules.util.opencv.OTools;
import org.bytedeco.opencv.opencv_core.TermCriteria;
import org.bytedeco.opencv.opencv_ml.ParamGrid;
import org.bytedeco.opencv.opencv_ml.SVM;
import org.bytedeco.opencv.opencv_ml.StatModel;
import org.bytedeco.opencv.opencv_ml.TrainData;

public final class MLTrainSVM
extends AbstractMLTrain {
    public static final String OUTPUT_C = "c";
    public static final String OUTPUT_GAMMA = "gamma";
    public static final String OUTPUT_P = "p";
    public static final String OUTPUT_NU = "nu";
    public static final String OUTPUT_COEF = "coef";
    public static final String OUTPUT_DEGREE = "degree";
    private SVMType svmType = SVMType.C_SVC;
    private KernelType kernelType = KernelType.RBF;
    private double c = 1.0;
    private double gamma = 1.0;
    private double p = 0.0;
    private double nu = 0.0;
    private double coef = 0.0;
    private double degree = 0.0;
    private int terminationMaxCount = 0;
    private double terminationEpsilon = 0.0;
    private boolean autoTraining = false;
    private int kFold = 10;
    private boolean cGridCustom = false;
    private double cGridMin = 0.0;
    private double cGridMax = 1.0;
    private double cGridLogStep = 1.0;
    private boolean gammaGridCustom = false;
    private double gammaGridMin = 0.0;
    private double gammaGridMax = 1.0;
    private double gammaGridLogStep = 1.0;
    private boolean pGridCustom = false;
    private double pGridMin = 0.0;
    private double pGridMax = 1.0;
    private double pGridLogStep = 1.0;
    private boolean nuGridCustom = false;
    private double nuGridMin = 0.0;
    private double nuGridMax = 1.0;
    private double nuGridLogStep = 1.0;
    private boolean coefGridCustom = false;
    private double coefGridMin = 0.0;
    private double coefGridMax = 1.0;
    private double coefGridLogStep = 1.0;
    private boolean degreeGridCustom = false;
    private double degreeGridMin = 0.0;
    private double degreeGridMax = 1.0;
    private double degreeGridLogStep = 1.0;
    private boolean balanced = false;

    private MLTrainSVM(MLSamplesType inputType) {
        super(inputType);
        this.addOutputScalar(OUTPUT_C);
        this.addOutputScalar(OUTPUT_GAMMA);
        this.addOutputScalar(OUTPUT_P);
        this.addOutputScalar(OUTPUT_NU);
        this.addOutputScalar(OUTPUT_COEF);
        this.addOutputScalar(OUTPUT_DEGREE);
    }

    public static MLTrainSVM newTrainNumbers() {
        return new MLTrainSVM(MLSamplesType.NUMBERS);
    }

    public static MLTrainSVM newTrainPixels() {
        return new MLTrainSVM(MLSamplesType.PIXELS);
    }

    public SVMType getSvmType() {
        return this.svmType;
    }

    public MLTrainSVM setSvmType(SVMType svmType) {
        this.svmType = (SVMType)((Object)MLTrainSVM.nonNull((Object)((Object)svmType)));
        return this;
    }

    public KernelType getKernelType() {
        return this.kernelType;
    }

    public MLTrainSVM setKernelType(KernelType kernelType) {
        this.kernelType = (KernelType)((Object)MLTrainSVM.nonNull((Object)((Object)kernelType)));
        return this;
    }

    public double getC() {
        return this.c;
    }

    public MLTrainSVM setC(double c) {
        this.c = c;
        return this;
    }

    public double getGamma() {
        return this.gamma;
    }

    public MLTrainSVM setGamma(double gamma) {
        this.gamma = gamma;
        return this;
    }

    public double getP() {
        return this.p;
    }

    public MLTrainSVM setP(double p) {
        this.p = p;
        return this;
    }

    public double getNu() {
        return this.nu;
    }

    public MLTrainSVM setNu(double nu) {
        this.nu = nu;
        return this;
    }

    public double getCoef() {
        return this.coef;
    }

    public MLTrainSVM setCoef(double coef) {
        this.coef = coef;
        return this;
    }

    public double getDegree() {
        return this.degree;
    }

    public MLTrainSVM setDegree(double degree) {
        this.degree = degree;
        return this;
    }

    public int getTerminationMaxCount() {
        return this.terminationMaxCount;
    }

    public MLTrainSVM setTerminationMaxCount(int terminationMaxCount) {
        this.terminationMaxCount = MLTrainSVM.nonNegative((int)terminationMaxCount);
        return this;
    }

    public double getTerminationEpsilon() {
        return this.terminationEpsilon;
    }

    public MLTrainSVM setTerminationEpsilon(double terminationEpsilon) {
        this.terminationEpsilon = MLTrainSVM.nonNegative((double)terminationEpsilon);
        return this;
    }

    public boolean isAutoTraining() {
        return this.autoTraining;
    }

    public MLTrainSVM setAutoTraining(boolean autoTraining) {
        this.autoTraining = autoTraining;
        return this;
    }

    public int getKFold() {
        return this.kFold;
    }

    public MLTrainSVM setKFold(int kFold) {
        this.kFold = MLTrainSVM.nonNegative((int)kFold);
        return this;
    }

    public boolean isCGridCustom() {
        return this.cGridCustom;
    }

    public MLTrainSVM setCGridCustom(boolean cGridCustom) {
        this.cGridCustom = cGridCustom;
        return this;
    }

    public double getCGridMin() {
        return this.cGridMin;
    }

    public MLTrainSVM setCGridMin(double cGridMin) {
        this.cGridMin = cGridMin;
        return this;
    }

    public double getCGridMax() {
        return this.cGridMax;
    }

    public MLTrainSVM setCGridMax(double cGridMax) {
        this.cGridMax = cGridMax;
        return this;
    }

    public double getCGridLogStep() {
        return this.cGridLogStep;
    }

    public MLTrainSVM setCGridLogStep(double cGridLogStep) {
        this.cGridLogStep = MLTrainSVM.positive((double)cGridLogStep);
        return this;
    }

    public boolean isGammaGridCustom() {
        return this.gammaGridCustom;
    }

    public MLTrainSVM setGammaGridCustom(boolean gammaGridCustom) {
        this.gammaGridCustom = gammaGridCustom;
        return this;
    }

    public double getGammaGridMin() {
        return this.gammaGridMin;
    }

    public MLTrainSVM setGammaGridMin(double gammaGridMin) {
        this.gammaGridMin = gammaGridMin;
        return this;
    }

    public double getGammaGridMax() {
        return this.gammaGridMax;
    }

    public MLTrainSVM setGammaGridMax(double gammaGridMax) {
        this.gammaGridMax = gammaGridMax;
        return this;
    }

    public double getGammaGridLogStep() {
        return this.gammaGridLogStep;
    }

    public MLTrainSVM setGammaGridLogStep(double gammaGridLogStep) {
        this.gammaGridLogStep = MLTrainSVM.positive((double)gammaGridLogStep);
        return this;
    }

    public boolean isPGridCustom() {
        return this.pGridCustom;
    }

    public MLTrainSVM setPGridCustom(boolean pGridCustom) {
        this.pGridCustom = pGridCustom;
        return this;
    }

    public double getPGridMin() {
        return this.pGridMin;
    }

    public MLTrainSVM setPGridMin(double pGridMin) {
        this.pGridMin = pGridMin;
        return this;
    }

    public double getPGridMax() {
        return this.pGridMax;
    }

    public MLTrainSVM setPGridMax(double pGridMax) {
        this.pGridMax = pGridMax;
        return this;
    }

    public double getPGridLogStep() {
        return this.pGridLogStep;
    }

    public MLTrainSVM setPGridLogStep(double pGridLogStep) {
        this.pGridLogStep = MLTrainSVM.nonNegative((double)pGridLogStep);
        return this;
    }

    public boolean isNuGridCustom() {
        return this.nuGridCustom;
    }

    public MLTrainSVM setNuGridCustom(boolean nuGridCustom) {
        this.nuGridCustom = nuGridCustom;
        return this;
    }

    public double getNuGridMin() {
        return this.nuGridMin;
    }

    public MLTrainSVM setNuGridMin(double nuGridMin) {
        this.nuGridMin = nuGridMin;
        return this;
    }

    public double getNuGridMax() {
        return this.nuGridMax;
    }

    public MLTrainSVM setNuGridMax(double nuGridMax) {
        this.nuGridMax = nuGridMax;
        return this;
    }

    public double getNuGridLogStep() {
        return this.nuGridLogStep;
    }

    public MLTrainSVM setNuGridLogStep(double nuGridLogStep) {
        this.nuGridLogStep = MLTrainSVM.nonNegative((double)nuGridLogStep);
        return this;
    }

    public boolean isCoefGridCustom() {
        return this.coefGridCustom;
    }

    public MLTrainSVM setCoefGridCustom(boolean coefGridCustom) {
        this.coefGridCustom = coefGridCustom;
        return this;
    }

    public double getCoefGridMin() {
        return this.coefGridMin;
    }

    public MLTrainSVM setCoefGridMin(double coefGridMin) {
        this.coefGridMin = coefGridMin;
        return this;
    }

    public double getCoefGridMax() {
        return this.coefGridMax;
    }

    public MLTrainSVM setCoefGridMax(double coefGridMax) {
        this.coefGridMax = coefGridMax;
        return this;
    }

    public double getCoefGridLogStep() {
        return this.coefGridLogStep;
    }

    public MLTrainSVM setCoefGridLogStep(double coefGridLogStep) {
        this.coefGridLogStep = MLTrainSVM.nonNegative((double)coefGridLogStep);
        return this;
    }

    public boolean isDegreeGridCustom() {
        return this.degreeGridCustom;
    }

    public MLTrainSVM setDegreeGridCustom(boolean degreeGridCustom) {
        this.degreeGridCustom = degreeGridCustom;
        return this;
    }

    public double getDegreeGridMin() {
        return this.degreeGridMin;
    }

    public MLTrainSVM setDegreeGridMin(double degreeGridMin) {
        this.degreeGridMin = degreeGridMin;
        return this;
    }

    public double getDegreeGridMax() {
        return this.degreeGridMax;
    }

    public MLTrainSVM setDegreeGridMax(double degreeGridMax) {
        this.degreeGridMax = degreeGridMax;
        return this;
    }

    public double getDegreeGridLogStep() {
        return this.degreeGridLogStep;
    }

    public MLTrainSVM setDegreeGridLogStep(double degreeGridLogStep) {
        this.degreeGridLogStep = MLTrainSVM.nonNegative((double)degreeGridLogStep);
        return this;
    }

    public boolean isBalanced() {
        return this.balanced;
    }

    public MLTrainSVM setBalanced(boolean balanced) {
        this.balanced = balanced;
        return this;
    }

    public void process() {
        try (SVM model = this.newStatModel();
             TermCriteria termCriteria = OTools.termCriteria(this.terminationMaxCount, this.terminationEpsilon, true);){
            model.setType(this.svmType.code());
            model.setKernel(this.kernelType.code());
            model.setC(this.c);
            model.setGamma(this.gamma);
            model.setP(this.p);
            model.setNu(this.nu);
            model.setCoef0(this.coef);
            model.setDegree(this.degree);
            if (termCriteria != null) {
                model.setTermCriteria(termCriteria);
            }
            MLTrainSVM.logDebug(() -> "Training SVM: " + MLTrainSVM.toString(model));
            MLStatModelTrainer trainer = new MLStatModelTrainer((StatModel)model, this.modelKind());
            this.setTrainingFlags(trainer);
            this.train(trainer);
            this.writeTrainer(trainer);
        }
    }

    public static String toString(SVM model) {
        return String.format(Locale.US, "type=%s, kernel=%s, c=%s, gamma=%s, p=%s, nu=%s, coef=%s, degree=%s, %s", model.getType(), model.getKernelType(), model.getC(), model.getGamma(), model.getP(), model.getNu(), model.getCoef0(), model.getDegree(), OTools.toString(model.getTermCriteria()));
    }

    @Override
    protected MLKind modelKind() {
        return MLKind.StatModelBased.SVM;
    }

    @Override
    protected boolean categoricalResponses() {
        return true;
    }

    @Override
    protected void doTrain(MLTrainer trainer, TrainData trainData, int sampleLength, int responseLength) {
        SVM svm;
        block39: {
            assert (trainer instanceof MLStatModelTrainer) : "Illegal usage of doTrain method";
            StatModel model = ((MLStatModelTrainer)trainer).statModel();
            assert (model instanceof SVM) : "Illegal usage of doTrain method";
            svm = (SVM)model;
            if (this.autoTraining) {
                try (ParamGrid cGrid = MLTrainSVM.gridOrNull(this.cGridCustom, this.cGridMin, this.cGridMax, this.cGridLogStep);
                     ParamGrid gammaGrid = MLTrainSVM.gridOrNull(this.gammaGridCustom, this.gammaGridMin, this.gammaGridMax, this.gammaGridLogStep);
                     ParamGrid pGrid = MLTrainSVM.gridOrNull(this.pGridCustom, this.pGridMin, this.pGridMax, this.pGridLogStep);
                     ParamGrid nuGrid = MLTrainSVM.gridOrNull(this.nuGridCustom, this.nuGridMin, this.nuGridMax, this.nuGridLogStep);
                     ParamGrid coefGrid = MLTrainSVM.gridOrNull(this.coefGridCustom, this.coefGridMin, this.coefGridMax, this.coefGridLogStep);
                     ParamGrid degreeGrid = MLTrainSVM.gridOrNull(this.degreeGridCustom, this.degreeGridMin, this.degreeGridMax, this.degreeGridLogStep);){
                    MLTrainSVM.logDebug(() -> String.format("Auto-training SVM: %s%n  kFold=%d%n  %s%n  %s%n  %s%n  %s%n  %s%n  %s%n  balanced=%s", MLTrainSVM.toString(svm), this.kFold, MLTrainSVM.gridToString("cGrid", cGrid, svm, 0), MLTrainSVM.gridToString("gammaGrid", gammaGrid, svm, 1), MLTrainSVM.gridToString("pGrid", pGrid, svm, 2), MLTrainSVM.gridToString("nuGrid", nuGrid, svm, 3), MLTrainSVM.gridToString("coefGrid", coefGrid, svm, 4), MLTrainSVM.gridToString("degreeGrid", degreeGrid, svm, 5), this.balanced));
                    svm.trainAuto(trainData, this.kFold, cGrid, gammaGrid, pGrid, nuGrid, coefGrid, degreeGrid, this.balanced);
                    MLTrainSVM.logDebug(() -> String.format("Auto-training SVM result: %s", MLTrainSVM.toString(svm)));
                    break block39;
                }
            }
            super.doTrain(trainer, trainData, sampleLength, responseLength);
        }
        this.getScalar(OUTPUT_C).setTo(svm.getC());
        this.getScalar(OUTPUT_GAMMA).setTo(svm.getGamma());
        this.getScalar(OUTPUT_P).setTo(svm.getP());
        this.getScalar(OUTPUT_NU).setTo(svm.getNu());
        this.getScalar(OUTPUT_COEF).setTo(svm.getCoef0());
        this.getScalar(OUTPUT_DEGREE).setTo(svm.getDegree());
    }

    private SVM newStatModel() {
        SVM result = SVM.create();
        MLTrainSVM.logDebug(() -> "Creating SVM: " + MLTrainSVM.toString(result));
        return result;
    }

    private static ParamGrid gridOrNull(boolean custom, double min, double max, double logStep) {
        return custom ? new ParamGrid(min, max, logStep) : null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static String gridToString(String name, ParamGrid grid, SVM svm, int paramId) {
        boolean custom = grid != null;
        try {
            if (!custom && (grid = SVM.getDefaultGrid((int)paramId)) == null) {
                String string = "No default grid?";
                return string;
            }
            String string = String.format(Locale.US, "%s: %s min=%.5f, max=%.5f, logStep=%.3f", name, custom ? "(custom)" : "(default)", grid.minVal(), grid.maxVal(), grid.logStep());
            return string;
        }
        finally {
            if (!custom && grid != null) {
                grid.close();
            }
        }
    }

    public static enum SVMType {
        C_SVC(100),
        NU_SVC(101),
        ONE_CLASS(102),
        EPS_SVR(103),
        NU_SVR(104);

        private final int code;

        public int code() {
            return this.code;
        }

        private SVMType(int code) {
            this.code = code;
        }
    }

    public static enum KernelType {
        LINEAR(0),
        POLY(1),
        RBF(2),
        SIGMOID(3),
        CHI2(4),
        INTER(5);

        private final int code;

        public int code() {
            return this.code;
        }

        private KernelType(int code) {
            this.code = code;
        }
    }
}

