package ai.chalk.internal.arrow;

import java.io.ByteArrayOutputStream;
import java.lang.reflect.Array;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.compression.CommonsCompressionFactory;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.SeekableReadChannel;
import org.apache.arrow.vector.table.Table;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.apache.arrow.vector.util.VectorSchemaRootAppender;

/* loaded from: input_file:ai/chalk/internal/arrow/FeatherProcessor.class */
public class FeatherProcessor {
    private static Map<Class<?>, ArrowType> javaToArrowType = new HashMap();

    /* renamed from: ai.chalk.internal.arrow.FeatherProcessor$1, reason: invalid class name */
    /* loaded from: input_file:ai/chalk/internal/arrow/FeatherProcessor$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$arrow$vector$types$pojo$ArrowType$ArrowTypeID = new int[ArrowType.ArrowTypeID.values().length];

        static {
            try {
                $SwitchMap$org$apache$arrow$vector$types$pojo$ArrowType$ArrowTypeID[ArrowType.ArrowTypeID.Int.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$pojo$ArrowType$ArrowTypeID[ArrowType.ArrowTypeID.FloatingPoint.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$pojo$ArrowType$ArrowTypeID[ArrowType.ArrowTypeID.Utf8.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$pojo$ArrowType$ArrowTypeID[ArrowType.ArrowTypeID.Bool.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public static byte[] inputsToArrowBytes(Map<String, List<?>> map) throws Exception {
        BigIntVector bitVector;
        ArrayList<Field> arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, List<?>> entry : map.entrySet()) {
            try {
                ArrayList arrayList3 = new ArrayList(entry.getValue());
                if (arrayList3.size() == 0) {
                    throw new Exception("Input values is an `Array` or a `List` of length 0");
                }
                hashMap.put(entry.getKey(), arrayList3);
            } catch (Exception e) {
                throw new Exception(String.format("error converting '%s' value to a `List<Object>`: %s", entry.getKey(), e.getMessage()));
            }
        }
        for (Map.Entry entry2 : hashMap.entrySet()) {
            ArrowType arrowType = javaToArrowType.get(((List) entry2.getValue()).get(0).getClass());
            if (arrowType == null) {
                throw new Exception("Unsupported data type: " + Array.get(entry2.getValue(), 0).getClass().getSimpleName());
            }
            Field field = new Field((String) entry2.getKey(), FieldType.nullable(arrowType), (List) null);
            arrayList.add(field);
            switch (AnonymousClass1.$SwitchMap$org$apache$arrow$vector$types$pojo$ArrowType$ArrowTypeID[field.getType().getTypeID().ordinal()]) {
                case 1:
                    bitVector = new BigIntVector(field.getName(), new RootAllocator(Long.MAX_VALUE));
                    break;
                case 2:
                    if (field.getType().getPrecision() == FloatingPointPrecision.SINGLE) {
                        bitVector = new Float4Vector(field.getName(), new RootAllocator(Long.MAX_VALUE));
                        break;
                    } else {
                        bitVector = new Float8Vector(field.getName(), new RootAllocator(Long.MAX_VALUE));
                        break;
                    }
                case 3:
                    bitVector = new VarCharVector(field.getName(), new RootAllocator(Long.MAX_VALUE));
                    break;
                case 4:
                    bitVector = new BitVector(field.getName(), new RootAllocator(Long.MAX_VALUE));
                    break;
                default:
                    throw new Exception("Unsupported arrow type: " + field.getType().getTypeID());
            }
            arrayList2.add(bitVector);
        }
        VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(arrayList, arrayList2, 0);
        int i = 0;
        for (Field field2 : arrayList) {
            List list = (List) hashMap.get(field2.getName());
            BigIntVector vector = vectorSchemaRoot.getVector(field2.getName());
            i = list.size();
            switch (AnonymousClass1.$SwitchMap$org$apache$arrow$vector$types$pojo$ArrowType$ArrowTypeID[field2.getType().getTypeID().ordinal()]) {
                case 1:
                    BigIntVector bigIntVector = vector;
                    bigIntVector.allocateNew(list.size());
                    for (int i2 = 0; i2 < list.size(); i2++) {
                        bigIntVector.set(i2, Long.parseLong(list.get(i2).toString()));
                    }
                    bigIntVector.setValueCount(list.size());
                    break;
                case 2:
                    if (field2.getType().getPrecision() == FloatingPointPrecision.SINGLE) {
                        Float4Vector float4Vector = (Float4Vector) vector;
                        float4Vector.allocateNew(list.size());
                        for (int i3 = 0; i3 < list.size(); i3++) {
                            float4Vector.set(i3, ((Float) list.get(i3)).floatValue());
                        }
                        float4Vector.setValueCount(list.size());
                        break;
                    } else {
                        Float8Vector float8Vector = (Float8Vector) vector;
                        float8Vector.allocateNew(list.size());
                        for (int i4 = 0; i4 < list.size(); i4++) {
                            float8Vector.set(i4, ((Double) list.get(i4)).doubleValue());
                        }
                        float8Vector.setValueCount(list.size());
                        break;
                    }
                case 3:
                    VarCharVector varCharVector = (VarCharVector) vector;
                    varCharVector.allocateNew(list.size());
                    for (int i5 = 0; i5 < list.size(); i5++) {
                        varCharVector.set(i5, ((String) list.get(i5)).getBytes());
                    }
                    varCharVector.setValueCount(list.size());
                    break;
                case 4:
                    BitVector bitVector2 = (BitVector) vector;
                    bitVector2.allocateNew(list.size());
                    for (int i6 = 0; i6 < list.size(); i6++) {
                        bitVector2.set(i6, ((Boolean) list.get(i6)).booleanValue() ? 1 : 0);
                    }
                    bitVector2.setValueCount(list.size());
                    break;
            }
        }
        vectorSchemaRoot.setRowCount(i);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        try {
            ArrowFileWriter arrowFileWriter = new ArrowFileWriter(vectorSchemaRoot, (DictionaryProvider) null, Channels.newChannel(byteArrayOutputStream));
            try {
                arrowFileWriter.start();
                arrowFileWriter.writeBatch();
                arrowFileWriter.end();
                arrowFileWriter.close();
                vectorSchemaRoot.close();
                byte[] byteArray = byteArrayOutputStream.toByteArray();
                arrowFileWriter.close();
                byteArrayOutputStream.close();
                return byteArray;
            } finally {
            }
        } catch (Throwable th) {
            try {
                byteArrayOutputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public static Table getTableIfBatchSizeOne(byte[] bArr) throws Exception {
        ArrowFileReader arrowFileReader = new ArrowFileReader(new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(bArr)), new RootAllocator(Long.MAX_VALUE), new CommonsCompressionFactory());
        int i = 0;
        VectorSchemaRoot vectorSchemaRoot = arrowFileReader.getVectorSchemaRoot();
        Table table = null;
        while (arrowFileReader.loadNextBatch()) {
            try {
                if (table == null) {
                    table = new Table(vectorSchemaRoot);
                }
                i++;
            } catch (Throwable th) {
                if (vectorSchemaRoot != null) {
                    try {
                        vectorSchemaRoot.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        if (i == 1) {
            Table table2 = table;
            if (vectorSchemaRoot != null) {
                vectorSchemaRoot.close();
            }
            return table2;
        }
        if (table != null) {
            table.close();
        }
        if (vectorSchemaRoot != null) {
            vectorSchemaRoot.close();
        }
        arrowFileReader.close();
        return null;
    }

    public static Table convertBytesToTable(byte[] bArr) throws Exception {
        Table tableIfBatchSizeOne = getTableIfBatchSizeOne(bArr);
        if (tableIfBatchSizeOne != null) {
            return tableIfBatchSizeOne;
        }
        ArrowFileReader arrowFileReader = new ArrowFileReader(new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(bArr)), new RootAllocator(Long.MAX_VALUE), new CommonsCompressionFactory());
        VectorSchemaRoot vectorSchemaRoot = arrowFileReader.getVectorSchemaRoot();
        try {
            VectorSchemaRoot create = VectorSchemaRoot.create(vectorSchemaRoot.getSchema(), new RootAllocator(Long.MAX_VALUE));
            try {
                create.allocateNew();
                while (arrowFileReader.loadNextBatch()) {
                    VectorSchemaRootAppender.append(create, new VectorSchemaRoot[]{vectorSchemaRoot});
                }
                Table table = new Table(create);
                arrowFileReader.close();
                if (create != null) {
                    create.close();
                }
                if (vectorSchemaRoot != null) {
                    vectorSchemaRoot.close();
                }
                return table;
            } finally {
            }
        } catch (Throwable th) {
            if (vectorSchemaRoot != null) {
                try {
                    vectorSchemaRoot.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    static {
        javaToArrowType.put(Byte.class, new ArrowType.Int(64, true));
        javaToArrowType.put(Short.class, new ArrowType.Int(64, true));
        javaToArrowType.put(Integer.class, new ArrowType.Int(64, true));
        javaToArrowType.put(Long.class, new ArrowType.Int(64, true));
        javaToArrowType.put(Float.class, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE));
        javaToArrowType.put(Double.class, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
        javaToArrowType.put(String.class, ArrowType.Utf8.INSTANCE);
        javaToArrowType.put(Boolean.class, ArrowType.Bool.INSTANCE);
    }
}
