Skip to content

Commit

Permalink
fix: Invert case_sensitive logic in StructType (apache#1147)
Browse files Browse the repository at this point in the history
* fix: Invert  logic in StructType

* Add test for StructType.field_by_name

* Remove var I forgot about.

* Fix formatting post-lint
  • Loading branch information
AnthonyLam authored Sep 9, 2024
1 parent d587e67 commit 9b9ed53
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pyiceberg/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,13 @@ def field(self, field_id: int) -> Optional[NestedField]:

def field_by_name(self, name: str, case_sensitive: bool = True) -> Optional[NestedField]:
if case_sensitive:
name_lower = name.lower()
for field in self.fields:
if field.name.lower() == name_lower:
if field.name == name:
return field
else:
name_lower = name.lower()
for field in self.fields:
if field.name == name:
if field.name.lower() == name_lower:
return field
return None

Expand Down
14 changes: 14 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,20 @@ def test_struct_type() -> None:
assert type_var == pickle.loads(pickle.dumps(type_var))


def test_struct_field_by_name() -> None:
lower_field = NestedField(1, "lower_case_field", IntegerType(), required=True)
upper_field = NestedField(2, "UPPER_CASE_FIELD", IntegerType(), required=True)
type_var = StructType(lower_field, upper_field)

assert type_var.field_by_name("lower_case_field", case_sensitive=False) == lower_field
assert type_var.field_by_name("upper_case_field", case_sensitive=False) == upper_field
assert type_var.field_by_name("nonexistent_field", case_sensitive=False) is None

assert type_var.field_by_name("lower_case_field", case_sensitive=True) == lower_field
assert type_var.field_by_name("upper_case_field", case_sensitive=True) is None
assert type_var.field_by_name("nonexistent_field", case_sensitive=True) is None


def test_list_type() -> None:
type_var = ListType(
1,
Expand Down

0 comments on commit 9b9ed53

Please sign in to comment.