diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index aa57bff0..0c81ec71 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -696,6 +696,39 @@ impl SslCurve { #[cfg(feature = "pq-experimental")] pub const P256_KYBER768_DRAFT00: SslCurve = SslCurve(ffi::NID_P256Kyber768Draft00); +} + +/// A TLS Curve group ID. +#[repr(transparent)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct SslCurveId(u16); + +impl SslCurveId { + pub const SECP224R1: SslCurveId = SslCurveId(ffi::SSL_CURVE_SECP224R1 as _); + + pub const SECP256R1: SslCurveId = SslCurveId(ffi::SSL_CURVE_SECP256R1 as _); + + pub const SECP384R1: SslCurveId = SslCurveId(ffi::SSL_CURVE_SECP384R1 as _); + + pub const SECP521R1: SslCurveId = SslCurveId(ffi::SSL_CURVE_SECP521R1 as _); + + pub const X25519: SslCurveId = SslCurveId(ffi::SSL_CURVE_X25519 as _); + + #[cfg(not(feature = "fips"))] + pub const X25519_KYBER768_DRAFT00: SslCurveId = + SslCurveId(ffi::SSL_CURVE_X25519_KYBER768_DRAFT00 as _); + + #[cfg(feature = "pq-experimental")] + pub const X25519_KYBER768_DRAFT00_OLD: SslCurveId = + SslCurveId(ffi::SSL_CURVE_X25519_KYBER768_DRAFT00_OLD as _); + + #[cfg(feature = "pq-experimental")] + pub const X25519_KYBER512_DRAFT00: SslCurveId = + SslCurveId(ffi::SSL_CURVE_X25519_KYBER512_DRAFT00 as _); + + #[cfg(feature = "pq-experimental")] + pub const P256_KYBER768_DRAFT00: SslCurveId = + SslCurveId(ffi::SSL_CURVE_P256_KYBER768_DRAFT00 as _); /// Returns the curve name /// @@ -704,7 +737,7 @@ impl SslCurve { /// [`SSL_get_curve_name`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_get_curve_name pub fn name(&self) -> Option<&'static str> { unsafe { - let ptr = ffi::SSL_get_curve_name(self.0 as u16); + let ptr = ffi::SSL_get_curve_name(self.0); if ptr.is_null() { return None; } @@ -2766,12 +2799,12 @@ impl SslRef { /// This corresponds to [`SSL_get_curve_id`] /// /// [`SSL_get_curve_id`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_get_curve_id - pub fn curve(&self) -> Option { + pub fn curve(&self) -> Option { let curve_id = unsafe { ffi::SSL_get_curve_id(self.as_ptr()) }; if curve_id == 0 { return None; } - Some(SslCurve(curve_id.into())) + Some(SslCurveId(curve_id)) } /// Returns an `ErrorCode` value for the most recent operation on this `SslRef`. diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index 1abdde3a..f40cb7f4 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -11,9 +11,9 @@ use crate::error::ErrorStack; use crate::hash::MessageDigest; use crate::pkey::PKey; use crate::srtp::SrtpProfileId; -use crate::ssl; use crate::ssl::test::server::Server; use crate::ssl::SslVersion; +use crate::ssl::{self, SslCurveId}; use crate::ssl::{ ExtensionType, ShutdownResult, ShutdownState, Ssl, SslAcceptor, SslAcceptorBuilder, SslConnector, SslContext, SslFiletype, SslMethod, SslOptions, SslStream, SslVerifyMode, @@ -929,6 +929,15 @@ fn get_curve() { assert!(curve.name().is_some()); } +#[test] +fn get_curve_name() { + assert_eq!(SslCurveId::SECP224R1.name(), Some("P-224")); + assert_eq!(SslCurveId::SECP256R1.name(), Some("P-256")); + assert_eq!(SslCurveId::SECP384R1.name(), Some("P-384")); + assert_eq!(SslCurveId::SECP521R1.name(), Some("P-521")); + assert_eq!(SslCurveId::X25519.name(), Some("X25519")); +} + #[test] fn test_get_ciphers() { let ctx_builder = SslContext::builder(SslMethod::tls()).unwrap();