package io.trino.plugin.functions.python;

import com.dylibso.chicory.runtime.ExportFunction;
import com.dylibso.chicory.runtime.HostFunction;
import com.dylibso.chicory.runtime.ImportFunction;
import com.dylibso.chicory.runtime.ImportValues;
import com.dylibso.chicory.runtime.Instance;
import com.dylibso.chicory.runtime.Memory;
import com.dylibso.chicory.wasi.WasiOptions;
import com.dylibso.chicory.wasi.WasiPreview1;
import com.dylibso.chicory.wasm.ChicoryException;
import com.dylibso.chicory.wasm.WasmModule;
import com.dylibso.chicory.wasm.types.ValueType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.Closer;
import com.google.common.jimfs.Configuration;
import com.google.common.jimfs.Jimfs;
import io.airlift.log.Logger;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.type.Type;
import io.trino.wasm.python.PythonModule;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileSystem;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Stream;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:io/trino/plugin/functions/python/PythonEngine.class */
public final class PythonEngine implements Closeable {
    private static final Logger log = Logger.get(PythonEngine.class);
    private static final com.dylibso.chicory.log.Logger logger = JdkLogger.get(PythonEngine.class);
    private static final Configuration FS_CONFIG = Configuration.unix().toBuilder().setAttributeViews("unix", new String[0]).setMaxSize(DataSize.of(8, DataSize.Unit.MEGABYTE).toBytes()).build();
    private static final Map<Integer, ErrorCodeSupplier> ERROR_CODES = (Map) Stream.of((Object[]) StandardErrorCode.values()).collect(ImmutableMap.toImmutableMap(standardErrorCode -> {
        return Integer.valueOf(standardErrorCode.toErrorCode().getCode());
    }, Function.identity()));
    private static final WasmModule PYTHON_MODULE = PythonModule.load();
    private final Closer closer = Closer.create();
    private final LimitedOutputStream stderr = new LimitedOutputStream();
    private final ExportFunction allocate;
    private final ExportFunction deallocate;
    private final ExportFunction setup;
    private final ExportFunction execute;
    private final Memory memory;
    private Type returnType;
    private List<Type> argumentTypes;
    private TrinoException error;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/plugin/functions/python/PythonEngine$LimitedOutputStream.class */
    public static class LimitedOutputStream extends ByteArrayOutputStream {
        private static final int LIMIT = 4096;

        private LimitedOutputStream() {
        }

        @Override // java.io.ByteArrayOutputStream, java.io.OutputStream
        public void write(byte[] bArr, int i, int i2) {
            if (this.count < LIMIT) {
                super.write(bArr, i, Math.min(i2, LIMIT - this.count));
            }
        }
    }

    public PythonEngine(String str) {
        Path path = ((FileSystem) this.closer.register(Jimfs.newFileSystem(FS_CONFIG))).getPath("/guest", new String[0]);
        try {
            Files.createDirectories(path, new FileAttribute[0]);
            Files.writeString(path.resolve("guest.py"), str, new OpenOption[0]);
            Instance build = Instance.builder(PYTHON_MODULE).withMachineFactory(PythonModule::create).withImportValues(ImportValues.builder().addFunction(this.closer.register(WasiPreview1.builder().withLogger(logger).withOptions(WasiOptions.builder().withStdout((OutputStream) this.closer.register(new LoggingOutputStream(log))).withStderr(this.stderr).withDirectory(path.toString(), path).build()).build()).toHostFunctions()).addFunction(new ImportFunction[]{returnErrorHostFunction()}).build()).build();
            this.allocate = build.export("allocate");
            this.deallocate = build.export("deallocate");
            this.setup = build.export("setup");
            this.execute = build.export("execute");
            this.memory = build.memory();
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public void setup(Type type, List<Type> list, String str) {
        try {
            doSetup(type, list, str);
        } catch (ChicoryException e) {
            throw fatalError("Python error", e);
        }
    }

    private void doSetup(Type type, List<Type> list, String str) {
        byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
        int allocate = allocate(bytes.length + 1);
        this.memory.write(allocate, bytes);
        this.memory.writeByte(allocate + bytes.length, (byte) 0);
        Slice rowTypeDescriptor = TrinoTypes.toRowTypeDescriptor(list);
        int allocate2 = allocate(rowTypeDescriptor.length());
        writeSliceTo(rowTypeDescriptor, allocate2);
        Slice typeDescriptor = TrinoTypes.toTypeDescriptor(type);
        int allocate3 = allocate(typeDescriptor.length());
        writeSliceTo(typeDescriptor, allocate3);
        this.setup.apply(new long[]{allocate, allocate2, allocate3});
        deallocate(allocate);
        this.returnType = (Type) Objects.requireNonNull(type, "returnType is null");
        this.argumentTypes = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "argumentTypes is null"));
    }

    private void writeSliceTo(Slice slice, int i) {
        this.memory.write(i, slice.byteArray(), slice.byteArrayOffset(), slice.length());
    }

    private int allocate(int i) {
        return Math.toIntExact(this.allocate.apply(new long[]{i})[0]);
    }

    private void deallocate(int i) {
        this.deallocate.apply(new long[]{i});
    }

    private int execute(int i) {
        return Math.toIntExact(this.execute.apply(new long[]{i})[0]);
    }

    public Object execute(Object[] objArr) {
        Slice javaToBinary = TrinoTypes.javaToBinary(this.argumentTypes, objArr);
        int allocate = allocate(javaToBinary.length());
        writeSliceTo(javaToBinary, allocate);
        this.error = null;
        try {
            int execute = execute(allocate);
            deallocate(allocate);
            if (this.error != null) {
                TrinoException trinoException = this.error;
                Objects.requireNonNull(trinoException);
                throw new TrinoException(trinoException::getErrorCode, this.error.getMessage(), this.error.getCause());
            }
            if (execute == 0) {
                throw new TrinoException(StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR, "Python function did not return a result");
            }
            byte[] readBytes = this.memory.readBytes(execute + 4, this.memory.readInt(execute));
            deallocate(execute);
            return TrinoTypes.binaryToJava(this.returnType, new BasicSliceInput(Slices.wrappedBuffer(readBytes)));
        } catch (ChicoryException e) {
            throw fatalError("Failed to invoke Python function", e);
        }
    }

    public TrinoException fatalError(String str, ChicoryException chicoryException) {
        String strip = this.stderr.toString(StandardCharsets.UTF_8).strip();
        if (!strip.isEmpty()) {
            str = ((str + ":") + (strip.contains("\n") ? "\n" : " ")) + strip;
        }
        return new TrinoException(StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR, str, chicoryException);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        try {
            this.closer.close();
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private long[] returnError(Instance instance, long... jArr) {
        int intExact = Math.toIntExact(jArr[0]);
        int intExact2 = Math.toIntExact(jArr[1]);
        int intExact3 = Math.toIntExact(jArr[2]);
        int intExact4 = Math.toIntExact(jArr[3]);
        int intExact5 = Math.toIntExact(jArr[4]);
        Memory memory = instance.memory();
        String readString = memory.readString(intExact2, intExact3);
        RuntimeException runtimeException = null;
        if (intExact4 != 0) {
            runtimeException = new RuntimeException("Python traceback:\n" + memory.readString(intExact4, intExact5).stripTrailing());
        }
        StandardErrorCode standardErrorCode = (ErrorCodeSupplier) ERROR_CODES.get(Integer.valueOf(intExact));
        if (standardErrorCode == null) {
            standardErrorCode = StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
            readString = "Unknown error code (%s): %s".formatted(Integer.valueOf(intExact), readString);
        }
        this.error = new TrinoException(standardErrorCode, readString, runtimeException);
        return null;
    }

    private HostFunction returnErrorHostFunction() {
        return new HostFunction("trino", "return_error", List.of(ValueType.I32, ValueType.I32, ValueType.I32, ValueType.I32, ValueType.I32), List.of(), this::returnError);
    }
}
