package org.apache.wayang.apps.sgd;

import java.util.ArrayList;
import org.apache.wayang.core.function.ExecutionContext;
import org.apache.wayang.core.function.FunctionDescriptor;

/* JADX INFO: Access modifiers changed from: package-private */
/* compiled from: SGDImprovedImpl.java */
/* loaded from: input_file:org/apache/wayang/apps/sgd/ComputeLogisticGradientPerPartition.class */
public class ComputeLogisticGradientPerPartition implements FunctionDescriptor.ExtendedSerializableFunction<Iterable<double[]>, Iterable<double[]>> {
    double[] weights;
    double[] sumGradOfPartition;
    int features;

    public ComputeLogisticGradientPerPartition(int i) {
        this.features = i;
        this.sumGradOfPartition = new double[i + 1];
    }

    public Iterable<double[]> apply(Iterable<double[]> iterable) {
        ArrayList arrayList = new ArrayList(1);
        iterable.forEach(dArr -> {
            double d = 0.0d;
            for (int i = 0; i < this.weights.length; i++) {
                d += this.weights[i] * dArr[i + 1];
            }
            for (int i2 = 0; i2 < this.weights.length; i2++) {
                double[] dArr = this.sumGradOfPartition;
                int i3 = i2 + 1;
                dArr[i3] = dArr[i3] + (((1.0d / (1.0d + Math.exp((-1.0d) * d))) - dArr[0]) * dArr[i2 + 1]);
            }
            double[] dArr2 = this.sumGradOfPartition;
            dArr2[0] = dArr2[0] + 1.0d;
        });
        arrayList.add(this.sumGradOfPartition);
        return arrayList;
    }

    public void open(ExecutionContext executionContext) {
        this.weights = (double[]) executionContext.getBroadcast("weights").iterator().next();
        this.sumGradOfPartition = new double[this.features + 1];
    }
}
