Skip to content

Commit

Permalink
[fix](Nereids) support check authorization for view but skip check in…
Browse files Browse the repository at this point in the history
… the view (#31289)

move UserAuthentication in BindRelation, support check authorization view but skip check in the view

relate pr: #23295
  • Loading branch information
924060929 authored and Doris-Extras committed Feb 23, 2024
1 parent 9a40b6c commit b5ec1e7
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public class CascadesContext implements ScheduleContext {
private boolean isLeadingJoin = false;

private final Map<String, Hint> hintMap = Maps.newLinkedHashMap();
private final boolean shouldCheckRelationAuthentication;

/**
* Constructor of OptimizerContext.
Expand All @@ -128,7 +129,7 @@ public class CascadesContext implements ScheduleContext {
*/
private CascadesContext(Optional<CascadesContext> parent, Optional<CTEId> currentTree,
StatementContext statementContext, Plan plan, Memo memo,
CTEContext cteContext, PhysicalProperties requireProperties) {
CTEContext cteContext, PhysicalProperties requireProperties, boolean shouldCheckRelationAuthentication) {
this.parent = Objects.requireNonNull(parent, "parent should not null");
this.currentTree = Objects.requireNonNull(currentTree, "currentTree should not null");
this.statementContext = Objects.requireNonNull(statementContext, "statementContext should not null");
Expand All @@ -142,6 +143,7 @@ private CascadesContext(Optional<CascadesContext> parent, Optional<CTEId> curren
this.subqueryExprIsAnalyzed = new HashMap<>();
this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable());
this.materializationContexts = new ArrayList<>();
this.shouldCheckRelationAuthentication = shouldCheckRelationAuthentication;
}

/**
Expand All @@ -150,7 +152,13 @@ private CascadesContext(Optional<CascadesContext> parent, Optional<CTEId> curren
public static CascadesContext initContext(StatementContext statementContext,
Plan initPlan, PhysicalProperties requireProperties) {
return newContext(Optional.empty(), Optional.empty(), statementContext,
initPlan, new CTEContext(), requireProperties);
initPlan, new CTEContext(), requireProperties, true);
}

public static CascadesContext initViewContext(StatementContext statementContext,
Plan initPlan, PhysicalProperties requireProperties) {
return newContext(Optional.empty(), Optional.empty(), statementContext,
initPlan, new CTEContext(), requireProperties, false);
}

/**
Expand All @@ -159,13 +167,14 @@ public static CascadesContext initContext(StatementContext statementContext,
public static CascadesContext newContextWithCteContext(CascadesContext cascadesContext,
Plan initPlan, CTEContext cteContext) {
return newContext(Optional.of(cascadesContext), Optional.empty(),
cascadesContext.getStatementContext(), initPlan, cteContext, PhysicalProperties.ANY);
cascadesContext.getStatementContext(), initPlan, cteContext, PhysicalProperties.ANY,
cascadesContext.shouldCheckRelationAuthentication);
}

public static CascadesContext newCurrentTreeContext(CascadesContext context) {
return CascadesContext.newContext(context.getParent(), context.getCurrentTree(), context.getStatementContext(),
context.getRewritePlan(), context.getCteContext(),
context.getCurrentJobContext().getRequiredProperties());
context.getCurrentJobContext().getRequiredProperties(), context.shouldCheckRelationAuthentication);
}

/**
Expand All @@ -174,13 +183,14 @@ public static CascadesContext newCurrentTreeContext(CascadesContext context) {
public static CascadesContext newSubtreeContext(Optional<CTEId> subtree, CascadesContext context,
Plan plan, PhysicalProperties requireProperties) {
return CascadesContext.newContext(Optional.of(context), subtree, context.getStatementContext(),
plan, context.getCteContext(), requireProperties);
plan, context.getCteContext(), requireProperties, context.shouldCheckRelationAuthentication);
}

private static CascadesContext newContext(Optional<CascadesContext> parent, Optional<CTEId> subtree,
StatementContext statementContext, Plan initPlan,
CTEContext cteContext, PhysicalProperties requireProperties) {
return new CascadesContext(parent, subtree, statementContext, initPlan, null, cteContext, requireProperties);
StatementContext statementContext, Plan initPlan, CTEContext cteContext,
PhysicalProperties requireProperties, boolean shouldCheckRelationAuthentication) {
return new CascadesContext(parent, subtree, statementContext, initPlan, null,
cteContext, requireProperties, shouldCheckRelationAuthentication);
}

public CascadesContext getRoot() {
Expand Down Expand Up @@ -636,6 +646,10 @@ public void setLeadingJoin(boolean leadingJoin) {
isLeadingJoin = leadingJoin;
}

public boolean shouldCheckRelationAuthentication() {
return shouldCheckRelationAuthentication;
}

public Map<String, Hint> getHintMap() {
return hintMap;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy;
import org.apache.doris.nereids.rules.analysis.SubqueryToApply;
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.SemiJoinCommute;

Expand Down Expand Up @@ -119,8 +118,7 @@ private static List<RewriteJob> buildAnalyzeViewJobs(Optional<CustomTableResolve
topDown(new EliminateLogicalSelectHint()),
bottomUp(
new BindRelation(customTableResolver),
new CheckPolicy(),
new UserAuthentication()
new CheckPolicy()
)
);
}
Expand All @@ -132,8 +130,7 @@ private static List<RewriteJob> buildAnalyzeJobs(Optional<CustomTableResolver> c
topDown(new EliminateLogicalSelectHint()),
bottomUp(
new BindRelation(customTableResolver),
new CheckPolicy(),
new UserAuthentication()
new CheckPolicy()
),
bottomUp(new BindExpression()),
bottomUp(new BindSlotWithPaths()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ private LogicalPlan bindWithCurrentDb(CascadesContext cascadesContext, UnboundRe
}

// TODO: should generate different Scan sub class according to table's type
LogicalPlan scan = getLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
LogicalPlan scan = getAndCheckLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
if (cascadesContext.isLeadingJoin()) {
LeadingHint leading = (LeadingHint) cascadesContext.getHintMap().get("Leading");
leading.putRelationIdAndTableName(Pair.of(unboundRelation.getRelationId(), tableName));
Expand All @@ -178,7 +178,7 @@ private LogicalPlan bind(CascadesContext cascadesContext, UnboundRelation unboun
if (table == null) {
table = RelationUtil.getTable(tableQualifier, cascadesContext.getConnectContext().getEnv());
}
return getLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
return getAndCheckLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
}

private LogicalPlan makeOlapScan(TableIf table, UnboundRelation unboundRelation, List<String> tableQualifier) {
Expand Down Expand Up @@ -234,7 +234,17 @@ private LogicalPlan makeOlapScan(TableIf table, UnboundRelation unboundRelation,
return scan;
}

private LogicalPlan getLogicalPlan(TableIf table, UnboundRelation unboundRelation, List<String> tableQualifier,
private LogicalPlan getAndCheckLogicalPlan(TableIf table, UnboundRelation unboundRelation,
List<String> tableQualifier, CascadesContext cascadesContext) {
// if current context is in the view, we can skip check authentication because
// the view already checked authentication
if (cascadesContext.shouldCheckRelationAuthentication()) {
UserAuthentication.checkPermission(table, cascadesContext.getConnectContext());
}
return doGetLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
}

private LogicalPlan doGetLogicalPlan(TableIf table, UnboundRelation unboundRelation, List<String> tableQualifier,
CascadesContext cascadesContext) {
switch (table.getType()) {
case OLAP:
Expand Down Expand Up @@ -289,7 +299,7 @@ private Plan parseAndAnalyzeView(String ddlSql, CascadesContext parentContext) {
if (parsedViewPlan instanceof UnboundResultSink) {
parsedViewPlan = (LogicalPlan) ((UnboundResultSink<?>) parsedViewPlan).child();
}
CascadesContext viewContext = CascadesContext.initContext(
CascadesContext viewContext = CascadesContext.initViewContext(
parentContext.getStatementContext(), parsedViewPlan, PhysicalProperties.ANY);
viewContext.newAnalyzer(true, customTableResolver).analyze();
// we should remove all group expression of the plan which in other memo, so the groupId would not conflict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,31 @@
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.qe.ConnectContext;

/**
* Check whether a user is permitted to scan specific tables.
*/
public class UserAuthentication extends OneAnalysisRuleFactory {

@Override
public Rule build() {
return logicalRelation()
.when(CatalogRelation.class::isInstance)
.thenApply(ctx -> checkPermission((CatalogRelation) ctx.root, ctx.connectContext))
.toRule(RuleType.RELATION_AUTHENTICATION);
}

private Plan checkPermission(CatalogRelation relation, ConnectContext connectContext) {
public class UserAuthentication {
/** checkPermission. */
public static void checkPermission(TableIf table, ConnectContext connectContext) {
if (table == null) {
return;
}
// do not check priv when replaying dump file
if (connectContext.getSessionVariable().isPlayNereidsDump()) {
return null;
}
TableIf table = relation.getTable();
if (table == null) {
return null;
return;
}
String tableName = table.getName();
DatabaseIf db = table.getDatabase();
// when table inatanceof FunctionGenTable,db will be null
if (db == null) {
return null;
return;
}
String dbName = db.getFullName();
CatalogIf catalog = db.getCatalog();
if (catalog == null) {
return null;
return;
}
String ctlName = catalog.getName();
// TODO: 2023/7/19 checkColumnsPriv
Expand All @@ -71,7 +58,5 @@ private Plan checkPermission(CatalogRelation relation, ConnectContext connectCon
ctlName + ": " + dbName + ": " + tableName);
throw new AnalysisException(message);
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ private void createDbAndTableForHmsCatalog(HMSExternalCatalog hmsCatalog) {
tbl.getType();
minTimes = 0;
result = TableIf.TableType.HMS_EXTERNAL_TABLE;

tbl.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down Expand Up @@ -169,6 +173,10 @@ private void createDbAndTableForHmsCatalog(HMSExternalCatalog hmsCatalog) {
view1.isSupportedHmsTable();
minTimes = 0;
result = true;

view1.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down Expand Up @@ -211,6 +219,10 @@ private void createDbAndTableForHmsCatalog(HMSExternalCatalog hmsCatalog) {
view2.isSupportedHmsTable();
minTimes = 0;
result = true;

view2.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down Expand Up @@ -253,6 +265,10 @@ private void createDbAndTableForHmsCatalog(HMSExternalCatalog hmsCatalog) {
view3.isSupportedHmsTable();
minTimes = 0;
result = true;

view3.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down Expand Up @@ -295,6 +311,10 @@ private void createDbAndTableForHmsCatalog(HMSExternalCatalog hmsCatalog) {
view4.isSupportedHmsTable();
minTimes = 0;
result = true;

view4.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ private void init(HMSExternalCatalog hmsCatalog) {
// mock initSchemaAndUpdateTime and do nothing
tbl.initSchemaAndUpdateTime();
minTimes = 0;

tbl.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down Expand Up @@ -203,6 +207,10 @@ private void init(HMSExternalCatalog hmsCatalog) {
// mock initSchemaAndUpdateTime and do nothing
tbl2.initSchemaAndUpdateTime();
minTimes = 0;

tbl2.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down Expand Up @@ -254,6 +262,10 @@ private void init(HMSExternalCatalog hmsCatalog) {
view1.getUpdateTime();
minTimes = 0;
result = NOW;

view1.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down Expand Up @@ -304,6 +316,10 @@ private void init(HMSExternalCatalog hmsCatalog) {
view2.getUpdateTime();
minTimes = 0;
result = NOW;

view2.getDatabase();
minTimes = 0;
result = db;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ suite("test_nereids_row_policy") {
sql "set enable_fallback_to_original_planner = false"
sql "SELECT * FROM ${tableName}"
}
def result3 = connect(user=user, password='123abc!@#', url=url) {
connect(user=user, password='123abc!@#', url=url) {
sql "set enable_nereids_planner = true"
sql "set enable_fallback_to_original_planner = false"
sql "SELECT * FROM ${viewName}"
test {
sql "SELECT * FROM ${viewName}"
exception "SELECT command denied to user"
}
}
assertEquals(size, result1.size())
assertEquals(size, result2.size())
assertEquals(size, result3.size())
}

def createPolicy = { name, predicate, type ->
Expand Down Expand Up @@ -79,6 +81,7 @@ suite("test_nereids_row_policy") {
sql "CREATE USER ${user} IDENTIFIED BY '123abc!@#'"
sql "GRANT SELECT_PRIV ON internal.${dbName}.${tableName} TO ${user}"

sql 'sync'

// no policy
assertQueryResult 3
Expand Down
Loading

0 comments on commit b5ec1e7

Please sign in to comment.