package it.unimi.dsi.law.stat;

import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;
import com.martiansoftware.jsap.Switch;
import com.martiansoftware.jsap.UnflaggedOption;
import it.unimi.dsi.Util;
import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import it.unimi.dsi.fastutil.ints.AbstractInt2DoubleFunction;
import it.unimi.dsi.fastutil.ints.Int2DoubleFunction;
import it.unimi.dsi.fastutil.ints.IntArrays;
import it.unimi.dsi.law.util.ExchangeWeigher;
import it.unimi.dsi.law.util.Precision;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:it/unimi/dsi/law/stat/WeightedTau.class */
public class WeightedTau extends CorrelationIndex {
    private static final Logger LOGGER = LoggerFactory.getLogger(WeightedTau.class);
    public static final Int2DoubleFunction HYPERBOLIC_WEIGHER = new HyperbolicWeigher();
    public static final Int2DoubleFunction QUADRATIC_WEIGHER = new QuadraticWeigher();
    public static final Int2DoubleFunction LOGARITHMIC_WEIGHER = new LogarithmicWeigher();
    public static final Int2DoubleFunction ZERO_WEIGHER = new ZeroWeigher();
    public static final WeightedTau HYPERBOLIC = new WeightedTau();
    private final Int2DoubleFunction weigher;
    private final boolean multiplicative;

    /* loaded from: input_file:it/unimi/dsi/law/stat/WeightedTau$AbstractWeigher.class */
    public static abstract class AbstractWeigher extends AbstractInt2DoubleFunction {
        private static final long serialVersionUID = 1;

        public boolean containsKey(int i) {
            return i >= 0;
        }

        public int size() {
            return -1;
        }
    }

    /* loaded from: input_file:it/unimi/dsi/law/stat/WeightedTau$HyperbolicWeigher.class */
    private static final class HyperbolicWeigher extends AbstractWeigher {
        private static final long serialVersionUID = 1;

        private HyperbolicWeigher() {
        }

        public double get(int i) {
            return 1.0d / (i + 1);
        }
    }

    /* loaded from: input_file:it/unimi/dsi/law/stat/WeightedTau$LogarithmicWeigher.class */
    private static final class LogarithmicWeigher extends AbstractWeigher {
        private static final long serialVersionUID = 1;

        private LogarithmicWeigher() {
        }

        public double get(int i) {
            return 1.0d / Math.log(i + 2.718281828459045d);
        }
    }

    /* loaded from: input_file:it/unimi/dsi/law/stat/WeightedTau$QuadraticWeigher.class */
    private static final class QuadraticWeigher extends AbstractWeigher {
        private static final long serialVersionUID = 1;

        private QuadraticWeigher() {
        }

        public double get(int i) {
            double d = i + 1.0d;
            return 1.0d / (d * d);
        }
    }

    /* loaded from: input_file:it/unimi/dsi/law/stat/WeightedTau$ZeroWeigher.class */
    private static final class ZeroWeigher extends AbstractWeigher {
        private static final long serialVersionUID = 1;

        private ZeroWeigher() {
        }

        public double get(int i) {
            return 0.0d;
        }
    }

    public WeightedTau() {
        this(HYPERBOLIC_WEIGHER);
    }

    public WeightedTau(Int2DoubleFunction int2DoubleFunction) {
        this(int2DoubleFunction, false);
    }

    public WeightedTau(Int2DoubleFunction int2DoubleFunction, boolean z) {
        this.weigher = int2DoubleFunction;
        this.multiplicative = z;
    }

    @Override // it.unimi.dsi.law.stat.CorrelationIndex
    public double compute(double[] dArr, double[] dArr2) {
        return Math.min(1.0d, Math.max(-1.0d, (compute(dArr, dArr2, (int[]) null) + compute(dArr2, dArr, (int[]) null)) / 2.0d));
    }

    public double compute(double[] dArr, double[] dArr2, int[] iArr) {
        double d;
        double d2;
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("Array lengths differ: " + dArr.length + ", " + dArr2.length);
        }
        int length = dArr.length;
        if (length == 0) {
            throw new IllegalArgumentException("The weighted τ is undefined on empty rankings");
        }
        if (iArr != null && iArr.length != length) {
            throw new IllegalArgumentException("The score array length (" + length + ") and the rank array length (" + iArr.length + ") do not match");
        }
        int[] identity = Util.identity(length);
        DoubleArrays.radixSortIndirect(identity, dArr2, dArr, true);
        if (iArr == null) {
            iArr = (int[]) identity.clone();
            IntArrays.reverse(iArr);
            Util.invertPermutationInPlace(iArr);
        }
        int i = 0;
        double d3 = 0.0d;
        double d4 = this.weigher.get(iArr[identity[0]]);
        double d5 = d4;
        double d6 = d4 * d4;
        int i2 = 1;
        while (i2 < length) {
            int i3 = i2;
            if (dArr[identity[i]] == dArr[identity[i3]]) {
                i3 = i2;
                if (dArr2[identity[i]] == dArr2[identity[i3]]) {
                    double d7 = this.weigher.get(iArr[identity[i2]]);
                    d5 += d7;
                    d6 += d7 * d7;
                    i2++;
                }
            }
            double d8 = d3;
            if (this.multiplicative) {
                d2 = ((d5 * d5) - d6) / 2.0d;
            } else {
                i3 = 1;
                d2 = d5 * ((i2 - i) - 1);
            }
            d3 = d8 + d2;
            i = i2;
            d6 = i3;
            d5 = 0.0d;
            double d72 = this.weigher.get(iArr[identity[i2]]);
            d5 += d72;
            d6 += d72 * d72;
            i2++;
        }
        double d9 = d3 + (this.multiplicative ? ((d5 * d5) - d6) / 2.0d : d5 * ((i2 - i) - 1));
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Weight of joint ties: " + d9);
        }
        int i4 = 0;
        double d10 = 0.0d;
        double d11 = this.weigher.get(iArr[identity[0]]);
        double d12 = d11;
        double d13 = d11 * d11;
        int i5 = 1;
        while (i5 < length) {
            int i6 = i5;
            if (dArr2[identity[i4]] != dArr2[identity[i6]]) {
                double d14 = d10;
                if (this.multiplicative) {
                    d = ((d12 * d12) - d13) / 2.0d;
                } else {
                    i6 = 1;
                    d = d12 * ((i5 - i4) - 1);
                }
                d10 = d14 + d;
                i4 = i5;
                d13 = i6;
                d12 = 0.0d;
            }
            double d15 = this.weigher.get(iArr[identity[i5]]);
            d12 += d15;
            d13 += d15 * d15;
            i5++;
        }
        double d16 = d10 + (this.multiplicative ? ((d12 * d12) - d13) / 2.0d : d12 * ((i5 - i4) - 1));
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Weight of ties in the second score vector: " + d16);
        }
        double weigh = new ExchangeWeigher(this.weigher, identity, dArr, iArr, this.multiplicative, new int[length]).weigh();
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Weight of exchanges: " + weigh);
        }
        int i7 = 0;
        double d17 = 0.0d;
        double d18 = this.weigher.get(iArr[identity[0]]);
        double d19 = d18;
        double d20 = d18 * d18;
        int i8 = 1;
        while (i8 < length) {
            if (dArr[identity[i7]] != dArr[identity[i8]]) {
                d17 += this.multiplicative ? ((d19 * d19) - d20) / 2.0d : d19 * ((i8 - i7) - 1);
                i7 = i8;
                d20 = dArr;
                d19 = 0.0d;
            }
            double d21 = this.weigher.get(iArr[identity[i8]]);
            d19 += d21;
            d20 += d21 * d21;
            i8++;
        }
        double d22 = d17 + (this.multiplicative ? ((d19 * d19) - d20) / 2.0d : d19 * ((i8 - i7) - 1));
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Weight of ties in the first score vector: " + d22);
        }
        double d23 = dArr;
        double d24 = 0.0d;
        for (int i9 = 0; i9 < length; i9++) {
            double d25 = this.weigher.get(iArr[identity[i9]]);
            d24 += d25;
            d23 += d25 * d25;
        }
        double d26 = this.multiplicative ? ((d24 * d24) - d23) / 2.0d : d24 * (length - 1);
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Total weight: " + d26);
        }
        if (d26 == d22 && d26 == d16) {
            return 1.0d;
        }
        return Math.min(1.0d, Math.max(-1.0d, ((((d26 - d16) - d22) + d9) - (2.0d * weigh)) / Math.sqrt((d26 - d22) * (d26 - d16))));
    }

    public static void main(String[] strArr) throws NumberFormatException, IOException, JSAPException {
        SimpleJSAP simpleJSAP = new SimpleJSAP(WeightedTau.class.getName(), "Computes a weighted correlation index between two given score files. By default, the index is a symmetric additive hyperbolic τ, but you can set a different choice using the available options. Note that scores need not to be distinct (i.e., you can have an arbitrary number of ties).\nBy default, the two files must contain the same number of doubles, written in Java binary (DataOutput) format. The option -t makes it possible to specify a different type (possibly for each input file).\nIf one or more truncations are specified with the option -T, the values of specified weighted correlation index for the given files truncated to the given number of binary fractional digits, in the same order, will be printed to standard output.If there is more than one value, the vectors will be loaded in memory just once and copied across computations.", new Parameter[]{new Switch("reverse", 'r', "reverse", "Use reverse ranks (that is, rank decreases as score increases)."), new Switch("logarithmic", 'l', "logarithmic", "Use a logarithmic (instead of hyperbolic) weight."), new Switch("quadratic", 'q', "quadratic", "Use a quadratic (instead of hyperbolic) weight."), new Switch("multiplicative", 'm', "multiplicative", "Use a multiplicative (instead of additive) combination of weights."), new FlaggedOption("type", JSAP.STRING_PARSER, "double", false, 't', "type", "The type of the input files, of the form type[:type] where type is one of int, long, float, double, text"), new FlaggedOption("digits", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, false, 'T', "truncate", "Truncate inputs to the given number of binary fractional digits.").setAllowMultipleDeclarations(true), new UnflaggedOption("file0", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, true, false, "The first score file."), new UnflaggedOption("file1", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, true, false, "The second score file.")});
        JSAPResult parse = simpleJSAP.parse(strArr);
        if (simpleJSAP.messagePrinted()) {
            System.exit(1);
        }
        String string = parse.getString("file0");
        String string2 = parse.getString("file1");
        boolean userSpecified = parse.userSpecified("reverse");
        boolean userSpecified2 = parse.userSpecified("logarithmic");
        boolean userSpecified3 = parse.userSpecified("quadratic");
        boolean userSpecified4 = parse.userSpecified("multiplicative");
        if (userSpecified2 && userSpecified3) {
            throw new IllegalArgumentException("You cannot specify logarithmic and quadratic weighting at the same time");
        }
        Class<?>[] parseInputTypes = parseInputTypes(parse);
        int[] intArray = parse.getIntArray("digits");
        if (intArray.length == 0) {
            intArray = new int[]{Integer.MAX_VALUE};
        }
        WeightedTau weightedTau = new WeightedTau(userSpecified2 ? LOGARITHMIC_WEIGHER : userSpecified3 ? QUADRATIC_WEIGHER : HYPERBOLIC_WEIGHER, userSpecified4);
        if (intArray.length == 1) {
            System.out.println(weightedTau.compute(string, parseInputTypes[0], string2, parseInputTypes[1], userSpecified, intArray[0]));
            return;
        }
        double[] loadAsDoubles = loadAsDoubles(string, parseInputTypes[0], userSpecified);
        double[] loadAsDoubles2 = loadAsDoubles(string2, parseInputTypes[1], userSpecified);
        for (int i : intArray) {
            System.out.println(weightedTau.compute(Precision.truncate((double[]) loadAsDoubles.clone(), i), Precision.truncate((double[]) loadAsDoubles2.clone(), i)));
        }
    }
}
