/*
 * Decompiled with CFR 0.152.
 */
package net.algart.executors.modules.opencv.matrices.ml.training;

import java.util.Locale;
import net.algart.executors.modules.opencv.matrices.ml.AbstractMLTrain;
import net.algart.executors.modules.opencv.matrices.ml.MLKind;
import net.algart.executors.modules.opencv.matrices.ml.MLSamplesType;
import net.algart.executors.modules.opencv.matrices.ml.MLStatModelTrainer;
import net.algart.executors.modules.util.opencv.OTools;
import org.bytedeco.opencv.opencv_core.TermCriteria;
import org.bytedeco.opencv.opencv_ml.SVMSGD;
import org.bytedeco.opencv.opencv_ml.StatModel;

public final class MLTrainSVMSGD
extends AbstractMLTrain {
    private SVMSGDType svmSgdType = SVMSGDType.ASGD;
    private MarginType marginType = MarginType.SOFT_MARGIN;
    private boolean optimalParameters = true;
    private double marginRegularization = 1.0E-5;
    private double initialStepSize = 0.05;
    private double stepDecreasingPower = 0.75;
    private int terminationMaxCount = 0;
    private double terminationEpsilon = 0.0;

    private MLTrainSVMSGD(MLSamplesType inputType) {
        super(inputType);
    }

    public static MLTrainSVMSGD newTrainNumbers() {
        return new MLTrainSVMSGD(MLSamplesType.NUMBERS);
    }

    public static MLTrainSVMSGD newTrainPixels() {
        return new MLTrainSVMSGD(MLSamplesType.PIXELS);
    }

    public SVMSGDType getSvmSgdType() {
        return this.svmSgdType;
    }

    public MLTrainSVMSGD setSvmSgdType(SVMSGDType svmSgdType) {
        this.svmSgdType = (SVMSGDType)((Object)MLTrainSVMSGD.nonNull((Object)((Object)svmSgdType)));
        return this;
    }

    public MarginType getMarginType() {
        return this.marginType;
    }

    public MLTrainSVMSGD setMarginType(MarginType marginType) {
        this.marginType = (MarginType)((Object)MLTrainSVMSGD.nonNull((Object)((Object)marginType)));
        return this;
    }

    public boolean isOptimalParameters() {
        return this.optimalParameters;
    }

    public MLTrainSVMSGD setOptimalParameters(boolean optimalParameters) {
        this.optimalParameters = optimalParameters;
        return this;
    }

    public double getMarginRegularization() {
        return this.marginRegularization;
    }

    public MLTrainSVMSGD setMarginRegularization(double marginRegularization) {
        this.marginRegularization = marginRegularization;
        return this;
    }

    public double getInitialStepSize() {
        return this.initialStepSize;
    }

    public MLTrainSVMSGD setInitialStepSize(double initialStepSize) {
        this.initialStepSize = initialStepSize;
        return this;
    }

    public double getStepDecreasingPower() {
        return this.stepDecreasingPower;
    }

    public MLTrainSVMSGD setStepDecreasingPower(double stepDecreasingPower) {
        this.stepDecreasingPower = stepDecreasingPower;
        return this;
    }

    public int getTerminationMaxCount() {
        return this.terminationMaxCount;
    }

    public MLTrainSVMSGD setTerminationMaxCount(int terminationMaxCount) {
        this.terminationMaxCount = MLTrainSVMSGD.nonNegative((int)terminationMaxCount);
        return this;
    }

    public double getTerminationEpsilon() {
        return this.terminationEpsilon;
    }

    public MLTrainSVMSGD setTerminationEpsilon(double terminationEpsilon) {
        this.terminationEpsilon = MLTrainSVMSGD.nonNegative((double)terminationEpsilon);
        return this;
    }

    public void process() {
        try (SVMSGD model = this.newStatModel();
             TermCriteria termCriteria = OTools.termCriteria(this.terminationMaxCount, this.terminationEpsilon, true);){
            if (this.optimalParameters) {
                model.setOptimalParameters(this.svmSgdType.code(), this.marginType.code());
            } else {
                model.setSvmsgdType(this.svmSgdType.code());
                model.setMarginType(this.marginType.code());
                model.setMarginRegularization((float)this.marginRegularization);
                model.setInitialStepSize((float)this.initialStepSize);
                model.setStepDecreasingPower((float)this.stepDecreasingPower);
                if (termCriteria != null) {
                    model.setTermCriteria(termCriteria);
                }
            }
            MLTrainSVMSGD.logDebug(() -> "Training SVMSGD: " + MLTrainSVMSGD.toString(model));
            MLStatModelTrainer trainer = new MLStatModelTrainer((StatModel)model, this.modelKind());
            this.setTrainingFlags(trainer);
            this.train(trainer);
            this.writeTrainer(trainer);
        }
    }

    public static String toString(SVMSGD model) {
        return String.format(Locale.US, "type=%s, margin=%s, marginRegularization=%s, initialStepSize=%s,stepDecreasingPower=%s, %s; result shift=%s", model.getSvmsgdType(), model.getMarginType(), Float.valueOf(model.getMarginRegularization()), Float.valueOf(model.getInitialStepSize()), Float.valueOf(model.getStepDecreasingPower()), OTools.toString(model.getTermCriteria()), Float.valueOf(model.getShift()));
    }

    @Override
    protected MLKind modelKind() {
        return MLKind.StatModelBased.SVM_SGD;
    }

    @Override
    protected boolean categoricalResponses() {
        return true;
    }

    private SVMSGD newStatModel() {
        SVMSGD result = SVMSGD.create();
        MLTrainSVMSGD.logDebug(() -> "Creating SVMSGD: " + MLTrainSVMSGD.toString(result));
        return result;
    }

    public static enum SVMSGDType {
        SGD(0),
        ASGD(1);

        private final int code;

        public int code() {
            return this.code;
        }

        private SVMSGDType(int code) {
            this.code = code;
        }
    }

    public static enum MarginType {
        SOFT_MARGIN(0),
        HARD_MARGIN(1);

        private final int code;

        public int code() {
            return this.code;
        }

        private MarginType(int code) {
            this.code = code;
        }
    }
}

