Skip to content

Commit

Permalink
removed checks that made pyright less unhappy
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime committed Oct 9, 2024
1 parent bb49c19 commit 92652f7
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 70 deletions.
1 change: 0 additions & 1 deletion src/backends/tests/arrow_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class TestCase:
]


@pytest.mark.interesting
class TestArrowTools:
"""Tests the functionality of the arrow tools package."""

Expand Down
69 changes: 0 additions & 69 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@
from gateway.converter.symbol_table import Field, SymbolTable


class InternalError(Exception):
"""Raised when there is an internal gateway issue (should not occur)."""


class ExpressionProcessingMode(Enum):
"""The mode of processing expressions."""

Expand Down Expand Up @@ -144,12 +140,7 @@ def lookup_function_by_name(self, name: str) -> ExtensionFunction:
def update_field_references(self, plan_id: int) -> None:
"""Use the field references using the specified portion of the plan."""
source_symbol = self._symbol_table.get_symbol(plan_id)
if not source_symbol:
raise InternalError(f'Could not find plan id {plan_id} constructed earlier.')
current_symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not current_symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
original_output_fields = current_symbol.output_fields
for symbol in source_symbol.output_fields:
new_symbol = symbol
Expand All @@ -161,9 +152,6 @@ def update_field_references(self, plan_id: int) -> None:
def find_field_by_name(self, field_name: str) -> int | None:
"""Look up the field name in the current set of field references."""
current_symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not current_symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
try:
return _find_reference_number(current_symbol.input_fields, field_name)
except ValueError:
Expand Down Expand Up @@ -896,9 +884,6 @@ def convert_read_named_table_relation(
schema = self.convert_arrow_schema(arrow_schema)

symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
if self._conversion_options.use_duckdb_struct_name_behavior:
for field_name in self.get_primary_names(schema):
symbol.output_fields.append(Field(field_name))
Expand Down Expand Up @@ -1156,9 +1141,6 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al
arrow_schema = self._backend.describe_files(paths)
schema = self.convert_arrow_schema(arrow_schema)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
for field_name in schema.names:
symbol.output_fields.append(Field(field_name))
if self._conversion_options.use_named_table_workaround:
Expand Down Expand Up @@ -1229,9 +1211,6 @@ def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon:
if not self._conversion_options.use_emits_instead_of_direct:
return algebra_pb2.RelCommon(direct=algebra_pb2.RelCommon.Direct())
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
emit = algebra_pb2.RelCommon.Emit()
if emit_overrides:
for field_number in emit_overrides:
Expand Down Expand Up @@ -1314,9 +1293,6 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge
self.update_field_references(rel.input.common.plan_id)
aggregate.common.CopyFrom(self.create_common_relation())
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')

# Start tracking the parts of the expressions we are interested in.
self._top_level_projects = []
Expand Down Expand Up @@ -1498,9 +1474,6 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a
# Now that we've processed the input, do the bookkeeping.
self.update_field_references(rel.input.common.plan_id)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')

# Find the length of each column in every row.
project1 = project_relation(
Expand Down Expand Up @@ -1660,9 +1633,6 @@ def full_line(fields: list[Field]) -> list[algebra_pb2.Expression]:
join2 = join_relation(project3, aggregate2)

symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
symbol.output_fields.clear()
symbol.output_fields.append(Field("show_string"))

Expand Down Expand Up @@ -1703,9 +1673,6 @@ def convert_with_columns_relation(
project = algebra_pb2.ProjectRel(input=input_rel)
self.update_field_references(rel.input.common.plan_id)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
proposed_expressions = [field_reference(i) for i in range(len(symbol.input_fields))]
for alias in rel.aliases:
if len(alias.name) != 1:
Expand All @@ -1732,9 +1699,6 @@ def convert_with_columns_renamed_relation(
"""Update the columns names based on the Spark with columns renamed relation."""
input_rel = self.convert_relation(rel.input)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
self.update_field_references(rel.input.common.plan_id)
symbol.output_fields.clear()
if hasattr(rel, "renames"):
Expand All @@ -1754,9 +1718,6 @@ def convert_drop_relation(self, rel: spark_relations_pb2.Drop) -> algebra_pb2.Re
project = algebra_pb2.ProjectRel(input=input_rel)
self.update_field_references(rel.input.common.plan_id)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
if rel.columns:
column_names = [c.unresolved_attribute.unparsed_identifier for c in rel.columns]
else:
Expand All @@ -1779,9 +1740,6 @@ def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.R
input_rel = self.convert_relation(rel.input)
self.update_field_references(rel.input.common.plan_id)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
if len(rel.column_names) != len(symbol.input_fields):
raise ValueError(
"column_names does not match the number of input fields at "
Expand Down Expand Up @@ -1858,9 +1816,6 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge
if not schema:
raise ValueError(f'Received an empty schema in plan id {self._current_plan_id}')
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
for field_name in schema.names:
symbol.output_fields.append(Field(field_name))
read.base_schema.CopyFrom(schema)
Expand All @@ -1872,9 +1827,6 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel:
# TODO -- Handle multithreading in the case with a persistent backend.
plan = self._sql_backend.convert_sql(rel.query)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
for field_name in plan.relations[0].root.names:
symbol.output_fields.append(Field(field_name))
# TODO -- Correctly capture all the used functions and extensions.
Expand Down Expand Up @@ -1940,16 +1892,10 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re
if rel.using_columns:
column_name = rel.using_columns[0]
left_symbol = self._symbol_table.get_symbol(rel.left.common.plan_id)
if not left_symbol:
raise InternalError(
f'Could not find plan id {rel.left.common.plan_id} constructed earlier.')
left_fields = left_symbol.output_fields
left_column_count = len(left_fields)
left_column_reference = _find_reference_number(left_fields, column_name)
right_symbol = self._symbol_table.get_symbol(rel.right.common.plan_id)
if not right_symbol:
raise InternalError(
f'Could not find plan id {rel.right.common.plan_id} constructed earlier.')
right_fields = right_symbol.output_fields
right_column_reference = _find_reference_number(right_fields, column_name)
join.expression.CopyFrom(
Expand All @@ -1962,9 +1908,6 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re

# Avoid emitting the join column twice.
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
if self._conversion_options.join_not_honoring_emit_workaround:
project = algebra_pb2.ProjectRel(input=algebra_pb2.Rel(join=join))
for column_number in range(len(symbol.output_fields)):
Expand Down Expand Up @@ -1996,9 +1939,6 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_
project = algebra_pb2.ProjectRel(input=input_rel)
self.update_field_references(rel.input.common.plan_id)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
for field_number, expr in enumerate(rel.expressions):
if expr.WhichOneof("expr_type") == "unresolved_regex":
regex = expr.unresolved_regex.col_name.replace("`", "")
Expand Down Expand Up @@ -2052,9 +1992,6 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) ->
self.update_field_references(rel.input.common.plan_id)
aggregate.common.CopyFrom(self.create_common_relation())
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
grouping = aggregate.groupings.add()
for idx, field in enumerate(symbol.input_fields):
grouping.grouping_expressions.append(field_reference(idx))
Expand Down Expand Up @@ -2136,9 +2073,6 @@ def convert_dropna_relation(self, rel: spark_relations_pb2.NADrop) -> algebra_pb
self.update_field_references(rel.input.common.plan_id)
filter_rel.common.CopyFrom(self.create_common_relation())
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
if rel.cols:
cols = [_find_reference_number(symbol.input_fields, col) for col in rel.cols]
else:
Expand Down Expand Up @@ -2234,9 +2168,6 @@ def convert_plan(self, plan: spark_pb2.Plan) -> plan_pb2.Plan:
if plan.HasField("root"):
rel_root = algebra_pb2.RelRoot(input=self.convert_relation(plan.root))
symbol = self._symbol_table.get_symbol(plan.root.common.plan_id)
if not symbol:
raise InternalError(
f'Could not find plan id {self._current_plan_id} constructed earlier.')
for field in symbol.output_fields:
rel_root.names.extend(field.output_names())
result.relations.append(plan_pb2.PlanRel(root=rel_root))
Expand Down

0 comments on commit 92652f7

Please sign in to comment.