Skip to content

Commit

Permalink
add IN subquery decorrelation support & fix not handle NOT IN bug (#242)
Browse files Browse the repository at this point in the history
This check-in adds IN subquery decorrelation support for #220. And fix not handle NOT IN bug in #224.
  • Loading branch information
9DemonFox authored Oct 20, 2020
1 parent 7b96cb4 commit ff70dd9
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 43 deletions.
30 changes: 22 additions & 8 deletions qpmodel/ExprSubquery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,17 @@ public override Value ExecNonDistributed(ExecContext context, Row input)

public class InSubqueryExpr : SubqueryExpr
{
internal bool hasNot_;

// children_[0] is the expr of in-query
internal Expr expr_() => children_[0];

public override string ToString() => $"{expr_()} in @{subqueryid_}";
public InSubqueryExpr(Expr expr, SelectStmt query) : base(query) { children_.Add(expr); }
public override string ToString()
{
string ifnot = hasNot_ ? " not" : "";
return $"{expr_()}{ifnot} in @{subqueryid_}";
}
public InSubqueryExpr(Expr expr, SelectStmt query, bool hasNot) : base(query) { hasNot_ = hasNot; children_.Add(expr); }

public override void Bind(BindContext context)
{
Expand All @@ -268,7 +274,10 @@ public override Value ExecNonDistributed(ExecContext context, Row input)
// is also not copied thus multiple threads may racing updating cacheVal_. Lock the
// code section to prevent it. This also redu
if (isCacheable_ && cachedValSet_)
return (cachedVal_ as HashSet<Value>).Contains(expr);
{
var in_cache_flag = (cachedVal_ as HashSet<Value>).Contains(expr);
return hasNot_ ? !in_cache_flag : in_cache_flag;
}

var set = new HashSet<Value>();
query_.physicPlan_.Exec(l =>
Expand All @@ -279,7 +288,8 @@ public override Value ExecNonDistributed(ExecContext context, Row input)

cachedVal_ = set;
cachedValSet_ = true;
return set.Contains(expr);
var in_flag = set.Contains(expr);
return hasNot_ ? !in_flag : in_flag;
}
}

Expand All @@ -288,10 +298,12 @@ public override Value ExecNonDistributed(ExecContext context, Row input)
//
public class InListExpr : Expr
{
internal bool hasNot_;
internal Expr expr_() => children_[0];
internal List<Expr> inlist_() => children_.GetRange(1, children_.Count - 1);
public InListExpr(Expr expr, List<Expr> inlist)
public InListExpr(Expr expr, List<Expr> inlist, bool hasNot)
{
hasNot_ = hasNot;
children_.Add(expr); children_.AddRange(inlist);
type_ = new BoolType();
Debug.Assert(Clone().Equals(this));
Expand All @@ -317,17 +329,19 @@ public override Value Exec(ExecContext context, Row input)
return null;
List<Value> inlist = new List<Value>();
inlist_().ForEach(x => { inlist.Add(x.Exec(context, input)); });
return inlist.Exists(v.Equals);
var in_flag = inlist.Exists(v.Equals);
return hasNot_ ? !in_flag : in_flag;
}

public override string ToString()
{
var inlist = inlist_();
string ifnot = hasNot_ ? " not" : "";
if (inlist_().Count < 5)
return $"{expr_()} in ({string.Join(",", inlist)})";
return $"{expr_()}{ifnot} in ({string.Join(",", inlist)})";
else
{
return $"{expr_()} in ({string.Join(",", inlist.GetRange(0, 3))}, ... <Total: {inlist.Count}> )";
return $"{expr_()}{ifnot} in ({string.Join(",", inlist.GetRange(0, 3))}, ... <Total: {inlist.Count}> )";
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions qpmodel/SQLParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace qpmodel.sqlparser
// antlr requires user defined exception
public class AntlrParserException : Exception
{
public AntlrParserException(string msg) : base(msg) {}
public AntlrParserException(string msg) : base(msg) { }
}

public class SyntaxErrorListener : BaseErrorListener
Expand All @@ -51,7 +51,7 @@ public override void SyntaxError(IRecognizer recognizer, IToken offendingSymbol,
RecognitionException e)
{
var stack = ((Parser)recognizer).GetRuleInvocationStack();
string errormsg = $@"{string.Join("->", stack)} : {line}|{charPositionInLine}|{offendingSymbol}|{msg}";
string errormsg = $@"{string.Join("->", stack)} : {line}|{charPositionInLine}|{offendingSymbol}|{msg}";
throw new AntlrParserException(errormsg);
}
}
Expand Down Expand Up @@ -198,7 +198,8 @@ public override object VisitLogicOrExpr([NotNull] SQLiteParser.LogicOrExprContex
public override object VisitLogicNotExpr([NotNull] SQLiteParser.LogicNotExprContext context)
{
var expr = (Expr)Visit(context.logical_expr());
if (expr is ExistSubqueryExpr ee) {
if (expr is ExistSubqueryExpr ee)
{
// to simplify EXISTS subquery handling, we don't want an extra unary on top
ee.hasNot_ = !ee.hasNot_;
return ee;
Expand Down Expand Up @@ -235,17 +236,19 @@ public override object VisitUnaryexpr([NotNull] SQLiteParser.UnaryexprContext co
return new UnaryExpr(op, Visit(context.arith_expr()) as Expr);
}

// TODO add in subquery
public override object VisitInSubqueryExpr([NotNull] SQLiteParser.InSubqueryExprContext context)
{
Debug.Assert(context.K_IN() != null);

SelectStmt select = null;
List<Expr> inlist = null;
bool hasNot = (context.K_NOT() != null) ? true : false;
if (context.select_stmt() != null)
{
Debug.Assert(context.arith_expr().Count() == 1);
select = Visit(context.select_stmt()) as SelectStmt;
return new InSubqueryExpr(Visit(context.arith_expr(0)) as Expr, select);
return new InSubqueryExpr(Visit(context.arith_expr(0)) as Expr, select, hasNot);
}
else
{
Expand All @@ -254,7 +257,7 @@ public override object VisitInSubqueryExpr([NotNull] SQLiteParser.InSubqueryExpr
inlist.Add(Visit(v) as Expr);
Expr expr = inlist[0];
inlist.RemoveAt(0);
return new InListExpr(expr, inlist);
return new InListExpr(expr, inlist, hasNot);
}
}
public override object VisitCaseExpr([NotNull] SQLiteParser.CaseExprContext context)
Expand Down
77 changes: 76 additions & 1 deletion qpmodel/subquery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,78 @@ LogicNode scalarToSingleJoin(LogicNode planWithSubExpr, ScalarSubqueryExpr scala
return newplan;
}

LogicNode inToMarkJoin(LogicNode planWithSubExpr,InSubqueryExpr inExpr)
{
LogicNode nodeA = planWithSubExpr;
var nodeAIsOnMarkJoin =
nodeA is LogicFilter && (nodeA.child_() is LogicMarkJoin || nodeA.child_() is LogicSingleJoin);

// nodeB contains the join filter
var nodeB = inExpr.query_.logicPlan_;
var nodeBFilter = nodeB.filter_;
nodeB.NullifyFilter();

// nullify nodeA's filter: the rest is push to top filter. However,
// if nodeA is a Filter|MarkJoin, keep its mark filter.
var markerFilter = new ExprRef(new MarkerExpr(), 0);
var nodeAFilter = nodeA.filter_;

// consider SQL ...a1 in select b1 from...
// a1 is outerExpr and b1 is selectExpr
Expr outerExpr = inExpr.child_();
Debug.Assert(inExpr.query_.selection_.Count == 1);
Expr selectExpr = inExpr.query_.selection_[0];
BinExpr inToEqual = BinExpr.MakeBooleanExpr(outerExpr, selectExpr, "=");

if (nodeAIsOnMarkJoin)
nodeA.filter_ = markerFilter;
else
{
if (nodeAFilter != null)
{
// a1 > @1 and a2 > @2 and a3 > 2, scalarExpr = @1
// keeplist: a1 > @1 and a3 > 2
// andlist after removal: a2 > @2
// nodeAFilter = a1 > @1 and a3 > 2
//
var andlist = nodeAFilter.FilterToAndList();
var keeplist = andlist.Where(x => x.VisitEachExists(e => e.Equals(inExpr))).ToList();
andlist.RemoveAll(x => x.VisitEachExists(e => e.Equals(inExpr)));
if (andlist.Count == 0)
nodeA.NullifyFilter();
else
{
nodeA.filter_ = andlist.AndListToExpr();
if (keeplist.Count > 0)
nodeAFilter = keeplist.AndListToExpr();
else
nodeAFilter = markerFilter;
}
}
}
// make a mark join
LogicMarkJoin markjoin;
if (inExpr.hasNot_)
markjoin = new LogicMarkAntiSemiJoin(nodeA, nodeB);
else
markjoin = new LogicMarkSemiJoin(nodeA, nodeB);

// make a filter on top of the mark join collecting all filters
Expr topfilter;
if (nodeAIsOnMarkJoin)
topfilter = nodeAFilter.SearchAndReplace(inExpr, ConstExpr.MakeConstBool(true));
else
topfilter = nodeAFilter.SearchAndReplace(inExpr, markerFilter);
nodeBFilter.DeParameter(nodeA.InclusiveTableRefs());
topfilter = topfilter.AddAndFilter(nodeBFilter);
// TODO mutiple nested insubquery subquery
// seperate the overlapping code with existsToSubquery to a new method
// when the PR in #support nestted exist subquery pass
LogicFilter Filter = new LogicFilter(markjoin, topfilter);
Filter = new LogicFilter(Filter, inToEqual);
return Filter;
}

// A Xs B => A LOJ B if max1row is assured
LogicJoin singleJoin2OuterJoin(LogicSingleJoin singJoinNode)
{
Expand Down Expand Up @@ -396,7 +468,7 @@ bool nullRejectingSingleCondition(Expr condition)

// exists|quantified subquery => mark join
// scalar subquery => single join or LOJ if max1row output is assured
//
//
LogicNode oneSubqueryToJoin(LogicNode planWithSubExpr, SubqueryExpr subexpr)
{
LogicNode oldplan = planWithSubExpr;
Expand All @@ -413,6 +485,9 @@ LogicNode oneSubqueryToJoin(LogicNode planWithSubExpr, SubqueryExpr subexpr)
case ScalarSubqueryExpr ss:
newplan = scalarToSingleJoin(planWithSubExpr, ss);
break;
case InSubqueryExpr si:
newplan = inToMarkJoin(planWithSubExpr, si);
break;
default:
break;
}
Expand Down
14 changes: 4 additions & 10 deletions test/NistTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,6 @@ FROM PROJ
Assert.AreEqual(1, stmtResult.Count);
Assert.AreEqual(stmtResult[0][0].ToString(), "Alice");

/* BUG/Unsupported ? */
/* should return 1 row with 12 but returns 11 rows */
sql = @"
SELECT WORKS.HOURS
FROM WORKS
Expand All @@ -337,8 +335,8 @@ FROM PROJ
WHERE PROJ.BUDGET BETWEEN 5000 AND 40000);";
stmtResult = TU.ExecuteSQL(sql);
Assert.AreEqual("", TU.error_);
// Assert.AreEqual(1, stmtResult.Count);
// Assert.AreEqual(stmtResult[0][0].ToString(), "12");
Assert.AreEqual(1, stmtResult.Count);
Assert.AreEqual(stmtResult[0][0].ToString(), "12");

sql = @"
SELECT WORKS.HOURS
Expand All @@ -352,10 +350,6 @@ FROM PROJ
Assert.AreEqual(1, stmtResult.Count);
Assert.AreEqual(stmtResult[0][0].ToString(), "12");

/*
* BUG/Unsupported?
* Should return one row with 80 but returns 11 rows.
*/
sql = @"
SELECT HOURS
FROM WORKS
Expand All @@ -365,8 +359,8 @@ FROM WORKS
WHERE PNUM IN ('P1','P2','P4','P5','P6'));";
stmtResult = TU.ExecuteSQL(sql);
Assert.AreEqual("", TU.error_);
// Assert.AreEqual(1, stmtResult.Count);
// Assert.AreEqual(stmtResult[0][0].ToString(), "80");
Assert.AreEqual(1, stmtResult.Count);
Assert.AreEqual(stmtResult[0][0].ToString(), "80");

sql = @"
SELECT HOURS
Expand Down
61 changes: 56 additions & 5 deletions test/UnitTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ public void TestCSVReader()
[TestMethod]
public void TestStringLike()
{
Debug.Assert(Utils.StringLike("ABCDEF","a%")==false);
Debug.Assert(Utils.StringLike("ABCDEF", "A%")==true);
Debug.Assert(Utils.StringLike("ABCDEF","%A%")==true);
Debug.Assert(Utils.StringLike("ABCDEF","A")==false);
Debug.Assert(Utils.StringLike("ABCDEF", "a%") == false);
Debug.Assert(Utils.StringLike("ABCDEF", "A%") == true);
Debug.Assert(Utils.StringLike("ABCDEF", "%A%") == true);
Debug.Assert(Utils.StringLike("ABCDEF", "A") == false);
Debug.Assert(Utils.StringLike("ABCDEF", "%EF") == true);
Debug.Assert(Utils.StringLike("ABCDEF", "%DE") == false);
Debug.Assert(Utils.StringLike("ABCDEF", "A_C%") == true);
Expand Down Expand Up @@ -601,7 +601,8 @@ void TestTpchWithData()
Assert.AreEqual(1, result.Count);
Assert.AreEqual(true, result[0].ToString().Contains("15.23"));
// q15 cte
TU.ExecuteSQL(File.ReadAllText(files[15]), "", out _, option);
result = TU.ExecuteSQL(File.ReadAllText(files[15]), out _, option);
Assert.AreEqual(34, result.Count);
TU.ExecuteSQL(File.ReadAllText(files[16]), "", out _, option);
TU.ExecuteSQL(File.ReadAllText(files[17]), "", out _, option);
TU.ExecuteSQL(File.ReadAllText(files[18]), out _, option); // FIXME: .. or ... or ...
Expand Down Expand Up @@ -872,6 +873,56 @@ public void TestExistsSubquery()
}
}

[TestMethod]
public void TestInSubquery()
{
QueryOption option = new QueryOption();

for (int i = 0; i < 2; i++)
{
option.optimize_.use_memo_ = i == 0;
var phyplan = "";

// many NOT test, there are only IN and NOT IN supported in SQL.
var sql = "select a1 from a where a2 not not in (1,2)";
var result = ExecuteSQL(sql); Assert.IsNull(result);
Assert.IsTrue(TU.error_.Contains(@"no viable alternative at input 'a2 not not'"));

sql = "select a1 from a where a2 not not not in (1,2)";
result = ExecuteSQL(sql); Assert.IsNull(result);
Assert.IsTrue(TU.error_.Contains(@"no viable alternative at input 'a2 not not'"));

// List InSubquery
sql = "select a1 from a where a2 not in (1,2)";
TU.ExecuteSQL(sql, "2", out phyplan, option);
Assert.AreEqual(1, TU.CountStr(phyplan, "not in"));
sql = "select a1 from a where a2 in (1,2)";
TU.ExecuteSQL(sql, "0;1", out phyplan, option);
Assert.AreEqual(0, TU.CountStr(phyplan, "not in"));

// non-corelated InSubquery
sql = "select a1 from a where a2 not in (select b1 from b where b2>1)"; // not in (1,2)
TU.ExecuteSQL(sql, "2", out phyplan, option);
Assert.AreEqual(1, TU.CountStr(phyplan, "not in"));

sql = "select a1 from a where a2 in (select b1 from b where b2>1)"; // in (1,2)
TU.ExecuteSQL(sql, "0;1", out phyplan, option);
Assert.AreEqual(0, TU.CountStr(phyplan, "not in"));

// corelated InSubquery
sql = "select a1 from a where a2 in (select b2 from b where b2 = a1)";
TU.ExecuteSQL(sql, "", out phyplan, option);
Assert.AreEqual(2, TU.CountStr(phyplan, "#marker"));

sql = "select a1 from a where a2 not in (select b2 from b where b2 = a1)";
TU.ExecuteSQL(sql, "0;1;2", out phyplan, option);
Assert.AreEqual(2, TU.CountStr(phyplan, "#marker"));

sql = "select a1 from a where a2 in (select b2 from b where b1 = a1 and b3 > 2 ) and a1 > 0";
TU.ExecuteSQL(sql, "1;2", out phyplan, option);
}
}

[TestMethod]
public void TestScalarSubquery()
{
Expand Down
Loading

0 comments on commit ff70dd9

Please sign in to comment.