package hex;

import hex.Model;
import java.util.Arrays;
import jsr166y.CountedCompleter;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/PartialDependence.class */
public class PartialDependence extends Lockable<PartialDependence> {
    public final transient Job _job;
    public Key<Model> _model_id;
    public Key<Frame> _frame_id;
    public String[] _cols;
    public int _nbins;
    public TwoDimTable[] _partial_dependence_data;

    /* loaded from: input_file:hex/PartialDependence$PartialDependenceDriver.class */
    private class PartialDependenceDriver extends H2O.H2OCountedCompleter<PartialDependenceDriver> {
        static final /* synthetic */ boolean $assertionsDisabled;

        private PartialDependenceDriver() {
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            if (!$assertionsDisabled && PartialDependence.this._job == null) {
                throw new AssertionError();
            }
            Frame frame = PartialDependence.this._frame_id.get();
            PartialDependence.this._partial_dependence_data = new TwoDimTable[PartialDependence.this._cols.length];
            for (int i = 0; i < PartialDependence.this._cols.length; i++) {
                final String str = PartialDependence.this._cols[i];
                Log.debug("Computing partial dependence of model on '" + str + "'.");
                Vec vec = frame.vec(str);
                int i2 = PartialDependence.this._nbins;
                if (vec.isInt() && (vec.max() - vec.min()) + 1.0d < PartialDependence.this._nbins) {
                    i2 = (int) ((vec.max() - vec.min()) + 1.0d);
                }
                double[] dArr = new double[i2];
                double max = (vec.max() - vec.min()) / (i2 - 1);
                if (i2 == 1) {
                    max = 0.0d;
                }
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    dArr[i3] = vec.min() + (i3 * max);
                }
                Log.debug("Computing PartialDependence for column " + str + " at the following values: ");
                Log.debug(Arrays.toString(dArr));
                Futures futures = new Futures();
                final double[] dArr2 = new double[dArr.length];
                final double[] dArr3 = new double[dArr.length];
                final double[] dArr4 = new double[dArr.length];
                final boolean isCategorical = frame.vec(str).isCategorical();
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    final double d = dArr[i4];
                    final int i5 = i4;
                    futures.add(H2O.submitTask(new H2O.H2OCountedCompleter() { // from class: hex.PartialDependence.PartialDependenceDriver.1
                        @Override // water.H2O.H2OCountedCompleter
                        public void compute2() {
                            Frame frame2 = PartialDependence.this._frame_id.get();
                            Frame frame3 = new Frame(frame2.names(), frame2.vecs());
                            Vec makeCon = frame3.remove(str).makeCon(d);
                            if (isCategorical) {
                                makeCon.setDomain(frame2.vec(str).domain());
                            }
                            frame3.add(str, makeCon);
                            Keyed keyed = null;
                            try {
                                Frame score = PartialDependence.this._model_id.get().score(frame3, Key.make().toString(), PartialDependence.this._job, false);
                                if (PartialDependence.this._model_id.get()._output.nclasses() == 2) {
                                    dArr2[i5] = score.vec(2).mean();
                                    dArr3[i5] = score.vec(2).sigma();
                                    dArr4[i5] = score.vec(2).sigma() / Math.sqrt(score.numRows());
                                } else {
                                    if (PartialDependence.this._model_id.get()._output.nclasses() != 1) {
                                        throw H2O.unimpl();
                                    }
                                    dArr2[i5] = score.vec(0).mean();
                                    dArr3[i5] = score.vec(0).sigma();
                                    dArr4[i5] = score.vec(0).sigma() / Math.sqrt(score.numRows());
                                }
                                if (score != null) {
                                    score.remove();
                                }
                                makeCon.remove();
                                tryComplete();
                            } catch (Throwable th) {
                                if (0 != 0) {
                                    keyed.remove();
                                }
                                throw th;
                            }
                        }
                    }));
                }
                futures.blockForPending();
                TwoDimTable[] twoDimTableArr = PartialDependence.this._partial_dependence_data;
                int i6 = i;
                String str2 = "Partial Dependence Plot of model " + PartialDependence.this._model_id + " on column '" + PartialDependence.this._cols[i] + "'";
                String[] strArr = new String[i2];
                String[] strArr2 = {PartialDependence.this._cols[i], "mean_response", "stddev_response", "std_error_mean_response"};
                String[] strArr3 = new String[4];
                strArr3[0] = isCategorical ? "string" : "double";
                strArr3[1] = "double";
                strArr3[2] = "double";
                strArr3[3] = "double";
                String[] strArr4 = new String[4];
                strArr4[0] = isCategorical ? "%s" : "%5f";
                strArr4[1] = "%5f";
                strArr4[2] = "%5f";
                strArr4[3] = "%5f";
                twoDimTableArr[i6] = new TwoDimTable("PartialDependence", str2, strArr, strArr2, strArr3, strArr4, null);
                for (int i7 = 0; i7 < dArr2.length; i7++) {
                    if (frame.vec(str).isCategorical()) {
                        PartialDependence.this._partial_dependence_data[i].set(i7, 0, frame.vec(str).domain()[(int) dArr[i7]]);
                    } else {
                        PartialDependence.this._partial_dependence_data[i].set(i7, 0, Double.valueOf(dArr[i7]));
                    }
                    PartialDependence.this._partial_dependence_data[i].set(i7, 1, Double.valueOf(dArr2[i7]));
                    PartialDependence.this._partial_dependence_data[i].set(i7, 2, Double.valueOf(dArr3[i7]));
                    PartialDependence.this._partial_dependence_data[i].set(i7, 3, Double.valueOf(dArr4[i7]));
                }
                PartialDependence.this._job.update(1L);
                PartialDependence.this.update(PartialDependence.this._job);
                if (PartialDependence.this._job.stop_requested()) {
                    break;
                }
            }
            tryComplete();
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // jsr166y.CountedCompleter
        public void onCompletion(CountedCompleter countedCompleter) {
            PartialDependence.this._frame_id.get().unlock((Key<Job>) PartialDependence.this._job._key);
            PartialDependence.this.unlock(PartialDependence.this._job);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // jsr166y.CountedCompleter
        public boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
            PartialDependence.this._frame_id.get().unlock((Key<Job>) PartialDependence.this._job._key);
            PartialDependence.this.unlock(PartialDependence.this._job);
            return true;
        }

        static {
            $assertionsDisabled = !PartialDependence.class.desiredAssertionStatus();
        }
    }

    public PartialDependence(Key<PartialDependence> key, Job job) {
        super(key);
        this._nbins = 20;
        this._job = job;
    }

    public PartialDependence(Key<PartialDependence> key) {
        this(key, new Job(key, PartialDependence.class.getName(), "PartialDependence"));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public PartialDependence execNested() {
        checkSanityAndFillParams();
        delete_and_lock(this._job);
        this._frame_id.get().write_lock((Key<Job>) this._job._key);
        new PartialDependenceDriver().compute2();
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Job<PartialDependence> execImpl() {
        checkSanityAndFillParams();
        delete_and_lock(this._job);
        this._frame_id.get().write_lock((Key<Job>) this._job._key);
        this._job.start(new PartialDependenceDriver(), this._cols.length);
        return this._job;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void checkSanityAndFillParams() {
        if (this._cols == null) {
            Model model = this._model_id.get();
            if (model == 0) {
                throw new IllegalArgumentException("Model not found.");
            }
            if (!model._output.isSupervised() || model._output.nclasses() > 2) {
                throw new IllegalArgumentException("Partial dependence plots are only implemented for regression and binomial classification models");
            }
            if (this._frame_id.get() == null) {
                throw new IllegalArgumentException("Frame not found.");
            }
            if (Model.GetMostImportantFeatures.class.isAssignableFrom(model.getClass())) {
                this._cols = ((Model.GetMostImportantFeatures) model).getMostImportantFeatures(10);
                if (this._cols != null) {
                    Log.info("Selecting the top " + this._cols.length + " features from the model's variable importances");
                }
            }
        }
        if (this._nbins < 2) {
            throw new IllegalArgumentException("_nbins must be >=2.");
        }
        Frame frame = this._frame_id.get();
        for (int i = 0; i < this._cols.length; i++) {
            String str = this._cols[i];
            Vec vec = frame.vec(str);
            if (vec.isCategorical() && vec.cardinality() > this._nbins) {
                throw new IllegalArgumentException("Column " + str + "'s cardinality of " + vec.cardinality() + " > nbins of " + this._nbins);
            }
        }
    }

    @Override // water.Keyed
    public Class<KeyV3.PartialDependenceKeyV3> makeSchema() {
        return KeyV3.PartialDependenceKeyV3.class;
    }
}
