Skip to content

Commit

Permalink
Move messages implementation to dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
elchupanebrej committed Dec 8, 2024
1 parent 21a6c90 commit 2c0b576
Show file tree
Hide file tree
Showing 6 changed files with 1,208 additions and 564 deletions.
15 changes: 9 additions & 6 deletions python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@ generate: ## Stub for the ancestor Makefile

generate-real: require install-deps
datamodel-codegen \
--output-model-type "pydantic_v2.BaseModel" \
--output-model-type "dataclasses.dataclass" \
--base-class ".json_converter.JSONDataclassMixin" \
--input $(HERE)../jsonschema/Envelope.json \
--output $(HERE)src/cucumber_messages/messages.py \
--use-generic-container-types \
--input-file-type=jsonschema \
--class-name Envelope \
--target-python-version=3.9 \
--allow-extra-fields \
--allow-population-by-field-name \
--snake-case-field \
--input-file-type=jsonschema \
--class-name Envelope \
--use-generic-container-types \
--use-double-quotes \
--use-exact-imports \
--use-field-description \
--use-union-operator \
--disable-timestamp \
--target-python-version=3.8
--disable-timestamp

require: ## Check requirements for the code generation (python is required)
@python --version >/dev/null 2>&1 || (echo "ERROR: python is required."; exit 1)
Expand Down
1 change: 1 addition & 0 deletions python/src/cucumber_messages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import json_converter
from .messages import *

ExpressionType = Type1
Expand Down
303 changes: 303 additions & 0 deletions python/src/cucumber_messages/json_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
import json
import re
import sys
import types
import typing
from dataclasses import fields, is_dataclass, Field, MISSING
from datetime import datetime, date
from enum import Enum
from typing import Any, Dict, List, Optional, Union, get_args, get_origin, TypeVar, Type, Tuple

T = TypeVar('T')


def camel_to_snake(name: str) -> str:
"""Convert string from camelCase to snake_case."""
# Validate field name - must start with letter or underscore and contain only alphanumeric and underscore
if not name or not (name[0].isalpha() or name[0] == '_') or not all(c.isalnum() or c == '_' for c in name):
raise ValueError(f"Invalid field name: {name}")

# Convert camelCase to snake_case
pattern = re.compile(r'(?<!^)(?=[A-Z])')
return pattern.sub('_', name).lower()


def snake_to_camel(name: str) -> str:
"""Convert string from snake_case to camelCase."""
# Validate field name - must start with letter or underscore and contain only alphanumeric and underscore
if not name or not (name[0].isalpha() or name[0] == '_') or not all(c.isalnum() or c == '_' for c in name):
raise ValueError(f"Invalid field name: {name}")

# Convert snake_case to camelCase
components = name.split('_')
return components[0] + ''.join(x.title() for x in components[1:])


class GenericTypeResolver:
"""Handles resolution of generic type hints and container types."""

def __init__(self, module_scope: types.ModuleType):
self.module_scope = module_scope
self.type_cache = {}

def resolve_container_type(self, container_name: str) -> Type:
container_types = {
'Sequence': typing.Sequence,
'List': list,
'Set': typing.Set,
'FrozenSet': typing.FrozenSet,
'Deque': typing.Deque
}
if container_name not in container_types:
raise ValueError(f"Unsupported container type: {container_name}")
return container_types[container_name]

def parse_generic_type(self, type_str: str) -> Tuple[Type, List[Type]]:
container_name = type_str[:type_str.index('[')].strip()
args_str = type_str[type_str.index('[') + 1:type_str.rindex(']')]

container_type = self.resolve_container_type(container_name)

if ',' in args_str:
arg_types = [self.resolve_type(arg.strip()) for arg in args_str.split(',')]
else:
arg_types = [self.resolve_type(args_str.strip())]

return container_type, arg_types

def resolve_type(self, type_str: str) -> Type:
if type_str in self.type_cache:
return self.type_cache[type_str]

if type_str in {'str', 'int', 'float', 'bool', 'Any'}:
resolved = Any if type_str == 'Any' else eval(type_str)
elif '[' in type_str:
container_type, arg_types = self.parse_generic_type(type_str)
resolved = container_type[tuple(arg_types) if len(arg_types) > 1 else arg_types[0]]
elif '|' in type_str:
resolved = self._resolve_union_type(type_str)
elif hasattr(self.module_scope, type_str):
resolved = getattr(self.module_scope, type_str)
else:
raise ValueError(f"Could not resolve type: {type_str}")

self.type_cache[type_str] = resolved
return resolved

def _resolve_union_type(self, type_str: str) -> Type:
types_str = [t.strip() for t in type_str.split('|')]
resolved_types = [
self.resolve_type(t)
for t in types_str
if t != 'None'
]
return resolved_types[0] if len(resolved_types) == 1 else Union[tuple(resolved_types)]


class DataclassJSONEncoder:
"""Handles encoding of dataclass instances to JSON-compatible dictionaries."""

@classmethod
def encode(cls, obj: Any) -> Any:
if is_dataclass(obj):
return cls._encode_dataclass(obj)
return cls._encode_value(obj)

@classmethod
def _encode_value(cls, value: Any) -> Any:
if isinstance(value, list):
return [cls.encode(item) for item in value]
if isinstance(value, dict):
return {key: cls.encode(val) for key, val in value.items()}
if isinstance(value, (datetime, date)):
return value.isoformat()
if isinstance(value, Enum):
return value.value
if isinstance(value, (str, int, float, bool, type(None))):
return value
raise TypeError(f"Object of type {type(value)} is not JSON serializable")

@classmethod
def _encode_dataclass(cls, obj: Any) -> Dict[str, Any]:
result = {}
for field in fields(obj):
value = getattr(obj, field.name)
if value is not None:
try:
result[snake_to_camel(field.name)] = cls.encode(value)
except ValueError as e:
raise ValueError(f"Error encoding field {field.name}: {str(e)}")
return result


class DataclassJSONDecoder:
"""Handles decoding of JSON data to dataclass instances."""

def __init__(self, module_scope: Optional[types.ModuleType] = None):
self.module_scope = module_scope or sys.modules[__name__]
self.type_resolver = GenericTypeResolver(self.module_scope)

def decode(self, data: Any, target_type: Any) -> Any:
if data is None:
return None

if isinstance(target_type, str):
target_type = self.type_resolver.resolve_type(target_type)

origin, type_args = self._get_type_info(target_type)

if origin is Union:
return self._decode_union(data, type_args)

if self._is_sequence_type(target_type):
return self._decode_sequence(data, type_args)

if origin is dict:
return self._decode_dict(data, type_args)

if is_dataclass(target_type):
return self._decode_dataclass(data, target_type)

if origin is datetime or target_type is datetime:
return datetime.fromisoformat(data)

if origin is date or target_type is date:
return date.fromisoformat(data)

if isinstance(target_type, type) and issubclass(target_type, Enum):
try:
return target_type(data)
except ValueError as e:
raise ValueError(f"{str(e)}. Valid values are: {[e.value for e in target_type]}")

return data

def _get_type_info(self, type_hint: Any) -> Tuple[Any, Tuple]:
origin = get_origin(type_hint)
args = get_args(type_hint)
if origin is None and isinstance(type_hint, type):
origin = type_hint
return origin, args

def _is_sequence_type(self, type_hint: Any) -> bool:
origin = get_origin(type_hint)
return (
origin is list or
origin is typing.Sequence or
(isinstance(origin, type) and issubclass(origin, typing.Sequence))
)

def _decode_sequence(self, data: Any, type_args: Tuple) -> List:
item_type = type_args[0] if type_args else Any
if not isinstance(data, list):
data = [data] if data is not None else []
return [self.decode(item, item_type) for item in data]

def _decode_union(self, data: Any, union_types: Tuple) -> Any:
types = [t for t in union_types if t is not type(None)]
if not types:
return data
if len(types) == 1:
return self.decode(data, types[0])

for t in types:
try:
return self.decode(data, t)
except (ValueError, TypeError):
continue
raise ValueError(f"Could not decode {data} as any type in {types}")

def _decode_dict(self, data: Dict, type_args: Tuple) -> Dict:
if not isinstance(data, dict):
raise TypeError(f"Expected dict but got {type(data)}")
key_type = type_args[0] if type_args else str
value_type = type_args[1] if len(type_args) > 1 else Any
return {
self.decode(k, key_type): self.decode(v, value_type)
for k, v in data.items()
}

def _decode_dataclass(self, data: Dict[str, Any], target_class: Any) -> Any:
"""Decode dictionary into a dataclass instance."""
if not isinstance(data, dict):
raise TypeError(f"Expected dict but got {type(data)}")

field_values = {}
class_fields = {field.name: field for field in fields(target_class)}

# Create a mapping for both camelCase and original field names
field_mapping = {}
for field_name in class_fields:
try:
camel_name = snake_to_camel(field_name)
field_mapping[camel_name] = field_name
field_mapping[field_name] = field_name
except ValueError:
continue

# Process input fields
for key, value in data.items():
# Try to map the input key to a field name
field_name = None

# First try direct mapping
if key in field_mapping:
field_name = field_mapping[key]
else:
# Try converting the key
try:
snake_key = camel_to_snake(key)
if snake_key in class_fields:
field_name = snake_key
except ValueError:
continue

if field_name and field_name in class_fields:
field = class_fields[field_name]
try:
field_values[field_name] = self.decode(value, field.type)
except Exception as e:
raise TypeError(f"Error decoding field {key}: {str(e)}")

# Check for missing required fields
missing_required = [
name for name, field in class_fields.items()
if name not in field_values
and field.default is MISSING
and field.default_factory is MISSING
]

if missing_required:
raise TypeError(f"Missing required arguments: {', '.join(missing_required)}")

# Add default values for missing optional fields
self._apply_default_values(field_values, class_fields)
return target_class(**field_values)

def _apply_default_values(self, field_values: Dict[str, Any], class_fields: Dict[str, Field]) -> None:
for field_name, field in class_fields.items():
if field_name not in field_values:
if field.default is not Field:
field_values[field_name] = field.default
elif field.default_factory is not Field:
field_values[field_name] = field.default_factory()


class JSONDataclassMixin:
"""Mixin class to add JSON conversion capabilities to dataclasses."""

def to_json(self) -> str:
return json.dumps(self.to_dict())

def to_dict(self) -> Dict[str, Any]:
return DataclassJSONEncoder.encode(self)

@classmethod
def from_json(cls: Type[T], json_str: str, module_scope: Optional[types.ModuleType] = None) -> T:
data = json.loads(json_str)
return cls.from_dict(data, module_scope)

@classmethod
def from_dict(cls: Type[T], data: Dict[str, Any], module_scope: Optional[types.ModuleType] = None) -> T:
decoder = DataclassJSONDecoder(module_scope or sys.modules[cls.__module__])
return decoder.decode(data, cls)
Loading

0 comments on commit 2c0b576

Please sign in to comment.