package org.apache.arrow.vector.ipc;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.ToIntBiFunction;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.UInt1Vector;
import org.apache.arrow.vector.UInt2Vector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.class */
public class TestUIntDictionaryRoundTrip {
    private final boolean streamMode;
    private BufferAllocator allocator;
    private DictionaryProvider.MapDictionaryProvider dictionaryProvider;

    public TestUIntDictionaryRoundTrip(boolean z) {
        this.streamMode = z;
    }

    @Before
    public void init() {
        this.allocator = new RootAllocator(Long.MAX_VALUE);
        this.dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(new Dictionary[0]);
    }

    @After
    public void terminate() throws Exception {
        this.allocator.close();
    }

    private byte[] writeData(FieldVector fieldVector) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(Arrays.asList(fieldVector.getField()), Arrays.asList(fieldVector), fieldVector.getValueCount());
        ArrowStreamWriter arrowStreamWriter = this.streamMode ? new ArrowStreamWriter(vectorSchemaRoot, this.dictionaryProvider, byteArrayOutputStream) : new ArrowFileWriter(vectorSchemaRoot, this.dictionaryProvider, Channels.newChannel(byteArrayOutputStream));
        try {
            arrowStreamWriter.start();
            arrowStreamWriter.writeBatch();
            arrowStreamWriter.end();
            byte[] byteArray = byteArrayOutputStream.toByteArray();
            if (arrowStreamWriter != null) {
                arrowStreamWriter.close();
            }
            return byteArray;
        } catch (Throwable th) {
            if (arrowStreamWriter != null) {
                try {
                    arrowStreamWriter.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void readData(byte[] bArr, Field field, ToIntBiFunction<ValueVector, Integer> toIntBiFunction, long j, int[] iArr, String[] strArr) throws IOException {
        ArrowStreamReader arrowStreamReader = this.streamMode ? new ArrowStreamReader(new ByteArrayInputStream(bArr), this.allocator) : new ArrowFileReader(new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(bArr)), this.allocator);
        try {
            Schema schema = arrowStreamReader.getVectorSchemaRoot().getSchema();
            Assert.assertEquals(1L, schema.getFields().size());
            Assert.assertEquals(field, schema.getFields().get(0));
            Assert.assertTrue(arrowStreamReader.loadNextBatch());
            VectorSchemaRoot vectorSchemaRoot = arrowStreamReader.getVectorSchemaRoot();
            Assert.assertEquals(1L, vectorSchemaRoot.getFieldVectors().size());
            FieldVector vector = vectorSchemaRoot.getVector(0);
            Assert.assertEquals(iArr.length, vector.getValueCount());
            for (int i = 0; i < iArr.length; i++) {
                Assert.assertEquals(iArr[i], toIntBiFunction.applyAsInt(vector, Integer.valueOf(i)));
            }
            Map dictionaryVectors = arrowStreamReader.getDictionaryVectors();
            Assert.assertEquals(1L, dictionaryVectors.size());
            Dictionary dictionary = (Dictionary) dictionaryVectors.get(Long.valueOf(j));
            Assert.assertNotNull(dictionary);
            Assert.assertTrue(dictionary.getVector() instanceof VarCharVector);
            VarCharVector vector2 = dictionary.getVector();
            Assert.assertEquals(strArr.length, vector2.getValueCount());
            for (int i2 = 0; i2 < vector2.getValueCount(); i2++) {
                Assert.assertArrayEquals(strArr[i2].getBytes(StandardCharsets.UTF_8), vector2.get(i2));
            }
            if (arrowStreamReader != null) {
                arrowStreamReader.close();
            }
        } catch (Throwable th) {
            if (arrowStreamReader != null) {
                try {
                    arrowStreamReader.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private ValueVector createEncodedVector(int i, VarCharVector varCharVector) {
        DictionaryEncoding dictionaryEncoding = new DictionaryEncoding(i, false, new ArrowType.Int(i, false));
        this.dictionaryProvider.put(new Dictionary(varCharVector, dictionaryEncoding));
        return new Field("encoded", new FieldType(true, dictionaryEncoding.getIndexType(), dictionaryEncoding, (Map) null), (List) null).createVector(this.allocator);
    }

    @Test
    public void testUInt1RoundTrip() throws IOException {
        VarCharVector varCharVector = new VarCharVector("dictionary", this.allocator);
        try {
            UInt1Vector createEncodedVector = createEncodedVector(8, varCharVector);
            try {
                int[] iArr = new int[255];
                String[] strArr = new String[255];
                for (int i = 0; i < 255; i++) {
                    createEncodedVector.setSafe(i, (byte) i);
                    iArr[i] = i;
                    strArr[i] = String.valueOf(i);
                }
                createEncodedVector.setValueCount(255);
                ValueVectorDataPopulator.setVector(varCharVector, strArr);
                readData(writeData(createEncodedVector), createEncodedVector.getField(), (valueVector, num) -> {
                    return (int) ((UInt1Vector) valueVector).getValueAsLong(num.intValue());
                }, 8L, iArr, strArr);
                if (createEncodedVector != null) {
                    createEncodedVector.close();
                }
                varCharVector.close();
            } finally {
            }
        } catch (Throwable th) {
            try {
                varCharVector.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test
    public void testUInt2RoundTrip() throws IOException {
        VarCharVector varCharVector = new VarCharVector("dictionary", this.allocator);
        try {
            UInt2Vector createEncodedVector = createEncodedVector(16, varCharVector);
            try {
                int[] iArr = {1, 3, 5, 7, 9, 65535};
                String[] strArr = new String[65535];
                for (int i = 0; i < 65535; i++) {
                    strArr[i] = String.valueOf(i);
                }
                ValueVectorDataPopulator.setVector(createEncodedVector, (char) 1, (char) 3, (char) 5, (char) 7, '\t', (char) 65535);
                ValueVectorDataPopulator.setVector(varCharVector, strArr);
                readData(writeData(createEncodedVector), createEncodedVector.getField(), (valueVector, num) -> {
                    return (int) ((UInt2Vector) valueVector).getValueAsLong(num.intValue());
                }, 16L, iArr, strArr);
                if (createEncodedVector != null) {
                    createEncodedVector.close();
                }
                varCharVector.close();
            } finally {
            }
        } catch (Throwable th) {
            try {
                varCharVector.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test
    public void testUInt4RoundTrip() throws IOException {
        VarCharVector varCharVector = new VarCharVector("dictionary", this.allocator);
        try {
            UInt4Vector createEncodedVector = createEncodedVector(32, varCharVector);
            try {
                int[] iArr = {1, 3, 5, 7, 9};
                String[] strArr = new String[10];
                for (int i = 0; i < 10; i++) {
                    strArr[i] = String.valueOf(i);
                }
                ValueVectorDataPopulator.setVector(createEncodedVector, 1, 3, 5, 7, 9);
                ValueVectorDataPopulator.setVector(varCharVector, strArr);
                ValueVectorDataPopulator.setVector(createEncodedVector, 1, 3, 5, 7, 9);
                readData(writeData(createEncodedVector), createEncodedVector.getField(), (valueVector, num) -> {
                    return (int) ((UInt4Vector) valueVector).getValueAsLong(num.intValue());
                }, 32L, iArr, strArr);
                if (createEncodedVector != null) {
                    createEncodedVector.close();
                }
                varCharVector.close();
            } finally {
            }
        } catch (Throwable th) {
            try {
                varCharVector.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test
    public void testUInt8RoundTrip() throws IOException {
        VarCharVector varCharVector = new VarCharVector("dictionary", this.allocator);
        try {
            UInt8Vector createEncodedVector = createEncodedVector(64, varCharVector);
            try {
                int[] iArr = {1, 3, 5, 7, 9};
                String[] strArr = new String[10];
                for (int i = 0; i < 10; i++) {
                    strArr[i] = String.valueOf(i);
                }
                ValueVectorDataPopulator.setVector(createEncodedVector, 1L, 3L, 5L, 7L, 9L);
                ValueVectorDataPopulator.setVector(varCharVector, strArr);
                readData(writeData(createEncodedVector), createEncodedVector.getField(), (valueVector, num) -> {
                    return (int) ((UInt8Vector) valueVector).getValueAsLong(num.intValue());
                }, 64L, iArr, strArr);
                if (createEncodedVector != null) {
                    createEncodedVector.close();
                }
                varCharVector.close();
            } finally {
            }
        } catch (Throwable th) {
            try {
                varCharVector.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Parameterized.Parameters(name = "stream mode = {0}")
    public static Collection<Object[]> getRepeat() {
        return Arrays.asList(new Object[]{true}, new Object[]{false});
    }
}
