Skip to content

Commit

Permalink
rfctr: improve xmlchemy typing
Browse files Browse the repository at this point in the history
  • Loading branch information
scanny committed Nov 3, 2023
1 parent a1c6b4f commit 523328c
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 105 deletions.
4 changes: 3 additions & 1 deletion src/docx/enum/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def from_xml(cls, xml_value: str | None) -> Self:
@classmethod
def to_xml(cls: Type[_T], value: int | _T | None) -> str | None:
"""XML value of this enum member, generally an XML attribute value."""
return cls(value).xml_value
# -- presence of multi-arg `__new__()` method fools type-checker, but getting a
# -- member by its value using EnumCls(val) works as usual.
return cls(value).xml_value # pyright: ignore[reportGeneralTypeIssues]


class DocsPageFormatter:
Expand Down
93 changes: 56 additions & 37 deletions src/docx/oxml/simpletypes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pyright: reportImportCycles=false

"""Simple-type classes, corresponding to ST_* schema items.
These provide validation and format translation for values stored in XML element
Expand All @@ -7,40 +9,49 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Tuple

from docx.exceptions import InvalidXmlError
from docx.shared import Emu, Pt, RGBColor, Twips

if TYPE_CHECKING:
from docx import types as t
from docx.shared import Length


class BaseSimpleType:
"""Base class for simple-types."""

@classmethod
def from_xml(cls, xml_value: str):
def from_xml(cls, xml_value: str) -> Any:
return cls.convert_from_xml(xml_value)

@classmethod
def to_xml(cls, value):
def to_xml(cls, value: Any) -> str:
cls.validate(value)
str_value = cls.convert_to_xml(value)
return str_value

@classmethod
def convert_from_xml(cls, str_value: str) -> t.AbstractSimpleTypeMember:
def convert_from_xml(cls, str_value: str) -> Any:
return int(str_value)

@classmethod
def validate_int(cls, value):
def convert_to_xml(cls, value: Any) -> str:
...

@classmethod
def validate(cls, value: Any) -> None:
...

@classmethod
def validate_int(cls, value: object):
if not isinstance(value, int):
raise TypeError("value must be <type 'int'>, got %s" % type(value))

@classmethod
def validate_int_in_range(cls, value, min_inclusive, max_inclusive):
def validate_int_in_range(
cls, value: int, min_inclusive: int, max_inclusive: int
) -> None:
cls.validate_int(value)
if value < min_inclusive or value > max_inclusive:
raise ValueError(
Expand All @@ -57,15 +68,15 @@ def validate_string(cls, value: Any) -> str:

class BaseIntType(BaseSimpleType):
@classmethod
def convert_from_xml(cls, str_value):
def convert_from_xml(cls, str_value: str) -> int:
return int(str_value)

@classmethod
def convert_to_xml(cls, value):
def convert_to_xml(cls, value: int) -> str:
return str(value)

@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_int(value)


Expand All @@ -84,34 +95,38 @@ def validate(cls, value: str):


class BaseStringEnumerationType(BaseStringType):
_members: Tuple[str, ...]

@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_string(value)
if value not in cls._members:
raise ValueError("must be one of %s, got '%s'" % (cls._members, value))


class XsdAnyUri(BaseStringType):
"""There's a regular expression this is supposed to meet but so far thinking
spending cycles on validating wouldn't be worth it for the number of programming
errors it would catch."""
"""There's a regex in the spec this is supposed to meet...
but current assessment is that spending cycles on validating wouldn't be worth it
for the number of programming errors it would catch.
"""


class XsdBoolean(BaseSimpleType):
@classmethod
def convert_from_xml(cls, str_value):
def convert_from_xml(cls, str_value: str) -> bool:
if str_value not in ("1", "0", "true", "false"):
raise InvalidXmlError(
"value must be one of '1', '0', 'true' or 'false', got '%s'" % str_value
)
return str_value in ("1", "true")

@classmethod
def convert_to_xml(cls, value):
def convert_to_xml(cls, value: bool) -> str:
return {True: "1", False: "0"}[value]

@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
if value not in (True, False):
raise TypeError(
"only True or False (and possibly None) may be assigned, got"
Expand All @@ -130,13 +145,13 @@ class XsdId(BaseStringType):

class XsdInt(BaseIntType):
@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_int_in_range(value, -2147483648, 2147483647)


class XsdLong(BaseIntType):
@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_int_in_range(value, -9223372036854775808, 9223372036854775807)


Expand All @@ -157,13 +172,13 @@ class XsdToken(BaseStringType):

class XsdUnsignedInt(BaseIntType):
@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_int_in_range(value, 0, 4294967295)


class XsdUnsignedLong(BaseIntType):
@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_int_in_range(value, 0, 18446744073709551615)


Expand All @@ -178,7 +193,7 @@ def validate(cls, value: str) -> None:

class ST_BrType(XsdString):
@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_string(value)
valid_values = ("page", "column", "textWrapping")
if value not in valid_values:
Expand All @@ -187,19 +202,19 @@ def validate(cls, value):

class ST_Coordinate(BaseIntType):
@classmethod
def convert_from_xml(cls, str_value):
def convert_from_xml(cls, str_value: str) -> Length:
if "i" in str_value or "m" in str_value or "p" in str_value:
return ST_UniversalMeasure.convert_from_xml(str_value)
return Emu(int(str_value))

@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
ST_CoordinateUnqualified.validate(value)


class ST_CoordinateUnqualified(XsdLong):
@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_int_in_range(value, -27273042329600, 27273042316900)


Expand All @@ -213,19 +228,23 @@ class ST_DrawingElementId(XsdUnsignedInt):

class ST_HexColor(BaseStringType):
@classmethod
def convert_from_xml(cls, str_value):
def convert_from_xml( # pyright: ignore[reportIncompatibleMethodOverride]
cls, str_value: str
) -> RGBColor | str:
if str_value == "auto":
return ST_HexColorAuto.AUTO
return RGBColor.from_string(str_value)

@classmethod
def convert_to_xml(cls, value):
def convert_to_xml( # pyright: ignore[reportIncompatibleMethodOverride]
cls, value: RGBColor
) -> str:
"""Keep alpha hex numerals all uppercase just for consistency."""
# expecting 3-tuple of ints in range 0-255
return "%02X%02X%02X" % value

@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
# must be an RGBColor object ---
if not isinstance(value, RGBColor):
raise ValueError(
Expand Down Expand Up @@ -269,7 +288,7 @@ class ST_Merge(XsdStringEnumeration):

class ST_OnOff(XsdBoolean):
@classmethod
def convert_from_xml(cls, str_value):
def convert_from_xml(cls, str_value: str) -> bool:
if str_value not in ("1", "0", "true", "false", "on", "off"):
raise InvalidXmlError(
"value must be one of '1', '0', 'true', 'false', 'on', or 'o"
Expand All @@ -280,11 +299,11 @@ def convert_from_xml(cls, str_value):

class ST_PositiveCoordinate(XsdLong):
@classmethod
def convert_from_xml(cls, str_value):
def convert_from_xml(cls, str_value: str) -> Length:
return Emu(int(str_value))

@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_int_in_range(value, 0, 27273042316900)


Expand All @@ -294,13 +313,13 @@ class ST_RelationshipId(XsdString):

class ST_SignedTwipsMeasure(XsdInt):
@classmethod
def convert_from_xml(cls, str_value):
def convert_from_xml(cls, str_value: str) -> Length:
if "i" in str_value or "m" in str_value or "p" in str_value:
return ST_UniversalMeasure.convert_from_xml(str_value)
return Twips(int(str_value))

@classmethod
def convert_to_xml(cls, value):
def convert_to_xml(cls, value: int | Length) -> str:
emu = Emu(value)
twips = emu.twips
return str(twips)
Expand All @@ -312,7 +331,7 @@ class ST_String(XsdString):

class ST_TblLayoutType(XsdString):
@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_string(value)
valid_values = ("fixed", "autofit")
if value not in valid_values:
Expand All @@ -321,7 +340,7 @@ def validate(cls, value):

class ST_TblWidth(XsdString):
@classmethod
def validate(cls, value):
def validate(cls, value: Any) -> None:
cls.validate_string(value)
valid_values = ("auto", "dxa", "nil", "pct")
if value not in valid_values:
Expand All @@ -330,13 +349,13 @@ def validate(cls, value):

class ST_TwipsMeasure(XsdUnsignedLong):
@classmethod
def convert_from_xml(cls, str_value):
def convert_from_xml(cls, str_value: str) -> Length:
if "i" in str_value or "m" in str_value or "p" in str_value:
return ST_UniversalMeasure.convert_from_xml(str_value)
return Twips(int(str_value))

@classmethod
def convert_to_xml(cls, value):
def convert_to_xml(cls, value: int | Length) -> str:
emu = Emu(value)
twips = emu.twips
return str(twips)
Expand Down
Loading

0 comments on commit 523328c

Please sign in to comment.