Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' of github.com:mozilla-services/autopush into fe…
Browse files Browse the repository at this point in the history
…ature/jwt_chan
  • Loading branch information
jrconlin committed Mar 1, 2016
2 parents d4f0c31 + 2a2c95f commit d59cc24
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 10 deletions.
59 changes: 58 additions & 1 deletion autopush/tests/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import time
import uuid
from hashlib import sha256

import twisted.internet.base
from boto.dynamodb2.exceptions import (
Expand Down Expand Up @@ -865,6 +866,35 @@ def check_register_result(msg):
f.addErrback(lambda x: d.errback(x))
return d

def test_register_with_key(self):
self._connect()
self.proto.ps.uaid = str(uuid.uuid4())
chid = str(uuid.uuid4())
key = "SomeRandomStringOfCryptoSignificantStuff"

def echo(str):
return str.encode('hex')

self.proto.ap_settings.fernet = Mock()
self.proto.ap_settings.fernet.encrypt = echo
self.proto.ap_settings.message.register_channel = Mock()
self.proto.sendJSON = Mock()

d = Deferred()

def check_register_result(msg, uaid, chid, key):
sha = sha256(key).hexdigest()
suaid = uaid.replace('-', '')
schid = chid.replace('-', '')
endpoint = self.proto.sendJSON.call_args[0][0]['pushEndpoint']
eq_(endpoint,
'http://localhost/push/v2/' + suaid + schid + sha)
d.callback(True)

res = self.proto.process_register(dict(channelID=chid, key=key))
res.addCallback(check_register_result, self.proto.ps.uaid, chid, key)
return d

def test_register_kill_others(self):
self._connect()
mock_agent = Mock()
Expand Down Expand Up @@ -943,6 +973,8 @@ def test_unregister_with_webpush(self):
assert self.proto.force_retry.called

def test_ws_unregister(self):
patcher = patch("autopush.websocket.log", spec=True)
mock_log = patcher.start()
self._connect()
self._send_message(dict(messageType="hello", channelIDs=[]))

Expand All @@ -953,12 +985,15 @@ def test_ws_unregister(self):
def check_unregister_result(msg):
eq_(msg["status"], 200)
eq_(msg["channelID"], chid)
eq_(len(mock_log.mock_calls), 1)
patcher.stop()
d.callback(True)

def check_hello_result(msg):
eq_(msg["messageType"], "hello")
eq_(msg["status"], 200)
self._send_message(dict(messageType="unregister",
code=104,
channelID=chid))
self._check_response(check_unregister_result)

Expand Down Expand Up @@ -1170,7 +1205,8 @@ def test_ack_with_webpush_from_storage(self, mock_log):
self.proto.force_retry = Mock(return_value=mock_defer)
self.proto.ack_update(dict(
channelID=chid,
version="bleh:jialsdjfilasjdf"
version="bleh:jialsdjfilasjdf",
code=200
))
assert self.proto.force_retry.called
assert mock_defer.addBoth.called
Expand All @@ -1180,6 +1216,27 @@ def test_ack_with_webpush_from_storage(self, mock_log):
eq_(kwargs["router_key"], "webpush")
eq_(kwargs["message_source"], "stored")

@patch('autopush.websocket.log', spec=True)
def test_nack(self, mock_log):
self._connect()
self.proto.ps.uaid = str(uuid.uuid4())
self.proto.onMessage(json.dumps(dict(
messageType="nack",
version="bleh:asdfhjklhjkl",
code=200
)), False)
eq_(len(mock_log.mock_calls), 1)

@patch('autopush.websocket.log', spec=True)
def test_nack_no_version(self, mock_log):
self._connect()
self.proto.ps.uaid = str(uuid.uuid4())
self.proto.onMessage(json.dumps(dict(
messageType="nack",
code=200
)), False)
eq_(len(mock_log.mock_calls), 0)

def test_ack_remove(self):
self._connect()
chid = str(uuid.uuid4())
Expand Down
51 changes: 42 additions & 9 deletions autopush/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@
from autopush.noseplugin import track_object


def extract_code(data):
"""Extracts and converts a code key if found in data dict"""
code = data.get("code", None)
if code and isinstance(code, int):
code = code
else:
code = 0
return code


def ms_time():
"""Return current time.time call as ms and a Python int"""
return int(time.time() * 1000)
Expand Down Expand Up @@ -403,6 +413,8 @@ def onMessage(self, payload, isBinary):
return self.process_unregister(data)
elif cmd == "ack":
return self.process_ack(data)
elif cmd == "nack":
return self.process_nack(data)
else:
self.sendClose()
finally:
Expand Down Expand Up @@ -988,7 +1000,7 @@ def process_register(self, data):
self.transport.pauseProducing()

d = self.deferToThread(self.ap_settings.make_endpoint, self.ps.uaid,
chid)
chid, data.get('key'))
d.addCallback(self.finish_register, chid)
d.addErrback(self.trap_cancel)
d.addErrback(self.error_register)
Expand Down Expand Up @@ -1037,6 +1049,12 @@ def process_unregister(self, data):
self.ps.metrics.increment("updates.client.unregister",
tags=self.base_tags)

# Log out the unregister if it has a code in it
if "code" in data:
code = extract_code(data)
log.msg("Unregister", channelID=chid, uaid_hash=self.ps.uaid_hash,
user_agent=self.ps.user_agent, code=code)

# Clear out any existing tracked messages for this channel
if self.ps.use_webpush:
self.ps.direct_updates[chid] = []
Expand Down Expand Up @@ -1075,12 +1093,14 @@ def ack_update(self, update):
if not chid or not version:
return

code = extract_code(update)

if self.ps.use_webpush:
return self._handle_webpush_ack(chid, version)
return self._handle_webpush_ack(chid, version, code)
else:
return self._handle_simple_ack(chid, version)
return self._handle_simple_ack(chid, version, code)

def _handle_webpush_ack(self, chid, version):
def _handle_webpush_ack(self, chid, version, code):
"""Handle clearing out a webpush ack"""
# Split off the updateid if its not a direct update
version, updateid = version.split(":")
Expand All @@ -1095,7 +1115,7 @@ def ver_filter(update):
log.msg("Ack", router_key="webpush", channelID=chid,
message_id=version, message_source="direct",
message_size=size, uaid_hash=self.ps.uaid_hash,
user_agent=self.ps.user_agent)
user_agent=self.ps.user_agent, code=code)
self.ps.direct_updates[chid].remove(msg)
return

Expand All @@ -1106,7 +1126,7 @@ def ver_filter(update):
log.msg("Ack", router_key="webpush", channelID=chid,
message_id=version, message_source="stored",
message_size=size, uaid_hash=self.ps.uaid_hash,
user_agent=self.ps.user_agent)
user_agent=self.ps.user_agent, code=code)
d = self.force_retry(self.ps.message.delete_message,
uaid=self.ps.uaid,
channel_id=chid,
Expand All @@ -1131,20 +1151,20 @@ def _handle_webpush_update_remove(self, result, chid, notif):
except AttributeError:
pass

def _handle_simple_ack(self, chid, version):
def _handle_simple_ack(self, chid, version, code):
"""Handle clearing out a simple ack"""
if chid in self.ps.direct_updates and \
self.ps.direct_updates[chid] <= version:
del self.ps.direct_updates[chid]
log.msg("Ack", router_key="simplepush", channelID=chid,
message_id=version, message_source="direct",
uaid_hash=self.ps.uaid_hash,
user_agent=self.ps.user_agent)
user_agent=self.ps.user_agent, code=code)
return
log.msg("Ack", router_key="simplepush", channelID=chid,
message_id=version, message_source="stored",
uaid_hash=self.ps.uaid_hash,
user_agent=self.ps.user_agent)
user_agent=self.ps.user_agent, code=code)
if chid in self.ps.updates_sent and \
self.ps.updates_sent[chid] <= version:
del self.ps.updates_sent[chid]
Expand All @@ -1170,6 +1190,19 @@ def process_ack(self, data):
else:
self.check_missed_notifications(None)

def process_nack(self, data):
"""Process a nack message and log its contents"""
code = extract_code(data)
version = data.get("version")
if not version:
return

version, updateid = version.split(":")

log.msg("Nack", uaid_hash=self.ps.uaid_hash,
user_agent=self.ps.user_agent, message_id=version,
code=code)

def check_missed_notifications(self, results, resume=False):
"""Check to see if notifications were missed"""
if resume:
Expand Down

0 comments on commit d59cc24

Please sign in to comment.