Skip to content

Commit

Permalink
Ensure that extensions respect deterministic serialization.
Browse files Browse the repository at this point in the history
Previously we were not sorting extensions at encode time, even in deterministic mode.

PiperOrigin-RevId: 508217926
  • Loading branch information
haberman authored and copybara-github committed Feb 9, 2023
1 parent 28de62f commit 57a79de
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 14 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ cc_library(
":base",
":hash",
":mem",
":message_internal",
":mini_table_internal",
":port",
],
Expand Down
41 changes: 35 additions & 6 deletions upb/collections/map_sorter.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,30 @@ static int (*const compar[kUpb_FieldType_SizeOf])(const void*, const void*) = {
[kUpb_FieldType_Bytes] = _upb_mapsorter_cmpstr,
};

bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type,
const upb_Map* map, _upb_sortedmap* sorted) {
int map_size = _upb_Map_Size(map);
static bool _upb_mapsorter_resize(_upb_mapsorter* s, _upb_sortedmap* sorted,
int size) {
sorted->start = s->size;
sorted->pos = sorted->start;
sorted->end = sorted->start + map_size;
sorted->end = sorted->start + size;

// Grow s->entries if necessary.
if (sorted->end > s->cap) {
s->cap = upb_Log2CeilingSize(sorted->end);
s->entries = realloc(s->entries, s->cap * sizeof(*s->entries));
if (!s->entries) return false;
}

s->size = sorted->end;
return true;
}

bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type,
const upb_Map* map, _upb_sortedmap* sorted) {
int map_size = _upb_Map_Size(map);

if (!_upb_mapsorter_resize(s, sorted, map_size)) return false;

// Copy non-empty entries from the table to s->entries.
upb_tabent const** dst = &s->entries[sorted->start];
const void** dst = &s->entries[sorted->start];
const upb_tabent* src = map->table.t.entries;
const upb_tabent* end = src + upb_table_size(&map->table.t);
for (; src < end; src++) {
Expand All @@ -135,3 +141,26 @@ bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type,
compar[key_type]);
return true;
}

static int _upb_mapsorter_cmpext(const void* _a, const void* _b) {
const upb_Message_Extension* const* a = _a;
const upb_Message_Extension* const* b = _b;
uint32_t a_num = (*a)->ext->field.number;
uint32_t b_num = (*b)->ext->field.number;
assert(a_num != b_num);
return a_num < b_num ? -1 : 1;
}

bool _upb_mapsorter_pushexts(_upb_mapsorter* s,
const upb_Message_Extension* exts, size_t count,
_upb_sortedmap* sorted) {
if (!_upb_mapsorter_resize(s, sorted, count)) return false;

for (size_t i = 0; i < count; i++) {
s->entries[sorted->start + i] = &exts[i];
}

qsort(&s->entries[sorted->start], count, sizeof(*s->entries),
_upb_mapsorter_cmpext);
return true;
}
17 changes: 15 additions & 2 deletions upb/collections/map_sorter_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <stdlib.h>

#include "upb/collections/map_internal.h"
#include "upb/message/extension_internal.h"
#include "upb/mini_table/message_internal.h"

// Must be last.
Expand All @@ -47,7 +48,7 @@ extern "C" {
// maps), _upb_mapsorter can contain a stack of maps.

typedef struct {
upb_tabent const** entries;
void const** entries;
int size;
int cap;
} _upb_mapsorter;
Expand All @@ -71,14 +72,22 @@ UPB_INLINE void _upb_mapsorter_destroy(_upb_mapsorter* s) {
UPB_INLINE bool _upb_sortedmap_next(_upb_mapsorter* s, const upb_Map* map,
_upb_sortedmap* sorted, upb_MapEntry* ent) {
if (sorted->pos == sorted->end) return false;
const upb_tabent* tabent = s->entries[sorted->pos++];
const upb_tabent* tabent = (const upb_tabent*)s->entries[sorted->pos++];
upb_StringView key = upb_tabstrview(tabent->key);
_upb_map_fromkey(key, &ent->data.k, map->key_size);
upb_value val = {tabent->val.val};
_upb_map_fromvalue(val, &ent->data.v, map->val_size);
return true;
}

UPB_INLINE bool _upb_sortedmap_nextext(_upb_mapsorter* s,
_upb_sortedmap* sorted,
const upb_Message_Extension** ext) {
if (sorted->pos == sorted->end) return false;
*ext = (const upb_Message_Extension*)s->entries[sorted->pos++];
return true;
}

UPB_INLINE void _upb_mapsorter_popmap(_upb_mapsorter* s,
_upb_sortedmap* sorted) {
s->size = sorted->start;
Expand All @@ -87,6 +96,10 @@ UPB_INLINE void _upb_mapsorter_popmap(_upb_mapsorter* s,
bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type,
const upb_Map* map, _upb_sortedmap* sorted);

bool _upb_mapsorter_pushexts(_upb_mapsorter* s,
const upb_Message_Extension* exts, size_t count,
_upb_sortedmap* sorted);

#ifdef __cplusplus
} /* extern "C" */
#endif
Expand Down
35 changes: 35 additions & 0 deletions upb/test/test_generated_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,3 +961,38 @@ TEST(GeneratedCode, ArenaUnaligned) {
EXPECT_EQ(0, reinterpret_cast<uintptr_t>(mem) & low_bits);
upb_Arena_Free(arena);
}

TEST(GeneratedCode, Extensions) {
upb::Arena arena;
upb_test_ModelExtension1* extension1 =
upb_test_ModelExtension1_new(arena.ptr());
upb_test_ModelExtension1_set_str(extension1,
upb_StringView_FromString("Hello"));

upb_test_ModelExtension2* extension2 =
upb_test_ModelExtension2_new(arena.ptr());
upb_test_ModelExtension2_set_i(extension2, 5);

upb_test_ModelWithExtensions* msg1 =
upb_test_ModelWithExtensions_new(arena.ptr());
upb_test_ModelWithExtensions* msg2 =
upb_test_ModelWithExtensions_new(arena.ptr());

// msg1: [extension1, extension2]
upb_test_ModelExtension1_set_model_ext(msg1, extension1, arena.ptr());
upb_test_ModelExtension2_set_model_ext(msg1, extension2, arena.ptr());

// msg2: [extension2, extension1]
upb_test_ModelExtension2_set_model_ext(msg2, extension2, arena.ptr());
upb_test_ModelExtension1_set_model_ext(msg2, extension1, arena.ptr());

size_t size1, size2;
int opts = kUpb_EncodeOption_Deterministic;
char* pb1 = upb_test_ModelWithExtensions_serialize_ex(msg1, opts, arena.ptr(),
&size1);
char* pb2 = upb_test_ModelWithExtensions_serialize_ex(msg2, opts, arena.ptr(),
&size2);

ASSERT_EQ(size1, size2);
ASSERT_EQ(0, memcmp(pb1, pb2, size1));
}
26 changes: 20 additions & 6 deletions upb/wire/encode.c
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,15 @@ static void encode_msgset_item(upb_encstate* e,
encode_tag(e, kUpb_MsgSet_Item, kUpb_WireType_StartGroup);
}

static void encode_ext(upb_encstate* e, const upb_Message_Extension* ext,
bool is_message_set) {
if (UPB_UNLIKELY(is_message_set)) {
encode_msgset_item(e, ext);
} else {
encode_field(e, &ext->data, &ext->ext->sub, &ext->ext->field);
}
}

static void encode_message(upb_encstate* e, const upb_Message* msg,
const upb_MiniTable* m, size_t* size) {
size_t pre_len = e->limit - e->ptr;
Expand Down Expand Up @@ -543,12 +552,17 @@ static void encode_message(upb_encstate* e, const upb_Message* msg,
size_t ext_count;
const upb_Message_Extension* ext = _upb_Message_Getexts(msg, &ext_count);
if (ext_count) {
const upb_Message_Extension* end = ext + ext_count;
for (; ext != end; ext++) {
if (UPB_UNLIKELY(m->ext == kUpb_ExtMode_IsMessageSet)) {
encode_msgset_item(e, ext);
} else {
encode_field(e, &ext->data, &ext->ext->sub, &ext->ext->field);
if (e->options & kUpb_EncodeOption_Deterministic) {
_upb_sortedmap sorted;
_upb_mapsorter_pushexts(&e->sorter, ext, ext_count, &sorted);
while (_upb_sortedmap_nextext(&e->sorter, &sorted, &ext)) {
encode_ext(e, ext, m->ext == kUpb_ExtMode_IsMessageSet);
}
_upb_mapsorter_popmap(&e->sorter, &sorted);
} else {
const upb_Message_Extension* end = ext + ext_count;
for (; ext != end; ext++) {
encode_ext(e, ext, m->ext == kUpb_ExtMode_IsMessageSet);
}
}
}
Expand Down

0 comments on commit 57a79de

Please sign in to comment.