package au.csiro.pathling.aggregate;

import au.csiro.pathling.QueryExecutor;
import au.csiro.pathling.QueryHelpers;
import au.csiro.pathling.config.QueryConfiguration;
import au.csiro.pathling.fhirpath.FhirPath;
import au.csiro.pathling.fhirpath.Materializable;
import au.csiro.pathling.fhirpath.ResourcePath;
import au.csiro.pathling.fhirpath.parser.Parser;
import au.csiro.pathling.fhirpath.parser.ParserContext;
import au.csiro.pathling.io.source.DataSource;
import au.csiro.pathling.query.ExpressionWithLabel;
import au.csiro.pathling.sql.SqlExpressions;
import au.csiro.pathling.terminology.TerminologyServiceFactory;
import au.csiro.pathling.utilities.Preconditions;
import ca.uhn.fhir.context.FhirContext;
import jakarta.annotation.Nonnull;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.Generated;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:au/csiro/pathling/aggregate/AggregateQueryExecutor.class */
public class AggregateQueryExecutor extends QueryExecutor {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(AggregateQueryExecutor.class);

    /* loaded from: input_file:au/csiro/pathling/aggregate/AggregateQueryExecutor$ResultWithExpressions.class */
    public static final class ResultWithExpressions {

        @Nonnull
        private final Dataset<Row> dataset;

        @Nonnull
        private final List<FhirPath> parsedAggregations;

        @Nonnull
        private final List<FhirPath> parsedGroupings;

        @Nonnull
        private final Collection<FhirPath> parsedFilters;

        @Generated
        public ResultWithExpressions(@Nonnull Dataset<Row> dataset, @Nonnull List<FhirPath> list, @Nonnull List<FhirPath> list2, @Nonnull Collection<FhirPath> collection) {
            if (dataset == null) {
                throw new NullPointerException("dataset is marked non-null but is null");
            }
            if (list == null) {
                throw new NullPointerException("parsedAggregations is marked non-null but is null");
            }
            if (list2 == null) {
                throw new NullPointerException("parsedGroupings is marked non-null but is null");
            }
            if (collection == null) {
                throw new NullPointerException("parsedFilters is marked non-null but is null");
            }
            this.dataset = dataset;
            this.parsedAggregations = list;
            this.parsedGroupings = list2;
            this.parsedFilters = collection;
        }

        @Nonnull
        @Generated
        public Dataset<Row> getDataset() {
            return this.dataset;
        }

        @Nonnull
        @Generated
        public List<FhirPath> getParsedAggregations() {
            return this.parsedAggregations;
        }

        @Nonnull
        @Generated
        public List<FhirPath> getParsedGroupings() {
            return this.parsedGroupings;
        }

        @Nonnull
        @Generated
        public Collection<FhirPath> getParsedFilters() {
            return this.parsedFilters;
        }

        @Generated
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ResultWithExpressions)) {
                return false;
            }
            ResultWithExpressions resultWithExpressions = (ResultWithExpressions) obj;
            Dataset<Row> dataset = getDataset();
            Dataset<Row> dataset2 = resultWithExpressions.getDataset();
            if (dataset == null) {
                if (dataset2 != null) {
                    return false;
                }
            } else if (!dataset.equals(dataset2)) {
                return false;
            }
            List<FhirPath> parsedAggregations = getParsedAggregations();
            List<FhirPath> parsedAggregations2 = resultWithExpressions.getParsedAggregations();
            if (parsedAggregations == null) {
                if (parsedAggregations2 != null) {
                    return false;
                }
            } else if (!parsedAggregations.equals(parsedAggregations2)) {
                return false;
            }
            List<FhirPath> parsedGroupings = getParsedGroupings();
            List<FhirPath> parsedGroupings2 = resultWithExpressions.getParsedGroupings();
            if (parsedGroupings == null) {
                if (parsedGroupings2 != null) {
                    return false;
                }
            } else if (!parsedGroupings.equals(parsedGroupings2)) {
                return false;
            }
            Collection<FhirPath> parsedFilters = getParsedFilters();
            Collection<FhirPath> parsedFilters2 = resultWithExpressions.getParsedFilters();
            return parsedFilters == null ? parsedFilters2 == null : parsedFilters.equals(parsedFilters2);
        }

        @Generated
        public int hashCode() {
            Dataset<Row> dataset = getDataset();
            int hashCode = (1 * 59) + (dataset == null ? 43 : dataset.hashCode());
            List<FhirPath> parsedAggregations = getParsedAggregations();
            int hashCode2 = (hashCode * 59) + (parsedAggregations == null ? 43 : parsedAggregations.hashCode());
            List<FhirPath> parsedGroupings = getParsedGroupings();
            int hashCode3 = (hashCode2 * 59) + (parsedGroupings == null ? 43 : parsedGroupings.hashCode());
            Collection<FhirPath> parsedFilters = getParsedFilters();
            return (hashCode3 * 59) + (parsedFilters == null ? 43 : parsedFilters.hashCode());
        }

        @Generated
        public String toString() {
            return "AggregateQueryExecutor.ResultWithExpressions(dataset=" + String.valueOf(getDataset()) + ", parsedAggregations=" + String.valueOf(getParsedAggregations()) + ", parsedGroupings=" + String.valueOf(getParsedGroupings()) + ", parsedFilters=" + String.valueOf(getParsedFilters()) + ")";
        }
    }

    public AggregateQueryExecutor(@Nonnull QueryConfiguration queryConfiguration, @Nonnull FhirContext fhirContext, @Nonnull SparkSession sparkSession, @Nonnull DataSource dataSource, @Nonnull Optional<TerminologyServiceFactory> optional) {
        super(queryConfiguration, fhirContext, sparkSession, dataSource, optional);
    }

    @Nonnull
    public ResultWithExpressions buildQuery(@Nonnull AggregateRequest aggregateRequest) {
        log.info("Executing request: {}", aggregateRequest);
        ResourcePath build = ResourcePath.build(getFhirContext(), getDataSource(), aggregateRequest.getSubjectResource(), aggregateRequest.getSubjectResource().toCode(), true);
        ParserContext buildParserContext = buildParserContext(build, Collections.singletonList(build.getIdColumn()));
        List<FhirPath> parseFilters = parseFilters(new Parser(buildParserContext), aggregateRequest.getFilters());
        List list = (List) parseMaterializableExpressions(buildParserContext, aggregateRequest.getGroupings(), "Grouping").stream().map((v0) -> {
            return v0.getFhirPath();
        }).collect(Collectors.toList());
        Column idColumn = build.getIdColumn();
        QueryHelpers.DatasetWithColumnMap createColumns = QueryHelpers.createColumns(applyFilters(joinExpressionsAndFilters(build, list, parseFilters, idColumn), parseFilters), (Column[]) list.stream().map((v0) -> {
            return v0.getValueColumn();
        }).map(SqlExpressions::pruneSyntheticFields).toArray(i -> {
            return new Column[i];
        }));
        Dataset<Row> dataset = createColumns.getDataset();
        List<Column> arrayList = new ArrayList<>((Collection<? extends Column>) createColumns.getColumnMap().values());
        List<FhirPath> parseAggregations = parseAggregations(new Parser(buildParserContext(build.copy(build.getExpression(), dataset, idColumn, build.getEidColumn(), build.getValueColumn(), build.isSingular(), Optional.empty()), arrayList)), aggregateRequest.getAggregations());
        List list2 = (List) parseAggregations.stream().map((v0) -> {
            return v0.getValueColumn();
        }).collect(Collectors.toList());
        Dataset<Row> joinExpressionsByColumns = joinExpressionsByColumns(parseAggregations, arrayList);
        if (arrayList.isEmpty()) {
            joinExpressionsByColumns = joinExpressionsByColumns.limit(1);
        }
        return new ResultWithExpressions(joinExpressionsByColumns.select((Column[]) Stream.concat(labelColumns(arrayList.stream(), ExpressionWithLabel.labelsAsStream(aggregateRequest.getGroupingsWithLabels())), labelColumns(list2.stream(), ExpressionWithLabel.labelsAsStream(aggregateRequest.getAggregationsWithLabels()))).toArray(i2 -> {
            return new Column[i2];
        })).distinct(), parseAggregations, list, parseFilters);
    }

    @Nonnull
    private List<FhirPath> parseAggregations(@Nonnull Parser parser, @Nonnull Collection<String> collection) {
        return (List) collection.stream().map(str -> {
            FhirPath parse = parser.parse(str);
            Preconditions.checkUserInput(parse instanceof Materializable, "Aggregation expression is not of a supported type: " + str);
            Preconditions.checkUserInput(parse.isSingular(), "Aggregation expression does not evaluate to a singular value: " + str);
            return parse;
        }).collect(Collectors.toList());
    }
}
