package org.apache.flink.ml.optimization;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.api.scala.DataSet;
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo;
import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializer;
import org.apache.flink.ml.common.LabeledVector;
import org.apache.flink.ml.common.WeightVector;
import org.apache.flink.ml.math.BLAS$;
import org.apache.flink.ml.math.DenseVector;
import org.apache.flink.ml.math.SparseVector;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.optimization.LearningRateMethod;
import org.apache.flink.ml.package$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq$;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: GradientDescent.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0015d\u0001B\u0001\u0003\u00015\u0011qb\u0012:bI&,g\u000e\u001e#fg\u000e,g\u000e\u001e\u0006\u0003\u0007\u0011\tAb\u001c9uS6L'0\u0019;j_:T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u00151G.\u001b8l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M\u0011\u0001A\u0004\t\u0003\u001fAi\u0011AA\u0005\u0003#\t\u0011q\"\u0013;fe\u0006$\u0018N^3T_24XM\u001d\u0005\u0006'\u0001!\t\u0001F\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003U\u0001\"a\u0004\u0001\t\u000b]\u0001A\u0011\t\r\u0002\u0011=\u0004H/[7ju\u0016$2!G\u0014.!\rQr$I\u0007\u00027)\u0011A$H\u0001\u0006g\u000e\fG.\u0019\u0006\u0003=\u0019\t1!\u00199j\u0013\t\u00013DA\u0004ECR\f7+\u001a;\u0011\u0005\t*S\"A\u0012\u000b\u0005\u0011\"\u0011AB2p[6|g.\u0003\u0002'G\taq+Z5hQR4Vm\u0019;pe\")\u0001F\u0006a\u0001S\u0005!A-\u0019;b!\rQrD\u000b\t\u0003E-J!\u0001L\u0012\u0003\u001b1\u000b'-\u001a7fIZ+7\r^8s\u0011\u0015qc\u00031\u00010\u00039Ig.\u001b;jC2<V-[4iiN\u00042\u0001\r\u001a\u001a\u001b\u0005\t$\"\u0001\u000f\n\u0005M\n$AB(qi&|g\u000eC\u00036\u0001\u0011\u0005a'\u0001\u0011paRLW.\u001b>f/&$\bnQ8om\u0016\u0014x-\u001a8dK\u000e\u0013\u0018\u000e^3sS>tGCC\r8sm\u0002UI\u0013'O'\")\u0001\b\u000ea\u0001S\u0005QA-\u0019;b!>Lg\u000e^:\t\u000bi\"\u0004\u0019A\r\u0002!%t\u0017\u000e^5bY^+\u0017n\u001a5ug\u0012\u001b\u0006\"\u0002\u001f5\u0001\u0004i\u0014A\u00058v[\n,'o\u00144Ji\u0016\u0014\u0018\r^5p]N\u0004\"\u0001\r \n\u0005}\n$aA%oi\")\u0011\t\u000ea\u0001\u0005\u0006)\"/Z4vY\u0006\u0014\u0018N_1uS>t\u0007+\u001a8bYRL\bCA\bD\u0013\t!%AA\u000bSK\u001e,H.\u0019:ju\u0006$\u0018n\u001c8QK:\fG\u000e^=\t\u000b\u0019#\u0004\u0019A$\u0002-I,w-\u001e7be&T\u0018\r^5p]\u000e{gn\u001d;b]R\u0004\"\u0001\r%\n\u0005%\u000b$A\u0002#pk\ndW\rC\u0003Li\u0001\u0007q)\u0001\u0007mK\u0006\u0014h.\u001b8h%\u0006$X\rC\u0003Ni\u0001\u0007q)\u0001\u000bd_:4XM]4f]\u000e,G\u000b\u001b:fg\"|G\u000e\u001a\u0005\u0006\u001fR\u0002\r\u0001U\u0001\rY>\u001c8OR;oGRLwN\u001c\t\u0003\u001fEK!A\u0015\u0002\u0003\u00191{7o\u001d$v]\u000e$\u0018n\u001c8\t\u000bQ#\u0004\u0019A+\u0002%1,\u0017M\u001d8j]\u001e\u0014\u0016\r^3NKRDw\u000e\u001a\t\u0003-\u001at!a\u00163\u000f\u0005a\u001bgBA-c\u001d\tQ\u0016M\u0004\u0002\\A:\u0011AlX\u0007\u0002;*\u0011a\fD\u0001\u0007yI|w\u000e\u001e \n\u0003-I!!\u0003\u0006\n\u0005\u001dA\u0011BA\u0003\u0007\u0013\t\u0019A!\u0003\u0002f\u0005\u0005\u0011B*Z1s]&twMU1uK6+G\u000f[8e\u0013\t9\u0007NA\fMK\u0006\u0014h.\u001b8h%\u0006$X-T3uQ>$GK]1ji*\u0011QM\u0001\u0005\u0006U\u0002!\ta[\u0001$_B$\u0018.\\5{K^KG\u000f[8vi\u000e{gN^3sO\u0016t7-Z\"sSR,'/[8o)%IB.\u001c8paF\u00148\u000fC\u0003)S\u0002\u0007\u0011\u0006C\u0003;S\u0002\u0007\u0011\u0004C\u0003=S\u0002\u0007Q\bC\u0003BS\u0002\u0007!\tC\u0003GS\u0002\u0007q\tC\u0003LS\u0002\u0007q\tC\u0003PS\u0002\u0007\u0001\u000bC\u0003uS\u0002\u0007Q+\u0001\npaRLW.\u001b>bi&|g.T3uQ>$\u0007\"\u0002<\u0001\t\u00139\u0018aB*H\tN#X\r\u001d\u000b\t3aL8\u0010`?\u007f\u007f\")\u0001&\u001ea\u0001S!)!0\u001ea\u00013\u0005q1-\u001e:sK:$x+Z5hQR\u001c\b\"B(v\u0001\u0004\u0001\u0006\"B!v\u0001\u0004\u0011\u0005\"\u0002$v\u0001\u00049\u0005\"B&v\u0001\u00049\u0005\"\u0002+v\u0001\u0004)\u0006bBA\u0002\u0001\u0011\u0005\u0011QA\u0001\ti\u0006\\Wm\u0015;faRa\u0011qAA\n\u0003/\tY\"!\b\u0002 A!\u0011\u0011BA\b\u001b\t\tYAC\u0002\u0002\u000e\u0011\tA!\\1uQ&!\u0011\u0011CA\u0006\u0005\u00191Vm\u0019;pe\"A\u0011QCA\u0001\u0001\u0004\t9!\u0001\u0007xK&<\u0007\u000e\u001e,fGR|'\u000f\u0003\u0005\u0002\u001a\u0005\u0005\u0001\u0019AA\u0004\u0003!9'/\u00193jK:$\bBB!\u0002\u0002\u0001\u0007!\t\u0003\u0004G\u0003\u0003\u0001\ra\u0012\u0005\u0007\u0017\u0006\u0005\u0001\u0019A$\t\u000f\u0005\r\u0002\u0001\"\u0003\u0002&\u0005i1-\u00197dk2\fG/\u001a'pgN$\u0002\"a\n\u0002*\u0005-\u0012q\u0006\t\u00045}9\u0005B\u0002\u0015\u0002\"\u0001\u0007\u0011\u0006C\u0004\u0002.\u0005\u0005\u0002\u0019A\r\u0002\u0011],\u0017n\u001a5u\tNCaaTA\u0011\u0001\u0004\u0001vaBA\u001a\u0005!\u0005\u0011QG\u0001\u0010\u000fJ\fG-[3oi\u0012+7oY3oiB\u0019q\"a\u000e\u0007\r\u0005\u0011\u0001\u0012AA\u001d'\u0019\t9$a\u000f\u0002BA\u0019\u0001'!\u0010\n\u0007\u0005}\u0012G\u0001\u0004B]f\u0014VM\u001a\t\u0004a\u0005\r\u0013bAA#c\ta1+\u001a:jC2L'0\u00192mK\"91#a\u000e\u0005\u0002\u0005%CCAA\u001b\u0011\u001d\ti%a\u000e\u0005\u0002Q\tQ!\u00199qYfD!\"!\u0015\u00028\u0005\u0005I\u0011BA*\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005U\u0003\u0003BA,\u0003Cj!!!\u0017\u000b\t\u0005m\u0013QL\u0001\u0005Y\u0006twM\u0003\u0002\u0002`\u0005!!.\u0019<b\u0013\u0011\t\u0019'!\u0017\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:org/apache/flink/ml/optimization/GradientDescent.class */
public class GradientDescent extends IterativeSolver {
    public static GradientDescent apply() {
        return GradientDescent$.MODULE$.apply();
    }

    @Override // org.apache.flink.ml.optimization.Solver
    public DataSet<WeightVector> optimize(DataSet<LabeledVector> dataSet, Option<DataSet<WeightVector>> option) {
        DataSet<WeightVector> optimizeWithConvergenceCriterion;
        int unboxToInt = BoxesRunTime.unboxToInt(parameters().apply(IterativeSolver$Iterations$.MODULE$));
        Some some = parameters().get(IterativeSolver$ConvergenceThreshold$.MODULE$);
        LossFunction lossFunction = (LossFunction) parameters().apply(Solver$LossFunction$.MODULE$);
        double unboxToDouble = BoxesRunTime.unboxToDouble(parameters().apply(IterativeSolver$LearningRate$.MODULE$));
        RegularizationPenalty regularizationPenalty = (RegularizationPenalty) parameters().apply(Solver$RegularizationPenaltyValue$.MODULE$);
        double unboxToDouble2 = BoxesRunTime.unboxToDouble(parameters().apply(Solver$RegularizationConstant$.MODULE$));
        LearningRateMethod.LearningRateMethodTrait learningRateMethodTrait = (LearningRateMethod.LearningRateMethodTrait) parameters().apply(IterativeSolver$LearningRateMethodValue$.MODULE$);
        DataSet<WeightVector> createInitialWeightsDS = createInitialWeightsDS(option, dataSet);
        if (None$.MODULE$.equals(some)) {
            optimizeWithConvergenceCriterion = optimizeWithoutConvergenceCriterion(dataSet, createInitialWeightsDS, unboxToInt, regularizationPenalty, unboxToDouble2, unboxToDouble, lossFunction, learningRateMethodTrait);
        } else {
            if (!(some instanceof Some)) {
                throw new MatchError(some);
            }
            optimizeWithConvergenceCriterion = optimizeWithConvergenceCriterion(dataSet, createInitialWeightsDS, unboxToInt, regularizationPenalty, unboxToDouble2, unboxToDouble, BoxesRunTime.unboxToDouble(some.value()), lossFunction, learningRateMethodTrait);
        }
        return optimizeWithConvergenceCriterion;
    }

    public DataSet<WeightVector> optimizeWithConvergenceCriterion(DataSet<LabeledVector> dataSet, DataSet<WeightVector> dataSet2, int i, RegularizationPenalty regularizationPenalty, double d, double d2, double d3, LossFunction lossFunction, LearningRateMethod.LearningRateMethodTrait learningRateMethodTrait) {
        final GradientDescent gradientDescent = null;
        return package$.MODULE$.RichDataSet(dataSet2).mapWithBcVariable(calculateLoss(dataSet, dataSet2, lossFunction), (weightVector, obj) -> {
            return $anonfun$optimizeWithConvergenceCriterion$1(weightVector, BoxesRunTime.unboxToDouble(obj));
        }, new CaseClassTypeInfo<Tuple2<WeightVector, Object>>(this) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$14
            public /* synthetic */ TypeInformation[] protected$types(GradientDescent$$anon$14 gradientDescent$$anon$14) {
                return gradientDescent$$anon$14.types;
            }

            public TypeSerializer<Tuple2<WeightVector, Object>> createSerializer(ExecutionConfig executionConfig) {
                final TypeSerializer[] typeSerializerArr = new TypeSerializer[getArity()];
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), getArity()).foreach$mVc$sp(i2 -> {
                    typeSerializerArr[i2] = this.protected$types(this)[i2].createSerializer(executionConfig);
                });
                new ScalaCaseClassSerializer<Tuple2<WeightVector, Object>>(this, typeSerializerArr) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$14$$anon$3
                    /* renamed from: createInstance, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
                    public Tuple2<WeightVector, Object> m153createInstance(Object[] objArr) {
                        return new Tuple2<>((WeightVector) objArr[0], BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(objArr[1])));
                    }

                    {
                        Class typeClass = this.getTypeClass();
                    }
                };
                return new ScalaCaseClassSerializer(getTypeClass(), typeSerializerArr);
            }

            /* JADX WARN: Illegal instructions before constructor call */
            {
                /*
                    r11 = this;
                    r0 = r11
                    java.lang.Class<scala.Tuple2> r1 = scala.Tuple2.class
                    scala.collection.immutable.$colon$colon r2 = new scala.collection.immutable.$colon$colon
                    r3 = r2
                    org.apache.flink.ml.optimization.GradientDescent$$anon$14$$anon$15 r4 = new org.apache.flink.ml.optimization.GradientDescent$$anon$14$$anon$15
                    r5 = r4
                    r6 = 0
                    r5.<init>(r6)
                    scala.collection.immutable.$colon$colon r5 = new scala.collection.immutable.$colon$colon
                    r6 = r5
                    java.lang.Class r7 = java.lang.Double.TYPE
                    org.apache.flink.api.common.typeinfo.BasicTypeInfo r7 = org.apache.flink.api.common.typeinfo.BasicTypeInfo.getInfoFor(r7)
                    scala.collection.immutable.Nil$ r8 = scala.collection.immutable.Nil$.MODULE$
                    r6.<init>(r7, r8)
                    r3.<init>(r4, r5)
                    scala.Predef$ r3 = scala.Predef$.MODULE$
                    scala.reflect.ClassTag$ r4 = scala.reflect.ClassTag$.MODULE$
                    java.lang.Class<org.apache.flink.api.common.typeinfo.TypeInformation> r5 = org.apache.flink.api.common.typeinfo.TypeInformation.class
                    scala.reflect.ClassTag r4 = r4.apply(r5)
                    java.lang.Object r3 = r3.implicitly(r4)
                    scala.reflect.ClassTag r3 = (scala.reflect.ClassTag) r3
                    java.lang.Object r2 = r2.toArray(r3)
                    org.apache.flink.api.common.typeinfo.TypeInformation[] r2 = (org.apache.flink.api.common.typeinfo.TypeInformation[]) r2
                    scala.collection.immutable.$colon$colon r3 = new scala.collection.immutable.$colon$colon
                    r4 = r3
                    org.apache.flink.ml.optimization.GradientDescent$$anon$14$$anon$16 r5 = new org.apache.flink.ml.optimization.GradientDescent$$anon$14$$anon$16
                    r6 = r5
                    r7 = 0
                    r6.<init>(r7)
                    scala.collection.immutable.$colon$colon r6 = new scala.collection.immutable.$colon$colon
                    r7 = r6
                    java.lang.Class r8 = java.lang.Double.TYPE
                    org.apache.flink.api.common.typeinfo.BasicTypeInfo r8 = org.apache.flink.api.common.typeinfo.BasicTypeInfo.getInfoFor(r8)
                    scala.collection.immutable.Nil$ r9 = scala.collection.immutable.Nil$.MODULE$
                    r7.<init>(r8, r9)
                    r4.<init>(r5, r6)
                    scala.collection.Seq$ r4 = scala.collection.Seq$.MODULE$
                    scala.Predef$ r5 = scala.Predef$.MODULE$
                    r6 = 2
                    java.lang.String[] r6 = new java.lang.String[r6]
                    r7 = r6
                    r8 = 0
                    java.lang.String r9 = "_1"
                    r7[r8] = r9
                    r7 = r6
                    r8 = 1
                    java.lang.String r9 = "_2"
                    r7[r8] = r9
                    java.lang.Object[] r6 = (java.lang.Object[]) r6
                    scala.collection.mutable.WrappedArray r5 = r5.wrapRefArray(r6)
                    scala.collection.GenTraversable r4 = r4.apply(r5)
                    scala.collection.Seq r4 = (scala.collection.Seq) r4
                    r0.<init>(r1, r2, r3, r4)
                    return
                */
                throw new UnsupportedOperationException("Method not decompiled: org.apache.flink.ml.optimization.GradientDescent$$anon$14.<init>(org.apache.flink.ml.optimization.GradientDescent):void");
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class)).iterateWithTermination(i, dataSet3 -> {
            final GradientDescent gradientDescent2 = null;
            DataSet<WeightVector> map = dataSet3.map(tuple2 -> {
                return (WeightVector) tuple2._1();
            }, new CaseClassTypeInfo<WeightVector>(gradientDescent2) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$17
                public /* synthetic */ TypeInformation[] protected$types(GradientDescent$$anon$17 gradientDescent$$anon$17) {
                    return gradientDescent$$anon$17.types;
                }

                public TypeSerializer<WeightVector> createSerializer(ExecutionConfig executionConfig) {
                    final TypeSerializer[] typeSerializerArr = new TypeSerializer[getArity()];
                    RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), getArity()).foreach$mVc$sp(i2 -> {
                        typeSerializerArr[i2] = this.protected$types(this)[i2].createSerializer(executionConfig);
                    });
                    new ScalaCaseClassSerializer<WeightVector>(this, typeSerializerArr) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$17$$anon$4
                        /* renamed from: createInstance, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
                        public WeightVector m155createInstance(Object[] objArr) {
                            return new WeightVector((Vector) objArr[0], BoxesRunTime.unboxToDouble(objArr[1]));
                        }

                        {
                            Class typeClass = this.getTypeClass();
                        }
                    };
                    return new ScalaCaseClassSerializer(getTypeClass(), typeSerializerArr);
                }

                {
                    super(WeightVector.class, (TypeInformation[]) Nil$.MODULE$.toArray((ClassTag) Predef$.MODULE$.implicitly(ClassTag$.MODULE$.apply(TypeInformation.class))), new $colon.colon(TypeExtractor.createTypeInfo(Vector.class), new $colon.colon(BasicTypeInfo.getInfoFor(Double.TYPE), Nil$.MODULE$)), Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{"weights", "intercept"})));
                }
            }, ClassTag$.MODULE$.apply(WeightVector.class));
            DataSet map2 = dataSet3.map(tuple22 -> {
                return BoxesRunTime.boxToDouble(tuple22._2$mcD$sp());
            }, BasicTypeInfo.getInfoFor(Double.TYPE), ClassTag$.MODULE$.Double());
            DataSet<WeightVector> SGDStep = this.SGDStep(dataSet, map, lossFunction, regularizationPenalty, d, d2, learningRateMethodTrait);
            DataSet<Object> calculateLoss = this.calculateLoss(dataSet, SGDStep, lossFunction);
            return new Tuple2(package$.MODULE$.RichDataSet(SGDStep).mapWithBcVariable(calculateLoss, (weightVector2, obj2) -> {
                return $anonfun$optimizeWithConvergenceCriterion$6(weightVector2, BoxesRunTime.unboxToDouble(obj2));
            }, new CaseClassTypeInfo<Tuple2<WeightVector, Object>>(this) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$18
                public /* synthetic */ TypeInformation[] protected$types(GradientDescent$$anon$18 gradientDescent$$anon$18) {
                    return gradientDescent$$anon$18.types;
                }

                public TypeSerializer<Tuple2<WeightVector, Object>> createSerializer(ExecutionConfig executionConfig) {
                    final TypeSerializer[] typeSerializerArr = new TypeSerializer[getArity()];
                    RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), getArity()).foreach$mVc$sp(i2 -> {
                        typeSerializerArr[i2] = this.protected$types(this)[i2].createSerializer(executionConfig);
                    });
                    new ScalaCaseClassSerializer<Tuple2<WeightVector, Object>>(this, typeSerializerArr) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$18$$anon$7
                        /* renamed from: createInstance, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
                        public Tuple2<WeightVector, Object> m161createInstance(Object[] objArr) {
                            return new Tuple2<>((WeightVector) objArr[0], BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(objArr[1])));
                        }

                        {
                            Class typeClass = this.getTypeClass();
                        }
                    };
                    return new ScalaCaseClassSerializer(getTypeClass(), typeSerializerArr);
                }

                /* JADX WARN: Illegal instructions before constructor call */
                {
                    /*
                        r11 = this;
                        r0 = r11
                        java.lang.Class<scala.Tuple2> r1 = scala.Tuple2.class
                        scala.collection.immutable.$colon$colon r2 = new scala.collection.immutable.$colon$colon
                        r3 = r2
                        org.apache.flink.ml.optimization.GradientDescent$$anon$18$$anon$19 r4 = new org.apache.flink.ml.optimization.GradientDescent$$anon$18$$anon$19
                        r5 = r4
                        r6 = 0
                        r5.<init>(r6)
                        scala.collection.immutable.$colon$colon r5 = new scala.collection.immutable.$colon$colon
                        r6 = r5
                        java.lang.Class r7 = java.lang.Double.TYPE
                        org.apache.flink.api.common.typeinfo.BasicTypeInfo r7 = org.apache.flink.api.common.typeinfo.BasicTypeInfo.getInfoFor(r7)
                        scala.collection.immutable.Nil$ r8 = scala.collection.immutable.Nil$.MODULE$
                        r6.<init>(r7, r8)
                        r3.<init>(r4, r5)
                        scala.Predef$ r3 = scala.Predef$.MODULE$
                        scala.reflect.ClassTag$ r4 = scala.reflect.ClassTag$.MODULE$
                        java.lang.Class<org.apache.flink.api.common.typeinfo.TypeInformation> r5 = org.apache.flink.api.common.typeinfo.TypeInformation.class
                        scala.reflect.ClassTag r4 = r4.apply(r5)
                        java.lang.Object r3 = r3.implicitly(r4)
                        scala.reflect.ClassTag r3 = (scala.reflect.ClassTag) r3
                        java.lang.Object r2 = r2.toArray(r3)
                        org.apache.flink.api.common.typeinfo.TypeInformation[] r2 = (org.apache.flink.api.common.typeinfo.TypeInformation[]) r2
                        scala.collection.immutable.$colon$colon r3 = new scala.collection.immutable.$colon$colon
                        r4 = r3
                        org.apache.flink.ml.optimization.GradientDescent$$anon$18$$anon$20 r5 = new org.apache.flink.ml.optimization.GradientDescent$$anon$18$$anon$20
                        r6 = r5
                        r7 = 0
                        r6.<init>(r7)
                        scala.collection.immutable.$colon$colon r6 = new scala.collection.immutable.$colon$colon
                        r7 = r6
                        java.lang.Class r8 = java.lang.Double.TYPE
                        org.apache.flink.api.common.typeinfo.BasicTypeInfo r8 = org.apache.flink.api.common.typeinfo.BasicTypeInfo.getInfoFor(r8)
                        scala.collection.immutable.Nil$ r9 = scala.collection.immutable.Nil$.MODULE$
                        r7.<init>(r8, r9)
                        r4.<init>(r5, r6)
                        scala.collection.Seq$ r4 = scala.collection.Seq$.MODULE$
                        scala.Predef$ r5 = scala.Predef$.MODULE$
                        r6 = 2
                        java.lang.String[] r6 = new java.lang.String[r6]
                        r7 = r6
                        r8 = 0
                        java.lang.String r9 = "_1"
                        r7[r8] = r9
                        r7 = r6
                        r8 = 1
                        java.lang.String r9 = "_2"
                        r7[r8] = r9
                        java.lang.Object[] r6 = (java.lang.Object[]) r6
                        scala.collection.mutable.WrappedArray r5 = r5.wrapRefArray(r6)
                        scala.collection.GenTraversable r4 = r4.apply(r5)
                        scala.collection.Seq r4 = (scala.collection.Seq) r4
                        r0.<init>(r1, r2, r3, r4)
                        return
                    */
                    throw new UnsupportedOperationException("Method not decompiled: org.apache.flink.ml.optimization.GradientDescent$$anon$18.<init>(org.apache.flink.ml.optimization.GradientDescent):void");
                }
            }, ClassTag$.MODULE$.apply(Tuple2.class)), package$.MODULE$.RichDataSet(map2).filterWithBcVariable(calculateLoss, (d4, d5) -> {
                return d4 > ((double) 0) && scala.math.package$.MODULE$.abs((d4 - d5) / d4) >= d3;
            }));
        }).map(tuple2 -> {
            return (WeightVector) tuple2._1();
        }, new CaseClassTypeInfo<WeightVector>(gradientDescent) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$21
            public /* synthetic */ TypeInformation[] protected$types(GradientDescent$$anon$21 gradientDescent$$anon$21) {
                return gradientDescent$$anon$21.types;
            }

            public TypeSerializer<WeightVector> createSerializer(ExecutionConfig executionConfig) {
                final TypeSerializer[] typeSerializerArr = new TypeSerializer[getArity()];
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), getArity()).foreach$mVc$sp(i2 -> {
                    typeSerializerArr[i2] = this.protected$types(this)[i2].createSerializer(executionConfig);
                });
                new ScalaCaseClassSerializer<WeightVector>(this, typeSerializerArr) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$21$$anon$8
                    /* renamed from: createInstance, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
                    public WeightVector m163createInstance(Object[] objArr) {
                        return new WeightVector((Vector) objArr[0], BoxesRunTime.unboxToDouble(objArr[1]));
                    }

                    {
                        Class typeClass = this.getTypeClass();
                    }
                };
                return new ScalaCaseClassSerializer(getTypeClass(), typeSerializerArr);
            }

            {
                super(WeightVector.class, (TypeInformation[]) Nil$.MODULE$.toArray((ClassTag) Predef$.MODULE$.implicitly(ClassTag$.MODULE$.apply(TypeInformation.class))), new $colon.colon(TypeExtractor.createTypeInfo(Vector.class), new $colon.colon(BasicTypeInfo.getInfoFor(Double.TYPE), Nil$.MODULE$)), Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{"weights", "intercept"})));
            }
        }, ClassTag$.MODULE$.apply(WeightVector.class));
    }

    public DataSet<WeightVector> optimizeWithoutConvergenceCriterion(DataSet<LabeledVector> dataSet, DataSet<WeightVector> dataSet2, int i, RegularizationPenalty regularizationPenalty, double d, double d2, LossFunction lossFunction, LearningRateMethod.LearningRateMethodTrait learningRateMethodTrait) {
        return dataSet2.iterate(i, dataSet3 -> {
            return this.SGDStep(dataSet, dataSet3, lossFunction, regularizationPenalty, d, d2, learningRateMethodTrait);
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public DataSet<WeightVector> SGDStep(DataSet<LabeledVector> dataSet, DataSet<WeightVector> dataSet2, LossFunction lossFunction, RegularizationPenalty regularizationPenalty, double d, double d2, LearningRateMethod.LearningRateMethodTrait learningRateMethodTrait) {
        final GradientDescent gradientDescent = null;
        return package$.MODULE$.RichDataSet(package$.MODULE$.RichDataSet(dataSet).mapWithBcVariable(dataSet2, (labeledVector, weightVector) -> {
            return new Tuple2(lossFunction.gradient(labeledVector, weightVector), BoxesRunTime.boxToInteger(1));
        }, new CaseClassTypeInfo<Tuple2<WeightVector, Object>>(this) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$22
            public /* synthetic */ TypeInformation[] protected$types(GradientDescent$$anon$22 gradientDescent$$anon$22) {
                return gradientDescent$$anon$22.types;
            }

            public TypeSerializer<Tuple2<WeightVector, Object>> createSerializer(ExecutionConfig executionConfig) {
                final TypeSerializer[] typeSerializerArr = new TypeSerializer[getArity()];
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), getArity()).foreach$mVc$sp(i -> {
                    typeSerializerArr[i] = this.protected$types(this)[i].createSerializer(executionConfig);
                });
                new ScalaCaseClassSerializer<Tuple2<WeightVector, Object>>(this, typeSerializerArr) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$22$$anon$11
                    /* renamed from: createInstance, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
                    public Tuple2<WeightVector, Object> m169createInstance(Object[] objArr) {
                        return new Tuple2<>((WeightVector) objArr[0], BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(objArr[1])));
                    }

                    {
                        Class typeClass = this.getTypeClass();
                    }
                };
                return new ScalaCaseClassSerializer(getTypeClass(), typeSerializerArr);
            }

            /* JADX WARN: Illegal instructions before constructor call */
            {
                /*
                    r11 = this;
                    r0 = r11
                    java.lang.Class<scala.Tuple2> r1 = scala.Tuple2.class
                    scala.collection.immutable.$colon$colon r2 = new scala.collection.immutable.$colon$colon
                    r3 = r2
                    org.apache.flink.ml.optimization.GradientDescent$$anon$22$$anon$23 r4 = new org.apache.flink.ml.optimization.GradientDescent$$anon$22$$anon$23
                    r5 = r4
                    r6 = 0
                    r5.<init>(r6)
                    scala.collection.immutable.$colon$colon r5 = new scala.collection.immutable.$colon$colon
                    r6 = r5
                    java.lang.Class r7 = java.lang.Integer.TYPE
                    org.apache.flink.api.common.typeinfo.BasicTypeInfo r7 = org.apache.flink.api.common.typeinfo.BasicTypeInfo.getInfoFor(r7)
                    scala.collection.immutable.Nil$ r8 = scala.collection.immutable.Nil$.MODULE$
                    r6.<init>(r7, r8)
                    r3.<init>(r4, r5)
                    scala.Predef$ r3 = scala.Predef$.MODULE$
                    scala.reflect.ClassTag$ r4 = scala.reflect.ClassTag$.MODULE$
                    java.lang.Class<org.apache.flink.api.common.typeinfo.TypeInformation> r5 = org.apache.flink.api.common.typeinfo.TypeInformation.class
                    scala.reflect.ClassTag r4 = r4.apply(r5)
                    java.lang.Object r3 = r3.implicitly(r4)
                    scala.reflect.ClassTag r3 = (scala.reflect.ClassTag) r3
                    java.lang.Object r2 = r2.toArray(r3)
                    org.apache.flink.api.common.typeinfo.TypeInformation[] r2 = (org.apache.flink.api.common.typeinfo.TypeInformation[]) r2
                    scala.collection.immutable.$colon$colon r3 = new scala.collection.immutable.$colon$colon
                    r4 = r3
                    org.apache.flink.ml.optimization.GradientDescent$$anon$22$$anon$24 r5 = new org.apache.flink.ml.optimization.GradientDescent$$anon$22$$anon$24
                    r6 = r5
                    r7 = 0
                    r6.<init>(r7)
                    scala.collection.immutable.$colon$colon r6 = new scala.collection.immutable.$colon$colon
                    r7 = r6
                    java.lang.Class r8 = java.lang.Integer.TYPE
                    org.apache.flink.api.common.typeinfo.BasicTypeInfo r8 = org.apache.flink.api.common.typeinfo.BasicTypeInfo.getInfoFor(r8)
                    scala.collection.immutable.Nil$ r9 = scala.collection.immutable.Nil$.MODULE$
                    r7.<init>(r8, r9)
                    r4.<init>(r5, r6)
                    scala.collection.Seq$ r4 = scala.collection.Seq$.MODULE$
                    scala.Predef$ r5 = scala.Predef$.MODULE$
                    r6 = 2
                    java.lang.String[] r6 = new java.lang.String[r6]
                    r7 = r6
                    r8 = 0
                    java.lang.String r9 = "_1"
                    r7[r8] = r9
                    r7 = r6
                    r8 = 1
                    java.lang.String r9 = "_2"
                    r7[r8] = r9
                    java.lang.Object[] r6 = (java.lang.Object[]) r6
                    scala.collection.mutable.WrappedArray r5 = r5.wrapRefArray(r6)
                    scala.collection.GenTraversable r4 = r4.apply(r5)
                    scala.collection.Seq r4 = (scala.collection.Seq) r4
                    r0.<init>(r1, r2, r3, r4)
                    return
                */
                throw new UnsupportedOperationException("Method not decompiled: org.apache.flink.ml.optimization.GradientDescent$$anon$22.<init>(org.apache.flink.ml.optimization.GradientDescent):void");
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class)).reduce((tuple2, tuple22) -> {
            DenseVector denseVector;
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            Tuple2 tuple2 = new Tuple2((WeightVector) tuple2._1(), BoxesRunTime.boxToInteger(tuple2._2$mcI$sp()));
            WeightVector weightVector2 = (WeightVector) tuple2._1();
            int _2$mcI$sp = tuple2._2$mcI$sp();
            if (tuple22 == null) {
                throw new MatchError(tuple22);
            }
            Tuple2 tuple22 = new Tuple2((WeightVector) tuple22._1(), BoxesRunTime.boxToInteger(tuple22._2$mcI$sp()));
            WeightVector weightVector3 = (WeightVector) tuple22._1();
            int _2$mcI$sp2 = tuple22._2$mcI$sp();
            Vector weights = weightVector2.weights();
            if (weights instanceof DenseVector) {
                denseVector = (DenseVector) weights;
            } else {
                if (!(weights instanceof SparseVector)) {
                    throw new MatchError(weights);
                }
                denseVector = ((SparseVector) weights).toDenseVector();
            }
            DenseVector denseVector2 = denseVector;
            BLAS$.MODULE$.axpy(1.0d, weightVector3.weights(), denseVector2);
            return new Tuple2(new WeightVector(denseVector2, weightVector2.intercept() + weightVector3.intercept()), BoxesRunTime.boxToInteger(_2$mcI$sp + _2$mcI$sp2));
        })).mapWithBcVariableIteration(dataSet2, (tuple23, weightVector2, obj) -> {
            return $anonfun$SGDStep$3(this, regularizationPenalty, d, d2, learningRateMethodTrait, tuple23, weightVector2, BoxesRunTime.unboxToInt(obj));
        }, new CaseClassTypeInfo<WeightVector>(gradientDescent) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$25
            public /* synthetic */ TypeInformation[] protected$types(GradientDescent$$anon$25 gradientDescent$$anon$25) {
                return gradientDescent$$anon$25.types;
            }

            public TypeSerializer<WeightVector> createSerializer(ExecutionConfig executionConfig) {
                final TypeSerializer[] typeSerializerArr = new TypeSerializer[getArity()];
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), getArity()).foreach$mVc$sp(i -> {
                    typeSerializerArr[i] = this.protected$types(this)[i].createSerializer(executionConfig);
                });
                new ScalaCaseClassSerializer<WeightVector>(this, typeSerializerArr) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$25$$anon$12
                    /* renamed from: createInstance, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
                    public WeightVector m171createInstance(Object[] objArr) {
                        return new WeightVector((Vector) objArr[0], BoxesRunTime.unboxToDouble(objArr[1]));
                    }

                    {
                        Class typeClass = this.getTypeClass();
                    }
                };
                return new ScalaCaseClassSerializer(getTypeClass(), typeSerializerArr);
            }

            {
                super(WeightVector.class, (TypeInformation[]) Nil$.MODULE$.toArray((ClassTag) Predef$.MODULE$.implicitly(ClassTag$.MODULE$.apply(TypeInformation.class))), new $colon.colon(TypeExtractor.createTypeInfo(Vector.class), new $colon.colon(BasicTypeInfo.getInfoFor(Double.TYPE), Nil$.MODULE$)), Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{"weights", "intercept"})));
            }
        }, ClassTag$.MODULE$.apply(WeightVector.class));
    }

    public Vector takeStep(Vector vector, Vector vector2, RegularizationPenalty regularizationPenalty, double d, double d2) {
        return regularizationPenalty.takeStep(vector, vector2, d, d2);
    }

    private DataSet<Object> calculateLoss(DataSet<LabeledVector> dataSet, DataSet<WeightVector> dataSet2, LossFunction lossFunction) {
        final GradientDescent gradientDescent = null;
        return package$.MODULE$.RichDataSet(dataSet).mapWithBcVariable(dataSet2, (labeledVector, weightVector) -> {
            return new Tuple2.mcDI.sp(lossFunction.loss(labeledVector, weightVector), 1);
        }, new CaseClassTypeInfo<Tuple2<Object, Object>>(gradientDescent) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$26
            public /* synthetic */ TypeInformation[] protected$types(GradientDescent$$anon$26 gradientDescent$$anon$26) {
                return gradientDescent$$anon$26.types;
            }

            public TypeSerializer<Tuple2<Object, Object>> createSerializer(ExecutionConfig executionConfig) {
                final TypeSerializer[] typeSerializerArr = new TypeSerializer[getArity()];
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), getArity()).foreach$mVc$sp(i -> {
                    typeSerializerArr[i] = this.protected$types(this)[i].createSerializer(executionConfig);
                });
                new ScalaCaseClassSerializer<Tuple2<Object, Object>>(this, typeSerializerArr) { // from class: org.apache.flink.ml.optimization.GradientDescent$$anon$26$$anon$13
                    /* renamed from: createInstance, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
                    public Tuple2<Object, Object> m173createInstance(Object[] objArr) {
                        return new Tuple2.mcDI.sp(BoxesRunTime.unboxToDouble(objArr[0]), BoxesRunTime.unboxToInt(objArr[1]));
                    }

                    {
                        Class typeClass = this.getTypeClass();
                    }
                };
                return new ScalaCaseClassSerializer(getTypeClass(), typeSerializerArr);
            }

            {
                super(Tuple2.class, (TypeInformation[]) new $colon.colon(BasicTypeInfo.getInfoFor(Double.TYPE), new $colon.colon(BasicTypeInfo.getInfoFor(Integer.TYPE), Nil$.MODULE$)).toArray((ClassTag) Predef$.MODULE$.implicitly(ClassTag$.MODULE$.apply(TypeInformation.class))), new $colon.colon(BasicTypeInfo.getInfoFor(Double.TYPE), new $colon.colon(BasicTypeInfo.getInfoFor(Integer.TYPE), Nil$.MODULE$)), Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{"_1", "_2"})));
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class)).reduce((tuple2, tuple22) -> {
            return new Tuple2.mcDI.sp(tuple2._1$mcD$sp() + tuple22._1$mcD$sp(), tuple2._2$mcI$sp() + tuple22._2$mcI$sp());
        }).map(tuple23 -> {
            return BoxesRunTime.boxToDouble($anonfun$calculateLoss$3(tuple23));
        }, BasicTypeInfo.getInfoFor(Double.TYPE), ClassTag$.MODULE$.Double());
    }

    public static final /* synthetic */ Tuple2 $anonfun$optimizeWithConvergenceCriterion$1(WeightVector weightVector, double d) {
        return new Tuple2(weightVector, BoxesRunTime.boxToDouble(d));
    }

    public static final /* synthetic */ Tuple2 $anonfun$optimizeWithConvergenceCriterion$6(WeightVector weightVector, double d) {
        return new Tuple2(weightVector, BoxesRunTime.boxToDouble(d));
    }

    public static final /* synthetic */ WeightVector $anonfun$SGDStep$3(GradientDescent gradientDescent, RegularizationPenalty regularizationPenalty, double d, double d2, LearningRateMethod.LearningRateMethodTrait learningRateMethodTrait, Tuple2 tuple2, WeightVector weightVector, int i) {
        if (tuple2 != null) {
            WeightVector weightVector2 = (WeightVector) tuple2._1();
            int _2$mcI$sp = tuple2._2$mcI$sp();
            if (weightVector2 != null) {
                Tuple3 tuple3 = new Tuple3(weightVector2.weights(), BoxesRunTime.boxToDouble(weightVector2.intercept()), BoxesRunTime.boxToInteger(_2$mcI$sp));
                Vector vector = (Vector) tuple3._1();
                double unboxToDouble = BoxesRunTime.unboxToDouble(tuple3._2());
                int unboxToInt = BoxesRunTime.unboxToInt(tuple3._3());
                BLAS$.MODULE$.scal(1.0d / unboxToInt, vector);
                WeightVector weightVector3 = new WeightVector(vector, unboxToDouble / unboxToInt);
                double calculateLearningRate = learningRateMethodTrait.calculateLearningRate(d2, i, d);
                return new WeightVector(gradientDescent.takeStep(weightVector.weights(), weightVector3.weights(), regularizationPenalty, d, calculateLearningRate), weightVector.intercept() - (calculateLearningRate * weightVector3.intercept()));
            }
        }
        throw new MatchError(tuple2);
    }

    public static final /* synthetic */ double $anonfun$calculateLoss$3(Tuple2 tuple2) {
        return tuple2._1$mcD$sp() / tuple2._2$mcI$sp();
    }
}
