package io.kgraph.library.cf;

import io.kgraph.EdgeWithValue;
import io.kgraph.VertexWithValue;
import io.kgraph.library.basic.EdgeCount;
import io.kgraph.pregel.ComputeFunction;
import io.kgraph.pregel.aggregators.DoubleSumAggregator;
import io.kgraph.pregel.aggregators.LongSumAggregator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.TreeMap;
import org.jblas.FloatMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.5.0.jar:io/kgraph/library/cf/Svdpp.class */
public class Svdpp implements ComputeFunction<CfLongId, SvdppValue, Float, FloatMatrixMessage> {
    public static final String OVERALL_RATING_AGGREGATOR = "svd.overall.rating.aggregator";
    public static final String RMSE_TARGET = "rmse";
    public static final float RMSE_TARGET_DEFAULT = -1.0f;
    public static final String ITERATIONS = "iterations";
    public static final int ITERATIONS_DEFAULT = 10;
    public static final String FACTOR_LAMBDA = "lambda.factor";
    public static final float FACTOR_LAMBDA_DEFAULT = 0.01f;
    public static final String FACTOR_GAMMA = "gamma.factor";
    public static final float FACTOR_GAMMA_DEFAULT = 0.005f;
    public static final String BIAS_LAMBDA = "lambda.bias";
    public static final float BIAS_LAMBDA_DEFAULT = 0.01f;
    public static final String BIAS_GAMMA = "gamma.bias";
    public static final float BIAS_GAMMA_DEFAULT = 0.005f;
    public static final String MAX_RATING = "max.rating";
    public static final float MAX_RATING_DEFAULT = 5.0f;
    public static final String MIN_RATING = "min.rating";
    public static final float MIN_RATING_DEFAULT = 0.0f;
    public static final String VECTOR_SIZE = "dim";
    public static final int VECTOR_SIZE_DEFAULT = 50;
    public static final String RANDOM_SEED = "random.seed";
    public static final String RMSE_AGGREGATOR = "svd.rmse.aggregator";
    private Map<String, Object> configs;
    private int maxIterations;
    private float rmseTarget;
    private Long randomSeed;
    private final UserComputation userComputation = new UserComputation();
    private final ItemComputation itemComputation = new ItemComputation();
    private static final Logger log = LoggerFactory.getLogger((Class<?>) Svdpp.class);
    public static final Long RANDOM_SEED_DEFAULT = null;

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.5.0.jar:io/kgraph/library/cf/Svdpp$InitItemsComputation.class */
    public class InitItemsComputation implements ComputeFunction<CfLongId, SvdppValue, Float, FloatMatrixMessage> {
        public InitItemsComputation() {
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<CfLongId, SvdppValue> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, SvdppValue, Float, FloatMatrixMessage> callback) {
            for (FloatMatrixMessage floatMatrixMessage : iterable) {
                callback.addEdge(floatMatrixMessage.getSenderId(), Float.valueOf(floatMatrixMessage.getScore()));
            }
            int intValue = ((Integer) Svdpp.this.configs.getOrDefault("dim", 50)).intValue();
            FloatMatrix floatMatrix = new FloatMatrix(1, intValue);
            FloatMatrix floatMatrix2 = new FloatMatrix(1, intValue);
            Random random = Svdpp.this.randomSeed != null ? new Random(Svdpp.this.randomSeed.longValue()) : new Random();
            for (int i2 = 0; i2 < floatMatrix.length; i2++) {
                floatMatrix.put(i2, 0.01f * random.nextFloat());
                floatMatrix2.put(i2, 0.01f * random.nextFloat());
            }
            float nextFloat = random.nextFloat();
            callback.setNewVertexValue(new SvdppValue(nextFloat, floatMatrix, floatMatrix2));
            FloatMatrix floatMatrix3 = new FloatMatrix(2, intValue);
            floatMatrix3.putRow(0, floatMatrix);
            floatMatrix3.putRow(1, floatMatrix2);
            Iterator<EdgeWithValue<CfLongId, Float>> it = iterable2.iterator();
            while (it.hasNext()) {
                callback.sendMessageTo(it.next().target(), new FloatMatrixMessage(vertexWithValue.id(), floatMatrix3, nextFloat));
            }
            callback.voteToHalt();
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.5.0.jar:io/kgraph/library/cf/Svdpp$InitUsersComputation.class */
    public class InitUsersComputation implements ComputeFunction<CfLongId, SvdppValue, Float, FloatMatrixMessage> {
        public InitUsersComputation() {
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<CfLongId, SvdppValue> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, SvdppValue, Float, FloatMatrixMessage> callback) {
            double d = 0.0d;
            while (iterable2.iterator().hasNext()) {
                d += r0.next().value().floatValue();
            }
            callback.aggregate(Svdpp.OVERALL_RATING_AGGREGATOR, Double.valueOf(d));
            FloatMatrix floatMatrix = new FloatMatrix(1, ((Integer) Svdpp.this.configs.getOrDefault("dim", 50)).intValue());
            Random random = Svdpp.this.randomSeed != null ? new Random(Svdpp.this.randomSeed.longValue()) : new Random();
            for (int i2 = 0; i2 < floatMatrix.length; i2++) {
                floatMatrix.put(i2, 0.01f * random.nextFloat());
            }
            callback.setNewVertexValue(new SvdppValue(random.nextFloat(), floatMatrix, new FloatMatrix(0)));
            for (EdgeWithValue<CfLongId, Float> edgeWithValue : iterable2) {
                callback.sendMessageTo(edgeWithValue.target(), new FloatMatrixMessage(vertexWithValue.id(), new FloatMatrix(0), edgeWithValue.value().floatValue()));
            }
            callback.voteToHalt();
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.5.0.jar:io/kgraph/library/cf/Svdpp$ItemComputation.class */
    public class ItemComputation implements ComputeFunction<CfLongId, SvdppValue, Float, FloatMatrixMessage> {
        private float biasLambda;
        private float biasGamma;
        private float factorLambda;
        private float factorGamma;
        private int vectorSize;

        public ItemComputation() {
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void preSuperstep(int i, ComputeFunction.Aggregators aggregators) {
            this.biasLambda = ((Float) Svdpp.this.configs.getOrDefault(Svdpp.BIAS_LAMBDA, Float.valueOf(0.01f))).floatValue();
            this.biasGamma = ((Float) Svdpp.this.configs.getOrDefault(Svdpp.BIAS_GAMMA, Float.valueOf(0.005f))).floatValue();
            this.factorLambda = ((Float) Svdpp.this.configs.getOrDefault(Svdpp.FACTOR_LAMBDA, Float.valueOf(0.01f))).floatValue();
            this.factorGamma = ((Float) Svdpp.this.configs.getOrDefault(Svdpp.FACTOR_GAMMA, Float.valueOf(0.005f))).floatValue();
            this.vectorSize = ((Integer) Svdpp.this.configs.getOrDefault("dim", 50)).intValue();
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<CfLongId, SvdppValue> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, SvdppValue, Float, FloatMatrixMessage> callback) {
            float baseline = vertexWithValue.value().getBaseline();
            FloatMatrix factors = vertexWithValue.value().getFactors();
            FloatMatrix weight = vertexWithValue.value().getWeight();
            for (FloatMatrixMessage floatMatrixMessage : iterable) {
                float score = floatMatrixMessage.getScore();
                FloatMatrix row = floatMatrixMessage.getFactors().getRow(0);
                FloatMatrix row2 = floatMatrixMessage.getFactors().getRow(1);
                baseline = Svdpp.incrementValue(baseline, score, this.biasGamma, this.biasLambda);
                Svdpp.incrementValue(factors, row, this.factorGamma, this.factorLambda);
                Svdpp.incrementValue(weight, row2, this.factorGamma, this.factorLambda);
            }
            FloatMatrix floatMatrix = new FloatMatrix(2, this.vectorSize);
            floatMatrix.putRow(0, factors);
            floatMatrix.putRow(1, weight);
            Iterator<EdgeWithValue<CfLongId, Float>> it = iterable2.iterator();
            while (it.hasNext()) {
                callback.sendMessageTo(it.next().target(), new FloatMatrixMessage(vertexWithValue.id(), floatMatrix, baseline));
            }
            callback.setNewVertexValue(new SvdppValue(baseline, vertexWithValue.value().factors, vertexWithValue.value().weight));
            callback.voteToHalt();
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.5.0.jar:io/kgraph/library/cf/Svdpp$SvdppValue.class */
    public static class SvdppValue {
        private final float baseline;
        private final FloatMatrix factors;
        private final FloatMatrix weight;

        public SvdppValue(float f, FloatMatrix floatMatrix, FloatMatrix floatMatrix2) {
            this.baseline = f;
            this.factors = floatMatrix;
            this.weight = floatMatrix2;
        }

        public float getBaseline() {
            return this.baseline;
        }

        public FloatMatrix getFactors() {
            return this.factors;
        }

        public FloatMatrix getWeight() {
            return this.weight;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            SvdppValue svdppValue = (SvdppValue) obj;
            return Float.compare(svdppValue.baseline, this.baseline) == 0 && Objects.equals(this.factors, svdppValue.factors) && Objects.equals(this.weight, svdppValue.weight);
        }

        public int hashCode() {
            return Objects.hash(Float.valueOf(this.baseline), this.factors, this.weight);
        }

        public String toString() {
            return "(" + this.baseline + ", " + this.factors.toString() + ")";
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.5.0.jar:io/kgraph/library/cf/Svdpp$UserComputation.class */
    public class UserComputation implements ComputeFunction<CfLongId, SvdppValue, Float, FloatMatrixMessage> {
        private float biasLambda;
        private float biasGamma;
        private float factorLambda;
        private float factorGamma;
        private float minRating;
        private float maxRating;
        private int vectorSize;
        private float meanRating;

        public UserComputation() {
        }

        protected void updateValue(FloatMatrix floatMatrix, FloatMatrix floatMatrix2, float f, float f2, float f3) {
            floatMatrix.addi(floatMatrix.mul((-f3) * f2).addi(floatMatrix2.mul(f * f2)));
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void preSuperstep(int i, ComputeFunction.Aggregators aggregators) {
            this.factorLambda = ((Float) Svdpp.this.configs.getOrDefault(Svdpp.FACTOR_LAMBDA, Float.valueOf(0.01f))).floatValue();
            this.factorGamma = ((Float) Svdpp.this.configs.getOrDefault(Svdpp.FACTOR_GAMMA, Float.valueOf(0.005f))).floatValue();
            this.biasLambda = ((Float) Svdpp.this.configs.getOrDefault(Svdpp.BIAS_LAMBDA, Float.valueOf(0.01f))).floatValue();
            this.biasGamma = ((Float) Svdpp.this.configs.getOrDefault(Svdpp.BIAS_GAMMA, Float.valueOf(0.005f))).floatValue();
            this.minRating = ((Float) Svdpp.this.configs.getOrDefault("min.rating", Float.valueOf(0.0f))).floatValue();
            this.maxRating = ((Float) Svdpp.this.configs.getOrDefault("max.rating", Float.valueOf(5.0f))).floatValue();
            this.vectorSize = ((Integer) Svdpp.this.configs.getOrDefault("dim", 50)).intValue();
            this.meanRating = (float) (((Double) aggregators.getAggregatedValue(Svdpp.OVERALL_RATING_AGGREGATOR)).doubleValue() / (Svdpp.this.getTotalNumEdges(aggregators) * 2));
            Svdpp.this.randomSeed = (Long) Svdpp.this.configs.getOrDefault("random.seed", Svdpp.RANDOM_SEED_DEFAULT);
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<CfLongId, SvdppValue> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, SvdppValue, Float, FloatMatrixMessage> callback) {
            double d = 0.0d;
            float baseline = vertexWithValue.value().getBaseline();
            int i2 = 0;
            HashMap hashMap = new HashMap();
            for (EdgeWithValue<CfLongId, Float> edgeWithValue : iterable2) {
                i2++;
                hashMap.put(edgeWithValue.target(), edgeWithValue.value());
            }
            FloatMatrix factors = vertexWithValue.value().getFactors();
            TreeMap treeMap = new TreeMap();
            for (FloatMatrixMessage floatMatrixMessage : iterable) {
                treeMap.put(floatMatrixMessage.getSenderId(), floatMatrixMessage);
            }
            FloatMatrix floatMatrix = new FloatMatrix(1, this.vectorSize);
            Iterator it = treeMap.values().iterator();
            while (it.hasNext()) {
                floatMatrix.addi(((FloatMatrixMessage) it.next()).getFactors().getRow(1));
            }
            FloatMatrix floatMatrix2 = new FloatMatrix(1, this.vectorSize);
            for (FloatMatrixMessage floatMatrixMessage2 : treeMap.values()) {
                FloatMatrix row = floatMatrixMessage2.getFactors().getRow(0);
                float score = floatMatrixMessage2.getScore();
                float floatValue = ((Float) hashMap.get(floatMatrixMessage2.getSenderId())).floatValue();
                float computePredictedRating = Svdpp.computePredictedRating(this.meanRating, baseline, score, factors, row, i2, floatMatrix, this.minRating, this.maxRating);
                float f = computePredictedRating - floatValue;
                baseline = Svdpp.computeUpdatedBaseLine(baseline, computePredictedRating, floatValue, this.biasGamma, this.biasLambda);
                updateValue(factors, row, f, this.factorGamma, this.factorLambda);
                floatMatrix2.addi(row.mul(f));
            }
            callback.setNewVertexValue(new SvdppValue(baseline, vertexWithValue.value().factors, vertexWithValue.value().weight));
            floatMatrix2.muli(this.factorGamma / ((float) Math.sqrt(i2)));
            for (FloatMatrixMessage floatMatrixMessage3 : treeMap.values()) {
                FloatMatrix row2 = floatMatrixMessage3.getFactors().getRow(0);
                float computePredictedRating2 = Svdpp.computePredictedRating(this.meanRating, baseline, floatMatrixMessage3.getScore(), factors, row2, i2, floatMatrix, this.minRating, this.maxRating) - ((Float) hashMap.get(floatMatrixMessage3.getSenderId())).floatValue();
                float f2 = this.biasGamma * computePredictedRating2;
                FloatMatrix mul = floatMatrix.mul(1.0f / ((float) Math.sqrt(i2))).add(factors).mul(this.factorGamma * computePredictedRating2);
                FloatMatrix floatMatrix3 = new FloatMatrix(2, this.vectorSize);
                floatMatrix3.putRow(0, mul);
                floatMatrix3.putRow(1, floatMatrix2);
                d += computePredictedRating2 * computePredictedRating2;
                callback.sendMessageTo(floatMatrixMessage3.getSenderId(), new FloatMatrixMessage(vertexWithValue.id(), floatMatrix3, f2));
            }
            callback.aggregate(Svdpp.RMSE_AGGREGATOR, Double.valueOf(d));
            callback.voteToHalt();
        }
    }

    protected static float computePredictedRating(float f, float f2, float f3, FloatMatrix floatMatrix, FloatMatrix floatMatrix2, int i, FloatMatrix floatMatrix3, float f4, float f5) {
        return Math.max(Math.min(f + f2 + f3 + floatMatrix2.dot(floatMatrix.add(floatMatrix3.mul(1.0f / ((float) Math.sqrt(i))))), f5), f4);
    }

    protected static float computeUpdatedBaseLine(float f, float f2, float f3, float f4, float f5) {
        return f + (f4 * ((f2 - f3) - (f5 * f)));
    }

    protected static float incrementValue(float f, float f2, float f3, float f4) {
        return (f + f2) - ((f3 * f4) * f);
    }

    protected static void incrementValue(FloatMatrix floatMatrix, FloatMatrix floatMatrix2, float f, float f2) {
        floatMatrix.addi(floatMatrix.mul((-f) * f2).addi(floatMatrix2));
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public final void init(Map<String, ?> map, ComputeFunction.InitCallback initCallback) {
        this.configs = map;
        this.maxIterations = ((Integer) this.configs.getOrDefault("iterations", 10)).intValue();
        this.rmseTarget = ((Float) this.configs.getOrDefault("rmse", Float.valueOf(-1.0f))).floatValue();
        this.randomSeed = (Long) this.configs.getOrDefault("random.seed", RANDOM_SEED_DEFAULT);
        initCallback.registerAggregator(EdgeCount.EDGE_COUNT_AGGREGATOR, LongSumAggregator.class, true);
        initCallback.registerAggregator(RMSE_AGGREGATOR, DoubleSumAggregator.class);
        initCallback.registerAggregator(OVERALL_RATING_AGGREGATOR, DoubleSumAggregator.class, true);
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public final void masterCompute(int i, ComputeFunction.MasterCallback masterCallback) {
        double sqrt = Math.sqrt(((Double) masterCallback.getAggregatedValue(RMSE_AGGREGATOR)).doubleValue() / getTotalNumEdges(masterCallback));
        if (this.rmseTarget > 0.0f && sqrt < this.rmseTarget) {
            masterCallback.haltComputation();
        } else if (i > this.maxIterations) {
            masterCallback.haltComputation();
        }
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public void preSuperstep(int i, ComputeFunction.Aggregators aggregators) {
        if (i <= 2) {
            return;
        }
        if (i % 2 != 0) {
            this.userComputation.preSuperstep(i, aggregators);
        } else {
            this.itemComputation.preSuperstep(i, aggregators);
        }
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public void compute(int i, VertexWithValue<CfLongId, SvdppValue> vertexWithValue, Iterable<FloatMatrixMessage> iterable, Iterable<EdgeWithValue<CfLongId, Float>> iterable2, ComputeFunction.Callback<CfLongId, SvdppValue, Float, FloatMatrixMessage> callback) {
        if (i == 0) {
            new EdgeCount().compute(i, vertexWithValue, iterable, iterable2, callback);
            return;
        }
        if (i == 1) {
            new InitUsersComputation().compute(i, vertexWithValue, iterable, iterable2, callback);
            return;
        }
        if (i == 2) {
            new InitItemsComputation().compute(i, vertexWithValue, iterable, iterable2, callback);
        } else if (i % 2 != 0) {
            this.userComputation.compute(i, vertexWithValue, iterable, iterable2, callback);
        } else {
            this.itemComputation.compute(i, vertexWithValue, iterable, iterable2, callback);
        }
    }

    protected long getTotalNumEdges(ComputeFunction.ReadAggregators readAggregators) {
        return ((Long) readAggregators.getAggregatedValue(EdgeCount.EDGE_COUNT_AGGREGATOR)).longValue();
    }
}
