/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.python.transforms;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import org.apache.beam.runners.core.construction.BaseExternalTest;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.extensions.python.transforms.RunInference;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.UsesPythonExpansionService;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(value=JUnit4.class)
public class RunInferenceTransformTest
extends BaseExternalTest {
    @Test
    @Category(value={ValidatesRunner.class, UsesPythonExpansionService.class})
    public void testRunInference() {
        String stagingLocation = Optional.ofNullable(System.getProperty("semiPersistDir")).orElse("/tmp");
        Schema schema = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"example", (Schema.FieldType)Schema.FieldType.array((Schema.FieldType)Schema.FieldType.INT64)), Schema.Field.of((String)"inference", (Schema.FieldType)Schema.FieldType.INT32)});
        Row row0 = Row.withSchema((Schema)schema).addArray(new Object[]{0L, 0L}).addValue((Object)0).build();
        Row row1 = Row.withSchema((Schema)schema).addArray(new Object[]{1L, 1L}).addValue((Object)1).build();
        PCollection col = (PCollection)((PCollection)this.testPipeline.apply((PTransform)Create.of(Arrays.asList(0L, 0L), (Object[])new Iterable[]{Arrays.asList(1L, 1L)}))).setCoder((Coder)IterableCoder.of((Coder)VarLongCoder.of())).apply((PTransform)RunInference.of((String)"apache_beam.ml.inference.sklearn_inference.SklearnModelHandlerNumpy", (Schema)schema).withKwarg("model_uri", (Object)String.format("%s/staged/sklearn_model", stagingLocation)).withExpansionService(expansionAddr));
        PAssert.that((PCollection)col).containsInAnyOrder((Object[])new Row[]{row0, row1});
    }

    private String getModelLoaderScriptWithKVs() {
        String s = "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy\n";
        s = s + "from apache_beam.ml.inference.base import KeyedModelHandler\n";
        s = s + "def get_model_handler(model_uri):\n";
        s = s + "  return KeyedModelHandler(SklearnModelHandlerNumpy(model_uri))\n";
        return s;
    }

    @Test
    @Category(value={ValidatesRunner.class, UsesPythonExpansionService.class})
    public void testRunInferenceWithKVs() {
        String stagingLocation = Optional.ofNullable(System.getProperty("semiPersistDir")).orElse("/tmp");
        Schema schema = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"example", (Schema.FieldType)Schema.FieldType.array((Schema.FieldType)Schema.FieldType.INT64)), Schema.Field.of((String)"inference", (Schema.FieldType)Schema.FieldType.INT32)});
        Row row0 = Row.withSchema((Schema)schema).addArray(new Object[]{0L, 0L}).addValue((Object)0).build();
        Row row1 = Row.withSchema((Schema)schema).addArray(new Object[]{1L, 1L}).addValue((Object)1).build();
        PCollection col = (PCollection)((PCollection)((PCollection)((PCollection)this.testPipeline.apply((PTransform)Create.of(Arrays.asList(0L, 0L), (Object[])new Iterable[]{Arrays.asList(1L, 1L)}))).apply((PTransform)MapElements.via((SimpleFunction)new KVFn()))).setCoder((Coder)KvCoder.of((Coder)VarLongCoder.of(), (Coder)IterableCoder.of((Coder)VarLongCoder.of()))).apply((PTransform)RunInference.ofKVs((String)this.getModelLoaderScriptWithKVs(), (Schema)schema, (Coder)VarLongCoder.of()).withKwarg("model_uri", (Object)String.format("%s/staged/sklearn_model", stagingLocation)).withExpansionService(expansionAddr))).apply((PTransform)Values.create());
        PAssert.that((PCollection)col).containsInAnyOrder((Object[])new Row[]{row0, row1});
    }

    static class KVFn
    extends SimpleFunction<Iterable<Long>, KV<Long, Iterable<Long>>> {
        KVFn() {
        }

        public KV<Long, Iterable<Long>> apply(Iterable<Long> input) {
            Long key = (Long)((List)input).get(0);
            return KV.of((Object)key, input);
        }
    }
}

