Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make id fields non-optional but retain auto-numbering #31

Merged
merged 1 commit into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/ome_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]:
cannot_have_required_args = base_type and members.has_non_default_args()
if cannot_have_required_args:
lines[0] += ", EMPTY"
if members.has_nonref_id():
lines[0] += ", AUTO_SEQUENCE"

lines += ["@ome_dataclass", f"class {component.local_name}{base_name}:"]
# FIXME: Refactor to remove BinData special-case.
Expand Down Expand Up @@ -436,6 +438,20 @@ def is_decimal(self) -> bool:
self.component.schema.builtin_types()["decimal"]
)

@property
def is_nonref_id(self) -> bool:
if self.identifier == "id":
gp = self.component.parent.parent
if not gp.base_type or gp.base_type.local_name != "Reference":
return True
return False

@property
def is_ref_id(self) -> bool:
if self.identifier == "id":
return not self.is_nonref_id
return False

@property
def parent_name(self) -> str:
"""Local name of component's first named ancestor."""
Expand Down Expand Up @@ -553,7 +569,9 @@ def default_val_str(self) -> str:
if self.key in OVERRIDES:
default = OVERRIDES[self.key].default
return f" = {default}" if default else ""
if not self.is_optional:
elif self.is_nonref_id:
return " = AUTO_SEQUENCE # type: ignore"
elif not self.is_optional:
return ""

if not self.max_occurs:
Expand All @@ -578,8 +596,8 @@ def max_occurs(self) -> bool:
@property
def is_optional(self) -> bool:
# FIXME: hack. doesn't fully capture the restriction
if self.identifier == "id":
return True
if self.is_ref_id:
return False
if getattr(self.component.parent, "model", "") == "choice":
return True
if hasattr(self.component, "min_occurs"):
Expand Down Expand Up @@ -637,6 +655,9 @@ def body(self) -> List[str]:
def has_non_default_args(self) -> bool:
return any(not m.default_val_str for m in self._members)

def has_nonref_id(self) -> bool:
return any(m.is_nonref_id for m in self._members)

@property
def non_defaults(self) -> "MemberSet":
return MemberSet(m for m in self._members if not m.default_val_str)
Expand Down
15 changes: 6 additions & 9 deletions src/ome_types/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
if TYPE_CHECKING:
from pydantic.dataclasses import DataclassType

# Sentinel default value to support optional fields in dataclass subclasses.
EMPTY = object()
# Sentinel default value to support automatic numbering for id field values.
AUTO_SEQUENCE = object()


@validator("id", pre=True, always=True)
Expand All @@ -22,13 +25,10 @@ def validate_id(cls: Type[Any], value: Any) -> str:
If no value is provided, this validator provides and integer ID, and stores the
maximum previously-seen value on the class.
"""
from typing import ClassVar, Union
from typing import ClassVar

# get the required LSID type from the annotation
id_type = cls.__annotations__.get("id")
# (it will likely be an Optional[LSID])
if getattr(id_type, "__origin__", None) is Union:
id_type = getattr(id_type, "__args__")[0]
if not id_type:
return value

Expand All @@ -37,7 +37,7 @@ def validate_id(cls: Type[Any], value: Any) -> str:
cls._max_id = 0
cls.__annotations__["_max_id"] = ClassVar[int]

if not value:
if value is AUTO_SEQUENCE:
value = cls._max_id + 1
if isinstance(value, int):
v_id = value
Expand Down Expand Up @@ -137,11 +137,8 @@ def ome_dataclass(
"""

def wrap(cls: Type[Any]) -> DataclassType:
if "id" in getattr(cls, "__annotations__", {}):
if getattr(cls, "id", None) is AUTO_SEQUENCE:
setattr(cls, "validate_id", validate_id)
if not hasattr(cls, "id"):
setattr(cls, "id", None)

modify_post_init(cls)
if not repr:
modify_repr(cls)
Expand Down