package tech.tablesaw.joining;

import com.google.common.collect.Streams;
import com.google.common.primitives.Ints;
import it.unimi.dsi.fastutil.ints.IntIterator;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import tech.tablesaw.api.BooleanColumn;
import tech.tablesaw.api.ColumnType;
import tech.tablesaw.api.DateColumn;
import tech.tablesaw.api.DateTimeColumn;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.FloatColumn;
import tech.tablesaw.api.InstantColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.LongColumn;
import tech.tablesaw.api.Row;
import tech.tablesaw.api.ShortColumn;
import tech.tablesaw.api.StringColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.api.TimeColumn;
import tech.tablesaw.columns.Column;
import tech.tablesaw.columns.booleans.BooleanColumnType;
import tech.tablesaw.columns.dates.DateColumnType;
import tech.tablesaw.columns.datetimes.DateTimeColumnType;
import tech.tablesaw.columns.instant.InstantColumnType;
import tech.tablesaw.columns.numbers.DoubleColumnType;
import tech.tablesaw.columns.numbers.FloatColumnType;
import tech.tablesaw.columns.numbers.IntColumnType;
import tech.tablesaw.columns.numbers.LongColumnType;
import tech.tablesaw.columns.numbers.ShortColumnType;
import tech.tablesaw.columns.strings.StringColumnType;
import tech.tablesaw.columns.strings.TextColumnType;
import tech.tablesaw.columns.times.TimeColumnType;
import tech.tablesaw.index.ByteIndex;
import tech.tablesaw.index.DoubleIndex;
import tech.tablesaw.index.FloatIndex;
import tech.tablesaw.index.Index;
import tech.tablesaw.index.IntIndex;
import tech.tablesaw.index.LongIndex;
import tech.tablesaw.index.ShortIndex;
import tech.tablesaw.index.StringIndex;
import tech.tablesaw.selection.Selection;

/* loaded from: input_file:tech/tablesaw/joining/DataFrameJoiner.class */
public class DataFrameJoiner {
    private static final String TABLE_ALIAS = "T";
    private final Table table;
    private final String[] joinColumnNames;
    private final List<Integer> joinColumnIndexes;
    private final AtomicInteger joinTableId = new AtomicInteger(2);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:tech/tablesaw/joining/DataFrameJoiner$JoinType.class */
    public enum JoinType {
        INNER,
        LEFT_OUTER,
        RIGHT_OUTER,
        FULL_OUTER
    }

    public DataFrameJoiner(Table table, String... strArr) {
        this.table = table;
        this.joinColumnNames = strArr;
        this.joinColumnIndexes = getJoinIndexes(table, strArr);
    }

    private List<Integer> getJoinIndexes(Table table, String[] strArr) {
        Stream stream = Arrays.stream(strArr);
        table.getClass();
        return (List) stream.map(table::columnIndex).collect(Collectors.toList());
    }

    public Table inner(Table... tableArr) {
        return inner(false, tableArr);
    }

    public Table inner(boolean z, Table... tableArr) {
        Table table = this.table;
        for (Table table2 : tableArr) {
            table = joinInternal(table, table2, JoinType.INNER, z, this.joinColumnNames);
        }
        return table;
    }

    public Table inner(Table table, String str) {
        return inner(table, false, str);
    }

    public Table inner(Table table, String[] strArr) {
        return inner(table, false, strArr);
    }

    public Table inner(Table table, String str, boolean z) {
        return inner(table, z, str);
    }

    public Table inner(Table table, boolean z, String... strArr) {
        return joinInternal(this.table, table, JoinType.INNER, z, strArr);
    }

    private Table joinInternal(Table table, Table table2, JoinType joinType, boolean z, String... strArr) {
        List<Integer> joinIndexes = getJoinIndexes(table2, strArr);
        Table create = Table.create(table.name());
        Set<Integer> emptyTableFromColumns = emptyTableFromColumns(create, table, table2, joinType, z, joinIndexes);
        List<Index> buildIndexesForJoinColumns = buildIndexesForJoinColumns(this.joinColumnIndexes, table);
        List<Index> buildIndexesForJoinColumns2 = buildIndexesForJoinColumns(joinIndexes, table2);
        validateIndexes(buildIndexesForJoinColumns, buildIndexesForJoinColumns2);
        if (table.rowCount() == 0 && (joinType == JoinType.LEFT_OUTER || joinType == JoinType.INNER)) {
            create.removeColumns(Ints.toArray(emptyTableFromColumns));
            return create;
        }
        Selection with = Selection.with(new int[0]);
        Selection with2 = Selection.with(new int[0]);
        Iterator<Row> it2 = table.iterator();
        while (it2.hasNext()) {
            int rowNumber = it2.next().getRowNumber();
            if (!with.contains(rowNumber)) {
                Selection createMultiColSelection = createMultiColSelection(table, rowNumber, buildIndexesForJoinColumns, table.rowCount());
                Selection createMultiColSelection2 = createMultiColSelection(table, rowNumber, buildIndexesForJoinColumns2, table2.rowCount());
                if ((joinType == JoinType.LEFT_OUTER || joinType == JoinType.FULL_OUTER) && createMultiColSelection2.isEmpty()) {
                    withMissingLeftJoin(create, table, createMultiColSelection, emptyTableFromColumns);
                } else {
                    crossProduct(create, table, table2, createMultiColSelection, createMultiColSelection2, emptyTableFromColumns);
                }
                with = with.or(createMultiColSelection);
                if (joinType == JoinType.FULL_OUTER || joinType == JoinType.RIGHT_OUTER) {
                    with2 = with2.or(createMultiColSelection2);
                } else if (with.size() == table.rowCount()) {
                    create.removeColumns(Ints.toArray(emptyTableFromColumns));
                    return create;
                }
            }
        }
        withMissingRight(create, table.columnCount(), table2, with2.flip(0, table2.rowCount()), joinType, joinIndexes, emptyTableFromColumns);
        create.removeColumns(Ints.toArray(emptyTableFromColumns));
        return create;
    }

    private void validateIndexes(List<Index> list, List<Index> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Cannot join using a different number of indices on each table: " + list + " and " + list2);
        }
        for (int i = 0; i < list.size(); i++) {
            if (!list.get(i).getClass().equals(list2.get(i).getClass())) {
                throw new IllegalArgumentException("Cannot join using different index types: " + list + " and " + list2);
            }
        }
    }

    private List<Index> buildIndexesForJoinColumns(List<Integer> list, Table table) {
        return (List) list.stream().map(num -> {
            return indexFor(table, num.intValue());
        }).collect(Collectors.toList());
    }

    private Index indexFor(Table table, int i) {
        ColumnType type = table.column(i).type();
        if (type instanceof DateColumnType) {
            return new IntIndex(table.dateColumn(i));
        }
        if (type instanceof DateTimeColumnType) {
            return new LongIndex(table.dateTimeColumn(i));
        }
        if (type instanceof InstantColumnType) {
            return new LongIndex(table.instantColumn(i));
        }
        if (type instanceof TimeColumnType) {
            return new IntIndex(table.timeColumn(i));
        }
        if ((type instanceof StringColumnType) || (type instanceof TextColumnType)) {
            return new StringIndex(table.stringColumn(i));
        }
        if (type instanceof IntColumnType) {
            return new IntIndex(table.intColumn(i));
        }
        if (type instanceof LongColumnType) {
            return new LongIndex(table.longColumn(i));
        }
        if (type instanceof ShortColumnType) {
            return new ShortIndex(table.shortColumn(i));
        }
        if (type instanceof BooleanColumnType) {
            return new ByteIndex(table.booleanColumn(i));
        }
        if (type instanceof DoubleColumnType) {
            return new DoubleIndex(table.doubleColumn(i));
        }
        if (type instanceof FloatColumnType) {
            return new FloatIndex(table.floatColumn(i));
        }
        throw new IllegalArgumentException("Joining attempted on unsupported column type " + type);
    }

    private Selection selectionForColumn(Column<?> column, int i, Index index) {
        ColumnType type = column.type();
        if (type instanceof DateColumnType) {
            return ((IntIndex) index).get(((DateColumn) column).getIntInternal(i));
        }
        if (type instanceof TimeColumnType) {
            return ((IntIndex) index).get(((TimeColumn) column).getIntInternal(i));
        }
        if (type instanceof DateTimeColumnType) {
            return ((LongIndex) index).get(((DateTimeColumn) column).getLongInternal(i));
        }
        if (type instanceof InstantColumnType) {
            return ((LongIndex) index).get(((InstantColumn) column).getLongInternal(i));
        }
        if ((type instanceof StringColumnType) || (type instanceof TextColumnType)) {
            return ((StringIndex) index).get(((StringColumn) column).get(i));
        }
        if (type instanceof IntColumnType) {
            return ((IntIndex) index).get(((IntColumn) column).getInt(i));
        }
        if (type instanceof LongColumnType) {
            return ((LongIndex) index).get(((LongColumn) column).getLong(i));
        }
        if (type instanceof ShortColumnType) {
            return ((ShortIndex) index).get(((ShortColumn) column).getShort(i));
        }
        if (type instanceof BooleanColumnType) {
            return ((ByteIndex) index).get(((BooleanColumn) column).getByte(i));
        }
        if (type instanceof DoubleColumnType) {
            return ((DoubleIndex) index).get(((DoubleColumn) column).getDouble(i));
        }
        if (type instanceof FloatColumnType) {
            return ((FloatIndex) index).get(((FloatColumn) column).getFloat(i));
        }
        throw new IllegalArgumentException("Joining is supported on numeric, string, and date-like columns. Column " + column.name() + " is of type " + column.type());
    }

    private Selection createMultiColSelection(Table table, int i, List<Index> list, int i2) {
        Selection withRange = Selection.withRange(0, i2);
        int i3 = 0;
        Iterator<Integer> it2 = this.joinColumnIndexes.iterator();
        while (it2.hasNext()) {
            withRange = withRange.and(selectionForColumn(table.column(it2.next().intValue()), i, list.get(i3)));
            i3++;
        }
        return withRange;
    }

    private String newName(String str, String str2) {
        return str + "." + str2;
    }

    public Table fullOuter(Table... tableArr) {
        return fullOuter(false, tableArr);
    }

    public Table fullOuter(boolean z, Table... tableArr) {
        Table table = this.table;
        for (Table table2 : tableArr) {
            table = joinInternal(table, table2, JoinType.FULL_OUTER, z, this.joinColumnNames);
        }
        return table;
    }

    public Table fullOuter(Table table, String str) {
        return joinInternal(this.table, table, JoinType.FULL_OUTER, false, str);
    }

    public Table leftOuter(Table... tableArr) {
        return leftOuter(false, tableArr);
    }

    public Table leftOuter(boolean z, Table... tableArr) {
        Table table = this.table;
        for (Table table2 : tableArr) {
            table = leftOuter(table2, z, this.joinColumnNames);
        }
        return table;
    }

    public Table leftOuter(Table table, String[] strArr) {
        return leftOuter(table, false, strArr);
    }

    public Table leftOuter(Table table, String str) {
        return leftOuter(table, false, str);
    }

    public Table leftOuter(Table table, boolean z, String... strArr) {
        return joinInternal(this.table, table, JoinType.LEFT_OUTER, z, strArr);
    }

    public Table rightOuter(Table... tableArr) {
        return rightOuter(false, tableArr);
    }

    public Table rightOuter(boolean z, Table... tableArr) {
        Table table = this.table;
        for (Table table2 : tableArr) {
            table = rightOuter(table2, z, this.joinColumnNames);
        }
        return table;
    }

    public Table rightOuter(Table table, String str) {
        return rightOuter(table, false, str);
    }

    public Table rightOuter(Table table, String[] strArr) {
        return rightOuter(table, false, strArr);
    }

    public Table rightOuter(Table table, boolean z, String... strArr) {
        return joinInternal(this.table, table, JoinType.RIGHT_OUTER, z, strArr);
    }

    private Set<Integer> emptyTableFromColumns(Table table, Table table2, Table table3, JoinType joinType, boolean z, List<Integer> list) {
        Column<?>[] columnArr = (Column[]) Streams.concat(table2.columns().stream(), table3.columns().stream()).map((v0) -> {
            return v0.emptyCopy2();
        }).toArray(i -> {
            return new Column[i];
        });
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < columnArr.length; i2++) {
            if (joinType != JoinType.RIGHT_OUTER) {
                int columnCount = i2 - table2.columnCount();
                if (i2 >= table2.columnCount() && list.contains(Integer.valueOf(columnCount))) {
                    columnArr[i2].setName("Placeholder_" + hashSet.size());
                    hashSet.add(Integer.valueOf(i2));
                }
            } else if (i2 < table2.columnCount() && this.joinColumnIndexes.contains(Integer.valueOf(i2))) {
                columnArr[i2].setName("Placeholder_" + hashSet.size());
                hashSet.add(Integer.valueOf(i2));
            }
        }
        if (z) {
            Set set = (Set) Arrays.stream(columnArr).map((v0) -> {
                return v0.name();
            }).map((v0) -> {
                return v0.toLowerCase();
            }).limit(table2.columnCount()).collect(Collectors.toSet());
            String str = TABLE_ALIAS + this.joinTableId.getAndIncrement();
            for (int columnCount2 = table2.columnCount(); columnCount2 < columnArr.length; columnCount2++) {
                String name = columnArr[columnCount2].name();
                if (set.contains(name.toLowerCase())) {
                    columnArr[columnCount2].setName(newName(str, name));
                }
            }
        }
        table.addColumns(columnArr);
        return hashSet;
    }

    private void crossProduct(Table table, Table table2, Table table3, Selection selection, Selection selection2, Set<Integer> set) {
        for (int i = 0; i < table2.columnCount() + table3.columnCount(); i++) {
            if (!set.contains(Integer.valueOf(i))) {
                int columnCount = i - table2.columnCount();
                IntIterator it2 = selection.iterator();
                while (it2.hasNext()) {
                    int intValue = it2.next().intValue();
                    IntIterator it3 = selection2.iterator();
                    while (it3.hasNext()) {
                        int intValue2 = it3.next().intValue();
                        if (i < table2.columnCount()) {
                            table.column(i).append2(table2.column(i), intValue);
                        } else {
                            table.column(i).append2(table3.column(columnCount), intValue2);
                        }
                    }
                }
            }
        }
    }

    private void withMissingLeftJoin(Table table, Table table2, Selection selection, Set<Integer> set) {
        for (int i = 0; i < table.columnCount(); i++) {
            if (!set.contains(Integer.valueOf(i))) {
                if (i < table2.columnCount()) {
                    Column<?> column = table2.column(i);
                    IntIterator it2 = selection.iterator();
                    while (it2.hasNext()) {
                        table.column(i).append2(column, it2.next().intValue());
                    }
                } else {
                    for (int i2 = 0; i2 < selection.size(); i2++) {
                        table.column(i).appendMissing2();
                    }
                }
            }
        }
    }

    private void withMissingRight(Table table, int i, Table table2, Selection selection, JoinType joinType, List<Integer> list, Set<Integer> set) {
        if (joinType == JoinType.FULL_OUTER) {
            for (int i2 = 0; i2 < list.size(); i2++) {
                Column<?> column = table2.column(list.get(i2).intValue());
                IntIterator it2 = selection.iterator();
                while (it2.hasNext()) {
                    table.column(this.joinColumnIndexes.get(i2).intValue()).append2(column, it2.next().intValue());
                }
            }
        }
        for (int i3 = 0; i3 < table.columnCount(); i3++) {
            if (!set.contains(Integer.valueOf(i3)) && !this.joinColumnIndexes.contains(Integer.valueOf(i3))) {
                if (i3 < i) {
                    for (int i4 = 0; i4 < selection.size(); i4++) {
                        table.column(i3).appendMissing2();
                    }
                } else {
                    Column<?> column2 = table2.column(i3 - i);
                    IntIterator it3 = selection.iterator();
                    while (it3.hasNext()) {
                        table.column(i3).append2(column2, it3.next().intValue());
                    }
                }
            }
        }
    }
}
