Skip to content

Commit

Permalink
add keyvault decrypt (#333)
Browse files Browse the repository at this point in the history
* add keyvault decrypt

* refactor according to comments

* pub decrypt_parameters_encryption

* change to_string to serde_json::to_value

* optional version taken care of in the API

* simplify decrypt params

* rename decrypt param structs to *DecryptParameters
  • Loading branch information
vincentserpoul authored Aug 12, 2021
1 parent 881f5c5 commit dd07e20
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 0 deletions.
263 changes: 263 additions & 0 deletions sdk/key_vault/src/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ pub struct JsonWebKey {

const BASE64_URL_SAFE: Config = Config::new(CharacterSet::UrlSafe, false);

fn ser_base64<S>(bytes: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let base_64 = base64::encode_config(bytes, BASE64_URL_SAFE);
serializer.serialize_str(&base_64)
}

fn ser_base64_opt<S>(bytes: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
Expand Down Expand Up @@ -239,6 +247,148 @@ impl Display for SignatureAlgorithm {
}
}

#[derive(Debug, Serialize, Deserialize)]
pub enum EncryptionAlgorithm {
#[serde(rename = "A128CBC")]
A128Cbc,
#[serde(rename = "A128CBCPAD")]
A128CbcPad,
#[serde(rename = "A128GCM")]
A128Gcm,
#[serde(rename = "A192CBC")]
A192Cbc,
#[serde(rename = "A192CBCPAD")]
A192CbcPad,
#[serde(rename = "A192GCM")]
A192Gcm,
#[serde(rename = "A256CBC")]
A256Cbc,
#[serde(rename = "A256CBCPAD")]
A256CbcPad,
#[serde(rename = "A256GCM")]
A256Gcm,
#[serde(rename = "RSA-OAEP")]
RsaOaep,
#[serde(rename = "RSA-OAEP-256")]
RsaOaep256,
#[serde(rename = "RSA1_5")]
Rsa15,
}

impl Default for EncryptionAlgorithm {
fn default() -> Self {
EncryptionAlgorithm::A128Cbc
}
}

impl Display for EncryptionAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(self, f)
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct DecryptParameters {
pub decrypt_parameters_encryption: DecryptParametersEncryption,
#[serde(serialize_with = "ser_base64", deserialize_with = "deser_base64")]
pub ciphertext: Vec<u8>,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum DecryptParametersEncryption {
Rsa(RsaDecryptParameters),
AesGcm(AesGcmDecryptParameters),
AesCbc(AesCbcDecryptParameters),
}

#[derive(Debug, Serialize, Deserialize)]
pub struct RsaDecryptParameters {
algorithm: EncryptionAlgorithm,
}

impl RsaDecryptParameters {
pub fn new(algorithm: EncryptionAlgorithm) -> Result<Self, Error> {
match algorithm {
EncryptionAlgorithm::Rsa15
| EncryptionAlgorithm::RsaOaep
| EncryptionAlgorithm::RsaOaep256 => Ok(Self { algorithm }),
_ => Err(Error::EncryptionAlgorithmMismatch),
}
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct AesGcmDecryptParameters {
algorithm: EncryptionAlgorithm,
#[serde(serialize_with = "ser_base64", deserialize_with = "deser_base64")]
pub iv: Vec<u8>,
#[serde(serialize_with = "ser_base64", deserialize_with = "deser_base64")]
pub authentication_tag: Vec<u8>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
pub additional_authenticated_data: Option<Vec<u8>>,
}

impl AesGcmDecryptParameters {
pub fn new(
algorithm: EncryptionAlgorithm,
iv: Vec<u8>,
authentication_tag: Vec<u8>,
additional_authenticated_data: Option<Vec<u8>>,
) -> Result<Self, Error> {
match algorithm {
EncryptionAlgorithm::A128Gcm
| EncryptionAlgorithm::A192Gcm
| EncryptionAlgorithm::A256Gcm => Ok(Self {
algorithm,
iv,
authentication_tag,
additional_authenticated_data,
}),
_ => Err(Error::EncryptionAlgorithmMismatch),
}
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct AesCbcDecryptParameters {
algorithm: EncryptionAlgorithm,
#[serde(serialize_with = "ser_base64", deserialize_with = "deser_base64")]
pub iv: Vec<u8>,
}

impl AesCbcDecryptParameters {
pub fn new(algorithm: EncryptionAlgorithm, iv: Vec<u8>) -> Result<Self, Error> {
match algorithm {
EncryptionAlgorithm::A128Cbc
| EncryptionAlgorithm::A192Cbc
| EncryptionAlgorithm::A256Cbc
| EncryptionAlgorithm::A128CbcPad
| EncryptionAlgorithm::A192CbcPad
| EncryptionAlgorithm::A256CbcPad => Ok(Self { algorithm, iv }),
_ => Err(Error::EncryptionAlgorithmMismatch),
}
}
}

#[derive(Debug, Deserialize, Getters)]
#[getset(get = "pub")]
pub struct DecryptResult {
#[serde(skip)]
algorithm: EncryptionAlgorithm,
#[serde(rename = "kid")]
key_id: String,
#[serde(
rename = "value",
serialize_with = "ser_base64",
deserialize_with = "deser_base64"
)]
result: Vec<u8>,
}

impl<'a, T: TokenCredential> KeyClient<'a, T> {
/// Gets the public part of a stored key.
/// The get key operation is applicable to all key types.
Expand Down Expand Up @@ -296,6 +446,71 @@ impl<'a, T: TokenCredential> KeyClient<'a, T> {
result.algorithm = algorithm;
Ok(result)
}

/// Decrypt a single block of encrypted data.
/// The DECRYPT operation decrypts a well-formed block of ciphertext using the target encryption key and specified algorithm.
/// This operation is the reverse of the ENCRYPT operation; only a single block of data may be decrypted, the size of this block is dependent on the target key and the algorithm to be used.
/// The DECRYPT operation applies to asymmetric and symmetric keys stored in Vault or HSM since it uses the private portion of the key. This operation requires the keys/decrypt permission.
pub async fn decrypt(
&mut self,
key_name: &str,
key_version: Option<&str>,
decrypt_parameters: DecryptParameters,
) -> Result<DecryptResult, Error> {
// POST {vaultBaseUrl}/keys/{key-name}/{key-version}/decrypt?api-version=7.2

let mut uri = self.vault_url.clone();
let path = format!("keys/{}/{}/decrypt", key_name, key_version.unwrap_or(""));

uri.set_path(&path);
uri.set_query(Some(API_VERSION_PARAM));

let mut request_body = Map::new();
request_body.insert(
"value".to_owned(),
Value::String(base64::encode(decrypt_parameters.ciphertext.to_owned())),
);

let algorithm = match decrypt_parameters.decrypt_parameters_encryption {
DecryptParametersEncryption::Rsa(RsaDecryptParameters { algorithm }) => {
request_body.insert("alg".to_owned(), serde_json::to_value(&algorithm).unwrap());
algorithm
}
DecryptParametersEncryption::AesGcm(AesGcmDecryptParameters {
algorithm,
iv,
authentication_tag,
additional_authenticated_data,
}) => {
request_body.insert("alg".to_owned(), serde_json::to_value(&algorithm).unwrap());
request_body.insert("iv".to_owned(), serde_json::to_value(iv).unwrap());
request_body.insert(
"tag".to_owned(),
serde_json::to_value(authentication_tag).unwrap(),
);
if let Some(aad) = additional_authenticated_data {
request_body.insert("aad".to_owned(), serde_json::to_value(aad).unwrap());
};
algorithm
}
DecryptParametersEncryption::AesCbc(AesCbcDecryptParameters { algorithm, iv }) => {
request_body.insert("alg".to_owned(), serde_json::to_value(&algorithm).unwrap());
request_body.insert("iv".to_owned(), serde_json::to_value(iv).unwrap());
algorithm
}
};

let response = self
.post_authed(
uri.to_string(),
Some(Value::Object(request_body).to_string()),
)
.await?;

let mut result = serde_json::from_str::<DecryptResult>(&response)?;
result.algorithm = algorithm;
Ok(result)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -435,4 +650,52 @@ mod tests {
assert_eq!(expected_sig, sig.to_owned());
assert!(matches!(alg, SignatureAlgorithm::RS512));
}

#[tokio::test]
async fn can_decrypt() {
let _m = mock("POST", "/keys/test-key/78deebed173b48e48f55abf87ed4cf71/decrypt")
.match_query(Matcher::UrlEncoded("api-version".into(), API_VERSION.into()))
.with_header("content-type", "application/json")
.with_body(
json!({
"kid": "https://myvault.vault.azure.net/keys/test-key/78deebed173b48e48f55abf87ed4cf71",
"value": "dvDmrSBpjRjtYg"
})
.to_string(),
)
.with_status(200)
.create();

let creds = MockCredential;
let mut client = mock_client!(&"test-keyvault", &creds,);

let decrypt_parameters = DecryptParameters {
ciphertext: base64::decode("dvDmrSBpjRjtYg").unwrap(),
decrypt_parameters_encryption: DecryptParametersEncryption::Rsa(
RsaDecryptParameters::new(EncryptionAlgorithm::RsaOaep256).unwrap(),
),
};

let res = client
.decrypt(
"test-key",
Some("78deebed173b48e48f55abf87ed4cf71"),
decrypt_parameters,
)
.await
.unwrap();

let kid = res.key_id();
let val = res.result();
let alg = res.algorithm();

assert_eq!(
kid,
"https://myvault.vault.azure.net/keys/test-key/78deebed173b48e48f55abf87ed4cf71"
);
let expected_val = base64::decode_config("dvDmrSBpjRjtYg", BASE64_URL_SAFE).unwrap();
assert_eq!(expected_val, val.to_owned());

assert!(matches!(alg, &EncryptionAlgorithm::RsaOaep256));
}
}
3 changes: 3 additions & 0 deletions sdk/key_vault/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ pub enum Error {
secret_name: String,
response_body: String,
},

#[error("Encryption algorithm mismatch")]
EncryptionAlgorithmMismatch,
}

#[cfg(test)]
Expand Down

0 comments on commit dd07e20

Please sign in to comment.