package org.apache.kylin.query.mask;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kylin.common.KylinConfig;
import org.apache.kylin.common.QueryContext;
import org.apache.kylin.common.exception.ErrorCodeSupplier;
import org.apache.kylin.common.exception.KylinException;
import org.apache.kylin.common.exception.ServerErrorCode;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.SetOp;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.TableScan;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Values;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Window;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexInputRef;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlIdentifier;
import org.apache.kylin.metadata.acl.AclTCRManager;
import org.apache.kylin.metadata.acl.DependentColumn;
import org.apache.kylin.metadata.acl.DependentColumnInfo;
import org.apache.kylin.metadata.model.ColumnDesc;
import org.apache.kylin.metadata.project.NProjectManager;
import org.apache.kylin.query.relnode.KapTableScan;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.parser.ParseException;

/* loaded from: input_file:org/apache/kylin/query/mask/QueryDependentColumnMask.class */
public class QueryDependentColumnMask implements QueryResultMask {
    private RelNode rootRelNode;
    private String defaultDatabase;
    private DependentColumnInfo dependentInfo;
    private List<ResultColumnMaskInfo> resultColumnMaskInfos;
    private boolean needMask = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/kylin/query/mask/QueryDependentColumnMask$ColumnReferences.class */
    public static class ColumnReferences {
        boolean hasCalculation;
        boolean hasAggregation;
        private Set<String> references;

        public ColumnReferences() {
            this.hasCalculation = false;
            this.hasAggregation = false;
            this.references = new HashSet();
        }

        public ColumnReferences(String str) {
            this.hasCalculation = false;
            this.hasAggregation = false;
            this.references = new HashSet();
            this.references = Sets.newHashSet(str);
        }

        void addReference(String str) {
            this.references.add(str);
        }

        void addReferences(Collection<String> collection) {
            this.references.addAll(collection);
        }

        ColumnReferences merge(ColumnReferences columnReferences) {
            if (columnReferences == null) {
                return this;
            }
            ColumnReferences columnReferences2 = new ColumnReferences();
            columnReferences2.addReferences(this.references);
            columnReferences2.addReferences(columnReferences.references);
            columnReferences2.hasCalculation = this.hasCalculation || columnReferences.hasCalculation;
            columnReferences2.hasAggregation = this.hasAggregation || columnReferences.hasAggregation;
            return columnReferences2;
        }

        boolean isSimpleSingleColumnProject() {
            return (this.references.size() != 1 || this.hasAggregation || this.hasCalculation) ? false : true;
        }
    }

    /* loaded from: input_file:org/apache/kylin/query/mask/QueryDependentColumnMask$ResultColumnMaskInfo.class */
    public static class ResultColumnMaskInfo {
        public boolean maskAsNull = false;
        public List<ResultDependentValues> dependentValues = new LinkedList();

        public boolean needMask() {
            return !this.dependentValues.isEmpty() || this.maskAsNull;
        }

        void addDependentValues(ResultDependentValues resultDependentValues) {
            this.dependentValues.add(resultDependentValues);
        }
    }

    /* loaded from: input_file:org/apache/kylin/query/mask/QueryDependentColumnMask$ResultDependentValues.class */
    public static class ResultDependentValues {
        public int colIdx;
        public Set<String> values;

        public ResultDependentValues(int i, String[] strArr) {
            this.colIdx = i;
            this.values = Sets.newHashSet(strArr);
        }
    }

    public QueryDependentColumnMask(String str, KylinConfig kylinConfig) {
        this.defaultDatabase = NProjectManager.getInstance(kylinConfig).getProject(str).getDefaultDatabase();
        QueryContext.AclInfo aclInfo = QueryContext.current().getAclInfo();
        if (aclInfo != null) {
            this.dependentInfo = AclTCRManager.getInstance(kylinConfig, str).getDependentColumns(aclInfo.getUsername(), aclInfo.getGroups());
        }
    }

    public QueryDependentColumnMask(String str, DependentColumnInfo dependentColumnInfo) {
        this.defaultDatabase = str;
        this.dependentInfo = dependentColumnInfo;
    }

    @Override // org.apache.kylin.query.mask.QueryResultMask
    public void doSetRootRelNode(RelNode relNode) {
        this.rootRelNode = relNode;
    }

    @Override // org.apache.kylin.query.mask.QueryResultMask
    public void init() {
        if (!$assertionsDisabled && this.rootRelNode == null) {
            throw new AssertionError();
        }
        this.resultColumnMaskInfos = buildResultColumnMaskInfo(getRefCols(this.rootRelNode));
        Iterator<ResultColumnMaskInfo> it2 = this.resultColumnMaskInfos.iterator();
        while (it2.hasNext()) {
            if (it2.next().needMask()) {
                this.needMask = true;
                return;
            }
        }
    }

    @Override // org.apache.kylin.query.mask.QueryResultMask
    public Dataset<Row> doMaskResult(Dataset<Row> dataset) {
        if (this.dependentInfo == null || this.rootRelNode == null || !this.dependentInfo.needMask()) {
            return dataset;
        }
        if (this.resultColumnMaskInfos == null) {
            init();
        }
        return !this.needMask ? dataset : doResultMaskInternal(dataset);
    }

    private Dataset<Row> doResultMaskInternal(Dataset<Row> dataset) {
        Column[] columnArr = new Column[dataset.columns().length];
        Dataset<Row> dFToDFWithIndexedColumns = MaskUtil.dFToDFWithIndexedColumns(dataset);
        for (int i = 0; i < dFToDFWithIndexedColumns.columns().length; i++) {
            ResultColumnMaskInfo resultColumnMaskInfo = this.resultColumnMaskInfos.get(i);
            if (!resultColumnMaskInfo.needMask()) {
                columnArr[i] = dFToDFWithIndexedColumns.col(dFToDFWithIndexedColumns.columns()[i]);
            } else if (resultColumnMaskInfo.maskAsNull) {
                columnArr[i] = new Column(new Literal((Object) null, dFToDFWithIndexedColumns.schema().fields()[i].dataType())).as(dFToDFWithIndexedColumns.columns()[i]);
            } else {
                try {
                    columnArr[i] = new Column(dFToDFWithIndexedColumns.sparkSession().sessionState().sqlParser().parseExpression(String.format(Locale.ROOT, "CASE WHEN (%s) THEN `%s` ELSE NULL END", maskDependentCondition(dFToDFWithIndexedColumns, resultColumnMaskInfo), dFToDFWithIndexedColumns.columns()[i]))).as(dFToDFWithIndexedColumns.columns()[i]);
                } catch (ParseException e) {
                    throw new KylinException((ErrorCodeSupplier) ServerErrorCode.ACL_DEPENDENT_COLUMN_PARSE_ERROR, (Throwable) e);
                }
            }
        }
        return dFToDFWithIndexedColumns.select(columnArr).toDF(dataset.columns());
    }

    private String maskDependentCondition(Dataset<Row> dataset, ResultColumnMaskInfo resultColumnMaskInfo) {
        StringBuilder sb = new StringBuilder();
        for (ResultDependentValues resultDependentValues : resultColumnMaskInfo.dependentValues) {
            String str = dataset.columns()[resultDependentValues.colIdx];
            if (sb.length() > 0) {
                sb.append(" AND ");
            }
            sb.append("(");
            sb.append("`").append(str).append("`");
            sb.append(" IN (");
            boolean z = true;
            for (String str2 : resultDependentValues.values) {
                if (!z) {
                    sb.append(",");
                }
                sb.append("'").append(str2).append("'");
                z = false;
            }
            sb.append("))");
        }
        return sb.toString();
    }

    private List<ResultColumnMaskInfo> buildResultColumnMaskInfo(List<ColumnReferences> list) {
        HashMap hashMap = new HashMap();
        int i = 0;
        for (ColumnReferences columnReferences : list) {
            if (columnReferences.isSimpleSingleColumnProject()) {
                hashMap.put(columnReferences.references.iterator().next(), Integer.valueOf(i));
            }
            i++;
        }
        LinkedList linkedList = new LinkedList();
        for (ColumnReferences columnReferences2 : list) {
            ResultColumnMaskInfo resultColumnMaskInfo = new ResultColumnMaskInfo();
            Iterator it2 = columnReferences2.references.iterator();
            while (it2.hasNext()) {
                Collection<DependentColumn> collection = this.dependentInfo.get((String) it2.next());
                if (!collection.isEmpty()) {
                    Iterator<DependentColumn> it3 = collection.iterator();
                    while (true) {
                        if (it3.hasNext()) {
                            DependentColumn next = it3.next();
                            Integer num = (Integer) hashMap.get(next.getDependentColumnIdentity());
                            if (num == null) {
                                resultColumnMaskInfo.maskAsNull = true;
                                break;
                            }
                            resultColumnMaskInfo.addDependentValues(new ResultDependentValues(num.intValue(), next.getDependentValues()));
                        }
                    }
                }
            }
            linkedList.add(resultColumnMaskInfo);
        }
        return linkedList;
    }

    private List<ColumnReferences> getRefCols(RelNode relNode) {
        if (relNode instanceof TableScan) {
            return getTableColRefs((TableScan) relNode);
        }
        if (relNode instanceof Values) {
            return (List) relNode.getRowType().getFieldList().stream().map(relDataTypeField -> {
                return new ColumnReferences();
            }).collect(Collectors.toList());
        }
        if (relNode instanceof Aggregate) {
            return getAggregateColRefs((Aggregate) relNode);
        }
        if (relNode instanceof Project) {
            return getProjectColRefs((Project) relNode);
        }
        if (relNode instanceof SetOp) {
            return getUnionColRefs((SetOp) relNode);
        }
        if (relNode instanceof Window) {
            return getWindowColRefs((Window) relNode);
        }
        LinkedList linkedList = new LinkedList();
        Iterator<RelNode> it2 = relNode.getInputs().iterator();
        while (it2.hasNext()) {
            linkedList.addAll(getRefCols(it2.next()));
        }
        return linkedList;
    }

    private List<ColumnReferences> getWindowColRefs(Window window) {
        List<ColumnReferences> refCols = getRefCols(window.getInput(0));
        LinkedList linkedList = new LinkedList(refCols);
        for (RexNode rexNode : (List) window.groups.stream().flatMap(group -> {
            return group.aggCalls.stream();
        }).collect(Collectors.toList())) {
            ColumnReferences columnReferences = new ColumnReferences();
            Iterator<Integer> it2 = RelOptUtil.InputFinder.bits(rexNode).iterator();
            while (it2.hasNext()) {
                Integer next = it2.next();
                if (next.intValue() < refCols.size() && refCols.get(next.intValue()) != null) {
                    columnReferences = columnReferences.merge(refCols.get(next.intValue()));
                }
            }
            linkedList.add(columnReferences);
        }
        return linkedList;
    }

    private List<ColumnReferences> getUnionColRefs(SetOp setOp) {
        List<ColumnReferences> linkedList = new LinkedList();
        Iterator<RelNode> it2 = setOp.getInputs().iterator();
        while (it2.hasNext()) {
            List<ColumnReferences> refCols = getRefCols(it2.next());
            if (linkedList.isEmpty()) {
                linkedList = refCols;
            } else {
                for (int i = 0; i < refCols.size(); i++) {
                    linkedList.set(i, linkedList.get(i).merge(refCols.get(i)));
                }
            }
        }
        return linkedList;
    }

    private List<ColumnReferences> getProjectColRefs(Project project) {
        List<ColumnReferences> refCols = getRefCols(project.getInput(0));
        LinkedList linkedList = new LinkedList();
        for (RexNode rexNode : project.getChildExps()) {
            ColumnReferences columnReferences = new ColumnReferences();
            Iterator<Integer> it2 = RelOptUtil.InputFinder.bits(rexNode).iterator();
            while (it2.hasNext()) {
                columnReferences = columnReferences.merge(refCols.get(it2.next().intValue()));
            }
            if (!(rexNode instanceof RexInputRef)) {
                columnReferences.hasCalculation = true;
            }
            linkedList.add(columnReferences);
        }
        return linkedList;
    }

    private List<ColumnReferences> getAggregateColRefs(Aggregate aggregate) {
        List<ColumnReferences> refCols = getRefCols(aggregate.getInput(0));
        LinkedList linkedList = new LinkedList();
        Iterator<Integer> it2 = aggregate.getGroupSet().iterator();
        while (it2.hasNext()) {
            linkedList.add(refCols.get(it2.next().intValue()));
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            ColumnReferences columnReferences = new ColumnReferences();
            Iterator<Integer> it3 = aggregateCall.getArgList().iterator();
            while (it3.hasNext()) {
                columnReferences = columnReferences.merge(refCols.get(it3.next().intValue()));
            }
            columnReferences.hasAggregation = true;
            linkedList.add(columnReferences);
        }
        return linkedList;
    }

    private List<ColumnReferences> getTableColRefs(TableScan tableScan) {
        if (!$assertionsDisabled && tableScan.getTable().getQualifiedName().size() != 2) {
            throw new AssertionError();
        }
        String str = tableScan.getTable().getQualifiedName().get(0);
        String str2 = tableScan.getTable().getQualifiedName().get(1);
        ArrayList arrayList = new ArrayList();
        for (RelDataTypeField relDataTypeField : tableScan.getRowType().getFieldList()) {
            ColumnDesc columnDesc = ((KapTableScan) tableScan).getOlapTable().getSourceColumns().get(relDataTypeField.getIndex());
            if (columnDesc.isComputedColumn()) {
                arrayList.add(getCCReferences(columnDesc.getComputedColumnExpr()));
            } else {
                arrayList.add(new ColumnReferences(str + "." + str2 + "." + relDataTypeField.getName()));
            }
        }
        return arrayList;
    }

    private ColumnReferences getCCReferences(String str) {
        ColumnReferences columnReferences = new ColumnReferences();
        for (SqlIdentifier sqlIdentifier : MaskUtil.getCCCols(str)) {
            if (sqlIdentifier.names.size() == 2) {
                columnReferences.addReference(this.defaultDatabase + "." + sqlIdentifier.toString());
            } else if (sqlIdentifier.names.size() == 3) {
                columnReferences.addReference(sqlIdentifier.toString());
            }
        }
        columnReferences.hasCalculation = true;
        return columnReferences;
    }

    public List<ResultColumnMaskInfo> getResultColumnMaskInfos() {
        return this.resultColumnMaskInfos;
    }

    static {
        $assertionsDisabled = !QueryDependentColumnMask.class.desiredAssertionStatus();
    }
}
