package org.apache.druid.query.aggregation.post;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.ibm.icu.text.PluralRules;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.derby.iapi.sql.compile.TypeCompiler;
import org.apache.derby.iapi.store.raw.RowLock;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.CountAggregator;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/druid/query/aggregation/post/ArithmeticPostAggregatorTest.class */
public class ArithmeticPostAggregatorTest extends InitializedNullHandlingTest {
    @Test
    public void testCompute() {
        CountAggregator countAggregator = new CountAggregator();
        countAggregator.aggregate();
        countAggregator.aggregate();
        countAggregator.aggregate();
        HashMap hashMap = new HashMap();
        hashMap.put("rows", countAggregator.get());
        ArrayList<PostAggregator> newArrayList = Lists.newArrayList(new ConstantPostAggregator("roku", Double.valueOf(6.0d)), new FieldAccessPostAggregator("rows", "rows"));
        for (PostAggregator postAggregator : newArrayList) {
            hashMap.put(postAggregator.getName(), postAggregator.compute(hashMap));
        }
        ArithmeticPostAggregator arithmeticPostAggregator = new ArithmeticPostAggregator("add", "+", newArrayList);
        ExpressionPostAggregator expressionPostAggregator = new ExpressionPostAggregator("add", "roku + rows", (String) null, TestExprMacroTable.INSTANCE);
        Assert.assertEquals(Double.valueOf(9.0d), arithmeticPostAggregator.compute(hashMap));
        Assert.assertEquals(Double.valueOf(9.0d), expressionPostAggregator.compute(hashMap));
        ArithmeticPostAggregator arithmeticPostAggregator2 = new ArithmeticPostAggregator("subtract", "-", newArrayList);
        ExpressionPostAggregator expressionPostAggregator2 = new ExpressionPostAggregator("add", "roku - rows", (String) null, TestExprMacroTable.INSTANCE);
        Assert.assertEquals(Double.valueOf(3.0d), arithmeticPostAggregator2.compute(hashMap));
        Assert.assertEquals(Double.valueOf(3.0d), expressionPostAggregator2.compute(hashMap));
        ArithmeticPostAggregator arithmeticPostAggregator3 = new ArithmeticPostAggregator("multiply", "*", newArrayList);
        ExpressionPostAggregator expressionPostAggregator3 = new ExpressionPostAggregator("add", "roku * rows", (String) null, TestExprMacroTable.INSTANCE);
        Assert.assertEquals(Double.valueOf(18.0d), arithmeticPostAggregator3.compute(hashMap));
        Assert.assertEquals(Double.valueOf(18.0d), expressionPostAggregator3.compute(hashMap));
        ArithmeticPostAggregator arithmeticPostAggregator4 = new ArithmeticPostAggregator("divide", "/", newArrayList);
        ExpressionPostAggregator expressionPostAggregator4 = new ExpressionPostAggregator("add", "roku / rows", (String) null, TestExprMacroTable.INSTANCE);
        Assert.assertEquals(Double.valueOf(2.0d), arithmeticPostAggregator4.compute(hashMap));
        Assert.assertEquals(Double.valueOf(2.0d), expressionPostAggregator4.compute(hashMap));
    }

    @Test
    public void testComparator() {
        CountAggregator countAggregator = new CountAggregator();
        HashMap hashMap = new HashMap();
        hashMap.put("rows", countAggregator.get());
        ArithmeticPostAggregator arithmeticPostAggregator = new ArithmeticPostAggregator("add", "+", Lists.newArrayList(new ConstantPostAggregator("roku", Double.valueOf(6.0d)), new FieldAccessPostAggregator("rows", "rows")));
        Comparator comparator = arithmeticPostAggregator.getComparator();
        Object compute = arithmeticPostAggregator.compute(hashMap);
        countAggregator.aggregate();
        countAggregator.aggregate();
        countAggregator.aggregate();
        hashMap.put("rows", countAggregator.get());
        Object compute2 = arithmeticPostAggregator.compute(hashMap);
        Assert.assertEquals(-1L, comparator.compare(compute, compute2));
        Assert.assertEquals(0L, comparator.compare(compute, compute));
        Assert.assertEquals(0L, comparator.compare(compute2, compute2));
        Assert.assertEquals(1L, comparator.compare(compute2, compute));
    }

    @Test
    public void testComparatorNulls() {
        HashMap hashMap = new HashMap();
        ArithmeticPostAggregator arithmeticPostAggregator = new ArithmeticPostAggregator("add", "+", Lists.newArrayList(new ConstantPostAggregator("roku", Double.valueOf(6.0d)), new FieldAccessPostAggregator("doubleWithNulls", "doubleWithNulls")));
        Comparator comparator = arithmeticPostAggregator.getComparator();
        hashMap.put("doubleWithNulls", NullHandling.replaceWithDefault() ? NullHandling.defaultDoubleValue() : null);
        Object compute = arithmeticPostAggregator.compute(hashMap);
        hashMap.put("doubleWithNulls", Double.valueOf(1.0d));
        Object compute2 = arithmeticPostAggregator.compute(hashMap);
        Assert.assertEquals(-1L, comparator.compare(compute, compute2));
        Assert.assertEquals(0L, comparator.compare(compute, compute));
        Assert.assertEquals(0L, comparator.compare(compute2, compute2));
        Assert.assertEquals(1L, comparator.compare(compute2, compute));
    }

    @Test
    public void testQuotient() {
        ArithmeticPostAggregator arithmeticPostAggregator = new ArithmeticPostAggregator(null, "quotient", ImmutableList.of((ConstantPostAggregator) new FieldAccessPostAggregator("numerator", "value"), new ConstantPostAggregator(PluralRules.KEYWORD_ZERO, 0)), "numericFirst");
        Assert.assertEquals(Double.valueOf(Double.NaN), arithmeticPostAggregator.compute(ImmutableMap.of("value", 0)));
        Assert.assertEquals(Double.valueOf(Double.NaN), arithmeticPostAggregator.compute(ImmutableMap.of("value", Double.valueOf(Double.NaN))));
        Assert.assertEquals(Double.valueOf(Double.POSITIVE_INFINITY), arithmeticPostAggregator.compute(ImmutableMap.of("value", 1)));
        Assert.assertEquals(Double.valueOf(Double.NEGATIVE_INFINITY), arithmeticPostAggregator.compute(ImmutableMap.of("value", -1)));
    }

    @Test
    public void testPow() {
        Assert.assertEquals(Double.valueOf(2.0d), new ArithmeticPostAggregator(null, "pow", ImmutableList.of(new ConstantPostAggregator("value", 4), new ConstantPostAggregator("power", Double.valueOf(0.5d))), "numericFirst").compute(ImmutableMap.of("value", 0)));
        ArithmeticPostAggregator arithmeticPostAggregator = new ArithmeticPostAggregator(null, "pow", ImmutableList.of((ConstantPostAggregator) new FieldAccessPostAggregator("base", "value"), new ConstantPostAggregator(PluralRules.KEYWORD_ZERO, 0)), "numericFirst");
        Assert.assertEquals(Double.valueOf(1.0d), arithmeticPostAggregator.compute(ImmutableMap.of("value", 0)));
        Assert.assertEquals(Double.valueOf(1.0d), arithmeticPostAggregator.compute(ImmutableMap.of("value", Double.valueOf(Double.NaN))));
        Assert.assertEquals(Double.valueOf(1.0d), arithmeticPostAggregator.compute(ImmutableMap.of("value", 1)));
        Assert.assertEquals(Double.valueOf(1.0d), arithmeticPostAggregator.compute(ImmutableMap.of("value", -1)));
        Assert.assertEquals(Double.valueOf(1.0d), arithmeticPostAggregator.compute(ImmutableMap.of("value", Double.valueOf(0.5d))));
    }

    @Test
    public void testDiv() {
        ArithmeticPostAggregator arithmeticPostAggregator = new ArithmeticPostAggregator(null, "/", ImmutableList.of((ConstantPostAggregator) new FieldAccessPostAggregator("numerator", "value"), new ConstantPostAggregator("denomiator", 0)));
        Assert.assertEquals(Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), arithmeticPostAggregator.compute(ImmutableMap.of("value", 0)));
        Assert.assertEquals(Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), arithmeticPostAggregator.compute(ImmutableMap.of("value", Double.valueOf(Double.NaN))));
        Assert.assertEquals(Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), arithmeticPostAggregator.compute(ImmutableMap.of("value", 1)));
        Assert.assertEquals(Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), arithmeticPostAggregator.compute(ImmutableMap.of("value", -1)));
    }

    @Test
    public void testNumericFirstOrdering() {
        Comparator comparator = new ArithmeticPostAggregator(null, "quotient", ImmutableList.of(new ConstantPostAggregator(PluralRules.KEYWORD_ZERO, 0), new ConstantPostAggregator(PluralRules.KEYWORD_ZERO, 0)), "numericFirst").getComparator();
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.NaN), Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS)) < 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.POSITIVE_INFINITY), Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS)) < 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.NEGATIVE_INFINITY), Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS)) < 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), Double.valueOf(Double.NaN)) > 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), Double.valueOf(Double.POSITIVE_INFINITY)) > 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), Double.valueOf(Double.NEGATIVE_INFINITY)) > 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.NEGATIVE_INFINITY), Double.valueOf(Double.POSITIVE_INFINITY)) < 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.POSITIVE_INFINITY), Double.valueOf(Double.NEGATIVE_INFINITY)) > 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.NaN), Double.valueOf(Double.POSITIVE_INFINITY)) > 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.NaN), Double.valueOf(Double.NEGATIVE_INFINITY)) > 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.POSITIVE_INFINITY), Double.valueOf(Double.NaN)) < 0);
        Assert.assertTrue(comparator.compare(Double.valueOf(Double.NEGATIVE_INFINITY), Double.valueOf(Double.NaN)) < 0);
    }

    @Test
    public void testResultArraySignature() {
        Assert.assertEquals(RowSignature.builder().addTimeColumn().add(TypeCompiler.SUM_OP, ColumnType.LONG).add(RowLock.DIAG_COUNT, ColumnType.LONG).add(TypeCompiler.AVG_OP, ColumnType.DOUBLE).build(), new TimeseriesQueryQueryToolChest().resultArraySignature(Druids.newTimeseriesQueryBuilder().dataSource(BaseCalciteQueryTest.DUMMY_SQL_ID).intervals("2000/3000").granularity(Granularities.HOUR).aggregators(new LongSumAggregatorFactory(TypeCompiler.SUM_OP, "col"), new CountAggregatorFactory(RowLock.DIAG_COUNT)).postAggregators(new ArithmeticPostAggregator(TypeCompiler.AVG_OP, "/", ImmutableList.of(new FieldAccessPostAggregator("_count", RowLock.DIAG_COUNT), new FieldAccessPostAggregator("_sum", TypeCompiler.SUM_OP)))).build()));
    }
}
