package org.apache.flink.ml.classification.logisticregression;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.servable.api.DataFrame;
import org.apache.flink.ml.servable.api.ModelServable;
import org.apache.flink.ml.servable.api.Row;
import org.apache.flink.ml.servable.types.BasicType;
import org.apache.flink.ml.servable.types.DataTypes;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ServableReadWriteUtils;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.class */
public class LogisticRegressionModelServable implements ModelServable<LogisticRegressionModelServable>, LogisticRegressionModelParams<LogisticRegressionModelServable> {
    private final Map<Param<?>, Object> paramMap;
    private LogisticRegressionModelData modelData;

    public LogisticRegressionModelServable() {
        this.paramMap = new HashMap();
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    LogisticRegressionModelServable(LogisticRegressionModelData logisticRegressionModelData) {
        this();
        this.modelData = logisticRegressionModelData;
    }

    public DataFrame transform(DataFrame dataFrame) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int index = dataFrame.getIndex(getFeaturesCol());
        Iterator it = dataFrame.collect().iterator();
        while (it.hasNext()) {
            Tuple2<Double, DenseVector> transform = transform((Vector) ((Row) it.next()).get(index));
            arrayList.add((Double) transform.f0);
            arrayList2.add((DenseVector) transform.f1);
        }
        dataFrame.addColumn(getPredictionCol(), DataTypes.DOUBLE, arrayList);
        dataFrame.addColumn(getRawPredictionCol(), DataTypes.VECTOR(BasicType.DOUBLE), arrayList2);
        return dataFrame;
    }

    /* renamed from: setModelData, reason: merged with bridge method [inline-methods] */
    public LogisticRegressionModelServable m0setModelData(InputStream... inputStreamArr) throws IOException {
        Preconditions.checkArgument(inputStreamArr.length == 1);
        this.modelData = LogisticRegressionModelData.decode(inputStreamArr[0]);
        return this;
    }

    public static LogisticRegressionModelServable load(String str) throws IOException {
        LogisticRegressionModelServable loadServableParam = ServableReadWriteUtils.loadServableParam(str, LogisticRegressionModelServable.class);
        InputStream loadModelData = ServableReadWriteUtils.loadModelData(str);
        try {
            loadServableParam.m0setModelData(loadModelData);
            if (loadModelData != null) {
                loadModelData.close();
            }
            return loadServableParam;
        } catch (Throwable th) {
            if (loadModelData != null) {
                try {
                    loadModelData.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    protected Tuple2<Double, DenseVector> transform(Vector vector) {
        double dot = BLAS.dot(vector, this.modelData.coefficient);
        double exp = 1.0d - (1.0d / (1.0d + Math.exp(dot)));
        return Tuple2.of(Double.valueOf(dot >= 0.0d ? 1.0d : 0.0d), Vectors.dense(new double[]{1.0d - exp, exp}));
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
