package org.apache.druid.math.expr;

import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BooleanSupplier;
import java.util.function.DoubleSupplier;
import java.util.function.LongSupplier;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.antlr.v4.runtime.tree.xpath.XPath;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.vector.ExprEvalVector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.skife.jdbi.org.antlr.runtime.debug.DebugEventListener;

/* loaded from: input_file:org/apache/druid/math/expr/VectorExprSanityTest.class */
public class VectorExprSanityTest extends InitializedNullHandlingTest {
    private static final Logger log = new Logger(VectorExprSanityTest.class);
    private static final int NUM_ITERATIONS = 10;
    private static final int VECTOR_SIZE = 512;
    final Map<String, ExpressionType> types = ImmutableMap.builder().put("l1", ExpressionType.LONG).put("l2", ExpressionType.LONG).put("d1", ExpressionType.DOUBLE).put("d2", ExpressionType.DOUBLE).put("s1", ExpressionType.STRING).put("s2", ExpressionType.STRING).put("boolString1", ExpressionType.STRING).put("boolString2", ExpressionType.STRING).build();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/druid/math/expr/VectorExprSanityTest$SettableVectorInputBinding.class */
    public static class SettableVectorInputBinding implements Expr.VectorInputBinding {
        private final int vectorSize;
        static final /* synthetic */ boolean $assertionsDisabled;
        private int id = 0;
        private final Map<String, boolean[]> nulls = new HashMap();
        private final Map<String, long[]> longs = new HashMap();
        private final Map<String, double[]> doubles = new HashMap();
        private final Map<String, Object[]> objects = new HashMap();
        private final Map<String, ExpressionType> types = new HashMap();

        /* JADX INFO: Access modifiers changed from: package-private */
        public SettableVectorInputBinding(int i) {
            this.vectorSize = i;
        }

        public SettableVectorInputBinding addBinding(String str, ExpressionType expressionType, boolean[] zArr) {
            this.nulls.put(str, zArr);
            this.types.put(str, expressionType);
            return this;
        }

        public SettableVectorInputBinding addLong(String str, long[] jArr) {
            return addLong(str, jArr, new boolean[jArr.length]);
        }

        public SettableVectorInputBinding addLong(String str, long[] jArr, boolean[] zArr) {
            if (!$assertionsDisabled && jArr.length != this.vectorSize) {
                throw new AssertionError();
            }
            this.longs.put(str, jArr);
            return addBinding(str, ExpressionType.LONG, zArr);
        }

        public SettableVectorInputBinding addDouble(String str, double[] dArr) {
            return addDouble(str, dArr, new boolean[dArr.length]);
        }

        public SettableVectorInputBinding addDouble(String str, double[] dArr, boolean[] zArr) {
            if (!$assertionsDisabled && dArr.length != this.vectorSize) {
                throw new AssertionError();
            }
            this.doubles.put(str, dArr);
            return addBinding(str, ExpressionType.DOUBLE, zArr);
        }

        public SettableVectorInputBinding addString(String str, String[] strArr) {
            if (!$assertionsDisabled && strArr.length != this.vectorSize) {
                throw new AssertionError();
            }
            this.objects.put(str, strArr);
            return addBinding(str, ExpressionType.STRING, new boolean[strArr.length]);
        }

        @Override // org.apache.druid.math.expr.Expr.VectorInputBinding
        public <T> T[] getObjectVector(String str) {
            return (T[]) this.objects.getOrDefault(str, new Object[getCurrentVectorSize()]);
        }

        @Override // org.apache.druid.math.expr.Expr.InputBindingInspector
        public ExpressionType getType(String str) {
            return this.types.get(str);
        }

        @Override // org.apache.druid.math.expr.Expr.VectorInputBinding
        public long[] getLongVector(String str) {
            return this.longs.getOrDefault(str, new long[getCurrentVectorSize()]);
        }

        @Override // org.apache.druid.math.expr.Expr.VectorInputBinding
        public double[] getDoubleVector(String str) {
            return this.doubles.getOrDefault(str, new double[getCurrentVectorSize()]);
        }

        @Override // org.apache.druid.math.expr.Expr.VectorInputBinding
        @Nullable
        public boolean[] getNullVector(String str) {
            boolean[] zArr = new boolean[getCurrentVectorSize()];
            Arrays.fill(zArr, NullHandling.sqlCompatible());
            return this.nulls.getOrDefault(str, zArr);
        }

        @Override // org.apache.druid.math.expr.Expr.VectorInputBindingInspector
        public int getMaxVectorSize() {
            return this.vectorSize;
        }

        @Override // org.apache.druid.math.expr.Expr.VectorInputBinding
        public int getCurrentVectorSize() {
            return this.vectorSize;
        }

        @Override // org.apache.druid.math.expr.Expr.VectorInputBinding
        public int getCurrentVectorId() {
            int i = this.id;
            this.id = i + 1;
            return i;
        }

        static {
            $assertionsDisabled = !VectorExprSanityTest.class.desiredAssertionStatus();
        }
    }

    @BeforeClass
    public static void setupTests() {
        ExpressionProcessing.initializeForStrictBooleansTests(true);
    }

    @AfterClass
    public static void teardownTests() {
        ExpressionProcessing.initializeForTests(null);
    }

    @Test
    public void testUnaryOperators() {
        testFunctions(this.types, new String[]{"%sd1", "%sl1"}, new String[]{"-"});
    }

    @Test
    public void testBinaryMathOperators() {
        testFunctions(this.types, (String[]) Arrays.stream(makeTemplateArgs(new String[]{"d1", "d2", "l1", "l2", "1", "1.0", "nonexistent", "null", "s1"}, new String[]{"d1", "d2", "l1", "l2", "1", "1.0"})).map(strArr -> {
            return StringUtils.format("%s %s %s", strArr[0], "%s", strArr[1]);
        }).toArray(i -> {
            return new String[i];
        }), new String[]{"+", "-", "*", "/", "^", "%"});
    }

    @Test
    public void testBinaryComparisonOperators() {
        testFunctions(this.types, (String[]) Arrays.stream(makeTemplateArgs(new String[]{"d1", "d2", "l1", "l2", "1", "1.0", "s1", "s2", "nonexistent", "null"}, new String[]{"d1", "d2", "l1", "l2", "1", "1.0", "s1", "s2", "null"})).map(strArr -> {
            return StringUtils.format("%s %s %s", strArr[0], "%s", strArr[1]);
        }).toArray(i -> {
            return new String[i];
        }), new String[]{">", ">=", "<", "<=", "==", "!="});
    }

    @Test
    public void testUnaryLogicOperators() {
        testFunctions(this.types, new String[]{"%sd1", "%sl1", "%sboolString1"}, new String[]{XPath.NOT});
    }

    @Test
    public void testBinaryLogicOperators() {
        testFunctions(this.types, new String[]{"d1 %s d2", "l1 %s l2", "boolString1 %s boolString2"}, new String[]{"&&", "||"});
    }

    @Test
    public void testBinaryOperatorTrees() {
        String[] strArr = {"d2", "l2", DebugEventListener.PROTOCOL_VERSION, "2.0"};
        String[] strArr2 = (String[]) Arrays.stream(makeTemplateArgs(new String[]{"d1", "l1", "1", "1.0", "nonexistent", "null"}, strArr, strArr)).map(strArr3 -> {
            return StringUtils.format("(%s %s %s) %s %s", strArr3[0], "%s", strArr3[1], "%s", strArr3[2]);
        }).toArray(i -> {
            return new String[i];
        });
        String[] strArr4 = {"+", "-", "*", "/"};
        testFunctions(this.types, strArr2, makeTemplateArgs(strArr4, strArr4));
    }

    @Test
    public void testUnivariateFunctions() {
        testFunctions(this.types, new String[]{"%s(s1)", "%s(l1)", "%s(d1)", "%s(nonexistent)", "%s(null)"}, new String[]{"parse_long", "isNull", "notNull"});
    }

    @Test
    public void testUnivariateMathFunctions() {
        testFunctions(this.types, new String[]{"%s(l1)", "%s(d1)", "%s(pi())", "%s(null)"}, new String[]{"abs", "acos", "asin", "atan", "cbrt", "ceil", "cos", "cosh", "cot", "exp", "expm1", "floor", "getExponent", "log", "log10", "log1p", "nextUp", "rint", "signum", "sin", "sinh", "sqrt", "tan", "tanh", "toDegrees", "toRadians", "ulp", "bitwiseComplement", "bitwiseConvertDoubleToLongBits", "bitwiseConvertLongBitsToDouble"});
    }

    @Test
    public void testBivariateMathFunctions() {
        testFunctions(this.types, new String[]{"%s(d1, d2)", "%s(d1, l1)", "%s(l1, d1)", "%s(l1, l2)", "%s(nonexistent, l1)", "%s(nonexistent, d1)"}, new String[]{"atan2", "copySign", "div", "hypot", "remainder", "max", "min", "nextAfter", "scalb", "pow", "bitwiseAnd", "bitwiseOr", "bitwiseXor", "bitwiseShiftLeft", "bitwiseShiftRight"});
    }

    @Test
    public void testSymmetricalBivariateFunctions() {
        testFunctions(this.types, new String[]{"%s(d1, d2)", "%s(l1, l2)", "%s(s1, s2)", "%s(nonexistent, l1)", "%s(nonexistent, d1)", "%s(nonexistent, s1)"}, new String[]{"nvl"});
    }

    @Test
    public void testCast() {
        testFunctions(this.types, new String[]{"cast(%s, %s)"}, makeTemplateArgs(new String[]{"d1", "l1", "s1"}, new String[]{"'STRING'", "'LONG'", "'DOUBLE'"}));
    }

    @Test
    public void testStringFns() {
        testExpression("s1 + s2", this.types);
        testExpression("s1 + '-' + s2", this.types);
        testExpression("concat(s1, s2)", this.types);
        testExpression("concat(s1,'-',s2,'-',l1,'-',d1)", this.types);
    }

    static void testFunctions(Map<String, ExpressionType> map, String[] strArr, String[] strArr2) {
        for (String str : strArr) {
            for (String str2 : strArr2) {
                testExpression(StringUtils.format(str, str2), map);
            }
        }
    }

    static void testFunctions(Map<String, ExpressionType> map, String[] strArr, String[][] strArr2) {
        for (String str : strArr) {
            for (String[] strArr3 : strArr2) {
                testExpression(StringUtils.format(str, strArr3), map);
            }
        }
    }

    static void testExpression(String str, Map<String, ExpressionType> map) {
        log.debug("[%s]", str);
        Expr parse = Parser.parse(str, ExprMacroTable.nil());
        for (int i = 0; i < 10; i++) {
            testExpressionWithBindings(str, parse, makeRandomizedBindings(512, map));
        }
        testExpressionWithBindings(str, parse, makeSequentialBinding(512, map));
    }

    public static void testExpressionWithBindings(String str, Expr expr, NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> nonnullPair) {
        Assert.assertTrue(StringUtils.format("Cannot vectorize %s", str), expr.canVectorize(nonnullPair.rhs));
        ExpressionType outputType = expr.getOutputType(nonnullPair.rhs);
        ExprEvalVector evalVector = expr.buildVectorized(nonnullPair.rhs).evalVector(nonnullPair.rhs);
        if (outputType != null) {
            Assert.assertEquals(str, outputType, evalVector.getType());
        }
        for (int i = 0; i < 512; i++) {
            ExprEval eval = expr.eval(nonnullPair.lhs[i]);
            if (outputType != null && !eval.isNumericNull()) {
                Assert.assertEquals(eval.type(), outputType);
            }
            Assert.assertEquals(StringUtils.format("Values do not match for row %s for expression %s", Integer.valueOf(i), str), eval.value(), evalVector.get(i));
        }
    }

    public static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeRandomizedBindings(int i, Map<String, ExpressionType> map) {
        ThreadLocalRandom current = ThreadLocalRandom.current();
        LongSupplier longSupplier = () -> {
            return current.nextLong(2147483646L);
        };
        current.getClass();
        DoubleSupplier doubleSupplier = current::nextDouble;
        current.getClass();
        return makeBindings(i, map, longSupplier, doubleSupplier, current::nextBoolean, () -> {
            return String.valueOf(current.nextInt());
        });
    }

    public static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeSequentialBinding(int i, Map<String, ExpressionType> map) {
        return makeBindings(i, map, new LongSupplier() { // from class: org.apache.druid.math.expr.VectorExprSanityTest.1
            int counter = 1;

            @Override // java.util.function.LongSupplier
            public long getAsLong() {
                int i2 = this.counter;
                this.counter = i2 + 1;
                return i2;
            }
        }, new DoubleSupplier() { // from class: org.apache.druid.math.expr.VectorExprSanityTest.2
            int counter = 1;

            @Override // java.util.function.DoubleSupplier
            public double getAsDouble() {
                int i2 = this.counter;
                this.counter = i2 + 1;
                return i2;
            }
        }, () -> {
            return ThreadLocalRandom.current().nextBoolean();
        }, new Supplier<String>() { // from class: org.apache.druid.math.expr.VectorExprSanityTest.3
            int counter = 1;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.function.Supplier
            public String get() {
                int i2 = this.counter;
                this.counter = i2 + 1;
                return String.valueOf(i2);
            }
        });
    }

    static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeBindings(int i, Map<String, ExpressionType> map, LongSupplier longSupplier, DoubleSupplier doubleSupplier, BooleanSupplier booleanSupplier, Supplier<String> supplier) {
        SettableVectorInputBinding settableVectorInputBinding = new SettableVectorInputBinding(i);
        SettableObjectBinding[] settableObjectBindingArr = new SettableObjectBinding[i];
        boolean sqlCompatible = NullHandling.sqlCompatible();
        for (Map.Entry<String, ExpressionType> entry : map.entrySet()) {
            boolean[] zArr = new boolean[i];
            switch (entry.getValue().getType()) {
                case LONG:
                    long[] jArr = new long[i];
                    for (int i2 = 0; i2 < i; i2++) {
                        zArr[i2] = sqlCompatible && booleanSupplier.getAsBoolean();
                        jArr[i2] = zArr[i2] ? 0L : longSupplier.getAsLong();
                        if (settableObjectBindingArr[i2] == null) {
                            settableObjectBindingArr[i2] = new SettableObjectBinding();
                        }
                        settableObjectBindingArr[i2].withBinding(entry.getKey(), zArr[i2] ? null : Long.valueOf(jArr[i2]));
                    }
                    if (sqlCompatible) {
                        settableVectorInputBinding.addLong(entry.getKey(), jArr, zArr);
                        break;
                    } else {
                        settableVectorInputBinding.addLong(entry.getKey(), jArr);
                        break;
                    }
                    break;
                case DOUBLE:
                    double[] dArr = new double[i];
                    for (int i3 = 0; i3 < i; i3++) {
                        zArr[i3] = sqlCompatible && booleanSupplier.getAsBoolean();
                        dArr[i3] = zArr[i3] ? CMAESOptimizer.DEFAULT_STOPFITNESS : doubleSupplier.getAsDouble();
                        if (settableObjectBindingArr[i3] == null) {
                            settableObjectBindingArr[i3] = new SettableObjectBinding();
                        }
                        settableObjectBindingArr[i3].withBinding(entry.getKey(), zArr[i3] ? null : Double.valueOf(dArr[i3]));
                    }
                    if (sqlCompatible) {
                        settableVectorInputBinding.addDouble(entry.getKey(), dArr, zArr);
                        break;
                    } else {
                        settableVectorInputBinding.addDouble(entry.getKey(), dArr);
                        break;
                    }
                case STRING:
                    String[] strArr = new String[i];
                    for (int i4 = 0; i4 < i; i4++) {
                        zArr[i4] = sqlCompatible && booleanSupplier.getAsBoolean();
                        if (zArr[i4] || !entry.getKey().startsWith("boolString")) {
                            strArr[i4] = zArr[i4] ? null : String.valueOf(supplier.get());
                        } else {
                            strArr[i4] = String.valueOf(booleanSupplier.getAsBoolean());
                        }
                        if (settableObjectBindingArr[i4] == null) {
                            settableObjectBindingArr[i4] = new SettableObjectBinding();
                        }
                        settableObjectBindingArr[i4].withBinding(entry.getKey(), zArr[i4] ? null : strArr[i4]);
                    }
                    settableVectorInputBinding.addString(entry.getKey(), strArr);
                    break;
            }
        }
        return new NonnullPair<>(settableObjectBindingArr, settableVectorInputBinding);
    }

    static String[][] makeTemplateArgs(String[] strArr, String[] strArr2) {
        return (String[][]) Arrays.stream(strArr).flatMap(str -> {
            return Arrays.stream(strArr2).map(str -> {
                return new String[]{str, str};
            });
        }).toArray(i -> {
            return new String[i];
        });
    }

    static String[][] makeTemplateArgs(String[] strArr, String[] strArr2, String[] strArr3) {
        return (String[][]) Arrays.stream(strArr).flatMap(str -> {
            return Arrays.stream(strArr2).flatMap(str -> {
                return Arrays.stream(strArr3).map(str -> {
                    return new String[]{str, str, str};
                });
            });
        }).toArray(i -> {
            return new String[i];
        });
    }
}
