Skip to content

Commit

Permalink
Add docstrings to most functions
Browse files Browse the repository at this point in the history
  • Loading branch information
yuvipanda committed Jan 7, 2023
1 parent 502710c commit 3653150
Showing 1 changed file with 114 additions and 2 deletions.
116 changes: 114 additions & 2 deletions lambda/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from opentelemetry import trace
from opentelemetry.propagate import inject
except ImportError:
# If opentelemetry is not installed, we make dummy objects / functions here
# so we do not need to add conditionls throughout our codebase
trace = None

def inject(obj):
Expand Down Expand Up @@ -97,6 +99,12 @@ def wrapper(*args, **kwargs):

@ttl_cache(maxsize=2, ttl=10 * 60, timer=time.time)
def get_black_list() -> dict:
"""
Return blacklist if configured
Looks at the BLACKLIST_ENDPOINT environment variable for a URL to
contact for fetching the blacklist from.
"""
endpoint = os.getenv('BLACKLIST_ENDPOINT', '')
if endpoint:
response = urllib.request.urlopen(endpoint).read().decode('utf-8')
Expand All @@ -108,7 +116,12 @@ def get_black_list() -> dict:


@app.middleware('http')
def initialize(event, get_response) -> Response:
def initialize(event: chalice.app.Request, get_response) -> Response:
"""
Initialize various properties needed for each request.
This function is called on *every* request before any of the handlers are called
"""
JWT_MANAGER.black_list = get_black_list()
jwt_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME'))
JWT_MANAGER.public_key = base64.b64decode(jwt_keys.get('rsa_pub_key', '')).decode()
Expand All @@ -119,6 +132,11 @@ def initialize(event, get_response) -> Response:

@app.middleware('http')
def set_log_context(event: chalice.app.Request, get_response) -> Response:
"""
Set context about current request for all log statements
This function is called for each request before the request handlers are called.
"""
origin_request_id = event.headers.get('x-origin-request-id')

log_context(
Expand All @@ -135,11 +153,19 @@ def set_log_context(event: chalice.app.Request, get_response) -> Response:
try:
return get_response(event)
finally:
# Reset log context after our request is completed
log_context(user_id=None, route=None, request_id=None)


@app.middleware('http')
def forward_origin_request_id(event: chalice.app.Request, get_response) -> Response:
"""
Normalize x-request-id header to point to aws_request_id for all requests
The original request_id, if present, is put in an x-origin-request-id
This function is called for each request before the request handlers are called.
"""
response = get_response(event)

origin_request_id = event.headers.get('x-origin-request-id')
Expand All @@ -157,17 +183,29 @@ class TeaException(Exception):


class EulaException(TeaException):
"""
Exception indicating that the authorized user has not accepted the EULA for accessing the dataset
"""
def __init__(self, payload: dict):
self.payload = payload


class RequestAuthorizer:
"""
Handle authorization of incoming requests.
Supports handling traditional OAuth2 with appropriate redirects, as well as using
bearer tokens.
"""
def __init__(self):
self._response = None
self._headers = {}

@with_trace()
def get_profile(self) -> Optional[UserProfile]:
"""
Return user profile if the user is authenticated
"""
user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers)
if user_profile is not None:
return user_profile
Expand Down Expand Up @@ -255,20 +293,32 @@ def get_success_response_headers(self) -> dict:

@with_trace()
def get_request_id() -> str:
"""
Return AWS Lambda Request ID
"""
assert app.lambda_context is not None

return app.lambda_context.aws_request_id


@with_trace()
def get_origin_request_id() -> Optional[str]:
"""
Return the *original* AWS Lambda Request ID
"""
assert app.current_request is not None

return app.current_request.headers.get("x-origin-request-id")


@with_trace()
def get_aux_request_headers() -> dict:
"""
Return common HTTP headers used when making requests to EarthData login servers.
These are headers we send when making requests as a *client*, not when making
responses as a server.
"""
req_headers = {"x-request-id": get_request_id()}
origin_request_id = get_origin_request_id()

Expand All @@ -283,6 +333,9 @@ def get_aux_request_headers() -> dict:

@with_trace()
def check_for_browser(hdrs) -> bool:
"""
Return True if request is being sent by a browser
"""
return 'user-agent' in hdrs and hdrs['user-agent'].lower().startswith('mozilla')


Expand Down Expand Up @@ -367,12 +420,18 @@ def get_user_from_token(token) -> Optional[str]:

@with_trace()
def cumulus_log_message(outcome: str, code: int, http_method: str, k_v: dict):
"""
Emit log message to stdout in a format that cumulus understands
"""
k_v.update({'code': code, 'http_method': http_method, 'status': outcome, 'requestid': get_request_id()})
print(json.dumps(k_v))


@with_trace()
def restore_bucket_vars():
"""
Update bucket config by re-fetching it from configured S3 object
"""
global b_map # pylint: disable=global-statement

log.debug('conf bucket: %s, bucket_map_file: %s', conf_bucket, bucket_map_file)
Expand Down Expand Up @@ -424,6 +483,9 @@ def do_auth_and_return(ctxt) -> Response:

@with_trace()
def add_cors_headers(headers):
"""
Add CORS headers to allow requests from all configured domains
"""
assert app.current_request is not None

# send CORS headers if we're configured to use them
Expand All @@ -439,6 +501,9 @@ def add_cors_headers(headers):

@with_trace()
def make_redirect(to_url, headers=None, status_code=301) -> Response:
"""
Return a HTTP Response redirecting users with appropriate headers
"""
if headers is None:
headers = {}
headers['Location'] = to_url
Expand All @@ -451,6 +516,9 @@ def make_redirect(to_url, headers=None, status_code=301) -> Response:

@with_trace()
def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html') -> Response:
"""
Return a HTTP response with rendered HTML from the given template
"""
template_vars = {
'STAGE': STAGE if not os.getenv('DOMAIN_NAME') else None,
'status_code': status_code,
Expand Down Expand Up @@ -489,6 +557,9 @@ def get_bcconfig(user_id: str) -> dict:
key=lambda _, bucketname: hashkey(bucketname)
)
def get_bucket_region(session, bucketname) -> str:
"""
Get the region of the given bucket
"""
try:
_time = time.time()
bucket_region = session.client('s3').get_bucket_location(Bucket=bucketname)['LocationConstraint'] or 'us-east-1'
Expand All @@ -508,6 +579,9 @@ def get_bucket_region(session, bucketname) -> str:

@with_trace()
def get_user_ip() -> str:
"""
Return IP of the user making the request
"""
assert app.current_request is not None

x_forwarded_for = app.current_request.headers.get('x-forwarded-for')
Expand All @@ -523,6 +597,12 @@ def get_user_ip() -> str:

@with_trace()
def try_download_from_bucket(bucket, filename, user_profile, headers: dict) -> Response:
"""
Attempt to redirect to given file from given bucket.
Returns a redirect response with presigned S3 URL if successful,
or an appropriate error response if unsuccessful.
"""
timer = Timer()
timer.mark()
user_id = None
Expand Down Expand Up @@ -634,6 +714,9 @@ def get_jwt_field(cookievar: dict, fieldname: str) -> Optional[str]:
@app.route('/')
@with_trace(context={})
def root() -> Response:
"""
Render human readable root page
"""
template_vars = {'title': 'Welcome'}
user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers)
if user_profile is not None:
Expand Down Expand Up @@ -739,6 +822,9 @@ def collapse_bucket_configuration(bucket_map) -> dict:

@with_trace()
def get_range_header_val() -> Optional[str]:
"""
Return value of range header if present
"""
if 'Range' in app.current_request.headers:
return app.current_request.headers['Range']
if 'range' in app.current_request.headers:
Expand Down Expand Up @@ -775,7 +861,13 @@ def get_data_dl_s3_client():


@with_trace()
def try_download_head(bucket, filename):
def try_download_head(bucket, filename) -> Response:
"""
Try to handle a HEAD request for given filename in given bucket
Return a redirect response if given filename exists in the bucket, provide
an error message otherwise.
"""
timer = Timer()

timer.mark("get_data_dl_s3_client()")
Expand Down Expand Up @@ -839,6 +931,13 @@ def try_download_head(bucket, filename):
@app.route('/{proxy+}', methods=['HEAD'])
@with_trace(context={})
def dynamic_url_head():
"""
Handle HEAD requests for arbitrary files in arbitrary buckets
The name of the bucket and filename is parsed out of the URL. If the
file is found in the bucket and the request is authenticated properly,
a signed s3 URL is returned. If not, an error response is returned.
"""
timer = Timer()
timer.mark("restore_bucket_vars()")
log.debug('attempting to HEAD a thing')
Expand Down Expand Up @@ -872,6 +971,13 @@ def dynamic_url_head():
@app.route('/{proxy+}', methods=['GET'])
@with_trace(context={})
def dynamic_url():
"""
Handle GET requests for arbitrary files in arbitrary buckets
The name of the bucket and filename is parsed out of the URL. If the
file is found in the bucket and the request is authenticated properly,
a signed s3 URL is returned. If not, an error response is returned.
"""
timer = Timer()
timer.mark("restore_bucket_vars()")

Expand Down Expand Up @@ -963,6 +1069,9 @@ def dynamic_url():
@app.route('/s3credentials', methods=['GET'])
@with_trace(context={})
def s3credentials():
"""
Return temporary AWS credentials to calling user, with ability to access S3
"""
timer = Timer()

timer.mark("restore_bucket_vars()")
Expand Down Expand Up @@ -1054,6 +1163,9 @@ def get_s3_credentials(user_id: str, role_session_name: str, policy: dict) -> di
@app.route('/s3credentialsREADME', methods=['GET'])
@with_trace(context={})
def s3credentials_readme() -> Response:
"""
Return a human readable README for how to use /s3credentials
"""
return make_html_response({}, {}, 200, "s3credentials_readme.html")


Expand Down

0 comments on commit 3653150

Please sign in to comment.