Skip to content

Commit

Permalink
User.enable_protocol: always call Protocol.create_for
Browse files Browse the repository at this point in the history
...even if there's already a copy, since it might need to be reactivated

for #1130
  • Loading branch information
snarfed committed Oct 22, 2024
1 parent aaf49e0 commit 059c7ed
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
4 changes: 3 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ def enable_protocol(self, to_proto):
"""
added = False

if to_proto.LABEL in ids.COPIES_PROTOCOLS and not self.get_copy(to_proto):
if to_proto.LABEL in ids.COPIES_PROTOCOLS:
# do this even if there's an existing copy since we might need to
# reactivate it, which create_for should do
to_proto.create_for(self)

@ndb.transactional()
Expand Down
3 changes: 1 addition & 2 deletions tests/test_dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_receive_no_yes_sets_enabled_protocols(self):
user = user.key.get()
self.assertEqual(['fake'], user.enabled_protocols)
self.assertTrue(user.is_enabled(Fake))
self.assertEqual([], Fake.created_for)
self.assertEqual(['efake:user'], Fake.created_for)

# "no" DM should remove from enabled_protocols
Follower.get_or_create(to=user, from_=alice)
Expand All @@ -164,7 +164,6 @@ def test_receive_no_yes_sets_enabled_protocols(self):
self.assertEqual(('OK', 200), receive(from_user=user, obj=dm))
user = user.key.get()
self.assertEqual([], user.enabled_protocols)
self.assertEqual([], Fake.created_for)
self.assertFalse(user.is_enabled(Fake))

# ...and delete copy actor
Expand Down
41 changes: 39 additions & 2 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,15 +2923,14 @@ def test_follow_and_block_protocol_user_sets_enabled_protocols(self):
user = user.key.get()
self.assertEqual(['fake'], user.enabled_protocols)
self.assertTrue(user.is_enabled(Fake))
self.assertEqual([], Fake.created_for)
self.assertEqual(['efake:user'], Fake.created_for)

# block should remove from enabled_protocols
Follower.get_or_create(to=user, from_=self.user)
block['id'] += '2'
self.assertEqual(('OK', 200), ExplicitFake.receive_as1(block))
user = user.key.get()
self.assertEqual([], user.enabled_protocols)
self.assertEqual([], Fake.created_for)
self.assertFalse(user.is_enabled(Fake))

# ...and delete copy actor
Expand Down Expand Up @@ -3042,6 +3041,44 @@ def test_follow_bot_user_overrides_nobot(self):
self.assertTrue(user.is_enabled(Fake))
self.assertEqual(['efake:user'], ExplicitFake.fetched)

def test_block_then_follow_protocol_user_recreates_copy(self):
# bot user
self.make_user('fa.brid.gy', cls=Web)

follow = {
'objectType': 'activity',
'verb': 'follow',
'id': 'efake:follow',
'actor': 'efake:user',
'object': 'fa.brid.gy',
}
block = {
'objectType': 'activity',
'verb': 'block',
'id': 'efake:block',
'actor': 'efake:user',
'object': 'fa.brid.gy',
}

copy = Target(uri='fake:user', protocol='fake')
user = self.make_user('efake:user', cls=ExplicitFake,
enabled_protocols=['fake'], copies=[copy])
self.assertTrue(user.is_enabled(Fake))
self.assertEqual([copy], user.copies)

self.assertEqual(('OK', 200), ExplicitFake.receive_as1(block))
user = user.key.get()
self.assertFalse(user.is_enabled(Fake))
self.assertEqual([copy], user.copies)

# fake protocol isn't enabled yet, block should be a noop
ExplicitFake.fetchable = {'efake:user': {'profile': 'info'}}
_, code = ExplicitFake.receive_as1(follow)
self.assertEqual(204, code)
user = user.key.get()
self.assertEqual(['fake'], user.enabled_protocols)
self.assertEqual(['efake:user'], Fake.created_for)

def test_receive_activity_lease(self):
Follower.get_or_create(to=self.user, from_=self.alice)

Expand Down
12 changes: 7 additions & 5 deletions tests/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,17 @@ def bridged_web_url_for(cls, user, fallback=False):

@classmethod
def create_for(cls, user):
assert not user.get_copy(cls)
id = user.key.id()
logger.info(f'{cls.__name__}.create_for {id}')
cls.created_for.append(id)
add(user.copies, Target(uri=ids.translate_user_id(id=id, from_=user, to=cls),
protocol=cls.LABEL))
user.put()

if user.obj_key:
if not user.get_copy(cls):
copy = Target(uri=ids.translate_user_id(id=id, from_=user, to=cls),
protocol=cls.LABEL)
add(user.copies, copy)
user.put()

if user.obj and not user.obj.get_copy(cls):
profile_copy_id = ids.translate_object_id(
id=user.profile_id(), from_=user, to=cls)
user.obj.add('copies', Target(uri=profile_copy_id, protocol=cls.LABEL))
Expand Down

0 comments on commit 059c7ed

Please sign in to comment.