/*
 * Decompiled with CFR 0.152.
 */
package tech.ydb.yoj.repository.ydb.statement;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import java.lang.reflect.Type;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.ydb.proto.ValueProtos;
import tech.ydb.yoj.databind.expression.FilterExpression;
import tech.ydb.yoj.databind.expression.OrderExpression;
import tech.ydb.yoj.databind.schema.ObjectSchema;
import tech.ydb.yoj.databind.schema.Schema;
import tech.ydb.yoj.repository.db.Entity;
import tech.ydb.yoj.repository.db.EntityIdSchema;
import tech.ydb.yoj.repository.db.EntitySchema;
import tech.ydb.yoj.repository.db.TableDescriptor;
import tech.ydb.yoj.repository.ydb.statement.MultipleVarsYqlStatement;
import tech.ydb.yoj.repository.ydb.statement.PredicateStatement;
import tech.ydb.yoj.repository.ydb.statement.Statement;
import tech.ydb.yoj.repository.ydb.statement.YqlStatementParam;
import tech.ydb.yoj.repository.ydb.yql.YqlListingQuery;
import tech.ydb.yoj.repository.ydb.yql.YqlPredicate;
import tech.ydb.yoj.repository.ydb.yql.YqlType;

public final class FindInStatement<IN, T extends Entity<T>, RESULT>
extends MultipleVarsYqlStatement<IN, T, RESULT> {
    private static final Logger log = LoggerFactory.getLogger(FindInStatement.class);
    private final String indexName;
    private final Schema<?> keySchema;
    private final Set<String> keyFields;
    private final PredicateClause<T> predicate;
    private final OrderExpression<T> orderBy;
    private final Integer limit;

    public static <ID extends Entity.Id<T>, T extends Entity<T>, RESULT> FindInStatement<Set<ID>, T, RESULT> from(TableDescriptor<T> tableDescriptor, EntitySchema<T> schema, Schema<RESULT> resultSchema, Iterable<ID> ids, @Nullable FilterExpression<T> filter, @Nullable OrderExpression<T> orderBy, @Nullable Integer limit) {
        EntityIdSchema keySchema = schema.getIdSchema();
        Set<String> keyFields = FindInStatement.collectKeyFieldsFromIds(schema.getIdSchema(), ids);
        return new FindInStatement(tableDescriptor, schema, resultSchema, keySchema, keyFields, null, filter, orderBy, limit);
    }

    public static <K, T extends Entity<T>, RESULT> FindInStatement<Set<K>, T, RESULT> from(TableDescriptor<T> tableDescriptor, EntitySchema<T> schema, Schema<RESULT> resultSchema, String indexName, Iterable<K> keys, @Nullable FilterExpression<T> filter, @Nullable OrderExpression<T> orderBy, @Nullable Integer limit) {
        Schema<K> keySchema = FindInStatement.getKeySchemaFromValues(keys);
        Set<String> keyFields = FindInStatement.collectKeyFieldsFromKeys(tableDescriptor, schema, indexName, keySchema, keys);
        return new FindInStatement(tableDescriptor, schema, resultSchema, keySchema, keyFields, indexName, filter, orderBy, limit);
    }

    private <PARAMS> FindInStatement(TableDescriptor<T> tableDescriptor, EntitySchema<T> schema, Schema<RESULT> resultSchema, Schema<PARAMS> keySchema, Set<String> keyFields, @Nullable String indexName, @Nullable FilterExpression<T> filter, @Nullable OrderExpression<T> orderBy, @Nullable Integer limit) {
        super(tableDescriptor, schema, resultSchema);
        this.indexName = indexName;
        this.orderBy = orderBy;
        this.limit = limit;
        this.keySchema = keySchema;
        this.keyFields = keyFields;
        this.predicate = filter != null ? new PredicateClause<T>(tableDescriptor, schema, YqlListingQuery.toYqlPredicate(filter)) : null;
        this.validateOrderByFields();
    }

    private static <T extends Entity<T>> Set<String> collectKeyFieldsFromIds(Schema<Entity.Id<T>> idSchema, Iterable<? extends Entity.Id<T>> ids) {
        Preconditions.checkArgument((!Iterables.isEmpty(ids) ? 1 : 0) != 0, (Object)"ids should be non empty");
        Set nonNullFieldsSet = Streams.stream(ids).map(id -> FindInStatement.nonNullKeyFieldNames(idSchema, id)).collect(Collectors.toUnmodifiableSet());
        Preconditions.checkArgument((nonNullFieldsSet.size() != 0 ? 1 : 0) != 0, (Object)"ids should have at least one non-null field");
        Preconditions.checkArgument((nonNullFieldsSet.size() == 1 ? 1 : 0) != 0, (Object)"ids should have nulls in the same fields");
        Set keyFields = (Set)Iterables.getOnlyElement(nonNullFieldsSet);
        if (!FindInStatement.isPrefixedFields(idSchema.flattenFieldNames(), keyFields)) {
            log.warn("FindIn(ids) not by the primary key prefix will result in a FullScan, PK: {}, query uses the fields: {}", (Object)idSchema.flattenFieldNames(), (Object)keyFields);
        }
        return keyFields;
    }

    private static <V> Schema<V> getKeySchemaFromValues(Iterable<V> keys) {
        Object key = Iterables.getFirst(keys, null);
        Preconditions.checkArgument((key != null ? 1 : 0) != 0, (Object)"keys should be non empty");
        return ObjectSchema.of(key.getClass());
    }

    private static <E extends Entity<E>, K> Set<String> collectKeyFieldsFromKeys(TableDescriptor<E> tableDescriptor, Schema<E> entitySchema, String indexName, Schema<K> keySchema, Iterable<K> keys) {
        Set nonNullFieldsSet = Streams.stream(keys).map(key -> FindInStatement.nonNullKeyFieldNames(keySchema, key)).collect(Collectors.toUnmodifiableSet());
        Preconditions.checkArgument((nonNullFieldsSet.size() != 0 ? 1 : 0) != 0, (Object)"keys should have at least one non-null field");
        Preconditions.checkArgument((nonNullFieldsSet.size() == 1 ? 1 : 0) != 0, (Object)"keys should have nulls in the same fields");
        Set keyFields = (Set)Iterables.getOnlyElement(nonNullFieldsSet);
        Schema.Index globalIndex = entitySchema.getGlobalIndexes().stream().filter(index -> indexName.equals(index.getIndexName())).findAny().orElseThrow(() -> new IllegalArgumentException("Table `%s` doesn't have index `%s`".formatted(tableDescriptor.toDebugString(), indexName)));
        Set indexKeys = Set.copyOf(globalIndex.getFieldNames());
        Sets.SetView missingInIndexKeys = Sets.difference((Set)keyFields, indexKeys);
        Preconditions.checkArgument((boolean)missingInIndexKeys.isEmpty(), (Object)"Index `%s` of table `%s` doesn't contain key(s): [%s]".formatted(indexName, tableDescriptor.toDebugString(), String.join((CharSequence)", ", (Iterable<? extends CharSequence>)missingInIndexKeys)));
        Preconditions.checkArgument((boolean)FindInStatement.isPrefixedFields(globalIndex.getFieldNames(), keyFields), (Object)"FindIn(keys) is allowed only by the prefix of the index key fields, index key: %s, query uses the fields: %s".formatted(globalIndex.getFieldNames(), keyFields));
        Map<String, Type> keyFieldTypes = FindInStatement.getKeyFieldTypeMap(keySchema, keyFields);
        Map<String, Type> entityFieldTypes = FindInStatement.getKeyFieldTypeMap(entitySchema, keyFields);
        for (Map.Entry<String, Type> keyFieldType : keyFieldTypes.entrySet()) {
            Type entityFieldType = entityFieldTypes.get(keyFieldType.getKey());
            Preconditions.checkArgument((boolean)entityFieldType.equals(keyFieldType.getValue()), (Object)"Table `%s` has column `%s` of type `%s`, but corresponding key field is `%s`".formatted(tableDescriptor.toDebugString(), keyFieldType.getKey(), entityFieldType, keyFieldType.getValue()));
        }
        return globalIndex.getFieldNames().stream().limit(keyFields.size()).collect(Collectors.toCollection(LinkedHashSet::new));
    }

    private static <V> Set<String> nonNullKeyFieldNames(Schema<V> schema, V key) {
        return schema.flatten(key).keySet();
    }

    private static boolean isPrefixedFields(List<String> keyFields, Set<String> fields) {
        for (String keyField : keyFields.subList(0, fields.size())) {
            if (fields.contains(keyField)) continue;
            return false;
        }
        return true;
    }

    private static Map<String, Type> getKeyFieldTypeMap(Schema<?> schema, Set<String> keyFields) {
        return schema.flattenFields().stream().filter(f -> keyFields.contains(f.getName())).collect(Collectors.toUnmodifiableMap(Schema.JavaField::getName, Schema.JavaField::getType));
    }

    private void validateOrderByFields() {
        if (!this.hasOrderBy() || this.schema.equals((Object)this.resultSchema)) {
            return;
        }
        Set resultColumns = this.resultSchema.flattenFields().stream().map(Schema.JavaField::getName).collect(Collectors.toUnmodifiableSet());
        List<String> missingColumns = this.orderBy.getKeys().stream().map(OrderExpression.SortKey::getField).flatMap(Schema.JavaField::flatten).map(Schema.JavaField::getName).filter(column -> !resultColumns.contains(column)).toList();
        Preconditions.checkArgument((boolean)missingColumns.isEmpty(), (Object)"Result schema of '%s' does not contain field(s): [%s] by which the result is ordered: %s".formatted(this.resultSchema.getTypeName(), String.join((CharSequence)", ", missingColumns), this.orderBy));
    }

    @Override
    public Statement.QueryType getQueryType() {
        return Statement.QueryType.SELECT;
    }

    @Override
    public String getQuery(String tablespace) {
        return this.declarations() + "SELECT " + this.outNames() + "\n" + (String)(this.hasPredicate() ? "FROM (\nSELECT " + this.allColumnNames() + "\n" : "") + "FROM AS_TABLE($Input) AS k\nJOIN " + this.table(tablespace) + this.indexUsage() + " AS t\nON " + this.joinExpression() + "\n" + (this.hasPredicate() ? ")\n" : "") + this.predicateClause() + this.orderByClause() + this.limitClause();
    }

    @Override
    public List<YqlStatementParam> getParams() {
        return this.schema.flattenFields().stream().filter(c -> this.keyFields.contains(c.getName())).map(c -> YqlStatementParam.required(YqlType.of(c), c.getName())).toList();
    }

    @Override
    public Map<String, ValueProtos.TypedValue> toQueryParameters(IN in) {
        if (this.hasPredicate()) {
            return ImmutableMap.builder().putAll(super.toQueryParameters(in)).putAll(this.predicate.toQueryParameters()).build();
        }
        return super.toQueryParameters(in);
    }

    @Override
    protected String declarations() {
        return super.declarations() + this.predicateClauseDeclarations();
    }

    @Override
    protected String outNames() {
        return this.resultSchema.flattenFields().stream().map(this::getOutName).collect(Collectors.joining(", "));
    }

    private String allColumnNames() {
        return this.schema.flattenFields().stream().map(this::getAliasedName).collect(Collectors.joining(", "));
    }

    private String getOutName(Schema.JavaField field) {
        return this.hasPredicate() ? this.escape(field.getName()) : this.getAliasedName(field);
    }

    private String getAliasedName(Schema.JavaField field) {
        String escapedName = this.escape(field.getName());
        return "t." + escapedName + " AS " + escapedName;
    }

    @Override
    protected Function<IN, Map<String, Object>> flattenInputVariables() {
        return arg_0 -> this.keySchema.flatten(arg_0);
    }

    private String indexUsage() {
        return this.isFindByIndex() ? " VIEW " + this.escape(this.indexName) : "";
    }

    private String joinExpression() {
        return this.keyFields.stream().map(n -> "t.%1$s = k.%1$s".formatted(this.escape((String)n))).collect(Collectors.joining(" AND "));
    }

    private String orderByClause() {
        return this.hasOrderBy() ? YqlListingQuery.toYqlOrderBy(this.orderBy).toFullYql(this.schema) + "\n" : "";
    }

    private String limitClause() {
        return this.hasLimit() ? "LIMIT " + this.limit + "\n" : "";
    }

    private String predicateClauseDeclarations() {
        return this.hasPredicate() ? this.predicate.declarations() : "";
    }

    private String predicateClause() {
        return this.hasPredicate() ? this.predicate.getClause() : "";
    }

    @Override
    public String toDebugString(IN in) {
        return "findIn(" + this.toDebugParams(in) + (String)(this.isFindByIndex() ? " by index " + this.escape(this.indexName) : "") + (String)(this.hasPredicate() ? ", filter [" + this.predicate.toDebugString() + "]" : "") + (String)(this.hasOrderBy() ? ", orderBy [" + String.valueOf(this.orderBy) + "]" : "") + (String)(this.hasLimit() ? ", limit [" + this.limit + "]" : "") + ")";
    }

    private boolean isFindByIndex() {
        return this.indexName != null;
    }

    private boolean hasLimit() {
        return this.limit != null;
    }

    private boolean hasOrderBy() {
        return this.orderBy != null;
    }

    private boolean hasPredicate() {
        return this.predicate != null;
    }

    private static class PredicateClause<T extends Entity<T>>
    extends PredicateStatement<Class<Void>, T, T> {
        private final YqlPredicate predicate;

        public PredicateClause(TableDescriptor<T> tableDescriptor, EntitySchema<T> schema, YqlPredicate predicate) {
            super(tableDescriptor, schema, schema, Void.class, __ -> predicate);
            this.predicate = predicate;
        }

        @Override
        public Statement.QueryType getQueryType() {
            return Statement.QueryType.UNTYPED;
        }

        public String getClause() {
            return this.resolveParamNames(this.predicate.toFullYql(this.schema)) + "\n";
        }

        @Override
        public String getQuery(String tablespace) {
            return "SELECT 1";
        }

        public String toDebugString() {
            return this.toDebugString(Void.TYPE);
        }

        public Map<String, ValueProtos.TypedValue> toQueryParameters() {
            return this.toQueryParameters(Void.TYPE);
        }

        @Override
        public String toDebugString(Class<Void> in) {
            return this.predicate.toString();
        }
    }
}

