package au.csiro.pathling.extract;

import au.csiro.pathling.QueryExecutor;
import au.csiro.pathling.QueryHelpers;
import au.csiro.pathling.config.QueryConfiguration;
import au.csiro.pathling.fhirpath.FhirPath;
import au.csiro.pathling.fhirpath.Materializable;
import au.csiro.pathling.fhirpath.ResourcePath;
import au.csiro.pathling.fhirpath.parser.ParserContext;
import au.csiro.pathling.io.source.DataSource;
import au.csiro.pathling.query.ExpressionWithLabel;
import au.csiro.pathling.terminology.TerminologyServiceFactory;
import au.csiro.pathling.utilities.Preconditions;
import ca.uhn.fhir.context.FhirContext;
import jakarta.annotation.Nonnull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:au/csiro/pathling/extract/ExtractQueryExecutor.class */
public class ExtractQueryExecutor extends QueryExecutor {
    private static final Logger log = LoggerFactory.getLogger(ExtractQueryExecutor.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:au/csiro/pathling/extract/ExtractQueryExecutor$FhirPathContextAndResult.class */
    public static final class FhirPathContextAndResult {

        @Nonnull
        private final FhirPath fhirPath;

        @Nonnull
        private final ParserContext context;

        @Nonnull
        private final Dataset<Row> result;

        public FhirPathContextAndResult(@Nonnull FhirPath fhirPath, @Nonnull ParserContext parserContext, @Nonnull Dataset<Row> dataset) {
            if (fhirPath == null) {
                throw new NullPointerException("fhirPath is marked non-null but is null");
            }
            if (parserContext == null) {
                throw new NullPointerException("context is marked non-null but is null");
            }
            if (dataset == null) {
                throw new NullPointerException("result is marked non-null but is null");
            }
            this.fhirPath = fhirPath;
            this.context = parserContext;
            this.result = dataset;
        }

        @Nonnull
        public FhirPath getFhirPath() {
            return this.fhirPath;
        }

        @Nonnull
        public ParserContext getContext() {
            return this.context;
        }

        @Nonnull
        public Dataset<Row> getResult() {
            return this.result;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof FhirPathContextAndResult)) {
                return false;
            }
            FhirPathContextAndResult fhirPathContextAndResult = (FhirPathContextAndResult) obj;
            FhirPath fhirPath = getFhirPath();
            FhirPath fhirPath2 = fhirPathContextAndResult.getFhirPath();
            if (fhirPath == null) {
                if (fhirPath2 != null) {
                    return false;
                }
            } else if (!fhirPath.equals(fhirPath2)) {
                return false;
            }
            ParserContext context = getContext();
            ParserContext context2 = fhirPathContextAndResult.getContext();
            if (context == null) {
                if (context2 != null) {
                    return false;
                }
            } else if (!context.equals(context2)) {
                return false;
            }
            Dataset<Row> result = getResult();
            Dataset<Row> result2 = fhirPathContextAndResult.getResult();
            return result == null ? result2 == null : result.equals(result2);
        }

        public int hashCode() {
            FhirPath fhirPath = getFhirPath();
            int hashCode = (1 * 59) + (fhirPath == null ? 43 : fhirPath.hashCode());
            ParserContext context = getContext();
            int hashCode2 = (hashCode * 59) + (context == null ? 43 : context.hashCode());
            Dataset<Row> result = getResult();
            return (hashCode2 * 59) + (result == null ? 43 : result.hashCode());
        }

        public String toString() {
            return "ExtractQueryExecutor.FhirPathContextAndResult(fhirPath=" + getFhirPath() + ", context=" + getContext() + ", result=" + getResult() + ")";
        }
    }

    public ExtractQueryExecutor(@Nonnull QueryConfiguration queryConfiguration, @Nonnull FhirContext fhirContext, @Nonnull SparkSession sparkSession, @Nonnull DataSource dataSource, @Nonnull Optional<TerminologyServiceFactory> optional) {
        super(queryConfiguration, fhirContext, sparkSession, dataSource, optional);
    }

    @Nonnull
    public Dataset<Row> buildQuery(@Nonnull ExtractRequest extractRequest) {
        ResourcePath build = ResourcePath.build(getFhirContext(), getDataSource(), extractRequest.getSubjectResource(), extractRequest.getSubjectResource().toCode(), true);
        List<QueryExecutor.FhirPathAndContext> parseMaterializableExpressions = parseMaterializableExpressions(buildParserContext(build, Collections.singletonList(build.getIdColumn())), extractRequest.getColumns(), "Column");
        List<FhirPath> list = (List) parseMaterializableExpressions.stream().map((v0) -> {
            return v0.getFhirPath();
        }).collect(Collectors.toUnmodifiableList());
        Dataset<Row> filter = filterDataset(build, extractRequest.getFilters(), trimTrailingNulls(build.getIdColumn(), list, joinColumns(parseMaterializableExpressions).getResult()), (v0, v1) -> {
            return v0.and(v1);
        }).select((Column[]) labelColumns(list.stream().map(fhirPath -> {
            return ((Materializable) fhirPath).getExtractableColumn();
        }), ExpressionWithLabel.labelsAsStream(extractRequest.getColumnsWithLabels())).toArray(i -> {
            return new Column[i];
        })).filter(build.getIdColumn().isNotNull());
        return extractRequest.getLimit().isPresent() ? filter.limit(extractRequest.getLimit().get().intValue()) : filter;
    }

    @Nonnull
    private FhirPathContextAndResult joinColumns(@Nonnull Collection<QueryExecutor.FhirPathAndContext> collection) {
        List<QueryExecutor.FhirPathAndContext> list = (List) collection.stream().sorted(Comparator.comparingInt(fhirPathAndContext -> {
            return fhirPathAndContext.getFhirPath().getExpression().length();
        }).reversed()).collect(Collectors.toList());
        FhirPathContextAndResult fhirPathContextAndResult = null;
        Preconditions.check(list.size() > 0);
        for (QueryExecutor.FhirPathAndContext fhirPathAndContext2 : list) {
            if (fhirPathContextAndResult != null) {
                HashSet hashSet = new HashSet();
                Map<String, Column> nodeIdColumns = fhirPathContextAndResult.getContext().getNodeIdColumns();
                Map<String, Column> nodeIdColumns2 = fhirPathAndContext2.getContext().getNodeIdColumns();
                hashSet.addAll(nodeIdColumns.keySet());
                hashSet.addAll(nodeIdColumns2.keySet());
                List list2 = (List) new ArrayList(hashSet).stream().sorted(Comparator.comparingInt((v0) -> {
                    return v0.length();
                }).reversed()).collect(Collectors.toList());
                FhirPathContextAndResult fhirPathContextAndResult2 = fhirPathContextAndResult;
                Optional findFirst = list2.stream().filter(str -> {
                    return fhirPathContextAndResult2.getFhirPath().getExpression().startsWith(str) && fhirPathAndContext2.getFhirPath().getExpression().startsWith(str);
                }).findFirst();
                fhirPathContextAndResult = (findFirst.isPresent() && nodeIdColumns.containsKey(findFirst.get()) && nodeIdColumns2.containsKey(findFirst.get())) ? new FhirPathContextAndResult(fhirPathAndContext2.getFhirPath(), fhirPathAndContext2.getContext(), QueryHelpers.join(fhirPathContextAndResult.getResult(), (List<Column>) Arrays.asList(fhirPathContextAndResult.getFhirPath().getIdColumn(), nodeIdColumns.get(findFirst.get())), fhirPathAndContext2.getFhirPath().getDataset(), (List<Column>) Arrays.asList(fhirPathAndContext2.getFhirPath().getIdColumn(), nodeIdColumns2.get(findFirst.get())), QueryHelpers.JoinType.LEFT_OUTER)) : new FhirPathContextAndResult(fhirPathAndContext2.getFhirPath(), fhirPathAndContext2.getContext(), QueryHelpers.join(fhirPathContextAndResult.getResult(), fhirPathContextAndResult.getFhirPath().getIdColumn(), fhirPathAndContext2.getFhirPath().getDataset(), fhirPathAndContext2.getFhirPath().getIdColumn(), QueryHelpers.JoinType.LEFT_OUTER));
            } else {
                fhirPathContextAndResult = new FhirPathContextAndResult(fhirPathAndContext2.getFhirPath(), fhirPathAndContext2.getContext(), fhirPathAndContext2.getFhirPath().getDataset());
            }
        }
        return fhirPathContextAndResult;
    }

    @Nonnull
    private Dataset<Row> trimTrailingNulls(@Nonnull Column column, @Nonnull List<FhirPath> list, @Nonnull Dataset<Row> dataset) {
        Preconditions.checkArgument(!list.isEmpty(), "At least one expression is required");
        Column[] columnArr = (Column[]) list.stream().filter(fhirPath -> {
            return !fhirPath.isSingular();
        }).map((v0) -> {
            return v0.getValueColumn();
        }).toArray(i -> {
            return new Column[i];
        });
        if (columnArr.length == 0) {
            return dataset;
        }
        Column column2 = (Column) Arrays.stream(columnArr).map((v0) -> {
            return v0.isNotNull();
        }).reduce((v0, v1) -> {
            return v0.or(v1);
        }).get();
        ArrayList arrayList = new ArrayList();
        arrayList.add(column);
        arrayList.addAll((List) list.stream().filter((v0) -> {
            return v0.isSingular();
        }).map((v0) -> {
            return v0.getValueColumn();
        }).collect(Collectors.toList()));
        return QueryHelpers.join(dataset, arrayList, (Dataset<Row>) dataset.select((Column[]) arrayList.toArray(new Column[0])).distinct(), arrayList, column2, QueryHelpers.JoinType.RIGHT_OUTER);
    }
}
