package org.apache.beam.sdk.extensions.python;

import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.UUID;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils;
import org.apache.beam.runners.core.construction.External;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.utils.StaticSchemaInference;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transformservice.launcher.TransformServiceLauncher;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.util.ReleaseInfo;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;

/* loaded from: input_file:org/apache/beam/sdk/extensions/python/PythonExternalTransform.class */
public class PythonExternalTransform<InputT extends PInput, OutputT extends POutput> extends PTransform<InputT, OutputT> {
    private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault();
    private String fullyQualifiedName;
    private String expansionService;
    private List<String> extraPackages = new ArrayList();
    private SortedMap<String, Object> kwargsMap = new TreeMap();
    private Map<Class<?>, Schema.FieldType> typeHints = new HashMap();
    private Object[] argsArray;
    private Row providedKwargsRow;
    Map<String, Coder<?>> outputCoders;

    private PythonExternalTransform(String str, String str2) {
        this.fullyQualifiedName = str;
        this.expansionService = str2;
        this.typeHints.put(PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable()));
        this.argsArray = new Object[0];
        this.outputCoders = new HashMap();
    }

    public static <InputT extends PInput, OutputT extends POutput> PythonExternalTransform<InputT, OutputT> from(String str) {
        return new PythonExternalTransform<>(str, "");
    }

    public static <InputT extends PInput, OutputT extends POutput> PythonExternalTransform<InputT, OutputT> from(String str, String str2) {
        return new PythonExternalTransform<>(str, str2);
    }

    public PythonExternalTransform<InputT, OutputT> withArgs(Object... objArr) {
        Object[] copyOf = Arrays.copyOf(this.argsArray, this.argsArray.length + objArr.length);
        System.arraycopy(objArr, 0, copyOf, this.argsArray.length, objArr.length);
        this.argsArray = copyOf;
        return this;
    }

    public PythonExternalTransform<InputT, OutputT> withKwarg(String str, Object obj) {
        if (this.providedKwargsRow != null) {
            throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
        }
        this.kwargsMap.put(str, obj);
        return this;
    }

    public PythonExternalTransform<InputT, OutputT> withKwargs(Map<String, Object> map) {
        if (this.providedKwargsRow != null) {
            throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
        }
        this.kwargsMap.putAll(map);
        return this;
    }

    public PythonExternalTransform<InputT, OutputT> withKwargs(Row row) {
        if (this.kwargsMap.size() > 0) {
            throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
        }
        this.providedKwargsRow = row;
        return this;
    }

    public PythonExternalTransform<InputT, OutputT> withTypeHint(Class<?> cls, Schema.FieldType fieldType) {
        if (this.typeHints.containsKey(cls)) {
            throw new IllegalArgumentException(String.format("typehint for arg type %s already exists", cls));
        }
        this.typeHints.put(cls, fieldType);
        return this;
    }

    public PythonExternalTransform<InputT, OutputT> withOutputCoders(Map<String, Coder<?>> map) {
        if (this.outputCoders.size() > 0) {
            throw new IllegalArgumentException("Output coders were already specified");
        }
        this.outputCoders.putAll(map);
        return this;
    }

    public PythonExternalTransform<InputT, OutputT> withOutputCoder(Coder<?> coder) {
        if (this.outputCoders.size() > 0) {
            throw new IllegalArgumentException("Output coders were already specified");
        }
        this.outputCoders.put("random_output_key", coder);
        return this;
    }

    public PythonExternalTransform<InputT, OutputT> withExtraPackages(List<String> list) {
        if (list.isEmpty()) {
            return this;
        }
        Preconditions.checkState(Strings.isNullOrEmpty(this.expansionService), "Extra packages only apply to auto-started expansion service.");
        this.extraPackages = list;
        return this;
    }

    @VisibleForTesting
    Row buildOrGetKwargsRow() {
        if (this.providedKwargsRow != null) {
            return this.providedKwargsRow;
        }
        Schema generateSchemaFromFieldValues = generateSchemaFromFieldValues(this.kwargsMap.values().toArray(), (String[]) this.kwargsMap.keySet().toArray(new String[0]));
        generateSchemaFromFieldValues.setUUID(UUID.randomUUID());
        return Row.withSchema(generateSchemaFromFieldValues).addValues(convertComplexTypesToRows(this.kwargsMap.values().toArray())).build();
    }

    private boolean isCustomType(Class<?> cls) {
        return (ClassUtils.isPrimitiveOrWrapper(cls) || cls == String.class || this.typeHints.containsKey(cls) || Row.class.isAssignableFrom(cls)) ? false : true;
    }

    private Row convertCustomValue(Object obj) {
        SerializableFunction toRowFunction;
        try {
            toRowFunction = SCHEMA_REGISTRY.getToRowFunction(obj.getClass());
        } catch (NoSuchSchemaException e) {
            SCHEMA_REGISTRY.registerSchemaProvider(obj.getClass(), new JavaFieldSchema());
            try {
                toRowFunction = SCHEMA_REGISTRY.getToRowFunction(obj.getClass());
            } catch (NoSuchSchemaException e2) {
                throw new RuntimeException((Throwable) e2);
            }
        }
        return (Row) toRowFunction.apply(obj);
    }

    private Object[] convertComplexTypesToRows(Object[] objArr) {
        Object[] objArr2 = new Object[objArr.length];
        for (int i = 0; i < objArr.length; i++) {
            Object obj = objArr[i];
            if (obj == null) {
                throw new RuntimeException("Null values are not supported");
            }
            objArr2[i] = isCustomType(obj.getClass()) ? convertCustomValue(obj) : obj;
        }
        return objArr2;
    }

    @VisibleForTesting
    Row buildOrGetArgsRow() {
        Schema generateSchemaFromFieldValues = generateSchemaFromFieldValues(this.argsArray, null);
        generateSchemaFromFieldValues.setUUID(UUID.randomUUID());
        return Row.withSchema(generateSchemaFromFieldValues).addValues(convertComplexTypesToRows(this.argsArray)).build();
    }

    private Schema generateSchemaDirectly(Object[] objArr, String[] strArr) {
        Schema.Builder builder = Schema.builder();
        int i = 0;
        for (Object obj : objArr) {
            if (obj == null) {
                throw new RuntimeException("Null field values are not supported");
            }
            String str = strArr != null ? strArr[i] : "field" + i;
            if (obj instanceof Row) {
                builder.addRowField(str, ((Row) obj).getSchema());
            } else if (this.typeHints.containsKey(obj.getClass())) {
                builder.addField(str, this.typeHints.get(obj.getClass()));
            } else {
                builder.addField(str, StaticSchemaInference.fieldFromType(TypeDescriptor.of(obj.getClass()), JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE));
            }
            i++;
        }
        return builder.build();
    }

    private Schema generateSchemaFromFieldValues(Object[] objArr, String[] strArr) {
        return generateSchemaDirectly(objArr, strArr);
    }

    @VisibleForTesting
    ExternalTransforms.ExternalConfigurationPayload generatePayload() {
        Row buildOrGetArgsRow = buildOrGetArgsRow();
        Row buildOrGetKwargsRow = buildOrGetKwargsRow();
        Schema.Builder builder = Schema.builder();
        builder.addStringField("constructor");
        if (buildOrGetArgsRow.getValues().size() > 0) {
            builder.addRowField("args", buildOrGetArgsRow.getSchema());
        }
        if (buildOrGetKwargsRow.getValues().size() > 0) {
            builder.addRowField("kwargs", buildOrGetKwargsRow.getSchema());
        }
        Schema build = builder.build();
        build.setUUID(UUID.randomUUID());
        Row.Builder withSchema = Row.withSchema(build);
        withSchema.addValue(this.fullyQualifiedName);
        if (buildOrGetArgsRow.getValues().size() > 0) {
            withSchema.addValue(buildOrGetArgsRow);
        }
        if (buildOrGetKwargsRow.getValues().size() > 0) {
            withSchema.addValue(buildOrGetKwargsRow);
        }
        try {
            return ExternalTransforms.ExternalConfigurationPayload.newBuilder().setSchema(SchemaTranslation.schemaToProto(build, true)).setPayload(ByteString.copyFrom(CoderUtils.encodeToByteArray(RowCoder.of(build), withSchema.build()))).build();
        } catch (CoderException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public OutputT expand(InputT inputt) {
        try {
            ExternalTransforms.ExternalConfigurationPayload generatePayload = generatePayload();
            if (!Strings.isNullOrEmpty(this.expansionService)) {
                PythonService.waitForPort((String) Iterables.get(Splitter.on(':').split(this.expansionService), 0), Integer.parseInt((String) Iterables.get(Splitter.on(':').split(this.expansionService), 1)), 15000);
                return apply(inputt, this.expansionService, generatePayload);
            }
            int findAvailablePort = PythonService.findAvailablePort();
            PipelineOptionsFactory.register(PythonExternalTransformOptions.class);
            if (((PythonExternalTransformOptions) inputt.getPipeline().getOptions().as(PythonExternalTransformOptions.class)).getUseTransformService()) {
                TransformServiceLauncher forProject = TransformServiceLauncher.forProject(UUID.randomUUID().toString(), findAvailablePort);
                forProject.setBeamVersion(ReleaseInfo.getReleaseInfo().getSdkVersion());
                if (!this.extraPackages.isEmpty()) {
                    throw new RuntimeException("Transform Service does not support installing extra packages yet");
                }
                try {
                    forProject.start();
                    forProject.waitTillUp(15000);
                    OutputT apply = apply(inputt, String.format("localhost:%s", Integer.valueOf(findAvailablePort)), generatePayload);
                    forProject.shutdown();
                    return apply;
                } catch (Throwable th) {
                    forProject.shutdown();
                    throw th;
                }
            }
            ImmutableList.Builder builder = ImmutableList.builder();
            builder.add(new String[]{"--port=" + findAvailablePort, "--fully_qualified_name_glob=*", "--pickle_library=cloudpickle"});
            if (!this.extraPackages.isEmpty()) {
                File createTempFile = File.createTempFile("requirements", ".txt");
                createTempFile.deleteOnExit();
                OutputStreamWriter outputStreamWriter = new OutputStreamWriter(new FileOutputStream(createTempFile.getAbsolutePath()), Charsets.UTF_8);
                Throwable th2 = null;
                try {
                    try {
                        Iterator<String> it = this.extraPackages.iterator();
                        while (it.hasNext()) {
                            outputStreamWriter.write(it.next());
                            outputStreamWriter.write(10);
                        }
                        $closeResource(null, outputStreamWriter);
                        builder.add("--requirements_file=" + createTempFile.getAbsolutePath());
                    } finally {
                    }
                } catch (Throwable th3) {
                    $closeResource(th2, outputStreamWriter);
                    throw th3;
                }
            }
            AutoCloseable start = new PythonService("apache_beam.runners.portability.expansion_service_main", (List<String>) builder.build()).withExtraPackages(this.extraPackages).start();
            Throwable th4 = null;
            try {
                try {
                    PythonService.waitForPort("localhost", findAvailablePort, 60000);
                    OutputT apply2 = apply(inputt, String.format("localhost:%s", Integer.valueOf(findAvailablePort)), generatePayload);
                    if (start != null) {
                        $closeResource(null, start);
                    }
                    return apply2;
                } finally {
                }
            } catch (Throwable th5) {
                if (start != null) {
                    $closeResource(th4, start);
                }
                throw th5;
            }
        } catch (RuntimeException e) {
            throw e;
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    private OutputT apply(InputT inputt, String str, ExternalTransforms.ExternalConfigurationPayload externalConfigurationPayload) {
        PCollectionTuple apply;
        External.MultiOutputExpandableTransform withOutputCoder = External.of("beam:transforms:python:fully_qualified_named", externalConfigurationPayload.toByteArray(), str).withMultiOutputs().withOutputCoder(this.outputCoders);
        if (inputt instanceof PCollection) {
            apply = (PCollectionTuple) ((PCollection) inputt).apply(withOutputCoder);
        } else if (inputt instanceof PCollectionTuple) {
            apply = ((PCollectionTuple) inputt).apply(withOutputCoder);
        } else {
            if (!(inputt instanceof PBegin)) {
                throw new RuntimeException("Unhandled input type " + inputt.getClass());
            }
            apply = ((PBegin) inputt).apply(withOutputCoder);
        }
        Set keySet = apply.getAll().keySet();
        return keySet.size() == 1 ? apply.get((TupleTag) Iterables.getOnlyElement(keySet)) : apply;
    }

    private static /* synthetic */ void $closeResource(Throwable th, AutoCloseable autoCloseable) {
        if (th == null) {
            autoCloseable.close();
            return;
        }
        try {
            autoCloseable.close();
        } catch (Throwable th2) {
            th.addSuppressed(th2);
        }
    }
}
