Skip to content

Commit

Permalink
Adds getsubscriptioninfo to the CLN plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
sr-gi committed Aug 16, 2022
1 parent 4b70825 commit e671bbe
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 15 deletions.
4 changes: 4 additions & 0 deletions teos-common/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.field_attribute("appointment_data", "#[serde(rename = \"appointment\")]")
.field_attribute("user_id", "#[serde(with = \"hex::serde\")]")
.field_attribute("locator", "#[serde(with = \"hex::serde\")]")
.field_attribute(
"locators",
"#[serde(with = \"crate::ser::serde_vec_bytes\")]",
)
.field_attribute("encrypted_blob", "#[serde(with = \"hex::serde\")]")
.field_attribute(
"GetAppointmentResponse.status",
Expand Down
48 changes: 48 additions & 0 deletions teos-common/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,54 @@ where
seq.end()
}

pub mod serde_vec_bytes {
use super::*;
use serde::de::{self, Deserializer, SeqAccess};

pub fn serialize<S>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = s.serialize_seq(Some(v.len()))?;
for element in v.iter() {
seq.serialize_element(&hex::encode(element))?;
}
seq.end()
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<Vec<u8>>, D::Error>
where
D: Deserializer<'de>,
{
struct VecVisitor;

impl<'de> de::Visitor<'de> for VecVisitor {
type Value = Vec<Vec<u8>>;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a hex encoded string")
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut result = Vec::new();
while let Some(v) = seq.next_element::<String>()? {
result
.push(hex::decode(v).map_err(|_| {
de::Error::custom("cannot deserialize the given value")
})?);
}

Ok(result)
}
}

deserializer.deserialize_any(VecVisitor)
}
}

pub mod serde_status {
use serde::de::{self, Deserializer};
use serde::ser::Serializer;
Expand Down
4 changes: 2 additions & 2 deletions teos/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.field_attribute("tower_id", "#[serde(with = \"hex::serde\")]")
.field_attribute(
"user_ids",
"#[serde(serialize_with = \"crate::api::http::serialize_vec_bytes\")]",
"#[serde(serialize_with = \"teos_common::ser::serde_vec_bytes::serialize\")]",
)
.field_attribute(
"GetUserResponse.appointments",
"#[serde(serialize_with = \"crate::api::http::serialize_vec_bytes\")]",
"#[serde(serialize_with = \"teos_common::ser::serde_vec_bytes::serialize\")]",
)
.compile(
&[
Expand Down
13 changes: 1 addition & 12 deletions teos/src/api/http.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::error::Error;
use std::net::SocketAddr;
Expand Down Expand Up @@ -58,17 +58,6 @@ impl ApiError {
}
}

pub fn serialize_vec_bytes<S>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = s.serialize_seq(Some(v.len()))?;
for element in v.iter() {
seq.serialize_element(&hex::encode(element))?;
}
seq.end()
}

fn with_grpc(
grpc_endpoint: PublicTowerServicesClient<Channel>,
) -> impl Filter<Extract = (PublicTowerServicesClient<Channel>,), Error = Infallible> + Clone {
Expand Down
3 changes: 2 additions & 1 deletion watchtower-plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ commitment transaction is generated. It also keeps a summary of the messages sen
The plugin has the following methods:

- `registertower tower_id` : registers the user id (compressed public key) with a given tower.
- `list_towers`: lists all registered towers.
- `listtowers`: lists all registered towers.
- `gettowerinfo tower_id`: gets all the locally stored data about a given tower.
- `getsubscriptioninfo tower_id`: gets the subscription information by querying the tower.
- `retrytower tower_id`: tries to send pending appointment to a (previously) unreachable tower.
- `getappointment tower_id locator`: queries a given tower about an appointment.

Expand Down
47 changes: 47 additions & 0 deletions watchtower-plugin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,48 @@ async fn get_appointment(
Ok(json!(response))
}

/// Gets the subscription information directly form the tower.
async fn get_subscription_info(
plugin: Plugin<Arc<Mutex<WTClient>>>,
v: serde_json::Value,
) -> Result<serde_json::Value, Error> {
let tower_id = TowerId::try_from(v).map_err(|x| anyhow!(x))?;

let user_sk = plugin.state().lock().unwrap().user_sk;
let tower_net_addr = {
let state = plugin.state().lock().unwrap();
if let Some(info) = state.towers.get(&tower_id) {
Ok(info.net_addr.clone())
} else {
Err(anyhow!("Unknown tower id: {}", tower_id))
}
}?;

let get_subscription_info = format!("{}/get_subscription_info", tower_net_addr);
let signature = cryptography::sign("get subscription info".as_bytes(), &user_sk).unwrap();

let response: common_msgs::GetSubscriptionInfoResponse = process_post_response(
post_request(
&get_subscription_info,
&common_msgs::GetSubscriptionInfoRequest { signature },
)
.await,
)
.await
.map_err(|e| {
if e.is_connection() {
plugin
.state()
.lock()
.unwrap()
.set_tower_status(tower_id, TowerStatus::TemporaryUnreachable);
}
to_cln_error(e)
})?;

Ok(json!(response))
}

/// Lists all the registered towers.
///
/// The given information comes from memory, so it is summarized.
Expand Down Expand Up @@ -408,6 +450,11 @@ async fn main() -> Result<(), Error> {
"Gets appointment data from the tower given the tower id and the locator.",
get_appointment,
)
.rpcmethod(
"getsubscriptioninfo",
"Gets the subscription information directly from the tower.",
get_subscription_info,
)
.rpcmethod("listtowers", "Lists all registered towers.", list_towers)
.rpcmethod(
"gettowerinfo",
Expand Down

0 comments on commit e671bbe

Please sign in to comment.