package org.apache.flink.ml.common.optimizer;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
import org.apache.flink.ml.common.lossfunc.LossFunc;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;

@Internal
/* loaded from: input_file:org/apache/flink/ml/common/optimizer/SGD.class */
public class SGD implements Optimizer {
    private final SGDParams params;

    /* loaded from: input_file:org/apache/flink/ml/common/optimizer/SGD$CacheDataAndDoTrain.class */
    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>, IterationListener<double[]> {
        private final SGDParams params;
        private final LossFunc lossFunc;
        private final OutputTag<DenseVector> modelDataOutputTag;
        private List<LabeledPointWithWeight> trainData;
        private ListState<LabeledPointWithWeight> trainDataState;
        private int nextBatchOffset;
        private ListState<Integer> nextBatchOffsetState;
        private DenseVector coefficient;
        private ListState<DenseVector> coefficientState;
        private int coefficientDim;
        private double[] feedbackArray;
        private ListState<double[]> feedbackArrayState;
        private int localBatchSize;

        private CacheDataAndDoTrain(LossFunc lossFunc, SGDParams sGDParams, OutputTag<DenseVector> outputTag) {
            this.nextBatchOffset = 0;
            this.lossFunc = lossFunc;
            this.params = sGDParams;
            this.modelDataOutputTag = outputTag;
        }

        public void open() {
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            this.localBatchSize = this.params.globalBatchSize / numberOfParallelSubtasks;
            if (this.params.globalBatchSize % numberOfParallelSubtasks > indexOfThisSubtask) {
                this.localBatchSize++;
            }
        }

        private double getTotalWeight() {
            return this.feedbackArray[this.coefficientDim];
        }

        private void setTotalWeight(double d) {
            this.feedbackArray[this.coefficientDim] = d;
        }

        private double getTotalLoss() {
            return this.feedbackArray[this.coefficientDim + 1];
        }

        private void setTotalLoss(double d) {
            this.feedbackArray[this.coefficientDim + 1] = d;
        }

        private void updateModel() {
            if (getTotalWeight() > 0.0d) {
                BLAS.axpy((-this.params.learningRate) / getTotalWeight(), new DenseVector(this.feedbackArray), this.coefficient, this.coefficientDim);
                setTotalLoss(getTotalLoss() + RegularizationUtils.regularize(this.coefficient, this.params.reg, this.params.elasticNet, this.params.learningRate));
            }
        }

        public void onEpochWatermarkIncremented(int i, IterationListener.Context context, Collector<double[]> collector) throws Exception {
            if (i == 0) {
                this.coefficient = new DenseVector(this.feedbackArray);
                this.coefficientDim = this.coefficient.size();
                this.feedbackArray = new double[this.coefficient.size() + 2];
            } else {
                updateModel();
            }
            if (this.trainData == null) {
                this.trainData = IteratorUtils.toList(((Iterable) this.trainDataState.get()).iterator());
            }
            if (this.trainData.size() > 0) {
                List<LabeledPointWithWeight> subList = this.trainData.subList(this.nextBatchOffset, Math.min(this.nextBatchOffset + this.localBatchSize, this.trainData.size()));
                this.nextBatchOffset += this.localBatchSize;
                this.nextBatchOffset = this.nextBatchOffset >= this.trainData.size() ? 0 : this.nextBatchOffset;
                Arrays.fill(this.feedbackArray, 0.0d);
                double d = 0.0d;
                double d2 = 0.0d;
                DenseVector denseVector = new DenseVector(this.feedbackArray);
                for (LabeledPointWithWeight labeledPointWithWeight : subList) {
                    d += this.lossFunc.computeLoss(labeledPointWithWeight, this.coefficient);
                    this.lossFunc.computeGradient(labeledPointWithWeight, this.coefficient, denseVector);
                    d2 += labeledPointWithWeight.getWeight();
                }
                setTotalLoss(d);
                setTotalWeight(d2);
                collector.collect(this.feedbackArray);
            }
        }

        public void onIterationTerminated(IterationListener.Context context, Collector<double[]> collector) {
            this.trainDataState.clear();
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                updateModel();
                context.output(this.modelDataOutputTag, this.coefficient);
            }
        }

        public void processElement1(StreamRecord<LabeledPointWithWeight> streamRecord) throws Exception {
            this.trainDataState.add((LabeledPointWithWeight) streamRecord.getValue());
        }

        public void processElement2(StreamRecord<double[]> streamRecord) {
            this.feedbackArray = (double[]) streamRecord.getValue();
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.coefficientState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("coefficientState", DenseVectorTypeInfo.INSTANCE));
            OperatorStateUtils.getUniqueElement(this.coefficientState, "coefficientState").ifPresent(denseVector -> {
                this.coefficient = denseVector;
            });
            if (this.coefficient != null) {
                this.coefficientDim = this.coefficient.size();
            }
            this.feedbackArrayState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("feedbackArrayState", PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO));
            OperatorStateUtils.getUniqueElement(this.feedbackArrayState, "feedbackArrayState").ifPresent(dArr -> {
                this.feedbackArray = dArr;
            });
            this.trainDataState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("trainDataState", TypeInformation.of(LabeledPointWithWeight.class)));
            this.nextBatchOffsetState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("nextBatchOffsetState", BasicTypeInfo.INT_TYPE_INFO));
            this.nextBatchOffset = ((Integer) OperatorStateUtils.getUniqueElement(this.nextBatchOffsetState, "nextBatchOffsetState").orElse(0)).intValue();
        }

        public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
            this.coefficientState.clear();
            if (this.coefficient != null) {
                this.coefficientState.add(this.coefficient);
            }
            this.feedbackArrayState.clear();
            if (this.feedbackArray != null) {
                this.feedbackArrayState.add(this.feedbackArray);
            }
            this.nextBatchOffsetState.clear();
            this.nextBatchOffsetState.add(Integer.valueOf(this.nextBatchOffset));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/common/optimizer/SGD$SGDParams.class */
    public static class SGDParams implements Serializable {
        public final int maxIter;
        public final double learningRate;
        public final int globalBatchSize;
        public final double tol;
        public final double reg;
        public final double elasticNet;

        private SGDParams(int i, double d, int i2, double d2, double d3, double d4) {
            this.maxIter = i;
            this.learningRate = d;
            this.globalBatchSize = i2;
            this.tol = d2;
            this.reg = d3;
            this.elasticNet = d4;
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/common/optimizer/SGD$TrainIterationBody.class */
    private static class TrainIterationBody implements IterationBody {
        private final LossFunc lossFunc;
        private final SGDParams params;

        public TrainIterationBody(LossFunc lossFunc, SGDParams sGDParams) {
            this.lossFunc = lossFunc;
            this.params = sGDParams;
        }

        public IterationBodyResult process(DataStreamList dataStreamList, DataStreamList dataStreamList2) {
            DataStream dataStream = dataStreamList.get(0);
            DataStream dataStream2 = dataStreamList2.get(0);
            OutputTag<DenseVector> outputTag = new OutputTag<DenseVector>("MODEL_OUTPUT") { // from class: org.apache.flink.ml.common.optimizer.SGD.TrainIterationBody.1
            };
            DataStream transform = dataStream2.connect(dataStream).transform("CacheDataAndDoTrain", PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, new CacheDataAndDoTrain(this.lossFunc, this.params, outputTag));
            DataStreamList forEachRound = IterationBody.forEachRound(DataStreamList.of(new DataStream[]{transform}), dataStreamList3 -> {
                return DataStreamList.of(new DataStream[]{DataStreamUtils.allReduceSum(dataStreamList3.get(0))});
            });
            return new IterationBodyResult(DataStreamList.of(new DataStream[]{forEachRound.get(0)}), DataStreamList.of(new DataStream[]{transform.getSideOutput(outputTag)}), forEachRound.get(0).map(obj -> {
                double[] dArr = (double[]) obj;
                return Double.valueOf(dArr[dArr.length - 1] / dArr[dArr.length - 2]);
            }).flatMap(new TerminateOnMaxIterOrTol(Integer.valueOf(this.params.maxIter), Double.valueOf(this.params.tol))));
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case -1953380414:
                    if (implMethodName.equals("lambda$process$97e3fc93$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/optimizer/SGD$TrainIterationBody") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Double;")) {
                        return obj -> {
                            double[] dArr = (double[]) obj;
                            return Double.valueOf(dArr[dArr.length - 1] / dArr[dArr.length - 2]);
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    public SGD(int i, double d, int i2, double d2, double d3, double d4) {
        this.params = new SGDParams(i, d, i2, d2, d3, d4);
    }

    @Override // org.apache.flink.ml.common.optimizer.Optimizer
    public DataStream<DenseVector> optimize(DataStream<DenseVector> dataStream, DataStream<LabeledPointWithWeight> dataStream2, LossFunc lossFunc) {
        return Iterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(new DataStream[]{dataStream.broadcast().map(denseVector -> {
            return denseVector.values;
        })}), ReplayableDataStreamList.notReplay(new DataStream[]{dataStream2.rebalance().map(labeledPointWithWeight -> {
            return labeledPointWithWeight;
        })}), IterationConfig.newBuilder().build(), new TrainIterationBody(lossFunc, this.params)).get(0);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1940573173:
                if (implMethodName.equals("lambda$optimize$a8ee2f1d$1")) {
                    z = true;
                    break;
                }
                break;
            case 1940573174:
                if (implMethodName.equals("lambda$optimize$a8ee2f1d$2")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/optimizer/SGD") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/common/feature/LabeledPointWithWeight;)Lorg/apache/flink/ml/common/feature/LabeledPointWithWeight;")) {
                    return labeledPointWithWeight -> {
                        return labeledPointWithWeight;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/optimizer/SGD") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/linalg/DenseVector;)[D")) {
                    return denseVector -> {
                        return denseVector.values;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
