-
-
Notifications
You must be signed in to change notification settings - Fork 202
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
340 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,16 @@ use std::io; | |
use std::fmt; | ||
use std::error; | ||
use std::sync::Arc; | ||
use self::schannel::crypt_prov::{AcquireOptions, CryptProv, ProviderType}; | ||
use self::schannel::cert_store::{PfxImportOptions, Memory, CertStore, CertAdd}; | ||
use self::schannel::cert_context::CertContext; | ||
use self::schannel::cert_context::{CertContext, KeySpec}; | ||
use self::schannel::schannel_cred::{Direction, SchannelCred, Protocol}; | ||
use self::schannel::tls_stream; | ||
|
||
const CONTAINER_NAME: &'static str = "native-tls"; | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
sfackler
Author
Owner
|
||
|
||
fn convert_protocols(protocols: &[::Protocol]) -> Vec<Protocol> { | ||
protocols | ||
.iter() | ||
protocols.iter() | ||
.map(|p| match *p { | ||
::Protocol::Sslv3 => Protocol::Ssl3, | ||
::Protocol::Tlsv10 => Protocol::Tls10, | ||
|
@@ -66,21 +68,17 @@ impl Pkcs12 { | |
.silent(true) | ||
.compare_key(true) | ||
.acquire() | ||
.is_ok() | ||
{ | ||
.is_ok() { | ||
identity = Some(cert); | ||
} | ||
} | ||
|
||
let identity = match identity { | ||
Some(identity) => identity, | ||
None => { | ||
return Err( | ||
io::Error::new( | ||
io::ErrorKind::InvalidInput, | ||
"No identity found in PKCS #12 archive", | ||
).into(), | ||
); | ||
return Err(io::Error::new(io::ErrorKind::InvalidInput, | ||
"No identity found in PKCS #12 archive") | ||
.into()); | ||
} | ||
}; | ||
|
||
|
@@ -97,20 +95,49 @@ impl Certificate { | |
} | ||
} | ||
|
||
pub struct PrivateKey(CryptProv); | ||
|
||
impl PrivateKey { | ||
pub fn from_der(buf: &[u8]) -> Result<PrivateKey, Error> { | ||
let mut options = AcquireOptions::new(); | ||
options.container(CONTAINER_NAME) | ||
.new_keyset(true); | ||
let type_ = ProviderType::rsa_full(); | ||
|
||
// this is kind of a mess - we have to tell WinAPI to either open an | ||
// existing container or create a new one, but there's no "open or | ||
// create" option. If you try to create it and it exists it'll error | ||
// and if you try to open it and it doesn't exist it'll error. We first | ||
// try to open an existing one, then try to create it, then finally try | ||
// to open it in case a parallel caller created it concurrently. | ||
let mut container = match options.acquire(type_) { | ||
Ok(container) => container, | ||
Err(_) => { | ||
match options.new_keyset(true).acquire(type_) { | ||
Ok(container) => container, | ||
Err(_) => options.new_keyset(false).acquire(type_)?, | ||
} | ||
} | ||
}; | ||
|
||
container.import().import(buf)?; | ||
|
||
Ok(PrivateKey(container)) | ||
} | ||
} | ||
|
||
pub struct MidHandshakeTlsStream<S>(tls_stream::MidHandshakeTlsStream<S>); | ||
|
||
impl<S> fmt::Debug for MidHandshakeTlsStream<S> | ||
where | ||
S: fmt::Debug, | ||
where S: fmt::Debug | ||
{ | ||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { | ||
fmt::Debug::fmt(&self.0, fmt) | ||
} | ||
} | ||
|
||
impl<S> MidHandshakeTlsStream<S> | ||
where | ||
S: io::Read + io::Write, | ||
where S: io::Read + io::Write | ||
{ | ||
pub fn get_ref(&self) -> &S { | ||
self.0.get_ref() | ||
|
@@ -192,26 +219,22 @@ impl TlsConnector { | |
} | ||
|
||
pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> | ||
where | ||
S: io::Read + io::Write, | ||
where S: io::Read + io::Write | ||
{ | ||
self._connect(Some(domain), stream) | ||
} | ||
|
||
pub fn connect_no_domain<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> | ||
where | ||
S: io::Read + io::Write, | ||
where S: io::Read + io::Write | ||
{ | ||
self._connect(None, stream) | ||
} | ||
|
||
fn _connect<S>( | ||
&self, | ||
domain: Option<&str>, | ||
stream: S, | ||
) -> Result<TlsStream<S>, HandshakeError<S>> | ||
where | ||
S: io::Read + io::Write, | ||
fn _connect<S>(&self, | ||
domain: Option<&str>, | ||
stream: S) | ||
-> Result<TlsStream<S>, HandshakeError<S>> | ||
where S: io::Read + io::Write | ||
{ | ||
let mut builder = SchannelCred::builder(); | ||
builder.enabled_protocols(&self.protocols); | ||
|
@@ -262,9 +285,31 @@ impl TlsAcceptor { | |
})) | ||
} | ||
|
||
pub fn builder2(_key: PrivateKey, | ||
cert: Certificate, | ||
chain: Vec<Certificate>) | ||
-> Result<TlsAcceptorBuilder, Error> { | ||
let mut store = try!(Memory::new()).into_store(); | ||
for cert in chain { | ||
try!(store.add_cert(&cert.0, CertAdd::ReplaceExisting)); | ||
} | ||
let cert = try!(store.add_cert(&cert.0, CertAdd::ReplaceExisting)); | ||
|
||
try!(cert.set_key_prov_info() | ||
.container(CONTAINER_NAME) | ||
.type_(ProviderType::rsa_full()) | ||
.keep_open(true) | ||
.key_spec(KeySpec::key_exchange()) | ||
.set()); | ||
|
||
Ok(TlsAcceptorBuilder(TlsAcceptor { | ||
cert: cert, | ||
protocols: vec![Protocol::Tls10, Protocol::Tls11, Protocol::Tls12], | ||
})) | ||
} | ||
|
||
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> | ||
where | ||
S: io::Read + io::Write, | ||
where S: io::Read + io::Write | ||
{ | ||
let mut builder = SchannelCred::builder(); | ||
builder.enabled_protocols(&self.protocols); | ||
|
@@ -326,14 +371,12 @@ pub trait TlsConnectorBuilderExt { | |
/// Sets a callback function which decides if the server's certificate chain | ||
/// is to be trusted. | ||
fn verify_callback<F>(&mut self, callback: F) | ||
where | ||
F: Fn(tls_stream::CertValidationResult) -> io::Result<()> + 'static + Send + Sync; | ||
where F: Fn(tls_stream::CertValidationResult) -> io::Result<()> + 'static + Send + Sync; | ||
} | ||
|
||
impl TlsConnectorBuilderExt for ::TlsConnectorBuilder { | ||
fn verify_callback<F>(&mut self, callback: F) | ||
where | ||
F: Fn(tls_stream::CertValidationResult) -> io::Result<()> + 'static + Send + Sync, | ||
where F: Fn(tls_stream::CertValidationResult) -> io::Result<()> + 'static + Send + Sync | ||
{ | ||
(self.0).0.callback = Some(Arc::new(callback)); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
How does namespacing work? Will using a constant work if you have multiple servers with different keys active at the same time?