Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: port register command to Rust
Browse files Browse the repository at this point in the history
Closes #1190
  • Loading branch information
bbangert committed May 10, 2018
1 parent af272df commit 4ff1743
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 193 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ language: python
cache:
directories:
- $HOME/.cargo
- cargo
- autopush_rs/target
- $HOME/.cache/pip
sudo: required
Expand Down Expand Up @@ -29,7 +30,7 @@ install:
- if [ ${WITH_RUST:-true} != "false" ]; then curl https://sh.rustup.rs | sh -s -- -y || travis_terminate 1; fi
- export PATH=$PATH:$HOME/.cargo/bin
script:
- tox -- ${CODECOV:+--with-coverage --cover-xml --cover-package=autopush}
- travis_wait tox -- ${CODECOV:+--with-coverage --cover-xml --cover-package=autopush}
after_success:
- ${CODECOV:+codecov}
notifications:
Expand Down
7 changes: 4 additions & 3 deletions autopush/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def broadcast_subscribe(self, services):
log.debug("Send: %s", msg)
self.ws.send(msg)

def register(self, chid=None, key=None):
def register(self, chid=None, key=None, status=200):
chid = chid or str(uuid.uuid4())
msg = json.dumps(dict(messageType="register",
channelID=chid,
Expand All @@ -146,9 +146,10 @@ def register(self, chid=None, key=None):
rcv = self.ws.recv()
result = json.loads(rcv)
log.debug("Recv: %s", result)
assert result["status"] == 200
assert result["status"] == status
assert result["channelID"] == chid
self.channels[chid] = result["pushEndpoint"]
if status == 200:
self.channels[chid] = result["pushEndpoint"]
return result

def unregister(self, chid):
Expand Down
28 changes: 13 additions & 15 deletions autopush/tests/test_rs_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import uuid
from contextlib import contextmanager
from http.server import BaseHTTPRequestHandler, HTTPServer
from mock import Mock, call, patch
from mock import Mock, patch
from threading import Thread, Event
from unittest.case import SkipTest

Expand All @@ -34,7 +34,6 @@
from autopush.logging import begin_or_register
from autopush.main import EndpointApplication, RustConnectionApplication
from autopush.utils import base64url_encode
from autopush.metrics import SinkMetrics
from autopush.tests.support import TestingLogObserver
from autopush.tests.test_integration import (
Client,
Expand Down Expand Up @@ -652,19 +651,6 @@ def test_message_without_crypto_headers(self):
assert result is None
yield self.shut_down(client)

@inlineCallbacks
def test_message_with_topic(self):
data = str(uuid.uuid4())
self.conn.db.metrics = Mock(spec=SinkMetrics)
client = yield self.quick_register()
yield client.send_notification(data=data, topic="topicname")
self.conn.db.metrics.increment.assert_has_calls([
call('ua.command.register'),
# We can't see Rust metric calls
# call('ua.notification.topic')
])
yield self.shut_down(client)

@inlineCallbacks
def test_empty_message_without_crypto_headers(self):
client = yield self.quick_register()
Expand Down Expand Up @@ -759,6 +745,18 @@ def test_with_key(self):

yield self.shut_down(client)

@inlineCallbacks
def test_with_bad_key(self):
chid = str(uuid.uuid4())
client = Client("ws://localhost:{}/".format(self.connection_port))
yield client.connect()
yield client.hello()
result = yield client.register(chid=chid, key="af1883%&!@#*(",
status=400)
assert result["status"] == 400

yield self.shut_down(client)


class TestRustWebPushBroadcast(unittest.TestCase):
connection_port = 9050
Expand Down
61 changes: 0 additions & 61 deletions autopush/tests/test_webpush_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
DeleteMessage,
DropUser,
MigrateUser,
Register,
StoreMessages,
Unregister,
WebPushMessage,
Expand Down Expand Up @@ -310,66 +309,6 @@ def test_no_migrate(self):
assert db.message.tablename == tablename


class TestRegisterProcessor(BaseSetup):

def _makeFUT(self):
from autopush.webpush_server import RegisterCommand
return RegisterCommand(self.conf, self.db)

def test_register(self):
cmd = self._makeFUT()
chid = str(uuid4())
result = cmd.process(Register(
uaid=uuid4().hex,
channel_id=chid,
message_month=self.db.current_msg_month)
)
assert result.endpoint
assert self.metrics.increment.called
assert self.metrics.increment.call_args[0][0] == 'ua.command.register'
assert self.logs.logged(
lambda e: (e['log_format'] == "Register" and
e['channel_id'] == chid and
e['endpoint'] == result.endpoint)
)

def _test_invalid(self, chid, msg="use lower case, dashed format",
status=401):
cmd = self._makeFUT()
result = cmd.process(Register(
uaid=uuid4().hex,
channel_id=chid,
message_month=self.db.current_msg_month)
)
assert result.error
assert msg in result.error_msg
assert status == result.status

def test_register_bad_chid(self):
self._test_invalid("oof", "Invalid UUID")

def test_register_bad_chid_upper(self):
self._test_invalid(str(uuid4()).upper())

def test_register_bad_chid_nodash(self):
self._test_invalid(uuid4().hex)

def test_register_over_provisioning(self):
import autopush

def raise_condition(*args, **kwargs):
from botocore.exceptions import ClientError
raise ClientError(
{'Error': {'Code': 'ProvisionedThroughputExceededException'}},
'mock_update_item'
)

mock_table = Mock(spec=autopush.db.Message)
mock_table.register_channel = Mock(side_effect=raise_condition)
self.db.message_table = Mock(return_value=mock_table)
self._test_invalid(str(uuid4()), "overloaded", 503)


class TestUnregisterProcessor(BaseSetup):

def _makeFUT(self):
Expand Down
55 changes: 0 additions & 55 deletions autopush/webpush_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,22 +266,19 @@ def __init__(self, conf, db):
self.delete_message_processor = DeleteMessageCommand(conf, db)
self.drop_user_processor = DropUserCommand(conf, db)
self.migrate_user_proocessor = MigrateUserCommand(conf, db)
self.register_process = RegisterCommand(conf, db)
self.unregister_process = UnregisterCommand(conf, db)
self.store_messages_process = StoreMessagesUserCommand(conf, db)
self.deserialize = dict(
delete_message=DeleteMessage,
drop_user=DropUser,
migrate_user=MigrateUser,
register=Register,
unregister=Unregister,
store_messages=StoreMessages,
)
self.command_dict = dict(
delete_message=self.delete_message_processor,
drop_user=self.drop_user_processor,
migrate_user=self.migrate_user_proocessor,
register=self.register_process,
unregister=self.unregister_process,
store_messages=self.store_messages_process,
) # type: Dict[str, ProcessorCommand]
Expand Down Expand Up @@ -434,58 +431,6 @@ def _validate_chid(chid):
return True, None


@attrs(slots=True)
class Register(InputCommand):
channel_id = attrib() # type: str
uaid = attrib(convert=uaid_from_str) # type: Optional[UUID]
message_month = attrib() # type: str
key = attrib(default=None) # type: str


@attrs(slots=True)
class RegisterResponse(OutputCommand):
endpoint = attrib() # type: str


@attrs(slots=True)
class RegisterErrorResponse(OutputCommand):
error_msg = attrib() # type: str
error = attrib(default=True) # type: bool
status = attrib(default=401) # type: int


class RegisterCommand(ProcessorCommand):

def process(self, command):
# type: (Register) -> Union[RegisterResponse, RegisterErrorResponse]
valid, msg = _validate_chid(command.channel_id)
if not valid:
return RegisterErrorResponse(error_msg=msg)

endpoint = self.conf.make_endpoint(
command.uaid.hex,
command.channel_id,
command.key
)
message = self.db.message_table(command.message_month)
try:
message.register_channel(command.uaid.hex,
command.channel_id)
except ClientError as ex:
if (ex.response['Error']['Code'] ==
"ProvisionedThroughputExceededException"):
return RegisterErrorResponse(error_msg="overloaded",
status=503)
self.metrics.increment('ua.command.register')
log.info(
"Register",
channel_id=command.channel_id,
endpoint=endpoint,
uaid_hash=hasher(command.uaid.hex),
)
return RegisterResponse(endpoint=endpoint)


@attrs(slots=True)
class Unregister(InputCommand):
channel_id = attrib() # type: str
Expand Down
15 changes: 15 additions & 0 deletions autopush_rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions autopush_rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ authors = ["Alex Crichton <[email protected]>"]
crate-type = ["cdylib"]

[dependencies]
base64 = "0.9.1"
bytes = "0.4.6"
cadence = "0.13.2"
chrono = "0.4.2"
env_logger = { version = "0.5.6", default-features = false }
error-chain = "0.11.0"
fernet = "0.1"
futures = "0.1.21"
futures-backoff = "0.1"
hex = "0.3.2"
hostname = "0.1.4"
httparse = "1.2.4"
hyper = "0.11.25"
Expand Down
7 changes: 6 additions & 1 deletion autopush_rs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def __init__(self, conf, message_tables, queue):
cfg.port = conf.port
cfg.router_port = conf.router_port
cfg.router_url = ffi_from_buffer(conf.router_url)
cfg.endpoint_url = ffi_from_buffer(conf.endpoint_url)
self.crypto_key = ','.join(name.encode('utf-8') for name in
conf._crypto_key)
cfg.crypto_key = ffi_from_buffer(self.crypto_key)
cfg.ssl_cert = ffi_from_buffer(conf.ssl.cert)
cfg.ssl_dh_param = ffi_from_buffer(conf.ssl.dh_param)
cfg.ssl_key = ffi_from_buffer(conf.ssl.key)
Expand All @@ -38,7 +42,8 @@ def __init__(self, conf, message_tables, queue):
cfg.statsd_port = conf.statsd_port
cfg.router_table_name = ffi_from_buffer(conf.router_table.tablename)
# XXX: keepalive
self.message_table_names = ','.join(name.encode('utf-8') for name in message_tables)
self.message_table_names = ','.join(name.encode('utf-8') for name in
message_tables)
cfg.message_table_names = ffi_from_buffer(self.message_table_names)
cfg.megaphone_api_url = ffi_from_buffer(conf.megaphone_api_url)
cfg.megaphone_api_token = ffi_from_buffer(conf.megaphone_api_token)
Expand Down
38 changes: 0 additions & 38 deletions autopush_rs/src/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,6 @@ impl<F: FnOnce(&str) + Send> FnBox for F {
#[derive(Serialize)]
#[serde(tag = "command", rename_all = "snake_case")]
enum Call {
Register {
uaid: String,
channel_id: String,
message_month: String,
key: Option<String>,
},

Unregister {
uaid: String,
channel_id: String,
Expand Down Expand Up @@ -156,20 +149,6 @@ struct PythonError {
pub error_msg: String,
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum RegisterResponse {
Success {
endpoint: String,
},

Error {
error_msg: String,
error: bool,
status: u32,
},
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum UnRegisterResponse {
Expand Down Expand Up @@ -205,23 +184,6 @@ pub struct StoreMessagesResponse {
}

impl Server {
pub fn register(
&self,
uaid: String,
message_month: String,
channel_id: String,
key: Option<String>,
) -> MyFuture<RegisterResponse> {
let (call, fut) = PythonCall::new(&Call::Register {
uaid: uaid,
message_month: message_month,
channel_id: channel_id,
key: key,
});
self.send_to_python(call);
return fut;
}

pub fn unregister(
&self,
uaid: String,
Expand Down
Loading

0 comments on commit 4ff1743

Please sign in to comment.