Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssm): Make decrypt an explicit option and refactoring #123

Merged
merged 13 commits into from
Aug 24, 2020
Merged
71 changes: 37 additions & 34 deletions aws_lambda_powertools/utilities/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from collections import namedtuple
from datetime import datetime, timedelta
from typing import Dict, Optional, Union
from typing import Dict, Optional, Tuple, Union

from .exceptions import GetParameterError, TransformParameterError

Expand All @@ -31,6 +31,9 @@ def __init__(self):

self.store = {}

def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool:
return key in self.store and self.store[key].ttl >= datetime.now()

def get(
self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options
) -> Union[str, list, dict, bytes]:
Expand Down Expand Up @@ -70,24 +73,26 @@ def get(
# an acceptable tradeoff.
key = (name, transform)

if key not in self.store or self.store[key].ttl < datetime.now():
try:
value = self._get(name, **sdk_options)
# Encapsulate all errors into a generic GetParameterError
except Exception as exc:
raise GetParameterError(str(exc))
if self._has_not_expired(key):
return self.store[key].value

try:
value = self._get(name, **sdk_options)
# Encapsulate all errors into a generic GetParameterError
except Exception as exc:
raise GetParameterError(str(exc))

if transform is not None:
value = transform_value(value, transform)
if transform is not None:
value = transform_value(value, transform)

self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),)
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),)

return self.store[key].value
return value

@abstractmethod
def _get(self, name: str, **sdk_options) -> str:
"""
Retrieve paramater value from the underlying parameter store
Retrieve parameter value from the underlying parameter store
"""
raise NotImplementedError()

Expand Down Expand Up @@ -129,29 +134,22 @@ def get_multiple(

key = (path, transform)

if key not in self.store or self.store[key].ttl < datetime.now():
try:
values = self._get_multiple(path, **sdk_options)
# Encapsulate all errors into a generic GetParameterError
except Exception as exc:
raise GetParameterError(str(exc))
if self._has_not_expired(key):
return self.store[key].value

if transform is not None:
new_values = {}
for key, value in values.items():
try:
new_values[key] = transform_value(value, transform)
except Exception as exc:
if raise_on_transform_error:
raise exc
else:
new_values[key] = None
try:
values: Dict[str, Union[str, bytes, dict, None]] = self._get_multiple(path, **sdk_options)
# Encapsulate all errors into a generic GetParameterError
except Exception as exc:
raise GetParameterError(str(exc))

values = new_values
if transform is not None:
for (key, value) in values.items():
values[key] = transform_value(value, transform, raise_on_transform_error)

self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),)
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),)

return self.store[key].value
return values

@abstractmethod
def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
Expand All @@ -161,16 +159,19 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
raise NotImplementedError()


def transform_value(value: str, transform: str) -> Union[dict, bytes]:
def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]:
"""
Apply a transform to a value

Parameters
---------
value: str
Parameter alue to transform
Parameter value to transform
transform: str
Type of transform, supported values are "json" and "binary"
raise_on_transform_error: bool, optional
Raises an exception if any transform fails, otherwise this will
return a None value for each transform that failed

Raises
------
Expand All @@ -187,4 +188,6 @@ def transform_value(value: str, transform: str) -> Union[dict, bytes]:
raise ValueError(f"Invalid transform type '{transform}'")

except Exception as exc:
raise TransformParameterError(str(exc))
if raise_on_transform_error:
raise TransformParameterError(str(exc))
return None
2 changes: 1 addition & 1 deletion aws_lambda_powertools/utilities/parameters/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get(self, name: str, **sdk_options) -> str:
----------
name: str
Name of the parameter
sdk_options: dict
sdk_options: dict, optional
Dictionary of options that will be passed to the Secrets Manager get_secret_value API call
"""

Expand Down
62 changes: 56 additions & 6 deletions aws_lambda_powertools/utilities/parameters/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import boto3
from botocore.config import Config

from .base import DEFAULT_PROVIDERS, BaseProvider
from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider


class SSMProvider(BaseProvider):
Expand Down Expand Up @@ -86,6 +86,46 @@ def __init__(

super().__init__()

def get(
self,
name: str,
max_age: int = DEFAULT_MAX_AGE_SECS,
transform: Optional[str] = None,
decrypt: bool = False,
**sdk_options
) -> Union[str, list, dict, bytes]:
"""
Retrieve a parameter value or return the cached value

Parameters
----------
name: str
Parameter name
max_age: int
Maximum age of the cached value
transform: str
Optional transformation of the parameter value. Supported values
are "json" for JSON strings and "binary" for base 64 encoded
values.
decrypt: bool, optional
If the parameter value should be decrypted
sdk_options: dict, optional
Arguments that will be passed directly to the underlying API call

Raises
------
GetParameterError
When the parameter provider fails to retrieve a parameter value for
a given name.
TransformParameterError
When the parameter provider fails to transform a parameter value.
"""

# Add to `decrypt` sdk_options to we can have an explicit option for this
sdk_options["decrypt"] = decrypt

return super().get(name, max_age, transform, **sdk_options)

def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str:
"""
Retrieve a parameter value from AWS Systems Manager Parameter Store
Expand Down Expand Up @@ -144,7 +184,9 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals
return parameters


def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, list, dict, bytes]:
def get_parameter(
name: str, transform: Optional[str] = None, decrypt: bool = False, **sdk_options
) -> Union[str, list, dict, bytes]:
"""
Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store

Expand All @@ -154,6 +196,8 @@ def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) ->
Name of the parameter
transform: str, optional
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
decrypt: bool, optional
If the parameter values should be decrypted
sdk_options: dict, optional
Dictionary of options that will be passed to the Parameter Store get_parameter API call

Expand Down Expand Up @@ -190,7 +234,10 @@ def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) ->
if "ssm" not in DEFAULT_PROVIDERS:
DEFAULT_PROVIDERS["ssm"] = SSMProvider()

return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform)
# Add to `decrypt` sdk_options to we can have an explicit option for this
sdk_options["decrypt"] = decrypt

return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, **sdk_options)


def get_parameters(
Expand All @@ -205,10 +252,10 @@ def get_parameters(
Path to retrieve the parameters
transform: str, optional
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
decrypt: bool, optional
If the parameter values should be decrypted
recursive: bool, optional
If this should retrieve the parameter values recursively or not, defaults to True
decrypt: bool, optional
If the parameter values should be decrypted
sdk_options: dict, optional
Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call

Expand Down Expand Up @@ -245,4 +292,7 @@ def get_parameters(
if "ssm" not in DEFAULT_PROVIDERS:
DEFAULT_PROVIDERS["ssm"] = SSMProvider()

return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, recursive=recursive, decrypt=decrypt)
sdk_options["recursive"] = recursive
sdk_options["decrypt"] = decrypt

return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, **sdk_options)
13 changes: 13 additions & 0 deletions tests/functional/test_utilities_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,7 @@ def test_get_parameter_new(monkeypatch, mock_name, mock_value):
class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
assert name == mock_name
assert not kwargs["decrypt"]
return mock_value

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
Expand Down Expand Up @@ -1355,6 +1356,8 @@ def _get(self, name: str, **kwargs) -> str:

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
assert path == mock_name
assert kwargs["recursive"]
assert not kwargs["decrypt"]
return mock_value

monkeypatch.setattr(parameters.ssm, "DEFAULT_PROVIDERS", {})
Expand Down Expand Up @@ -1468,3 +1471,13 @@ def test_transform_value_wrong(mock_value):
parameters.base.transform_value(mock_value, "INCORRECT")

assert "Invalid transform type" in str(excinfo)


def test_transform_value_ignore_error(mock_value):
"""
Test transform_value() does not raise errors when raise_on_transform_error is False
"""

value = parameters.base.transform_value(mock_value, "INCORRECT", raise_on_transform_error=False)

assert value is None