package org.apache.flink.ml.feature.lsh;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.typeinfo.PriorityQueueTypeInfo;
import org.apache.flink.ml.feature.lsh.LSHModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/lsh/LSHModel.class */
abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, LSHModelParams<T> {
    private static final String MODEL_DATA_BC_KEY = "modelData";
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private final Class<? extends LSHModelData> modelDataClass;
    protected Table modelDataTable;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/lsh/LSHModel$ExplodeHashValuesFunction.class */
    public static class ExplodeHashValuesFunction implements FlatMapFunction<Row, Row> {
        private final String idCol;
        private final String inputCol;
        private final String outputCol;

        public ExplodeHashValuesFunction(String str, String str2, String str3) {
            this.idCol = str;
            this.inputCol = str2;
            this.outputCol = str3;
        }

        public void flatMap(Row row, Collector<Row> collector) throws Exception {
            Row of = Row.of(new Object[]{row.getField(this.idCol), row.getField(this.inputCol)});
            DenseVector[] denseVectorArr = (DenseVector[]) row.getFieldAs(this.outputCol);
            for (int i = 0; i < denseVectorArr.length; i++) {
                collector.collect(Row.join(of, new Row[]{Row.of(new Object[]{Integer.valueOf(i), denseVectorArr[i]})}));
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row>) collector);
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/feature/lsh/LSHModel$FilterByBucketFunction.class */
    private static class FilterByBucketFunction extends RichFlatMapFunction<Row, Row> {
        private final String inputCol;
        private final String outputCol;
        private final Vector key;
        private LSHModelData modelData;
        private DenseVector[] keyHashes;

        public FilterByBucketFunction(String str, String str2, Vector vector) {
            this.inputCol = str;
            this.outputCol = str2;
            this.key = vector;
        }

        public void flatMap(Row row, Collector<Row> collector) throws Exception {
            if (null == this.modelData) {
                this.modelData = (LSHModelData) getRuntimeContext().getBroadcastVariable(LSHModel.MODEL_DATA_BC_KEY).get(0);
                this.keyHashes = this.modelData.hashFunction(this.key);
            }
            DenseVector[] denseVectorArr = (DenseVector[]) row.getFieldAs(this.outputCol);
            boolean z = false;
            int i = 0;
            while (true) {
                if (i >= this.keyHashes.length) {
                    break;
                }
                if (this.keyHashes[i].equals(denseVectorArr[i])) {
                    z = true;
                    break;
                }
                i++;
            }
            if (z) {
                collector.collect(Row.join(row, new Row[]{Row.of(new Object[]{Double.valueOf(this.modelData.keyDistance(this.key, (Vector) row.getFieldAs(this.inputCol)))})}));
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/lsh/LSHModel$FilterByDistanceFunction.class */
    public static class FilterByDistanceFunction extends RichFlatMapFunction<Row, Row> {
        private final double threshold;
        private LSHModelData modelData;

        public FilterByDistanceFunction(double d) {
            this.threshold = d;
        }

        public void flatMap(Row row, Collector<Row> collector) throws Exception {
            if (null == this.modelData) {
                this.modelData = (LSHModelData) getRuntimeContext().getBroadcastVariable(LSHModel.MODEL_DATA_BC_KEY).get(0);
            }
            double keyDistance = this.modelData.keyDistance((Vector) row.getFieldAs(2), (Vector) row.getFieldAs(3));
            if (keyDistance <= this.threshold) {
                collector.collect(Row.of(new Object[]{row.getFieldAs(0), row.getFieldAs(1), Double.valueOf(keyDistance)}));
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/lsh/LSHModel$IndexHashValueKeySelector.class */
    public static class IndexHashValueKeySelector implements KeySelector<Row, Tuple2<Integer, DenseVector>> {
        private IndexHashValueKeySelector() {
        }

        public Tuple2<Integer, DenseVector> getKey(Row row) throws Exception {
            return Tuple2.of((Integer) row.getFieldAs(2), (DenseVector) row.getFieldAs(3));
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/feature/lsh/LSHModel$PredictFunction.class */
    private static class PredictFunction extends RichMapFunction<Row, Row> {
        private final String inputCol;
        private LSHModelData modelData;

        public PredictFunction(String str) {
            this.inputCol = str;
        }

        public Row map(Row row) throws Exception {
            if (null == this.modelData) {
                this.modelData = (LSHModelData) getRuntimeContext().getBroadcastVariable(LSHModel.MODEL_DATA_BC_KEY).get(0);
            }
            return Row.join(row, new Row[]{Row.of(new Object[]{this.modelData.hashFunction((Vector) row.getFieldAs(this.inputCol))})});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/lsh/LSHModel$TopKFunction.class */
    public static class TopKFunction implements AggregateFunction<Row, PriorityQueue<Row>, List<Row>> {
        private final int numNearestNeighbors;
        private final String distCol;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/flink/ml/feature/lsh/LSHModel$TopKFunction$DistColComparator.class */
        public static class DistColComparator implements Comparator<Row>, Serializable {
            private final String distCol;

            private DistColComparator(String str) {
                this.distCol = str;
            }

            @Override // java.util.Comparator
            public int compare(Row row, Row row2) {
                return Double.compare(((Double) row.getFieldAs(this.distCol)).doubleValue(), ((Double) row2.getFieldAs(this.distCol)).doubleValue());
            }
        }

        public TopKFunction(String str, int i) {
            this.distCol = str;
            this.numNearestNeighbors = i;
        }

        /* renamed from: createAccumulator, reason: merged with bridge method [inline-methods] */
        public PriorityQueue<Row> m145createAccumulator() {
            return new PriorityQueue<>(this.numNearestNeighbors, getComparator());
        }

        public PriorityQueue<Row> add(Row row, PriorityQueue<Row> priorityQueue) {
            if (priorityQueue.size() == this.numNearestNeighbors) {
                if (priorityQueue.comparator().compare(row, priorityQueue.peek()) < 0) {
                    priorityQueue.poll();
                }
            }
            priorityQueue.add(row);
            return priorityQueue;
        }

        public List<Row> getResult(PriorityQueue<Row> priorityQueue) {
            return new ArrayList(priorityQueue);
        }

        public PriorityQueue<Row> merge(PriorityQueue<Row> priorityQueue, PriorityQueue<Row> priorityQueue2) {
            PriorityQueue<Row> priorityQueue3 = new PriorityQueue<>((PriorityQueue<? extends Row>) priorityQueue);
            Iterator<Row> it = priorityQueue2.iterator();
            while (it.hasNext()) {
                add(it.next(), priorityQueue3);
            }
            return priorityQueue3;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public Comparator<Row> getComparator() {
            return new DistColComparator(this.distCol);
        }
    }

    public LSHModel(Class<? extends LSHModelData> cls) {
        this.modelDataClass = cls;
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override // org.apache.flink.ml.api.Model
    public T setModelData(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        this.modelDataTable = tableArr[0];
        return this;
    }

    @Override // org.apache.flink.ml.api.Model
    public Table[] getModelData() {
        return new Table[]{this.modelDataTable};
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream dataStream = tableEnvironment.toDataStream(this.modelDataTable, this.modelDataClass);
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        RowTypeInfo rowTypeInfo2 = new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{TypeInformation.of(DenseVector[].class)}), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{getOutputCol()}));
        return new Table[]{tableEnvironment.fromDataStream(BroadcastUtils.withBroadcastStream(Collections.singletonList(tableEnvironment.toDataStream(tableArr[0])), Collections.singletonMap(MODEL_DATA_BC_KEY, dataStream), list -> {
            return ((DataStream) list.get(0)).map(new PredictFunction(getInputCol()), rowTypeInfo2);
        }))};
    }

    public Table approxNearestNeighbors(Table table, Vector vector, int i, String str) {
        StreamTableEnvironment tableEnvironment = ((TableImpl) table).getTableEnvironment();
        Table table2 = table.getResolvedSchema().getColumnNames().contains(getOutputCol()) ? table : transform(table)[0];
        DataStream dataStream = tableEnvironment.toDataStream(this.modelDataTable, this.modelDataClass);
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(table2.getResolvedSchema());
        RowTypeInfo rowTypeInfo2 = new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{Types.DOUBLE}), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{str}));
        DataStream withBroadcastStream = BroadcastUtils.withBroadcastStream(Collections.singletonList(tableEnvironment.toDataStream(table2)), Collections.singletonMap(MODEL_DATA_BC_KEY, dataStream), list -> {
            return ((DataStream) list.get(0)).flatMap(new FilterByBucketFunction(getInputCol(), getOutputCol(), vector), rowTypeInfo2);
        });
        TopKFunction topKFunction = new TopKFunction(str, i);
        SingleOutputStreamOperator flatMap = DataStreamUtils.aggregate(withBroadcastStream, topKFunction, new PriorityQueueTypeInfo(topKFunction.getComparator(), rowTypeInfo2), Types.LIST(rowTypeInfo2)).flatMap((list2, collector) -> {
            Iterator it = list2.iterator();
            while (it.hasNext()) {
                collector.collect((Row) it.next());
            }
        });
        flatMap.getTransformation().setOutputType(rowTypeInfo2);
        return tableEnvironment.fromDataStream(flatMap);
    }

    public Table approxNearestNeighbors(Table table, Vector vector, int i) {
        return approxNearestNeighbors(table, vector, i, "distCol");
    }

    public Table approxSimilarityJoin(Table table, Table table2, double d, String str, String str2) {
        StreamTableEnvironment tableEnvironment = ((TableImpl) table).getTableEnvironment();
        DataStream<Row> preprocessData = preprocessData(table, str);
        DataStream<Row> preprocessData2 = preprocessData(table2, str);
        RowTypeInfo outputType = getOutputType(table, str);
        RowTypeInfo rowTypeInfo = new RowTypeInfo(new TypeInformation[]{outputType.getTypeAt(0), outputType.getTypeAt(0), outputType.getTypeAt(1), outputType.getTypeAt(1)});
        DataStream dataStream = tableEnvironment.toDataStream(this.modelDataTable, this.modelDataClass);
        DataStream reduce = DataStreamUtils.reduce(preprocessData.join(preprocessData2).where(new IndexHashValueKeySelector()).equalTo(new IndexHashValueKeySelector()).window(EndOfStreamWindows.get()).apply((row, row2) -> {
            return Row.of(new Object[]{row.getField(0), row2.getField(0), row.getField(1), row2.getField(1)});
        }, rowTypeInfo).keyBy(new KeySelector<Row, Tuple2<Integer, Integer>>() { // from class: org.apache.flink.ml.feature.lsh.LSHModel.1
            public Tuple2<Integer, Integer> getKey(Row row3) {
                return Tuple2.of((Integer) row3.getFieldAs(0), (Integer) row3.getFieldAs(1));
            }
        }), (row3, row4) -> {
            return row3;
        }, (TypeInformation) rowTypeInfo);
        TypeInformation typeAt = TableUtils.getRowTypeInfo(table.getResolvedSchema()).getTypeAt(str);
        return tableEnvironment.fromDataStream(BroadcastUtils.withBroadcastStream(Collections.singletonList(reduce), Collections.singletonMap(MODEL_DATA_BC_KEY, dataStream), list -> {
            return ((DataStream) list.get(0)).flatMap(new FilterByDistanceFunction(d), new RowTypeInfo(new TypeInformation[]{typeAt, typeAt, Types.DOUBLE}, new String[]{"datasetA.id", "datasetB.id", str2}));
        }));
    }

    public Table approxSimilarityJoin(Table table, Table table2, double d, String str) {
        return approxSimilarityJoin(table, table2, d, str, "distCol");
    }

    private DataStream<Row> preprocessData(Table table, String str) {
        StreamTableEnvironment tableEnvironment = ((TableImpl) table).getTableEnvironment();
        Table table2 = table.getResolvedSchema().getColumnNames().contains(getOutputCol()) ? table : transform(table)[0];
        return tableEnvironment.toDataStream(table2).flatMap(new ExplodeHashValuesFunction(str, getInputCol(), getOutputCol()), getOutputType(table2, str));
    }

    private RowTypeInfo getOutputType(Table table, String str) {
        return new RowTypeInfo(new TypeInformation[]{TableUtils.getRowTypeInfo(table.getResolvedSchema()).getTypeAt(str), VectorTypeInfo.INSTANCE, Types.INT, DenseVectorTypeInfo.INSTANCE}, new String[]{str, getInputCol(), "index", "hashValue"});
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -36062676:
                if (implMethodName.equals("lambda$approxSimilarityJoin$3cd5fa47$1")) {
                    z = false;
                    break;
                }
                break;
            case 998559026:
                if (implMethodName.equals("lambda$approxSimilarityJoin$b5b16867$1")) {
                    z = true;
                    break;
                }
                break;
            case 1847267126:
                if (implMethodName.equals("lambda$approxNearestNeighbors$48f590d9$1")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/JoinFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("join") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/lsh/LSHModel") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;Lorg/apache/flink/types/Row;)Lorg/apache/flink/types/Row;")) {
                    return (row, row2) -> {
                        return Row.of(new Object[]{row.getField(0), row2.getField(0), row.getField(1), row2.getField(1)});
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/ReduceFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("reduce") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/lsh/LSHModel") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;Lorg/apache/flink/types/Row;)Lorg/apache/flink/types/Row;")) {
                    return (row3, row4) -> {
                        return row3;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("flatMap") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Lorg/apache/flink/util/Collector;)V") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/lsh/LSHModel") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Lorg/apache/flink/util/Collector;)V")) {
                    return (list2, collector) -> {
                        Iterator it = list2.iterator();
                        while (it.hasNext()) {
                            collector.collect((Row) it.next());
                        }
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
