/*
 * Decompiled with CFR 0.152.
 */
package net.algart.executors.modules.opencv.matrices.ml;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.Objects;
import java.util.function.Function;
import net.algart.executors.modules.opencv.matrices.ml.MLKind;
import net.algart.executors.modules.opencv.matrices.ml.MLTrainer;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.UMat;
import org.bytedeco.opencv.opencv_ml.StatModel;
import org.bytedeco.opencv.opencv_ml.TrainData;

public class MLStatModelTrainer
implements MLTrainer {
    private final StatModel statModel;
    private final MLKind statModelKind;
    private int predictionFlags = 0;
    private int trainingFlags = 0;

    public MLStatModelTrainer(StatModel statModel, MLKind statModelKind) {
        this.statModel = Objects.requireNonNull(statModel, "Null statModel");
        this.statModelKind = Objects.requireNonNull(statModelKind, "Null statModelKind");
    }

    public StatModel statModel() {
        return this.statModel;
    }

    public MLKind modelKine() {
        return this.statModelKind;
    }

    public int getPredictionFlags() {
        return this.predictionFlags;
    }

    @Override
    public MLStatModelTrainer setPredictionFlags(int predictionFlags) {
        this.predictionFlags = predictionFlags;
        return this;
    }

    public int getTrainingFlags() {
        return this.trainingFlags;
    }

    public MLStatModelTrainer setTrainingFlags(int trainingFlags) {
        this.trainingFlags = trainingFlags;
        return this;
    }

    @Override
    public boolean isClassifier() {
        return this.statModel.isClassifier();
    }

    @Override
    public void predict(Mat samples, Mat result) {
        this.statModel.predict(samples, result, this.predictionFlags);
    }

    @Override
    public void predict(UMat samples, UMat result) {
        this.statModel.predict(samples, result, this.predictionFlags);
    }

    @Override
    public void train(TrainData trainData) {
        this.statModel.train(trainData, this.trainingFlags);
    }

    @Override
    public double calculateError(TrainData trainData, Mat result) {
        return this.statModel.calcError(trainData, false, result);
    }

    @Override
    public double calculateError(TrainData trainData, UMat result) {
        return this.statModel.calcError(trainData, false, result);
    }

    @Override
    public void close() {
        this.statModel.close();
    }

    @Override
    public void save(Path file) throws IOException {
        Objects.requireNonNull(file, "Null file");
        Objects.requireNonNull(this.statModel, "Null statModel");
        if (Files.isDirectory(file, new LinkOption[0])) {
            throw new IOException("Result statistic model file cannot be an existing directory: " + file);
        }
        Path parent = file.getParent();
        if (parent != null && !Files.isDirectory(parent, new LinkOption[0])) {
            throw new IOException("Result statistic model file cannot be saved inside non-existing directory: " + parent);
        }
        this.statModel.save(file.toAbsolutePath().toString());
    }

    public String toString() {
        return "MLStatModelTrainer for " + this.statModelKind + ", statModel " + this.statModel;
    }

    public static MLTrainer loadOpenCVTrainer(Path file, Function<String, StatModel> loader, MLKind statModelKind) throws IOException {
        Objects.requireNonNull(file, "Null file");
        Objects.requireNonNull(loader, "Null loader");
        Objects.requireNonNull(statModelKind, "Null statModelKind");
        if (!Files.isRegularFile(file, new LinkOption[0])) {
            throw new FileNotFoundException("Statistic model file does not exist or is not a regular file: " + file);
        }
        StatModel statModel = loader.apply(file.toAbsolutePath().toString());
        return new MLStatModelTrainer(statModel, statModelKind);
    }
}

