package io.dingodb.exec.operator;

import io.dingodb.codec.CodecService;
import io.dingodb.codec.KeyValueCodec;
import io.dingodb.common.CommonId;
import io.dingodb.common.profile.OperatorProfile;
import io.dingodb.common.store.KeyValue;
import io.dingodb.common.type.ListType;
import io.dingodb.common.type.TupleMapping;
import io.dingodb.common.util.Optional;
import io.dingodb.common.vector.TxnVectorSearchResponse;
import io.dingodb.common.vector.VectorSearchResponse;
import io.dingodb.exec.Services;
import io.dingodb.exec.dag.Vertex;
import io.dingodb.exec.fun.vector.VectorCosineDistanceFun;
import io.dingodb.exec.fun.vector.VectorIPDistanceFun;
import io.dingodb.exec.fun.vector.VectorL2DistanceFun;
import io.dingodb.exec.operator.params.TxnPartVectorParam;
import io.dingodb.meta.entity.Column;
import io.dingodb.partition.DingoPartitionServiceProvider;
import io.dingodb.partition.PartitionService;
import io.dingodb.store.api.StoreService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/dingodb/exec/operator/TxnPartVectorOperator.class */
public class TxnPartVectorOperator extends FilterProjectSourceOperator {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) TxnPartVectorOperator.class);
    public static final TxnPartVectorOperator INSTANCE = new TxnPartVectorOperator();

    @Override // io.dingodb.exec.operator.FilterProjectSourceOperator
    protected Iterator<Object[]> createSourceIterator(Vertex vertex) {
        TxnPartVectorParam txnPartVectorParam = (TxnPartVectorParam) vertex.getParam();
        OperatorProfile profile = txnPartVectorParam.getProfile("partVector");
        long currentTimeMillis = System.currentTimeMillis();
        int vectorIndex = txnPartVectorParam.getVectorIndex();
        String distanceType = txnPartVectorParam.getDistanceType();
        KeyValueCodec createKeyValueCodec = CodecService.getDefault().createKeyValueCodec(txnPartVectorParam.getTable().version, txnPartVectorParam.getTableDataSchema(), txnPartVectorParam.tableDataKeyMapping());
        List<VectorSearchResponse> vectorSearch = Services.KV_STORE.getInstance(txnPartVectorParam.getTableId(), txnPartVectorParam.getPartId()).vectorSearch(txnPartVectorParam.getScanTs(), txnPartVectorParam.getIndexId(), txnPartVectorParam.getFloatArray(), txnPartVectorParam.getTopN(), txnPartVectorParam.getParameterMap(), txnPartVectorParam.getCoprocessor());
        ArrayList arrayList = new ArrayList();
        if (txnPartVectorParam.isLookUp()) {
            Map<Integer, Integer> vecPriIdxMapping = getVecPriIdxMapping(txnPartVectorParam);
            for (VectorSearchResponse vectorSearchResponse : vectorSearch) {
                TxnVectorSearchResponse txnVectorSearchResponse = (TxnVectorSearchResponse) vectorSearchResponse;
                Object[] decode = createKeyValueCodec.decode(new KeyValue(txnVectorSearchResponse.getTableKey(), txnVectorSearchResponse.getTableVal()));
                byte[] bArr = new byte[txnVectorSearchResponse.getKey().length];
                System.arraycopy(txnVectorSearchResponse.getKey(), 0, bArr, 0, txnVectorSearchResponse.getKey().length);
                CommonId calcPartId = PartitionService.getService(Optional.ofNullable(txnPartVectorParam.getTable().getPartitionStrategy()).orElse(DingoPartitionServiceProvider.RANGE_FUNC_NAME)).calcPartId(txnVectorSearchResponse.getKey(), txnPartVectorParam.getDistributions());
                CodecService.getDefault().setId(bArr, calcPartId.domain);
                Iterator<Object[]> localStore = TxnGetByKeysOperator.getLocalStore(calcPartId, txnPartVectorParam.getCodec(), bArr, txnPartVectorParam.getTableId(), vertex.getTask().getTxnId(), calcPartId.encode(), vertex.getTask().getTransactionType());
                if (localStore != null) {
                    while (localStore.hasNext()) {
                        Object[] next = localStore.next();
                        if (vectorIndex >= next.length || distanceType == null) {
                            next[next.length - 1] = Float.valueOf(0.0f);
                        } else {
                            Object obj = next[vectorIndex];
                            if (obj instanceof List) {
                                List asList = Arrays.asList(txnPartVectorParam.getFloatArray());
                                float f = 0.0f;
                                if (distanceType.contains("L2")) {
                                    f = (float) VectorL2DistanceFun.l2DistanceCombine((List) obj, asList);
                                } else if (distanceType.contains("INNER_PRODUCT")) {
                                    f = (float) VectorIPDistanceFun.innerProduct((List) obj, asList);
                                } else if (distanceType.contains("COSINE")) {
                                    f = VectorCosineDistanceFun.cosine((List) obj, asList);
                                }
                                next[next.length - 1] = Float.valueOf(f);
                            } else {
                                next[next.length - 1] = Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS);
                            }
                        }
                        arrayList.add(next);
                    }
                } else {
                    KeyValue txnGet = StoreService.getDefault().getInstance(txnPartVectorParam.getTableId(), calcPartId).txnGet(txnPartVectorParam.getScanTs(), txnVectorSearchResponse.getKey(), txnPartVectorParam.getTimeOut());
                    if (txnGet != null && txnGet.getValue() != null) {
                        Object[] decode2 = txnPartVectorParam.getCodec().decode(txnGet);
                        decode2[decode2.length - 1] = Float.valueOf(vectorSearchResponse.getDistance());
                        decode2[vectorIndex] = vectorSearchResponse.getFloatValues();
                        vecPriIdxMapping.forEach((num, num2) -> {
                            decode2[num2.intValue()] = decode[num.intValue()];
                        });
                        arrayList.add(decode2);
                    }
                }
            }
        } else {
            TupleMapping mapping2VecSelection = mapping2VecSelection(txnPartVectorParam);
            for (VectorSearchResponse vectorSearchResponse2 : vectorSearch) {
                TxnVectorSearchResponse txnVectorSearchResponse2 = (TxnVectorSearchResponse) vectorSearchResponse2;
                Object[] decode3 = createKeyValueCodec.decode(new KeyValue(txnVectorSearchResponse2.getTableKey(), txnVectorSearchResponse2.getTableVal()));
                Object[] objArr = new Object[txnPartVectorParam.getTable().columns.size() + 1];
                TupleMapping resultSelection = txnPartVectorParam.getResultSelection();
                for (int i = 0; i < resultSelection.size(); i++) {
                    objArr[resultSelection.get(i)] = decode3[mapping2VecSelection.get(i)];
                }
                objArr[objArr.length - 1] = Float.valueOf(vectorSearchResponse2.getDistance());
                arrayList.add(objArr);
            }
        }
        profile.incrTime(currentTimeMillis);
        return arrayList.iterator();
    }

    private static Map<Integer, Integer> getVecPriIdxMapping(TxnPartVectorParam txnPartVectorParam) {
        int size = txnPartVectorParam.getTableDataColList().size();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < size; i++) {
            Column column = txnPartVectorParam.getTableDataColList().get(i);
            if (!column.isPrimary() && !(column.type instanceof ListType)) {
                hashMap.put(Integer.valueOf(i), Integer.valueOf(txnPartVectorParam.getTable().getColumns().indexOf(column)));
            }
        }
        return hashMap;
    }

    private static TupleMapping mapping2VecSelection(TxnPartVectorParam txnPartVectorParam) {
        TupleMapping resultSelection = txnPartVectorParam.getResultSelection();
        int[] iArr = new int[resultSelection.size()];
        for (int i = 0; i < resultSelection.size(); i++) {
            if (resultSelection.get(i) != txnPartVectorParam.getTable().columns.size()) {
                String name = txnPartVectorParam.getTable().getColumns().get(resultSelection.get(i)).getName();
                java.util.Optional<Column> findFirst = txnPartVectorParam.getIndexTable().getColumns().stream().filter(column -> {
                    return column.getName().equalsIgnoreCase(name);
                }).findFirst();
                List<Column> tableDataColList = txnPartVectorParam.getTableDataColList();
                tableDataColList.getClass();
                iArr[i] = ((Integer) findFirst.map((v1) -> {
                    return r3.indexOf(v1);
                }).orElseThrow(() -> {
                    return new RuntimeException("not found vector selection");
                })).intValue();
            }
        }
        return TupleMapping.of(iArr);
    }
}
