package hex.adaboost;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import org.apache.log4j.Logger;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Key;
import water.Keyed;

/* loaded from: input_file:hex/adaboost/AdaBoostModel.class */
public class AdaBoostModel extends Model<AdaBoostModel, AdaBoostParameters, AdaBoostOutput> {
    private static final Logger LOG = Logger.getLogger(AdaBoostModel.class);

    /* loaded from: input_file:hex/adaboost/AdaBoostModel$AdaBoostOutput.class */
    public static class AdaBoostOutput extends Model.Output {
        public double[] alphas;
        public Key<Model>[] models;

        public AdaBoostOutput(AdaBoost adaBoost) {
            super(adaBoost);
        }
    }

    /* loaded from: input_file:hex/adaboost/AdaBoostModel$AdaBoostParameters.class */
    public static class AdaBoostParameters extends Model.Parameters {
        public int _nlearners = 50;
        public Algorithm _weak_learner = Algorithm.AUTO;
        public double _learn_rate = 0.5d;
        public String _weak_learner_params = "";

        public String algoName() {
            return "AdaBoost";
        }

        public String fullName() {
            return "AdaBoost";
        }

        public String javaName() {
            return AdaBoostModel.class.getName();
        }

        public long progressUnits() {
            return this._nlearners;
        }
    }

    /* loaded from: input_file:hex/adaboost/AdaBoostModel$Algorithm.class */
    public enum Algorithm {
        DRF,
        GLM,
        GBM,
        DEEP_LEARNING,
        AUTO
    }

    public AdaBoostModel(Key<AdaBoostModel> key, AdaBoostParameters adaBoostParameters, AdaBoostOutput adaBoostOutput) {
        super(key, adaBoostParameters, adaBoostOutput);
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        if (((AdaBoostOutput) this._output).getModelCategory() == ModelCategory.Binomial) {
            return new ModelMetricsBinomial.MetricBuilderBinomial(strArr);
        }
        throw H2O.unimpl("AdaBoost currently support only binary classification");
    }

    protected String[] makeScoringNames() {
        return new String[]{"predict", "p0", "p1"};
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < ((AdaBoostOutput) this._output).alphas.length; i++) {
            if (DKV.getGet(((AdaBoostOutput) this._output).models[i]).score(dArr) == 0.0d) {
                d3 += ((AdaBoostOutput) this._output).alphas[i] * (-1.0d);
                d += ((AdaBoostOutput) this._output).alphas[i];
            } else {
                d3 += ((AdaBoostOutput) this._output).alphas[i];
                d2 += ((AdaBoostOutput) this._output).alphas[i];
            }
        }
        dArr2[0] = d > d2 ? 0.0d : 1.0d;
        dArr2[2] = 1.0d / (1.0d + Math.exp((-2.0d) * d3));
        dArr2[1] = 1.0d - dArr2[2];
        return dArr2;
    }

    protected boolean needsPostProcess() {
        return false;
    }

    protected Futures remove_impl(Futures futures, boolean z) {
        for (Key<Model> key : ((AdaBoostOutput) this._output).models) {
            Keyed.remove(key, futures, true);
        }
        return super.remove_impl(futures, z);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer autoBuffer) {
        for (Key<Model> key : ((AdaBoostOutput) this._output).models) {
            autoBuffer.putKey(key);
        }
        return super.writeAll_impl(autoBuffer);
    }

    protected Keyed readAll_impl(AutoBuffer autoBuffer, Futures futures) {
        for (Key<Model> key : ((AdaBoostOutput) this._output).models) {
            autoBuffer.getKey(key, futures);
        }
        return super.readAll_impl(autoBuffer, futures);
    }
}
