Skip to content

Commit

Permalink
Add the ability to save licenses for later use
Browse files Browse the repository at this point in the history
  • Loading branch information
zmb3 committed Oct 21, 2024
1 parent b5dc810 commit bf453b1
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 5 deletions.
14 changes: 14 additions & 0 deletions src/core/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use crate::model::link::{Link, Stream};
use crate::nla::ntlm::Ntlm;
use std::io::{Read, Write};

use super::license::MemoryLicenseStore;
use super::LicenseStore;

impl From<&str> for KeyboardLayout {
fn from(e: &str) -> Self {
match e {
Expand Down Expand Up @@ -164,6 +167,8 @@ pub struct Connector {
/// Use network level authentication
/// default TRUE
use_nla: bool,
/// Stores RDS licenses for reuse.
license_store: Box<dyn LicenseStore>,
}

impl Connector {
Expand Down Expand Up @@ -193,6 +198,7 @@ impl Connector {
check_certificate: false,
name: "rdp-rs".to_string(),
use_nla: true,
license_store: Box::new(MemoryLicenseStore::new()),
}
}

Expand Down Expand Up @@ -255,6 +261,7 @@ impl Connector {
self.auto_logon,
None,
None,
&mut *self.license_store,
)?;
} else {
sec::connect(
Expand All @@ -266,6 +273,7 @@ impl Connector {
self.auto_logon,
None,
None,
&mut *self.license_store,
)?;
}

Expand Down Expand Up @@ -346,4 +354,10 @@ impl Connector {
self.use_nla = use_nla;
self
}

/// Use a custom license store implementation
pub fn use_license_store(mut self, license_store: Box<dyn LicenseStore>) -> Self {
self.license_store = license_store;
self
}
}
100 changes: 96 additions & 4 deletions src/core/license.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult};
use crate::model::rnd::random;
use crate::model::unicode;
use num_enum::TryFromPrimitive;
use std::collections::HashMap;
use std::ffi::CStr;
use std::ffi::CString;
use std::io::{self, Cursor, Read, Write};
use std::sync::Arc;
use std::sync::Mutex;

use crate::core::sec::SecurityFlag;

Expand All @@ -22,6 +25,8 @@ use rsa::{PublicKeyParts, RsaPublicKey};
use uuid::Uuid;
use x509_parser::{certificate::X509Certificate, prelude::FromDer};

use super::LicenseStore;

const SIGNATURE_ALG_RSA: u32 = 0x00000001;
const KEY_EXCHANGE_ALG_RSA: u32 = 0x00000001;
const CERT_CHAIN_VERSION_1: u32 = 0x00000001;
Expand Down Expand Up @@ -1029,14 +1034,12 @@ pub fn client_connect<T: Read + Write>(
mcs: &mut mcs::Client<T>,
client_machine: &str, // must be a UUID
username: &str,
license_store: &mut dyn LicenseStore,
) -> RdpResult<()> {
// We use the UUID that identifies the client as both the client machine name,
// and (in binary form) the hardware identifier for the client.
let client_uuid = Uuid::try_parse(client_machine)?;

// TODO(zmb3): attempt to load an existing license
let existing_license: Option<Vec<u8>> = None;

let (channel, payload) = mcs.read()?;
let session_encryption_data = match LicenseMessage::new(payload)? {
// When we get the `NewLicense` message at the start of the
Expand All @@ -1051,10 +1054,26 @@ pub fn client_connect<T: Read + Write>(
request.certificate,
);

let mut existing_license: Option<Vec<u8>> = None;
for issuer in request.scopes {
let l = license_store.read_license(
request.version_major,
request.version_minor,
&request.company_name,
&issuer,
&request.product_id,
);
if l.is_some() {
existing_license.replace(l.unwrap());
break;
}
}

// we either send information about a previously obtained license
// or a new license request, depending on whether we have a license
// cached from a previous attempt
if let Some(license) = existing_license {
println!("!! [LIC] using existing license");
let license_info = ClientLicenseInfo::new(
&session_encryption_data,
&license,
Expand All @@ -1065,6 +1084,7 @@ pub fn client_connect<T: Read + Write>(
license_response(MessageType::LicenseInfo, license_info.to_bytes()?)?,
)?
} else {
println!("!! [LIC] requesting new license");
let client_new_license_response = ClientNewLicense::new(
&session_encryption_data,
CString::new(username).unwrap_or_else(|_| CString::new("default").unwrap()),
Expand Down Expand Up @@ -1133,11 +1153,83 @@ pub fn client_connect<T: Read + Write>(
}
};

// TODO(zmb3): save the license
license_store.write_license(
license.version_major,
license.version_minor,
&license.company_name,
&license.scope,
&license.product_id,
&license.cert_data,
);

Ok(())
}

#[derive(PartialEq, Eq, Hash)]
struct LicenseStoreKey {
major: u16,
minor: u16,
company: String,
issuer: String,
product_id: String,
}

pub struct MemoryLicenseStore {
licenses: Arc<Mutex<HashMap<LicenseStoreKey, Vec<u8>>>>,
}

impl MemoryLicenseStore {
pub fn new() -> Self {
Self {
licenses: Arc::new(Mutex::new(HashMap::new())),
}
}
}

impl LicenseStore for MemoryLicenseStore {
fn write_license(
&mut self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
license: &[u8],
) {
self.licenses.lock().unwrap().insert(
LicenseStoreKey {
major,
minor,
company: company.to_owned(),
issuer: issuer.to_owned(),
product_id: product_id.to_owned(),
},
license.to_vec(),
);
}

fn read_license(
&self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
) -> Option<Vec<u8>> {
self.licenses
.lock()
.unwrap()
.get(&LicenseStoreKey {
major,
minor,
company: company.to_owned(),
issuer: issuer.to_owned(),
product_id: product_id.to_owned(),
})
.cloned()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
22 changes: 22 additions & 0 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,25 @@ pub mod per;
pub mod sec;
pub mod tpkt;
pub mod x224;

/// LicenseStore provides the ability to save (and later retrieve)
/// RDS licenses.
pub trait LicenseStore {
fn write_license(
&mut self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
license: &[u8],
);
fn read_license(
&self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
) -> Option<Vec<u8>>;
}
5 changes: 4 additions & 1 deletion src/core/sec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use crate::model::error::RdpResult;
use crate::model::unicode::Unicode;
use std::io::{Read, Write};

use super::LicenseStore;

/// Security flag send as header flage in core ptotocol
/// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e13405c5-668b-4716-94b2-1c2654ca1ad4?redirectedfrom=MSDN
#[repr(u16)]
Expand Down Expand Up @@ -160,6 +162,7 @@ pub fn connect<T: Read + Write>(
auto_logon: bool,
info_flags: Option<u32>,
extended_info_flags: Option<u32>,
license_store: &mut dyn LicenseStore,
) -> RdpResult<()> {
let perf_flags = if mcs.is_rdp_version_5_plus() {
extended_info_flags
Expand All @@ -176,6 +179,6 @@ pub fn connect<T: Read + Write>(
],
)?;

license::client_connect(mcs, agent_id, username)?;
license::client_connect(mcs, agent_id, username, license_store)?;
Ok(())
}

0 comments on commit bf453b1

Please sign in to comment.