Skip to content

Commit

Permalink
Adds named arguments to CoreLN plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
sr-gi committed Dec 14, 2022
1 parent ae91cd5 commit d0a476d
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 18 deletions.
102 changes: 102 additions & 0 deletions teos-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::fmt;
use std::{convert::TryFrom, str::FromStr};

use serde::{Deserialize, Serialize};
use serde_json::json;

use bitcoin::secp256k1::{Error, PublicKey};

Expand Down Expand Up @@ -79,10 +80,111 @@ impl TryFrom<serde_json::Value> for UserId {
))
}
}
serde_json::Value::Object(mut m) => {
let param_count = m.len();
if param_count > 1 {
Err(format!(
"Unexpected json format. Expected a single parameter. Received: {}",
param_count
))
} else {
UserId::try_from(json!(m
.remove("user_id")
.or_else(|| m.remove("tower_id"))
.ok_or("user_id or tower_id not found")?))
}
}
_ => Err(format!(
"Unexpected request format. Expected: user_id/tower_id. Received: '{}'",
value
)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::collections::HashMap;

use crate::test_utils::get_random_user_id;

#[test]
fn try_from_json_string() {
let user_id = get_random_user_id();
assert_eq!(UserId::try_from(json!(user_id.to_string())), Ok(user_id));
}

#[test]
fn try_from_json_wrong_string() {
let user_id = "not_a_user_id";
assert!(matches!(
UserId::try_from(json!(user_id.to_string())),
Err(..)
));
}

#[test]
fn try_from_json_array() {
let user_id = get_random_user_id();
assert_eq!(UserId::try_from(json!([user_id.to_string()])), Ok(user_id));
}

#[test]
fn try_from_json_array_empty() {
assert!(matches!(UserId::try_from(json!([])), Err(..)));
}

#[test]
fn try_from_json_array_too_many_elements() {
let user_id = get_random_user_id();
assert!(matches!(
UserId::try_from(json!([user_id.to_string(), user_id.to_string()])),
Err(..)
));
}

#[test]
fn try_from_json_dict() {
let user_id = get_random_user_id();
assert_eq!(
UserId::try_from(json!(HashMap::from([("tower_id", user_id.to_string())]))),
Ok(user_id)
);
assert_eq!(
UserId::try_from(json!(HashMap::from([("user_id", user_id.to_string())]))),
Ok(user_id)
);
}

#[test]
fn try_from_json_empty_dict() {
assert!(matches!(
UserId::try_from(json!(HashMap::<String, serde_json::Value>::new())),
Err(..)
));
}

#[test]
fn try_from_json_wrong_dict() {
let user_id = get_random_user_id();
assert!(matches!(
UserId::try_from(json!(HashMap::from([("random_key", user_id.to_string())]))),
Err(..)
));
}

#[test]
fn try_from_json_dict_too_many_keys() {
let user_id = get_random_user_id();

assert!(matches!(
UserId::try_from(json!(HashMap::from([
("tower_id", user_id.to_string()),
("user_id", user_id.to_string())
]))),
Err(..)
));
}
}
163 changes: 146 additions & 17 deletions watchtower-plugin/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{convert::TryFrom, str::FromStr};

use hex::FromHex;
use serde::{Deserialize, Serialize};
use serde_json::json;

use bitcoin::{Transaction, Txid};

Expand Down Expand Up @@ -120,33 +121,44 @@ impl TryFrom<serde_json::Value> for RegisterParams {
},
serde_json::Value::Array(mut a) => {
let param_count = a.len();

match param_count {
1 => RegisterParams::try_from(a.pop().unwrap()),
2 | 3 => {
let tower_id = a.get(0).unwrap();
let host = a.get(1).unwrap();

if !tower_id.is_string() {
return Err(RegisterError::InvalidId(format!("tower_id must be a string. Received: {}", tower_id)));
}
if !host.is_string() {
return Err(RegisterError::InvalidHost(format!("host must be a string. Received: {}", host)));
}
let port = if param_count == 3 {
let p = a.get(2).unwrap();
if !p.is_u64() {
return Err(RegisterError::InvalidPort(format!("port must be a number. Received: {}", p)));
}
p.as_u64()
} else{
let tower_id = a.get(0).unwrap().as_str().ok_or_else(|| RegisterError::InvalidId("tower_id must be a string".to_string()))?;
let host = Some(a.get(1).unwrap().as_str().ok_or_else(|| RegisterError::InvalidHost("host must be a string".to_string()))?);
let port = if let Some(p) = a.get(2) {
Some(p.as_u64().ok_or_else(|| RegisterError::InvalidPort(format!("port must be a number. Received: {}", p)))?)
} else {
None
};

RegisterParams::new(tower_id.as_str().unwrap(), host.as_str(), port)
RegisterParams::new(tower_id, host, port)
}
_ => Err(RegisterError::InvalidFormat(format!("Unexpected request format. The request needs 1-3 parameters. Received: {}", param_count))),
}
},
serde_json::Value::Object(mut m) => {
let allowed_keys = ["tower_id", "host", "port"];
let param_count = m.len();

if m.is_empty() || param_count > allowed_keys.len() {
Err(RegisterError::InvalidFormat(format!("Unexpected request format. The request needs 1-3 parameters. Received: {}", param_count)))
} else if !m.contains_key(allowed_keys[0]){
Err(RegisterError::InvalidId(format!("{} is mandatory", allowed_keys[0])))
} else if !m.iter().all(|(k, _)| allowed_keys.contains(&k.as_str())) {
Err(RegisterError::InvalidFormat("Invalid named parameter found in request".to_owned()))
} else {
let mut params = Vec::with_capacity(allowed_keys.len());
for k in allowed_keys {
if let Some(v) = m.remove(k) {
params.push(v);
}
}

RegisterParams::try_from(json!(params))
}
},
_ => Err(RegisterError::InvalidFormat(
format!("Unexpected request format. Expected: 'tower_id[@host][:port]' or 'tower_id [host] [port]'. Received: '{}'", value),
)),
Expand Down Expand Up @@ -215,6 +227,33 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
Ok(Self { tower_id, locator })
}
}
serde_json::Value::Object(mut m) => {
let allowed_keys = ["tower_id", "locator"];

if m.len() > allowed_keys.len() {
return Err(GetAppointmentError::InvalidFormat(
"Invalid named argument found in request".to_owned(),
));
}

// DISCUSS: There may be a more idiomatic way of doing this
for k in allowed_keys.iter() {
if !m.contains_key(*k) {
return Err(GetAppointmentError::InvalidFormat(format!(
"{} is mandatory",
k
)));
}
}

let mut params = Vec::with_capacity(allowed_keys.len());
for k in allowed_keys {
if let Some(v) = m.remove(k) {
params.push(v);
}
}
GetAppointmentParams::try_from(json!(params))
}
_ => Err(GetAppointmentError::InvalidFormat(format!(
"Unexpected request format. Expected: tower_id locator. Received: '{}'",
value
Expand All @@ -238,6 +277,7 @@ pub struct CommitmentRevocation {
mod tests {
use super::*;
use serde_json::json;
use std::collections::HashMap;

const VALID_ID: &str = "020dea894c967319407265764aba31bdef75d463f96800f34dd6df61380d82dfc0";

Expand Down Expand Up @@ -372,6 +412,61 @@ mod tests {
assert!(matches!(p, Err(RegisterError::InvalidFormat(..))));
}

#[test]
fn test_try_from_json_dict() {
let id = json!(VALID_ID);
let host = json!("host");
let port = json!(80);

for v in [
HashMap::from([("tower_id", &id), ("host", &host), ("port", &port)]),
HashMap::from([("tower_id", &id), ("host", &host)]),
HashMap::from([("tower_id", &id)]),
] {
let p = RegisterParams::try_from(json!(v));
assert!(matches!(p, Ok(..)));
}

// Id key missing
let p =
RegisterParams::try_from(json!(HashMap::from([("host", &host), ("port", &port)])));
assert!(matches!(p, Err(RegisterError::InvalidId(..))));

// Wrong id key
let p = RegisterParams::try_from(json!(HashMap::from([
("wrong_tower_id", &id),
("tower_id", &id),
("host", &host),
("port", &port)
])));
assert!(matches!(p, Err(RegisterError::InvalidFormat(..))));

// Wrong host key
let p = RegisterParams::try_from(json!(HashMap::from([
("tower_id", &id),
("wrong_host", &host),
("port", &port)
])));
assert!(matches!(p, Err(RegisterError::InvalidFormat(..))));

// Wrong port key
let p = RegisterParams::try_from(json!(HashMap::from([
("tower_id", &id),
("host", &host),
("wrong_port", &port)
])));
assert!(matches!(p, Err(RegisterError::InvalidFormat(..))));

// Wrong param count (params should be 1-3)
let p = RegisterParams::try_from(json!(HashMap::from([
("tower_id", &id),
("host", &host),
("port", &port),
("another_param", &json!(0))
])));
assert!(matches!(p, Err(RegisterError::InvalidFormat(..))));
}

#[test]
fn test_try_from_other_json() {
// Unexpected json object (it must be either String or Array)
Expand Down Expand Up @@ -415,6 +510,40 @@ mod tests {
assert!(matches!(p, Err(GetAppointmentError::InvalidLocator(..))));
}

#[test]
fn test_try_from_dict() {
let id = json!(VALID_ID);
let locator = json!("c69517f00d9482e6b1c41639f9bdfd5c");

// Valid params
let p = GetAppointmentParams::try_from(json!(HashMap::from([
("tower_id", &id),
("locator", &locator)
])));
assert!(matches!(p, Ok(..)));

// Wrong keys
let p = GetAppointmentParams::try_from(json!(HashMap::from([
("wrong_tower_id", &id),
("locator", &locator)
])));
assert!(matches!(p, Err(GetAppointmentError::InvalidFormat(..))));

let p = GetAppointmentParams::try_from(json!(HashMap::from([
("tower_id", &id),
("wrong_locator", &locator)
])));
assert!(matches!(p, Err(GetAppointmentError::InvalidFormat(..))));

// Too many parameters
let p = GetAppointmentParams::try_from(json!(HashMap::from([
("tower_id", &id),
("locator", &locator),
("another_param", &json!(0))
])));
assert!(matches!(p, Err(GetAppointmentError::InvalidFormat(..))));
}

#[test]
fn test_try_from_other_json() {
// Unexpected json object (it must be either String or Array)
Expand Down
2 changes: 1 addition & 1 deletion watchtower-plugin/src/net/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub async fn post_request<S: Serialize>(
};

client.post(endpoint).json(&data).send().await.map_err(|e| {
log::debug!("POST request failed: {:?}", e);
log::debug!("An error ocurred when sending data to the tower: {}", e);
if e.is_connect() | e.is_timeout() {
RequestError::ConnectionError(
"Cannot connect to the tower. Connection refused".to_owned(),
Expand Down

0 comments on commit d0a476d

Please sign in to comment.