From 57a79de7cc34cc2ca4436483834aebd44e8f4f4b Mon Sep 17 00:00:00 2001 From: Joshua Haberman Date: Wed, 8 Feb 2023 16:42:15 -0800 Subject: [PATCH] Ensure that extensions respect deterministic serialization. Previously we were not sorting extensions at encode time, even in deterministic mode. PiperOrigin-RevId: 508217926 --- BUILD | 1 + upb/collections/map_sorter.c | 41 +++++++++++++++++++++++---- upb/collections/map_sorter_internal.h | 17 +++++++++-- upb/test/test_generated_code.cc | 35 +++++++++++++++++++++++ upb/wire/encode.c | 26 +++++++++++++---- 5 files changed, 106 insertions(+), 14 deletions(-) diff --git a/BUILD b/BUILD index 27ddc9b9ce..79844a80f5 100644 --- a/BUILD +++ b/BUILD @@ -481,6 +481,7 @@ cc_library( ":base", ":hash", ":mem", + ":message_internal", ":mini_table_internal", ":port", ], diff --git a/upb/collections/map_sorter.c b/upb/collections/map_sorter.c index aa8a2416b3..c18e03c784 100644 --- a/upb/collections/map_sorter.c +++ b/upb/collections/map_sorter.c @@ -102,14 +102,12 @@ static int (*const compar[kUpb_FieldType_SizeOf])(const void*, const void*) = { [kUpb_FieldType_Bytes] = _upb_mapsorter_cmpstr, }; -bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type, - const upb_Map* map, _upb_sortedmap* sorted) { - int map_size = _upb_Map_Size(map); +static bool _upb_mapsorter_resize(_upb_mapsorter* s, _upb_sortedmap* sorted, + int size) { sorted->start = s->size; sorted->pos = sorted->start; - sorted->end = sorted->start + map_size; + sorted->end = sorted->start + size; - // Grow s->entries if necessary. if (sorted->end > s->cap) { s->cap = upb_Log2CeilingSize(sorted->end); s->entries = realloc(s->entries, s->cap * sizeof(*s->entries)); @@ -117,9 +115,17 @@ bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type, } s->size = sorted->end; + return true; +} + +bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type, + const upb_Map* map, _upb_sortedmap* sorted) { + int map_size = _upb_Map_Size(map); + + if (!_upb_mapsorter_resize(s, sorted, map_size)) return false; // Copy non-empty entries from the table to s->entries. - upb_tabent const** dst = &s->entries[sorted->start]; + const void** dst = &s->entries[sorted->start]; const upb_tabent* src = map->table.t.entries; const upb_tabent* end = src + upb_table_size(&map->table.t); for (; src < end; src++) { @@ -135,3 +141,26 @@ bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type, compar[key_type]); return true; } + +static int _upb_mapsorter_cmpext(const void* _a, const void* _b) { + const upb_Message_Extension* const* a = _a; + const upb_Message_Extension* const* b = _b; + uint32_t a_num = (*a)->ext->field.number; + uint32_t b_num = (*b)->ext->field.number; + assert(a_num != b_num); + return a_num < b_num ? -1 : 1; +} + +bool _upb_mapsorter_pushexts(_upb_mapsorter* s, + const upb_Message_Extension* exts, size_t count, + _upb_sortedmap* sorted) { + if (!_upb_mapsorter_resize(s, sorted, count)) return false; + + for (size_t i = 0; i < count; i++) { + s->entries[sorted->start + i] = &exts[i]; + } + + qsort(&s->entries[sorted->start], count, sizeof(*s->entries), + _upb_mapsorter_cmpext); + return true; +} diff --git a/upb/collections/map_sorter_internal.h b/upb/collections/map_sorter_internal.h index a8822bdbc2..8e9f1294ac 100644 --- a/upb/collections/map_sorter_internal.h +++ b/upb/collections/map_sorter_internal.h @@ -33,6 +33,7 @@ #include #include "upb/collections/map_internal.h" +#include "upb/message/extension_internal.h" #include "upb/mini_table/message_internal.h" // Must be last. @@ -47,7 +48,7 @@ extern "C" { // maps), _upb_mapsorter can contain a stack of maps. typedef struct { - upb_tabent const** entries; + void const** entries; int size; int cap; } _upb_mapsorter; @@ -71,7 +72,7 @@ UPB_INLINE void _upb_mapsorter_destroy(_upb_mapsorter* s) { UPB_INLINE bool _upb_sortedmap_next(_upb_mapsorter* s, const upb_Map* map, _upb_sortedmap* sorted, upb_MapEntry* ent) { if (sorted->pos == sorted->end) return false; - const upb_tabent* tabent = s->entries[sorted->pos++]; + const upb_tabent* tabent = (const upb_tabent*)s->entries[sorted->pos++]; upb_StringView key = upb_tabstrview(tabent->key); _upb_map_fromkey(key, &ent->data.k, map->key_size); upb_value val = {tabent->val.val}; @@ -79,6 +80,14 @@ UPB_INLINE bool _upb_sortedmap_next(_upb_mapsorter* s, const upb_Map* map, return true; } +UPB_INLINE bool _upb_sortedmap_nextext(_upb_mapsorter* s, + _upb_sortedmap* sorted, + const upb_Message_Extension** ext) { + if (sorted->pos == sorted->end) return false; + *ext = (const upb_Message_Extension*)s->entries[sorted->pos++]; + return true; +} + UPB_INLINE void _upb_mapsorter_popmap(_upb_mapsorter* s, _upb_sortedmap* sorted) { s->size = sorted->start; @@ -87,6 +96,10 @@ UPB_INLINE void _upb_mapsorter_popmap(_upb_mapsorter* s, bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type, const upb_Map* map, _upb_sortedmap* sorted); +bool _upb_mapsorter_pushexts(_upb_mapsorter* s, + const upb_Message_Extension* exts, size_t count, + _upb_sortedmap* sorted); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/upb/test/test_generated_code.cc b/upb/test/test_generated_code.cc index 4b9d2b8627..e2db7d97dc 100644 --- a/upb/test/test_generated_code.cc +++ b/upb/test/test_generated_code.cc @@ -961,3 +961,38 @@ TEST(GeneratedCode, ArenaUnaligned) { EXPECT_EQ(0, reinterpret_cast(mem) & low_bits); upb_Arena_Free(arena); } + +TEST(GeneratedCode, Extensions) { + upb::Arena arena; + upb_test_ModelExtension1* extension1 = + upb_test_ModelExtension1_new(arena.ptr()); + upb_test_ModelExtension1_set_str(extension1, + upb_StringView_FromString("Hello")); + + upb_test_ModelExtension2* extension2 = + upb_test_ModelExtension2_new(arena.ptr()); + upb_test_ModelExtension2_set_i(extension2, 5); + + upb_test_ModelWithExtensions* msg1 = + upb_test_ModelWithExtensions_new(arena.ptr()); + upb_test_ModelWithExtensions* msg2 = + upb_test_ModelWithExtensions_new(arena.ptr()); + + // msg1: [extension1, extension2] + upb_test_ModelExtension1_set_model_ext(msg1, extension1, arena.ptr()); + upb_test_ModelExtension2_set_model_ext(msg1, extension2, arena.ptr()); + + // msg2: [extension2, extension1] + upb_test_ModelExtension2_set_model_ext(msg2, extension2, arena.ptr()); + upb_test_ModelExtension1_set_model_ext(msg2, extension1, arena.ptr()); + + size_t size1, size2; + int opts = kUpb_EncodeOption_Deterministic; + char* pb1 = upb_test_ModelWithExtensions_serialize_ex(msg1, opts, arena.ptr(), + &size1); + char* pb2 = upb_test_ModelWithExtensions_serialize_ex(msg2, opts, arena.ptr(), + &size2); + + ASSERT_EQ(size1, size2); + ASSERT_EQ(0, memcmp(pb1, pb2, size1)); +} diff --git a/upb/wire/encode.c b/upb/wire/encode.c index 497f04dac5..ef3b4adadf 100644 --- a/upb/wire/encode.c +++ b/upb/wire/encode.c @@ -514,6 +514,15 @@ static void encode_msgset_item(upb_encstate* e, encode_tag(e, kUpb_MsgSet_Item, kUpb_WireType_StartGroup); } +static void encode_ext(upb_encstate* e, const upb_Message_Extension* ext, + bool is_message_set) { + if (UPB_UNLIKELY(is_message_set)) { + encode_msgset_item(e, ext); + } else { + encode_field(e, &ext->data, &ext->ext->sub, &ext->ext->field); + } +} + static void encode_message(upb_encstate* e, const upb_Message* msg, const upb_MiniTable* m, size_t* size) { size_t pre_len = e->limit - e->ptr; @@ -543,12 +552,17 @@ static void encode_message(upb_encstate* e, const upb_Message* msg, size_t ext_count; const upb_Message_Extension* ext = _upb_Message_Getexts(msg, &ext_count); if (ext_count) { - const upb_Message_Extension* end = ext + ext_count; - for (; ext != end; ext++) { - if (UPB_UNLIKELY(m->ext == kUpb_ExtMode_IsMessageSet)) { - encode_msgset_item(e, ext); - } else { - encode_field(e, &ext->data, &ext->ext->sub, &ext->ext->field); + if (e->options & kUpb_EncodeOption_Deterministic) { + _upb_sortedmap sorted; + _upb_mapsorter_pushexts(&e->sorter, ext, ext_count, &sorted); + while (_upb_sortedmap_nextext(&e->sorter, &sorted, &ext)) { + encode_ext(e, ext, m->ext == kUpb_ExtMode_IsMessageSet); + } + _upb_mapsorter_popmap(&e->sorter, &sorted); + } else { + const upb_Message_Extension* end = ext + ext_count; + for (; ext != end; ext++) { + encode_ext(e, ext, m->ext == kUpb_ExtMode_IsMessageSet); } } }