package org.apache.flink.ml.classification.logisticregression;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.class */
public class OnlineLogisticRegression implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>, OnlineLogisticRegressionParams<OnlineLogisticRegression> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private Table initModelDataTable;

    /* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression$CalculateLocalGradient.class */
    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]> implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
        private ListState<DenseVector> modelDataState;
        private ListState<Row[]> localBatchDataState;
        private double[] gradient;
        private double[] weightSum;

        private CalculateLocalGradient() {
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.modelDataState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("modelData", DenseVector.class));
            this.localBatchDataState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("localBatch", ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class))));
        }

        public void processElement1(StreamRecord<Row[]> streamRecord) throws Exception {
            this.localBatchDataState.add(streamRecord.getValue());
            calculateGradient();
        }

        private void calculateGradient() throws Exception {
            if (((Iterable) this.modelDataState.get()).iterator().hasNext() && ((Iterable) this.localBatchDataState.get()).iterator().hasNext()) {
                DenseVector denseVector = (DenseVector) OperatorStateUtils.getUniqueElement(this.modelDataState, "modelData").get();
                this.modelDataState.clear();
                List list = IteratorUtils.toList(((Iterable) this.localBatchDataState.get()).iterator());
                Row[] rowArr = (Row[]) list.remove(0);
                this.localBatchDataState.update(list);
                for (Row row : rowArr) {
                    DenseVector denseVector2 = (Vector) row.getFieldAs(0);
                    double doubleValue = ((Double) row.getFieldAs(1)).doubleValue();
                    double doubleValue2 = row.getArity() == 2 ? 1.0d : ((Double) row.getFieldAs(2)).doubleValue();
                    if (this.gradient == null) {
                        this.gradient = new double[denseVector2.size()];
                        this.weightSum = new double[this.gradient.length];
                    }
                    double exp = 1.0d / (1.0d + Math.exp(-BLAS.dot(denseVector, denseVector2)));
                    if (denseVector2 instanceof DenseVector) {
                        DenseVector denseVector3 = denseVector2;
                        for (int i = 0; i < denseVector.size(); i++) {
                            double[] dArr = this.gradient;
                            int i2 = i;
                            dArr[i2] = dArr[i2] + ((exp - doubleValue) * denseVector3.values[i]);
                            double[] dArr2 = this.weightSum;
                            int i3 = i;
                            dArr2[i3] = dArr2[i3] + 1.0d;
                        }
                    } else {
                        SparseVector sparseVector = (SparseVector) denseVector2;
                        for (int i4 = 0; i4 < sparseVector.indices.length; i4++) {
                            int i5 = sparseVector.indices[i4];
                            double[] dArr3 = this.gradient;
                            dArr3[i5] = dArr3[i5] + ((exp - doubleValue) * sparseVector.values[i4]);
                            double[] dArr4 = this.weightSum;
                            dArr4[i5] = dArr4[i5] + doubleValue2;
                        }
                    }
                }
                if (rowArr.length > 0) {
                    Output output = this.output;
                    DenseVector[] denseVectorArr = new DenseVector[3];
                    denseVectorArr[0] = new DenseVector(this.gradient);
                    denseVectorArr[1] = new DenseVector(this.weightSum);
                    denseVectorArr[2] = getRuntimeContext().getIndexOfThisSubtask() == 0 ? denseVector : null;
                    output.collect(new StreamRecord(denseVectorArr));
                }
                Arrays.fill(this.gradient, 0.0d);
                Arrays.fill(this.weightSum, 0.0d);
            }
        }

        public void processElement2(StreamRecord<DenseVector> streamRecord) throws Exception {
            this.modelDataState.add(streamRecord.getValue());
            calculateGradient();
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression$CreateLrModelData.class */
    private static class CreateLrModelData implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction {
        private Long modelVersion;
        private transient ListState<Long> modelVersionState;

        private CreateLrModelData() {
            this.modelVersion = 1L;
        }

        public LogisticRegressionModelData map(DenseVector denseVector) throws Exception {
            Long l = this.modelVersion;
            this.modelVersion = Long.valueOf(this.modelVersion.longValue() + 1);
            return new LogisticRegressionModelData(denseVector, l.longValue());
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            this.modelVersionState.update(Collections.singletonList(this.modelVersion));
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            this.modelVersionState = functionInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("modelVersionState", Long.class));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression$FeaturesLabelExtractor.class */
    public static class FeaturesLabelExtractor implements MapFunction<Row, Row> {
        private final String featuresCol;
        private final String labelCol;
        private final String weightCol;

        private FeaturesLabelExtractor(String str, String str2, String str3) {
            this.featuresCol = str;
            this.labelCol = str2;
            this.weightCol = str3;
        }

        public Row map(Row row) throws Exception {
            return this.weightCol == null ? Row.of(new Object[]{row.getField(this.featuresCol), row.getField(this.labelCol)}) : Row.of(new Object[]{row.getField(this.featuresCol), row.getField(this.labelCol), row.getField(this.weightCol)});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression$FtrlIterationBody.class */
    public static class FtrlIterationBody implements IterationBody {
        private final int batchSize;
        private final double alpha;
        private final double beta;
        private final double l1;
        private final double l2;

        public FtrlIterationBody(int i, double d, double d2, double d3, double d4) {
            this.batchSize = i;
            this.alpha = d;
            this.beta = d2;
            this.l1 = d4 * d3;
            this.l2 = (1.0d - d4) * d3;
        }

        public IterationBodyResult process(DataStreamList dataStreamList, DataStreamList dataStreamList2) {
            DataStream dataStream = dataStreamList.get(0);
            DataStream dataStream2 = dataStreamList2.get(0);
            int parallelism = dataStream2.getParallelism();
            Preconditions.checkState(parallelism <= this.batchSize, "There are more subtasks in the training process than the number of elements in each batch. Some subtasks might be idling forever.");
            DataStream parallelism2 = DataStreamUtils.generateBatchData(dataStream2, parallelism, this.batchSize).connect(dataStream.broadcast()).transform("LocalGradientCalculator", TypeInformation.of(DenseVector[].class), new CalculateLocalGradient()).setParallelism(parallelism).countWindowAll(parallelism).reduce((denseVectorArr, denseVectorArr2) -> {
                BLAS.axpy(1.0d, denseVectorArr[0], denseVectorArr2[0]);
                BLAS.axpy(1.0d, denseVectorArr[1], denseVectorArr2[1]);
                if (denseVectorArr2[2] == null) {
                    denseVectorArr2[2] = denseVectorArr[2];
                }
                return denseVectorArr2;
            }).transform("ModelDataUpdater", TypeInformation.of(DenseVector.class), new UpdateModel(this.alpha, this.beta, this.l1, this.l2)).setParallelism(1);
            return new IterationBodyResult(DataStreamList.of(new DataStream[]{parallelism2}), DataStreamList.of(new DataStream[]{parallelism2.map(new CreateLrModelData()).setParallelism(1)}));
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case -550779690:
                    if (implMethodName.equals("lambda$process$545b1b82$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/ReduceFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("reduce") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression$FtrlIterationBody") && serializedLambda.getImplMethodSignature().equals("([Lorg/apache/flink/ml/linalg/DenseVector;[Lorg/apache/flink/ml/linalg/DenseVector;)[Lorg/apache/flink/ml/linalg/DenseVector;")) {
                        return (denseVectorArr, denseVectorArr2) -> {
                            BLAS.axpy(1.0d, denseVectorArr[0], denseVectorArr2[0]);
                            BLAS.axpy(1.0d, denseVectorArr[1], denseVectorArr2[1]);
                            if (denseVectorArr2[2] == null) {
                                denseVectorArr2[2] = denseVectorArr[2];
                            }
                            return denseVectorArr2;
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression$UpdateModel.class */
    private static class UpdateModel extends AbstractStreamOperator<DenseVector> implements OneInputStreamOperator<DenseVector[], DenseVector> {
        private ListState<double[]> nParamState;
        private ListState<double[]> zParamState;
        private final double alpha;
        private final double beta;
        private final double l1;
        private final double l2;
        private double[] nParam;
        private double[] zParam;

        public UpdateModel(double d, double d2, double d3, double d4) {
            this.alpha = d;
            this.beta = d2;
            this.l1 = d3;
            this.l2 = d4;
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.nParamState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("nParamState", double[].class));
            this.zParamState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("zParamState", double[].class));
        }

        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
            DenseVector[] denseVectorArr = (DenseVector[]) streamRecord.getValue();
            double[] dArr = denseVectorArr[2].values;
            double[] dArr2 = denseVectorArr[0].values;
            for (int i = 0; i < dArr2.length; i++) {
                if (denseVectorArr[1].values[i] != 0.0d) {
                    dArr2[i] = dArr2[i] / denseVectorArr[1].values[i];
                }
            }
            if (this.zParam == null) {
                this.zParam = new double[dArr2.length];
                this.nParam = new double[dArr2.length];
                this.nParamState.add(this.nParam);
                this.zParamState.add(this.zParam);
            }
            for (int i2 = 0; i2 < this.zParam.length; i2++) {
                double sqrt = (Math.sqrt(this.nParam[i2] + (dArr2[i2] * dArr2[i2])) - Math.sqrt(this.nParam[i2])) / this.alpha;
                double[] dArr3 = this.zParam;
                int i3 = i2;
                dArr3[i3] = dArr3[i3] + (dArr2[i2] - (sqrt * dArr[i2]));
                double[] dArr4 = this.nParam;
                int i4 = i2;
                dArr4[i4] = dArr4[i4] + (dArr2[i2] * dArr2[i2]);
                if (Math.abs(this.zParam[i2]) <= this.l1) {
                    dArr[i2] = 0.0d;
                } else {
                    dArr[i2] = (((this.zParam[i2] < 0.0d ? -1 : 1) * this.l1) - this.zParam[i2]) / (((this.beta + Math.sqrt(this.nParam[i2])) / this.alpha) + this.l2);
                }
            }
            this.output.collect(new StreamRecord(new DenseVector(dArr)));
        }
    }

    public OnlineLogisticRegression() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public OnlineLogisticRegressionModel m11fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream<LogisticRegressionModelData> modelDataStream = LogisticRegressionModelData.getModelDataStream(this.initModelDataTable);
        DataStream map = tableEnvironment.toDataStream(tableArr[0]).map(new FeaturesLabelExtractor(getFeaturesCol(), getLabelCol(), getWeightCol()));
        DataStream map2 = modelDataStream.map(logisticRegressionModelData -> {
            return logisticRegressionModelData.coefficient;
        });
        map2.getTransformation().setParallelism(1);
        OnlineLogisticRegressionModel m12setModelData = new OnlineLogisticRegressionModel().m12setModelData(tableEnvironment.fromDataStream(Iterations.iterateUnboundedStreams(DataStreamList.of(new DataStream[]{map2}), DataStreamList.of(new DataStream[]{map}), new FtrlIterationBody(getGlobalBatchSize(), getAlpha().doubleValue(), getBeta().doubleValue(), getReg(), getElasticNet())).get(0)));
        ReadWriteUtils.updateExistingParams(m12setModelData, this.paramMap);
        return m12setModelData;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
        ReadWriteUtils.saveModelData(LogisticRegressionModelData.getModelDataStream(this.initModelDataTable), str, new LogisticRegressionModelData.ModelDataEncoder());
    }

    public static OnlineLogisticRegression load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        OnlineLogisticRegression loadStageParam = ReadWriteUtils.loadStageParam(str);
        loadStageParam.setInitialModelData(ReadWriteUtils.loadModelData(streamTableEnvironment, str, new LogisticRegressionModelData.ModelDataDecoder()));
        return loadStageParam;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public OnlineLogisticRegression setInitialModelData(Table table) {
        this.initModelDataTable = table;
        return this;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -816315598:
                if (implMethodName.equals("lambda$fit$f0aee8b6$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData;)Lorg/apache/flink/ml/linalg/DenseVector;")) {
                    return logisticRegressionModelData -> {
                        return logisticRegressionModelData.coefficient;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
