Skip to content

Commit

Permalink
Add rule and field value to violations (#224)
Browse files Browse the repository at this point in the history
Adds the ability to access the captured rule and field value from a
`Violation`.

**This is a breaking change.** The API changes in the following ways:
- `ValidationError` has changed:
    - Old accesses to `ValidationError.violations` should call
`ValidationError.to_proto()` instead, to get a `buf.validate.Violations`
message.
    - `ValidationError.errors` was removed. Switch to using
`ValidationError.violations` instead.
    - `ValidationError.violations` provides a list of the new `Violation`
wrapper type instead of a list of `buf.validate.Violation`.
    - The new `Violation` wrapper type contains the `buf.validate.Violation`
message under the `proto` field, as well as `field_value` and
`rule_value` properties that capture the field and rule values,
respectively.
- `Validator.collect_violations` now operates on and returns
`list[Violation]` instead of the protobuf `buf.validate.Violations`
message.

This API mirrors the changes being made in protovalidate-go in
bufbuild/protovalidate-go#154.
  • Loading branch information
jchadwick-buf authored Dec 4, 2024
1 parent 41a4661 commit 5e6ac7b
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 73 deletions.
133 changes: 88 additions & 45 deletions protovalidate/internal/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import datetime
import typing

Expand Down Expand Up @@ -81,7 +82,7 @@ def __getitem__(self, name):
return super().__getitem__(name)


def _msg_to_cel(msg: message.Message) -> dict[str, celtypes.Value]:
def _msg_to_cel(msg: message.Message) -> celtypes.Value:
ctor = _MSG_TYPE_URL_TO_CTOR.get(msg.DESCRIPTOR.full_name)
if ctor is not None:
return ctor(msg)
Expand Down Expand Up @@ -230,43 +231,56 @@ def _set_path_element_map_key(
raise CompilationError(msg)


class Violation:
"""A singular constraint violation."""

proto: validate_pb2.Violation
field_value: typing.Any
rule_value: typing.Any

def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = None, **kwargs):
self.proto = validate_pb2.Violation(**kwargs)
self.field_value = field_value
self.rule_value = rule_value


class ConstraintContext:
"""The state associated with a single constraint evaluation."""

def __init__(self, fail_fast: bool = False, violations: validate_pb2.Violations = None): # noqa: FBT001, FBT002
def __init__(self, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None): # noqa: FBT001, FBT002
self._fail_fast = fail_fast
if violations is None:
violations = validate_pb2.Violations()
violations = []
self._violations = violations

@property
def fail_fast(self) -> bool:
return self._fail_fast

@property
def violations(self) -> validate_pb2.Violations:
def violations(self) -> list[Violation]:
return self._violations

def add(self, violation: validate_pb2.Violation):
self._violations.violations.append(violation)
def add(self, violation: Violation):
self._violations.append(violation)

def add_errors(self, other_ctx):
self._violations.violations.extend(other_ctx.violations.violations)
self._violations.extend(other_ctx.violations)

def add_field_path_element(self, element: validate_pb2.FieldPathElement):
for violation in self._violations.violations:
violation.field.elements.append(element)
for violation in self._violations:
violation.proto.field.elements.append(element)

def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPathElement]):
for violation in self._violations.violations:
violation.rule.elements.extend(elements)
for violation in self._violations:
violation.proto.rule.elements.extend(elements)

@property
def done(self) -> bool:
return self._fail_fast and self.has_errors()

def has_errors(self) -> bool:
return len(self._violations.violations) > 0
return len(self._violations) > 0

def sub_context(self):
return ConstraintContext(self._fail_fast)
Expand All @@ -277,55 +291,67 @@ class ConstraintRules:

def validate(self, ctx: ConstraintContext, message: message.Message): # noqa: ARG002
"""Validate the message against the rules in this constraint."""
ctx.add(validate_pb2.Violation(constraint_id="unimplemented", message="Unimplemented"))
ctx.add(Violation(constraint_id="unimplemented", message="Unimplemented"))


@dataclasses.dataclass
class CelRunner:
runner: celpy.Runner
constraint: validate_pb2.Constraint
rule_value: typing.Optional[typing.Any] = None
rule_cel: typing.Optional[celtypes.Value] = None
rule_path: typing.Optional[validate_pb2.FieldPath] = None


class CelConstraintRules(ConstraintRules):
"""A constraint that has rules written in CEL."""

_runners: list[
tuple[
celpy.Runner,
validate_pb2.Constraint,
typing.Optional[celtypes.Value],
typing.Optional[validate_pb2.FieldPath],
]
]
_rules_cel: celtypes.Value = None
_cel: list[CelRunner]
_rules: typing.Optional[message.Message] = None
_rules_cel: typing.Optional[celtypes.Value] = None

def __init__(self, rules: typing.Optional[message.Message]):
self._runners = []
self._cel = []
if rules is not None:
self._rules = rules
self._rules_cel = _msg_to_cel(rules)

def _validate_cel(
self,
ctx: ConstraintContext,
activation: dict[str, typing.Any],
*,
this_value: typing.Optional[typing.Any] = None,
this_cel: typing.Optional[celtypes.Value] = None,
for_key: bool = False,
):
activation: dict[str, celtypes.Value] = {}
if this_cel is not None:
activation["this"] = this_cel
activation["rules"] = self._rules_cel
activation["now"] = celtypes.TimestampType(datetime.datetime.now(tz=datetime.timezone.utc))
for runner, constraint, rule, rule_path in self._runners:
activation["rule"] = rule
result = runner.evaluate(activation)
for cel in self._cel:
activation["rule"] = cel.rule_cel
result = cel.runner.evaluate(activation)
if isinstance(result, celtypes.BoolType):
if not result:
ctx.add(
validate_pb2.Violation(
rule=rule_path,
constraint_id=constraint.id,
message=constraint.message,
Violation(
field_value=this_value,
rule=cel.rule_path,
rule_value=cel.rule_value,
constraint_id=cel.constraint.id,
message=cel.constraint.message,
for_key=for_key,
),
)
elif isinstance(result, celtypes.StringType):
if result:
ctx.add(
validate_pb2.Violation(
rule=rule_path,
constraint_id=constraint.id,
Violation(
field_value=this_value,
rule=cel.rule_path,
rule_value=cel.rule_value,
constraint_id=cel.constraint.id,
message=result,
for_key=for_key,
),
Expand All @@ -339,19 +365,32 @@ def add_rule(
funcs: dict[str, celpy.CELFunction],
rules: validate_pb2.Constraint,
*,
rule: typing.Optional[celtypes.Value] = None,
rule_field: typing.Optional[descriptor.FieldDescriptor] = None,
rule_path: typing.Optional[validate_pb2.FieldPath] = None,
):
ast = env.compile(rules.expression)
prog = env.program(ast, functions=funcs)
self._runners.append((prog, rules, rule, rule_path))
rule_value = None
rule_cel = None
if rule_field is not None and self._rules is not None:
rule_value = _proto_message_get_field(self._rules, rule_field)
rule_cel = _field_to_cel(self._rules, rule_field)
self._cel.append(
CelRunner(
runner=prog,
constraint=rules,
rule_value=rule_value,
rule_cel=rule_cel,
rule_path=rule_path,
)
)


class MessageConstraintRules(CelConstraintRules):
"""Message-level rules."""

def validate(self, ctx: ConstraintContext, message: message.Message):
self._validate_cel(ctx, {"this": _msg_to_cel(message)})
self._validate_cel(ctx, this_cel=_msg_to_cel(message))


def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_name: typing.Optional[str] = None):
Expand Down Expand Up @@ -445,7 +484,7 @@ def __init__(
env,
funcs,
cel,
rule=_field_to_cel(rules, list_field),
rule_field=list_field,
rule_path=validate_pb2.FieldPath(
elements=[
_field_to_element(list_field),
Expand All @@ -465,13 +504,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
if _is_empty_field(message, self._field):
if self._required:
ctx.add(
validate_pb2.Violation(
Violation(
field=validate_pb2.FieldPath(
elements=[
_field_to_element(self._field),
],
),
rule=FieldConstraintRules._required_rule_path,
rule_value=self._required,
constraint_id="required",
message="value is required",
),
Expand All @@ -485,15 +525,15 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
return
sub_ctx = ctx.sub_context()
self._validate_value(sub_ctx, val)
self._validate_cel(sub_ctx, {"this": cel_val})
self._validate_cel(sub_ctx, this_value=_proto_message_get_field(message, self._field), this_cel=cel_val)
if sub_ctx.has_errors():
element = _field_to_element(self._field)
sub_ctx.add_field_path_element(element)
ctx.add_errors(sub_ctx)

def validate_item(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False):
self._validate_value(ctx, val, for_key=for_key)
self._validate_cel(ctx, {"this": _scalar_field_value_to_cel(val, self._field)}, for_key=for_key)
self._validate_cel(ctx, this_value=val, this_cel=_scalar_field_value_to_cel(val, self._field), for_key=for_key)

def _validate_value(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False):
pass
Expand Down Expand Up @@ -546,17 +586,19 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key
if len(self._in) > 0:
if value.type_url not in self._in:
ctx.add(
validate_pb2.Violation(
Violation(
rule=AnyConstraintRules._in_rule_path,
rule_value=self._in,
constraint_id="any.in",
message="type URL must be in the allow list",
for_key=for_key,
)
)
if value.type_url in self._not_in:
ctx.add(
validate_pb2.Violation(
Violation(
rule=AnyConstraintRules._not_in_rule_path,
rule_value=self._not_in,
constraint_id="any.not_in",
message="type URL must not be in the block list",
for_key=for_key,
Expand Down Expand Up @@ -603,13 +645,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
value = getattr(message, self._field.name)
if value not in self._field.enum_type.values_by_number:
ctx.add(
validate_pb2.Violation(
Violation(
field=validate_pb2.FieldPath(
elements=[
_field_to_element(self._field),
],
),
rule=EnumConstraintRules._defined_only_rule_path,
rule_value=self._defined_only,
constraint_id="enum.defined_only",
message="value must be one of the defined enum values",
),
Expand Down Expand Up @@ -742,7 +785,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
if not message.WhichOneof(self._oneof.name):
if self.required:
ctx.add(
validate_pb2.Violation(
Violation(
field=validate_pb2.FieldPath(
elements=[_oneof_to_element(self._oneof)],
),
Expand Down
43 changes: 28 additions & 15 deletions protovalidate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from google.protobuf import message

from buf.validate import validate_pb2 # type: ignore
Expand All @@ -20,6 +22,7 @@

CompilationError = _constraints.CompilationError
Violations = validate_pb2.Violations
Violation = _constraints.Violation


class Validator:
Expand Down Expand Up @@ -54,7 +57,7 @@ def validate(
ValidationError: If the message is invalid.
"""
violations = self.collect_violations(message, fail_fast=fail_fast)
if violations.violations:
if len(violations) > 0:
msg = f"invalid {message.DESCRIPTOR.name}"
raise ValidationError(msg, violations)

Expand All @@ -63,8 +66,8 @@ def collect_violations(
message: message.Message,
*,
fail_fast: bool = False,
into: validate_pb2.Violations = None,
) -> validate_pb2.Violations:
into: typing.Optional[list[Violation]] = None,
) -> list[Violation]:
"""
Validates the given message against the static constraints defined in
the message's descriptor. Compared to validate, collect_violations is
Expand All @@ -84,12 +87,12 @@ def collect_violations(
constraint.validate(ctx, message)
if ctx.done:
break
for violation in ctx.violations.violations:
if violation.HasField("field"):
violation.field.elements.reverse()
if violation.HasField("rule"):
violation.rule.elements.reverse()
violation.field_path = field_path.string(violation.field)
for violation in ctx.violations:
if violation.proto.HasField("field"):
violation.proto.field.elements.reverse()
if violation.proto.HasField("rule"):
violation.proto.rule.elements.reverse()
violation.proto.field_path = field_path.string(violation.proto.field)
return ctx.violations


Expand All @@ -98,15 +101,25 @@ class ValidationError(ValueError):
An error raised when a message fails to validate.
"""

violations: validate_pb2.Violations
_violations: list[_constraints.Violation]

def __init__(self, msg: str, violations: validate_pb2.Violations):
def __init__(self, msg: str, violations: list[_constraints.Violation]):
super().__init__(msg)
self.violations = violations
self._violations = violations

def to_proto(self) -> validate_pb2.Violations:
"""
Provides the Protobuf form of the validation errors.
"""
result = validate_pb2.Violations()
for violation in self._violations:
result.violations.append(violation.proto)
return result

def errors(self) -> list[validate_pb2.Violation]:
@property
def violations(self) -> list[Violation]:
"""
Returns the validation errors as a simple Python list, rather than the
Provides the validation errors as a simple Python list, rather than the
Protobuf-specific collection type used by Violations.
"""
return list(self.violations.violations)
return self._violations
4 changes: 3 additions & 1 deletion tests/conformance/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def run_test_case(tc: typing.Any, result: typing.Optional[harness_pb2.TestResult
result = harness_pb2.TestResult()
# Run the validator
try:
protovalidate.collect_violations(tc, into=result.validation_error)
violations = protovalidate.collect_violations(tc)
for violation in violations:
result.validation_error.violations.append(violation.proto)
if len(result.validation_error.violations) == 0:
result.success = True
except celpy.CELEvalError as e:
Expand Down
Loading

0 comments on commit 5e6ac7b

Please sign in to comment.