package hex.generic;

import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsAutoEncoder;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsClustering;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsOrdinal;
import hex.ModelMetricsRegression;
import hex.ModelMetricsRegressionCoxPH;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.GenModel;
import hex.genmodel.ModelMojoReader;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackendFactory;
import hex.genmodel.algos.kmeans.KMeansMojoModel;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.tree.isofor.ModelMetricsAnomaly;
import java.io.IOException;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.MRTask;
import water.fvec.ByteVec;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RowDataUtils;

/* loaded from: input_file:hex/generic/GenericModel.class */
public class GenericModel extends Model<GenericModel, GenericModelParameters, GenericModelOutput> implements Model.Contributions {
    private final MojoModelSource _mojoModelSource;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/generic/GenericModel$GenericScoreContributionsTask.class */
    public class GenericScoreContributionsTask extends MRTask<GenericScoreContributionsTask> {
        private transient EasyPredictModelWrapper _wrapper;

        GenericScoreContributionsTask(EasyPredictModelWrapper easyPredictModelWrapper) {
            this._wrapper = easyPredictModelWrapper;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void setupLocal() {
            if (this._wrapper == null) {
                this._wrapper = GenericModel.this.makeWrapperWithContributions();
            }
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            try {
                predict(chunkArr, newChunkArr);
            } catch (PredictException e) {
                throw new RuntimeException(e);
            }
        }

        private void predict(Chunk[] chunkArr, NewChunk[] newChunkArr) throws PredictException {
            RowData rowData = new RowData();
            byte[] types = this._fr.types();
            for (int i = 0; i < chunkArr[0]._len; i++) {
                RowDataUtils.extractChunkRow(chunkArr, this._fr._names, types, i, rowData);
                NewChunk.addNums(newChunkArr, this._wrapper.predictContributions(rowData));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/generic/GenericModel$MetricBuilderGeneric.class */
    public static class MetricBuilderGeneric extends ModelMetrics.MetricBuilder<MetricBuilderGeneric> {
        private MetricBuilderGeneric(int i) {
            this._work = new double[i];
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public double[] perRow(double[] dArr, float[] fArr, Model model) {
            return dArr;
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public ModelMetrics makeModelMetrics(Model model, Frame frame, Frame frame2, Frame frame3) {
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/generic/GenericModel$MojoModelSource.class */
    public static class MojoModelSource extends Iced<MojoModelSource> {
        private final Key<Frame> _mojoSource;
        private transient MojoModel _mojoModel;
        static final /* synthetic */ boolean $assertionsDisabled;

        MojoModelSource(Key<Frame> key, MojoModel mojoModel) {
            this._mojoSource = key;
            this._mojoModel = mojoModel;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public ByteVec mojoByteVec() {
            return (ByteVec) this._mojoSource.get().anyVec();
        }

        MojoModel get() {
            if (this._mojoModel == null) {
                synchronized (this) {
                    if (this._mojoModel == null) {
                        this._mojoModel = GenericModel.reconstructMojo(mojoByteVec());
                    }
                }
            }
            if ($assertionsDisabled || this._mojoModel != null) {
                return this._mojoModel;
            }
            throw new AssertionError();
        }

        static {
            $assertionsDisabled = !GenericModel.class.desiredAssertionStatus();
        }
    }

    public GenericModel(Key<GenericModel> key, GenericModelParameters genericModelParameters, GenericModelOutput genericModelOutput, MojoModel mojoModel, Key<Frame> key2) {
        super(key, genericModelParameters, genericModelOutput);
        this._mojoModelSource = new MojoModelSource(key2, mojoModel);
        this._output = new GenericModelOutput(mojoModel._modelDescriptor, mojoModel._modelAttributes, mojoModel._reproducibilityInformation);
        if (mojoModel._modelAttributes == null || mojoModel._modelAttributes.getModelParameters() == null) {
            return;
        }
        ((GenericModelParameters) this._parms)._modelParameters = GenericModelParameters.convertParameters(mojoModel._modelAttributes.getModelParameters());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static MojoModel reconstructMojo(ByteVec byteVec) {
        try {
            return ModelMojoReader.readFrom(MojoReaderBackendFactory.createReaderBackend(byteVec.openStream(null), MojoReaderBackendFactory.CachingStrategy.MEMORY), true);
        } catch (IOException e) {
            throw new IllegalStateException("Unreachable MOJO file: " + byteVec._key, e);
        }
    }

    @Override // hex.Model
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        switch (((GenericModelOutput) this._output).getModelCategory()) {
            case Unknown:
                throw new IllegalStateException("Model category is unknown");
            case Binomial:
                return new ModelMetricsBinomial.MetricBuilderBinomial(strArr);
            case Multinomial:
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((GenericModelOutput) this._output).nclasses(), strArr, ((GenericModelParameters) this._parms)._auc_type);
            case Ordinal:
                return new ModelMetricsOrdinal.MetricBuilderOrdinal(((GenericModelOutput) this._output).nclasses(), strArr);
            case Regression:
                return new ModelMetricsRegression.MetricBuilderRegression();
            case Clustering:
                if (!(mojoModel() instanceof KMeansMojoModel)) {
                    return unsupportedMetricsBuilder();
                }
                return new ModelMetricsClustering.MetricBuilderClustering(((GenericModelOutput) this._output).nfeatures(), ((KMeansMojoModel) mojoModel()).getNumClusters());
            case AutoEncoder:
                return new ModelMetricsAutoEncoder.MetricBuilderAutoEncoder(((GenericModelOutput) this._output).nfeatures());
            case DimReduction:
                return unsupportedMetricsBuilder();
            case WordEmbedding:
                return unsupportedMetricsBuilder();
            case CoxPH:
                return new ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH("start", "stop", false, new String[0]);
            case AnomalyDetection:
                return new ModelMetricsAnomaly.MetricBuilderAnomaly();
            default:
                throw H2O.unimpl();
        }
    }

    private ModelMetrics.MetricBuilder unsupportedMetricsBuilder() {
        if (!((GenericModelParameters) this._parms)._disable_algo_check) {
            throw new UnsupportedOperationException(((GenericModelOutput) this._output)._modelCategory + " is not supported.");
        }
        Log.warn("Model category `" + ((GenericModelOutput) this._output)._modelCategory + "` currently doesn't support calculating model metrics. Model metrics will not be available.");
        return new MetricBuilderGeneric(mojoModel().getPredsSize(((GenericModelOutput) this._output)._modelCategory));
    }

    @Override // hex.Model
    protected double[] score0(double[] dArr, double[] dArr2) {
        return mojoModel().score0(dArr, dArr2);
    }

    @Override // hex.Model
    protected double[] score0(double[] dArr, double[] dArr2, double d) {
        return d == CMAESOptimizer.DEFAULT_STOPFITNESS ? score0(dArr, dArr2) : mojoModel().score0(dArr, d, dArr2);
    }

    @Override // hex.Model
    protected Model.AdaptFrameParameters makeAdaptFrameParameters() {
        final MojoModel mojoModel = mojoModel();
        CategoricalEncoding categoricalEncoding = mojoModel.getCategoricalEncoding();
        if (categoricalEncoding.isParametrized()) {
            throw new UnsupportedOperationException("Models with categorical encoding '" + categoricalEncoding + "' are not currently supported for predicting and/or calculating metrics.");
        }
        final Model.Parameters.CategoricalEncodingScheme fromGenModel = Model.Parameters.CategoricalEncodingScheme.fromGenModel(categoricalEncoding);
        final ModelDescriptor modelDescriptor = mojoModel._modelDescriptor;
        return new Model.AdaptFrameParameters() { // from class: hex.generic.GenericModel.1
            @Override // hex.Model.AdaptFrameParameters
            public Model.Parameters.CategoricalEncodingScheme getCategoricalEncoding() {
                return fromGenModel;
            }

            @Override // hex.Model.AdaptFrameParameters
            public String getWeightsColumn() {
                return modelDescriptor.weightsColumn();
            }

            @Override // hex.Model.AdaptFrameParameters
            public String getOffsetColumn() {
                return modelDescriptor.offsetColumn();
            }

            @Override // hex.Model.AdaptFrameParameters
            public String getFoldColumn() {
                return modelDescriptor.foldColumn();
            }

            @Override // hex.Model.AdaptFrameParameters
            public String getResponseColumn() {
                if (mojoModel.isSupervised()) {
                    return mojoModel.getResponseName();
                }
                return null;
            }

            @Override // hex.Model.AdaptFrameParameters
            public double missingColumnsType() {
                return Double.NaN;
            }

            @Override // hex.Model.AdaptFrameParameters
            public int getMaxCategoricalLevels() {
                return -1;
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model
    public String[] makeScoringNames() {
        return mojoModel().getOutputNames();
    }

    @Override // hex.Model
    protected boolean needsPostProcess() {
        return false;
    }

    @Override // hex.Model
    public GenericModelMojoWriter getMojo() {
        return new GenericModelMojoWriter(this._mojoModelSource.mojoByteVec());
    }

    private MojoModel mojoModel() {
        return this._mojoModelSource.get();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, water.Keyed
    public Futures remove_impl(Futures futures, boolean z) {
        Frame frame;
        if (((GenericModelParameters) this._parms)._path != null && (frame = (Frame) this._mojoModelSource._mojoSource.get()) != null) {
            frame.remove(futures, z);
        }
        return super.remove_impl(futures, z);
    }

    @Override // hex.Model.Contributions
    public Frame scoreContributions(Frame frame, Key<Frame> key) {
        return scoreContributions(frame, key, null);
    }

    @Override // hex.Model.Contributions
    public Frame scoreContributions(Frame frame, Key<Frame> key, Job<Frame> job) {
        EasyPredictModelWrapper makeWrapperWithContributions = makeWrapperWithContributions();
        Frame frame2 = new Frame(frame);
        GenModel model = makeWrapperWithContributions.getModel();
        frame2.remove(ArrayUtils.difference(frame._names, model.getOrigNames() != null ? model.getOrigNames() : model.getNames()));
        String[] contributionNames = makeWrapperWithContributions.getContributionNames();
        return new GenericScoreContributionsTask(makeWrapperWithContributions).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(contributionNames.length, (byte) 3, frame2).outputFrame(key, contributionNames, (String[][]) null);
    }

    EasyPredictModelWrapper makeWrapperWithContributions() {
        try {
            return new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(mojoModel()).setConvertUnknownCategoricalLevelsToNa(true).setEnableContributions(true));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
