package org.apache.flink.ml.optimization;

import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.scala.DataSet;
import org.apache.flink.ml.common.WeightVector;
import org.apache.flink.ml.optimization.LearningRateMethod;
import org.apache.flink.ml.package$;
import scala.Serializable;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;

/* compiled from: GradientDescent.scala */
/* loaded from: input_file:org/apache/flink/ml/optimization/GradientDescent$$anonfun$3.class */
public class GradientDescent$$anonfun$3 extends AbstractFunction1<DataSet<Tuple2<WeightVector, Object>>, Tuple2<DataSet<Tuple2<WeightVector, Object>>, DataSet<Object>>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ GradientDescent $outer;
    private final DataSet dataPoints$1;
    private final RegularizationPenalty regularizationPenalty$2;
    private final double regularizationConstant$2;
    private final double learningRate$2;
    public final double convergenceThreshold$1;
    private final LossFunction lossFunction$2;
    private final LearningRateMethod.LearningRateMethodTrait learningRateMethod$1;

    public final Tuple2<DataSet<Tuple2<WeightVector, Object>>, DataSet<Object>> apply(DataSet<Tuple2<WeightVector, Object>> dataSet) {
        DataSet<WeightVector> map = dataSet.map(new GradientDescent$$anonfun$3$$anonfun$4(this), new GradientDescent$$anonfun$3$$anon$17(this), ClassTag$.MODULE$.apply(WeightVector.class));
        DataSet map2 = dataSet.map(new GradientDescent$$anonfun$3$$anonfun$5(this), BasicTypeInfo.getInfoFor(Double.TYPE), ClassTag$.MODULE$.Double());
        DataSet<WeightVector> org$apache$flink$ml$optimization$GradientDescent$$SGDStep = this.$outer.org$apache$flink$ml$optimization$GradientDescent$$SGDStep(this.dataPoints$1, map, this.lossFunction$2, this.regularizationPenalty$2, this.regularizationConstant$2, this.learningRate$2, this.learningRateMethod$1);
        DataSet<Object> org$apache$flink$ml$optimization$GradientDescent$$calculateLoss = this.$outer.org$apache$flink$ml$optimization$GradientDescent$$calculateLoss(this.dataPoints$1, org$apache$flink$ml$optimization$GradientDescent$$SGDStep, this.lossFunction$2);
        return new Tuple2<>(package$.MODULE$.RichDataSet(org$apache$flink$ml$optimization$GradientDescent$$SGDStep).mapWithBcVariable(org$apache$flink$ml$optimization$GradientDescent$$calculateLoss, new GradientDescent$$anonfun$3$$anonfun$apply$1(this), new GradientDescent$$anonfun$3$$anon$18(this), ClassTag$.MODULE$.apply(Tuple2.class)), package$.MODULE$.RichDataSet(map2).filterWithBcVariable(org$apache$flink$ml$optimization$GradientDescent$$calculateLoss, new GradientDescent$$anonfun$3$$anonfun$1(this)));
    }

    public GradientDescent$$anonfun$3(GradientDescent gradientDescent, DataSet dataSet, RegularizationPenalty regularizationPenalty, double d, double d2, double d3, LossFunction lossFunction, LearningRateMethod.LearningRateMethodTrait learningRateMethodTrait) {
        if (gradientDescent == null) {
            throw new NullPointerException();
        }
        this.$outer = gradientDescent;
        this.dataPoints$1 = dataSet;
        this.regularizationPenalty$2 = regularizationPenalty;
        this.regularizationConstant$2 = d;
        this.learningRate$2 = d2;
        this.convergenceThreshold$1 = d3;
        this.lossFunction$2 = lossFunction;
        this.learningRateMethod$1 = learningRateMethodTrait;
    }
}
