package org.apache.mahout.vectorizer.encoders;

import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import java.util.HashMap;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/vectorizer/encoders/InteractionValueEncoderTest.class */
public class InteractionValueEncoderTest extends MahoutTestCase {
    @Test
    public void testAddToVector() {
        StaticWordValueEncoder staticWordValueEncoder = new StaticWordValueEncoder("word");
        ContinuousValueEncoder continuousValueEncoder = new ContinuousValueEncoder("cont");
        InteractionValueEncoder interactionValueEncoder = new InteractionValueEncoder("interactions", staticWordValueEncoder, continuousValueEncoder);
        DenseVector denseVector = new DenseVector(200);
        interactionValueEncoder.addInteractionToVector("a", "1.0", 1.0d, denseVector);
        assertEquals(interactionValueEncoder.getProbes(), denseVector.norm(1.0d), 0.0d);
        assertEquals(1.0d, denseVector.maxValue(), 0.0d);
        interactionValueEncoder.addInteractionToVector("a", "1.0", 1.0d, denseVector);
        assertEquals(r0 * 2.0f, denseVector.norm(1.0d), 0.0d);
        assertEquals(2.0d, denseVector.maxValue(), 0.0d);
        DenseVector denseVector2 = new DenseVector(20000);
        interactionValueEncoder.addInteractionToVector("a", "1.0", 1.0d, denseVector2);
        staticWordValueEncoder.addToVector("a", denseVector2);
        continuousValueEncoder.addToVector("1.0", denseVector2);
        assertEquals(interactionValueEncoder.getProbes() + staticWordValueEncoder.getProbes() + continuousValueEncoder.getProbes(), denseVector2.norm(1.0d), 0.001d);
    }

    @Test
    public void testAddToVectorUsesProductOfWeights() {
        InteractionValueEncoder interactionValueEncoder = new InteractionValueEncoder("interactions", new StaticWordValueEncoder("word"), new ContinuousValueEncoder("cont"));
        DenseVector denseVector = new DenseVector(200);
        interactionValueEncoder.addInteractionToVector("a", "0.9", 0.5d, denseVector);
        assertEquals(interactionValueEncoder.getProbes() * 0.5d * 0.9d, denseVector.norm(1.0d), 0.0d);
        assertEquals(0.45d, denseVector.maxValue(), 0.0d);
    }

    @Test
    public void testAddToVectorWithTextValueEncoder() {
        InteractionValueEncoder interactionValueEncoder = new InteractionValueEncoder("interactions", new StaticWordValueEncoder("word"), new TextValueEncoder("text"));
        DenseVector denseVector = new DenseVector(200);
        interactionValueEncoder.addInteractionToVector("a", "some text here", 1.0d, denseVector);
        assertEquals(interactionValueEncoder.getProbes() * 3.0f, denseVector.norm(1.0d), 0.0d);
    }

    @Test
    public void testTraceDictionary() {
        StaticWordValueEncoder staticWordValueEncoder = new StaticWordValueEncoder("first");
        StaticWordValueEncoder staticWordValueEncoder2 = new StaticWordValueEncoder("second");
        HashMap newHashMap = Maps.newHashMap();
        InteractionValueEncoder interactionValueEncoder = new InteractionValueEncoder("interactions", staticWordValueEncoder, staticWordValueEncoder2);
        interactionValueEncoder.setProbes(1);
        interactionValueEncoder.setTraceDictionary(newHashMap);
        interactionValueEncoder.addInteractionToVector("a", "b", 1.0d, new DenseVector(10));
        assertEquals(1L, r0.getNumNonZeroElements());
        assertEquals(1L, newHashMap.size());
        assertEquals("interactions=a:b", Iterables.getFirst(newHashMap.keySet(), (Object) null));
    }
}
