diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 68b15345a..278233ee4 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -215,7 +215,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = fetch.getCount.getAsLong.intValue() + val limit = fetch.getCount.orElse(-1).intValue() val offset = fetch.getOffset.intValue() val toLiteral = (i: Int) => Literal(i, IntegerType) if (limit >= 0) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 3451c1166..5cf4cbc97 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -206,12 +206,15 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = { - relation.Fetch + val builder = relation.Fetch .builder() .input(visit(child)) .offset(offset) - .count(limit) - .build() + if (limit != -1) { + builder.count(limit) + } + + builder.build() } override def visitGlobalLimit(p: GlobalLimit): relation.Rel = {