package net.myrrix.online.factorizer.als;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import net.myrrix.common.LangUtils;
import net.myrrix.common.collection.FastByIDFloatMap;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.math.MatrixUtils;
import net.myrrix.common.math.SimpleVectorMath;
import net.myrrix.common.parallel.ExecutorUtils;
import net.myrrix.common.random.RandomManager;
import net.myrrix.common.random.RandomUtils;
import net.myrrix.common.stats.JVMEnvironment;
import net.myrrix.online.factorizer.MatrixFactorizer;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.common.WeightedRunningAverage;
import org.apache.mahout.common.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:net/myrrix/online/factorizer/als/AlternatingLeastSquares.class */
public final class AlternatingLeastSquares implements MatrixFactorizer {
    public static final double DEFAULT_ALPHA = 1.0d;
    public static final double DEFAULT_LAMBDA = 0.1d;
    public static final double DEFAULT_CONVERGENCE_THRESHOLD = 0.001d;
    public static final int DEFAULT_MAX_ITERATIONS = 30;
    private static final int WORK_UNIT_SIZE = 100;
    private static final int NUM_USER_ITEMS_TO_TEST_CONVERGENCE = 100;
    private static final long LOG_INTERVAL = 100000;
    private static final int MAX_FAR_FROM_VECTORS = 100000;
    private final FastByIDMap<FastByIDFloatMap> RbyRow;
    private final FastByIDMap<FastByIDFloatMap> RbyColumn;
    private final int features;
    private final double estimateErrorConvergenceThreshold;
    private final int maxIterations;
    private FastByIDMap<float[]> X;
    private FastByIDMap<float[]> Y;
    private FastByIDMap<float[]> previousY;
    private static final Logger log = LoggerFactory.getLogger(AlternatingLeastSquares.class);
    private static final boolean RECONSTRUCT_R_MATRIX = Boolean.parseBoolean(System.getProperty("model.reconstructRMatrix", "false"));
    private static final boolean LOSS_IGNORES_UNSPECIFIED = Boolean.parseBoolean(System.getProperty("model.lossIgnoresUnspecified", "false"));

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:net/myrrix/online/factorizer/als/AlternatingLeastSquares$Worker.class */
    public static final class Worker implements Callable<Void> {
        private final int features;
        private final FastByIDMap<float[]> Y;
        private final RealMatrix YTY;
        private final FastByIDMap<float[]> X;
        private final List<Pair<Long, FastByIDFloatMap>> workUnit;

        private Worker(int i, FastByIDMap<float[]> fastByIDMap, RealMatrix realMatrix, FastByIDMap<float[]> fastByIDMap2, List<Pair<Long, FastByIDFloatMap>> list) {
            this.features = i;
            this.Y = fastByIDMap;
            this.YTY = realMatrix;
            this.X = fastByIDMap2;
            this.workUnit = list;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Void call() {
            double alpha = getAlpha();
            double lambda = getLambda() * alpha;
            int i = this.features;
            for (Pair<Long, FastByIDFloatMap> pair : this.workUnit) {
                FastByIDFloatMap fastByIDFloatMap = (FastByIDFloatMap) pair.getSecond();
                RealMatrix partialTransposeTimesSelf = AlternatingLeastSquares.LOSS_IGNORES_UNSPECIFIED ? partialTransposeTimesSelf(this.Y, this.YTY.getRowDimension(), fastByIDFloatMap.keySetIterator()) : this.YTY.copy();
                double[][] accessMatrixDataDirectly = MatrixUtils.accessMatrixDataDirectly(partialTransposeTimesSelf);
                double[] dArr = new double[i];
                for (FastByIDFloatMap.MapEntry mapEntry : fastByIDFloatMap.entrySet()) {
                    double value = mapEntry.getValue();
                    float[] fArr = (float[]) this.Y.get(mapEntry.getKey());
                    if (fArr == null) {
                        AlternatingLeastSquares.log.warn("No vector for {}. This should not happen. Continuing...", Long.valueOf(mapEntry.getKey()));
                    } else if (AlternatingLeastSquares.RECONSTRUCT_R_MATRIX) {
                        for (int i2 = 0; i2 < i; i2++) {
                            int i3 = i2;
                            dArr[i3] = dArr[i3] + (value * fArr[i2]);
                        }
                    } else {
                        double abs = 1.0d + (alpha * FastMath.abs(value));
                        for (int i4 = 0; i4 < i; i4++) {
                            float f = fArr[i4];
                            double d = f * (abs - 1.0d);
                            double[] dArr2 = accessMatrixDataDirectly[i4];
                            for (int i5 = 0; i5 < i; i5++) {
                                int i6 = i5;
                                dArr2[i6] = dArr2[i6] + (d * fArr[i5]);
                            }
                            if (value > 0.0d) {
                                int i7 = i4;
                                dArr[i7] = dArr[i7] + (f * abs);
                            }
                        }
                    }
                }
                double size = lambda * fastByIDFloatMap.size();
                for (int i8 = 0; i8 < i; i8++) {
                    double[] dArr3 = accessMatrixDataDirectly[i8];
                    int i9 = i8;
                    dArr3[i9] = dArr3[i9] + size;
                }
                float[] solveDToF = MatrixUtils.getSolver(partialTransposeTimesSelf).solveDToF(dArr);
                synchronized (this.X) {
                    this.X.put(((Long) pair.getFirst()).longValue(), solveDToF);
                }
            }
            return null;
        }

        private static double getAlpha() {
            String property = System.getProperty("model.als.alpha");
            if (property == null) {
                return 1.0d;
            }
            return LangUtils.parseDouble(property);
        }

        private static double getLambda() {
            String property = System.getProperty("model.als.lambda");
            if (property == null) {
                return 0.1d;
            }
            return LangUtils.parseDouble(property);
        }

        private static RealMatrix partialTransposeTimesSelf(FastByIDMap<float[]> fastByIDMap, int i, LongPrimitiveIterator longPrimitiveIterator) {
            Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(i, i);
            while (longPrimitiveIterator.hasNext()) {
                float[] fArr = (float[]) fastByIDMap.get(((Long) longPrimitiveIterator.next()).longValue());
                for (int i2 = 0; i2 < i; i2++) {
                    float f = fArr[i2];
                    for (int i3 = 0; i3 < i; i3++) {
                        array2DRowRealMatrix.addToEntry(i2, i3, f * fArr[i3]);
                    }
                }
            }
            return array2DRowRealMatrix;
        }
    }

    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> fastByIDMap, FastByIDMap<FastByIDFloatMap> fastByIDMap2) {
        this(fastByIDMap, fastByIDMap2, 30, 0.001d, 30);
    }

    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> fastByIDMap, FastByIDMap<FastByIDFloatMap> fastByIDMap2, int i) {
        this(fastByIDMap, fastByIDMap2, i, 0.001d, 30);
    }

    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> fastByIDMap, FastByIDMap<FastByIDFloatMap> fastByIDMap2, int i, double d, int i2) {
        Preconditions.checkNotNull(fastByIDMap);
        Preconditions.checkNotNull(fastByIDMap2);
        Preconditions.checkArgument(i > 0, "features must be positive: %s", new Object[]{Integer.valueOf(i)});
        Preconditions.checkArgument(d > 0.0d && d < 1.0d, "threshold must be in (0,1): %s", new Object[]{Double.valueOf(d)});
        this.RbyRow = fastByIDMap;
        this.RbyColumn = fastByIDMap2;
        this.features = i;
        this.estimateErrorConvergenceThreshold = d;
        this.maxIterations = i2;
    }

    @Override // net.myrrix.online.factorizer.MatrixFactorizer
    public FastByIDMap<float[]> getX() {
        return this.X;
    }

    @Override // net.myrrix.online.factorizer.MatrixFactorizer
    public FastByIDMap<float[]> getY() {
        return this.Y;
    }

    @Override // net.myrrix.online.factorizer.MatrixFactorizer
    public void setPreviousX(FastByIDMap<float[]> fastByIDMap) {
    }

    @Override // net.myrrix.online.factorizer.MatrixFactorizer
    public void setPreviousY(FastByIDMap<float[]> fastByIDMap) {
        this.previousY = fastByIDMap;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.Callable
    public Void call() throws ExecutionException, InterruptedException {
        this.X = new FastByIDMap<>(this.RbyRow.size(), 1.25f);
        boolean z = this.previousY == null || this.previousY.isEmpty();
        this.Y = constructInitialY(this.previousY);
        String property = System.getProperty("model.threads");
        int availableProcessors = property == null ? Runtime.getRuntime().availableProcessors() : Integer.parseInt(property);
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(availableProcessors, new ThreadFactoryBuilder().setNameFormat("ALS-%d").setDaemon(true).build());
        log.info("Iterating using {} threads", Integer.valueOf(availableProcessors));
        if (!Boolean.parseBoolean(System.getProperty("model.als.iterate", "true"))) {
            try {
                iterateXFromY(newFixedThreadPool);
                ExecutorUtils.shutdownNowAndAwait(newFixedThreadPool);
                return null;
            } finally {
            }
        }
        RandomGenerator random = RandomManager.getRandom();
        long[] chooseAboutNFromStream = RandomUtils.chooseAboutNFromStream(100, this.RbyRow.keySetIterator(), this.RbyRow.size(), random);
        long[] chooseAboutNFromStream2 = RandomUtils.chooseAboutNFromStream(100, this.RbyColumn.keySetIterator(), this.RbyColumn.size(), random);
        double[][] dArr = new double[chooseAboutNFromStream.length][chooseAboutNFromStream2.length];
        if (!this.X.isEmpty()) {
            for (int i = 0; i < chooseAboutNFromStream.length; i++) {
                for (int i2 = 0; i2 < chooseAboutNFromStream2.length; i2++) {
                    dArr[i][i2] = SimpleVectorMath.dot((float[]) this.X.get(chooseAboutNFromStream[i]), (float[]) this.Y.get(chooseAboutNFromStream2[i2]));
                }
            }
        }
        int i3 = 0;
        while (true) {
            try {
                iterateXFromY(newFixedThreadPool);
                iterateYFromX(newFixedThreadPool);
                WeightedRunningAverage weightedRunningAverage = new WeightedRunningAverage();
                for (int i4 = 0; i4 < chooseAboutNFromStream.length; i4++) {
                    for (int i5 = 0; i5 < chooseAboutNFromStream2.length; i5++) {
                        double dot = SimpleVectorMath.dot((float[]) this.X.get(chooseAboutNFromStream[i4]), (float[]) this.Y.get(chooseAboutNFromStream2[i5]));
                        double d = dArr[i4][i5];
                        dArr[i4][i5] = dot;
                        weightedRunningAverage.addDatum(FastMath.abs(dot - d), FastMath.max(0.0d, dot));
                    }
                }
                i3++;
                log.info("Finished iteration {}", Integer.valueOf(i3));
                if (this.maxIterations > 0 && i3 >= this.maxIterations) {
                    log.info("Reached iteration limit");
                    break;
                }
                log.info("Avg absolute difference in estimate vs prior iteration: {}", weightedRunningAverage);
                double average = weightedRunningAverage.getAverage();
                if (!LangUtils.isFinite(average)) {
                    log.warn("Invalid convergence value, aborting iteration! {}", Double.valueOf(average));
                    break;
                }
                if ((!z || i3 != 1) && average < this.estimateErrorConvergenceThreshold) {
                    log.info("Converged");
                    break;
                }
            } finally {
            }
        }
        ExecutorUtils.shutdownNowAndAwait(newFixedThreadPool);
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r2v16 */
    /* JADX WARN: Type inference failed for: r2v2, types: [int] */
    /* JADX WARN: Type inference failed for: r2v27 */
    /* JADX WARN: Type inference failed for: r2v3 */
    /* JADX WARN: Type inference failed for: r2v30 */
    /* JADX WARN: Type inference failed for: r2v31 */
    /* JADX WARN: Type inference failed for: r2v32 */
    /* JADX WARN: Type inference failed for: r2v33 */
    /* JADX WARN: Type inference failed for: r2v34 */
    /* JADX WARN: Type inference failed for: r2v35 */
    /* JADX WARN: Type inference failed for: r2v36 */
    /* JADX WARN: Type inference failed for: r2v4 */
    /* JADX WARN: Type inference failed for: r2v5 */
    /* JADX WARN: Type inference failed for: r2v6 */
    /* JADX WARN: Type inference failed for: r2v8, types: [java.lang.Long, java.lang.Object] */
    private FastByIDMap<float[]> constructInitialY(FastByIDMap<float[]> fastByIDMap) {
        long size;
        FastByIDMap<float[]> fastByIDMap2;
        RandomGenerator random = RandomManager.getRandom();
        if (fastByIDMap == null || fastByIDMap.isEmpty()) {
            log.info("Starting from new, random Y matrix");
            size = this.RbyColumn.size();
            fastByIDMap2 = new FastByIDMap<>((int) size, 1.25f);
        } else {
            int length = ((float[]) ((FastByIDMap.MapEntry) fastByIDMap.entrySet().iterator().next()).getValue()).length;
            if (length > this.features) {
                log.info("Feature count has decreased to {}, projecting down previous generation's Y matrix", Integer.valueOf(this.features));
                int size2 = fastByIDMap.size();
                fastByIDMap2 = new FastByIDMap<>(size2, 1.25f);
                size = size2;
                for (FastByIDMap.MapEntry mapEntry : fastByIDMap.entrySet()) {
                    float[] fArr = (float[]) mapEntry.getValue();
                    float[] fArr2 = new float[this.features];
                    System.arraycopy(fArr, 0, fArr2, 0, fArr2.length);
                    SimpleVectorMath.normalize(fArr2);
                    float[] fArr3 = fArr2;
                    fastByIDMap2.put(mapEntry.getKey(), fArr3);
                    size = fArr3;
                }
            } else if (length < this.features) {
                log.info("Feature count has increased to {}, using previous generation's Y matrix as subspace", Integer.valueOf(this.features));
                int size3 = fastByIDMap.size();
                fastByIDMap2 = new FastByIDMap<>(size3, 1.25f);
                size = size3;
                for (FastByIDMap.MapEntry mapEntry2 : fastByIDMap.entrySet()) {
                    float[] fArr4 = (float[]) mapEntry2.getValue();
                    float[] fArr5 = new float[this.features];
                    System.arraycopy(fArr4, 0, fArr5, 0, fArr4.length);
                    for (int length2 = fArr4.length; length2 < fArr5.length; length2++) {
                        fArr5[length2] = (float) random.nextGaussian();
                    }
                    SimpleVectorMath.normalize(fArr5);
                    float[] fArr6 = fArr5;
                    fastByIDMap2.put(mapEntry2.getKey(), fArr6);
                    size = fArr6;
                }
            } else {
                log.info("Starting from previous generation's Y matrix");
                fastByIDMap2 = fastByIDMap;
            }
        }
        ArrayList newArrayList = Lists.newArrayList();
        for (FastByIDMap.MapEntry mapEntry3 : fastByIDMap2.entrySet()) {
            if (newArrayList.size() >= MAX_FAR_FROM_VECTORS) {
                break;
            }
            newArrayList.add(mapEntry3.getValue());
        }
        LongPrimitiveIterator keySetIterator = this.RbyColumn.keySetIterator();
        long j = 0;
        while (keySetIterator.hasNext()) {
            long nextLong = keySetIterator.nextLong();
            size = size;
            if (!fastByIDMap2.containsKey(nextLong)) {
                float[] randomUnitVectorFarFrom = RandomUtils.randomUnitVectorFarFrom(this.features, newArrayList, random);
                float[] fArr7 = randomUnitVectorFarFrom;
                fastByIDMap2.put(nextLong, fArr7);
                size = fArr7;
                if (newArrayList.size() < MAX_FAR_FROM_VECTORS) {
                    newArrayList.add(randomUnitVectorFarFrom);
                    size = fArr7;
                }
            }
            long j2 = j + 1;
            j = size;
            if (j2 % LOG_INTERVAL == 0) {
                Logger logger = log;
                size = Long.valueOf(j);
                logger.info("Computed {} initial Y rows", (Object) size);
            }
        }
        log.info("Constructed initial Y");
        return fastByIDMap2;
    }

    private void iterateXFromY(ExecutorService executorService) throws ExecutionException, InterruptedException {
        RealMatrix transposeTimesSelf = MatrixUtils.transposeTimesSelf(this.Y);
        ArrayList newArrayList = Lists.newArrayList();
        addWorkers(this.RbyRow, this.Y, transposeTimesSelf, this.X, executorService, newArrayList);
        int i = 0;
        long j = 0;
        Iterator<Future<?>> it = newArrayList.iterator();
        while (it.hasNext()) {
            it.next().get();
            i += 100;
            if (i >= LOG_INTERVAL) {
                j += i;
                JVMEnvironment jVMEnvironment = new JVMEnvironment();
                log.info("{} X/tag rows computed ({}MB heap)", Long.valueOf(j), Integer.valueOf(jVMEnvironment.getUsedMemoryMB()));
                if (jVMEnvironment.getPercentUsedMemory() > 95) {
                    log.warn("Memory is low. Increase heap size with -Xmx, decrease new generation size with larger -XX:NewRatio value, and/or use -XX:+UseCompressedOops");
                }
                i = 0;
            }
        }
    }

    private void iterateYFromX(ExecutorService executorService) throws ExecutionException, InterruptedException {
        RealMatrix transposeTimesSelf = MatrixUtils.transposeTimesSelf(this.X);
        ArrayList newArrayList = Lists.newArrayList();
        addWorkers(this.RbyColumn, this.X, transposeTimesSelf, this.Y, executorService, newArrayList);
        int i = 0;
        long j = 0;
        Iterator<Future<?>> it = newArrayList.iterator();
        while (it.hasNext()) {
            it.next().get();
            i += 100;
            if (i >= LOG_INTERVAL) {
                j += i;
                JVMEnvironment jVMEnvironment = new JVMEnvironment();
                log.info("{} Y/tag rows computed ({}MB heap)", Long.valueOf(j), Integer.valueOf(jVMEnvironment.getUsedMemoryMB()));
                if (jVMEnvironment.getPercentUsedMemory() > 95) {
                    log.warn("Memory is low. Increase heap size with -Xmx, decrease new generation size with larger -XX:NewRatio value, and/or use -XX:+UseCompressedOops");
                }
                i = 0;
            }
        }
    }

    private void addWorkers(FastByIDMap<FastByIDFloatMap> fastByIDMap, FastByIDMap<float[]> fastByIDMap2, RealMatrix realMatrix, FastByIDMap<float[]> fastByIDMap3, ExecutorService executorService, Collection<Future<?>> collection) {
        if (fastByIDMap != null) {
            ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(100);
            for (FastByIDMap.MapEntry mapEntry : fastByIDMap.entrySet()) {
                newArrayListWithCapacity.add(new Pair(Long.valueOf(mapEntry.getKey()), mapEntry.getValue()));
                if (newArrayListWithCapacity.size() == 100) {
                    collection.add(executorService.submit(new Worker(this.features, fastByIDMap2, realMatrix, fastByIDMap3, newArrayListWithCapacity)));
                    newArrayListWithCapacity = Lists.newArrayListWithCapacity(100);
                }
            }
            if (newArrayListWithCapacity.isEmpty()) {
                return;
            }
            collection.add(executorService.submit(new Worker(this.features, fastByIDMap2, realMatrix, fastByIDMap3, newArrayListWithCapacity)));
        }
    }
}
