package au.csiro.pathling.fhirpath.operator;

import au.csiro.pathling.fhirpath.FhirPath;
import au.csiro.pathling.fhirpath.element.ElementPath;
import au.csiro.pathling.test.SpringBootUnitTest;
import au.csiro.pathling.test.assertions.Assertions;
import au.csiro.pathling.test.builders.DatasetBuilder;
import au.csiro.pathling.test.builders.ElementPathBuilder;
import au.csiro.pathling.test.builders.ParserContextBuilder;
import au.csiro.pathling.test.helpers.SparkHelpers;
import ca.uhn.fhir.context.FhirContext;
import com.google.common.collect.ImmutableSet;
import jakarta.annotation.Nonnull;
import java.math.BigDecimal;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.fhir.ucum.UcumService;
import org.hl7.fhir.r4.model.Enumerations;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;

@SpringBootUnitTest
/* loaded from: input_file:au/csiro/pathling/fhirpath/operator/QuantityOperatorsPrecisionTest.class */
public class QuantityOperatorsPrecisionTest {

    @Autowired
    SparkSession spark;

    @Autowired
    FhirContext fhirContext;

    @Autowired
    UcumService ucumService;
    static final String ID_ALIAS = "_abc123";
    static final String REASONABLE_DECIMAL_01 = createSpanningDecimal(9, 3, 1, 6).toString();
    static final String REASONABLE_DECIMAL_02 = createSpanningDecimal(9, 3, 2, 6).toString();
    static final String FULL_DECIMAL_01 = createSpanningDecimal(9, 26, 1, 6).toString();
    static final String FULL_DECIMAL_02 = createSpanningDecimal(9, 26, 2, 6).toString();
    static final Set<String> UNSUPPORTED_FULL_DECIMAL_UNITS = ImmutableSet.of("Ym", "Zm", "Em", "Pm", "Tm");
    static final Set<String> UNSUPPORTED_REASONABLE_DECIMAL_MOL_UNITS = ImmutableSet.of("Ymol", "Zmol", "Emol", "Pmol", "Tmol");

    @Nonnull
    private static String unitToRowId(@Nonnull String str) {
        return "unit-" + str;
    }

    @Nonnull
    private ElementPath buildQuantityPathForUnits(@Nonnull String str, List<String> list) {
        DatasetBuilder withStructTypeColumns = new DatasetBuilder(this.spark).withIdColumn(ID_ALIAS).withStructTypeColumns(SparkHelpers.quantityStructType());
        for (String str2 : list) {
            withStructTypeColumns = withStructTypeColumns.withRow(unitToRowId(str2), SparkHelpers.rowForUcumQuantity(str, str2));
        }
        return new ElementPathBuilder(this.spark).fhirType(Enumerations.FHIRDefinedType.QUANTITY).singular(true).dataset(withStructTypeColumns.buildWithStructValue()).idAndValueColumns().build();
    }

    @Nonnull
    private List<String> getAllPrefixedUnits(@Nonnull String str) {
        return (List) this.ucumService.getModel().getPrefixes().stream().map((v0) -> {
            return v0.getCode();
        }).filter(str2 -> {
            return str2.length() == 1;
        }).map(str3 -> {
            return str3 + str;
        }).collect(Collectors.toUnmodifiableList());
    }

    @Nonnull
    private static BigDecimal createSpanningDecimal(int i, int i2, int i3, int i4) {
        return new BigDecimal(i).movePointRight(i2).add(new BigDecimal(i3).movePointLeft(i4));
    }

    @Nonnull
    private static List<Row> createResult(@Nonnull List<String> list, boolean z) {
        return createResult(list, z, Collections.emptySet());
    }

    @Nonnull
    private static List<Row> createResult(@Nonnull List<String> list, boolean z, @Nonnull Set<String> set) {
        return (List) list.stream().map(str -> {
            Object[] objArr = new Object[2];
            objArr[0] = unitToRowId(str);
            objArr[1] = set.contains(str) ? null : Boolean.valueOf(z);
            return RowFactory.create(objArr);
        }).collect(Collectors.toList());
    }

    @Nonnull
    private FhirPath callOperator(@Nonnull ElementPath elementPath, @Nonnull String str, @Nonnull ElementPath elementPath2) {
        return Operator.getInstance(str).invoke(new OperatorInput(new ParserContextBuilder(this.spark, this.fhirContext).groupingColumns(Collections.singletonList(elementPath.getIdColumn())).build(), elementPath, elementPath2));
    }

    @Test
    void equalityPrecisionForReasonableDecimals() {
        List<String> allPrefixedUnits = getAllPrefixedUnits("m");
        Assertions.assertThat(callOperator(buildQuantityPathForUnits(REASONABLE_DECIMAL_01, allPrefixedUnits), "=", buildQuantityPathForUnits(REASONABLE_DECIMAL_01, allPrefixedUnits))).selectResult().hasRows(createResult(allPrefixedUnits, true));
    }

    @Test
    void nonEqualityPrecisionForReasonableDecimals() {
        List<String> allPrefixedUnits = getAllPrefixedUnits("m");
        Assertions.assertThat(callOperator(buildQuantityPathForUnits(REASONABLE_DECIMAL_01, allPrefixedUnits), "!=", buildQuantityPathForUnits(REASONABLE_DECIMAL_02, allPrefixedUnits))).selectResult().hasRows(createResult(allPrefixedUnits, true));
    }

    @Test
    void comparisonPrecisionForReasonableDecimals() {
        List<String> allPrefixedUnits = getAllPrefixedUnits("m");
        Assertions.assertThat(callOperator(buildQuantityPathForUnits(REASONABLE_DECIMAL_01, allPrefixedUnits), "<", buildQuantityPathForUnits(REASONABLE_DECIMAL_02, allPrefixedUnits))).selectResult().hasRows(createResult(allPrefixedUnits, true));
    }

    @Test
    void equalityPrecisionForFullDecimals() {
        List<String> allPrefixedUnits = getAllPrefixedUnits("m");
        Assertions.assertThat(callOperator(buildQuantityPathForUnits(FULL_DECIMAL_01, allPrefixedUnits), "=", buildQuantityPathForUnits(FULL_DECIMAL_01, allPrefixedUnits))).selectResult().hasRows(createResult(allPrefixedUnits, true, UNSUPPORTED_FULL_DECIMAL_UNITS));
    }

    @Test
    void nonEqualityPrecisionForFullDecimals() {
        List<String> allPrefixedUnits = getAllPrefixedUnits("m");
        Assertions.assertThat(callOperator(buildQuantityPathForUnits(FULL_DECIMAL_01, allPrefixedUnits), "!=", buildQuantityPathForUnits(FULL_DECIMAL_02, allPrefixedUnits))).selectResult().hasRows(createResult(allPrefixedUnits, true, UNSUPPORTED_FULL_DECIMAL_UNITS));
    }

    @Test
    void comparisonPrecisionForFullDecimals() {
        List<String> allPrefixedUnits = getAllPrefixedUnits("m");
        Assertions.assertThat(callOperator(buildQuantityPathForUnits(FULL_DECIMAL_01, allPrefixedUnits), "<", buildQuantityPathForUnits(FULL_DECIMAL_02, allPrefixedUnits))).selectResult().hasRows(createResult(allPrefixedUnits, true, UNSUPPORTED_FULL_DECIMAL_UNITS));
    }

    @Test
    void equalityPrecisionForReasonableDecimalsWithMoles() {
        List<String> allPrefixedUnits = getAllPrefixedUnits("mol");
        Assertions.assertThat(callOperator(buildQuantityPathForUnits(REASONABLE_DECIMAL_01, allPrefixedUnits), "=", buildQuantityPathForUnits(REASONABLE_DECIMAL_01, allPrefixedUnits))).selectResult().hasRows(createResult(allPrefixedUnits, true, UNSUPPORTED_REASONABLE_DECIMAL_MOL_UNITS));
    }
}
