package org.apache.arrow.compression;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Objects;
import java.util.Optional;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.GenerateSampleData;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.apache.arrow.vector.util.CallBack;
import org.junit.After;
import org.junit.Assert;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/arrow/compression/TestArrowReaderWriterWithCompression.class */
public class TestArrowReaderWriterWithCompression {
    private BufferAllocator allocator;
    private ByteArrayOutputStream out;
    private VectorSchemaRoot root;

    @BeforeEach
    public void setup() {
        if (this.allocator == null) {
            this.allocator = new RootAllocator(2147483647L);
        }
        this.out = new ByteArrayOutputStream();
        this.root = null;
    }

    @After
    public void tearDown() {
        if (this.root != null) {
            this.root.close();
        }
        if (this.allocator != null) {
            this.allocator.close();
        }
        if (this.out != null) {
            this.out.reset();
        }
    }

    private void createAndWriteArrowFile(DictionaryProvider dictionaryProvider, CompressionUtil.CodecType codecType) throws IOException {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList()));
        this.root = VectorSchemaRoot.create(new Schema(arrayList), this.allocator);
        GenerateSampleData.generateTestData(this.root.getVector(0), 10);
        this.root.setRowCount(10);
        ArrowFileWriter arrowFileWriter = new ArrowFileWriter(this.root, dictionaryProvider, Channels.newChannel(this.out), new HashMap(), IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7));
        try {
            arrowFileWriter.start();
            arrowFileWriter.writeBatch();
            arrowFileWriter.end();
            arrowFileWriter.close();
        } catch (Throwable th) {
            try {
                arrowFileWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private void createAndWriteArrowStream(DictionaryProvider dictionaryProvider, CompressionUtil.CodecType codecType) throws IOException {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList()));
        this.root = VectorSchemaRoot.create(new Schema(arrayList), this.allocator);
        GenerateSampleData.generateTestData(this.root.getVector(0), 10);
        this.root.setRowCount(10);
        ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(this.root, dictionaryProvider, Channels.newChannel(this.out), IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7));
        try {
            arrowStreamWriter.start();
            arrowStreamWriter.writeBatch();
            arrowStreamWriter.end();
            arrowStreamWriter.close();
        } catch (Throwable th) {
            try {
                arrowStreamWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [byte[], byte[][]] */
    private Dictionary createDictionary(VarCharVector varCharVector) {
        setVector(varCharVector, new byte[]{"foo".getBytes(StandardCharsets.UTF_8), "bar".getBytes(StandardCharsets.UTF_8), "baz".getBytes(StandardCharsets.UTF_8)});
        return new Dictionary(varCharVector, new DictionaryEncoding(1L, false, (ArrowType.Int) null));
    }

    @Test
    public void testArrowFileZstdRoundTrip() throws Exception {
        createAndWriteArrowFile(null, CompressionUtil.CodecType.ZSTD);
        ArrowFileReader arrowFileReader = new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(this.out.toByteArray()), this.allocator, CommonsCompressionFactory.INSTANCE);
        try {
            Assertions.assertEquals(1, arrowFileReader.getRecordBlocks().size());
            Assertions.assertTrue(arrowFileReader.loadNextBatch());
            Assertions.assertTrue(this.root.equals(arrowFileReader.getVectorSchemaRoot()));
            Assertions.assertFalse(arrowFileReader.loadNextBatch());
            arrowFileReader.close();
            arrowFileReader = new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(this.out.toByteArray()), this.allocator, NoCompressionCodec.Factory.INSTANCE);
            try {
                Assertions.assertEquals(1, arrowFileReader.getRecordBlocks().size());
                Objects.requireNonNull(arrowFileReader);
                Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD", ((Exception) Assert.assertThrows(IllegalArgumentException.class, arrowFileReader::loadNextBatch)).getMessage());
                arrowFileReader.close();
            } finally {
            }
        } finally {
        }
    }

    @Test
    public void testArrowStreamZstdRoundTrip() throws Exception {
        createAndWriteArrowStream(null, CompressionUtil.CodecType.ZSTD);
        ArrowStreamReader arrowStreamReader = new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(this.out.toByteArray()), this.allocator, CommonsCompressionFactory.INSTANCE);
        try {
            Assert.assertTrue(arrowStreamReader.loadNextBatch());
            Assert.assertTrue(this.root.equals(arrowStreamReader.getVectorSchemaRoot()));
            Assert.assertFalse(arrowStreamReader.loadNextBatch());
            arrowStreamReader.close();
            arrowStreamReader = new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(this.out.toByteArray()), this.allocator, NoCompressionCodec.Factory.INSTANCE);
            try {
                Objects.requireNonNull(arrowStreamReader);
                Assert.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD", ((Exception) Assert.assertThrows(IllegalArgumentException.class, arrowStreamReader::loadNextBatch)).getMessage());
                arrowStreamReader.close();
            } finally {
            }
        } finally {
        }
    }

    @Test
    public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
        VarCharVector varCharVector = (VarCharVector) FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1_file", this.allocator, (CallBack) null);
        Dictionary createDictionary = createDictionary(varCharVector);
        DictionaryProvider.MapDictionaryProvider mapDictionaryProvider = new DictionaryProvider.MapDictionaryProvider(new Dictionary[0]);
        mapDictionaryProvider.put(createDictionary);
        createAndWriteArrowFile(mapDictionaryProvider, CompressionUtil.CodecType.ZSTD);
        ArrowFileReader arrowFileReader = new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(this.out.toByteArray()), this.allocator, CommonsCompressionFactory.INSTANCE);
        try {
            Assertions.assertEquals(1, arrowFileReader.getRecordBlocks().size());
            Assertions.assertTrue(arrowFileReader.loadNextBatch());
            Assertions.assertTrue(this.root.equals(arrowFileReader.getVectorSchemaRoot()));
            Assertions.assertFalse(arrowFileReader.loadNextBatch());
            arrowFileReader.close();
            arrowFileReader = new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(this.out.toByteArray()), this.allocator, NoCompressionCodec.Factory.INSTANCE);
            try {
                Assertions.assertEquals(1, arrowFileReader.getRecordBlocks().size());
                Objects.requireNonNull(arrowFileReader);
                Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD", ((Exception) Assert.assertThrows(IllegalArgumentException.class, arrowFileReader::loadNextBatch)).getMessage());
                arrowFileReader.close();
                varCharVector.close();
            } finally {
            }
        } finally {
        }
    }

    @Test
    public void testArrowStreamZstdRoundTripWithDictionary() throws Exception {
        VarCharVector varCharVector = (VarCharVector) FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1_stream", this.allocator, (CallBack) null);
        Dictionary createDictionary = createDictionary(varCharVector);
        DictionaryProvider.MapDictionaryProvider mapDictionaryProvider = new DictionaryProvider.MapDictionaryProvider(new Dictionary[0]);
        mapDictionaryProvider.put(createDictionary);
        createAndWriteArrowStream(mapDictionaryProvider, CompressionUtil.CodecType.ZSTD);
        ArrowStreamReader arrowStreamReader = new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(this.out.toByteArray()), this.allocator, CommonsCompressionFactory.INSTANCE);
        try {
            Assertions.assertTrue(arrowStreamReader.loadNextBatch());
            Assertions.assertTrue(this.root.equals(arrowStreamReader.getVectorSchemaRoot()));
            Assertions.assertFalse(arrowStreamReader.loadNextBatch());
            arrowStreamReader.close();
            arrowStreamReader = new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(this.out.toByteArray()), this.allocator, NoCompressionCodec.Factory.INSTANCE);
            try {
                Objects.requireNonNull(arrowStreamReader);
                Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD", ((Exception) Assert.assertThrows(IllegalArgumentException.class, arrowStreamReader::loadNextBatch)).getMessage());
                arrowStreamReader.close();
                varCharVector.close();
            } finally {
            }
        } finally {
        }
    }

    public static void setVector(VarCharVector varCharVector, byte[]... bArr) {
        int length = bArr.length;
        varCharVector.allocateNewSafe();
        for (int i = 0; i < length; i++) {
            if (bArr[i] != null) {
                varCharVector.set(i, bArr[i]);
            }
        }
        varCharVector.setValueCount(length);
    }
}
