package io.kgraph.library.cf;

import io.kgraph.EdgeWithValue;
import io.kgraph.VertexWithValue;
import io.kgraph.library.basic.EdgeCount;
import io.kgraph.pregel.ComputeFunction;
import io.kgraph.pregel.aggregators.DoubleSumAggregator;
import io.kgraph.pregel.aggregators.LongSumAggregator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import org.jblas.FloatMatrix;
import org.jblas.JavaBlas;
import org.jblas.Solve;

/* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.1.2.jar:io/kgraph/library/cf/Als.class */
public class Als implements ComputeFunction<CfLongId, FloatMatrix, Float, FloatMatrixMessage> {
    public static final String RMSE_TARGET = "rmse";
    public static final float RMSE_TARGET_DEFAULT = -1.0f;
    public static final String ITERATIONS = "iterations";
    public static final int ITERATIONS_DEFAULT = 10;
    public static final String LAMBDA = "lambda";
    public static final float LAMBDA_DEFAULT = 0.01f;
    public static final String VECTOR_SIZE = "dim";
    public static final int VECTOR_SIZE_DEFAULT = 50;
    public static final String RMSE_AGGREGATOR = "als.rmse.aggregator";
    private float lambda;
    private int vectorSize;
    private Map<String, Object> configs;
    private int maxIterations;
    private float rmseTarget;

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.1.2.jar:io/kgraph/library/cf/Als$InitItemsComputation.class */
    public class InitItemsComputation implements ComputeFunction<CfLongId, FloatMatrix, Float, FloatMatrixMessage> {
        public InitItemsComputation() {
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<CfLongId, FloatMatrix> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, FloatMatrix, Float, FloatMatrixMessage> callback) {
            FloatMatrix floatMatrix = new FloatMatrix(((Integer) Als.this.configs.getOrDefault("dim", 50)).intValue());
            Random random = new Random(0L);
            for (int i2 = 0; i2 < floatMatrix.length; i2++) {
                floatMatrix.put(i2, 0.01f * random.nextFloat());
            }
            callback.setNewVertexValue(floatMatrix);
            for (FloatMatrixMessage floatMatrixMessage : iterable) {
                callback.addEdge(floatMatrixMessage.getSenderId(), Float.valueOf(floatMatrixMessage.getScore()));
            }
            Iterator<EdgeWithValue<CfLongId, Float>> it = iterable2.iterator();
            while (it.hasNext()) {
                callback.sendMessageTo(it.next().target(), new FloatMatrixMessage(vertexWithValue.id(), vertexWithValue.value(), 0.0f));
            }
            callback.voteToHalt();
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.1.2.jar:io/kgraph/library/cf/Als$InitUsersComputation.class */
    public class InitUsersComputation implements ComputeFunction<CfLongId, FloatMatrix, Float, FloatMatrixMessage> {
        public InitUsersComputation() {
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<CfLongId, FloatMatrix> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, FloatMatrix, Float, FloatMatrixMessage> callback) {
            FloatMatrix floatMatrix = new FloatMatrix(((Integer) Als.this.configs.getOrDefault("dim", 50)).intValue());
            Random random = new Random(0L);
            for (int i2 = 0; i2 < floatMatrix.length; i2++) {
                floatMatrix.put(i2, 0.01f * random.nextFloat());
            }
            callback.setNewVertexValue(floatMatrix);
            for (EdgeWithValue<CfLongId, Float> edgeWithValue : iterable2) {
                callback.sendMessageTo(edgeWithValue.target(), new FloatMatrixMessage(vertexWithValue.id(), vertexWithValue.value(), edgeWithValue.value().floatValue()));
            }
            callback.voteToHalt();
        }
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public void preSuperstep(int i, ComputeFunction.Aggregators aggregators) {
        this.lambda = ((Float) this.configs.getOrDefault("lambda", Float.valueOf(0.01f))).floatValue();
        this.vectorSize = ((Integer) this.configs.getOrDefault("dim", 50)).intValue();
    }

    public void superstepCompute(int i, VertexWithValue<CfLongId, FloatMatrix> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, FloatMatrix, Float, FloatMatrixMessage> callback) {
        int i2 = 0;
        HashMap hashMap = new HashMap();
        for (EdgeWithValue<CfLongId, Float> edgeWithValue : iterable2) {
            i2++;
            hashMap.put(edgeWithValue.target(), edgeWithValue.value());
        }
        FloatMatrix floatMatrix = new FloatMatrix(this.vectorSize, i2);
        FloatMatrix floatMatrix2 = new FloatMatrix(i2, 1);
        int i3 = 0;
        for (FloatMatrixMessage floatMatrixMessage : iterable) {
            floatMatrix.putColumn(i3, floatMatrixMessage.getFactors());
            floatMatrix2.put(i3, 0, ((Float) hashMap.get(floatMatrixMessage.getSenderId())).floatValue());
            i3++;
        }
        updateValue(vertexWithValue.value(), floatMatrix, floatMatrix2, this.lambda);
        double d = 0.0d;
        for (int i4 = 0; i4 < floatMatrix.columns; i4++) {
            double dot = vertexWithValue.value().dot(floatMatrix.getColumn(i4)) - floatMatrix2.get(i4, 0);
            d += dot * dot;
        }
        callback.aggregate(RMSE_AGGREGATOR, Double.valueOf(d));
        Iterator<EdgeWithValue<CfLongId, Float>> it = iterable2.iterator();
        while (it.hasNext()) {
            callback.sendMessageTo(it.next().target(), new FloatMatrixMessage(vertexWithValue.id(), vertexWithValue.value(), 0.0f));
        }
        callback.setNewVertexValue(vertexWithValue.value());
        callback.voteToHalt();
    }

    protected void updateValue(FloatMatrix floatMatrix, FloatMatrix floatMatrix2, FloatMatrix floatMatrix3, float f) {
        FloatMatrix mmul = floatMatrix2.mmul(floatMatrix3);
        FloatMatrix mmul2 = floatMatrix2.mmul(floatMatrix2.transpose());
        mmul2.addi(FloatMatrix.eye(floatMatrix2.rows).muli(f * floatMatrix3.rows));
        FloatMatrix solve = Solve.solve(mmul2, mmul);
        floatMatrix.rows = solve.rows;
        floatMatrix.columns = solve.columns;
        JavaBlas.rcopy(solve.length, solve.data, 0, 1, floatMatrix.data, 0, 1);
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public final void init(Map<String, ?> map, ComputeFunction.InitCallback initCallback) {
        this.configs = map;
        this.maxIterations = ((Integer) this.configs.getOrDefault("iterations", 10)).intValue();
        this.rmseTarget = ((Float) this.configs.getOrDefault("rmse", Float.valueOf(-1.0f))).floatValue();
        initCallback.registerAggregator(EdgeCount.EDGE_COUNT_AGGREGATOR, LongSumAggregator.class, true);
        initCallback.registerAggregator(RMSE_AGGREGATOR, DoubleSumAggregator.class);
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public final void masterCompute(int i, ComputeFunction.MasterCallback masterCallback) {
        double sqrt = Math.sqrt(((Double) masterCallback.getAggregatedValue(RMSE_AGGREGATOR)).doubleValue() / getTotalNumEdges(masterCallback));
        if (this.rmseTarget > 0.0f && sqrt < this.rmseTarget) {
            masterCallback.haltComputation();
        } else if (i > this.maxIterations) {
            masterCallback.haltComputation();
        }
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public void compute(int i, VertexWithValue<CfLongId, FloatMatrix> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, FloatMatrix, Float, FloatMatrixMessage> callback) {
        if (i == 0) {
            new EdgeCount().compute(i, vertexWithValue, iterable, iterable2, callback);
            return;
        }
        if (i == 1) {
            new InitUsersComputation().compute(i, vertexWithValue, iterable, iterable2, callback);
        } else if (i == 2) {
            new InitItemsComputation().compute(i, vertexWithValue, iterable, iterable2, callback);
        } else {
            superstepCompute(i, vertexWithValue, iterable, iterable2, callback);
        }
    }

    protected long getTotalNumEdges(ComputeFunction.ReadAggregators readAggregators) {
        return ((Long) readAggregators.getAggregatedValue(EdgeCount.EDGE_COUNT_AGGREGATOR)).longValue();
    }
}
