package au.csiro.pathling.encoders;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.functions;
import org.hl7.fhir.exceptions.FHIRException;
import org.hl7.fhir.r4.model.Annotation;
import org.hl7.fhir.r4.model.Base;
import org.hl7.fhir.r4.model.Coding;
import org.hl7.fhir.r4.model.Condition;
import org.hl7.fhir.r4.model.Encounter;
import org.hl7.fhir.r4.model.Identifier;
import org.hl7.fhir.r4.model.MedicationRequest;
import org.hl7.fhir.r4.model.Observation;
import org.hl7.fhir.r4.model.Patient;
import org.hl7.fhir.r4.model.Quantity;
import org.hl7.fhir.r4.model.Questionnaire;
import org.hl7.fhir.r4.model.QuestionnaireResponse;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:au/csiro/pathling/encoders/FhirEncodersTest.class */
public class FhirEncodersTest {
    private static SparkSession spark;
    private static Dataset<Patient> patientDataset;
    private static Patient decodedPatient;
    private static Dataset<Condition> conditionsDataset;
    private static Condition decodedCondition;
    private static Dataset<Observation> observationsDataset;
    private static Observation decodedObservation;
    private static Dataset<Condition> conditionsWithVersionDataset;
    private static Dataset<MedicationRequest> medDataset;
    private static MedicationRequest decodedMedRequest;
    private static Condition decodedConditionWithVersion;
    private static Dataset<Encounter> encounterDataset;
    private static Encounter decodedEncounter;
    private static Dataset<Questionnaire> questionnaireDataset;
    private static Questionnaire decodedQuestionnaire;
    private static Dataset<QuestionnaireResponse> questionnaireResponseDataset;
    private static QuestionnaireResponse decodedQuestionnaireResponse;
    private static final FhirEncoders ENCODERS_L0 = FhirEncoders.forR4().getOrCreate();
    private static final int NESTING_LEVEL_3 = 3;
    private static final FhirEncoders ENCODERS_L3 = FhirEncoders.forR4().withMaxNestingLevel(NESTING_LEVEL_3).getOrCreate();
    private static final Patient patient = TestData.newPatient();
    private static final Condition condition = TestData.newCondition();
    private static final Condition conditionWithVersion = TestData.conditionWithVersion();
    private static final Observation observation = TestData.newObservation();
    private static final MedicationRequest medRequest = TestData.newMedRequest();
    private static final Encounter encounter = TestData.newEncounter();
    private static final Questionnaire questionnaire = TestData.newQuestionnaire();
    private static final QuestionnaireResponse questionnaireResponse = TestData.newQuestionnaireResponse();

    @BeforeAll
    public static void setUp() {
        spark = SparkSession.builder().master("local[*]").appName("testing").config("spark.driver.bindAddress", "localhost").config("spark.driver.host", "localhost").getOrCreate();
        patientDataset = spark.createDataset(ImmutableList.of(patient), ENCODERS_L0.of(Patient.class));
        decodedPatient = (Patient) patientDataset.head();
        conditionsDataset = spark.createDataset(ImmutableList.of(condition), ENCODERS_L0.of(Condition.class));
        decodedCondition = (Condition) conditionsDataset.head();
        conditionsWithVersionDataset = spark.createDataset(ImmutableList.of(conditionWithVersion), ENCODERS_L0.of(Condition.class));
        decodedConditionWithVersion = (Condition) conditionsWithVersionDataset.head();
        observationsDataset = spark.createDataset(ImmutableList.of(observation), ENCODERS_L0.of(Observation.class));
        decodedObservation = (Observation) observationsDataset.head();
        medDataset = spark.createDataset(ImmutableList.of(medRequest), ENCODERS_L0.of(MedicationRequest.class));
        decodedMedRequest = (MedicationRequest) medDataset.head();
        encounterDataset = spark.createDataset(ImmutableList.of(encounter), ENCODERS_L0.of(Encounter.class));
        decodedEncounter = (Encounter) encounterDataset.head();
        questionnaireDataset = spark.createDataset(ImmutableList.of(questionnaire), ENCODERS_L0.of(Questionnaire.class));
        decodedQuestionnaire = (Questionnaire) questionnaireDataset.head();
        questionnaireResponseDataset = spark.createDataset(ImmutableList.of(questionnaireResponse), ENCODERS_L0.of(QuestionnaireResponse.class));
        decodedQuestionnaireResponse = (QuestionnaireResponse) questionnaireResponseDataset.head();
    }

    @AfterAll
    public static void tearDown() {
        spark.stop();
    }

    @Test
    public void testResourceId() {
        Assertions.assertEquals(condition.getId(), ((Row) conditionsDataset.select("id", new String[0]).head()).get(0));
        Assertions.assertEquals(condition.getId(), decodedCondition.getId());
    }

    @Test
    public void testResourceWithVersionId() {
        Assertions.assertEquals("with-version", ((Row) conditionsWithVersionDataset.select("id", new String[0]).head()).get(0));
        Assertions.assertEquals(conditionWithVersion.getId(), ((Row) conditionsWithVersionDataset.select("id_versioned", new String[0]).head()).get(0));
        Assertions.assertEquals(conditionWithVersion.getId(), decodedConditionWithVersion.getId());
    }

    @Test
    public void testResourceLanguage() {
        Assertions.assertEquals(condition.getLanguage(), ((Row) conditionsDataset.select("language", new String[0]).head()).get(0));
        Assertions.assertEquals(condition.getLanguage(), decodedCondition.getLanguage());
    }

    @Test
    public void boundCode() {
        GenericRowWithSchema struct = ((Row) conditionsDataset.select("verificationStatus", new String[0]).head()).getStruct(0);
        GenericRowWithSchema genericRowWithSchema = (GenericRowWithSchema) struct.getList(struct.fieldIndex("coding")).get(0);
        Assertions.assertEquals(condition.getVerificationStatus().getCoding().size(), 1);
        Assertions.assertEquals(condition.getVerificationStatus().getCodingFirstRep().getSystem(), genericRowWithSchema.getString(genericRowWithSchema.fieldIndex("system")));
        Assertions.assertEquals(condition.getVerificationStatus().getCodingFirstRep().getCode(), genericRowWithSchema.getString(genericRowWithSchema.fieldIndex("code")));
    }

    @Test
    public void choiceValue() {
        Assertions.assertEquals(condition.getOnset().getValueAsString(), ((Row) conditionsDataset.select("onsetDateTime", new String[0]).head()).get(0));
        Assertions.assertEquals(condition.getOnset().toString(), decodedCondition.getOnset().toString());
    }

    @Test
    public void narrative() {
        Assertions.assertEquals(condition.getText().getStatus().toCode(), ((Row) conditionsDataset.select("text.status", new String[0]).head()).get(0));
        Assertions.assertEquals(condition.getText().getStatus(), decodedCondition.getText().getStatus());
        Assertions.assertEquals(condition.getText().getDivAsString(), ((Row) conditionsDataset.select("text.div", new String[0]).head()).get(0));
        Assertions.assertEquals(condition.getText().getDivAsString(), decodedCondition.getText().getDivAsString());
    }

    @Test
    public void coding() {
        Coding codingFirstRep = condition.getSeverity().getCodingFirstRep();
        Coding codingFirstRep2 = decodedCondition.getSeverity().getCodingFirstRep();
        Dataset cache = conditionsDataset.select(new Column[]{functions.explode(conditionsDataset.col("severity.coding")).alias("coding")}).select("coding.*", new String[0]).cache();
        Assertions.assertEquals(codingFirstRep.getCode(), ((Row) cache.select("code", new String[0]).head()).get(0));
        Assertions.assertEquals(codingFirstRep.getCode(), codingFirstRep2.getCode());
        Assertions.assertEquals(codingFirstRep.getSystem(), ((Row) cache.select("system", new String[0]).head()).get(0));
        Assertions.assertEquals(codingFirstRep.getSystem(), codingFirstRep2.getSystem());
        Assertions.assertEquals(Boolean.valueOf(codingFirstRep.getUserSelected()), ((Row) cache.select("userSelected", new String[0]).head()).get(0));
        Assertions.assertEquals(Boolean.valueOf(codingFirstRep.getUserSelected()), Boolean.valueOf(codingFirstRep2.getUserSelected()));
        Assertions.assertEquals(codingFirstRep.getDisplay(), ((Row) cache.select("display", new String[0]).head()).get(0));
        Assertions.assertEquals(codingFirstRep.getDisplay(), codingFirstRep2.getDisplay());
    }

    @Test
    public void reference() {
        Condition conditionWithReferencesWithIdentifiers = TestData.conditionWithReferencesWithIdentifiers();
        Dataset createDataset = spark.createDataset(ImmutableList.of(conditionWithReferencesWithIdentifiers), ENCODERS_L3.of(Condition.class));
        Condition condition2 = (Condition) createDataset.head();
        Assertions.assertEquals(RowFactory.create(new Object[]{"withReferencesWithIdentifiers", "Patient/example", "http://terminology.hl7.org/CodeSystem/v2-0203", "MR", "https://fhir.example.com/identifiers/mrn", "urn:id"}), createDataset.select(new Column[]{functions.col("id"), functions.col("subject.reference"), functions.col("subject.identifier.type.coding.system").getItem(0), functions.col("subject.identifier.type.coding.code").getItem(0), functions.col("subject.identifier.system"), functions.col("subject.identifier.value")}).head());
        Assertions.assertEquals("Patient/example", condition2.getSubject().getReference());
        Assertions.assertEquals("urn:id", condition2.getSubject().getIdentifier().getValue());
        Assertions.assertTrue(conditionWithReferencesWithIdentifiers.getSubject().getIdentifier().hasAssigner());
        Assertions.assertFalse(condition2.getSubject().getIdentifier().hasAssigner());
    }

    @Test
    public void identifier() {
        Condition conditionWithIdentifiersWithReferences = TestData.conditionWithIdentifiersWithReferences();
        Dataset createDataset = spark.createDataset(ImmutableList.of(conditionWithIdentifiersWithReferences), ENCODERS_L3.of(Condition.class));
        Condition condition2 = (Condition) createDataset.head();
        Assertions.assertEquals(RowFactory.create(new Object[]{"withIdentifiersWithReferences", "http://terminology.hl7.org/CodeSystem/v2-0203", "MR", "https://fhir.example.com/identifiers/mrn", "urn:id01", "Organization/001", "urn:id02"}), createDataset.select(new Column[]{functions.col("id"), functions.col("identifier.type.coding").getItem(0).getField("system").getItem(0), functions.col("identifier.type.coding").getItem(0).getField("code").getItem(0), functions.col("identifier.system").getItem(0), functions.col("identifier.value").getItem(0), functions.col("identifier.assigner.reference").getItem(0), functions.col("identifier.assigner.identifier.value").getItem(0)}).head());
        Assertions.assertTrue(((Identifier) conditionWithIdentifiersWithReferences.getIdentifier().get(0)).getAssigner().getIdentifier().hasAssigner());
        Assertions.assertFalse(((Identifier) condition2.getIdentifier().get(0)).getAssigner().getIdentifier().hasAssigner());
    }

    @Test
    public void integer() {
        Assertions.assertEquals(patient.getMultipleBirth().getValue(), ((Row) patientDataset.select("multipleBirthInteger", new String[0]).head()).get(0));
        Assertions.assertEquals((Integer) patient.getMultipleBirth().getValue(), (Integer) decodedPatient.getMultipleBirth().getValue());
    }

    @Test
    public void bigDecimal() {
        BigDecimal value = observation.getValue().getValue();
        BigDecimal bigDecimal = (BigDecimal) ((Row) observationsDataset.select("valueQuantity.value", new String[0]).head()).get(0);
        int i = ((Row) observationsDataset.select("valueQuantity.value_scale", new String[0]).head()).getInt(0);
        Assertions.assertEquals(0, value.compareTo(bigDecimal));
        Assertions.assertEquals(value.scale(), i);
        Assertions.assertEquals(value, decodedObservation.getValue().getValue());
        Assertions.assertEquals(TestData.TEST_VERY_BIG_DECIMAL, ((Observation.ObservationReferenceRangeComponent) decodedObservation.getReferenceRange().get(0)).getHigh().getValue());
        Assertions.assertEquals(TestData.TEST_VERY_SMALL_DECIMAL_SCALE_6, ((Observation.ObservationReferenceRangeComponent) decodedObservation.getReferenceRange().get(0)).getLow().getValue());
    }

    @Test
    public void choiceBigDecimalInQuestionnaire() {
        BigDecimal bigDecimal = (BigDecimal) questionnaire.getItemFirstRep().getEnableWhenFirstRep().getAnswerDecimalType().getValue();
        BigDecimal bigDecimal2 = (BigDecimal) ((Row) questionnaireDataset.select(new Column[]{functions.col("item").getItem(0).getField("enableWhen").getItem(0).getField("answerDecimal")}).head()).get(0);
        int i = ((Row) questionnaireDataset.select(new Column[]{functions.col("item").getItem(0).getField("enableWhen").getItem(0).getField("answerDecimal_scale")}).head()).getInt(0);
        Assertions.assertEquals(0, bigDecimal.compareTo(bigDecimal2));
        Assertions.assertEquals(bigDecimal.scale(), i);
        Assertions.assertEquals(bigDecimal, (BigDecimal) decodedQuestionnaire.getItemFirstRep().getEnableWhenFirstRep().getAnswerDecimalType().getValue());
        Assertions.assertEquals(TestData.TEST_VERY_BIG_DECIMAL, decodedQuestionnaire.getItemFirstRep().getInitialFirstRep().getValueDecimalType().getValue());
    }

    @Test
    public void choiceBigDecimalInQuestionnaireResponse() {
        BigDecimal bigDecimal = (BigDecimal) questionnaireResponse.getItemFirstRep().getAnswerFirstRep().getValueDecimalType().getValue();
        BigDecimal bigDecimal2 = (BigDecimal) ((Row) questionnaireResponseDataset.select(new Column[]{functions.col("item").getItem(0).getField("answer").getItem(0).getField("valueDecimal")}).head()).get(0);
        int i = ((Row) questionnaireResponseDataset.select(new Column[]{functions.col("item").getItem(0).getField("answer").getItem(0).getField("valueDecimal_scale")}).head()).getInt(0);
        Assertions.assertEquals(0, bigDecimal.compareTo(bigDecimal2));
        Assertions.assertEquals(bigDecimal.scale(), i);
        Assertions.assertEquals(bigDecimal, (BigDecimal) decodedQuestionnaireResponse.getItemFirstRep().getAnswerFirstRep().getValueDecimalType().getValue());
        Assertions.assertEquals(TestData.TEST_VERY_SMALL_DECIMAL_SCALE_6, decodedQuestionnaireResponse.getItemFirstRep().getAnswerFirstRep().getValueDecimalType().getValue());
        Assertions.assertEquals(TestData.TEST_VERY_BIG_DECIMAL, ((QuestionnaireResponse.QuestionnaireResponseItemAnswerComponent) decodedQuestionnaireResponse.getItemFirstRep().getAnswer().get(1)).getValueDecimalType().getValue());
    }

    @Test
    public void instant() {
        Date date = TestData.TEST_DATE;
        Assertions.assertEquals(date, ((Row) observationsDataset.select("issued", new String[0]).head()).get(0));
        Assertions.assertEquals(date, decodedObservation.getIssued());
    }

    @Test
    public void annotation() throws FHIRException {
        Annotation noteFirstRep = medRequest.getNoteFirstRep();
        Annotation noteFirstRep2 = decodedMedRequest.getNoteFirstRep();
        Assertions.assertEquals(noteFirstRep.getText(), ((Row) medDataset.select(new Column[]{functions.expr("note[0].text")}).head()).get(0));
        Assertions.assertEquals(noteFirstRep.getText(), noteFirstRep2.getText());
        Assertions.assertEquals(noteFirstRep.getAuthorReference().getReference(), noteFirstRep2.getAuthorReference().getReference());
    }

    @Test
    public void testCopyDecoded() {
        Assertions.assertEquals(condition.getId(), decodedCondition.copy().getId());
        Assertions.assertEquals(medRequest.getId(), decodedMedRequest.copy().getId());
        Assertions.assertEquals(observation.getId(), decodedObservation.copy().getId());
        Assertions.assertEquals(patient.getId(), decodedPatient.copy().getId());
    }

    @Test
    public void testEmptyAttributes() {
        Map attributes = decodedMedRequest.getText().getDiv().getAttributes();
        Assertions.assertNotNull(attributes);
        Assertions.assertEquals(0, attributes.size());
    }

    @Test
    public void testFromRdd() {
        JavaSparkContext javaSparkContext = new JavaSparkContext(spark.sparkContext());
        try {
            Assertions.assertEquals(condition.getId(), ((Condition) spark.createDataset(javaSparkContext.parallelize(ImmutableList.of(condition)).rdd(), ENCODERS_L0.of(Condition.class)).head()).getId());
            javaSparkContext.close();
        } catch (Throwable th) {
            try {
                javaSparkContext.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test
    public void testFromParquet() throws IOException {
        String path = Files.createTempDirectory("encoder_test", new FileAttribute[0]).resolve("out.parquet").toString();
        conditionsDataset.write().save(path);
        Assertions.assertEquals(condition.getId(), ((Condition) spark.read().parquet(path).as(ENCODERS_L0.of(Condition.class)).head()).getId());
    }

    @Test
    public void testEncoderCached() {
        Assertions.assertSame(ENCODERS_L0.of(Condition.class), ENCODERS_L0.of(Condition.class));
        Assertions.assertSame(ENCODERS_L0.of(Patient.class), ENCODERS_L0.of(Patient.class));
    }

    @Test
    public void testPrimitiveClassDecoding() {
        Assertions.assertEquals(encounter.getClass_().getCode(), ((Row) encounterDataset.select("class.code", new String[0]).head()).get(0));
        Assertions.assertEquals(encounter.getClass_().getCode(), decodedEncounter.getClass_().getCode());
    }

    @Test
    public void testNestedQuestionnaire() {
        List list = (List) IntStream.rangeClosed(0, NESTING_LEVEL_3).mapToObj(i -> {
            return TestData.newNestedQuestionnaire(i, 4);
        }).collect(Collectors.toUnmodifiableList());
        Questionnaire questionnaire2 = (Questionnaire) list.get(0);
        List collectAsList = spark.createDataset(list, ENCODERS_L0.of(Questionnaire.class)).collectAsList();
        for (int i2 = 0; i2 <= NESTING_LEVEL_3; i2++) {
            Assertions.assertTrue(questionnaire2.equalsDeep((Base) collectAsList.get(i2)));
        }
        Dataset createDataset = spark.createDataset(list, ENCODERS_L3.of(Questionnaire.class));
        List collectAsList2 = createDataset.collectAsList();
        for (int i3 = 0; i3 <= NESTING_LEVEL_3; i3++) {
            Assertions.assertTrue(((Questionnaire) list.get(i3)).equalsDeep((Base) collectAsList2.get(i3)));
        }
        Assertions.assertEquals(Stream.of((Object[]) new String[]{"Item/0", "Item/0", "Item/0", "Item/0"}).map(obj -> {
            return RowFactory.create(new Object[]{obj});
        }).collect(Collectors.toUnmodifiableList()), createDataset.select(new Column[]{functions.col("item").getItem(0).getField("linkId")}).collectAsList());
        Assertions.assertEquals(Stream.of((Object[]) new String[]{null, "Item/1.0", "Item/1.0", "Item/1.0"}).map(obj2 -> {
            return RowFactory.create(new Object[]{obj2});
        }).collect(Collectors.toUnmodifiableList()), createDataset.select(new Column[]{functions.col("item").getItem(1).getField("item").getItem(0).getField("linkId")}).collectAsList());
        Assertions.assertEquals(Stream.of((Object[]) new String[]{null, null, "Item/2.1.0", "Item/2.1.0"}).map(obj3 -> {
            return RowFactory.create(new Object[]{obj3});
        }).collect(Collectors.toUnmodifiableList()), createDataset.select(new Column[]{functions.col("item").getItem(2).getField("item").getItem(1).getField("item").getItem(0).getField("linkId")}).collectAsList());
        Assertions.assertEquals(Stream.of((Object[]) new String[]{null, null, null, "Item/3.2.1.0"}).map(obj4 -> {
            return RowFactory.create(new Object[]{obj4});
        }).collect(Collectors.toUnmodifiableList()), createDataset.select(new Column[]{functions.col("item").getItem(Integer.valueOf(NESTING_LEVEL_3)).getField("item").getItem(2).getField("item").getItem(1).getField("item").getItem(0).getField("linkId")}).collectAsList());
    }

    @Test
    public void testQuantityComparator() {
        Quantity.QuantityComparator comparator = observation.getValueQuantity().getComparator();
        Assertions.assertEquals(comparator.toCode(), ((Row) observationsDataset.select("valueQuantity.comparator", new String[0]).head()).getString(0));
    }

    @Test
    public void nullEncoding() {
        Observation observation2 = new Observation();
        Assertions.assertFalse(observation2.hasSubject());
        Row row = (Row) spark.createDataset(ImmutableList.of(observation2), ENCODERS_L0.of(Observation.class)).toDF().select("subject", new String[]{"identifier", "status"}).first();
        Assertions.assertTrue(row.isNullAt(0));
        Assertions.assertTrue(row.isNullAt(1));
        Assertions.assertTrue(row.isNullAt(2));
    }

    @Test
    public void nullEncodingFromJson() {
        Observation parseResource = ENCODERS_L0.getContext().newJsonParser().parseResource(Observation.class, "{ \"resourceType\": \"Observation\"}");
        Assertions.assertFalse(parseResource.hasSubject());
        Row row = (Row) spark.createDataset(ImmutableList.of(parseResource), ENCODERS_L0.of(Observation.class)).toDF().select("subject", new String[]{"identifier", "status"}).first();
        Assertions.assertTrue(row.isNullAt(0));
        Assertions.assertTrue(row.isNullAt(1));
        Assertions.assertTrue(row.isNullAt(2));
    }
}
