Skip to content

Commit

Permalink
Add unknown field checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jchadwick-buf committed Sep 17, 2024
1 parent ab3d16e commit 5f1d8aa
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 14 deletions.
1 change: 1 addition & 0 deletions buf/validate/conformance/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TestRunner {
google::protobuf::DescriptorPool::generated_pool())
: descriptorPool_(descriptorPool), validatorFactory_(ValidatorFactory::New().value()) {
validatorFactory_->SetMessageFactory(&messageFactory_, descriptorPool_);
validatorFactory_->SetAllowUnknownFields(false);
}

harness::TestConformanceResponse runTest(const harness::TestConformanceRequest& request);
Expand Down
12 changes: 11 additions & 1 deletion buf/validate/internal/cel_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ namespace buf::validate::internal {
template <typename R>
absl::Status BuildCelRules(
std::unique_ptr<MessageFactory>& messageFactory,
bool allowUnknownFields,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const R& rules,
CelConstraintRules& result) {
// Look for constraints on the set fields.
std::vector<const google::protobuf::FieldDescriptor*> fields;
google::protobuf::Message* reparsedRules{};
if (messageFactory) {
if (messageFactory && rules.unknown_fields().field_count() > 0) {
reparsedRules = messageFactory->messageFactory()
->GetPrototype(messageFactory->descriptorPool()->FindMessageTypeByName(
rules.GetTypeName()))
Expand All @@ -43,9 +44,18 @@ absl::Status BuildCelRules(
}
}
if (reparsedRules) {
if (!allowUnknownFields &&
!reparsedRules->GetReflection()->GetUnknownFields(*reparsedRules).empty()) {
return absl::FailedPreconditionError(
absl::StrCat("unknown constraints in ", reparsedRules->GetTypeName()));
}
result.setRules(reparsedRules, arena);
reparsedRules->GetReflection()->ListFields(*reparsedRules, &fields);
} else {
if (!allowUnknownFields && !R::GetReflection()->GetUnknownFields(rules).empty()) {
return absl::FailedPreconditionError(
absl::StrCat("unknown constraints in ", rules.GetTypeName()));
}
result.setRules(&rules, arena);
R::GetReflection()->ListFields(rules, &fields);
}
Expand Down
49 changes: 41 additions & 8 deletions buf/validate/internal/field_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
namespace buf::validate::internal {
absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
std::unique_ptr<MessageFactory>& messageFactory,
bool allowUnknownFields,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const google::protobuf::FieldDescriptor* field,
Expand All @@ -32,6 +33,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kBool:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -43,6 +45,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kFloat:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -54,6 +57,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kDouble:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -65,6 +69,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kInt32:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -76,6 +81,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kInt64:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -87,6 +93,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kUint32:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -98,6 +105,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kUint64:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -109,6 +117,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kSint32:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -119,6 +128,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kSint64:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -129,6 +139,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kFixed32:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -139,6 +150,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kFixed64:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -149,6 +161,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kSfixed32:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -159,6 +172,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kSfixed64:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -169,6 +183,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kString:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -180,6 +195,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
case FieldConstraints::kBytes:
rules_or = NewScalarFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -193,6 +209,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
auto status = BuildScalarFieldRules(
*rules_or.value(),
messageFactory,
allowUnknownFields,
arena,
builder,
field,
Expand All @@ -211,7 +228,8 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
return absl::InvalidArgumentError("duration field validator on non-duration field");
} else {
auto result = std::make_unique<FieldConstraintRules>(field, fieldLvl);
auto status = BuildCelRules(messageFactory, arena, builder, fieldLvl.duration(), *result);
auto status = BuildCelRules(
messageFactory, allowUnknownFields, arena, builder, fieldLvl.duration(), *result);
if (!status.ok()) {
rules_or = status;
} else {
Expand All @@ -226,7 +244,8 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
return absl::InvalidArgumentError("timestamp field validator on non-timestamp field");
} else {
auto result = std::make_unique<FieldConstraintRules>(field, fieldLvl);
auto status = BuildCelRules(messageFactory, arena, builder, fieldLvl.timestamp(), *result);
auto status = BuildCelRules(
messageFactory, allowUnknownFields, arena, builder, fieldLvl.timestamp(), *result);
if (!status.ok()) {
rules_or = status;
} else {
Expand All @@ -242,15 +261,21 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
} else {
std::unique_ptr<FieldConstraintRules> items;
if (fieldLvl.repeated().has_items()) {
auto items_or =
NewFieldRules(messageFactory, arena, builder, field, fieldLvl.repeated().items());
auto items_or = NewFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field,
fieldLvl.repeated().items());
if (!items_or.ok()) {
return items_or.status();
}
items = std::move(items_or).value();
}
auto result = std::make_unique<RepeatedConstraintRules>(field, fieldLvl, std::move(items));
auto status = BuildCelRules(messageFactory, arena, builder, fieldLvl.repeated(), *result);
auto status = BuildCelRules(
messageFactory, allowUnknownFields, arena, builder, fieldLvl.repeated(), *result);
if (!status.ok()) {
rules_or = status;
} else {
Expand All @@ -263,12 +288,18 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
return absl::InvalidArgumentError("map field validator on non-map field");
} else {
auto keyRulesOr = NewFieldRules(
messageFactory, arena, builder, field->message_type()->field(0), fieldLvl.map().keys());
messageFactory,
allowUnknownFields,
arena,
builder,
field->message_type()->field(0),
fieldLvl.map().keys());
if (!keyRulesOr.ok()) {
return keyRulesOr.status();
}
auto valueRulesOr = NewFieldRules(
messageFactory,
allowUnknownFields,
arena,
builder,
field->message_type()->field(1),
Expand All @@ -278,7 +309,8 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
}
auto result = std::make_unique<MapConstraintRules>(
field, fieldLvl, std::move(keyRulesOr).value(), std::move(valueRulesOr).value());
auto status = BuildCelRules(messageFactory, arena, builder, fieldLvl.map(), *result);
auto status = BuildCelRules(
messageFactory, allowUnknownFields, arena, builder, fieldLvl.map(), *result);
if (!status.ok()) {
rules_or = status;
} else {
Expand All @@ -292,7 +324,8 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
return absl::InvalidArgumentError("any field validator on non-any field");
} else {
auto result = std::make_unique<FieldConstraintRules>(field, fieldLvl, &fieldLvl.any());
auto status = BuildCelRules(messageFactory, arena, builder, fieldLvl.any(), *result);
auto status = BuildCelRules(
messageFactory, allowUnknownFields, arena, builder, fieldLvl.any(), *result);
if (!status.ok()) {
rules_or = status;
} else {
Expand Down
16 changes: 14 additions & 2 deletions buf/validate/internal/field_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ template <typename R>
absl::Status BuildScalarFieldRules(
FieldConstraintRules& result,
std::unique_ptr<MessageFactory>& messageFactory,
bool allowUnknownFields,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const google::protobuf::FieldDescriptor* field,
Expand All @@ -48,12 +49,13 @@ absl::Status BuildScalarFieldRules(
google::protobuf::FieldDescriptor::TypeName(expectedType)));
}
}
return BuildCelRules(messageFactory, arena, builder, rules, result);
return BuildCelRules(messageFactory, allowUnknownFields, arena, builder, rules, result);
}

template <typename R>
absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewScalarFieldRules(
std::unique_ptr<MessageFactory>& messageFactory,
bool allowUnknownFields,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const google::protobuf::FieldDescriptor* field,
Expand All @@ -63,7 +65,16 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewScalarFieldRules(
std::string_view wrapperName = "") {
auto result = std::make_unique<FieldConstraintRules>(field, fieldLvl);
auto status = BuildScalarFieldRules(
*result, messageFactory, arena, builder, field, fieldLvl, rules, expectedType, wrapperName);
*result,
messageFactory,
allowUnknownFields,
arena,
builder,
field,
fieldLvl,
rules,
expectedType,
wrapperName);
if (!status.ok()) {
return status;
}
Expand All @@ -72,6 +83,7 @@ absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewScalarFieldRules(

absl::StatusOr<std::unique_ptr<FieldConstraintRules>> NewFieldRules(
std::unique_ptr<MessageFactory>& messageFactory,
bool allowUnknownFields,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const google::protobuf::FieldDescriptor* field,
Expand Down
4 changes: 3 additions & 1 deletion buf/validate/internal/message_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ absl::StatusOr<std::unique_ptr<MessageConstraintRules>> BuildMessageRules(

Constraints NewMessageConstraints(
std::unique_ptr<MessageFactory>& messageFactory,
bool allowUnknownFields,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const google::protobuf::Descriptor* descriptor) {
Expand All @@ -54,7 +55,8 @@ Constraints NewMessageConstraints(
continue;
}
const auto& fieldLvl = field->options().GetExtension(buf::validate::field);
auto rules_or = NewFieldRules(messageFactory, arena, builder, field, fieldLvl);
auto rules_or =
NewFieldRules(messageFactory, allowUnknownFields, arena, builder, field, fieldLvl);
if (!rules_or.ok()) {
return rules_or.status();
}
Expand Down
1 change: 1 addition & 0 deletions buf/validate/internal/message_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using Constraints = absl::StatusOr<std::vector<std::unique_ptr<ConstraintRules>>

Constraints NewMessageConstraints(
std::unique_ptr<MessageFactory>& messageFactory,
bool allowUnknownFields,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const google::protobuf::Descriptor* descriptor);
Expand Down
4 changes: 2 additions & 2 deletions buf/validate/validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ absl::Status ValidatorFactory::Add(const google::protobuf::Descriptor* desc) {
auto status =
constraints_
.emplace(
desc, internal::NewMessageConstraints(messageFactory_, &arena_, *builder_, desc))
desc, internal::NewMessageConstraints(messageFactory_, allowUnknownFields_, &arena_, *builder_, desc))
.first->second.status();
if (!status.ok()) {
return status;
Expand Down Expand Up @@ -179,7 +179,7 @@ const internal::Constraints* ValidatorFactory::GetMessageConstraints(
}
return &constraints_
.emplace(
desc, internal::NewMessageConstraints(messageFactory_, &arena_, *builder_, desc))
desc, internal::NewMessageConstraints(messageFactory_, allowUnknownFields_, &arena_, *builder_, desc))
.first->second;
}

Expand Down
6 changes: 6 additions & 0 deletions buf/validate/validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,17 @@ class ValidatorFactory {
messageFactory_ = std::make_unique<internal::MessageFactory>(messageFactory, descriptorPool);
}

/// Set whether or not unknown constraint fields will be tolerated. Defaults to false.
void SetAllowUnknownFields(bool allowUnknownFields) {
allowUnknownFields_ = allowUnknownFields;
}

private:
friend class Validator;
google::protobuf::Arena arena_;
absl::Mutex mutex_;
std::unique_ptr<internal::MessageFactory> messageFactory_;
bool allowUnknownFields_;
absl::flat_hash_map<const google::protobuf::Descriptor*, internal::Constraints> constraints_
ABSL_GUARDED_BY(mutex_);
std::unique_ptr<google::api::expr::runtime::CelExpressionBuilder> builder_
Expand Down

0 comments on commit 5f1d8aa

Please sign in to comment.