package io.kgraph.library.cf;

import io.kgraph.EdgeWithValue;
import io.kgraph.VertexWithValue;
import io.kgraph.library.basic.EdgeCount;
import io.kgraph.pregel.ComputeFunction;
import io.kgraph.pregel.aggregators.DoubleSumAggregator;
import io.kgraph.pregel.aggregators.LongSumAggregator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import org.jblas.FloatMatrix;

/* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.2.0.jar:io/kgraph/library/cf/Sgd.class */
public class Sgd implements ComputeFunction<CfLongId, FloatMatrix, Float, FloatMatrixMessage> {
    public static final String RMSE_TARGET = "rmse";
    public static final float RMSE_TARGET_DEFAULT = -1.0f;
    public static final String TOLERANCE = "tolerance";
    public static final float TOLERANCE_DEFAULT = -1.0f;
    public static final String ITERATIONS = "iterations";
    public static final int ITERATIONS_DEFAULT = 10;
    public static final String LAMBDA = "lambda";
    public static final float LAMBDA_DEFAULT = 0.01f;
    public static final String GAMMA = "gamma";
    public static final float GAMMA_DEFAULT = 0.005f;
    public static final String VECTOR_SIZE = "dim";
    public static final int VECTOR_SIZE_DEFAULT = 50;
    public static final String MAX_RATING = "max.rating";
    public static final float MAX_RATING_DEFAULT = 5.0f;
    public static final String MIN_RATING = "min.rating";
    public static final float MIN_RATING_DEFAULT = 0.0f;
    public static final String RANDOM_SEED = "random.seed";
    public static final Long RANDOM_SEED_DEFAULT = null;
    public static final String RMSE_AGGREGATOR = "sgd.rmse.aggregator";
    private float tolerance;
    private float lambda;
    private float gamma;
    protected float minRating;
    protected float maxRating;
    private Long randomSeed;
    private FloatMatrix oldValue;
    private Map<String, Object> configs;
    private int maxIterations;
    private float rmseTarget;

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.2.0.jar:io/kgraph/library/cf/Sgd$InitItemsComputation.class */
    public class InitItemsComputation implements ComputeFunction<CfLongId, FloatMatrix, Float, FloatMatrixMessage> {
        public InitItemsComputation() {
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<CfLongId, FloatMatrix> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, FloatMatrix, Float, FloatMatrixMessage> callback) {
            FloatMatrix floatMatrix = new FloatMatrix(((Integer) Sgd.this.configs.getOrDefault("dim", 50)).intValue());
            Random random = Sgd.this.randomSeed != null ? new Random(Sgd.this.randomSeed.longValue()) : new Random();
            for (int i2 = 0; i2 < floatMatrix.length; i2++) {
                floatMatrix.put(i2, 0.01f * random.nextFloat());
            }
            callback.setNewVertexValue(floatMatrix);
            for (FloatMatrixMessage floatMatrixMessage : iterable) {
                callback.addEdge(floatMatrixMessage.getSenderId(), Float.valueOf(floatMatrixMessage.getScore()));
            }
            Iterator<EdgeWithValue<CfLongId, Float>> it = iterable2.iterator();
            while (it.hasNext()) {
                callback.sendMessageTo(it.next().target(), new FloatMatrixMessage(vertexWithValue.id(), floatMatrix, 0.0f));
            }
            callback.voteToHalt();
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.2.0.jar:io/kgraph/library/cf/Sgd$InitUsersComputation.class */
    public class InitUsersComputation implements ComputeFunction<CfLongId, FloatMatrix, Float, FloatMatrixMessage> {
        public InitUsersComputation() {
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<CfLongId, FloatMatrix> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, FloatMatrix, Float, FloatMatrixMessage> callback) {
            FloatMatrix floatMatrix = new FloatMatrix(((Integer) Sgd.this.configs.getOrDefault("dim", 50)).intValue());
            Random random = Sgd.this.randomSeed != null ? new Random(Sgd.this.randomSeed.longValue()) : new Random();
            for (int i2 = 0; i2 < floatMatrix.length; i2++) {
                floatMatrix.put(i2, 0.01f * random.nextFloat());
            }
            callback.setNewVertexValue(floatMatrix);
            for (EdgeWithValue<CfLongId, Float> edgeWithValue : iterable2) {
                callback.sendMessageTo(edgeWithValue.target(), new FloatMatrixMessage(vertexWithValue.id(), floatMatrix, edgeWithValue.value().floatValue()));
            }
            callback.voteToHalt();
        }
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public void preSuperstep(int i, ComputeFunction.Aggregates aggregates) {
        this.lambda = ((Float) this.configs.getOrDefault("lambda", Float.valueOf(0.01f))).floatValue();
        this.gamma = ((Float) this.configs.getOrDefault(GAMMA, Float.valueOf(0.005f))).floatValue();
        this.tolerance = ((Float) this.configs.getOrDefault("tolerance", Float.valueOf(-1.0f))).floatValue();
        this.minRating = ((Float) this.configs.getOrDefault("min.rating", Float.valueOf(0.0f))).floatValue();
        this.maxRating = ((Float) this.configs.getOrDefault("max.rating", Float.valueOf(5.0f))).floatValue();
        this.randomSeed = (Long) this.configs.getOrDefault("random.seed", RANDOM_SEED_DEFAULT);
    }

    public void superstepCompute(int i, VertexWithValue<CfLongId, FloatMatrix> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, FloatMatrix, Float, FloatMatrixMessage> callback) {
        double d = 0.0d;
        if (this.tolerance > 0.0f) {
            this.oldValue = new FloatMatrix(vertexWithValue.value().getRows(), vertexWithValue.value().getColumns(), vertexWithValue.value().data);
        }
        HashMap hashMap = new HashMap();
        for (EdgeWithValue<CfLongId, Float> edgeWithValue : iterable2) {
            hashMap.put(edgeWithValue.target(), edgeWithValue.value());
        }
        for (FloatMatrixMessage floatMatrixMessage : iterable) {
            updateValue(vertexWithValue.value(), floatMatrixMessage.getFactors(), ((Float) hashMap.get(floatMatrixMessage.getSenderId())).floatValue(), this.minRating, this.maxRating, this.lambda, this.gamma);
        }
        for (FloatMatrixMessage floatMatrixMessage2 : iterable) {
            float max = Math.max(Math.min(vertexWithValue.value().dot(floatMatrixMessage2.getFactors()), this.maxRating), this.minRating) - ((Float) hashMap.get(floatMatrixMessage2.getSenderId())).floatValue();
            d += max * max;
        }
        callback.aggregate(RMSE_AGGREGATOR, Double.valueOf(d));
        float distance2 = this.tolerance > 0.0f ? vertexWithValue.value().distance2(this.oldValue) : 0.0f;
        if (this.tolerance < 0.0f || (this.tolerance > 0.0f && distance2 > this.tolerance)) {
            Iterator<EdgeWithValue<CfLongId, Float>> it = iterable2.iterator();
            while (it.hasNext()) {
                callback.sendMessageTo(it.next().target(), new FloatMatrixMessage(vertexWithValue.id(), vertexWithValue.value(), 0.0f));
            }
        }
        callback.setNewVertexValue(vertexWithValue.value());
        callback.voteToHalt();
    }

    protected final void updateValue(FloatMatrix floatMatrix, FloatMatrix floatMatrix2, float f, float f2, float f3, float f4, float f5) {
        floatMatrix.addi(floatMatrix.mul(f4).add(floatMatrix2.mul(Math.max(Math.min(floatMatrix.dot(floatMatrix2), f3), f2) - f)).mul(-f5));
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public final void init(Map<String, ?> map, ComputeFunction.InitCallback initCallback) {
        this.configs = map;
        this.maxIterations = ((Integer) this.configs.getOrDefault("iterations", 10)).intValue();
        this.rmseTarget = ((Float) this.configs.getOrDefault("rmse", Float.valueOf(-1.0f))).floatValue();
        initCallback.registerAggregator(RMSE_AGGREGATOR, DoubleSumAggregator.class);
        initCallback.registerAggregator(EdgeCount.EDGE_COUNT_AGGREGATOR, LongSumAggregator.class, true);
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public final void masterCompute(int i, ComputeFunction.MasterCallback masterCallback) {
        double sqrt = Math.sqrt(((Double) masterCallback.getAggregatedValue(RMSE_AGGREGATOR)).doubleValue() / getTotalNumEdges(masterCallback));
        if (this.rmseTarget > 0.0f && sqrt < this.rmseTarget) {
            masterCallback.haltComputation();
        } else if (i > this.maxIterations) {
            masterCallback.haltComputation();
        }
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public void compute(int i, VertexWithValue<CfLongId, FloatMatrix> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, FloatMatrix, Float, FloatMatrixMessage> callback) {
        if (i == 0) {
            new EdgeCount().compute(i, vertexWithValue, iterable, iterable2, callback);
            return;
        }
        if (i == 1) {
            new InitUsersComputation().compute(i, vertexWithValue, iterable, iterable2, callback);
        } else if (i == 2) {
            new InitItemsComputation().compute(i, vertexWithValue, iterable, iterable2, callback);
        } else {
            superstepCompute(i, vertexWithValue, iterable, iterable2, callback);
        }
    }

    protected long getTotalNumEdges(ComputeFunction.Aggregates aggregates) {
        return ((Long) aggregates.getAggregatedValue(EdgeCount.EDGE_COUNT_AGGREGATOR)).longValue();
    }
}
