/*
 * Decompiled with CFR 0.152.
 */
package net.algart.executors.modules.opencv.matrices.ml;

import java.io.IOError;
import java.io.IOException;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Objects;
import net.algart.executors.api.Port;
import net.algart.executors.api.data.DataType;
import net.algart.executors.api.data.SMat;
import net.algart.executors.api.data.SNumbers;
import net.algart.executors.api.parameters.Parameters;
import net.algart.executors.modules.cv.matrices.pixels.GetLabelledPixels;
import net.algart.executors.modules.cv.matrices.pixels.SetPixels;
import net.algart.executors.modules.opencv.matrices.ml.AbstractMLOperation;
import net.algart.executors.modules.opencv.matrices.ml.MLKind;
import net.algart.executors.modules.opencv.matrices.ml.MLMetadataJson;
import net.algart.executors.modules.opencv.matrices.ml.MLSamplesType;
import net.algart.executors.modules.opencv.matrices.ml.MLStatModelTrainer;
import net.algart.executors.modules.opencv.matrices.ml.MLTrainer;
import net.algart.executors.modules.opencv.util.O2SMat;
import net.algart.executors.modules.opencv.util.OTools;
import net.algart.multimatrix.MultiMatrix;
import net.algart.multimatrix.MultiMatrix2D;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.UMat;
import org.bytedeco.opencv.opencv_ml.TrainData;

public abstract class AbstractMLTrain
extends AbstractMLOperation {
    public static final String OUTPUT_ACTUAL_TRAINING_RESPONSES = "training_responses";
    public static final String OUTPUT_TRAINING_MODEL_FILE = "model_file";
    public static final String OUTPUT_TRAINING_METADATA = "metadata";
    public static final String OUTPUT_TRAINING_ERROR = "error";
    public static final String OUTPUT_IS_CLASSIFIER = "is_classifier";
    private boolean convertCategoricalResponses = false;
    private boolean calculateError = false;
    private boolean testPredictTrainedSamples = false;
    private boolean trainingCombinedSamplesAndResponses = false;
    private int trainingFlags = 0;

    protected AbstractMLTrain(MLSamplesType samplesType) {
        super(samplesType);
        this.addPort(Port.newInput((String)"samples", (DataType)samplesType.portDataType));
        this.addPort(Port.newOutput((String)DEFAULT_OUTPUT_PORT, (DataType)samplesType.portDataType));
        this.addPort(Port.newInput((String)OUTPUT_ACTUAL_TRAINING_RESPONSES, (DataType)samplesType.portDataType));
        this.addOutputNumbers(OUTPUT_ACTUAL_TRAINING_RESPONSES);
        this.addOutputScalar(OUTPUT_TRAINING_MODEL_FILE);
        this.addOutputScalar(OUTPUT_TRAINING_METADATA);
        this.addOutputScalar(OUTPUT_TRAINING_ERROR);
        this.addOutputScalar(OUTPUT_IS_CLASSIFIER);
    }

    public final boolean isConvertCategoricalResponses() {
        return this.convertCategoricalResponses;
    }

    public final void setConvertCategoricalResponses(boolean convertCategoricalResponses) {
        this.convertCategoricalResponses = convertCategoricalResponses;
    }

    public final boolean isCalculateError() {
        return this.calculateError;
    }

    public final void setCalculateError(boolean calculateError) {
        this.calculateError = calculateError;
    }

    public final boolean isTestPredictTrainedSamples() {
        return this.testPredictTrainedSamples;
    }

    public final void setTestPredictTrainedSamples(boolean testPredictTrainedSamples) {
        this.testPredictTrainedSamples = testPredictTrainedSamples;
    }

    public final boolean isTrainingCombinedSamplesAndResponses() {
        return this.trainingCombinedSamplesAndResponses;
    }

    public final void setTrainingCombinedSamplesAndResponses(boolean trainingCombinedSamplesAndResponses) {
        this.trainingCombinedSamplesAndResponses = trainingCombinedSamplesAndResponses;
    }

    public final int getTrainingFlags() {
        return this.trainingFlags;
    }

    public final void setTrainingFlags(int trainingFlags) {
        this.trainingFlags = trainingFlags;
    }

    public final void setTrainingFlagByMask(int bitMask, boolean value) {
        this.trainingFlags = value ? (this.trainingFlags |= bitMask) : (this.trainingFlags &= ~bitMask);
    }

    public final boolean getTrainingFlagByMask(int bitMask) {
        return (this.trainingFlags & bitMask) != 0;
    }

    public final void train(MLTrainer trainer) {
        Objects.requireNonNull(trainer, "Null trainer");
        this.samplesType().train(this, trainer);
    }

    public double trainNumbers(MLTrainer trainer, SNumbers samples, SNumbers responses, SNumbers autoTestResult) {
        double error;
        block54: {
            Objects.requireNonNull(trainer, "Null trainer");
            Objects.requireNonNull(samples, "Null samples");
            Objects.requireNonNull(responses, "Null responses");
            if (samples.n() != responses.n()) {
                throw new IllegalArgumentException("Source and training responses arrays have different length: " + samples + " and " + responses);
            }
            if (this.isUseGPU()) {
                try (UMat samplesMat = O2SMat.numbersToMulticolumn32BitUMat(samples, false);
                     UMat responsesMat = O2SMat.numbersToMulticolumn32BitUMat(responses, this.categoricalResponses());
                     UMat binary = this.convertCategoricalResponses ? AbstractMLTrain.categoricalToMultiBinaryResponses(responsesMat) : null;
                     UMat autoTestResultsMat = this.calculateError ? new UMat() : null;){
                    UMat actualResponses = binary != null ? binary : responsesMat;
                    error = this.doTrain(trainer, samplesMat, actualResponses, autoTestResultsMat);
                    if (this.isOutputNecessary(OUTPUT_ACTUAL_TRAINING_RESPONSES)) {
                        this.getNumbers(OUTPUT_ACTUAL_TRAINING_RESPONSES).exchange(O2SMat.multicolumnMatToNumbers(actualResponses));
                    }
                    if (autoTestResultsMat != null) {
                        autoTestResult.exchange(O2SMat.multicolumnMatToNumbers(autoTestResultsMat));
                    }
                    break block54;
                }
            }
            try (Mat samplesMat = O2SMat.numbersToMulticolumn32BitMat(samples, false);
                 Mat responsesMat = O2SMat.numbersToMulticolumn32BitMat(responses, this.categoricalResponses());
                 Mat binary = this.convertCategoricalResponses ? AbstractMLTrain.categoricalToMultiBinaryResponses(responsesMat) : null;
                 Mat autoTestResultsMat = this.calculateError ? new Mat() : null;){
                Mat actualResponses = binary != null ? binary : responsesMat;
                error = this.doTrain(trainer, samplesMat, actualResponses, autoTestResultsMat);
                if (this.isOutputNecessary(OUTPUT_ACTUAL_TRAINING_RESPONSES)) {
                    this.getNumbers(OUTPUT_ACTUAL_TRAINING_RESPONSES).exchange(O2SMat.multicolumnMatToNumbers(actualResponses));
                }
                if (autoTestResultsMat != null) {
                    autoTestResult.exchange(O2SMat.multicolumnMatToNumbers(autoTestResultsMat));
                }
            }
        }
        return error;
    }

    public double trainPixels(MLTrainer trainer, SMat samples, SMat responses, SMat autoTestResult) {
        Objects.requireNonNull(trainer, "Null trainer");
        Objects.requireNonNull(samples, "Null samples");
        Objects.requireNonNull(responses, "Null responses");
        SNumbers sampleNumbers = new SNumbers();
        SNumbers responseNumbers = new SNumbers();
        MultiMatrix2D labels = responses.toMultiMatrix2D();
        new GetLabelledPixels().process(samples.toMultiMatrix2D(), labels, sampleNumbers, responseNumbers);
        SNumbers autoTestResultNumbers = new SNumbers();
        double error = this.trainNumbers(trainer, sampleNumbers, responseNumbers, autoTestResultNumbers);
        if (autoTestResultNumbers.isInitialized()) {
            autoTestResult.setTo((MultiMatrix)new SetPixels().process(autoTestResultNumbers, labels, null));
        }
        return error;
    }

    public final void writeTrainer(MLTrainer trainer) {
        Objects.requireNonNull(trainer, "Null trainer");
        try {
            this.getScalar(OUTPUT_IS_CLASSIFIER).setTo(trainer.isClassifier());
            Path file = this.statModelFile();
            trainer.save(file);
            this.getScalar(OUTPUT_TRAINING_MODEL_FILE).setTo((Object)this.statModelFile());
            MLMetadataJson metadata = this.metadata(trainer);
            if (metadata != null) {
                this.getScalar(OUTPUT_TRAINING_METADATA).setTo(metadata.jsonString());
                metadata.write(MLMetadataJson.metadataFile(file), new OpenOption[0]);
            }
        }
        catch (IOException e) {
            throw new IOError(e);
        }
    }

    public final void setTrainingFlags(MLStatModelTrainer trainer) {
        Objects.requireNonNull(trainer, "Null trainer");
        trainer.setTrainingFlags(this.trainingFlags);
    }

    protected abstract MLKind modelKind();

    protected MLMetadataJson metadata(MLTrainer trainer) {
        LinkedHashMap parameters = new LinkedHashMap(this.parameters());
        return new MLMetadataJson().setModelKind(this.modelKind()).setCreatedBy(((Object)((Object)this)).getClass().getCanonicalName()).setParameters(Parameters.toJson(parameters));
    }

    protected abstract boolean categoricalResponses();

    protected double doTrain(MLTrainer trainer, Mat samples, Mat responses, Mat autoTestResults) {
        int sampleLength = samples.cols();
        int responseLength = responses.cols();
        try (Mat varType = AbstractMLTrain.toMat(this.varType(sampleLength, responseLength));){
            double d;
            block16: {
                TrainData trainData;
                block14: {
                    double d2;
                    block15: {
                        trainData = TrainData.create((Mat)samples, (int)O2SMat.ML_LAYOUT, (Mat)responses, null, null, null, (Mat)varType);
                        try {
                            this.doTrain(trainer, trainData, sampleLength, responseLength);
                            if (!this.calculateError) break block14;
                            d2 = this.doCalculateError(trainer, trainData, autoTestResults);
                            if (trainData == null) break block15;
                        }
                        catch (Throwable throwable) {
                            if (trainData != null) {
                                try {
                                    trainData.close();
                                }
                                catch (Throwable throwable2) {
                                    throwable.addSuppressed(throwable2);
                                }
                            }
                            throw throwable;
                        }
                        trainData.close();
                    }
                    return d2;
                }
                d = Double.NaN;
                if (trainData == null) break block16;
                trainData.close();
            }
            return d;
        }
    }

    protected double doTrain(MLTrainer trainer, UMat samples, UMat responses, UMat autoTestResults) {
        int sampleLength = samples.cols();
        int responseLength = responses.cols();
        try (UMat varType = AbstractMLTrain.toUMat(this.varType(sampleLength, responseLength));){
            double d;
            block16: {
                TrainData trainData;
                block14: {
                    double d2;
                    block15: {
                        trainData = TrainData.create((UMat)samples, (int)O2SMat.ML_LAYOUT, (UMat)responses, null, null, null, (UMat)varType);
                        try {
                            this.doTrain(trainer, trainData, sampleLength, responseLength);
                            if (!this.calculateError) break block14;
                            d2 = this.doCalculateError(trainer, trainData, autoTestResults);
                            if (trainData == null) break block15;
                        }
                        catch (Throwable throwable) {
                            if (trainData != null) {
                                try {
                                    trainData.close();
                                }
                                catch (Throwable throwable2) {
                                    throwable.addSuppressed(throwable2);
                                }
                            }
                            throw throwable;
                        }
                        trainData.close();
                    }
                    return d2;
                }
                d = Double.NaN;
                if (trainData == null) break block16;
                trainData.close();
            }
            return d;
        }
    }

    protected void doTrain(MLTrainer trainer, TrainData trainData, int sampleLength, int responseLength) {
        trainer.train(trainData);
    }

    protected double doCalculateError(MLTrainer trainer, TrainData trainData, Mat result) {
        return trainer.calculateError(trainData, result);
    }

    protected double doCalculateError(MLTrainer trainer, TrainData trainData, UMat result) {
        return trainer.calculateError(trainData, result);
    }

    protected byte[] varType(int numberOfSamples, int numberOfResponses) {
        byte[] result = new byte[numberOfSamples + numberOfResponses];
        Arrays.fill(result, (byte)0);
        if (this.categoricalResponses() && numberOfResponses == 1) {
            Arrays.fill(result, numberOfSamples, result.length, (byte)1);
        }
        return result;
    }

    private static Mat toMat(byte[] array) {
        return OTools.toMat(1, array.length, opencv_core.CV_8UC1, array);
    }

    private static UMat toUMat(byte[] array) {
        return OTools.toUMat(1, array.length, opencv_core.CV_8UC1, array);
    }
}

