package org.apache.beam.sdk.extensions.python.transforms;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.extensions.python.PythonExternalTransform;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/sdk/extensions/python/transforms/RunInference.class */
public class RunInference<OutputT> extends PTransform<PCollection<?>, PCollection<OutputT>> {
    private static final Logger LOG = LoggerFactory.getLogger(RunInference.class);
    private final String modelLoader;
    private final Schema schema;
    private final Map<String, Object> kwargs;
    private final String expansionService;
    private final Coder<?> keyCoder;
    private final List<String> extraPackages = new ArrayList();

    public static RunInference<Row> of(String str, Schema.FieldType fieldType, Schema.FieldType fieldType2) {
        return new RunInference<>(str, Schema.of(new Schema.Field[]{Schema.Field.of("example", fieldType), Schema.Field.of("inference", fieldType2)}), ImmutableMap.of(), null, "");
    }

    public static <KeyT> RunInference<KV<KeyT, Row>> ofKVs(String str, Schema.FieldType fieldType, Schema.FieldType fieldType2, Coder<KeyT> coder) {
        return new RunInference<>(str, Schema.of(new Schema.Field[]{Schema.Field.of("example", fieldType), Schema.Field.of("inference", fieldType2)}), ImmutableMap.of(), coder, "");
    }

    public static RunInference<Row> of(String str, Schema schema) {
        return new RunInference<>(str, schema, ImmutableMap.of(), null, "");
    }

    public static <KeyT> RunInference<KV<KeyT, Row>> ofKVs(String str, Schema schema, Coder<KeyT> coder) {
        return new RunInference<>(str, schema, ImmutableMap.of(), coder, "");
    }

    public RunInference<OutputT> withKwarg(String str, Object obj) {
        return new RunInference<>(this.modelLoader, this.schema, ImmutableMap.builder().putAll(this.kwargs).put(str, obj).build(), this.keyCoder, this.expansionService);
    }

    public RunInference<OutputT> withExtraPackages(List<String> list) {
        if (!this.extraPackages.isEmpty()) {
            throw new IllegalArgumentException("Extra packages were already specified");
        }
        this.extraPackages.addAll(list);
        return this;
    }

    public RunInference<OutputT> withExpansionService(String str) {
        return new RunInference<>(this.modelLoader, this.schema, this.kwargs, this.keyCoder, str);
    }

    private RunInference(String str, Schema schema, Map<String, Object> map, Coder<?> coder, String str2) {
        this.modelLoader = str;
        this.schema = schema;
        this.kwargs = map;
        this.keyCoder = coder;
        this.expansionService = str2;
    }

    private List<String> inferExtraPackagesFromModelHandler() {
        ArrayList arrayList = new ArrayList();
        if (this.modelLoader.toLowerCase().contains("sklearn")) {
            arrayList.add("scikit-learn");
            arrayList.add("pandas");
        } else if (this.modelLoader.toLowerCase().contains("pytorch")) {
            arrayList.add("torch");
        }
        if (!arrayList.isEmpty()) {
            LOG.info("Automatically inferred dependencies {} from the provided model handler.", arrayList);
        }
        return arrayList;
    }

    public PCollection<OutputT> expand(PCollection<?> pCollection) {
        RowCoder of = this.keyCoder == null ? RowCoder.of(this.schema) : KvCoder.of(this.keyCoder, RowCoder.of(this.schema));
        if (this.expansionService.isEmpty() && this.extraPackages.isEmpty()) {
            this.extraPackages.addAll(inferExtraPackagesFromModelHandler());
        }
        return pCollection.apply(PythonExternalTransform.from("apache_beam.ml.inference.base.RunInference.from_callable", this.expansionService).withKwarg("model_handler_provider", PythonCallableSource.of(this.modelLoader)).withOutputCoder(of).withExtraPackages(this.extraPackages).withKwargs(this.kwargs));
    }
}
