package io.kgraph.tools.library;

import io.kgraph.library.basic.EdgeCount;
import io.kgraph.library.cf.CfLongId;
import io.kgraph.library.cf.Svdpp;
import io.kgraph.rest.server.graph.GraphAlgorithmResultRequest;
import io.kgraph.rest.server.graph.GraphAlgorithmStatus;
import io.kgraph.rest.server.graph.KeyValue;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import org.jblas.FloatMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.propertyeditors.CustomBooleanEditor;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.MediaType;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import picocli.CommandLine;
import reactor.core.publisher.Mono;

@CommandLine.Command(description = {"Predicts a rating for a given user and item."}, name = "svdpp-predict", mixinStandardHelpOptions = true, version = {"svdpp-predict 1.0"})
/* loaded from: input_file:BOOT-INF/classes/io/kgraph/tools/library/SvdppPredictor.class */
public class SvdppPredictor implements Callable<Void> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) SvdppPredictor.class);

    @CommandLine.Parameters(index = CustomBooleanEditor.VALUE_0, description = {"Rest app server."})
    private String restAppServer;

    @CommandLine.Parameters(index = CustomBooleanEditor.VALUE_1, description = {"Pregel graph ID."})
    private String id;

    @CommandLine.Option(names = {"-u", "--user"}, description = {"The user id."})
    private Long user;

    @CommandLine.Option(names = {"-i", "--item"}, description = {"The item id."})
    private Long item;

    public SvdppPredictor() {
    }

    public SvdppPredictor(String str, Long l, Long l2) {
        this.restAppServer = str;
        this.user = l;
        this.item = l2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Type inference failed for: r0v10, types: [org.springframework.web.reactive.function.client.WebClient$RequestHeadersSpec] */
    /* JADX WARN: Type inference failed for: r0v18, types: [org.springframework.web.reactive.function.client.WebClient$RequestHeadersSpec] */
    @Override // java.util.concurrent.Callable
    public Void call() {
        try {
            String str = this.restAppServer;
            if (!str.startsWith("http://")) {
                str = "http://" + str;
            }
            WebClient create = WebClient.create(str);
            GraphAlgorithmStatus graphAlgorithmStatus = (GraphAlgorithmStatus) create.get().uri("/pregel/{id}", this.id).retrieve().bodyToMono(GraphAlgorithmStatus.class).block();
            if (graphAlgorithmStatus == null) {
                log.error("Error: no status found");
                return null;
            }
            Map map = (Map) create.get().uri("/pregel/{id}/configs", this.id).retrieve().bodyToMono(new ParameterizedTypeReference<Map<String, Object>>() { // from class: io.kgraph.tools.library.SvdppPredictor.1
            }).block();
            if (map == null) {
                log.error("Error: no configs found");
                return null;
            }
            float parseDouble = (float) (Double.parseDouble(graphAlgorithmStatus.getAggregates().get(Svdpp.OVERALL_RATING_AGGREGATOR)) / (Long.parseLong(graphAlgorithmStatus.getAggregates().get(EdgeCount.EDGE_COUNT_AGGREGATOR)) * 2));
            float floatValue = ((Number) map.getOrDefault("min.rating", Float.valueOf(0.0f))).floatValue();
            float floatValue2 = ((Number) map.getOrDefault("max.rating", Float.valueOf(5.0f))).floatValue();
            List<Float> floats = getFloats(create, (byte) 0, this.user.longValue());
            Float remove = floats.remove(0);
            FloatMatrix floatMatrix = new FloatMatrix(floats);
            List<Float> floats2 = getFloats(create, (byte) 1, this.item.longValue());
            float floatValue3 = parseDouble + remove.floatValue() + floats2.remove(0).floatValue() + new FloatMatrix(floats2).dot(floatMatrix);
            log.info("Raw rating: " + floatValue3);
            float max = Math.max(Math.min(floatValue3, floatValue2), floatValue);
            log.info("Predicted rating: " + max);
            System.out.println("Predicted rating: " + max);
            return null;
        } catch (WebClientResponseException e) {
            log.error("Error: " + e.getMessage());
            return null;
        }
    }

    private List<Float> getFloats(WebClient webClient, byte b, long j) {
        GraphAlgorithmResultRequest graphAlgorithmResultRequest = new GraphAlgorithmResultRequest();
        graphAlgorithmResultRequest.setKey(new CfLongId(b, j).toString());
        KeyValue keyValue = (KeyValue) ((WebClient.RequestBodySpec) webClient.post().uri("/pregel/{id}/result", this.id)).accept2(MediaType.TEXT_EVENT_STREAM).body((WebClient.RequestBodySpec) Mono.just(graphAlgorithmResultRequest), GraphAlgorithmResultRequest.class).retrieve().bodyToFlux(KeyValue.class).next().block();
        return keyValue == null ? Collections.emptyList() : (List) Arrays.stream(keyValue.getValue().split("(\\(|\\)|\\[|\\]|,\\s)")).filter(str -> {
            return !str.isEmpty();
        }).map(Float::parseFloat).collect(Collectors.toList());
    }

    public static void main(String[] strArr) {
        CommandLine.call(new SvdppPredictor(), strArr);
    }
}
