package org.apache.flink.ml.classification.logisticregression;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
import org.apache.flink.ml.common.optimizer.SGD;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/LogisticRegression.class */
public class LogisticRegression implements Estimator<LogisticRegression, LogisticRegressionModel>, LogisticRegressionParams<LogisticRegression> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public LogisticRegression() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.flink.ml.api.Estimator
    public LogisticRegressionModel fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        String multiClass = getMultiClass();
        Preconditions.checkArgument("auto".equals(multiClass) || "binomial".equals(multiClass), "Multinomial classification is not supported yet. Supported options: [auto, binomial].");
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream<LabeledPointWithWeight> map = tableEnvironment.toDataStream(tableArr[0]).map(row -> {
            double doubleValue = getWeightCol() == null ? 1.0d : ((Number) row.getField(getWeightCol())).doubleValue();
            double doubleValue2 = ((Number) row.getField(getLabelCol())).doubleValue();
            if (Double.compare(0.0d, doubleValue2) == 0 || Double.compare(1.0d, doubleValue2) == 0) {
                return new LabeledPointWithWeight(((Vector) row.getField(getFeaturesCol())).toDense(), doubleValue2, doubleValue);
            }
            throw new RuntimeException("Multinomial classification is not supported yet. Supported options: [auto, binomial].");
        });
        LogisticRegressionModel modelData = new LogisticRegressionModel().setModelData(tableEnvironment.fromDataStream(new SGD(getMaxIter(), getLearningRate(), getGlobalBatchSize(), getTol(), getReg(), getElasticNet()).optimize(DataStreamUtils.reduce((DataStream) map.map(labeledPointWithWeight -> {
            return Integer.valueOf(labeledPointWithWeight.getFeatures().size());
        }), (num, num2) -> {
            Preconditions.checkState(num.equals(num2), "The training data should all have same dimensions.");
            return num;
        }).map((v1) -> {
            return new DenseVector(v1);
        }), map, BinaryLogisticLoss.INSTANCE).map(denseVector -> {
            return new LogisticRegressionModelData(denseVector, 0L);
        })));
        ParamUtils.updateExistingParams(modelData, this.paramMap);
        return modelData;
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static LogisticRegression load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (LogisticRegression) ReadWriteUtils.loadStageParam(str);
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 103725464:
                if (implMethodName.equals("lambda$fit$2a160249$1")) {
                    z = 2;
                    break;
                }
                break;
            case 386226691:
                if (implMethodName.equals("lambda$fit$a521a499$1")) {
                    z = 3;
                    break;
                }
                break;
            case 712690700:
                if (implMethodName.equals("lambda$fit$af50211a$1")) {
                    z = true;
                    break;
                }
                break;
            case 1818100338:
                if (implMethodName.equals("<init>")) {
                    z = 4;
                    break;
                }
                break;
            case 1899832764:
                if (implMethodName.equals("lambda$fit$ff6b6d9$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/ReduceFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("reduce") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/logisticregression/LogisticRegression") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return (num, num2) -> {
                        Preconditions.checkState(num.equals(num2), "The training data should all have same dimensions.");
                        return num;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/logisticregression/LogisticRegression") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/common/feature/LabeledPointWithWeight;")) {
                    LogisticRegression logisticRegression = (LogisticRegression) serializedLambda.getCapturedArg(0);
                    return row -> {
                        double doubleValue = getWeightCol() == null ? 1.0d : ((Number) row.getField(getWeightCol())).doubleValue();
                        double doubleValue2 = ((Number) row.getField(getLabelCol())).doubleValue();
                        if (Double.compare(0.0d, doubleValue2) == 0 || Double.compare(1.0d, doubleValue2) == 0) {
                            return new LabeledPointWithWeight(((Vector) row.getField(getFeaturesCol())).toDense(), doubleValue2, doubleValue);
                        }
                        throw new RuntimeException("Multinomial classification is not supported yet. Supported options: [auto, binomial].");
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/logisticregression/LogisticRegression") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/common/feature/LabeledPointWithWeight;)Ljava/lang/Integer;")) {
                    return labeledPointWithWeight -> {
                        return Integer.valueOf(labeledPointWithWeight.getFeatures().size());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/logisticregression/LogisticRegression") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/linalg/DenseVector;)Lorg/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData;")) {
                    return denseVector -> {
                        return new LogisticRegressionModelData(denseVector, 0L);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/linalg/DenseVector") && serializedLambda.getImplMethodSignature().equals("(I)V")) {
                    return (v1) -> {
                        return new DenseVector(v1);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
