From 365315062a74932c9ac05c7d328be22117ecc5b1 Mon Sep 17 00:00:00 2001 From: YuviPanda <yuvipanda@gmail.com> Date: Fri, 6 Jan 2023 16:09:11 -0800 Subject: [PATCH] Add docstrings to most functions --- lambda/app.py | 116 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 2 deletions(-) diff --git a/lambda/app.py b/lambda/app.py index e8879da1..9df263db 100644 --- a/lambda/app.py +++ b/lambda/app.py @@ -24,6 +24,8 @@ from opentelemetry import trace from opentelemetry.propagate import inject except ImportError: + # If opentelemetry is not installed, we make dummy objects / functions here + # so we do not need to add conditionls throughout our codebase trace = None def inject(obj): @@ -97,6 +99,12 @@ def wrapper(*args, **kwargs): @ttl_cache(maxsize=2, ttl=10 * 60, timer=time.time) def get_black_list() -> dict: + """ + Return blacklist if configured + + Looks at the BLACKLIST_ENDPOINT environment variable for a URL to + contact for fetching the blacklist from. + """ endpoint = os.getenv('BLACKLIST_ENDPOINT', '') if endpoint: response = urllib.request.urlopen(endpoint).read().decode('utf-8') @@ -108,7 +116,12 @@ def get_black_list() -> dict: @app.middleware('http') -def initialize(event, get_response) -> Response: +def initialize(event: chalice.app.Request, get_response) -> Response: + """ + Initialize various properties needed for each request. + + This function is called on *every* request before any of the handlers are called + """ JWT_MANAGER.black_list = get_black_list() jwt_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME')) JWT_MANAGER.public_key = base64.b64decode(jwt_keys.get('rsa_pub_key', '')).decode() @@ -119,6 +132,11 @@ def initialize(event, get_response) -> Response: @app.middleware('http') def set_log_context(event: chalice.app.Request, get_response) -> Response: + """ + Set context about current request for all log statements + + This function is called for each request before the request handlers are called. + """ origin_request_id = event.headers.get('x-origin-request-id') log_context( @@ -135,11 +153,19 @@ def set_log_context(event: chalice.app.Request, get_response) -> Response: try: return get_response(event) finally: + # Reset log context after our request is completed log_context(user_id=None, route=None, request_id=None) @app.middleware('http') def forward_origin_request_id(event: chalice.app.Request, get_response) -> Response: + """ + Normalize x-request-id header to point to aws_request_id for all requests + + The original request_id, if present, is put in an x-origin-request-id + + This function is called for each request before the request handlers are called. + """ response = get_response(event) origin_request_id = event.headers.get('x-origin-request-id') @@ -157,17 +183,29 @@ class TeaException(Exception): class EulaException(TeaException): + """ + Exception indicating that the authorized user has not accepted the EULA for accessing the dataset + """ def __init__(self, payload: dict): self.payload = payload class RequestAuthorizer: + """ + Handle authorization of incoming requests. + + Supports handling traditional OAuth2 with appropriate redirects, as well as using + bearer tokens. + """ def __init__(self): self._response = None self._headers = {} @with_trace() def get_profile(self) -> Optional[UserProfile]: + """ + Return user profile if the user is authenticated + """ user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) if user_profile is not None: return user_profile @@ -255,6 +293,9 @@ def get_success_response_headers(self) -> dict: @with_trace() def get_request_id() -> str: + """ + Return AWS Lambda Request ID + """ assert app.lambda_context is not None return app.lambda_context.aws_request_id @@ -262,6 +303,9 @@ def get_request_id() -> str: @with_trace() def get_origin_request_id() -> Optional[str]: + """ + Return the *original* AWS Lambda Request ID + """ assert app.current_request is not None return app.current_request.headers.get("x-origin-request-id") @@ -269,6 +313,12 @@ def get_origin_request_id() -> Optional[str]: @with_trace() def get_aux_request_headers() -> dict: + """ + Return common HTTP headers used when making requests to EarthData login servers. + + These are headers we send when making requests as a *client*, not when making + responses as a server. + """ req_headers = {"x-request-id": get_request_id()} origin_request_id = get_origin_request_id() @@ -283,6 +333,9 @@ def get_aux_request_headers() -> dict: @with_trace() def check_for_browser(hdrs) -> bool: + """ + Return True if request is being sent by a browser + """ return 'user-agent' in hdrs and hdrs['user-agent'].lower().startswith('mozilla') @@ -367,12 +420,18 @@ def get_user_from_token(token) -> Optional[str]: @with_trace() def cumulus_log_message(outcome: str, code: int, http_method: str, k_v: dict): + """ + Emit log message to stdout in a format that cumulus understands + """ k_v.update({'code': code, 'http_method': http_method, 'status': outcome, 'requestid': get_request_id()}) print(json.dumps(k_v)) @with_trace() def restore_bucket_vars(): + """ + Update bucket config by re-fetching it from configured S3 object + """ global b_map # pylint: disable=global-statement log.debug('conf bucket: %s, bucket_map_file: %s', conf_bucket, bucket_map_file) @@ -424,6 +483,9 @@ def do_auth_and_return(ctxt) -> Response: @with_trace() def add_cors_headers(headers): + """ + Add CORS headers to allow requests from all configured domains + """ assert app.current_request is not None # send CORS headers if we're configured to use them @@ -439,6 +501,9 @@ def add_cors_headers(headers): @with_trace() def make_redirect(to_url, headers=None, status_code=301) -> Response: + """ + Return a HTTP Response redirecting users with appropriate headers + """ if headers is None: headers = {} headers['Location'] = to_url @@ -451,6 +516,9 @@ def make_redirect(to_url, headers=None, status_code=301) -> Response: @with_trace() def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html') -> Response: + """ + Return a HTTP response with rendered HTML from the given template + """ template_vars = { 'STAGE': STAGE if not os.getenv('DOMAIN_NAME') else None, 'status_code': status_code, @@ -489,6 +557,9 @@ def get_bcconfig(user_id: str) -> dict: key=lambda _, bucketname: hashkey(bucketname) ) def get_bucket_region(session, bucketname) -> str: + """ + Get the region of the given bucket + """ try: _time = time.time() bucket_region = session.client('s3').get_bucket_location(Bucket=bucketname)['LocationConstraint'] or 'us-east-1' @@ -508,6 +579,9 @@ def get_bucket_region(session, bucketname) -> str: @with_trace() def get_user_ip() -> str: + """ + Return IP of the user making the request + """ assert app.current_request is not None x_forwarded_for = app.current_request.headers.get('x-forwarded-for') @@ -523,6 +597,12 @@ def get_user_ip() -> str: @with_trace() def try_download_from_bucket(bucket, filename, user_profile, headers: dict) -> Response: + """ + Attempt to redirect to given file from given bucket. + + Returns a redirect response with presigned S3 URL if successful, + or an appropriate error response if unsuccessful. + """ timer = Timer() timer.mark() user_id = None @@ -634,6 +714,9 @@ def get_jwt_field(cookievar: dict, fieldname: str) -> Optional[str]: @app.route('/') @with_trace(context={}) def root() -> Response: + """ + Render human readable root page + """ template_vars = {'title': 'Welcome'} user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) if user_profile is not None: @@ -739,6 +822,9 @@ def collapse_bucket_configuration(bucket_map) -> dict: @with_trace() def get_range_header_val() -> Optional[str]: + """ + Return value of range header if present + """ if 'Range' in app.current_request.headers: return app.current_request.headers['Range'] if 'range' in app.current_request.headers: @@ -775,7 +861,13 @@ def get_data_dl_s3_client(): @with_trace() -def try_download_head(bucket, filename): +def try_download_head(bucket, filename) -> Response: + """ + Try to handle a HEAD request for given filename in given bucket + + Return a redirect response if given filename exists in the bucket, provide + an error message otherwise. + """ timer = Timer() timer.mark("get_data_dl_s3_client()") @@ -839,6 +931,13 @@ def try_download_head(bucket, filename): @app.route('/{proxy+}', methods=['HEAD']) @with_trace(context={}) def dynamic_url_head(): + """ + Handle HEAD requests for arbitrary files in arbitrary buckets + + The name of the bucket and filename is parsed out of the URL. If the + file is found in the bucket and the request is authenticated properly, + a signed s3 URL is returned. If not, an error response is returned. + """ timer = Timer() timer.mark("restore_bucket_vars()") log.debug('attempting to HEAD a thing') @@ -872,6 +971,13 @@ def dynamic_url_head(): @app.route('/{proxy+}', methods=['GET']) @with_trace(context={}) def dynamic_url(): + """ + Handle GET requests for arbitrary files in arbitrary buckets + + The name of the bucket and filename is parsed out of the URL. If the + file is found in the bucket and the request is authenticated properly, + a signed s3 URL is returned. If not, an error response is returned. + """ timer = Timer() timer.mark("restore_bucket_vars()") @@ -963,6 +1069,9 @@ def dynamic_url(): @app.route('/s3credentials', methods=['GET']) @with_trace(context={}) def s3credentials(): + """ + Return temporary AWS credentials to calling user, with ability to access S3 + """ timer = Timer() timer.mark("restore_bucket_vars()") @@ -1054,6 +1163,9 @@ def get_s3_credentials(user_id: str, role_session_name: str, policy: dict) -> di @app.route('/s3credentialsREADME', methods=['GET']) @with_trace(context={}) def s3credentials_readme() -> Response: + """ + Return a human readable README for how to use /s3credentials + """ return make_html_response({}, {}, 200, "s3credentials_readme.html")