package org.apache.mahout.benchmark;

import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.lang3.StringUtils;
import org.apache.mahout.benchmark.BenchmarkRunner;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.TimingStatistics;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.ChebyshevDistanceMeasure;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.common.distance.MinkowskiDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.common.distance.TanimotoDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/benchmark/VectorBenchmarks.class */
public class VectorBenchmarks {
    private static final int MAX_TIME_MS = 5000;
    private static final int LEAD_TIME_MS = 15000;
    public static final String CLUSTERS = "Clusters";
    public static final String CREATE_INCREMENTALLY = "Create (incrementally)";
    public static final String CREATE_COPY = "Create (copy)";
    public static final String DENSE_FN_SEQ = "Dense.fn(Seq)";
    public static final String RAND_FN_DENSE = "Rand.fn(Dense)";
    public static final String SEQ_FN_RAND = "Seq.fn(Rand)";
    public static final String RAND_FN_SEQ = "Rand.fn(Seq)";
    public static final String SEQ_FN_DENSE = "Seq.fn(Dense)";
    public static final String DENSE_FN_RAND = "Dense.fn(Rand)";
    public static final String SEQ_SPARSE_VECTOR = "SeqSparseVector";
    public static final String RAND_SPARSE_VECTOR = "RandSparseVector";
    public static final String DENSE_VECTOR = "DenseVector";
    private static final Logger log = LoggerFactory.getLogger(VectorBenchmarks.class);
    private static final Pattern TAB_NEWLINE_PATTERN = Pattern.compile("[\n\t]");
    private static final String[] EMPTY = new String[0];
    private static final DecimalFormat DF = new DecimalFormat("#.##");
    final Vector[][] vectors;
    final Vector[] clusters;
    final int cardinality;
    final int numNonZeros;
    final int numVectors;
    final int numClusters;
    final int opsPerUnit;
    final int loop = Integer.MAX_VALUE;
    private final List<Vector> randomVectors = new ArrayList();
    private final List<int[]> randomVectorIndices = new ArrayList();
    private final List<double[]> randomVectorValues = new ArrayList();
    private final Map<String, Integer> implType = new HashMap();
    private final Map<String, List<String[]>> statsMap = new HashMap();
    private final Random r = RandomUtils.getRandom();
    private final BenchmarkRunner runner = new BenchmarkRunner(15000, 5000);
    final long maxTimeUsec = TimeUnit.MILLISECONDS.toNanos(5000);
    final long leadTimeUsec = TimeUnit.MILLISECONDS.toNanos(15000);

    public VectorBenchmarks(int i, int i2, int i3, int i4, int i5) {
        this.cardinality = i;
        this.numNonZeros = i2;
        this.numVectors = i3;
        this.numClusters = i4;
        this.opsPerUnit = i5;
        setUpVectors(i, i2, i3);
        this.vectors = new Vector[3][i3];
        this.clusters = new Vector[i4];
    }

    private void setUpVectors(int i, int i2, int i3) {
        for (int i4 = 0; i4 < i3; i4++) {
            Vector sequentialAccessSparseVector = new SequentialAccessSparseVector(i, i2);
            BitSet bitSet = new BitSet(i);
            int[] iArr = new int[i2];
            double[] dArr = new double[i2];
            int i5 = 0;
            while (i5 < i2) {
                double nextGaussian = this.r.nextGaussian();
                int nextInt = this.r.nextInt(i);
                if (!bitSet.get(nextInt) && nextGaussian != 0.0d) {
                    bitSet.set(nextInt);
                    iArr[i5] = nextInt;
                    int i6 = i5;
                    i5++;
                    dArr[i6] = nextGaussian;
                    sequentialAccessSparseVector.set(nextInt, nextGaussian);
                }
            }
            this.randomVectorIndices.add(iArr);
            this.randomVectorValues.add(dArr);
            this.randomVectors.add(sequentialAccessSparseVector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void printStats(TimingStatistics timingStatistics, String str, String str2, String str3) {
        printStats(timingStatistics, str, str2, str3, 1);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void printStats(TimingStatistics timingStatistics, String str, String str2) {
        printStats(timingStatistics, str, str2, "", 1);
    }

    private void printStats(TimingStatistics timingStatistics, String str, String str2, String str3, int i) {
        float nCalls = i * timingStatistics.getNCalls() * (((this.numNonZeros * 1000.0f) * 12.0f) / ((float) timingStatistics.getSumTime()));
        float nCalls2 = (timingStatistics.getNCalls() * 1.0E9f) / ((float) timingStatistics.getSumTime());
        log.info("{} {} \n{} {} \nOps    = {} Units/sec\nIOps   = {} MBytes/sec", new Object[]{str, str2, str3, timingStatistics.toString(), DF.format(nCalls2), DF.format(nCalls)});
        if (!this.implType.containsKey(str2)) {
            this.implType.put(str2, Integer.valueOf(this.implType.size()));
        }
        int intValue = this.implType.get(str2).intValue();
        if (!this.statsMap.containsKey(str)) {
            this.statsMap.put(str, new ArrayList());
        }
        List<String[]> list = this.statsMap.get(str);
        while (list.size() < intValue + 1) {
            list.add(EMPTY);
        }
        list.set(intValue, TAB_NEWLINE_PATTERN.split(timingStatistics + "\tSpeed  = " + DF.format(nCalls2) + " /sec\tRate   = " + DF.format(nCalls) + " MB/s"));
    }

    public void createData() {
        for (int i = 0; i < Math.max(this.numVectors, this.numClusters); i++) {
            this.vectors[0][vIndex(i)] = new DenseVector(this.randomVectors.get(vIndex(i)));
            this.vectors[1][vIndex(i)] = new RandomAccessSparseVector(this.randomVectors.get(vIndex(i)));
            this.vectors[2][vIndex(i)] = new SequentialAccessSparseVector(this.randomVectors.get(vIndex(i)));
            if (this.numClusters > 0) {
                this.clusters[cIndex(i)] = new RandomAccessSparseVector(this.randomVectors.get(vIndex(i)));
            }
        }
    }

    public void createBenchmark() {
        printStats(this.runner.benchmark(new BenchmarkRunner.BenchmarkFn() { // from class: org.apache.mahout.benchmark.VectorBenchmarks.1
            public Boolean apply(Integer num) {
                VectorBenchmarks.this.vectors[0][VectorBenchmarks.this.vIndex(num.intValue())] = new DenseVector((Vector) VectorBenchmarks.this.randomVectors.get(VectorBenchmarks.this.vIndex(num.intValue())));
                return Boolean.valueOf(depends(VectorBenchmarks.this.vectors[0][VectorBenchmarks.this.vIndex(num.intValue())]));
            }
        }), CREATE_COPY, DENSE_VECTOR);
        printStats(this.runner.benchmark(new BenchmarkRunner.BenchmarkFn() { // from class: org.apache.mahout.benchmark.VectorBenchmarks.2
            public Boolean apply(Integer num) {
                VectorBenchmarks.this.vectors[1][VectorBenchmarks.this.vIndex(num.intValue())] = new RandomAccessSparseVector((Vector) VectorBenchmarks.this.randomVectors.get(VectorBenchmarks.this.vIndex(num.intValue())));
                return Boolean.valueOf(depends(VectorBenchmarks.this.vectors[1][VectorBenchmarks.this.vIndex(num.intValue())]));
            }
        }), CREATE_COPY, RAND_SPARSE_VECTOR);
        printStats(this.runner.benchmark(new BenchmarkRunner.BenchmarkFn() { // from class: org.apache.mahout.benchmark.VectorBenchmarks.3
            public Boolean apply(Integer num) {
                VectorBenchmarks.this.vectors[2][VectorBenchmarks.this.vIndex(num.intValue())] = new SequentialAccessSparseVector((Vector) VectorBenchmarks.this.randomVectors.get(VectorBenchmarks.this.vIndex(num.intValue())));
                return Boolean.valueOf(depends(VectorBenchmarks.this.vectors[2][VectorBenchmarks.this.vIndex(num.intValue())]));
            }
        }), CREATE_COPY, SEQ_SPARSE_VECTOR);
        if (this.numClusters > 0) {
            printStats(this.runner.benchmark(new BenchmarkRunner.BenchmarkFn() { // from class: org.apache.mahout.benchmark.VectorBenchmarks.4
                public Boolean apply(Integer num) {
                    VectorBenchmarks.this.clusters[VectorBenchmarks.this.cIndex(num.intValue())] = new RandomAccessSparseVector((Vector) VectorBenchmarks.this.randomVectors.get(VectorBenchmarks.this.vIndex(num.intValue())));
                    return Boolean.valueOf(depends(VectorBenchmarks.this.clusters[VectorBenchmarks.this.cIndex(num.intValue())]));
                }
            }), CREATE_COPY, CLUSTERS);
        }
    }

    private boolean buildVectorIncrementally(TimingStatistics timingStatistics, int i, Vector vector, boolean z) {
        int[] iArr = this.randomVectorIndices.get(i);
        double[] dArr = this.randomVectorValues.get(i);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            arrayList.add(Integer.valueOf(i2));
        }
        Collections.shuffle(arrayList);
        int[] iArr2 = new int[arrayList.size()];
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            iArr2[i3] = ((Integer) arrayList.get(i3)).intValue();
        }
        TimingStatistics.Call newCall = timingStatistics.newCall(this.leadTimeUsec);
        if (z) {
            for (int i4 : iArr2) {
                vector.setQuick(iArr[i4], dArr[i4]);
            }
        } else {
            for (int i5 : iArr2) {
                vector.set(iArr[i5], dArr[i5]);
            }
        }
        return newCall.end(this.maxTimeUsec);
    }

    public void incrementalCreateBenchmark() {
        TimingStatistics timingStatistics = new TimingStatistics();
        for (int i = 0; i < Integer.MAX_VALUE; i++) {
            this.vectors[0][vIndex(i)] = new DenseVector(this.cardinality);
            if (buildVectorIncrementally(timingStatistics, vIndex(i), this.vectors[0][vIndex(i)], false)) {
                break;
            }
        }
        printStats(timingStatistics, CREATE_INCREMENTALLY, DENSE_VECTOR);
        TimingStatistics timingStatistics2 = new TimingStatistics();
        for (int i2 = 0; i2 < Integer.MAX_VALUE; i2++) {
            this.vectors[1][vIndex(i2)] = new RandomAccessSparseVector(this.cardinality);
            if (buildVectorIncrementally(timingStatistics2, vIndex(i2), this.vectors[1][vIndex(i2)], false)) {
                break;
            }
        }
        printStats(timingStatistics2, CREATE_INCREMENTALLY, RAND_SPARSE_VECTOR);
        TimingStatistics timingStatistics3 = new TimingStatistics();
        for (int i3 = 0; i3 < Integer.MAX_VALUE; i3++) {
            this.vectors[2][vIndex(i3)] = new SequentialAccessSparseVector(this.cardinality);
            if (buildVectorIncrementally(timingStatistics3, vIndex(i3), this.vectors[2][vIndex(i3)], false)) {
                break;
            }
        }
        printStats(timingStatistics3, CREATE_INCREMENTALLY, SEQ_SPARSE_VECTOR);
        if (this.numClusters > 0) {
            TimingStatistics timingStatistics4 = new TimingStatistics();
            for (int i4 = 0; i4 < Integer.MAX_VALUE; i4++) {
                this.clusters[cIndex(i4)] = new RandomAccessSparseVector(this.cardinality);
                if (buildVectorIncrementally(timingStatistics4, vIndex(i4), this.clusters[cIndex(i4)], false)) {
                    break;
                }
            }
            printStats(timingStatistics4, CREATE_INCREMENTALLY, CLUSTERS);
        }
    }

    public int vIndex(int i) {
        return i % this.numVectors;
    }

    public int cIndex(int i) {
        return i % this.numClusters;
    }

    public static void main(String[] strArr) throws IOException {
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        GroupBuilder groupBuilder = new GroupBuilder();
        DefaultOption create = defaultOptionBuilder.withLongName("vectorSize").withRequired(false).withArgument(argumentBuilder.withName("vs").withDefault(1000000).create()).withDescription("Cardinality of the vector. Default: 1000000").withShortName("vs").create();
        DefaultOption create2 = defaultOptionBuilder.withLongName("numNonZero").withRequired(false).withArgument(argumentBuilder.withName("nz").withDefault(1000).create()).withDescription("Size of the vector. Default: 1000").withShortName("nz").create();
        DefaultOption create3 = defaultOptionBuilder.withLongName("numVectors").withRequired(false).withArgument(argumentBuilder.withName("nv").withDefault(25).create()).withDescription("Number of Vectors to create. Default: 25").withShortName("nv").create();
        DefaultOption create4 = defaultOptionBuilder.withLongName("numClusters").withRequired(false).withArgument(argumentBuilder.withName("nc").withDefault(0).create()).withDescription("Number of clusters to create. Set to non zero to run cluster benchmark. Default: 0").withShortName("nc").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName("numOps").withRequired(false).withArgument(argumentBuilder.withName("numOps").withDefault(10).create()).withDescription("Number of operations to do per timer. E.g In distance measure, the distance is calculated numOps times and the total time is measured. Default: 10").withShortName("no").create();
        Option helpOption = DefaultOptionCreator.helpOption();
        Group create6 = groupBuilder.withName("Options").withOption(create).withOption(create2).withOption(create3).withOption(create5).withOption(create4).withOption(helpOption).create();
        try {
            Parser parser = new Parser();
            parser.setGroup(create6);
            CommandLine parse = parser.parse(strArr);
            if (parse.hasOption(helpOption)) {
                CommandLineUtil.printHelpWithGenericOptions(create6);
                return;
            }
            int i = 1000000;
            if (parse.hasOption(create)) {
                i = Integer.parseInt((String) parse.getValue(create));
            }
            int i2 = 0;
            if (parse.hasOption(create4)) {
                i2 = Integer.parseInt((String) parse.getValue(create4));
            }
            int i3 = 1000;
            if (parse.hasOption(create2)) {
                i3 = Integer.parseInt((String) parse.getValue(create2));
            }
            int i4 = 25;
            if (parse.hasOption(create3)) {
                i4 = Integer.parseInt((String) parse.getValue(create3));
            }
            int i5 = 10;
            if (parse.hasOption(create5)) {
                i5 = Integer.parseInt((String) parse.getValue(create5));
            }
            VectorBenchmarks vectorBenchmarks = new VectorBenchmarks(i, i3, i4, i2, i5);
            runBenchmark(vectorBenchmarks);
            log.info("\n{}", vectorBenchmarks.asCsvString());
        } catch (OptionException e) {
            CommandLineUtil.printHelp(create6);
        }
    }

    private static void runBenchmark(VectorBenchmarks vectorBenchmarks) throws IOException {
        vectorBenchmarks.createData();
        vectorBenchmarks.createBenchmark();
        if (vectorBenchmarks.cardinality < 200000) {
            vectorBenchmarks.incrementalCreateBenchmark();
        }
        new CloneBenchmark(vectorBenchmarks).benchmark();
        new DotBenchmark(vectorBenchmarks).benchmark();
        new PlusBenchmark(vectorBenchmarks).benchmark();
        new MinusBenchmark(vectorBenchmarks).benchmark();
        new TimesBenchmark(vectorBenchmarks).benchmark();
        new SerializationBenchmark(vectorBenchmarks).benchmark();
        DistanceBenchmark distanceBenchmark = new DistanceBenchmark(vectorBenchmarks);
        distanceBenchmark.benchmark(new CosineDistanceMeasure());
        distanceBenchmark.benchmark(new SquaredEuclideanDistanceMeasure());
        distanceBenchmark.benchmark(new EuclideanDistanceMeasure());
        distanceBenchmark.benchmark(new ManhattanDistanceMeasure());
        distanceBenchmark.benchmark(new TanimotoDistanceMeasure());
        distanceBenchmark.benchmark(new ChebyshevDistanceMeasure());
        distanceBenchmark.benchmark(new MinkowskiDistanceMeasure());
        if (vectorBenchmarks.numClusters > 0) {
            ClosestCentroidBenchmark closestCentroidBenchmark = new ClosestCentroidBenchmark(vectorBenchmarks);
            closestCentroidBenchmark.benchmark(new CosineDistanceMeasure());
            closestCentroidBenchmark.benchmark(new SquaredEuclideanDistanceMeasure());
            closestCentroidBenchmark.benchmark(new EuclideanDistanceMeasure());
            closestCentroidBenchmark.benchmark(new ManhattanDistanceMeasure());
            closestCentroidBenchmark.benchmark(new TanimotoDistanceMeasure());
            closestCentroidBenchmark.benchmark(new ChebyshevDistanceMeasure());
            closestCentroidBenchmark.benchmark(new MinkowskiDistanceMeasure());
        }
    }

    private String asCsvString() {
        ArrayList<String> arrayList = new ArrayList(this.statsMap.keySet());
        Collections.sort(arrayList);
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Integer> entry : this.implType.entrySet()) {
            hashMap.put(entry.getValue(), entry.getKey());
        }
        StringBuilder sb = new StringBuilder(1000);
        for (String str : arrayList) {
            int i = 0;
            for (String[] strArr : this.statsMap.get(str)) {
                if (strArr.length >= 8) {
                    sb.append(str).append(',');
                    int i2 = i;
                    i++;
                    sb.append((String) hashMap.get(Integer.valueOf(i2))).append(',');
                    sb.append(strArr[7].trim().split("=|/")[1].trim());
                    sb.append('\n');
                }
            }
        }
        sb.append('\n');
        return sb.toString();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(1000);
        sb.append(StringUtils.rightPad("BenchMarks", 24));
        for (int i = 0; i < this.implType.size(); i++) {
            Iterator<Map.Entry<String, Integer>> it = this.implType.entrySet().iterator();
            while (true) {
                if (it.hasNext()) {
                    Map.Entry<String, Integer> next = it.next();
                    if (next.getValue().intValue() == i) {
                        sb.append(StringUtils.rightPad(next.getKey(), 24).substring(0, 24));
                        break;
                    }
                }
            }
        }
        sb.append('\n');
        ArrayList<String> arrayList = new ArrayList(this.statsMap.keySet());
        Collections.sort(arrayList);
        for (String str : arrayList) {
            List<String[]> list = this.statsMap.get(str);
            int i2 = 0;
            Iterator<String[]> it2 = list.iterator();
            while (it2.hasNext()) {
                i2 = Math.max(i2, it2.next().length);
            }
            for (int i3 = 0; i3 < i2; i3++) {
                boolean z = false;
                for (String[] strArr : list) {
                    if (i3 == 0 && !z) {
                        sb.append(StringUtils.rightPad(str, 24));
                        z = true;
                    } else if (!z) {
                        z = true;
                        sb.append(StringUtils.rightPad("", 24));
                    }
                    if (strArr.length > i3) {
                        sb.append(StringUtils.rightPad(strArr[i3], 24));
                    } else {
                        sb.append(StringUtils.rightPad("", 24));
                    }
                }
                sb.append('\n');
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    public BenchmarkRunner getRunner() {
        return this.runner;
    }
}
