package com.whylogs.core.metrics;

import com.google.common.base.Preconditions;
import com.whylogs.core.message.RegressionMetricsMessage;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/whylogs/core/metrics/RegressionMetrics.class */
public class RegressionMetrics {
    private static final Logger log = LoggerFactory.getLogger(RegressionMetrics.class);
    private final String predictionField;
    private final String targetField;
    private double sumAbsDiff;
    private double sumDiff;
    private double sum2Diff;
    private long count;

    public void track(Map<String, ?> map) {
        Preconditions.checkState(this.predictionField != null);
        Preconditions.checkState(this.targetField != null);
        double doubleValue = ((Number) map.get(this.predictionField)).doubleValue() - ((Number) map.get(this.targetField)).doubleValue();
        this.sumAbsDiff += Math.abs(doubleValue);
        this.sumDiff += doubleValue;
        this.sum2Diff += doubleValue * doubleValue;
        this.count++;
    }

    public RegressionMetrics copy() {
        RegressionMetrics regressionMetrics = new RegressionMetrics(this.predictionField, this.targetField);
        regressionMetrics.sumAbsDiff = this.sumAbsDiff;
        regressionMetrics.sumDiff = this.sumDiff;
        regressionMetrics.sum2Diff = this.sum2Diff;
        regressionMetrics.count = this.count;
        return regressionMetrics;
    }

    public RegressionMetrics merge(RegressionMetrics regressionMetrics) {
        if (regressionMetrics == null) {
            return copy();
        }
        Preconditions.checkState(Objects.equals(this.predictionField, regressionMetrics.predictionField), "Mismatched prediction fields: %s vs %s", new Object[]{this.predictionField, regressionMetrics.predictionField});
        Preconditions.checkState(Objects.equals(this.targetField, regressionMetrics.targetField), "Mismatched target fields: %s vs %s", new Object[]{this.targetField, regressionMetrics.targetField});
        RegressionMetrics regressionMetrics2 = new RegressionMetrics(this.predictionField, this.targetField);
        regressionMetrics2.sumAbsDiff = this.sumAbsDiff + regressionMetrics.sumAbsDiff;
        regressionMetrics2.sumDiff = this.sumDiff + regressionMetrics.sumDiff;
        regressionMetrics2.sum2Diff = this.sum2Diff + regressionMetrics.sum2Diff;
        regressionMetrics2.count = this.count + regressionMetrics.count;
        return regressionMetrics2;
    }

    public RegressionMetricsMessage.Builder toProtobuf() {
        if ((this.predictionField == null) || (this.targetField == null)) {
            return null;
        }
        return RegressionMetricsMessage.newBuilder().setPredictionField(this.predictionField).setTargetField(this.targetField).setSumAbsDiff(this.sumAbsDiff).setSumDiff(this.sumDiff).setSum2Diff(this.sum2Diff).setCount(this.count);
    }

    public static RegressionMetrics fromProtobuf(RegressionMetricsMessage regressionMetricsMessage) {
        if (regressionMetricsMessage == null) {
            return null;
        }
        if ("".equals(regressionMetricsMessage.getPredictionField()) || "".equals(regressionMetricsMessage.getTargetField())) {
            log.warn("Skipping Regression metrics: prediction or target field not set");
            return null;
        }
        RegressionMetrics regressionMetrics = new RegressionMetrics(regressionMetricsMessage.getPredictionField(), regressionMetricsMessage.getTargetField());
        regressionMetrics.sumAbsDiff = regressionMetricsMessage.getSumAbsDiff();
        regressionMetrics.sumDiff = regressionMetricsMessage.getSumDiff();
        regressionMetrics.sum2Diff = regressionMetricsMessage.getSum2Diff();
        regressionMetrics.count = regressionMetricsMessage.getCount();
        return regressionMetrics;
    }

    public RegressionMetrics(String str, String str2) {
        this.predictionField = str;
        this.targetField = str2;
    }

    public String getPredictionField() {
        return this.predictionField;
    }

    public String getTargetField() {
        return this.targetField;
    }

    public double getSumAbsDiff() {
        return this.sumAbsDiff;
    }

    public double getSumDiff() {
        return this.sumDiff;
    }

    public double getSum2Diff() {
        return this.sum2Diff;
    }

    public long getCount() {
        return this.count;
    }
}
