/*
 * Decompiled with CFR 0.152.
 */
package net.algart.executors.modules.opencv.matrices.ml.training;

import java.util.Locale;
import net.algart.executors.api.data.SNumbers;
import net.algart.executors.api.data.SScalar;
import net.algart.executors.modules.opencv.matrices.ml.AbstractMLTrain;
import net.algart.executors.modules.opencv.matrices.ml.MLKind;
import net.algart.executors.modules.opencv.matrices.ml.MLSamplesType;
import net.algart.executors.modules.opencv.matrices.ml.MLStatModelTrainer;
import net.algart.executors.modules.opencv.matrices.ml.MLTrainer;
import net.algart.executors.modules.opencv.util.O2SMat;
import net.algart.executors.modules.opencv.util.OTools;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.RNG;
import org.bytedeco.opencv.opencv_core.TermCriteria;
import org.bytedeco.opencv.opencv_ml.ANN_MLP;
import org.bytedeco.opencv.opencv_ml.StatModel;
import org.bytedeco.opencv.opencv_ml.TrainData;

public final class MLTrainANNMLP
extends AbstractMLTrain {
    public static final String OUTPUT_LAYER_WEIGHTS = "layer_weights";
    private TrainingMethod trainingMethod = TrainingMethod.RPROP;
    private double trainingMethodParam1 = 0.0;
    private double trainingMethodParam2 = 0.0;
    private ActivationFunction activationFunction = ActivationFunction.SIGMOID_SYM;
    private double activationFunctionParam1 = 0.0;
    private double activationFunctionParam2 = 0.0;
    private int[] hiddenLayerSizes = new int[0];
    private double backpropMomentumScale = 0.1;
    private double backpropWeightScale = 0.1;
    private double rpropDW0 = 0.1;
    private double rpropDWMax = 50.0;
    private double rpropDWMin = 1.4E-45f;
    private double rpropDWMinus = 0.5;
    private double rpropDWPlus = 1.2;
    private double annealCoolingRatio = 0.95;
    private double annealFinalT = 0.1;
    private double annealInitialT = 10.0;
    private int annealItePerStep = 10;
    private Integer annealEnergyRandSeed = null;
    private int terminationMaxCount = 0;
    private double terminationEpsilon = 0.0;
    private int layerIndexToGetWeights = 0;

    private MLTrainANNMLP(MLSamplesType inputType) {
        super(inputType);
        this.addOutputNumbers(OUTPUT_LAYER_WEIGHTS);
    }

    public static MLTrainANNMLP newTrainNumbers() {
        return new MLTrainANNMLP(MLSamplesType.NUMBERS);
    }

    public static MLTrainANNMLP newTrainPixels() {
        return new MLTrainANNMLP(MLSamplesType.PIXELS);
    }

    public TrainingMethod getTrainingMethod() {
        return this.trainingMethod;
    }

    public MLTrainANNMLP setTrainingMethod(TrainingMethod trainingMethod) {
        this.trainingMethod = (TrainingMethod)((Object)MLTrainANNMLP.nonNull((Object)((Object)trainingMethod)));
        return this;
    }

    public double getTrainingMethodParam1() {
        return this.trainingMethodParam1;
    }

    public MLTrainANNMLP setTrainingMethodParam1(double trainingMethodParam1) {
        this.trainingMethodParam1 = trainingMethodParam1;
        return this;
    }

    public double getTrainingMethodParam2() {
        return this.trainingMethodParam2;
    }

    public MLTrainANNMLP setTrainingMethodParam2(double trainingMethodParam2) {
        this.trainingMethodParam2 = trainingMethodParam2;
        return this;
    }

    public ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    public MLTrainANNMLP setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = (ActivationFunction)((Object)MLTrainANNMLP.nonNull((Object)((Object)activationFunction)));
        return this;
    }

    public double getActivationFunctionParam1() {
        return this.activationFunctionParam1;
    }

    public MLTrainANNMLP setActivationFunctionParam1(double activationFunctionParam1) {
        this.activationFunctionParam1 = activationFunctionParam1;
        return this;
    }

    public double getActivationFunctionParam2() {
        return this.activationFunctionParam2;
    }

    public MLTrainANNMLP setActivationFunctionParam2(double activationFunctionParam2) {
        this.activationFunctionParam2 = activationFunctionParam2;
        return this;
    }

    public int[] getHiddenLayerSizes() {
        return (int[])this.hiddenLayerSizes.clone();
    }

    public MLTrainANNMLP setHiddenLayerSizes(int[] hiddenLayerSizes) {
        this.hiddenLayerSizes = (int[])((int[])MLTrainANNMLP.nonNull((Object)hiddenLayerSizes)).clone();
        return this;
    }

    public MLTrainANNMLP setHiddenLayerSizes(String layerSizes) {
        return this.setHiddenLayerSizes(new SScalar((String)MLTrainANNMLP.nonNull((Object)layerSizes)).toInts());
    }

    public double getBackpropMomentumScale() {
        return this.backpropMomentumScale;
    }

    public MLTrainANNMLP setBackpropMomentumScale(double backpropMomentumScale) {
        this.backpropMomentumScale = backpropMomentumScale;
        return this;
    }

    public double getBackpropWeightScale() {
        return this.backpropWeightScale;
    }

    public MLTrainANNMLP setBackpropWeightScale(double backpropWeightScale) {
        this.backpropWeightScale = backpropWeightScale;
        return this;
    }

    public double getRpropDW0() {
        return this.rpropDW0;
    }

    public MLTrainANNMLP setRpropDW0(double rpropDW0) {
        this.rpropDW0 = rpropDW0;
        return this;
    }

    public double getRpropDWMax() {
        return this.rpropDWMax;
    }

    public MLTrainANNMLP setRpropDWMax(double rpropDWMax) {
        this.rpropDWMax = rpropDWMax;
        return this;
    }

    public double getRpropDWMin() {
        return this.rpropDWMin;
    }

    public MLTrainANNMLP setRpropDWMin(double rpropDWMin) {
        this.rpropDWMin = rpropDWMin;
        return this;
    }

    public double getRpropDWMinus() {
        return this.rpropDWMinus;
    }

    public MLTrainANNMLP setRpropDWMinus(double rpropDWMinus) {
        this.rpropDWMinus = rpropDWMinus;
        return this;
    }

    public double getRpropDWPlus() {
        return this.rpropDWPlus;
    }

    public MLTrainANNMLP setRpropDWPlus(double rpropDWPlus) {
        this.rpropDWPlus = rpropDWPlus;
        return this;
    }

    public double getAnnealCoolingRatio() {
        return this.annealCoolingRatio;
    }

    public MLTrainANNMLP setAnnealCoolingRatio(double annealCoolingRatio) {
        this.annealCoolingRatio = annealCoolingRatio;
        return this;
    }

    public double getAnnealFinalT() {
        return this.annealFinalT;
    }

    public MLTrainANNMLP setAnnealFinalT(double annealFinalT) {
        this.annealFinalT = annealFinalT;
        return this;
    }

    public double getAnnealInitialT() {
        return this.annealInitialT;
    }

    public MLTrainANNMLP setAnnealInitialT(double annealInitialT) {
        this.annealInitialT = annealInitialT;
        return this;
    }

    public int getAnnealItePerStep() {
        return this.annealItePerStep;
    }

    public MLTrainANNMLP setAnnealItePerStep(int annealItePerStep) {
        this.annealItePerStep = annealItePerStep;
        return this;
    }

    public Integer getAnnealEnergyRandSeed() {
        return this.annealEnergyRandSeed;
    }

    public MLTrainANNMLP setAnnealEnergyRandSeed(Integer annealEnergyRandSeed) {
        this.annealEnergyRandSeed = annealEnergyRandSeed;
        return this;
    }

    public boolean isUpdateWeights() {
        return this.getTrainingFlagByMask(1);
    }

    public MLTrainANNMLP setUpdateWeights(boolean updateWeights) {
        this.setTrainingFlagByMask(1, updateWeights);
        return this;
    }

    public boolean isNoInputScale() {
        return this.getTrainingFlagByMask(2);
    }

    public MLTrainANNMLP setNoInputScale(boolean noInputScale) {
        this.setTrainingFlagByMask(2, noInputScale);
        return this;
    }

    public boolean isNoOutputScale() {
        return this.getTrainingFlagByMask(4);
    }

    public MLTrainANNMLP setNoOutputScale(boolean noOutputScale) {
        this.setTrainingFlagByMask(4, noOutputScale);
        return this;
    }

    public int getTerminationMaxCount() {
        return this.terminationMaxCount;
    }

    public MLTrainANNMLP setTerminationMaxCount(int terminationMaxCount) {
        this.terminationMaxCount = MLTrainANNMLP.nonNegative((int)terminationMaxCount);
        return this;
    }

    public double getTerminationEpsilon() {
        return this.terminationEpsilon;
    }

    public MLTrainANNMLP setTerminationEpsilon(double terminationEpsilon) {
        this.terminationEpsilon = MLTrainANNMLP.nonNegative((double)terminationEpsilon);
        return this;
    }

    public int getLayerIndexToGetWeights() {
        return this.layerIndexToGetWeights;
    }

    public MLTrainANNMLP setLayerIndexToGetWeights(int layerIndexToGetWeights) {
        this.layerIndexToGetWeights = MLTrainANNMLP.nonNegative((int)layerIndexToGetWeights);
        return this;
    }

    public void process() {
        try (ANN_MLP model = this.newStatModel();
             TermCriteria termCriteria = OTools.termCriteria(this.terminationMaxCount, this.terminationEpsilon, true);
             RNG rng = this.annealEnergyRandSeed != null ? new RNG((long)this.annealEnergyRandSeed.intValue()) : null;){
            model.setTrainMethod(this.trainingMethod.code(), this.trainingMethodParam1, this.trainingMethodParam2);
            model.setBackpropMomentumScale(this.backpropMomentumScale);
            model.setBackpropWeightScale(this.backpropWeightScale);
            model.setRpropDW0(this.rpropDW0);
            model.setRpropDWMax(this.rpropDWMax);
            model.setRpropDWMin(this.rpropDWMin);
            model.setRpropDWMinus(this.rpropDWMinus);
            model.setRpropDWPlus(this.rpropDWPlus);
            model.setAnnealCoolingRatio(this.annealCoolingRatio);
            model.setAnnealFinalT(this.annealFinalT);
            model.setAnnealInitialT(this.annealInitialT);
            model.setAnnealItePerStep(this.annealItePerStep);
            if (termCriteria != null) {
                model.setTermCriteria(termCriteria);
            }
            if (rng != null) {
                model.setAnnealEnergyRNG(rng);
            }
            MLTrainANNMLP.logDebug(() -> "Training ANN_MLP: " + MLTrainANNMLP.toString(model));
            MLStatModelTrainer trainer = new MLStatModelTrainer((StatModel)model, this.modelKind());
            this.setTrainingFlags(trainer);
            this.train(trainer);
            this.writeTrainer(trainer);
            try (Mat layerWeights = model.getWeights(this.layerIndexToGetWeights);){
                this.getNumbers(OUTPUT_LAYER_WEIGHTS).exchange(O2SMat.toRawNumbers(layerWeights, 1));
            }
        }
    }

    @Override
    protected void doTrain(MLTrainer trainer, TrainData trainData, int sampleLength, int responseLength) {
        assert (trainer instanceof MLStatModelTrainer) : "Illegal usage of doTrain method";
        StatModel model = ((MLStatModelTrainer)trainer).statModel();
        assert (model instanceof ANN_MLP) : "Illegal usage of doTrain method";
        ANN_MLP ann = (ANN_MLP)model;
        try (Mat layerSizesMat = this.layerSizes(sampleLength, responseLength);){
            ann.setLayerSizes(layerSizesMat);
            ann.setActivationFunction(this.activationFunction.code(), this.activationFunctionParam1, this.activationFunctionParam2);
            super.doTrain(trainer, trainData, sampleLength, responseLength);
        }
    }

    public static String toString(ANN_MLP model) {
        return String.format(Locale.US, "method=%s, backpropMomentumScale=%s, backpropWeightScale=%s, rpropDW0=%s, rpropDWMax=%s, rpropDWMin=%s, rpropDWMinus=%s, rpropDWPlus=%s, annealCoolingRatio=%s, annealFinalT=%s, annealInitialT=%s, annealItePerStep=%s, %s", model.getTrainMethod(), model.getBackpropMomentumScale(), model.getBackpropWeightScale(), model.getRpropDW0(), model.getRpropDWMax(), model.getRpropDWMin(), model.getRpropDWMinus(), model.getRpropDWPlus(), model.getAnnealCoolingRatio(), model.getAnnealFinalT(), model.getAnnealInitialT(), model.getAnnealItePerStep(), OTools.toString(model.getTermCriteria()));
    }

    @Override
    protected MLKind modelKind() {
        return MLKind.StatModelBased.ANN_MLP;
    }

    @Override
    protected boolean categoricalResponses() {
        return false;
    }

    private ANN_MLP newStatModel() {
        ANN_MLP result = ANN_MLP.create();
        MLTrainANNMLP.logDebug(() -> "Creating ANN_MLP: " + MLTrainANNMLP.toString(result));
        return result;
    }

    private Mat layerSizes(int sampleLength, int responseLength) {
        int[] layerSizes = new int[this.hiddenLayerSizes.length + 2];
        System.arraycopy(this.hiddenLayerSizes, 0, layerSizes, 1, this.hiddenLayerSizes.length);
        layerSizes[0] = sampleLength;
        layerSizes[layerSizes.length - 1] = responseLength;
        SNumbers result = SNumbers.valueOfArray((Object)layerSizes, (int)1);
        return O2SMat.numbersToMulticolumn32BitMat(result, true);
    }

    public static void main(String[] args) {
        ANN_MLP model = ANN_MLP.create();
        int[] layersSizes = new int[]{5};
        MLTrainANNMLP training = new MLTrainANNMLP(MLSamplesType.NUMBERS);
        training.setHiddenLayerSizes(layersSizes);
        training.setUseGPU(false);
        training.trainNumbers(new MLStatModelTrainer((StatModel)model, MLKind.StatModelBased.ANN_MLP), SNumbers.valueOfArray((Object)new float[]{1.0f, 1.0f}, (int)1), SNumbers.valueOfArray((Object)new float[]{1.0f, 1.0f}, (int)1), null);
        System.out.println("OK 2");
        for (int k = 0; k < layersSizes.length + 2; ++k) {
            Mat weightsMat = model.getWeights(k);
            SNumbers weights = O2SMat.toRawNumbers(weightsMat, weightsMat.cols());
            System.out.printf("Weights for layer #%d: %dx%d%n%s%n", k, weightsMat.cols(), weightsMat.rows(), weights.toString(true));
        }
    }

    public static enum TrainingMethod {
        BACKPROP(0),
        RPROP(1),
        ANNEAL(2);

        private final int code;

        public int code() {
            return this.code;
        }

        private TrainingMethod(int code) {
            this.code = code;
        }
    }

    public static enum ActivationFunction {
        IDENTITY(0),
        SIGMOID_SYM(1),
        GAUSSIAN(2),
        RELU(3),
        LEAKYRELU(4);

        private final int code;

        public int code() {
            return this.code;
        }

        private ActivationFunction(int code) {
            this.code = code;
        }
    }
}

