package org.apache.flink.ml.feature.vectorindexer;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModelData;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.class */
public class VectorIndexerModel implements Model<VectorIndexerModel>, VectorIndexerModelParams<VectorIndexerModel> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private Table modelDataTable;

    /* loaded from: input_file:org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel$FindIndex.class */
    private static class FindIndex extends RichFlatMapFunction<Row, Row> {
        private final String broadcastModelKey;
        private final String inputCol;
        private final String handleInValid;
        private Map<Integer, Map<Double, Integer>> categoryMaps;

        public FindIndex(String str, String str2, String str3) {
            this.broadcastModelKey = str;
            this.inputCol = str2;
            this.handleInValid = str3;
        }

        public void flatMap(Row row, Collector<Row> collector) {
            if (this.categoryMaps == null) {
                this.categoryMaps = ((VectorIndexerModelData) getRuntimeContext().getBroadcastVariable(this.broadcastModelKey).get(0)).categoryMaps;
            }
            Vector m185clone = ((Vector) row.getField(this.inputCol)).m185clone();
            for (Map.Entry<Integer, Map<Double, Integer>> entry : this.categoryMaps.entrySet()) {
                int intValue = entry.getKey().intValue();
                if (VectorIndexerModel.getMapping(m185clone.get(intValue), entry.getValue(), this.handleInValid) == null) {
                    return;
                } else {
                    m185clone.set(intValue, r0.intValue());
                }
            }
            collector.collect(Row.join(row, new Row[]{Row.of(new Object[]{m185clone})}));
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row>) collector);
        }
    }

    public VectorIndexerModel() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        String inputCol = getInputCol();
        String outputCol = getOutputCol();
        StreamTableEnvironment tableEnvironment = this.modelDataTable.getTableEnvironment();
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        RowTypeInfo rowTypeInfo2 = new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{VectorTypeInfo.INSTANCE}), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{outputCol}));
        return new Table[]{tableEnvironment.fromDataStream(BroadcastUtils.withBroadcastStream(Collections.singletonList(tableEnvironment.toDataStream(tableArr[0])), Collections.singletonMap("broadcastModelKey", VectorIndexerModelData.getModelDataStream(this.modelDataTable)), list -> {
            return ((DataStream) list.get(0)).flatMap(new FindIndex("broadcastModelKey", inputCol, getHandleInvalid()), rowTypeInfo2);
        }))};
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
        ReadWriteUtils.saveModelData(VectorIndexerModelData.getModelDataStream(this.modelDataTable), str, new VectorIndexerModelData.ModelDataEncoder());
    }

    public static VectorIndexerModel load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ((VectorIndexerModel) ReadWriteUtils.loadStageParam(str)).setModelData(ReadWriteUtils.loadModelData(streamTableEnvironment, str, new VectorIndexerModelData.ModelDataDecoder()));
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.flink.ml.api.Model
    public VectorIndexerModel setModelData(Table... tableArr) {
        this.modelDataTable = tableArr[0];
        return this;
    }

    @Override // org.apache.flink.ml.api.Model
    public Table[] getModelData() {
        return new Table[]{this.modelDataTable};
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Integer getMapping(double d, Map<Double, Integer> map, String str) {
        if (map.containsKey(Double.valueOf(d))) {
            return map.get(Double.valueOf(d));
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case 3287941:
                if (str.equals(HasHandleInvalid.KEEP_INVALID)) {
                    z = 2;
                    break;
                }
                break;
            case 3532159:
                if (str.equals(HasHandleInvalid.SKIP_INVALID)) {
                    z = false;
                    break;
                }
                break;
            case 96784904:
                if (str.equals(HasHandleInvalid.ERROR_INVALID)) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return null;
            case true:
                throw new RuntimeException("The input contains unseen double: " + d + ". See " + HANDLE_INVALID + " parameter for more options.");
            case true:
                return Integer.valueOf(map.size());
            default:
                throw new UnsupportedOperationException("Unsupported " + HANDLE_INVALID + "type: " + str);
        }
    }
}
