Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(idempotent): Include function name in the idempotent key #326

Merged
merged 9 commits into from
Mar 12, 2021
4 changes: 4 additions & 0 deletions aws_lambda_powertools/utilities/idempotency/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
use_local_cache: bool = False,
local_cache_max_items: int = 256,
hash_function: str = "md5",
include_function_name: bool = True,
michaelbrewer marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Initialize the base persistence layer
Expand All @@ -32,6 +33,8 @@ def __init__(
Max number of items to store in local cache, by default 1024
hash_function: str, optional
Function to use for calculating hashes, by default md5.
include_function_name: bool, optional
michaelbrewer marked this conversation as resolved.
Show resolved Hide resolved
Whether to include the function name in the idempotent key
"""
self.event_key_jmespath = event_key_jmespath
self.payload_validation_jmespath = payload_validation_jmespath
Expand All @@ -41,3 +44,4 @@ def __init__(
self.use_local_cache = use_local_cache
self.local_cache_max_items = local_cache_max_items
self.hash_function = hash_function
self.include_function_name = include_function_name
10 changes: 6 additions & 4 deletions aws_lambda_powertools/utilities/idempotency/idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def handle(self) -> Any:
try:
# We call save_inprogress first as an optimization for the most common case where no idempotent record
# already exists. If it succeeds, there's no need to call get_record.
self.persistence_store.save_inprogress(event=self.event)
self.persistence_store.save_inprogress(event=self.event, context=self.context)
except IdempotencyItemAlreadyExistsError:
# Now we know the item already exists, we can retrieve it
record = self._get_idempotency_record()
Expand All @@ -151,7 +151,7 @@ def _get_idempotency_record(self) -> DataRecord:

"""
try:
event_record = self.persistence_store.get_record(self.event)
event_record = self.persistence_store.get_record(event=self.event, context=self.context)
except IdempotencyItemNotFoundError:
# This code path will only be triggered if the record is removed between save_inprogress and get_record.
logger.debug(
Expand Down Expand Up @@ -219,7 +219,9 @@ def _call_lambda_handler(self) -> Any:
# We need these nested blocks to preserve lambda handler exception in case the persistence store operation
# also raises an exception
try:
self.persistence_store.delete_record(event=self.event, exception=handler_exception)
self.persistence_store.delete_record(
event=self.event, context=self.context, exception=handler_exception
)
except Exception as delete_exception:
raise IdempotencyPersistenceLayerError(
"Failed to delete record from idempotency store"
Expand All @@ -228,7 +230,7 @@ def _call_lambda_handler(self) -> Any:

else:
try:
self.persistence_store.save_success(event=self.event, result=handler_response)
self.persistence_store.save_success(event=self.event, context=self.context, result=handler_response)
except Exception as save_exception:
raise IdempotencyPersistenceLayerError(
"Failed to update record state to success in idempotency store"
Expand Down
46 changes: 31 additions & 15 deletions aws_lambda_powertools/utilities/idempotency/persistence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
IdempotencyKeyError,
IdempotencyValidationError,
)
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,6 +122,7 @@ def __init__(self):
self.use_local_cache = False
self._cache: Optional[LRUDict] = None
self.hash_function = None
self.include_function_name = True

def configure(self, config: IdempotencyConfig) -> None:
"""
Expand Down Expand Up @@ -151,35 +153,40 @@ def configure(self, config: IdempotencyConfig) -> None:
if self.use_local_cache:
self._cache = LRUDict(max_items=config.local_cache_max_items)
self.hash_function = getattr(hashlib, config.hash_function)
self.include_function_name = config.include_function_name

def _get_hashed_idempotency_key(self, lambda_event: Dict[str, Any]) -> str:
def _get_hashed_idempotency_key(self, event: Dict[str, Any], context: LambdaContext) -> str:
"""
Extract data from lambda event using event key jmespath, and return a hashed representation

Parameters
----------
lambda_event: Dict[str, Any]
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context

Returns
-------
str
Hashed representation of the data extracted by the jmespath expression

"""
data = lambda_event
data = event

if self.event_key_jmespath:
data = self.event_key_compiled_jmespath.search(
lambda_event, options=jmespath.Options(**self.jmespath_options)
)
data = self.event_key_compiled_jmespath.search(event, options=jmespath.Options(**self.jmespath_options))

if self.is_missing_idempotency_key(data):
if self.raise_on_no_idempotency_key:
raise IdempotencyKeyError("No data found to create a hashed idempotency_key")
warnings.warn(f"No value found for idempotency_key. jmespath: {self.event_key_jmespath}")

return self._generate_hash(data)
generated_hard = self._generate_hash(data)
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
if self.include_function_name:
return f"{context.function_name}#{generated_hard}"
else:
return generated_hard

@staticmethod
def is_missing_idempotency_key(data) -> bool:
Expand Down Expand Up @@ -298,21 +305,23 @@ def _delete_from_cache(self, idempotency_key: str):
if idempotency_key in self._cache:
del self._cache[idempotency_key]

def save_success(self, event: Dict[str, Any], result: dict) -> None:
def save_success(self, event: Dict[str, Any], context: LambdaContext, result: dict) -> None:
"""
Save record of function's execution completing successfully

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
result: dict
The response from lambda handler
"""
response_data = json.dumps(result, cls=Encoder)

data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(event),
idempotency_key=self._get_hashed_idempotency_key(event, context),
status=STATUS_CONSTANTS["COMPLETED"],
expiry_timestamp=self._get_expiry_timestamp(),
response_data=response_data,
Expand All @@ -326,17 +335,19 @@ def save_success(self, event: Dict[str, Any], result: dict) -> None:

self._save_to_cache(data_record)

def save_inprogress(self, event: Dict[str, Any]) -> None:
def save_inprogress(self, event: Dict[str, Any], context: LambdaContext) -> None:
"""
Save record of function's execution being in progress

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
"""
data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(event),
idempotency_key=self._get_hashed_idempotency_key(event, context),
status=STATUS_CONSTANTS["INPROGRESS"],
expiry_timestamp=self._get_expiry_timestamp(),
payload_hash=self._get_hashed_payload(event),
Expand All @@ -349,18 +360,20 @@ def save_inprogress(self, event: Dict[str, Any]) -> None:

self._put_record(data_record)

def delete_record(self, event: Dict[str, Any], exception: Exception):
def delete_record(self, event: Dict[str, Any], context: LambdaContext, exception: Exception):
"""
Delete record from the persistence store

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
exception
The exception raised by the lambda handler
"""
data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(event))
data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(event, context))

logger.debug(
f"Lambda raised an exception ({type(exception).__name__}). Clearing in progress record in persistence "
Expand All @@ -370,14 +383,17 @@ def delete_record(self, event: Dict[str, Any], exception: Exception):

self._delete_from_cache(data_record.idempotency_key)

def get_record(self, event: Dict[str, Any]) -> DataRecord:
def get_record(self, event: Dict[str, Any], context: LambdaContext) -> DataRecord:
"""
Calculate idempotency key for lambda_event, then retrieve item from persistence store using idempotency key
and return it as a DataRecord instance.and return it as a DataRecord instance.

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context

Returns
-------
Expand All @@ -392,7 +408,7 @@ def get_record(self, event: Dict[str, Any]) -> DataRecord:
Event payload doesn't match the stored record for the given idempotency key
"""

idempotency_key = self._get_hashed_idempotency_key(event)
idempotency_key = self._get_hashed_idempotency_key(event, context)

cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key)
if cached_record:
Expand Down
19 changes: 16 additions & 3 deletions tests/functional/idempotency/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import json
import os
from collections import namedtuple
from decimal import Decimal
from unittest import mock

Expand Down Expand Up @@ -34,6 +35,18 @@ def lambda_apigw_event():
return event


@pytest.fixture
def lambda_context():
lambda_context = {
"function_name": "test-func",
"memory_limit_in_mb": 128,
"invoked_function_arn": "arn:aws:lambda:eu-west-1:809313241234:function:test-func",
"aws_request_id": "52fdfc07-2182-154f-163f-5f0f9a621d72",
}

return namedtuple("LambdaContext", lambda_context.keys())(*lambda_context.values())


@pytest.fixture
def timestamp_future():
return str(int((datetime.datetime.now() + datetime.timedelta(seconds=3600)).timestamp()))
Expand Down Expand Up @@ -132,18 +145,18 @@ def expected_params_put_item_with_validation(hashed_idempotency_key, hashed_vali


@pytest.fixture
def hashed_idempotency_key(lambda_apigw_event, default_jmespath):
def hashed_idempotency_key(lambda_apigw_event, default_jmespath, lambda_context):
compiled_jmespath = jmespath.compile(default_jmespath)
data = compiled_jmespath.search(lambda_apigw_event)
return hashlib.md5(json.dumps(data).encode()).hexdigest()
return "test-func#" + hashlib.md5(json.dumps(data).encode()).hexdigest()


@pytest.fixture
def hashed_idempotency_key_with_envelope(lambda_apigw_event):
event = unwrap_event_from_envelope(
data=lambda_apigw_event, envelope=envelopes.API_GATEWAY_HTTP, jmespath_options={}
)
return hashlib.md5(json.dumps(event).encode()).hexdigest()
return "test-func#" + hashlib.md5(json.dumps(event).encode()).hexdigest()


@pytest.fixture
Expand Down
Loading