Skip to content

Commit

Permalink
refactor: Tokenserver: Rewrite inlined Python code in Rust (#1053)
Browse files Browse the repository at this point in the history
Closes #1049
  • Loading branch information
ethowitz authored Apr 28, 2021
1 parent 2ce4570 commit 34fe585
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 123 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ grpcio = { version = "0.8" }
lazy_static = "1.4.0"
pyo3 = "0.13"
hawk = "3.2"
hex = "0.4.3"
hostname = "0.3.1"
hkdf = "0.10"
hmac = "0.10"
Expand Down
228 changes: 105 additions & 123 deletions src/web/tokenserver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ use diesel::mysql::MysqlConnection;
use diesel::prelude::*;
use diesel::sql_types::*;
use diesel::RunQueryDsl;
use hmac::{Hmac, Mac, NewMac};
use sha2::Sha256;
use std::env;

use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use pyo3::types::{IntoPyDict, PyDict};

#[derive(Debug)]
enum MyError {
Expand All @@ -30,6 +32,55 @@ impl From<env::VarError> for MyError {
}
}

pub struct Tokenlib<'a> {
py: Python<'a>,
inner: &'a PyModule,
}

impl<'a> Tokenlib<'a> {
pub fn new(py: Python<'a>) -> Result<Self, PyErr> {
let inner = PyModule::import(py, "tokenlib").map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})?;

Ok(Self { py, inner })
}

pub fn make_token(&self, plaintext: &PyDict, shared_secret: &str) -> Result<String, PyErr> {
let kwargs = PyDict::new(self.py);
kwargs.set_item("secret", shared_secret)?;

match self.inner.call("make_token", (plaintext,), Some(kwargs)) {
Err(e) => {
e.print_and_set_sys_last_vars(self.py);
Err(e)
}
Ok(x) => Ok(x.extract::<String>().unwrap()),
}
}

pub fn get_derived_secret(
&self,
plaintext: &str,
shared_secret: &str,
) -> Result<String, PyErr> {
let kwargs = PyDict::new(self.py);
kwargs.set_item("secret", shared_secret)?;

match self
.inner
.call("get_derived_secret", (plaintext,), Some(kwargs))
{
Err(e) => {
e.print_and_set_sys_last_vars(self.py);
Err(e)
}
Ok(x) => Ok(x.extract::<String>().unwrap()),
}
}
}

#[derive(Debug, QueryableByName)]
struct TokenserverUser {
#[sql_type = "Bigint"]
Expand Down Expand Up @@ -138,128 +189,43 @@ pub fn get_sync(
.bind::<Text, _>(&email)
.load::<TokenserverUser>(&connection)
.unwrap();
let (python_result, python_derived_result) = Python::with_gil(|py| {
let tokenlib = PyModule::from_code(
py,
r###"
import base64
from hashlib import sha256
import hmac
import tokenlib
def make_token(plaintext, shared_secret):
return tokenlib.make_token(plaintext, secret=shared_secret)
def get_derived_secret(plaintext, shared_secret):
return tokenlib.get_derived_secret(plaintext, secret=shared_secret)
def encode_bytes(value):
"""Encode BrowserID's base64 encoding format.
BrowserID likes to strip padding characters off of base64-encoded strings,
meaning we can't use the stdlib routines to encode them directly. This
is a simple wrapper that strips the padding.
"""
if isinstance(value, str):
value = value.encode("ascii")
return base64.urlsafe_b64encode(value).rstrip(b"=").decode("ascii")
def fxa_metrics_hash(value, hmac_key):
"""Derive FxA metrics id from user's FxA email address or whatever.
This is used to obfuscate the id before logging it with the metrics
data, as a simple privacy measure.
"""
hasher = hmac.new(hmac_key.encode("ascii"), ''.encode("ascii"), sha256)
hasher.update(value.split("@", 1)[0].encode("ascii"))
return hasher.hexdigest()
def hash_device_id(fxa_uid, device, secret):
return fxa_metrics_hash(fxa_uid[:32] + device, secret)[:32]
"###,
"main.py",
"main",
)
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})?;
let client_state_b64 = match tokenlib.call1("encode_bytes", (&user_record[0].client_state,))
{
Err(e) => {
e.print_and_set_sys_last_vars(py);
return Err(e);
}
Ok(x) => x.extract::<String>().unwrap(),
};
let hashed_fxa_uid = match tokenlib.call1(
"fxa_metrics_hash",
(
&email,
env::var("FXA_METRICS_HASH_SECRET").unwrap_or_else(|_| "insecure".to_string()),
),
) {
Err(e) => {
e.print_and_set_sys_last_vars(py);
return Err(e);
}
Ok(x) => x.extract::<String>().unwrap(),
};
let device_id = "none".to_string();
let fxa_metrics_hash_secret =
env::var("FXA_METRICS_HASH_SECRET").unwrap_or_else(|_| "insecure".to_string());
let hashed_device_id = match tokenlib.call1(
"hash_device_id",
(&hashed_fxa_uid, device_id, &fxa_metrics_hash_secret),
) {
Err(e) => {
e.print_and_set_sys_last_vars(py);
return Err(e);
}
Ok(x) => x.extract::<String>().unwrap(),
};

let fxa_kid = format!(
"{:013}-{:}",
user_record[0].keys_changed_at.unwrap_or(0),
client_state_b64
);
let thedict = [
("node", &user_record[0].node),
("fxa_kid", &fxa_kid), // userid component of authorization email
("fxa_uid", &token_data.claims.sub),
("hashed_device_id", &hashed_device_id),
("hashed_fxa_uid", &hashed_fxa_uid),
]
.into_py_dict(py);
// todo don't hardcode
// we're supposed to check the "duration" query
// param and use that if present (for testing)
thedict.set_item("expires", 300).unwrap(); // todo this needs to be converted to timestamp int (now + value * 1000)
thedict.set_item("uid", user_record[0].uid).unwrap();
let result = match tokenlib.call1("make_token", (thedict, &shared_secret)) {
Err(e) => {
e.print_and_set_sys_last_vars(py);
return Err(e);
}
Ok(x) => x.extract::<String>().unwrap(),
};
let derived_result = match tokenlib.call1("get_derived_secret", (&result, &shared_secret)) {
Err(e) => {
e.print_and_set_sys_last_vars(py);
return Err(e);
}
Ok(x) => x.extract::<String>().unwrap(),
};
//assert_eq!(result, false);
Ok((result, derived_result))
})
.unwrap();
let client_state_b64 =
base64::encode_config(&user_record[0].client_state, base64::URL_SAFE_NO_PAD);
let fxa_metrics_hash_secret = env::var("FXA_METRICS_HASH_SECRET")
.unwrap_or_else(|_| "insecure".to_string())
.into_bytes();
let hashed_fxa_uid = fxa_metrics_hash(&email, &fxa_metrics_hash_secret);
let device_id = "none".to_string();
let hashed_device_id = hash_device_id(&hashed_fxa_uid, &device_id, &fxa_metrics_hash_secret);

let fxa_kid = format!(
"{:013}-{:}",
user_record[0].keys_changed_at.unwrap_or(0),
client_state_b64
);
let (python_result, python_derived_result) =
Python::with_gil(|py| -> Result<(String, String), PyErr> {
let thedict = [
("node", &user_record[0].node),
("fxa_kid", &fxa_kid), // userid component of authorization email
("fxa_uid", &token_data.claims.sub),
("hashed_device_id", &hashed_device_id),
("hashed_fxa_uid", &hashed_fxa_uid),
]
.into_py_dict(py);
// todo don't hardcode
// we're supposed to check the "duration" query
// param and use that if present (for testing)
thedict.set_item("expires", 300).unwrap(); // todo this needs to be converted to timestamp int (now + value * 1000)
thedict.set_item("uid", user_record[0].uid).unwrap();

let tokenlib = Tokenlib::new(py)?;
let result = tokenlib.make_token(thedict, &shared_secret)?;
let derived_result = tokenlib.get_derived_secret(&result, &shared_secret)?;
//assert_eq!(result, false);
Ok((result, derived_result))
})
.unwrap();
let api_endpoint = format!("{:}/1.5/{:}", user_record[0].node, user_record[0].uid);
Ok(TokenServerResult {
id: python_result,
Expand All @@ -269,3 +235,19 @@ def hash_device_id(fxa_uid, device, secret):
duration: "300".to_string(),
})
}

fn fxa_metrics_hash(value: &str, hmac_key: &[u8]) -> String {
let mut mac = Hmac::<Sha256>::new_varkey(hmac_key).unwrap();
let v = value.split('@').next().unwrap();
mac.update(v.as_bytes());

let result = mac.finalize().into_bytes();
hex::encode(result)
}

fn hash_device_id(fxa_uid: &str, device: &str, hmac_key: &[u8]) -> String {
let mut to_hash = String::from(&fxa_uid[0..32]);
to_hash.push_str(device);

String::from(&fxa_metrics_hash(&to_hash, hmac_key)[0..32])
}

0 comments on commit 34fe585

Please sign in to comment.