package dev.langchain4j.community.store.embedding.duckdb;

import static java.lang.String.format;

import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.*;
import dev.langchain4j.store.embedding.filter.logical.And;
import dev.langchain4j.store.embedding.filter.logical.Not;
import dev.langchain4j.store.embedding.filter.logical.Or;
import java.util.AbstractMap;
import java.util.Collection;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class DuckDBMetadataFilterMapper {

    static final Map<Class<?>, String> SQL_TYPE_MAP = Stream.of(
                    new AbstractMap.SimpleEntry<>(Integer.class, "int"),
                    new AbstractMap.SimpleEntry<>(Long.class, "bigint"),
                    new AbstractMap.SimpleEntry<>(Float.class, "float"),
                    new AbstractMap.SimpleEntry<>(Double.class, "double"),
                    new AbstractMap.SimpleEntry<>(String.class, "text"),
                    new AbstractMap.SimpleEntry<>(UUID.class, "uuid"),
                    new AbstractMap.SimpleEntry<>(Boolean.class, "boolean"),
                    new AbstractMap.SimpleEntry<>(Object.class, "text"))
            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

    public String map(Filter filter) {
        if (filter instanceof IsEqualTo eq) {
            return mapEqual(eq);
        } else if (filter instanceof IsNotEqualTo neq) {
            return mapNotEqual(neq);
        } else if (filter instanceof IsGreaterThan gt) {
            return mapGreaterThan(gt);
        } else if (filter instanceof IsGreaterThanOrEqualTo gte) {
            return mapGreaterThanOrEqual(gte);
        } else if (filter instanceof IsLessThan lt) {
            return mapLessThan(lt);
        } else if (filter instanceof IsLessThanOrEqualTo lte) {
            return mapLessThanOrEqual(lte);
        } else if (filter instanceof IsIn in) {
            return mapIn(in);
        } else if (filter instanceof IsNotIn nin) {
            return mapNotIn(nin);
        } else if (filter instanceof And and) {
            return mapAnd(and);
        } else if (filter instanceof Not not) {
            return mapNot(not);
        } else if (filter instanceof Or or) {
            return mapOr(or);
        } else {
            throw new UnsupportedOperationException(
                    "Unsupported filter type: " + filter.getClass().getName());
        }
    }

    private String mapEqual(IsEqualTo isEqualTo) {
        String key = formatKey(isEqualTo.key(), isEqualTo.comparisonValue().getClass());
        return format("%s is not null and %s = %s", key, key, formatValue(isEqualTo.comparisonValue()));
    }

    private String mapNotEqual(IsNotEqualTo isNotEqualTo) {
        String key =
                formatKey(isNotEqualTo.key(), isNotEqualTo.comparisonValue().getClass());
        return format("%s is null or %s != %s", key, key, formatValue(isNotEqualTo.comparisonValue()));
    }

    private String mapGreaterThan(IsGreaterThan isGreaterThan) {
        return format(
                "%s > %s",
                formatKey(isGreaterThan.key(), isGreaterThan.comparisonValue().getClass()),
                formatValue(isGreaterThan.comparisonValue()));
    }

    private String mapGreaterThanOrEqual(IsGreaterThanOrEqualTo isGreaterThanOrEqualTo) {
        return format(
                "%s >= %s",
                formatKey(
                        isGreaterThanOrEqualTo.key(),
                        isGreaterThanOrEqualTo.comparisonValue().getClass()),
                formatValue(isGreaterThanOrEqualTo.comparisonValue()));
    }

    private String mapLessThan(IsLessThan isLessThan) {
        return format(
                "%s < %s",
                formatKey(isLessThan.key(), isLessThan.comparisonValue().getClass()),
                formatValue(isLessThan.comparisonValue()));
    }

    private String mapLessThanOrEqual(IsLessThanOrEqualTo isLessThanOrEqualTo) {
        return format(
                "%s <= %s",
                formatKey(
                        isLessThanOrEqualTo.key(),
                        isLessThanOrEqualTo.comparisonValue().getClass()),
                formatValue(isLessThanOrEqualTo.comparisonValue()));
    }

    private String mapIn(IsIn isIn) {
        return format("%s in %s", formatKeyAsString(isIn.key()), formatValuesAsString(isIn.comparisonValues()));
    }

    private String mapNotIn(IsNotIn isNotIn) {
        String key = formatKeyAsString(isNotIn.key());
        return format("%s is null or %s not in %s", key, key, formatValuesAsString(isNotIn.comparisonValues()));
    }

    private String mapAnd(And and) {
        return format("%s and %s", map(and.left()), map(and.right()));
    }

    private String mapNot(Not not) {
        return format("not(%s)", map(not.expression()));
    }

    private String mapOr(Or or) {
        return format("(%s or %s)", map(or.left()), map(or.right()));
    }

    String formatKey(String key, Class<?> valueType) {
        return format("(metadata->>'%s')::%s", key, SQL_TYPE_MAP.get(valueType));
    }

    String formatKeyAsString(String key) {
        return format("(metadata->>'%s')", key);
    }

    String formatValue(Object value) {
        if (value instanceof String || value instanceof UUID) {
            return "'" + value + "'";
        } else {
            return value.toString();
        }
    }

    String formatValuesAsString(Collection<?> values) {
        return "(" + values.stream().map(v -> format("'%s'", v)).collect(Collectors.joining(",")) + ")";
    }
}
