package org.apache.mahout.cf.taste.hadoop.als;

import java.io.File;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.hadoop.als.ParallelALSFactorizationJob;
import org.apache.mahout.cf.taste.impl.TasteTestCase;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.VarIntWritable;
import org.apache.mahout.math.VarLongWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.als.AlternateLeastSquaresSolver;
import org.apache.mahout.math.hadoop.MathHelper;
import org.easymock.IArgumentMatcher;
import org.easymock.classextension.EasyMock;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.class */
public class ParallelALSFactorizationJobTest extends TasteTestCase {
    private static final Logger logger = LoggerFactory.getLogger(ParallelALSFactorizationJobTest.class);

    @Test
    public void prefsToRatingsMapper() throws Exception {
        Mapper.Context context = (Mapper.Context) EasyMock.createMock(Mapper.Context.class);
        context.write(new VarIntWritable(TasteHadoopUtils.idToIndex(456L)), new FeatureVectorWithRatingWritable(TasteHadoopUtils.idToIndex(123L), 2.35f));
        EasyMock.replay(new Object[]{context});
        new ParallelALSFactorizationJob.PrefsToRatingsMapper().map((LongWritable) null, new Text("123,456,2.35"), context);
        EasyMock.verify(new Object[]{context});
    }

    @Test
    public void prefsToRatingsMapperTranspose() throws Exception {
        Mapper.Context context = (Mapper.Context) EasyMock.createMock(Mapper.Context.class);
        context.write(new VarIntWritable(TasteHadoopUtils.idToIndex(123L)), new FeatureVectorWithRatingWritable(TasteHadoopUtils.idToIndex(456L), 2.35f));
        EasyMock.replay(new Object[]{context});
        ParallelALSFactorizationJob.PrefsToRatingsMapper prefsToRatingsMapper = new ParallelALSFactorizationJob.PrefsToRatingsMapper();
        setField(prefsToRatingsMapper, "transpose", true);
        prefsToRatingsMapper.map((LongWritable) null, new Text("123,456,2.35"), context);
        EasyMock.verify(new Object[]{context});
    }

    @Test
    public void initializeMReducer() throws Exception {
        Reducer.Context context = (Reducer.Context) EasyMock.createMock(Reducer.Context.class);
        context.write(EasyMock.eq(new VarIntWritable(TasteHadoopUtils.idToIndex(123L))), matchInitializedFeatureVector(3.0d, 3));
        EasyMock.replay(new Object[]{context});
        ParallelALSFactorizationJob.InitializeMReducer initializeMReducer = new ParallelALSFactorizationJob.InitializeMReducer();
        setField(initializeMReducer, "numFeatures", 3);
        initializeMReducer.reduce(new VarLongWritable(123L), Arrays.asList(new FloatWritable(4.0f), new FloatWritable(2.0f)), context);
        EasyMock.verify(new Object[]{context});
    }

    static FeatureVectorWithRatingWritable matchInitializedFeatureVector(final double d, final int i) {
        EasyMock.reportMatcher(new IArgumentMatcher() { // from class: org.apache.mahout.cf.taste.hadoop.als.ParallelALSFactorizationJobTest.1
            public boolean matches(Object obj) {
                if (!(obj instanceof FeatureVectorWithRatingWritable)) {
                    return false;
                }
                Vector featureVector = ((FeatureVectorWithRatingWritable) obj).getFeatureVector();
                if (featureVector.get(0) != d) {
                    return false;
                }
                for (int i2 = 1; i2 < i; i2++) {
                    if (featureVector.get(i2) < 0.0d || featureVector.get(i2) > 1.0d) {
                        return false;
                    }
                }
                return true;
            }

            public void appendTo(StringBuffer stringBuffer) {
            }
        });
        return null;
    }

    @Test
    public void itemIDRatingMapper() throws Exception {
        Mapper.Context context = (Mapper.Context) EasyMock.createMock(Mapper.Context.class);
        context.write(new VarLongWritable(456L), new FloatWritable(2.35f));
        EasyMock.replay(new Object[]{context});
        new ParallelALSFactorizationJob.ItemIDRatingMapper().map((LongWritable) null, new Text("123,456,2.35"), context);
        EasyMock.verify(new Object[]{context});
    }

    @Test
    public void joinFeatureVectorAndRatingsReducer() throws Exception {
        DenseVector denseVector = new DenseVector(new double[]{4.5d, 1.2d});
        Reducer.Context context = (Reducer.Context) EasyMock.createMock(Reducer.Context.class);
        context.write(new IndexedVarIntWritable(456, 123), new FeatureVectorWithRatingWritable(123, Float.valueOf(2.35f), denseVector));
        EasyMock.replay(new Object[]{context});
        new ParallelALSFactorizationJob.JoinFeatureVectorAndRatingsReducer().reduce(new VarIntWritable(123), Arrays.asList(new FeatureVectorWithRatingWritable(456, denseVector), new FeatureVectorWithRatingWritable(456, 2.35f)), context);
        EasyMock.verify(new Object[]{context});
    }

    @Test
    public void solvingReducer() throws Exception {
        AlternateLeastSquaresSolver alternateLeastSquaresSolver = new AlternateLeastSquaresSolver();
        DenseVector denseVector = new DenseVector(new double[]{2.0d, 1.0d});
        Vector denseVector2 = new DenseVector(new double[]{1.0d, 2.0d});
        Vector denseVector3 = new DenseVector(new double[]{3.0d, 4.0d});
        Vector solve = alternateLeastSquaresSolver.solve(Arrays.asList(denseVector2, denseVector3), denseVector, 0.01d, 2);
        Vector.Element[] elementArr = new Vector.Element[solve.size()];
        for (int i = 0; i < solve.size(); i++) {
            elementArr[i] = solve.getElement(i);
        }
        Reducer.Context context = (Reducer.Context) EasyMock.createMock(Reducer.Context.class);
        context.write(EasyMock.eq(new VarIntWritable(123)), matchFeatureVector(elementArr));
        EasyMock.replay(new Object[]{context});
        ParallelALSFactorizationJob.SolvingReducer solvingReducer = new ParallelALSFactorizationJob.SolvingReducer();
        setField(solvingReducer, "numFeatures", 2);
        setField(solvingReducer, "lambda", Double.valueOf(0.01d));
        setField(solvingReducer, "solver", alternateLeastSquaresSolver);
        solvingReducer.reduce(new IndexedVarIntWritable(123, 1), Arrays.asList(new FeatureVectorWithRatingWritable(456, new Float(denseVector.get(0)), denseVector2), new FeatureVectorWithRatingWritable(789, new Float(denseVector.get(1)), denseVector3)), context);
        EasyMock.verify(new Object[]{context});
    }

    static FeatureVectorWithRatingWritable matchFeatureVector(final Vector.Element... elementArr) {
        EasyMock.reportMatcher(new IArgumentMatcher() { // from class: org.apache.mahout.cf.taste.hadoop.als.ParallelALSFactorizationJobTest.2
            public boolean matches(Object obj) {
                if (obj instanceof FeatureVectorWithRatingWritable) {
                    return MathHelper.consistsOf(((FeatureVectorWithRatingWritable) obj).getFeatureVector(), elementArr);
                }
                return false;
            }

            public void appendTo(StringBuffer stringBuffer) {
            }
        });
        return null;
    }

    @Test
    public void completeJobToyExample() throws Exception {
        File testTempFile = getTestTempFile("prefs.txt");
        File testTempDir = getTestTempDir("output");
        testTempDir.delete();
        File testTempDir2 = getTestTempDir("tmp");
        Double valueOf = Double.valueOf(Double.NaN);
        SparseRowMatrix sparseRowMatrix = new SparseRowMatrix(new int[]{4, 4}, new Vector[]{new DenseVector(new double[]{5.0d, 5.0d, 2.0d, valueOf.doubleValue()}), new DenseVector(new double[]{2.0d, valueOf.doubleValue(), 3.0d, 5.0d}), new DenseVector(new double[]{valueOf.doubleValue(), 5.0d, valueOf.doubleValue(), 3.0d}), new DenseVector(new double[]{3.0d, valueOf.doubleValue(), valueOf.doubleValue(), 5.0d})});
        StringBuilder sb = new StringBuilder();
        String str = "";
        Iterator iterateAll = sparseRowMatrix.iterateAll();
        while (iterateAll.hasNext()) {
            MatrixSlice matrixSlice = (MatrixSlice) iterateAll.next();
            Iterator iterateNonZero = matrixSlice.vector().iterateNonZero();
            while (iterateNonZero.hasNext()) {
                Vector.Element element = (Vector.Element) iterateNonZero.next();
                if (!Double.isNaN(element.get())) {
                    sb.append(str).append(matrixSlice.index()).append(',').append(element.index()).append(',').append(element.get());
                    str = "\n";
                }
            }
        }
        logger.info("Input matrix:\n" + ((Object) sb));
        writeLines(testTempFile, sb.toString());
        ParallelALSFactorizationJob parallelALSFactorizationJob = new ParallelALSFactorizationJob();
        Configuration configuration = new Configuration();
        configuration.set("mapred.input.dir", testTempFile.getAbsolutePath());
        configuration.set("mapred.output.dir", testTempDir.getAbsolutePath());
        configuration.setBoolean("mapred.output.compress", false);
        parallelALSFactorizationJob.setConf(configuration);
        parallelALSFactorizationJob.run(new String[]{"--tempDir", testTempDir2.getAbsolutePath(), "--lambda", String.valueOf(0.065d), "--numFeatures", String.valueOf(3), "--numIterations", String.valueOf(5)});
        Matrix readEntries = MathHelper.readEntries(configuration, new Path(testTempDir.getAbsolutePath(), "U/part-r-00000"), sparseRowMatrix.numRows(), 3);
        Matrix readEntries2 = MathHelper.readEntries(configuration, new Path(testTempDir.getAbsolutePath(), "M/part-r-00000"), sparseRowMatrix.numCols(), 3);
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        Iterator iterateAll2 = sparseRowMatrix.iterateAll();
        while (iterateAll2.hasNext()) {
            MatrixSlice matrixSlice2 = (MatrixSlice) iterateAll2.next();
            Iterator iterateNonZero2 = matrixSlice2.vector().iterateNonZero();
            while (iterateNonZero2.hasNext()) {
                Vector.Element element2 = (Vector.Element) iterateNonZero2.next();
                if (!Double.isNaN(element2.get())) {
                    double d = element2.get();
                    double dot = readEntries.getRow(matrixSlice2.index()).dot(readEntries2.getRow(element2.index()));
                    double d2 = d - dot;
                    fullRunningAverage.addDatum(d2 * d2);
                    logger.info("Comparing preference of user [" + matrixSlice2.index() + "] towards item [" + element2.index() + "], was [" + d + "] estimate is [" + dot + ']');
                }
            }
        }
        double sqrt = Math.sqrt(fullRunningAverage.getAverage());
        logger.info("RMSE: " + sqrt);
        assertTrue(sqrt < 0.2d);
    }
}
