Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
cteDef.id,
anchor,
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, None),
anchor.output.map(_.newInstance().exprId),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can add a new apply method in object UnionLoop, which takes the same parameters as before, and pass anchor.output.map(_.newInstance().exprId) additionally to construct UnionLoop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just using apply in object UnionLoop doesn't work because of default arguments.
in object UnionLoop, multiple overloaded alternatives of method apply define default arguments.
There are other possibilities to make this work, but I don't think any one change can work without modifying something in ResolveWithCTE.scala

maxDepth = cteDef.maxDepth)
cteDef.copy(child = alias.copy(child = loop))
}
Expand All @@ -99,6 +100,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
cteDef.id,
anchor,
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, None),
anchor.output.map(_.newInstance().exprId),
maxDepth = cteDef.maxDepth)
cteDef.copy(child = alias.copy(child = withCTE.copy(
plan = loop, cteDefs = newInnerCteDefs)))
Expand All @@ -118,6 +120,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
cteDef.id,
anchor,
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
anchor.output.map(_.newInstance().exprId),
maxDepth = cteDef.maxDepth)
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
}
Expand All @@ -142,6 +145,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
cteDef.id,
anchor,
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
anchor.output.map(_.newInstance().exprId),
maxDepth = cteDef.maxDepth)
cteDef.copy(child = alias.copy(child = columnAlias.copy(
child = withCTE.copy(plan = loop, cteDefs = newInnerCteDefs))))
Expand All @@ -166,6 +170,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
UnionLoopRef(cteDef.id, anchor.output, true),
isAll = false
),
anchor.output.map(_.newInstance().exprId),
maxDepth = cteDef.maxDepth
)
cteDef.copy(child = alias.copy(child = loop))
Expand Down Expand Up @@ -194,6 +199,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
UnionLoopRef(cteDef.id, anchor.output, true),
isAll = false
),
anchor.output.map(_.newInstance().exprId),
maxDepth = cteDef.maxDepth
)
cteDef.copy(child = alias.copy(child = withCTE.copy(
Expand All @@ -220,6 +226,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
UnionLoopRef(cteDef.id, anchor.output, true),
isAll = false
),
anchor.output.map(_.newInstance().exprId),
maxDepth = cteDef.maxDepth
)
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
Expand Down Expand Up @@ -251,6 +258,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
UnionLoopRef(cteDef.id, anchor.output, true),
isAll = false
),
anchor.output.map(_.newInstance().exprId),
maxDepth = cteDef.maxDepth
)
cteDef.copy(child = alias.copy(child = columnAlias.copy(
Expand Down Expand Up @@ -298,7 +306,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
columnNames: Option[Seq[String]]) = {
recursion.transformUpWithSubqueriesAndPruning(_.containsPattern(CTE)) {
case r: CTERelationRef if r.recursive && r.cteId == cteDefId =>
val ref = UnionLoopRef(r.cteId, anchor.output, false)
val ref = UnionLoopRef(r.cteId, anchor.output.map(_.newInstance()), false)
columnNames.map(UnresolvedSubqueryColumnAliases(_, ref)).getOrElse(ref)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ abstract class UnionBase extends LogicalPlan {

private lazy val lazyOutput: Seq[Attribute] = computeOutput()

private def computeOutput(): Seq[Attribute] = Union.mergeChildOutputs(children.map(_.output))
protected def computeOutput(): Seq[Attribute] = Union.mergeChildOutputs(children.map(_.output))

/**
* Maps the constraints containing a given (original) sequence of attributes to those with a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,31 @@ import org.apache.spark.sql.internal.SQLConf
* @param id The id of the loop, inherited from [[CTERelationDef]] within which the Union lived.
* @param anchor The plan of the initial element of the loop.
* @param recursion The plan that describes the recursion with an [[UnionLoopRef]] node.
* @param outputAttrIds The ids of UnionLoop's output attributes.
* @param limit An optional limit that can be pushed down to the node to stop the loop earlier.
* @param maxDepth Maximal number of iterations before we report an error.
*/
case class UnionLoop(
id: Long,
anchor: LogicalPlan,
recursion: LogicalPlan,
outputAttrIds: Seq[ExprId],
limit: Option[Int] = None,
maxDepth: Option[Int] = None) extends UnionBase {
override def children: Seq[LogicalPlan] = Seq(anchor, recursion)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): UnionLoop =
copy(anchor = newChildren(0), recursion = newChildren(1))

override protected def computeOutput(): Seq[Attribute] =
Union.mergeChildOutputs(children.map(_.output)).zip(outputAttrIds).map { case (x, id) =>
x.withExprId(id)
}

override def argString(maxFields: Int): String = {
id.toString + limit.map(", " + _.toString).getOrElse("") +
maxDepth.map(", " + _.toString).getOrElse("")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,20 @@ class ResolveRecursiveCTESuite extends AnalysisTest {
Seq(CTERelationDef(anchor.union(recursion).subquery("t"), cteId)))
}

val analyzed = getAnalyzer.execute(getBeforePlan())

val outputExprIds = analyzed match {
case WithCTE(_, cteDefs) =>
cteDefs.head.child match {
case SubqueryAlias(_, UnionLoop(_, _, _, exprIds, _, _)) =>
exprIds
}
}

def getAfterPlan(): LogicalPlan = {
val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false).subquery("t")
val cteDef = CTERelationDef(UnionLoop(cteId, anchor, recursion).subquery("t"), cteId)
val cteDef = CTERelationDef(UnionLoop(cteId, anchor, recursion,
outputExprIds).subquery("t"), cteId)
val cteRef = CTERelationRef(
cteId,
_resolved = true,
Expand All @@ -53,7 +64,7 @@ class ResolveRecursiveCTESuite extends AnalysisTest {
WithCTE(cteRef, Seq(cteDef))
}

comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan())
comparePlans(analyzed, getAfterPlan())
}

// Motivated by:
Expand All @@ -75,14 +86,24 @@ class ResolveRecursiveCTESuite extends AnalysisTest {
WithCTE(cteRef.copy(recursive = false), Seq(cteDef))
}

val analyzed = getAnalyzer.execute(getBeforePlan())

val outputExprIds = analyzed match {
case WithCTE(_, cteDefs) =>
cteDefs.head.child match {
case SubqueryAlias(_, Project(_, UnionLoop(_, _, _, exprIds, _, _))) =>
exprIds
}
}

def getAfterPlan(): LogicalPlan = {
val col = anchor.output.head
val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false)
.select(col.as("n"))
.subquery("t")
val cteDef = CTERelationDef(
UnionLoop(cteId, anchor, recursion).select(col.as("n")).subquery("t"),
cteId)
UnionLoop(cteId, anchor, recursion, outputExprIds)
.select(col.as("n")).subquery("t"), cteId)
val cteRef = CTERelationRef(
cteId,
_resolved = true,
Expand All @@ -91,6 +112,6 @@ class ResolveRecursiveCTESuite extends AnalysisTest {
WithCTE(cteRef, Seq(cteDef))
}

comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan())
comparePlans(analyzed, getAfterPlan())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
GlobalLimitExec(child = planLater(child), offset = offset) :: Nil
case union: logical.Union =>
execution.UnionExec(union.children.map(planLater)) :: Nil
case u @ logical.UnionLoop(id, anchor, recursion, limit, maxDepth) =>
case u @ logical.UnionLoop(id, anchor, recursion, _, limit, maxDepth) =>
execution.UnionLoopExec(id, anchor, recursion, u.output, limit, maxDepth) :: Nil
case g @ logical.Generate(generator, _, outer, _, _, child) =>
execution.GenerateExec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,72 @@ WithCTE
+- CTERelationRef xxxx, true, [n#x], false, false


-- !query
WITH RECURSIVE tmp(x) AS (
values (1), (2), (3), (4), (5)
), rcte(x, y) AS (
SELECT x, x FROM tmp WHERE x = 1
UNION ALL
SELECT x + 1, x FROM rcte WHERE x < 5
)
SELECT * FROM rcte
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias tmp
: +- Project [col1#x AS x#x]
: +- LocalRelation [col1#x]
:- CTERelationDef xxxx, false
: +- SubqueryAlias rcte
: +- Project [x#x AS x#x, x#x AS y#x]
: +- UnionLoop xxxx
: :- Project [x#x, x#x]
: : +- Filter (x#x = 1)
: : +- SubqueryAlias tmp
: : +- CTERelationRef xxxx, true, [x#x], false, false, 5
: +- Project [(x#x + 1) AS (x + 1)#x, x#x]
: +- Filter (x#x < 5)
: +- SubqueryAlias rcte
: +- Project [x#x AS x#x, x#x AS y#x]
: +- UnionLoopRef xxxx, [x#x, x#x], false
+- Project [x#x, y#x]
+- SubqueryAlias rcte
+- CTERelationRef xxxx, true, [x#x, y#x], false, false


-- !query
WITH RECURSIVE tmp(x) AS (
values (1), (2), (3), (4), (5)
), rcte(x, y, z, t) AS (
SELECT x, x, x, x FROM tmp WHERE x = 1
UNION ALL
SELECT x + 1, x, y + 1, y FROM rcte WHERE x < 5
)
SELECT * FROM rcte
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias tmp
: +- Project [col1#x AS x#x]
: +- LocalRelation [col1#x]
:- CTERelationDef xxxx, false
: +- SubqueryAlias rcte
: +- Project [x#x AS x#x, x#x AS y#x, x#x AS z#x, x#x AS t#x]
: +- UnionLoop xxxx
: :- Project [x#x, x#x, x#x, x#x]
: : +- Filter (x#x = 1)
: : +- SubqueryAlias tmp
: : +- CTERelationRef xxxx, true, [x#x], false, false, 5
: +- Project [(x#x + 1) AS (x + 1)#x, x#x, (y#x + 1) AS (y + 1)#x, y#x]
: +- Filter (x#x < 5)
: +- SubqueryAlias rcte
: +- Project [x#x AS x#x, x#x AS y#x, x#x AS z#x, x#x AS t#x]
: +- UnionLoopRef xxxx, [x#x, x#x, x#x, x#x], false
+- Project [x#x, y#x, z#x, t#x]
+- SubqueryAlias rcte
+- CTERelationRef xxxx, true, [x#x, y#x, z#x, t#x], false, false


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,26 @@ WITH RECURSIVE t1 AS (
SELECT n+1 FROM t2 WHERE n < 5)
SELECT * FROM t1;

-- Recursive CTE with multiple of the same reference in the anchor, which get referenced differently subsequent iterations.
WITH RECURSIVE tmp(x) AS (
values (1), (2), (3), (4), (5)
), rcte(x, y) AS (
SELECT x, x FROM tmp WHERE x = 1
UNION ALL
SELECT x + 1, x FROM rcte WHERE x < 5
)
SELECT * FROM rcte;

-- Recursive CTE with multiple of the same reference in the anchor, which get referenced as different variables in subsequent iterations.
WITH RECURSIVE tmp(x) AS (
values (1), (2), (3), (4), (5)
), rcte(x, y, z, t) AS (
SELECT x, x, x, x FROM tmp WHERE x = 1
UNION ALL
SELECT x + 1, x, y + 1, y FROM rcte WHERE x < 5
)
SELECT * FROM rcte;

-- Non-deterministic query with rand with seed
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1477,6 +1477,44 @@ struct<n:int>
5


-- !query
WITH RECURSIVE tmp(x) AS (
values (1), (2), (3), (4), (5)
), rcte(x, y) AS (
SELECT x, x FROM tmp WHERE x = 1
UNION ALL
SELECT x + 1, x FROM rcte WHERE x < 5
)
SELECT * FROM rcte
-- !query schema
struct<x:int,y:int>
-- !query output
1 1
2 1
3 2
4 3
5 4


-- !query
WITH RECURSIVE tmp(x) AS (
values (1), (2), (3), (4), (5)
), rcte(x, y, z, t) AS (
SELECT x, x, x, x FROM tmp WHERE x = 1
UNION ALL
SELECT x + 1, x, y + 1, y FROM rcte WHERE x < 5
)
SELECT * FROM rcte
-- !query schema
struct<x:int,y:int,z:int,t:int>
-- !query output
1 1 1 1
2 1 2 1
3 2 2 1
4 3 3 2
5 4 4 3


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
Expand Down