From 491f9c35d060303561acbd81a8525ffd0b439ad3 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Wed, 11 Dec 2024 14:37:23 -0800 Subject: [PATCH] add new Storage.store_repo method, run test_storage tests on DatastoreStorage too for https://console.cloud.google.com/errors/detail/COSMr7uQuqb2zQE;locations=global;time=P30D?project=bridgy-federated&inv=1&invt=Abj3dg --- arroba/datastore_storage.py | 13 +++ arroba/storage.py | 17 +++ arroba/tests/test_datastore_storage.py | 1 + arroba/tests/test_storage.py | 141 +++++++++++++------------ 4 files changed, 105 insertions(+), 67 deletions(-) diff --git a/arroba/datastore_storage.py b/arroba/datastore_storage.py index dc19b40..fb0228d 100644 --- a/arroba/datastore_storage.py +++ b/arroba/datastore_storage.py @@ -506,6 +506,19 @@ def update(): update() + def store_repo(self, repo): + @ndb.transactional() + def store(): + atp_repo = AtpRepo.get_by_id(repo.did) + atp_repo.populate( + handles=[repo.handle] if repo.handle else [], + status=repo.status, + ) + atp_repo.put() + logger.info(f'Stored repo {atp_repo}') + + store() + @ndb_context def read(self, cid): block = AtpBlock.get_by_id(cid.encode('base32')) diff --git a/arroba/storage.py b/arroba/storage.py index 418ba2c..47ae19f 100644 --- a/arroba/storage.py +++ b/arroba/storage.py @@ -165,6 +165,18 @@ def load_repo(self, did_or_handle): """ raise NotImplementedError() + def store_repo(self, repo): + """Writes a repo to storage. + + Right now only writes some metadata: + * handle + * status + + Args: + repo (Repo) + """ + raise NotImplementedError() + def load_repos(self, after=None, limit=500): """Loads multiple repos from storage. @@ -461,6 +473,11 @@ def load_repo(self, did_or_handle): if did_or_handle in (repo.did, repo.handle): return repo + def store_repo(self, repo): + stored = self.repos[repo.did] + stored.handle = repo.handle + stored.statue = repo.status + def load_repos(self, after=None, limit=500): it = iter(sorted(self.repos.values(), key=lambda repo: repo.did)) diff --git a/arroba/tests/test_datastore_storage.py b/arroba/tests/test_datastore_storage.py index 80413bc..26c0336 100644 --- a/arroba/tests/test_datastore_storage.py +++ b/arroba/tests/test_datastore_storage.py @@ -39,6 +39,7 @@ ] BLOB_CID = CID.decode('bafkreicqpqncshdd27sgztqgzocd3zhhqnnsv6slvzhs5uz6f57cq6lmtq') + class DatastoreStorageTest(DatastoreTest): def test_atpsequence_allocate_new(self): diff --git a/arroba/tests/test_storage.py b/arroba/tests/test_storage.py index c7c0cf6..3a28279 100644 --- a/arroba/tests/test_storage.py +++ b/arroba/tests/test_storage.py @@ -5,10 +5,11 @@ from multiformats import CID from ..repo import Repo, Write +from ..datastore_storage import DatastoreStorage from ..storage import Action, Block, MemoryStorage, SUBSCRIBE_REPOS_NSID from ..util import dag_cbor_cid, next_tid, DEACTIVATED, TOMBSTONED -from .testutil import NOW, TestCase +from .testutil import DatastoreTest, NOW, TestCase DECODED = {'foo': 'bar'} ENCODED = b'\xa1cfoocbar' @@ -16,6 +17,12 @@ class StorageTest(TestCase): + STORAGE_CLS = MemoryStorage + + def setUp(self): + super().setUp() + self.storage = self.STORAGE_CLS() + def test_block_encoded(self): block = Block(encoded=ENCODED) self.assertEqual(DECODED, block.decoded) @@ -33,8 +40,7 @@ def test_block_hash(self): self.assertEqual(id(Block(decoded=DECODED)), id(Block(encoded=ENCODED))) def test_read_events_by_seq(self): - storage = MemoryStorage() - repo = Repo.create(storage, 'did:web:user.com', signing_key=self.key) + repo = Repo.create(self.storage, 'did:web:user.com', signing_key=self.key) init = repo.head.cid tid = next_tid() @@ -46,7 +52,7 @@ def test_read_events_by_seq(self): repo.apply_writes([delete]) delete = repo.head.cid - events = list(storage.read_events_by_seq()) + events = list(self.storage.read_events_by_seq()) self.assertEqual(5, len(events)) self.assertEqual(init, events[0].commit.cid) self.assertEqual('com.atproto.sync.subscribeRepos#identity', @@ -56,15 +62,14 @@ def test_read_events_by_seq(self): self.assertEqual(create, events[3].commit.cid) self.assertEqual(delete, events[4].commit.cid) - events = storage.read_events_by_seq(start=4) + events = self.storage.read_events_by_seq(start=4) self.assertEqual([create, delete], [cd.commit.cid for cd in events]) def test_read_events_by_seq_repo(self): - storage = MemoryStorage() - alice = Repo.create(storage, 'did:alice', signing_key=self.key) + alice = Repo.create(self.storage, 'did:alice', signing_key=self.key) alice_init = alice.head.cid - bob = Repo.create(storage, 'did:bob', signing_key=self.key) + bob = Repo.create(self.storage, 'did:bob', signing_key=self.key) create = Write(Action.CREATE, 'co.ll', next_tid(), {'foo': 'bar'}) alice.apply_writes([create]) @@ -72,7 +77,7 @@ def test_read_events_by_seq_repo(self): create = Write(Action.CREATE, 'co.ll', next_tid(), {'baz': 'biff'}) bob.apply_writes([create]) - events = list(storage.read_events_by_seq(repo='did:alice')) + events = list(self.storage.read_events_by_seq(repo='did:alice')) self.assertEqual(4, len(events)) self.assertEqual(alice_init, events[0].commit.cid) self.assertEqual('com.atproto.sync.subscribeRepos#identity', @@ -81,15 +86,14 @@ def test_read_events_by_seq_repo(self): events[2]['$type']) self.assertEqual(alice.head.cid, events[3].commit.cid) - events = storage.read_events_by_seq(repo='did:alice', start=4) + events = self.storage.read_events_by_seq(repo='did:alice', start=4) self.assertEqual([alice.head.cid], [cd.commit.cid for cd in events]) def test_read_events_by_seq_include_record_block_even_if_preexisting(self): # https://github.com/snarfed/bridgy-fed/issues/1016#issuecomment-2109276344 commit_cids = [] - storage = MemoryStorage() - repo = Repo.create(storage, 'did:web:user.com', signing_key=self.key) + repo = Repo.create(self.storage, 'did:web:user.com', signing_key=self.key) commit_cids.append(repo.head.cid) prev_prev = repo.head.cid @@ -101,7 +105,7 @@ def test_read_events_by_seq_include_record_block_even_if_preexisting(self): second = Write(Action.CREATE, 'co.ll', next_tid(), {'foo': 'bar'}) commit_cid = repo.apply_writes([second]) - commits = list(storage.read_events_by_seq(start=4)) + commits = list(self.storage.read_events_by_seq(start=4)) self.assertEqual(2, len(commits)) record = Block(decoded={'foo': 'bar'}) @@ -114,14 +118,13 @@ def test_read_events_by_seq_include_record_block_even_if_preexisting(self): self.assertEqual(record, commits[1].blocks[record.cid]) def test_read_events_tombstone_then_commit(self): - storage = MemoryStorage() - alice = Repo.create(storage, 'did:alice', signing_key=self.key) + alice = Repo.create(self.storage, 'did:alice', signing_key=self.key) - storage.tombstone_repo(alice) + self.storage.tombstone_repo(alice) - bob = Repo.create(storage, 'did:bob', signing_key=self.key) + bob = Repo.create(self.storage, 'did:bob', signing_key=self.key) - events = list(storage.read_events_by_seq()) + events = list(self.storage.read_events_by_seq()) self.assertEqual(alice.head.cid, events[0].commit.cid) self.assertEqual(1, events[0].commit.seq) @@ -136,11 +139,10 @@ def test_read_events_tombstone_then_commit(self): self.assertEqual(5, events[4].commit.seq) def test_read_events_commit_then_tombstone(self): - storage = MemoryStorage() - alice = Repo.create(storage, 'did:alice', signing_key=self.key) - storage.tombstone_repo(alice) + alice = Repo.create(self.storage, 'did:alice', signing_key=self.key) + self.storage.tombstone_repo(alice) - events = list(storage.read_events_by_seq()) + events = list(self.storage.read_events_by_seq()) self.assertEqual(4, len(events)) self.assertEqual(alice.head.cid, events[0].commit.cid) self.assertEqual(1, events[0].commit.seq) @@ -153,21 +155,27 @@ def test_read_events_commit_then_tombstone(self): }, events[3]) def test_load_repo(self): - storage = MemoryStorage() - created = Repo.create(storage, 'did:web:user.com', signing_key=self.key) + created = Repo.create(self.storage, 'did:web:user.com', signing_key=self.key) - got = storage.load_repo('did:web:user.com') + got = self.storage.load_repo('did:web:user.com') self.assertEqual('did:web:user.com', got.did) self.assertEqual(created.head, got.head) self.assertIsNone(got.status) + def test_store_repo(self): + repo = Repo.create(self.storage, 'did:web:user.com', signing_key=self.key) + repo.handle = 'foo.bar' + self.storage.store_repo(repo) + + got = self.storage.load_repo('did:web:user.com') + self.assertEqual('foo.bar', repo.handle) + def test_load_repos(self): - storage = MemoryStorage() - alice = Repo.create(storage, 'did:web:alice', signing_key=self.key) - bob = Repo.create(storage, 'did:plc:bob', signing_key=self.key) - storage.tombstone_repo(bob) + alice = Repo.create(self.storage, 'did:web:alice', signing_key=self.key) + bob = Repo.create(self.storage, 'did:plc:bob', signing_key=self.key) + self.storage.tombstone_repo(bob) - got_bob, got_alice = storage.load_repos() + got_bob, got_alice = self.storage.load_repos() self.assertEqual('did:web:alice', got_alice.did) self.assertEqual(alice.head, got_alice.head) self.assertIsNone(got_alice.status) @@ -177,45 +185,42 @@ def test_load_repos(self): self.assertEqual('tombstoned', got_bob.status) def test_load_repos_after(self): - storage = MemoryStorage() - Repo.create(storage, 'did:web:alice', signing_key=self.key) - Repo.create(storage, 'did:plc:bob', signing_key=self.key) + Repo.create(self.storage, 'did:web:alice', signing_key=self.key) + Repo.create(self.storage, 'did:plc:bob', signing_key=self.key) - got = storage.load_repos(after='did:plc:bob') + got = self.storage.load_repos(after='did:plc:bob') self.assertEqual(1, len(got)) self.assertEqual('did:web:alice', got[0].did) - got = storage.load_repos(after='did:web:a') + got = self.storage.load_repos(after='did:web:a') self.assertEqual(1, len(got)) self.assertEqual('did:web:alice', got[0].did) - got = storage.load_repos(after='did:web:alice') + got = self.storage.load_repos(after='did:web:alice') self.assertEqual([], got) def test_load_repos_limit(self): - storage = MemoryStorage() - Repo.create(storage, 'did:web:alice', signing_key=self.key) - Repo.create(storage, 'did:plc:bob', signing_key=self.key) + Repo.create(self.storage, 'did:web:alice', signing_key=self.key) + Repo.create(self.storage, 'did:plc:bob', signing_key=self.key) - got = storage.load_repos(limit=2) + got = self.storage.load_repos(limit=2) self.assertEqual(2, len(got)) - got = storage.load_repos(limit=1) + got = self.storage.load_repos(limit=1) self.assertEqual(1, len(got)) self.assertEqual('did:plc:bob', got[0].did) def test_tombstone_repo(self): seen = [] - storage = MemoryStorage() - repo = Repo.create(storage, 'did:user', signing_key=self.key) - self.assertEqual(3, storage.last_seq(SUBSCRIBE_REPOS_NSID)) + repo = Repo.create(self.storage, 'did:user', signing_key=self.key) + self.assertEqual(3, self.storage.last_seq(SUBSCRIBE_REPOS_NSID)) repo.callback = lambda event: seen.append(event) - storage.tombstone_repo(repo) + self.storage.tombstone_repo(repo) self.assertEqual(TOMBSTONED, repo.status) - self.assertEqual(4, storage.last_seq(SUBSCRIBE_REPOS_NSID)) + self.assertEqual(4, self.storage.last_seq(SUBSCRIBE_REPOS_NSID)) expected = { '$type': 'com.atproto.sync.subscribeRepos#tombstone', 'seq': 4, @@ -223,22 +228,21 @@ def test_tombstone_repo(self): 'time': NOW.isoformat(), } self.assertEqual([expected], seen) - self.assertEqual(expected, storage.read(dag_cbor_cid(expected)).decoded) + self.assertEqual(expected, self.storage.read(dag_cbor_cid(expected)).decoded) self.assertEqual(TOMBSTONED, repo.status) - self.assertEqual(TOMBSTONED, storage.load_repo('did:user').status) + self.assertEqual(TOMBSTONED, self.storage.load_repo('did:user').status) def test_deactivate_repo(self): seen = [] - storage = MemoryStorage() - repo = Repo.create(storage, 'did:user', signing_key=self.key) - self.assertEqual(3, storage.last_seq(SUBSCRIBE_REPOS_NSID)) + repo = Repo.create(self.storage, 'did:user', signing_key=self.key) + self.assertEqual(3, self.storage.last_seq(SUBSCRIBE_REPOS_NSID)) repo.callback = lambda event: seen.append(event) - storage.deactivate_repo(repo) + self.storage.deactivate_repo(repo) self.assertEqual(DEACTIVATED, repo.status) - self.assertEqual(DEACTIVATED, storage.load_repo('did:user').status) + self.assertEqual(DEACTIVATED, self.storage.load_repo('did:user').status) - self.assertEqual(4, storage.last_seq(SUBSCRIBE_REPOS_NSID)) + self.assertEqual(4, self.storage.last_seq(SUBSCRIBE_REPOS_NSID)) expected = { '$type': 'com.atproto.sync.subscribeRepos#account', 'seq': 4, @@ -248,20 +252,19 @@ def test_deactivate_repo(self): 'status': 'deactivated', } self.assertEqual([expected], seen) - self.assertEqual(expected, storage.read(dag_cbor_cid(expected)).decoded) + self.assertEqual(expected, self.storage.read(dag_cbor_cid(expected)).decoded) def test_activate_repo(self): seen = [] - storage = MemoryStorage() - repo = Repo.create(storage, 'did:user', signing_key=self.key, + repo = Repo.create(self.storage, 'did:user', signing_key=self.key, status=DEACTIVATED) - self.assertEqual(3, storage.last_seq(SUBSCRIBE_REPOS_NSID)) + self.assertEqual(3, self.storage.last_seq(SUBSCRIBE_REPOS_NSID)) repo.callback = lambda event: seen.append(event) - storage.activate_repo(repo) + self.storage.activate_repo(repo) self.assertIsNone(repo.status) - self.assertEqual(4, storage.last_seq(SUBSCRIBE_REPOS_NSID)) + self.assertEqual(4, self.storage.last_seq(SUBSCRIBE_REPOS_NSID)) expected = { '$type': 'com.atproto.sync.subscribeRepos#account', 'seq': 4, @@ -270,16 +273,15 @@ def test_activate_repo(self): 'active': True, } self.assertEqual([expected], seen) - self.assertEqual(expected, storage.read(dag_cbor_cid(expected)).decoded) + self.assertEqual(expected, self.storage.read(dag_cbor_cid(expected)).decoded) self.assertIsNone(repo.status) - self.assertIsNone(storage.load_repo('did:user').status) + self.assertIsNone(self.storage.load_repo('did:user').status) def test_write_event(self): - storage = MemoryStorage() - repo = Repo.create(storage, 'did:user', signing_key=self.key) - self.assertEqual(3, storage.last_seq(SUBSCRIBE_REPOS_NSID)) + repo = Repo.create(self.storage, 'did:user', signing_key=self.key) + self.assertEqual(3, self.storage.last_seq(SUBSCRIBE_REPOS_NSID)) - block = storage.write_event(repo=repo, type='identity', + block = self.storage.write_event(repo=repo, type='identity', active=False, status='foo') self.assertEqual({ '$type': 'com.atproto.sync.subscribeRepos#identity', @@ -289,4 +291,9 @@ def test_write_event(self): 'active': False, 'status': 'foo', }, block.decoded) - self.assertEqual(block, storage.read(block.cid)) + self.assertEqual(block, self.storage.read(block.cid)) + + +class DatastoreStorageTest(StorageTest, DatastoreTest): + """Run all of StorageTest's tests with DatastoreStorage.""" + STORAGE_CLS = DatastoreStorage