package ai.h2o.sparkling.ml.models;

import ai.h2o.sparkling.ml.params.H2OSupervisedMOJOParams;
import ai.h2o.sparkling.ml.params.NullableStringParam;
import hex.ModelCategory;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import java.io.InputStream;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function1;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: H2OSupervisedMOJOModel.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005%b\u0001B\u0006\r\u0001]A\u0001B\t\u0001\u0003\u0006\u0004%\te\t\u0005\nc\u0001\u0011\t\u0011)A\u0005IIBQa\r\u0001\u0005\u0002QBQa\u000e\u0001\u0005BaBQa\u0011\u0001\u0005B\u0011CQ\u0001\u0018\u0001\u0005Ru;q!!\u0001\r\u0011\u0003\t\u0019A\u0002\u0004\f\u0019!\u0005\u0011Q\u0001\u0005\u0007g!!\t!a\u0005\t\u0013\u0005U\u0001\"!A\u0005\n\u0005]!A\u0006%3\u001fN+\b/\u001a:wSN,G-T(K\u001f6{G-\u001a7\u000b\u00055q\u0011AB7pI\u0016d7O\u0003\u0002\u0010!\u0005\u0011Q\u000e\u001c\u0006\u0003#I\t\u0011b\u001d9be.d\u0017N\\4\u000b\u0005M!\u0012a\u000153_*\tQ#\u0001\u0002bS\u000e\u00011c\u0001\u0001\u00199A\u0011\u0011DG\u0007\u0002\u0019%\u00111\u0004\u0004\u0002\r\u0011JzUj\u0014&P\u001b>$W\r\u001c\t\u0003;\u0001j\u0011A\b\u0006\u0003?9\ta\u0001]1sC6\u001c\u0018BA\u0011\u001f\u0005]A%gT*va\u0016\u0014h/[:fI6{%j\u0014)be\u0006l7/A\u0002vS\u0012,\u0012\u0001\n\t\u0003K9r!A\n\u0017\u0011\u0005\u001dRS\"\u0001\u0015\u000b\u0005%2\u0012A\u0002\u001fs_>$hHC\u0001,\u0003\u0015\u00198-\u00197b\u0013\ti#&\u0001\u0004Qe\u0016$WMZ\u0005\u0003_A\u0012aa\u0015;sS:<'BA\u0017+\u0003\u0011)\u0018\u000e\u001a\u0011\n\u0005\tR\u0012A\u0002\u001fj]&$h\b\u0006\u00026mA\u0011\u0011\u0004\u0001\u0005\u0006E\r\u0001\r\u0001J\u0001\u0012g\u0016$8\u000b]3dS\u001aL7\rU1sC6\u001cHCA\u001b:\u0011\u0015QD\u00011\u0001<\u0003%iwN[8N_\u0012,G\u000e\u0005\u0002=\u00036\tQH\u0003\u0002?\u007f\u0005Aq-\u001a8n_\u0012,GNC\u0001A\u0003\rAW\r_\u0005\u0003\u0005v\u0012\u0011\"T8k_6{G-\u001a7\u0002\u001fQ\u0014\u0018M\\:g_Jl7k\u00195f[\u0006$\"!R*\u0011\u0005\u0019\u000bV\"A$\u000b\u0005!K\u0015!\u0002;za\u0016\u001c(B\u0001&L\u0003\r\u0019\u0018\u000f\u001c\u0006\u0003\u00196\u000bQa\u001d9be.T!AT(\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005\u0001\u0016aA8sO&\u0011!k\u0012\u0002\u000b'R\u0014Xo\u0019;UsB,\u0007\"\u0002+\u0006\u0001\u0004)\u0015AB:dQ\u0016l\u0017\r\u000b\u0002\u0006-B\u0011qKW\u0007\u00021*\u0011\u0011lS\u0001\u000bC:tw\u000e^1uS>t\u0017BA.Y\u00051!UM^3m_B,'/\u00119j\u0003\u0005\n\u0007\u000f\u001d7z!J,G-[2uS>tW\u000b\u001a4U_\u001ac\u0017\r\u001e#bi\u00064%/Y7f)\u0011qVn\u001c@\u0011\u0005}SgB\u00011i\u001d\t\twM\u0004\u0002cM:\u00111-\u001a\b\u0003O\u0011L\u0011\u0001U\u0005\u0003\u001d>K!\u0001T'\n\u0005)[\u0015BA5J\u0003\u001d\u0001\u0018mY6bO\u0016L!a\u001b7\u0003\u0013\u0011\u000bG/\u0019$sC6,'BA5J\u0011\u0015qg\u00011\u0001_\u000351G.\u0019;ECR\fgI]1nK\")\u0001O\u0002a\u0001c\u0006qQ\u000f\u001a4D_:\u001cHO];di>\u0014\b\u0003\u0002:tkbl\u0011AK\u0005\u0003i*\u0012\u0011BR;oGRLwN\\\u0019\u0011\u0007I4H%\u0003\u0002xU\t)\u0011I\u001d:bsB\u0011\u0011\u0010`\u0007\u0002u*\u001110S\u0001\fKb\u0004(/Z:tS>t7/\u0003\u0002~u\n\u0019Rk]3s\t\u00164\u0017N\\3e\rVt7\r^5p]\")qP\u0002a\u0001k\u00061\u0011N\u001c9viN\fa\u0003\u0013\u001aP'V\u0004XM\u001d<jg\u0016$Wj\u0014&P\u001b>$W\r\u001c\t\u00033!\u0019R\u0001CA\u0004\u0003\u001b\u0001B!GA\u0005k%\u0019\u00111\u0002\u0007\u0003+!\u0013tj\u00159fG&4\u0017nY'P\u0015>cu.\u00193feB\u0019!/a\u0004\n\u0007\u0005E!F\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002\u0004\u0005Y!/Z1e%\u0016\u001cx\u000e\u001c<f)\t\tI\u0002\u0005\u0003\u0002\u001c\u0005\u0015RBAA\u000f\u0015\u0011\ty\"!\t\u0002\t1\fgn\u001a\u0006\u0003\u0003G\tAA[1wC&!\u0011qEA\u000f\u0005\u0019y%M[3di\u0002")
/* loaded from: input_file:ai/h2o/sparkling/ml/models/H2OSupervisedMOJOModel.class */
public class H2OSupervisedMOJOModel extends H2OMOJOModel implements H2OSupervisedMOJOParams {
    private final NullableStringParam offsetCol;

    public static HasMojo createFromMojo(InputStream inputStream, String str, H2OMOJOSettings h2OMOJOSettings) {
        return H2OSupervisedMOJOModel$.MODULE$.createFromMojo(inputStream, str, h2OMOJOSettings);
    }

    public static Object createFromMojo(InputStream inputStream, String str) {
        return H2OSupervisedMOJOModel$.MODULE$.createFromMojo(inputStream, str);
    }

    public static Object createFromMojo(String str, String str2, H2OMOJOSettings h2OMOJOSettings) {
        return H2OSupervisedMOJOModel$.MODULE$.createFromMojo(str, str2, h2OMOJOSettings);
    }

    public static Object createFromMojo(String str, String str2) {
        return H2OSupervisedMOJOModel$.MODULE$.createFromMojo(str, str2);
    }

    public static Object createFromMojo(String str, H2OMOJOSettings h2OMOJOSettings) {
        return H2OSupervisedMOJOModel$.MODULE$.createFromMojo(str, h2OMOJOSettings);
    }

    public static Object createFromMojo(String str) {
        return H2OSupervisedMOJOModel$.MODULE$.createFromMojo(str);
    }

    public static MLReader<H2OSupervisedMOJOModel> read() {
        return H2OSupervisedMOJOModel$.MODULE$.read();
    }

    public static Object load(String str) {
        return H2OSupervisedMOJOModel$.MODULE$.load(str);
    }

    @Override // ai.h2o.sparkling.ml.params.H2OSupervisedMOJOParams
    public String getOffsetCol() {
        String offsetCol;
        offsetCol = getOffsetCol();
        return offsetCol;
    }

    @Override // ai.h2o.sparkling.ml.params.H2OSupervisedMOJOParams
    public final NullableStringParam offsetCol() {
        return this.offsetCol;
    }

    @Override // ai.h2o.sparkling.ml.params.H2OSupervisedMOJOParams
    public final void ai$h2o$sparkling$ml$params$H2OSupervisedMOJOParams$_setter_$offsetCol_$eq(NullableStringParam nullableStringParam) {
        this.offsetCol = nullableStringParam;
    }

    @Override // ai.h2o.sparkling.ml.models.H2OMOJOModel
    public String uid() {
        return super.uid();
    }

    @Override // ai.h2o.sparkling.ml.models.H2OMOJOModel
    public H2OSupervisedMOJOModel setSpecificParams(MojoModel mojoModel) {
        super.setSpecificParams(mojoModel);
        set(offsetCol().$minus$greater(mojoModel._offsetColumn));
        return this;
    }

    @Override // ai.h2o.sparkling.ml.models.H2OMOJOModelBase
    @DeveloperApi
    public StructType transformSchema(StructType structType) {
        String offsetCol = getOffsetCol();
        if (offsetCol != null) {
            Predef$.MODULE$.require(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structType.fieldNames())).contains(offsetCol), () -> {
                return "Offset column must be present within the dataset!";
            });
        }
        return super.transformSchema(structType);
    }

    @Override // ai.h2o.sparkling.ml.models.H2OMOJOModel, ai.h2o.sparkling.ml.models.H2OMOJOModelBase, ai.h2o.sparkling.ml.models.H2OMOJOFlattenedInput
    public Dataset<Row> applyPredictionUdfToFlatDataFrame(Dataset<Row> dataset, Function1<String[], UserDefinedFunction> function1, String[] strArr) {
        Dataset<Row> withColumn;
        Dataset<Row> withColumn2;
        String[] relevantColumnNames = getRelevantColumnNames(dataset, strArr);
        Column[] columnArr = (Column[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(relevantColumnNames)).map(str -> {
            return dataset.apply(new StringBuilder(2).append("`").append(str).append("`").toString());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
        UserDefinedFunction userDefinedFunction = (UserDefinedFunction) function1.apply(relevantColumnNames);
        ModelCategory modelCategory = ((EasyPredictModelWrapper) H2OMOJOCache$.MODULE$.getMojoBackend(uid(), () -> {
            return this.getMojo();
        }, this)).getModelCategory();
        if (ModelCategory.Binomial.equals(modelCategory) ? true : ModelCategory.Regression.equals(modelCategory) ? true : ModelCategory.Multinomial.equals(modelCategory) ? true : ModelCategory.Ordinal.equals(modelCategory)) {
            String offsetCol = getOffsetCol();
            if (offsetCol == null) {
                withColumn2 = dataset.withColumn(outputColumnName(), userDefinedFunction.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray(columnArr)), functions$.MODULE$.lit(BoxesRunTime.boxToDouble(0.0d))})));
            } else {
                if (!new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dataset.columns())).contains(offsetCol)) {
                    throw new RuntimeException("Offset column must be present within the dataset!");
                }
                withColumn2 = dataset.withColumn(outputColumnName(), userDefinedFunction.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray(columnArr)), functions$.MODULE$.col(getOffsetCol()).cast(DoubleType$.MODULE$)})));
            }
            withColumn = withColumn2;
        } else {
            withColumn = dataset.withColumn(outputColumnName(), userDefinedFunction.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray(columnArr))})));
        }
        return withColumn;
    }

    public H2OSupervisedMOJOModel(String str) {
        super(str);
        H2OSupervisedMOJOParams.$init$((H2OSupervisedMOJOParams) this);
    }
}
