package org.apache.jackrabbit.oak.plugins.index.search.util;

import java.util.Iterator;
import org.apache.commons.io.IOUtils;
import org.apache.jackrabbit.oak.query.SQL2Parser;
import org.apache.jackrabbit.oak.query.SQL2ParserTest;
import org.apache.jackrabbit.oak.spi.query.Filter;
import org.apache.jackrabbit.oak.spi.query.fulltext.FullTextTerm;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/jackrabbit/oak/plugins/index/search/util/LMSEstimatorTest.class */
public class LMSEstimatorTest {
    private static final SQL2Parser p = SQL2ParserTest.createTestSQL2Parser();

    @Test
    public void testUpdate() throws Exception {
        new LMSEstimator().update((Filter) Mockito.mock(Filter.class), 100L);
    }

    @Test
    public void testMultipleUpdates() throws Exception {
        LMSEstimator lMSEstimator = new LMSEstimator();
        Filter filter = (Filter) Mockito.mock(Filter.class);
        Mockito.when(filter.getFullTextConstraint()).thenReturn(new FullTextTerm("foo", "bar", false, false, ""));
        lMSEstimator.update(filter, 0L);
        long estimate = lMSEstimator.estimate(filter);
        Assert.assertEquals(estimate, lMSEstimator.estimate(filter));
        long j = 10 - estimate;
        lMSEstimator.update(filter, 10L);
        long estimate2 = lMSEstimator.estimate(filter);
        Assert.assertEquals(estimate2, lMSEstimator.estimate(filter));
        long j2 = 10 - estimate2;
        Assert.assertTrue(j2 < j);
        lMSEstimator.update(filter, 10L);
        long estimate3 = lMSEstimator.estimate(filter);
        Assert.assertEquals(estimate3, lMSEstimator.estimate(filter));
        Assert.assertTrue(10 - estimate3 < j2);
    }

    @Test
    public void testEstimate() throws Exception {
        Assert.assertEquals(0L, new LMSEstimator().estimate((Filter) Mockito.mock(Filter.class)));
    }

    @Test
    public void testConvergence() throws Exception {
        LMSEstimator lMSEstimator = new LMSEstimator();
        long mse = getMSE(lMSEstimator);
        for (int i = 1; i <= 15; i++) {
            train(lMSEstimator);
            long mse2 = getMSE(lMSEstimator);
            Assert.assertTrue(mse2 <= mse);
            mse = mse2;
        }
    }

    private long getMSE(LMSEstimator lMSEstimator) throws Exception {
        int i = 0;
        long j = 0;
        Iterator it = IOUtils.readLines(getClass().getResourceAsStream("/lms-data.tsv")).iterator();
        while (it.hasNext()) {
            String[] split = ((String) it.next()).split("\t");
            j = (long) (j + Math.pow(Long.parseLong(split[1]) - lMSEstimator.estimate(p.parse(split[2]).getSource().createFilter(true)), 2.0d));
            i++;
        }
        return j / i;
    }

    private void train(LMSEstimator lMSEstimator) throws Exception {
        Iterator it = IOUtils.readLines(getClass().getResourceAsStream("/lms-data.tsv")).iterator();
        while (it.hasNext()) {
            String[] split = ((String) it.next()).split("\t");
            lMSEstimator.update(p.parse(split[2]).getSource().createFilter(true), Long.parseLong(split[1]));
        }
    }
}
