package org.apache.arrow.gandiva.evaluator;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.arrow.gandiva.exceptions.GandivaException;
import org.apache.arrow.gandiva.expression.TreeBuilder;
import org.apache.arrow.gandiva.expression.TreeNode;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/arrow/gandiva/evaluator/FilterTest.class */
public class FilterTest extends BaseEvaluatorTest {
    private Charset utf8Charset = Charset.forName("UTF-8");
    private Charset utf16Charset = Charset.forName("UTF-16");

    private int[] selectionVectorToArray(SelectionVector selectionVector) {
        int[] iArr = new int[selectionVector.getRecordCount()];
        for (int i = 0; i < selectionVector.getRecordCount(); i++) {
            iArr[i] = selectionVector.getIndex(i);
        }
        return iArr;
    }

    List<ArrowBuf> varBufs(String[] strArr, Charset charset) {
        ArrowBuf buffer = this.allocator.buffer((strArr.length + 1) * 4);
        ArrowBuf buffer2 = this.allocator.buffer(strArr.length * 8);
        int i = 0;
        for (String str : strArr) {
            buffer.writeInt(i);
            byte[] bytes = str.getBytes(charset);
            buffer2 = buffer2.reallocIfNeeded(buffer2.writerIndex() + bytes.length);
            buffer2.setBytes(i, bytes, 0, bytes.length);
            i += bytes.length;
        }
        buffer.writeInt(i);
        return Arrays.asList(buffer, buffer2);
    }

    List<ArrowBuf> stringBufs(String[] strArr) {
        return varBufs(strArr, this.utf8Charset);
    }

    @Test
    public void testSimpleInString() throws GandivaException, Exception {
        Field nullable = Field.nullable("c1", new ArrowType.Utf8());
        Filter make = Filter.make(new Schema(Lists.newArrayList(new Field[]{nullable})), TreeBuilder.makeCondition(TreeBuilder.makeInExpressionString(TreeBuilder.makeFunction("substr", Lists.newArrayList(new TreeNode[]{TreeBuilder.makeField(nullable), TreeBuilder.makeLiteral(1L), TreeBuilder.makeLiteral(3L)}), new ArrowType.Utf8()), Sets.newHashSet(new String[]{"one", "two", "thr", "fou"}))));
        byte[] bArr = {-1, 0};
        ArrowBuf buf = buf(bArr);
        ArrowBuf buf2 = buf(bArr);
        List<ArrowBuf> stringBufs = stringBufs(new String[]{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen"});
        ArrowRecordBatch arrowRecordBatch = new ArrowRecordBatch(16, Lists.newArrayList(new ArrowFieldNode[]{new ArrowFieldNode(16, 0L)}), Lists.newArrayList(new ArrowBuf[]{buf, stringBufs.get(0), stringBufs.get(1), buf2}));
        ArrowBuf buf3 = buf(16 * 2);
        SelectionVectorInt16 selectionVectorInt16 = new SelectionVectorInt16(buf3);
        make.evaluate(arrowRecordBatch, selectionVectorInt16);
        int[] selectionVectorToArray = selectionVectorToArray(selectionVectorInt16);
        releaseRecordBatch(arrowRecordBatch);
        buf3.close();
        make.close();
        Assert.assertArrayEquals(new int[]{0, 1, 2, 3}, selectionVectorToArray);
    }

    @Test
    public void testSimpleInInt() throws GandivaException, Exception {
        Field nullable = Field.nullable("c1", this.int32);
        Filter make = Filter.make(new Schema(Lists.newArrayList(new Field[]{nullable})), TreeBuilder.makeCondition(TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(nullable), Sets.newHashSet(new Integer[]{1, 2, 3, 4}))));
        byte[] bArr = {-1, 0};
        ArrowRecordBatch arrowRecordBatch = new ArrowRecordBatch(16, Lists.newArrayList(new ArrowFieldNode[]{new ArrowFieldNode(16, 0L)}), Lists.newArrayList(new ArrowBuf[]{buf(bArr), intBuf(new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), buf(bArr)}));
        ArrowBuf buf = buf(16 * 2);
        SelectionVectorInt16 selectionVectorInt16 = new SelectionVectorInt16(buf);
        make.evaluate(arrowRecordBatch, selectionVectorInt16);
        int[] selectionVectorToArray = selectionVectorToArray(selectionVectorInt16);
        releaseRecordBatch(arrowRecordBatch);
        buf.close();
        make.close();
        Assert.assertArrayEquals(new int[]{0, 1, 2, 3}, selectionVectorToArray);
    }

    @Test
    public void testSimpleSV16() throws GandivaException, Exception {
        ArrayList newArrayList = Lists.newArrayList(new Field[]{Field.nullable("a", this.int32), Field.nullable("b", this.int32)});
        verifyTestCase(Filter.make(new Schema(newArrayList), TreeBuilder.makeCondition("less_than", newArrayList)), 16, new byte[]{-1, 0}, new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, new int[]{2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15}, new int[]{0, 2, 4, 6});
    }

    @Test
    public void testSimpleSV16_AllMatched() throws GandivaException, Exception {
        ArrayList newArrayList = Lists.newArrayList(new Field[]{Field.nullable("a", this.int32), Field.nullable("b", this.int32)});
        Filter make = Filter.make(new Schema(newArrayList), TreeBuilder.makeCondition("less_than", newArrayList));
        byte[] bArr = new byte[32 / 8];
        IntStream.range(0, 32 / 8).forEach(i -> {
            bArr[i] = -1;
        });
        int[] iArr = new int[32];
        IntStream.range(0, 32).forEach(i2 -> {
            iArr[i2] = i2;
        });
        int[] iArr2 = new int[32];
        IntStream.range(0, 32).forEach(i3 -> {
            iArr2[i3] = i3 + 1;
        });
        int[] iArr3 = new int[32];
        IntStream.range(0, 32).forEach(i4 -> {
            iArr3[i4] = i4;
        });
        verifyTestCase(make, 32, bArr, iArr, iArr2, iArr3);
    }

    @Test
    public void testSimpleSV16_GreaterThan64Recs() throws GandivaException, Exception {
        ArrayList newArrayList = Lists.newArrayList(new Field[]{Field.nullable("a", this.int32), Field.nullable("b", this.int32)});
        Filter make = Filter.make(new Schema(newArrayList), TreeBuilder.makeCondition("greater_than", newArrayList));
        byte[] bArr = new byte[1000 / 8];
        IntStream.range(0, 1000 / 8).forEach(i -> {
            bArr[i] = -1;
        });
        int[] iArr = new int[1000];
        IntStream.range(0, 1000).forEach(i2 -> {
            iArr[i2] = i2;
        });
        int[] iArr2 = new int[1000];
        IntStream.range(0, 1000).forEach(i3 -> {
            iArr2[i3] = i3 + 1;
        });
        iArr[0] = 5;
        iArr2[0] = 0;
        verifyTestCase(make, 1000, bArr, iArr, iArr2, new int[]{0});
    }

    @Test
    public void testSimpleSV32() throws GandivaException, Exception {
        ArrayList newArrayList = Lists.newArrayList(new Field[]{Field.nullable("a", this.int32), Field.nullable("b", this.int32)});
        verifyTestCase(Filter.make(new Schema(newArrayList), TreeBuilder.makeCondition("less_than", newArrayList)), 16, new byte[]{-1, 0}, new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, new int[]{2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15}, new int[]{0, 2, 4, 6});
    }

    @Test
    public void testSimpleFilterWithNoOptimisation() throws GandivaException, Exception {
        ArrayList newArrayList = Lists.newArrayList(new Field[]{Field.nullable("a", this.int32), Field.nullable("b", this.int32)});
        verifyTestCase(Filter.make(new Schema(newArrayList), TreeBuilder.makeCondition("less_than", newArrayList), false), 16, new byte[]{-1, 0}, new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, new int[]{2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15}, new int[]{0, 2, 4, 6});
    }

    private void verifyTestCase(Filter filter, int i, byte[] bArr, int[] iArr, int[] iArr2, int[] iArr3) throws GandivaException {
        ArrowRecordBatch arrowRecordBatch = new ArrowRecordBatch(i, Lists.newArrayList(new ArrowFieldNode[]{new ArrowFieldNode(i, 0L), new ArrowFieldNode(i, 0L)}), Lists.newArrayList(new ArrowBuf[]{buf(bArr), intBuf(iArr), buf(bArr), intBuf(iArr2)}));
        ArrowBuf buf = buf(i * 2);
        SelectionVectorInt16 selectionVectorInt16 = new SelectionVectorInt16(buf);
        filter.evaluate(arrowRecordBatch, selectionVectorInt16);
        int[] selectionVectorToArray = selectionVectorToArray(selectionVectorInt16);
        releaseRecordBatch(arrowRecordBatch);
        buf.close();
        filter.close();
        Assert.assertArrayEquals(iArr3, selectionVectorToArray);
    }

    @Override // org.apache.arrow.gandiva.evaluator.BaseEvaluatorTest
    @After
    public /* bridge */ /* synthetic */ void tearDown() {
        super.tearDown();
    }

    @Override // org.apache.arrow.gandiva.evaluator.BaseEvaluatorTest
    @Before
    public /* bridge */ /* synthetic */ void init() {
        super.init();
    }
}
