package au.csiro.pathling;

import au.csiro.pathling.encoders.FhirEncoders;
import au.csiro.pathling.fhirpath.FhirPath;
import au.csiro.pathling.fhirpath.NonLiteralPath;
import au.csiro.pathling.fhirpath.literal.LiteralPath;
import au.csiro.pathling.fhirpath.parser.ParserContext;
import au.csiro.pathling.utilities.Preconditions;
import au.csiro.pathling.utilities.Strings;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.hl7.fhir.r4.model.Enumerations;

/* loaded from: input_file:au/csiro/pathling/QueryHelpers.class */
public abstract class QueryHelpers {

    /* loaded from: input_file:au/csiro/pathling/QueryHelpers$DatasetWithColumn.class */
    public static final class DatasetWithColumn {

        @Nonnull
        private final Dataset<Row> dataset;

        @Nonnull
        private final Column column;

        public DatasetWithColumn(@Nonnull Dataset<Row> dataset, @Nonnull Column column) {
            if (dataset == null) {
                throw new NullPointerException("dataset is marked non-null but is null");
            }
            if (column == null) {
                throw new NullPointerException("column is marked non-null but is null");
            }
            this.dataset = dataset;
            this.column = column;
        }

        @Nonnull
        public Dataset<Row> getDataset() {
            return this.dataset;
        }

        @Nonnull
        public Column getColumn() {
            return this.column;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof DatasetWithColumn)) {
                return false;
            }
            DatasetWithColumn datasetWithColumn = (DatasetWithColumn) obj;
            Dataset<Row> dataset = getDataset();
            Dataset<Row> dataset2 = datasetWithColumn.getDataset();
            if (dataset == null) {
                if (dataset2 != null) {
                    return false;
                }
            } else if (!dataset.equals(dataset2)) {
                return false;
            }
            Column column = getColumn();
            Column column2 = datasetWithColumn.getColumn();
            return column == null ? column2 == null : column.equals(column2);
        }

        public int hashCode() {
            Dataset<Row> dataset = getDataset();
            int hashCode = (1 * 59) + (dataset == null ? 43 : dataset.hashCode());
            Column column = getColumn();
            return (hashCode * 59) + (column == null ? 43 : column.hashCode());
        }

        public String toString() {
            return "QueryHelpers.DatasetWithColumn(dataset=" + getDataset() + ", column=" + getColumn() + ")";
        }
    }

    /* loaded from: input_file:au/csiro/pathling/QueryHelpers$DatasetWithColumnMap.class */
    public static final class DatasetWithColumnMap {

        @Nonnull
        private final Dataset<Row> dataset;

        @Nonnull
        private final Map<Column, Column> columnMap;

        @Nonnull
        public Column getColumn(@Nonnull Column column) {
            return this.columnMap.get(column);
        }

        public DatasetWithColumnMap(@Nonnull Dataset<Row> dataset, @Nonnull Map<Column, Column> map) {
            if (dataset == null) {
                throw new NullPointerException("dataset is marked non-null but is null");
            }
            if (map == null) {
                throw new NullPointerException("columnMap is marked non-null but is null");
            }
            this.dataset = dataset;
            this.columnMap = map;
        }

        @Nonnull
        public Dataset<Row> getDataset() {
            return this.dataset;
        }

        @Nonnull
        public Map<Column, Column> getColumnMap() {
            return this.columnMap;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof DatasetWithColumnMap)) {
                return false;
            }
            DatasetWithColumnMap datasetWithColumnMap = (DatasetWithColumnMap) obj;
            Dataset<Row> dataset = getDataset();
            Dataset<Row> dataset2 = datasetWithColumnMap.getDataset();
            if (dataset == null) {
                if (dataset2 != null) {
                    return false;
                }
            } else if (!dataset.equals(dataset2)) {
                return false;
            }
            Map<Column, Column> columnMap = getColumnMap();
            Map<Column, Column> columnMap2 = datasetWithColumnMap.getColumnMap();
            return columnMap == null ? columnMap2 == null : columnMap.equals(columnMap2);
        }

        public int hashCode() {
            Dataset<Row> dataset = getDataset();
            int hashCode = (1 * 59) + (dataset == null ? 43 : dataset.hashCode());
            Map<Column, Column> columnMap = getColumnMap();
            return (hashCode * 59) + (columnMap == null ? 43 : columnMap.hashCode());
        }

        public String toString() {
            return "QueryHelpers.DatasetWithColumnMap(dataset=" + getDataset() + ", columnMap=" + getColumnMap() + ")";
        }
    }

    /* loaded from: input_file:au/csiro/pathling/QueryHelpers$JoinType.class */
    public enum JoinType {
        INNER("inner"),
        CROSS("cross"),
        OUTER("outer"),
        FULL("full"),
        FULL_OUTER("full_outer"),
        LEFT("left"),
        LEFT_OUTER("left_outer"),
        RIGHT("right"),
        RIGHT_OUTER("right_outer"),
        LEFT_SEMI("left_semi"),
        LEFT_ANTI("left_anti");


        @Nonnull
        private final String sparkName;

        JoinType(@Nonnull String str) {
            this.sparkName = str;
        }

        @Nonnull
        public String getSparkName() {
            return this.sparkName;
        }
    }

    @Nonnull
    public static DatasetWithColumn createColumn(@Nonnull Dataset<Row> dataset, @Nonnull Column column) {
        DatasetWithColumnMap aliasColumns = aliasColumns(dataset, Collections.singletonList(column));
        return new DatasetWithColumn(aliasColumns.getDataset(), aliasColumns.getColumnMap().get(column));
    }

    @Nonnull
    public static DatasetWithColumnMap createColumns(@Nonnull Dataset<Row> dataset, @Nonnull Column... columnArr) {
        return aliasColumns(dataset, Arrays.asList(columnArr));
    }

    @Nonnull
    public static DatasetWithColumnMap aliasAllColumns(@Nonnull Dataset<Row> dataset) {
        Stream of = Stream.of((Object[]) dataset.columns());
        Objects.requireNonNull(dataset);
        DatasetWithColumnMap aliasColumns = aliasColumns(dataset, (List) of.map(dataset::col).collect(Collectors.toList()));
        return new DatasetWithColumnMap(aliasColumns.getDataset(), aliasColumns.getColumnMap());
    }

    @Nonnull
    private static DatasetWithColumnMap aliasColumns(@Nonnull Dataset<Row> dataset, @Nonnull Iterable<Column> iterable) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Stream filter = Stream.of((Object[]) dataset.columns()).filter((v0) -> {
            return Strings.looksLikeAlias(v0);
        });
        Objects.requireNonNull(dataset);
        List list = (List) filter.map(dataset::col).collect(Collectors.toList());
        for (Column column : iterable) {
            String randomAlias = Strings.randomAlias();
            list.add(column.alias(randomAlias));
            linkedHashMap.put(column, functions.col(randomAlias));
        }
        return new DatasetWithColumnMap(dataset.select((Column[]) list.toArray(new Column[0])), linkedHashMap);
    }

    private static Dataset<Row> join(@Nonnull Dataset<Row> dataset, @Nonnull List<Column> list, @Nonnull Dataset<Row> dataset2, @Nonnull List<Column> list2, @Nonnull Optional<Column> optional, @Nonnull JoinType joinType) {
        Preconditions.checkArgument(list.size() == list2.size(), "Left columns should be same size as right columns");
        Dataset<Row> dataset3 = dataset;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            DatasetWithColumn createColumn = createColumn(dataset3, list.get(i));
            dataset3 = createColumn.getDataset();
            arrayList.add(createColumn.getColumn().eqNullSafe(list2.get(i)));
        }
        Objects.requireNonNull(arrayList);
        optional.ifPresent((v1) -> {
            r1.add(v1);
        });
        Column column = (Column) arrayList.stream().reduce((v0, v1) -> {
            return v0.and(v1);
        }).orElse(functions.lit(true));
        List asList = Arrays.asList(dataset3.columns());
        List list3 = (List) list2.stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.toList());
        return applySelection(dataset3, Collections.emptyList(), list3).join(applySelection(dataset2, list3, asList), column, joinType.getSparkName());
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull Dataset<Row> dataset, @Nonnull List<Column> list, @Nonnull Dataset<Row> dataset2, @Nonnull List<Column> list2, @Nonnull JoinType joinType) {
        return join(dataset, list, dataset2, list2, (Optional<Column>) Optional.empty(), joinType);
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull Dataset<Row> dataset, @Nonnull Column column, @Nonnull Dataset<Row> dataset2, @Nonnull Column column2, @Nonnull JoinType joinType) {
        return join(dataset, (List<Column>) Collections.singletonList(column), dataset2, (List<Column>) Collections.singletonList(column2), joinType);
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull Dataset<Row> dataset, @Nonnull Column column, @Nonnull Dataset<Row> dataset2, @Nonnull Column column2, @Nonnull Column column3, @Nonnull JoinType joinType) {
        return join(dataset, (List<Column>) Collections.singletonList(column), dataset2, (List<Column>) Collections.singletonList(column2), (Optional<Column>) Optional.of(column3), joinType);
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull Dataset<Row> dataset, @Nonnull List<Column> list, @Nonnull Dataset<Row> dataset2, @Nonnull List<Column> list2, @Nonnull Column column, @Nonnull JoinType joinType) {
        return join(dataset, list, dataset2, list2, (Optional<Column>) Optional.of(column), joinType);
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull Dataset<Row> dataset, @Nonnull Dataset<Row> dataset2, @Nonnull Column column, @Nonnull JoinType joinType) {
        return join(dataset, (List<Column>) Collections.emptyList(), dataset2, (List<Column>) Collections.emptyList(), (Optional<Column>) Optional.of(column), joinType);
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull ParserContext parserContext, @Nonnull FhirPath fhirPath, @Nonnull FhirPath fhirPath2, @Nonnull JoinType joinType) {
        return join(parserContext, Arrays.asList(fhirPath, fhirPath2), joinType);
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull ParserContext parserContext, @Nonnull List<FhirPath> list, @Nonnull JoinType joinType) {
        Preconditions.checkArgument(list.size() > 1, "fhirPaths must contain more than one FhirPath");
        FhirPath fhirPath = list.get(0);
        List<FhirPath> list2 = (List) list.subList(1, list.size()).stream().filter(fhirPath2 -> {
            return fhirPath2 instanceof NonLiteralPath;
        }).collect(Collectors.toList());
        if ((fhirPath instanceof NonLiteralPath) && list2.isEmpty()) {
            return fhirPath.getDataset();
        }
        if ((fhirPath instanceof LiteralPath) && !list2.isEmpty()) {
            return ((FhirPath) list2.get(0)).getDataset();
        }
        Dataset<Row> dataset = fhirPath.getDataset();
        List<Column> groupingColumns = parserContext.getGroupingColumns();
        Column idColumn = parserContext.getInputContext().getIdColumn();
        List<Column> checkColumnsAndFallback = checkColumnsAndFallback(fhirPath.getDataset(), groupingColumns, idColumn);
        for (FhirPath fhirPath3 : list2) {
            List<Column> checkColumnsAndFallback2 = checkColumnsAndFallback(fhirPath3.getDataset(), checkColumnsAndFallback, idColumn);
            dataset = join(dataset, checkColumnsAndFallback2, fhirPath3.getDataset(), checkColumnsAndFallback2, joinType);
        }
        return dataset;
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull FhirPath fhirPath, @Nonnull Dataset<Row> dataset, @Nonnull Column column, @Nonnull JoinType joinType) {
        return join(fhirPath.getDataset(), fhirPath.getIdColumn(), dataset, column, joinType);
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull FhirPath fhirPath, @Nonnull Dataset<Row> dataset, @Nonnull Column column, @Nonnull Column column2, @Nonnull JoinType joinType) {
        return join(fhirPath.getDataset(), (List<Column>) Collections.singletonList(fhirPath.getIdColumn()), dataset, (List<Column>) Collections.singletonList(column), (Optional<Column>) Optional.of(column2), joinType);
    }

    @Nonnull
    public static Dataset<Row> join(@Nonnull Dataset<Row> dataset, @Nonnull Column column, @Nonnull FhirPath fhirPath, @Nonnull JoinType joinType) {
        return join(dataset, column, fhirPath.getDataset(), fhirPath.getIdColumn(), joinType);
    }

    private static List<Column> checkColumnsAndFallback(@Nonnull Dataset<Row> dataset, @Nonnull List<Column> list, @Nonnull Column column) {
        HashSet hashSet = new HashSet(List.of((Object[]) dataset.columns()));
        Set set = (Set) list.stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.toSet());
        if (hashSet.containsAll(set)) {
            return list;
        }
        HashSet hashSet2 = new HashSet(set);
        hashSet2.retainAll(hashSet);
        hashSet2.add(column.toString());
        Stream stream = hashSet2.stream();
        Objects.requireNonNull(dataset);
        return (List) stream.map(dataset::col).collect(Collectors.toList());
    }

    @Nonnull
    public static Dataset<Row> union(@Nonnull Collection<Dataset<Row>> collection) {
        return collection.stream().reduce((v0, v1) -> {
            return v0.union(v1);
        }).orElseThrow();
    }

    @Nonnull
    private static Dataset<Row> applySelection(@Nonnull Dataset<Row> dataset, @Nonnull Collection<String> collection, @Nonnull Collection<String> collection2) {
        Stream filter = Stream.of((Object[]) dataset.columns()).filter(str -> {
            return collection.contains(str) || !collection2.contains(str);
        });
        Objects.requireNonNull(dataset);
        return dataset.select((Column[]) filter.map(dataset::col).toArray(i -> {
            return new Column[i];
        }));
    }

    @Nonnull
    public static List<Column> getUnionableColumns(@Nonnull FhirPath fhirPath, @Nonnull FhirPath fhirPath2) {
        HashSet hashSet = new HashSet(List.of((Object[]) fhirPath.getDataset().columns()));
        hashSet.retainAll(List.of((Object[]) fhirPath2.getDataset().columns()));
        List<Column> list = (List) hashSet.stream().map(functions::col).sorted(Comparator.comparing((v0) -> {
            return v0.toString();
        })).collect(Collectors.toList());
        list.add(fhirPath.getValueColumn());
        return list;
    }

    @Nonnull
    public static Dataset<Row> createEmptyDataset(@Nonnull SparkSession sparkSession, @Nonnull FhirEncoders fhirEncoders, @Nonnull Enumerations.ResourceType resourceType) {
        return sparkSession.emptyDataset(fhirEncoders.of(resourceType.toCode())).toDF();
    }
}
