diff --git a/protovalidate/internal/constraints.py b/protovalidate/internal/constraints.py index 9e86876..0c47603 100644 --- a/protovalidate/internal/constraints.py +++ b/protovalidate/internal/constraints.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import datetime import typing @@ -81,7 +82,7 @@ def __getitem__(self, name): return super().__getitem__(name) -def _msg_to_cel(msg: message.Message) -> dict[str, celtypes.Value]: +def _msg_to_cel(msg: message.Message) -> celtypes.Value: ctor = _MSG_TYPE_URL_TO_CTOR.get(msg.DESCRIPTOR.full_name) if ctor is not None: return ctor(msg) @@ -230,13 +231,26 @@ def _set_path_element_map_key( raise CompilationError(msg) +class Violation: + """A singular constraint violation.""" + + proto: validate_pb2.Violation + field_value: typing.Any + rule_value: typing.Any + + def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = None, **kwargs): + self.proto = validate_pb2.Violation(**kwargs) + self.field_value = field_value + self.rule_value = rule_value + + class ConstraintContext: """The state associated with a single constraint evaluation.""" - def __init__(self, fail_fast: bool = False, violations: validate_pb2.Violations = None): # noqa: FBT001, FBT002 + def __init__(self, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None): # noqa: FBT001, FBT002 self._fail_fast = fail_fast if violations is None: - violations = validate_pb2.Violations() + violations = [] self._violations = violations @property @@ -244,29 +258,29 @@ def fail_fast(self) -> bool: return self._fail_fast @property - def violations(self) -> validate_pb2.Violations: + def violations(self) -> list[Violation]: return self._violations - def add(self, violation: validate_pb2.Violation): - self._violations.violations.append(violation) + def add(self, violation: Violation): + self._violations.append(violation) def add_errors(self, other_ctx): - self._violations.violations.extend(other_ctx.violations.violations) + self._violations.extend(other_ctx.violations) def add_field_path_element(self, element: validate_pb2.FieldPathElement): - for violation in self._violations.violations: - violation.field.elements.append(element) + for violation in self._violations: + violation.proto.field.elements.append(element) def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPathElement]): - for violation in self._violations.violations: - violation.rule.elements.extend(elements) + for violation in self._violations: + violation.proto.rule.elements.extend(elements) @property def done(self) -> bool: return self._fail_fast and self.has_errors() def has_errors(self) -> bool: - return len(self._violations.violations) > 0 + return len(self._violations) > 0 def sub_context(self): return ConstraintContext(self._fail_fast) @@ -277,55 +291,67 @@ class ConstraintRules: def validate(self, ctx: ConstraintContext, message: message.Message): # noqa: ARG002 """Validate the message against the rules in this constraint.""" - ctx.add(validate_pb2.Violation(constraint_id="unimplemented", message="Unimplemented")) + ctx.add(Violation(constraint_id="unimplemented", message="Unimplemented")) + + +@dataclasses.dataclass +class CelRunner: + runner: celpy.Runner + constraint: validate_pb2.Constraint + rule_value: typing.Optional[typing.Any] = None + rule_cel: typing.Optional[celtypes.Value] = None + rule_path: typing.Optional[validate_pb2.FieldPath] = None class CelConstraintRules(ConstraintRules): """A constraint that has rules written in CEL.""" - _runners: list[ - tuple[ - celpy.Runner, - validate_pb2.Constraint, - typing.Optional[celtypes.Value], - typing.Optional[validate_pb2.FieldPath], - ] - ] - _rules_cel: celtypes.Value = None + _cel: list[CelRunner] + _rules: typing.Optional[message.Message] = None + _rules_cel: typing.Optional[celtypes.Value] = None def __init__(self, rules: typing.Optional[message.Message]): - self._runners = [] + self._cel = [] if rules is not None: + self._rules = rules self._rules_cel = _msg_to_cel(rules) def _validate_cel( self, ctx: ConstraintContext, - activation: dict[str, typing.Any], *, + this_value: typing.Optional[typing.Any] = None, + this_cel: typing.Optional[celtypes.Value] = None, for_key: bool = False, ): + activation: dict[str, celtypes.Value] = {} + if this_cel is not None: + activation["this"] = this_cel activation["rules"] = self._rules_cel activation["now"] = celtypes.TimestampType(datetime.datetime.now(tz=datetime.timezone.utc)) - for runner, constraint, rule, rule_path in self._runners: - activation["rule"] = rule - result = runner.evaluate(activation) + for cel in self._cel: + activation["rule"] = cel.rule_cel + result = cel.runner.evaluate(activation) if isinstance(result, celtypes.BoolType): if not result: ctx.add( - validate_pb2.Violation( - rule=rule_path, - constraint_id=constraint.id, - message=constraint.message, + Violation( + field_value=this_value, + rule=cel.rule_path, + rule_value=cel.rule_value, + constraint_id=cel.constraint.id, + message=cel.constraint.message, for_key=for_key, ), ) elif isinstance(result, celtypes.StringType): if result: ctx.add( - validate_pb2.Violation( - rule=rule_path, - constraint_id=constraint.id, + Violation( + field_value=this_value, + rule=cel.rule_path, + rule_value=cel.rule_value, + constraint_id=cel.constraint.id, message=result, for_key=for_key, ), @@ -339,19 +365,32 @@ def add_rule( funcs: dict[str, celpy.CELFunction], rules: validate_pb2.Constraint, *, - rule: typing.Optional[celtypes.Value] = None, + rule_field: typing.Optional[descriptor.FieldDescriptor] = None, rule_path: typing.Optional[validate_pb2.FieldPath] = None, ): ast = env.compile(rules.expression) prog = env.program(ast, functions=funcs) - self._runners.append((prog, rules, rule, rule_path)) + rule_value = None + rule_cel = None + if rule_field is not None and self._rules is not None: + rule_value = _proto_message_get_field(self._rules, rule_field) + rule_cel = _field_to_cel(self._rules, rule_field) + self._cel.append( + CelRunner( + runner=prog, + constraint=rules, + rule_value=rule_value, + rule_cel=rule_cel, + rule_path=rule_path, + ) + ) class MessageConstraintRules(CelConstraintRules): """Message-level rules.""" def validate(self, ctx: ConstraintContext, message: message.Message): - self._validate_cel(ctx, {"this": _msg_to_cel(message)}) + self._validate_cel(ctx, this_cel=_msg_to_cel(message)) def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_name: typing.Optional[str] = None): @@ -445,7 +484,7 @@ def __init__( env, funcs, cel, - rule=_field_to_cel(rules, list_field), + rule_field=list_field, rule_path=validate_pb2.FieldPath( elements=[ _field_to_element(list_field), @@ -465,13 +504,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message): if _is_empty_field(message, self._field): if self._required: ctx.add( - validate_pb2.Violation( + Violation( field=validate_pb2.FieldPath( elements=[ _field_to_element(self._field), ], ), rule=FieldConstraintRules._required_rule_path, + rule_value=self._required, constraint_id="required", message="value is required", ), @@ -485,7 +525,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message): return sub_ctx = ctx.sub_context() self._validate_value(sub_ctx, val) - self._validate_cel(sub_ctx, {"this": cel_val}) + self._validate_cel(sub_ctx, this_value=_proto_message_get_field(message, self._field), this_cel=cel_val) if sub_ctx.has_errors(): element = _field_to_element(self._field) sub_ctx.add_field_path_element(element) @@ -493,7 +533,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message): def validate_item(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False): self._validate_value(ctx, val, for_key=for_key) - self._validate_cel(ctx, {"this": _scalar_field_value_to_cel(val, self._field)}, for_key=for_key) + self._validate_cel(ctx, this_value=val, this_cel=_scalar_field_value_to_cel(val, self._field), for_key=for_key) def _validate_value(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False): pass @@ -546,8 +586,9 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key if len(self._in) > 0: if value.type_url not in self._in: ctx.add( - validate_pb2.Violation( + Violation( rule=AnyConstraintRules._in_rule_path, + rule_value=self._in, constraint_id="any.in", message="type URL must be in the allow list", for_key=for_key, @@ -555,8 +596,9 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key ) if value.type_url in self._not_in: ctx.add( - validate_pb2.Violation( + Violation( rule=AnyConstraintRules._not_in_rule_path, + rule_value=self._not_in, constraint_id="any.not_in", message="type URL must not be in the block list", for_key=for_key, @@ -603,13 +645,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message): value = getattr(message, self._field.name) if value not in self._field.enum_type.values_by_number: ctx.add( - validate_pb2.Violation( + Violation( field=validate_pb2.FieldPath( elements=[ _field_to_element(self._field), ], ), rule=EnumConstraintRules._defined_only_rule_path, + rule_value=self._defined_only, constraint_id="enum.defined_only", message="value must be one of the defined enum values", ), @@ -742,7 +785,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message): if not message.WhichOneof(self._oneof.name): if self.required: ctx.add( - validate_pb2.Violation( + Violation( field=validate_pb2.FieldPath( elements=[_oneof_to_element(self._oneof)], ), diff --git a/protovalidate/validator.py b/protovalidate/validator.py index 12253cb..0efdc38 100644 --- a/protovalidate/validator.py +++ b/protovalidate/validator.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.protobuf import message from buf.validate import validate_pb2 # type: ignore @@ -20,6 +22,7 @@ CompilationError = _constraints.CompilationError Violations = validate_pb2.Violations +Violation = _constraints.Violation class Validator: @@ -54,7 +57,7 @@ def validate( ValidationError: If the message is invalid. """ violations = self.collect_violations(message, fail_fast=fail_fast) - if violations.violations: + if len(violations) > 0: msg = f"invalid {message.DESCRIPTOR.name}" raise ValidationError(msg, violations) @@ -63,8 +66,8 @@ def collect_violations( message: message.Message, *, fail_fast: bool = False, - into: validate_pb2.Violations = None, - ) -> validate_pb2.Violations: + into: typing.Optional[list[Violation]] = None, + ) -> list[Violation]: """ Validates the given message against the static constraints defined in the message's descriptor. Compared to validate, collect_violations is @@ -84,12 +87,12 @@ def collect_violations( constraint.validate(ctx, message) if ctx.done: break - for violation in ctx.violations.violations: - if violation.HasField("field"): - violation.field.elements.reverse() - if violation.HasField("rule"): - violation.rule.elements.reverse() - violation.field_path = field_path.string(violation.field) + for violation in ctx.violations: + if violation.proto.HasField("field"): + violation.proto.field.elements.reverse() + if violation.proto.HasField("rule"): + violation.proto.rule.elements.reverse() + violation.proto.field_path = field_path.string(violation.proto.field) return ctx.violations @@ -98,15 +101,25 @@ class ValidationError(ValueError): An error raised when a message fails to validate. """ - violations: validate_pb2.Violations + _violations: list[_constraints.Violation] - def __init__(self, msg: str, violations: validate_pb2.Violations): + def __init__(self, msg: str, violations: list[_constraints.Violation]): super().__init__(msg) - self.violations = violations + self._violations = violations + + def to_proto(self) -> validate_pb2.Violations: + """ + Provides the Protobuf form of the validation errors. + """ + result = validate_pb2.Violations() + for violation in self._violations: + result.violations.append(violation.proto) + return result - def errors(self) -> list[validate_pb2.Violation]: + @property + def violations(self) -> list[Violation]: """ - Returns the validation errors as a simple Python list, rather than the + Provides the validation errors as a simple Python list, rather than the Protobuf-specific collection type used by Violations. """ - return list(self.violations.violations) + return self._violations diff --git a/tests/conformance/runner.py b/tests/conformance/runner.py index 012fe6b..d15a7a9 100644 --- a/tests/conformance/runner.py +++ b/tests/conformance/runner.py @@ -62,7 +62,9 @@ def run_test_case(tc: typing.Any, result: typing.Optional[harness_pb2.TestResult result = harness_pb2.TestResult() # Run the validator try: - protovalidate.collect_violations(tc, into=result.validation_error) + violations = protovalidate.collect_violations(tc) + for violation in violations: + result.validation_error.violations.append(violation.proto) if len(result.validation_error.violations) == 0: result.success = True except celpy.CELEvalError as e: diff --git a/tests/validate_test.py b/tests/validate_test.py index 9939e77..2642b47 100644 --- a/tests/validate_test.py +++ b/tests/validate_test.py @@ -23,23 +23,27 @@ def test_ninf(self): msg = numbers_pb2.DoubleFinite() msg.val = float("-inf") violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations.violations), 1) - self.assertEqual(violations.violations[0].constraint_id, "double.finite") + self.assertEqual(len(violations), 1) + self.assertEqual(violations[0].proto.constraint_id, "double.finite") + self.assertEqual(violations[0].field_value, msg.val) + self.assertEqual(violations[0].rule_value, True) def test_map_key(self): msg = maps_pb2.MapKeys() msg.val[1] = "a" violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations.violations), 1) - self.assertEqual(violations.violations[0].field_path, "val[1]") - self.assertEqual(violations.violations[0].for_key, True) + self.assertEqual(len(violations), 1) + self.assertEqual(violations[0].proto.field_path, "val[1]") + self.assertEqual(violations[0].proto.for_key, True) + self.assertEqual(violations[0].field_value, 1) + self.assertEqual(violations[0].rule_value, 0) def test_sfixed64(self): msg = numbers_pb2.SFixed64ExLTGT(val=11) protovalidate.validate(msg) violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations.violations), 0) + self.assertEqual(len(violations), 0) def test_oneofs(self): msg1 = oneofs_pb2.Oneof() @@ -52,7 +56,7 @@ def test_oneofs(self): violations = protovalidate.collect_violations(msg1) protovalidate.collect_violations(msg2, into=violations) - assert len(violations.violations) == 0 + assert len(violations) == 0 def test_repeated(self): msg = repeated_pb2.RepeatedEmbedSkip() @@ -60,23 +64,23 @@ def test_repeated(self): protovalidate.validate(msg) violations = protovalidate.collect_violations(msg) - assert len(violations.violations) == 0 + assert len(violations) == 0 def test_maps(self): msg = maps_pb2.MapMinMax() try: protovalidate.validate(msg) except protovalidate.ValidationError as e: - assert len(e.errors()) == 1 - assert len(e.violations.violations) == 1 + assert len(e.violations) == 1 + assert len(e.to_proto().violations) == 1 assert str(e) == "invalid MapMinMax" violations = protovalidate.collect_violations(msg) - assert len(violations.violations) == 1 + assert len(violations) == 1 def test_timestamp(self): msg = wkt_timestamp_pb2.TimestampGTNow() protovalidate.validate(msg) violations = protovalidate.collect_violations(msg) - assert len(violations.violations) == 0 + assert len(violations) == 0