package org.apache.mahout.cf.taste.impl.recommender.svd;

import java.util.Arrays;
import org.apache.mahout.cf.taste.impl.TasteTestCase;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
import org.apache.mahout.cf.taste.impl.model.GenericPreference;
import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
import org.apache.mahout.cf.taste.impl.recommender.svd.ALSWRFactorizer;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.class */
public class ALSWRFactorizerTest extends TasteTestCase {
    private ALSWRFactorizer factorizer;
    private DataModel dataModel;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        FastByIDMap fastByIDMap = new FastByIDMap();
        fastByIDMap.put(1L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1L, 1L, 5.0f), new GenericPreference(1L, 2L, 5.0f), new GenericPreference(1L, 3L, 2.0f))));
        fastByIDMap.put(2L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2L, 1L, 2.0f), new GenericPreference(2L, 3L, 3.0f), new GenericPreference(2L, 4L, 5.0f))));
        fastByIDMap.put(3L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(3L, 2L, 5.0f), new GenericPreference(3L, 4L, 3.0f))));
        fastByIDMap.put(4L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(4L, 1L, 3.0f), new GenericPreference(4L, 4L, 5.0f))));
        this.dataModel = new GenericDataModel(fastByIDMap);
        this.factorizer = new ALSWRFactorizer(this.dataModel, 3, 0.065d, 10);
    }

    @Test
    public void setFeatureColumn() throws Exception {
        ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(this.factorizer);
        DenseVector denseVector = new DenseVector(new double[]{0.5d, 2.0d, 1.5d});
        features.setFeatureColumnInM(1, denseVector);
        double[][] m = features.getM();
        assertEquals(denseVector.get(0), m[1][0], 1.0E-6d);
        assertEquals(denseVector.get(1), m[1][1], 1.0E-6d);
        assertEquals(denseVector.get(2), m[1][2], 1.0E-6d);
    }

    @Test
    public void ratingVector() throws Exception {
        Vector ratingVector = this.factorizer.ratingVector(this.dataModel.getPreferencesFromUser(1L));
        assertEquals(r0.length(), ratingVector.getNumNondefaultElements());
        assertEquals(r0.get(0).getValue(), ratingVector.get(0), 1.0E-6d);
        assertEquals(r0.get(1).getValue(), ratingVector.get(1), 1.0E-6d);
        assertEquals(r0.get(2).getValue(), ratingVector.get(2), 1.0E-6d);
    }

    @Test
    public void averageRating() throws Exception {
        assertEquals(2.5d, new ALSWRFactorizer.Features(this.factorizer).averateRating(3L), 1.0E-6d);
    }

    @Test
    public void initializeM() throws Exception {
        double[][] m = new ALSWRFactorizer.Features(this.factorizer).getM();
        assertEquals(3.333333333d, m[0][0], 1.0E-6d);
        assertEquals(5.0d, m[1][0], 1.0E-6d);
        assertEquals(2.5d, m[2][0], 1.0E-6d);
        assertEquals(4.333333333d, m[3][0], 1.0E-6d);
        for (int i = 0; i < this.dataModel.getNumItems(); i++) {
            for (int i2 = 1; i2 < 3; i2++) {
                assertTrue(m[i][i2] >= 0.0d);
                assertTrue(m[i][i2] <= 0.1d);
            }
        }
    }

    @Test
    public void toyExample() throws Exception {
        SVDRecommender sVDRecommender = new SVDRecommender(this.dataModel, this.factorizer);
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            for (Preference preference : this.dataModel.getPreferencesFromUser(userIDs.nextLong())) {
                double value = preference.getValue() - sVDRecommender.estimatePreference(r0, preference.getItemID());
                fullRunningAverage.addDatum(value * value);
            }
        }
        assertTrue(Math.sqrt(fullRunningAverage.getAverage()) < 0.2d);
    }
}
