/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.naivebayes.compound;

import java.util.Collection;
import java.util.Collections;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.naivebayes.BayesModel;
import org.apache.ignite.ml.naivebayes.compound.CompoundNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class CompoundNaiveBayesTrainer
extends SingleLabelDatasetTrainer<CompoundNaiveBayesModel> {
    private double[] priorProbabilities;
    private GaussianNaiveBayesTrainer gaussianNaiveBayesTrainer;
    private Collection<Integer> gaussianFeatureIdsToSkip = Collections.emptyList();
    private DiscreteNaiveBayesTrainer discreteNaiveBayesTrainer;
    private Collection<Integer> discreteFeatureIdsToSkip = Collections.emptyList();

    @Override
    public <K, V> CompoundNaiveBayesModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        return this.updateModel((CompoundNaiveBayesModel)null, datasetBuilder, extractor);
    }

    @Override
    public boolean isUpdateable(CompoundNaiveBayesModel mdl) {
        return this.gaussianNaiveBayesTrainer.isUpdateable(mdl.getGaussianModel()) && this.discreteNaiveBayesTrainer.isUpdateable(mdl.getDiscreteModel());
    }

    public CompoundNaiveBayesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        return (CompoundNaiveBayesTrainer)super.withEnvironmentBuilder(envBuilder);
    }

    @Override
    protected <K, V> CompoundNaiveBayesModel updateModel(CompoundNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        BayesModel<GaussianNaiveBayesModel, Vector, Double> model;
        CompoundNaiveBayesModel compoundModel = new CompoundNaiveBayesModel().withPriorProbabilities(this.priorProbabilities);
        if (this.gaussianNaiveBayesTrainer != null) {
            if (this.priorProbabilities != null) {
                this.gaussianNaiveBayesTrainer.setPriorProbabilities(this.priorProbabilities);
            }
            model = mdl == null ? (GaussianNaiveBayesModel)this.gaussianNaiveBayesTrainer.fit(datasetBuilder, extractor.map(CompoundNaiveBayesTrainer.skipFeatures(this.gaussianFeatureIdsToSkip))) : this.gaussianNaiveBayesTrainer.update(mdl.getGaussianModel(), datasetBuilder, extractor.map(CompoundNaiveBayesTrainer.skipFeatures(this.gaussianFeatureIdsToSkip)));
            compoundModel.withGaussianModel((GaussianNaiveBayesModel)model).withGaussianFeatureIdsToSkip(this.gaussianFeatureIdsToSkip).withLabels(((GaussianNaiveBayesModel)model).getLabels()).withPriorProbabilities(this.priorProbabilities);
        }
        if (this.discreteNaiveBayesTrainer != null) {
            if (this.priorProbabilities != null) {
                this.discreteNaiveBayesTrainer.setPriorProbabilities(this.priorProbabilities);
            }
            model = mdl == null ? (DiscreteNaiveBayesModel)this.discreteNaiveBayesTrainer.fit(datasetBuilder, extractor.map(CompoundNaiveBayesTrainer.skipFeatures(this.discreteFeatureIdsToSkip))) : this.discreteNaiveBayesTrainer.update(mdl.getDiscreteModel(), datasetBuilder, extractor.map(CompoundNaiveBayesTrainer.skipFeatures(this.discreteFeatureIdsToSkip)));
            compoundModel.withDiscreteModel((DiscreteNaiveBayesModel)model).withDiscreteFeatureIdsToSkip(this.discreteFeatureIdsToSkip).withLabels(((DiscreteNaiveBayesModel)model).getLabels()).withPriorProbabilities(this.priorProbabilities);
        }
        return compoundModel;
    }

    public CompoundNaiveBayesTrainer withPriorProbabilities(double[] priorProbabilities) {
        this.priorProbabilities = (double[])priorProbabilities.clone();
        return this;
    }

    public CompoundNaiveBayesTrainer withGaussianNaiveBayesTrainer(GaussianNaiveBayesTrainer gaussianNaiveBayesTrainer) {
        this.gaussianNaiveBayesTrainer = gaussianNaiveBayesTrainer;
        return this;
    }

    public CompoundNaiveBayesTrainer withDiscreteNaiveBayesTrainer(DiscreteNaiveBayesTrainer discreteNaiveBayesTrainer) {
        this.discreteNaiveBayesTrainer = discreteNaiveBayesTrainer;
        return this;
    }

    public CompoundNaiveBayesTrainer withGaussianFeatureIdsToSkip(Collection<Integer> gaussianFeatureIdsToSkip) {
        this.gaussianFeatureIdsToSkip = gaussianFeatureIdsToSkip;
        return this;
    }

    public CompoundNaiveBayesTrainer withDiscreteFeatureIdsToSkip(Collection<Integer> discreteFeatureIdsToSkip) {
        this.discreteFeatureIdsToSkip = discreteFeatureIdsToSkip;
        return this;
    }

    private static IgniteFunction<LabeledVector<Object>, LabeledVector<Object>> skipFeatures(Collection<Integer> featureIdsToSkip) {
        return featureValues -> {
            int size = featureValues.features().size();
            int newSize = size - featureIdsToSkip.size();
            double[] newFeaturesValues = new double[newSize];
            int index = 0;
            for (int j = 0; j < size; ++j) {
                if (featureIdsToSkip.contains(j)) continue;
                newFeaturesValues[index] = featureValues.get(j);
                ++index;
            }
            return new LabeledVector(VectorUtils.of(newFeaturesValues), featureValues.label());
        };
    }
}

