package au.csiro.pathling.library;

import au.csiro.pathling.config.EncodingConfiguration;
import au.csiro.pathling.config.HttpClientCachingConfiguration;
import au.csiro.pathling.config.HttpClientCachingStorageType;
import au.csiro.pathling.config.HttpClientConfiguration;
import au.csiro.pathling.config.TerminologyAuthConfiguration;
import au.csiro.pathling.config.TerminologyConfiguration;
import au.csiro.pathling.encoders.FhirEncoders;
import au.csiro.pathling.fhirpath.encoding.CodingEncoding;
import au.csiro.pathling.terminology.DefaultTerminologyServiceFactory;
import au.csiro.pathling.terminology.TerminologyService;
import au.csiro.pathling.terminology.TerminologyServiceFactory;
import au.csiro.pathling.test.SchemaAsserts;
import au.csiro.pathling.test.helpers.TerminologyServiceHelpers;
import ca.uhn.fhir.context.FhirVersionEnum;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.validation.ConstraintViolationException;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.streaming.OutputMode;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.hl7.fhir.r4.model.Coding;
import org.hl7.fhir.r4.model.Condition;
import org.hl7.fhir.r4.model.Enumerations;
import org.hl7.fhir.r4.model.codesystems.ConceptMapEquivalence;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.mutable.WrappedArray;

/* loaded from: input_file:au/csiro/pathling/library/PathlingContextTest.class */
public class PathlingContextTest {
    private static SparkSession spark;
    private static final String testDataUrl = "target/encoders-tests/data";
    private TerminologyServiceFactory terminologyServiceFactory;
    private TerminologyService terminologyService;
    private static final String GUID_REG_SUBEXPRESSION = "[0-9a-fA-F]{8}-([0-9a-fA-F]{4}-){3}[0-9a-fA-F]{12}";
    private static final Logger log = LoggerFactory.getLogger(PathlingContextTest.class);
    public static final Pattern GUID_REGEX = Pattern.compile("^[0-9a-fA-F]{8}-([0-9a-fA-F]{4}-){3}[0-9a-fA-F]{12}$");
    private static final Pattern RELATIVE_REF_REGEX = Pattern.compile("^[A-Z][A-Za-z]+/[0-9a-fA-F]{8}-([0-9a-fA-F]{4}-){3}[0-9a-fA-F]{12}$");

    @BeforeAll
    public static void setUpAll() {
        spark = TestHelpers.spark();
    }

    @AfterAll
    public static void tearDownAll() {
        spark.stop();
    }

    public static boolean isValidGUID(@Nonnull String str) {
        return GUID_REGEX.matcher(str).matches();
    }

    public static boolean isValidRelativeReference(@Nonnull String str) {
        return RELATIVE_REF_REGEX.matcher(str).matches();
    }

    public void assertAllGUIDs(@Nonnull Collection<String> collection) {
        List list = (List) collection.stream().filter(Predicate.not(PathlingContextTest::isValidGUID)).limit(7L).collect(Collectors.toUnmodifiableList());
        Assertions.assertTrue(list.isEmpty(), "All values should be GUIDs, but some are not: " + list);
    }

    public void assertAllRelativeReferences(@Nonnull Collection<String> collection) {
        List list = (List) collection.stream().filter(Predicate.not(PathlingContextTest::isValidRelativeReference)).limit(7L).collect(Collectors.toUnmodifiableList());
        Assertions.assertTrue(list.isEmpty(), "All values should be relative references, but some are not: " + list);
    }

    public <T> void assertValidIdColumns(@Nonnull Dataset<T> dataset) {
        assertAllGUIDs(dataset.select("id", new String[0]).as(Encoders.STRING()).collectAsList());
        assertAllRelativeReferences(dataset.select("id_versioned", new String[0]).as(Encoders.STRING()).collectAsList());
    }

    public <T> void assertValidRelativeRefColumns(@Nonnull Dataset<T> dataset, Column... columnArr) {
        Stream.of((Object[]) columnArr).forEach(column -> {
            assertAllRelativeReferences(dataset.select(new Column[]{column.getField("reference")}).as(Encoders.STRING()).collectAsList());
        });
    }

    @BeforeEach
    public void setUp() {
        this.terminologyServiceFactory = (TerminologyServiceFactory) Mockito.mock(TerminologyServiceFactory.class, Mockito.withSettings().serializable());
        this.terminologyService = (TerminologyService) Mockito.mock(TerminologyService.class, Mockito.withSettings().serializable());
        Mockito.when(this.terminologyServiceFactory.build()).thenReturn(this.terminologyService);
        DefaultTerminologyServiceFactory.reset();
    }

    @Test
    public void testEncodeResourcesFromJsonBundle() {
        Dataset textFile = spark.read().option("wholetext", true).textFile("target/encoders-tests/data/bundles/R4/json");
        PathlingContext create = PathlingContext.create(spark);
        Dataset encodeBundle = create.encodeBundle(textFile.toDF(), "Patient", "application/fhir+json");
        Assertions.assertEquals(5L, encodeBundle.count());
        assertValidIdColumns(encodeBundle);
        Dataset encodeBundle2 = create.encodeBundle(textFile.toDF(), "Patient");
        Assertions.assertEquals(5L, encodeBundle2.count());
        assertValidIdColumns(encodeBundle2);
        Dataset encodeBundle3 = create.encodeBundle(textFile, Condition.class, "application/fhir+json");
        Assertions.assertEquals(107L, encodeBundle3.count());
        assertValidIdColumns(encodeBundle3);
        assertValidRelativeRefColumns(encodeBundle3, functions.col("subject"));
        assertValidRelativeRefColumns(encodeBundle3, functions.col("encounter"));
    }

    @Test
    public void testEncodeResourcesFromXmlBundle() {
        Dataset encodeBundle = PathlingContext.create(spark).encodeBundle(spark.read().option("wholetext", true).textFile("target/encoders-tests/data/bundles/R4/xml"), Condition.class, "application/fhir+xml");
        Assertions.assertEquals(107L, encodeBundle.count());
        assertValidIdColumns(encodeBundle);
        assertValidRelativeRefColumns(encodeBundle, functions.col("subject"));
        assertValidRelativeRefColumns(encodeBundle, functions.col("encounter"));
    }

    @Test
    public void testEncodeResourcesFromJson() {
        Dataset textFile = spark.read().textFile("target/encoders-tests/data/resources/R4/json");
        PathlingContext create = PathlingContext.create(spark);
        Dataset encode = create.encode(textFile.toDF(), "Patient", "application/fhir+json");
        Assertions.assertEquals(9L, encode.count());
        assertValidIdColumns(encode);
        Dataset encode2 = create.encode(textFile, Condition.class, "application/fhir+json");
        Assertions.assertEquals(71L, encode2.count());
        assertValidIdColumns(encode2);
        assertValidRelativeRefColumns(encode2, functions.col("subject"));
        assertValidRelativeRefColumns(encode2, functions.col("encounter"));
    }

    @Test
    public void testEncoderOptions() {
        Dataset text = spark.read().text("target/encoders-tests/data/resources/R4/json");
        Row row = (Row) PathlingContext.create(spark).encode(text, "Questionnaire").head();
        SchemaAsserts.assertFieldPresent("_extension", row.schema());
        SchemaAsserts.assertFieldPresent("item", ((Row) row.getList(row.fieldIndex("item")).get(0)).schema());
        Row row2 = (Row) PathlingContext.create(spark, EncodingConfiguration.builder().enableExtensions(false).maxNestingLevel(1).build()).encode(text, "Questionnaire").head();
        SchemaAsserts.assertFieldNotPresent("_extension", row2.schema());
        Row row3 = (Row) row2.getList(row2.fieldIndex("item")).get(0);
        SchemaAsserts.assertFieldPresent("item", row3.schema());
        SchemaAsserts.assertFieldNotPresent("item", ((Row) row3.getList(row3.fieldIndex("item")).get(0)).schema());
        Row row4 = (Row) PathlingContext.create(spark, EncodingConfiguration.builder().enableExtensions(true).openTypes(Set.of("boolean", "string", "Address")).build()).encode(text, "Patient").head();
        SchemaAsserts.assertFieldPresent("_extension", row4.schema());
        Row row5 = (Row) ((WrappedArray[]) row4.getJavaMap(row4.fieldIndex("_extension")).values().toArray(i -> {
            return new WrappedArray[i];
        }))[0].apply(0);
        SchemaAsserts.assertFieldPresent("valueString", row5.schema());
        SchemaAsserts.assertFieldPresent("valueAddress", row5.schema());
        SchemaAsserts.assertFieldPresent("valueBoolean", row5.schema());
        SchemaAsserts.assertFieldNotPresent("valueInteger", row5.schema());
    }

    @Test
    public void testEncodeResourceStream() throws Exception {
        PathlingContext create = PathlingContext.create(spark, EncodingConfiguration.builder().enableExtensions(true).build());
        Dataset text = spark.readStream().text("target/encoders-tests/data/resources/R4/json");
        Assertions.assertTrue(text.isStreaming());
        Dataset encode = create.encode(text, "Patient", "application/fhir+json");
        Assertions.assertTrue(encode.isStreaming());
        encode.writeStream().queryName("patients").format("memory").start().processAllAvailable();
        Assertions.assertEquals(9L, ((Row) spark.sql("select count(*) from patients").head()).getLong(0));
        create.encode(text, "Condition", "application/fhir+json").groupBy(new Column[0]).count().writeStream().outputMode(OutputMode.Complete()).queryName("countCondition").format("memory").start().processAllAvailable();
        Assertions.assertEquals(71L, ((Row) spark.sql("select * from countCondition").head()).getLong(0));
    }

    @Test
    void testMemberOf() {
        Coding coding = new Coding("urn:test:123", "ABC", "abc");
        Coding coding2 = new Coding("urn:test:123", "DEF", "def");
        TerminologyServiceHelpers.setupValidate(this.terminologyService).withValueSet("urn:test:456", new Coding[]{coding});
        List collectAsList = PathlingContext.create(spark, FhirEncoders.forR4().getOrCreate(), this.terminologyServiceFactory).memberOf(spark.createDataFrame(List.of(RowFactory.create(new Object[]{"foo", CodingEncoding.encode(coding)}), RowFactory.create(new Object[]{"bar", CodingEncoding.encode(coding2)})), DataTypes.createStructType(new StructField[]{DataTypes.createStructField("id", DataTypes.StringType, true), DataTypes.createStructField("coding", CodingEncoding.codingStructType(), true)})), functions.col("coding"), "urn:test:456", "result").select("id", new String[]{"result"}).collectAsList();
        Assertions.assertEquals(RowFactory.create(new Object[]{"foo", true}), collectAsList.get(0));
        Assertions.assertEquals(RowFactory.create(new Object[]{"bar", false}), collectAsList.get(1));
    }

    @Test
    void testTranslate() {
        Coding coding = new Coding("urn:test:123", "ABC", "abc");
        Coding coding2 = new Coding("urn:test:123", "DEF", "def");
        TerminologyServiceHelpers.setupTranslate(this.terminologyService).withTranslations(coding, "urn:test:456", new TerminologyService.Translation[]{TerminologyService.Translation.of(ConceptMapEquivalence.EQUIVALENT, coding2)});
        Assertions.assertEquals(RowFactory.create(new Object[]{"foo", WrappedArray.make(new Row[]{CodingEncoding.encode(coding2)})}), PathlingContext.create(spark, FhirEncoders.forR4().getOrCreate(), this.terminologyServiceFactory).translate(spark.createDataFrame(List.of(RowFactory.create(new Object[]{"foo", CodingEncoding.encode(coding)}), RowFactory.create(new Object[]{"bar", CodingEncoding.encode(coding2)})), DataTypes.createStructType(new StructField[]{DataTypes.createStructField("id", DataTypes.StringType, true), DataTypes.createStructField("coding", CodingEncoding.codingStructType(), true)})), functions.col("coding"), "urn:test:456", false, Enumerations.ConceptMapEquivalence.EQUIVALENT.toCode(), (String) null, "result").select("id", new String[]{"result"}).collectAsList().get(0));
    }

    @Test
    void testSubsumes() {
        Coding coding = new Coding("urn:test:123", "ABC", "abc");
        Coding coding2 = new Coding("urn:test:123", "DEF", "def");
        Coding coding3 = new Coding("urn:test:123", "GHI", "ghi");
        TerminologyServiceHelpers.setupSubsumes(this.terminologyService).withSubsumes(coding, coding2);
        List collectAsList = PathlingContext.create(spark, FhirEncoders.forR4().getOrCreate(), this.terminologyServiceFactory).subsumes(spark.createDataFrame(List.of(RowFactory.create(new Object[]{"foo", CodingEncoding.encode(coding), CodingEncoding.encode(coding2)}), RowFactory.create(new Object[]{"bar", CodingEncoding.encode(coding), CodingEncoding.encode(coding3)})), DataTypes.createStructType(new StructField[]{DataTypes.createStructField("id", DataTypes.StringType, true), DataTypes.createStructField("leftCoding", CodingEncoding.codingStructType(), true), DataTypes.createStructField("rightCoding", CodingEncoding.codingStructType(), true)})), functions.col("leftCoding"), functions.col("rightCoding"), "result").select("id", new String[]{"result"}).collectAsList();
        Assertions.assertEquals(RowFactory.create(new Object[]{"foo", true}), collectAsList.get(0));
        Assertions.assertEquals(RowFactory.create(new Object[]{"bar", false}), collectAsList.get(1));
    }

    @Test
    void testBuildContextWithTerminologyDefaults() {
        TerminologyConfiguration build = TerminologyConfiguration.builder().serverUrl("https://tx.ontoserver.csiro.au/fhir").build();
        PathlingContext create = PathlingContext.create(spark, build);
        Assertions.assertNotNull(create);
        DefaultTerminologyServiceFactory defaultTerminologyServiceFactory = new DefaultTerminologyServiceFactory(FhirVersionEnum.R4, build);
        TerminologyServiceFactory terminologyServiceFactory = create.getTerminologyServiceFactory();
        Assertions.assertEquals(defaultTerminologyServiceFactory, terminologyServiceFactory);
        Assertions.assertNotNull(terminologyServiceFactory.build());
    }

    @Test
    void testBuildContextWithTerminologyNoCache() {
        TerminologyConfiguration build = TerminologyConfiguration.builder().serverUrl("https://tx.ontoserver.csiro.au/fhir").cache(HttpClientCachingConfiguration.builder().enabled(false).build()).build();
        PathlingContext create = PathlingContext.create(spark, build);
        Assertions.assertNotNull(create);
        DefaultTerminologyServiceFactory defaultTerminologyServiceFactory = new DefaultTerminologyServiceFactory(FhirVersionEnum.R4, build);
        TerminologyServiceFactory terminologyServiceFactory = create.getTerminologyServiceFactory();
        Assertions.assertEquals(defaultTerminologyServiceFactory, terminologyServiceFactory);
        Assertions.assertNotNull(terminologyServiceFactory.build());
    }

    @Test
    void testBuildContextWithCustomizedTerminology() throws IOException {
        HttpClientCachingStorageType httpClientCachingStorageType = HttpClientCachingStorageType.DISK;
        File file = Files.createTempDirectory("pathling-cache", new FileAttribute[0]).toFile();
        file.deleteOnExit();
        String absolutePath = file.getAbsolutePath();
        HttpClientConfiguration build = HttpClientConfiguration.builder().maxConnectionsTotal(66).maxConnectionsPerRoute(33).socketTimeout(123).build();
        HttpClientCachingConfiguration build2 = HttpClientCachingConfiguration.builder().maxEntries(1233).storageType(httpClientCachingStorageType).storagePath(absolutePath).build();
        TerminologyConfiguration build3 = TerminologyConfiguration.builder().serverUrl("https://r4.ontoserver.csiro.au/fhir").verboseLogging(true).client(build).cache(build2).authentication(TerminologyAuthConfiguration.builder().tokenEndpoint("https://auth.ontoserver.csiro.au/auth/realms/aehrc/protocol/openid-connect/token").clientId("some-client").clientSecret("some-secret").scope("openid").tokenExpiryTolerance(300L).build()).build();
        PathlingContext create = PathlingContext.create(spark, build3);
        Assertions.assertNotNull(create);
        DefaultTerminologyServiceFactory defaultTerminologyServiceFactory = new DefaultTerminologyServiceFactory(FhirVersionEnum.R4, build3);
        TerminologyServiceFactory terminologyServiceFactory = create.getTerminologyServiceFactory();
        Assertions.assertEquals(defaultTerminologyServiceFactory, terminologyServiceFactory);
        Assertions.assertNotNull(terminologyServiceFactory.build());
    }

    @Test
    public void failsOnInvalidTerminologyConfiguration() {
        TerminologyConfiguration build = TerminologyConfiguration.builder().serverUrl("not-a-URL").client((HttpClientConfiguration) null).cache(HttpClientCachingConfiguration.builder().storageType(HttpClientCachingStorageType.DISK).build()).build();
        Assertions.assertEquals("Invalid terminology configuration: cache: If the storage type is disk, then a storage path must be supplied., client: must not be null, serverUrl: must be a valid URL", Assertions.assertThrows(ConstraintViolationException.class, () -> {
            PathlingContext.create(spark, build);
        }).getMessage());
    }

    @Test
    public void failsOnInvalidEncodingConfiguration() {
        TerminologyConfiguration build = TerminologyConfiguration.builder().build();
        EncodingConfiguration build2 = EncodingConfiguration.builder().maxNestingLevel(-10).openTypes((Set) null).build();
        Assertions.assertEquals("Invalid encoding configuration: maxNestingLevel: must be greater than or equal to 0, openTypes: must not be null", Assertions.assertThrows(ConstraintViolationException.class, () -> {
            PathlingContext.create(spark, build2, build);
        }).getMessage());
    }
}
