package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.feature.lsh.MinHashLSH;
import org.apache.flink.ml.feature.lsh.MinHashLSHModel;
import org.apache.flink.ml.feature.lsh.MinHashLSHModelData;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/MinHashLSHTest.class */
public class MinHashLSHTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private final List<Row> outputRows = convertToOutputFormat(Arrays.asList(new double[]{new double[]{1.73046954E8d, 1.57275425E8d, 6.90717571E8d}, new double[]{5.02301169E8d, 7.967141E8d, 4.06089319E8d}, new double[]{2.83652171E8d, 1.97714719E8d, 6.04731316E8d}, new double[]{5.2181506E8d, 6.36933726E8d, 6.13894128E8d}, new double[]{3.04301769E8d, 1.113672955E9d, 6.1388711E8d}}, new double[]{new double[]{1.73046954E8d, 1.57275425E8d, 6.7798584E7d}, new double[]{6.38582806E8d, 1.78703694E8d, 4.06089319E8d}, new double[]{6.232638E8d, 9.28867E7d, 9.92010642E8d}, new double[]{2.461064E8d, 1.12787481E8d, 1.92180297E8d}, new double[]{2.38162496E8d, 1.552933319E9d, 2.77995137E8d}}, new double[]{new double[]{1.73046954E8d, 1.57275425E8d, 6.90717571E8d}, new double[]{1.453197722E9d, 7.967141E8d, 4.06089319E8d}, new double[]{6.232638E8d, 1.97714719E8d, 6.04731316E8d}, new double[]{2.461064E8d, 1.12787481E8d, 1.92180297E8d}, new double[]{1.224130231E9d, 1.113672955E9d, 2.77995137E8d}}));
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table inputTable;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/MinHashLSHTest$DenseVectorArrayComparator.class */
    public static class DenseVectorArrayComparator implements Comparator<DenseVector[]> {
        private DenseVectorArrayComparator() {
        }

        @Override // java.util.Comparator
        public int compare(DenseVector[] denseVectorArr, DenseVector[] denseVectorArr2) {
            if (denseVectorArr.length != denseVectorArr2.length) {
                return denseVectorArr.length - denseVectorArr2.length;
            }
            for (int i = 0; i < denseVectorArr.length; i++) {
                int compare = TestUtils.compare(denseVectorArr[i], denseVectorArr2[i]);
                if (0 != compare) {
                    return compare;
                }
            }
            return 0;
        }
    }

    private static List<Row> convertToOutputFormat(List<double[][]> list) {
        return (List) list.stream().map(dArr -> {
            return Row.of(new Object[]{(DenseVector[]) Arrays.stream(dArr).map(Vectors::dense).toArray(i -> {
                return new DenseVector[i];
            })});
        }).collect(Collectors.toList());
    }

    private static void verifyPredictionResult(Table table, List<Row> list) throws Exception {
        TestBaseUtils.compareResultCollections(list, IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect()), (row, row2) -> {
            return new DenseVectorArrayComparator().compare((DenseVector[]) row.getFieldAs(0), (DenseVector[]) row2.getFieldAs(0));
        });
    }

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        List asList = Arrays.asList(Row.of(new Object[]{0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0d, 1.0d, 1.0d})}), Row.of(new Object[]{1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0d, 1.0d, 1.0d})}), Row.of(new Object[]{2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0d, 1.0d, 1.0d})}));
        this.inputTable = this.tEnv.fromDataStream(this.env.fromCollection(asList), Schema.newBuilder().column("f0", DataTypes.INT()).column("f1", DataTypes.of(SparseVector.class)).build()).as("id", new String[]{"vec"});
    }

    @Test
    public void testHashFunction() {
        DenseVector[] hashFunction = new MinHashLSHModelData(3, 1, new int[]{0, 1, 3}, new int[]{1, 2, 0}).hashFunction(Vectors.sparse(10, new int[]{2, 3, 5, 7}, new double[]{1.0d, 1.0d, 1.0d, 1.0d}));
        Assert.assertEquals(3L, hashFunction.length);
        Assert.assertEquals(Vectors.dense(new double[]{1.0d}), hashFunction[0]);
        Assert.assertEquals(Vectors.dense(new double[]{5.0d}), hashFunction[1]);
        Assert.assertEquals(Vectors.dense(new double[]{9.0d}), hashFunction[2]);
    }

    @Test
    public void testHashFunctionEqualWithSparseDenseVector() {
        MinHashLSHModelData generateModelData = MinHashLSHModelData.generateModelData(3, 1, 10, 2022L);
        new MinHashLSHModelData(3, 1, new int[]{0, 1, 3}, new int[]{1, 2, 0});
        SparseVector sparse = Vectors.sparse(10, new int[]{2, 3, 5, 7}, new double[]{1.0d, 1.0d, 1.0d, 1.0d});
        Assert.assertArrayEquals(generateModelData.hashFunction(sparse.toDense()), generateModelData.hashFunction(sparse.toSparse()));
    }

    @Test(expected = IllegalArgumentException.class)
    public void testHashFunctionWithEmptyVector() {
        new MinHashLSHModelData(3, 1, new int[]{0, 1, 3}, new int[]{1, 2, 0}).hashFunction(Vectors.sparse(10, new int[0], new double[0]));
    }

    @Test
    public void testParam() {
        MinHashLSH minHashLSH = new MinHashLSH();
        Assert.assertEquals("input", minHashLSH.getInputCol());
        Assert.assertEquals("output", minHashLSH.getOutputCol());
        Assert.assertEquals(MinHashLSH.class.getName().hashCode(), minHashLSH.getSeed());
        Assert.assertEquals(1L, minHashLSH.getNumHashTables());
        Assert.assertEquals(1L, minHashLSH.getNumHashFunctionsPerTable());
        ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) minHashLSH.setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(3)).setNumHashFunctionsPerTable(4);
        Assert.assertEquals("vec", minHashLSH.getInputCol());
        Assert.assertEquals("hashes", minHashLSH.getOutputCol());
        Assert.assertEquals(2022L, minHashLSH.getSeed());
        Assert.assertEquals(3L, minHashLSH.getNumHashTables());
        Assert.assertEquals(4L, minHashLSH.getNumHashFunctionsPerTable());
    }

    @Test
    public void testOutputSchema() throws Exception {
        Assert.assertEquals(Arrays.asList("id", "vec", "hashes"), ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5)).setNumHashFunctionsPerTable(3)).fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        MinHashLSH minHashLSH = (MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5)).setNumHashFunctionsPerTable(3);
        verifyPredictionResult(minHashLSH.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0].select(new Expression[]{Expressions.$(minHashLSH.getOutputCol())}), this.outputRows);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testFitAndPredictWithNumHashFunctionPerTableIsOne() throws Exception {
        List<Row> convertToOutputFormat = convertToOutputFormat(Arrays.asList(new double[]{new double[]{1.73046954E8d}, new double[]{1.57275425E8d}, new double[]{6.7798584E7d}, new double[]{6.38582806E8d}, new double[]{1.78703694E8d}}, new double[]{new double[]{1.73046954E8d}, new double[]{1.57275425E8d}, new double[]{6.90717571E8d}, new double[]{5.02301169E8d}, new double[]{7.967141E8d}}, new double[]{new double[]{1.73046954E8d}, new double[]{1.57275425E8d}, new double[]{6.90717571E8d}, new double[]{1.453197722E9d}, new double[]{7.967141E8d}}));
        MinHashLSH minHashLSH = (MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5);
        verifyPredictionResult(minHashLSH.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0].select(new Expression[]{Expressions.$(minHashLSH.getOutputCol())}), convertToOutputFormat);
    }

    @Test
    public void testEstimatorSaveLoadAndPredict() throws Exception {
        MinHashLSH minHashLSH = (MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5)).setNumHashFunctionsPerTable(3);
        MinHashLSHModel fit = TestUtils.saveAndReload(this.tEnv, minHashLSH, this.tempFolder.newFolder().getAbsolutePath(), MinHashLSH::load).fit(new Table[]{this.inputTable});
        Assert.assertEquals(Arrays.asList("numHashTables", "numHashFunctionsPerTable", "randCoefficientA", "randCoefficientB"), fit.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(fit.transform(new Table[]{this.inputTable})[0].select(new Expression[]{Expressions.$(minHashLSH.getOutputCol())}), this.outputRows);
    }

    @Test
    public void testModelSaveLoadAndPredict() throws Exception {
        MinHashLSH minHashLSH = (MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5)).setNumHashFunctionsPerTable(3);
        verifyPredictionResult(TestUtils.saveAndReload(this.tEnv, minHashLSH.fit(new Table[]{this.inputTable}), this.tempFolder.newFolder().getAbsolutePath(), MinHashLSHModel::load).transform(new Table[]{this.inputTable})[0].select(new Expression[]{Expressions.$(minHashLSH.getOutputCol())}), this.outputRows);
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5)).setNumHashFunctionsPerTable(3)).fit(new Table[]{this.inputTable}).getModelData()[0];
        List columnNames = table.getResolvedSchema().getColumnNames();
        DataStream dataStream = this.tEnv.toDataStream(table);
        Assert.assertArrayEquals(new String[]{"numHashTables", "numHashFunctionsPerTable", "randCoefficientA", "randCoefficientB"}, columnNames.toArray(new String[0]));
        Row row = (Row) IteratorUtils.toList(dataStream.executeAndCollect()).get(0);
        Assert.assertNotNull(new MinHashLSHModelData(((Integer) row.getFieldAs(0)).intValue(), ((Integer) row.getFieldAs(1)).intValue(), (int[]) row.getFieldAs(2), (int[]) row.getFieldAs(3)));
        Assert.assertEquals(r0.getNumHashTables(), r0.numHashTables);
        Assert.assertEquals(r0.getNumHashFunctionsPerTable(), r0.numHashFunctionsPerTable);
        Assert.assertEquals(r0.getNumHashTables() * r0.getNumHashFunctionsPerTable(), r0.randCoefficientA.length);
        Assert.assertEquals(r0.getNumHashTables() * r0.getNumHashFunctionsPerTable(), r0.randCoefficientB.length);
    }

    @Test
    public void testSetModelData() throws Exception {
        MinHashLSH minHashLSH = (MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5)).setNumHashFunctionsPerTable(3);
        MinHashLSHModel fit = minHashLSH.fit(new Table[]{this.inputTable});
        MinHashLSHModel modelData = new MinHashLSHModel().setModelData(new Table[]{fit.getModelData()[0]});
        ParamUtils.updateExistingParams(modelData, fit.getParamMap());
        verifyPredictionResult(modelData.transform(new Table[]{this.inputTable})[0].select(new Expression[]{Expressions.$(minHashLSH.getOutputCol())}), this.outputRows);
    }

    @Test
    public void testApproxNearestNeighbors() {
        TestBaseUtils.compareResultCollections(Arrays.asList(Row.of(new Object[]{0, Double.valueOf(0.75d)}), Row.of(new Object[]{1, Double.valueOf(0.75d)})), IteratorUtils.toList(((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5)).setNumHashFunctionsPerTable(1)).fit(new Table[]{this.inputTable}).approxNearestNeighbors(this.inputTable, Vectors.sparse(6, new int[]{1, 3}, new double[]{1.0d, 1.0d}), 2).select(new Expression[]{Expressions.$("id"), Expressions.$("distCol")}).execute().collect()), Comparator.comparing(row -> {
            return (Comparable) row.getFieldAs(0);
        }));
    }

    @Test
    public void testApproxSimilarityJoin() {
        MinHashLSH minHashLSH = (MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) ((MinHashLSH) new MinHashLSH().setInputCol("vec")).setOutputCol("hashes")).setSeed(2022L)).setNumHashTables(5)).setNumHashFunctionsPerTable(1);
        Table table = this.inputTable;
        TestBaseUtils.compareResultCollections(Arrays.asList(Row.of(new Object[]{1, 4, Double.valueOf(0.5d)}), Row.of(new Object[]{0, 5, Double.valueOf(0.5d)}), Row.of(new Object[]{1, 5, Double.valueOf(0.5d)}), Row.of(new Object[]{2, 5, Double.valueOf(0.5d)})), IteratorUtils.toList(minHashLSH.fit(new Table[]{table}).approxSimilarityJoin(table, this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{3, Vectors.sparse(6, new int[]{1, 3, 5}, new double[]{1.0d, 1.0d, 1.0d})}), Row.of(new Object[]{4, Vectors.sparse(6, new int[]{2, 3, 5}, new double[]{1.0d, 1.0d, 1.0d})}), Row.of(new Object[]{5, Vectors.sparse(6, new int[]{1, 2, 4}, new double[]{1.0d, 1.0d, 1.0d})}))), Schema.newBuilder().column("f0", DataTypes.INT()).column("f1", DataTypes.of(SparseVector.class)).build()).as("id", new String[]{"vec"}), 0.6d, "id").execute().collect()), Comparator.comparingInt(row -> {
            return ((Integer) row.getFieldAs(0)).intValue();
        }).thenComparingInt(row2 -> {
            return ((Integer) row2.getFieldAs(1)).intValue();
        }).thenComparingDouble(row3 -> {
            return ((Double) row3.getFieldAs(2)).doubleValue();
        }));
    }
}
