/*
 * Decompiled with CFR 0.152.
 */
package net.algart.executors.modules.opencv.matrices.ml.training;

import java.util.Locale;
import net.algart.executors.modules.opencv.matrices.ml.MLKind;
import net.algart.executors.modules.opencv.matrices.ml.MLSamplesType;
import net.algart.executors.modules.opencv.matrices.ml.MLStatModelTrainer;
import net.algart.executors.modules.opencv.matrices.ml.training.MLTrainDTrees;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_ml.Boost;
import org.bytedeco.opencv.opencv_ml.DTrees;
import org.bytedeco.opencv.opencv_ml.StatModel;

public final class MLTrainBoost
extends MLTrainDTrees {
    private BoostType boostType = BoostType.REAL;
    private int weakCount = 100;
    private double weightTrimRate = 0.95;

    private MLTrainBoost(MLSamplesType inputType) {
        super(inputType);
    }

    public static MLTrainBoost newTrainNumbers() {
        return new MLTrainBoost(MLSamplesType.NUMBERS);
    }

    public static MLTrainBoost newTrainPixels() {
        return new MLTrainBoost(MLSamplesType.PIXELS);
    }

    public BoostType getBoostType() {
        return this.boostType;
    }

    public MLTrainBoost setBoostType(BoostType boostType) {
        this.boostType = (BoostType)((Object)MLTrainBoost.nonNull((Object)((Object)boostType)));
        return this;
    }

    public int getWeakCount() {
        return this.weakCount;
    }

    public MLTrainBoost setWeakCount(int weakCount) {
        this.weakCount = weakCount;
        return this;
    }

    public double getWeightTrimRate() {
        return this.weightTrimRate;
    }

    public MLTrainBoost setWeightTrimRate(double weightTrimRate) {
        this.weightTrimRate = weightTrimRate;
        return this;
    }

    public void process() {
        try (Boost model = this.newStatModel();
             Mat priors = this.priors();){
            model.setBoostType(this.boostType.code());
            model.setWeakCount(this.weakCount);
            model.setWeightTrimRate(this.weightTrimRate);
            this.customizeDTrees((DTrees)model, priors);
            MLTrainBoost.logDebug(() -> "Training " + this.modelKind().modelName() + ": " + MLTrainBoost.toString(model));
            MLStatModelTrainer trainer = new MLStatModelTrainer((StatModel)model, this.modelKind());
            this.setTrainingFlags(trainer);
            this.train(trainer);
            this.writeTrainer(trainer);
        }
    }

    public static String toString(Boost model) {
        return String.format(Locale.US, "%s, boostType=%s, weakCount=%s, weightTrimRate=%s", MLTrainDTrees.toString((DTrees)model), model.getBoostType(), model.getWeakCount(), model.getWeightTrimRate());
    }

    @Override
    protected MLKind modelKind() {
        return MLKind.StatModelBased.BOOST;
    }

    @Override
    protected boolean categoricalResponses() {
        return true;
    }

    private Boost newStatModel() {
        Boost result = Boost.create();
        MLTrainBoost.logDebug(() -> "Creating Boost: " + MLTrainBoost.toString(result));
        return result;
    }

    public static enum BoostType {
        DISCRETE(0),
        REAL(1),
        LOGIT(2),
        GENTLE(3);

        private final int code;

        public int code() {
            return this.code;
        }

        private BoostType(int code) {
            this.code = code;
        }
    }
}

