package hex.adaboost;

import com.google.gson.FieldNamingPolicy;
import com.google.gson.FieldNamingStrategy;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import com.google.gson.JsonSyntaxException;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.adaboost.AdaBoostModel;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.dt.binning.NumericBin;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import org.apache.log4j.Logger;
import water.DKV;
import water.Key;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Timer;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/adaboost/AdaBoost.class */
public class AdaBoost extends ModelBuilder<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput> {
    private static final Logger LOG = Logger.getLogger(AdaBoost.class);
    private static final int MAX_LEARNERS = 100000;
    private AdaBoostModel _model;
    private String _weightsName;
    private Gson _gsonParser;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.adaboost.AdaBoost$1, reason: invalid class name */
    /* loaded from: input_file:hex/adaboost/AdaBoost$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$adaboost$AdaBoostModel$Algorithm = new int[AdaBoostModel.Algorithm.values().length];

        static {
            try {
                $SwitchMap$hex$adaboost$AdaBoostModel$Algorithm[AdaBoostModel.Algorithm.GLM.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$adaboost$AdaBoostModel$Algorithm[AdaBoostModel.Algorithm.GBM.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$adaboost$AdaBoostModel$Algorithm[AdaBoostModel.Algorithm.DEEP_LEARNING.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$adaboost$AdaBoostModel$Algorithm[AdaBoostModel.Algorithm.DRF.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:hex/adaboost/AdaBoost$AdaBoostDriver.class */
    private class AdaBoostDriver extends ModelBuilder<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput>.Driver {
        private AdaBoostDriver() {
            super(AdaBoost.this);
        }

        public void computeImpl() {
            AdaBoost.this._model = null;
            try {
                AdaBoost.this.init(true);
                if (AdaBoost.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(AdaBoost.this);
                }
                AdaBoost.this._model = new AdaBoostModel(AdaBoost.this.dest(), (AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms, new AdaBoostModel.AdaBoostOutput(AdaBoost.this));
                AdaBoost.this._model.delete_and_lock(AdaBoost.this._job);
                buildAdaboost();
                AdaBoost.LOG.info(AdaBoost.this._model.toString());
            } finally {
                if (AdaBoost.this._model != null) {
                    AdaBoost.this._model.unlock(AdaBoost.this._job);
                }
            }
        }

        private void buildAdaboost() {
            Frame train;
            ((AdaBoostModel.AdaBoostOutput) AdaBoost.this._model._output).alphas = new double[((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms)._nlearners];
            ((AdaBoostModel.AdaBoostOutput) AdaBoost.this._model._output).models = new Key[((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms)._nlearners];
            if (((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms)._weights_column == null) {
                train = new Frame(AdaBoost.this.train());
                Vec vec = train.anyVec().makeCons(1, 1L, (String[][]) null, (byte[]) null)[0];
                AdaBoost.this._weightsName = train.uniquify(AdaBoost.this._weightsName);
                train.add(AdaBoost.this._weightsName, vec);
                DKV.put(train);
                Scope.track(vec);
            } else {
                train = ((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms).train();
            }
            for (int i = 0; i < ((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms)._nlearners; i++) {
                Timer timer = new Timer();
                ModelBuilder chooseWeakLearner = AdaBoost.this.chooseWeakLearner(train);
                chooseWeakLearner._parms._seed += i;
                Model model = chooseWeakLearner.trainModel().get();
                DKV.put(model);
                Scope.untrack(new Key[]{model._key});
                ((AdaBoostModel.AdaBoostOutput) AdaBoost.this._model._output).models[i] = model._key;
                Frame score = model.score(train);
                Scope.track(new Frame[]{score});
                CountWeTask countWeTask = (CountWeTask) new CountWeTask().doAll(new Vec[]{train.vec(AdaBoost.this._weightsName), train.vec(((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms)._response_column), score.vec("predict")});
                double d = countWeTask.We / countWeTask.W;
                double log = ((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms)._learn_rate * Math.log((1.0d - d) / d);
                ((AdaBoostModel.AdaBoostOutput) AdaBoost.this._model._output).alphas[i] = log;
                new UpdateWeightsTask(log).doAll(new Vec[]{train.vec(AdaBoost.this._weightsName), train.vec(((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms)._response_column), score.vec("predict")});
                AdaBoost.this._job.update(1L);
                AdaBoost.this._model.update(AdaBoost.this._job);
                AdaBoost.LOG.info((i + 1) + ". estimator was built in " + timer.toString());
                AdaBoost.LOG.info("*********************************************************************");
            }
            if (train != ((AdaBoostModel.AdaBoostParameters) AdaBoost.this._parms).train()) {
                DKV.remove(train._key);
            }
            ((AdaBoostModel.AdaBoostOutput) AdaBoost.this._model._output)._model_summary = AdaBoost.this.createModelSummaryTable();
        }

        /* synthetic */ AdaBoostDriver(AdaBoost adaBoost, AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/adaboost/AdaBoost$PrecedingUnderscoreNamingStrategy.class */
    public class PrecedingUnderscoreNamingStrategy implements FieldNamingStrategy {
        private PrecedingUnderscoreNamingStrategy() {
        }

        public String translateName(Field field) {
            String translateName = FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES.translateName(field);
            if (translateName.startsWith("_")) {
                translateName = translateName.substring(1);
            }
            return translateName;
        }

        /* synthetic */ PrecedingUnderscoreNamingStrategy(AdaBoost adaBoost, AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    public AdaBoost(AdaBoostModel.AdaBoostParameters adaBoostParameters) {
        super(adaBoostParameters);
        this._weightsName = "weights";
        init(false);
    }

    public AdaBoost(boolean z) {
        super(new AdaBoostModel.AdaBoostParameters(), z);
        this._weightsName = "weights";
    }

    public boolean havePojo() {
        return false;
    }

    public boolean haveMojo() {
        return false;
    }

    public void init(boolean z) {
        super.init(z);
        if (((AdaBoostModel.AdaBoostParameters) this._parms)._nlearners < 1 || ((AdaBoostModel.AdaBoostParameters) this._parms)._nlearners > 100000) {
            error("n_estimators", "Parameter n_estimators must be in interval [1, 100000] but it is " + ((AdaBoostModel.AdaBoostParameters) this._parms)._nlearners);
        }
        if (((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner == AdaBoostModel.Algorithm.AUTO) {
            ((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner = AdaBoostModel.Algorithm.DRF;
        }
        if (((AdaBoostModel.AdaBoostParameters) this._parms)._weights_column != null) {
            this._weightsName = ((AdaBoostModel.AdaBoostParameters) this._parms)._weights_column;
        }
        if (0.0d >= ((AdaBoostModel.AdaBoostParameters) this._parms)._learn_rate || ((AdaBoostModel.AdaBoostParameters) this._parms)._learn_rate > 1.0d) {
            error("learn_rate", "learn_rate must be between 0 and 1");
        }
        if (useCustomWeakLearnerParameters()) {
            try {
                this._gsonParser = new GsonBuilder().setFieldNamingStrategy(new PrecedingUnderscoreNamingStrategy(this, null)).create();
                this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner_params, JsonObject.class);
            } catch (JsonSyntaxException e) {
                error("weak_learner_params", "Provided parameters are not in the valid json format. Got error: " + e.getMessage());
            }
        }
    }

    private boolean useCustomWeakLearnerParameters() {
        return (((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner_params == null || ((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner_params.isEmpty()) ? false : true;
    }

    protected ModelBuilder<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput>.Driver trainModelImpl() {
        return new AdaBoostDriver(this, null);
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial};
    }

    public boolean isSupervised() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public ModelBuilder chooseWeakLearner(Frame frame) {
        switch (AnonymousClass1.$SwitchMap$hex$adaboost$AdaBoostModel$Algorithm[((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner.ordinal()]) {
            case 1:
                return getGLMWeakLearner(frame);
            case 2:
                return getGBMWeakLearner(frame);
            case NumericBin.MIN_INDEX /* 3 */:
                return getDeepLearningWeakLearner(frame);
            case NumericBin.MAX_INDEX /* 4 */:
            default:
                return getDRFWeakLearner(frame);
        }
    }

    private DRF getDRFWeakLearner(Frame frame) {
        DRFModel.DRFParameters dRFParameters = useCustomWeakLearnerParameters() ? (DRFModel.DRFParameters) this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner_params, DRFModel.DRFParameters.class) : new DRFModel.DRFParameters();
        dRFParameters._train = frame._key;
        dRFParameters._response_column = ((AdaBoostModel.AdaBoostParameters) this._parms)._response_column;
        dRFParameters._weights_column = this._weightsName;
        dRFParameters._seed = ((AdaBoostModel.AdaBoostParameters) this._parms)._seed;
        if (!useCustomWeakLearnerParameters()) {
            dRFParameters._mtries = 1;
            dRFParameters._min_rows = 1.0d;
            dRFParameters._ntrees = 1;
            dRFParameters._sample_rate = 1.0d;
            dRFParameters._max_depth = 1;
        }
        return new DRF(dRFParameters);
    }

    private GLM getGLMWeakLearner(Frame frame) {
        GLMModel.GLMParameters gLMParameters = useCustomWeakLearnerParameters() ? (GLMModel.GLMParameters) this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner_params, GLMModel.GLMParameters.class) : new GLMModel.GLMParameters();
        gLMParameters._train = frame._key;
        gLMParameters._response_column = ((AdaBoostModel.AdaBoostParameters) this._parms)._response_column;
        gLMParameters._weights_column = this._weightsName;
        gLMParameters._seed = ((AdaBoostModel.AdaBoostParameters) this._parms)._seed;
        return new GLM(gLMParameters);
    }

    private GBM getGBMWeakLearner(Frame frame) {
        GBMModel.GBMParameters gBMParameters = useCustomWeakLearnerParameters() ? (GBMModel.GBMParameters) this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner_params, GBMModel.GBMParameters.class) : new GBMModel.GBMParameters();
        gBMParameters._train = frame._key;
        gBMParameters._response_column = ((AdaBoostModel.AdaBoostParameters) this._parms)._response_column;
        gBMParameters._weights_column = this._weightsName;
        if (!useCustomWeakLearnerParameters()) {
            gBMParameters._min_rows = 1.0d;
            gBMParameters._ntrees = 1;
            gBMParameters._sample_rate = 1.0d;
            gBMParameters._max_depth = 1;
            gBMParameters._seed = ((AdaBoostModel.AdaBoostParameters) this._parms)._seed;
        }
        return new GBM(gBMParameters);
    }

    private DeepLearning getDeepLearningWeakLearner(Frame frame) {
        DeepLearningModel.DeepLearningParameters deepLearningParameters = useCustomWeakLearnerParameters() ? (DeepLearningModel.DeepLearningParameters) this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner_params, DeepLearningModel.DeepLearningParameters.class) : new DeepLearningModel.DeepLearningParameters();
        deepLearningParameters._train = frame._key;
        deepLearningParameters._response_column = ((AdaBoostModel.AdaBoostParameters) this._parms)._response_column;
        deepLearningParameters._weights_column = this._weightsName;
        deepLearningParameters._seed = ((AdaBoostModel.AdaBoostParameters) this._parms)._seed;
        if (!useCustomWeakLearnerParameters()) {
            deepLearningParameters._epochs = 10.0d;
            deepLearningParameters._hidden = new int[]{2};
        }
        return new DeepLearning(deepLearningParameters);
    }

    public TwoDimTable createModelSummaryTable() {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        arrayList.add("Number of weak learners");
        arrayList2.add("int");
        arrayList3.add("%d");
        arrayList.add("Learn rate");
        arrayList2.add("int");
        arrayList3.add("%d");
        arrayList.add("Weak learner");
        arrayList2.add("int");
        arrayList3.add("%d");
        arrayList.add("Seed");
        arrayList2.add("long");
        arrayList3.add("%d");
        TwoDimTable twoDimTable = new TwoDimTable("Model Summary", (String) null, new String[1], (String[]) arrayList.toArray(new String[0]), (String[]) arrayList2.toArray(new String[0]), (String[]) arrayList3.toArray(new String[0]), "");
        int i = 0 + 1;
        twoDimTable.set(0, 0, Integer.valueOf(((AdaBoostModel.AdaBoostParameters) this._parms)._nlearners));
        int i2 = i + 1;
        twoDimTable.set(0, i, Double.valueOf(((AdaBoostModel.AdaBoostParameters) this._parms)._learn_rate));
        twoDimTable.set(0, i2, ((AdaBoostModel.AdaBoostParameters) this._parms)._weak_learner.toString());
        twoDimTable.set(0, i2 + 1, Long.valueOf(((AdaBoostModel.AdaBoostParameters) this._parms)._seed));
        return twoDimTable;
    }
}
