package org.apache.pinot.query.aggregation.groupby;

import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.lang.RandomStringUtils;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.response.broker.GroupByResult;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/query/aggregation/groupby/AggregationGroupByTrimmingServiceTest.class */
public class AggregationGroupByTrimmingServiceTest {
    private static final long RANDOM_SEED = System.currentTimeMillis();
    private static final Random RANDOM = new Random(RANDOM_SEED);
    private static final String ERROR_MESSAGE = "Random seed: " + RANDOM_SEED;
    private static final AggregationFunction SUM = createAggregationFunction("SUM", "sumColumn");
    private static final AggregationFunction DISTINCTCOUNT = createAggregationFunction("DISTINCTCOUNT", "distinctColumn");
    private static final AggregationFunction[] AGGREGATION_FUNCTIONS = {SUM, DISTINCTCOUNT};
    private static final int NUM_GROUP_KEYS = 3;
    private static final int GROUP_BY_TOP_N = 100;
    private static final int NUM_GROUPS = 50000;
    private static final int MAX_SIZE_OF_SET = 50;
    private List<String> _groups;
    private AggregationGroupByTrimmingService _trimmingService;

    @BeforeClass
    public void setUp() {
        HashSet hashSet = new HashSet(NUM_GROUPS);
        while (hashSet.size() < NUM_GROUPS) {
            ArrayList arrayList = new ArrayList(NUM_GROUP_KEYS);
            for (int i = 0; i < NUM_GROUP_KEYS; i++) {
                arrayList.add(RandomStringUtils.random(RANDOM.nextInt(10)).replace("\t", ""));
            }
            hashSet.add(buildGroupString(arrayList));
        }
        this._groups = new ArrayList(hashSet);
        StringBuilder sb = new StringBuilder();
        for (int i2 = 1; i2 < NUM_GROUP_KEYS; i2++) {
            sb.append("\t");
        }
        this._groups.set(49999, sb.toString());
        this._trimmingService = new AggregationGroupByTrimmingService(AGGREGATION_FUNCTIONS, GROUP_BY_TOP_N);
    }

    @Test
    public void testTrimming() {
        HashMap hashMap = new HashMap(NUM_GROUPS);
        for (int i = 0; i < NUM_GROUPS; i++) {
            IntOpenHashSet intOpenHashSet = new IntOpenHashSet();
            for (int i2 = 0; i2 <= i; i2 += 1000) {
                intOpenHashSet.add(i2);
            }
            hashMap.put(this._groups.get(i), new Object[]{Double.valueOf(i), intOpenHashSet});
        }
        List trimIntermediateResultsMap = this._trimmingService.trimIntermediateResultsMap(hashMap);
        Map map = (Map) trimIntermediateResultsMap.get(0);
        Map map2 = (Map) trimIntermediateResultsMap.get(1);
        int size = map.size();
        Assert.assertEquals(map2.size(), size, ERROR_MESSAGE);
        for (int i3 = NUM_GROUPS - size; i3 < NUM_GROUPS; i3++) {
            String str = this._groups.get(i3);
            Assert.assertEquals(((Double) map.get(str)).intValue(), i3, ERROR_MESSAGE);
            Assert.assertEquals(((IntOpenHashSet) map2.get(str)).size(), (i3 / 1000) + 1, ERROR_MESSAGE);
        }
        HashMap hashMap2 = new HashMap(size);
        for (Map.Entry entry : map2.entrySet()) {
            hashMap2.put(entry.getKey(), Integer.valueOf(((IntOpenHashSet) entry.getValue()).size()));
        }
        List[] trimFinalResults = this._trimmingService.trimFinalResults(new Map[]{map, hashMap2});
        List list = trimFinalResults[0];
        List list2 = trimFinalResults[1];
        for (int i4 = 0; i4 < GROUP_BY_TOP_N; i4++) {
            int i5 = 49999 - i4;
            GroupByResult groupByResult = (GroupByResult) list.get(i4);
            List group = groupByResult.getGroup();
            Assert.assertEquals(group.size(), NUM_GROUP_KEYS, ERROR_MESSAGE);
            Assert.assertEquals(buildGroupString(group), this._groups.get(i5), ERROR_MESSAGE);
            Assert.assertEquals(((Double) groupByResult.getValue()).intValue(), i5, ERROR_MESSAGE);
            Assert.assertEquals(((GroupByResult) list2.get(i4)).getValue(), Integer.valueOf((i5 / 1000) + 1), ERROR_MESSAGE);
        }
    }

    private static String buildGroupString(List<String> list) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < NUM_GROUP_KEYS; i++) {
            if (i != 0) {
                sb.append("\t");
            }
            sb.append(list.get(i));
        }
        return sb.toString();
    }

    private static AggregationFunction createAggregationFunction(String str, String str2) {
        AggregationInfo aggregationInfo = new AggregationInfo();
        aggregationInfo.setAggregationType(str);
        aggregationInfo.setExpressions(Collections.singletonList(str2));
        return AggregationFunctionFactory.getAggregationFunction(aggregationInfo, new BrokerRequest());
    }
}
