From 502710c3b7bb8414f9322a5bb4ec2322ac98fa08 Mon Sep 17 00:00:00 2001 From: YuviPanda Date: Mon, 2 Jan 2023 15:01:28 -0800 Subject: [PATCH 1/4] Add more return type hints for app.py I loved reading https://github.com/asfadmin/thin-egress-app/blob/master/docs/vision.md! Hopefully this PR is useful and not just noise. I couldn't figure out how exactly to type a return for the S3 botocore client, but otherwise I think this is accurate. --- lambda/app.py | 50 +++++++++++++++++++++++++------------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/lambda/app.py b/lambda/app.py index ec317326..e8879da1 100644 --- a/lambda/app.py +++ b/lambda/app.py @@ -96,7 +96,7 @@ def wrapper(*args, **kwargs): @ttl_cache(maxsize=2, ttl=10 * 60, timer=time.time) -def get_black_list(): +def get_black_list() -> dict: endpoint = os.getenv('BLACKLIST_ENDPOINT', '') if endpoint: response = urllib.request.urlopen(endpoint).read().decode('utf-8') @@ -108,7 +108,7 @@ def get_black_list(): @app.middleware('http') -def initialize(event, get_response): +def initialize(event, get_response) -> Response: JWT_MANAGER.black_list = get_black_list() jwt_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME')) JWT_MANAGER.public_key = base64.b64decode(jwt_keys.get('rsa_pub_key', '')).decode() @@ -118,7 +118,7 @@ def initialize(event, get_response): @app.middleware('http') -def set_log_context(event: chalice.app.Request, get_response): +def set_log_context(event: chalice.app.Request, get_response) -> Response: origin_request_id = event.headers.get('x-origin-request-id') log_context( @@ -139,7 +139,7 @@ def set_log_context(event: chalice.app.Request, get_response): @app.middleware('http') -def forward_origin_request_id(event: chalice.app.Request, get_response): +def forward_origin_request_id(event: chalice.app.Request, get_response) -> Response: response = get_response(event) origin_request_id = event.headers.get('x-origin-request-id') @@ -268,7 +268,7 @@ def get_origin_request_id() -> Optional[str]: @with_trace() -def get_aux_request_headers(): +def get_aux_request_headers() -> dict: req_headers = {"x-request-id": get_request_id()} origin_request_id = get_origin_request_id() @@ -282,12 +282,12 @@ def get_aux_request_headers(): @with_trace() -def check_for_browser(hdrs): +def check_for_browser(hdrs) -> bool: return 'user-agent' in hdrs and hdrs['user-agent'].lower().startswith('mozilla') @with_trace() -def get_user_from_token(token): +def get_user_from_token(token) -> Optional[str]: """ This may be moved to rain-api-core.urs_util.py once things stabilize. Will query URS for user ID of requesting user based on token sent with request @@ -405,7 +405,7 @@ def restore_bucket_vars(): @with_trace() -def do_auth_and_return(ctxt): +def do_auth_and_return(ctxt) -> Response: log.debug('context: {}'.format(ctxt)) here = ctxt['path'] if os.getenv('DOMAIN_NAME'): @@ -438,7 +438,7 @@ def add_cors_headers(headers): @with_trace() -def make_redirect(to_url, headers=None, status_code=301): +def make_redirect(to_url, headers=None, status_code=301) -> Response: if headers is None: headers = {} headers['Location'] = to_url @@ -450,7 +450,7 @@ def make_redirect(to_url, headers=None, status_code=301): @with_trace() -def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html'): +def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html') -> Response: template_vars = { 'STAGE': STAGE if not os.getenv('DOMAIN_NAME') else None, 'status_code': status_code, @@ -507,7 +507,7 @@ def get_bucket_region(session, bucketname) -> str: @with_trace() -def get_user_ip(): +def get_user_ip() -> str: assert app.current_request is not None x_forwarded_for = app.current_request.headers.get('x-forwarded-for') @@ -522,7 +522,7 @@ def get_user_ip(): @with_trace() -def try_download_from_bucket(bucket, filename, user_profile, headers: dict): +def try_download_from_bucket(bucket, filename, user_profile, headers: dict) -> Response: timer = Timer() timer.mark() user_id = None @@ -627,13 +627,13 @@ def try_download_from_bucket(bucket, filename, user_profile, headers: dict): @with_trace() -def get_jwt_field(cookievar: dict, fieldname: str): +def get_jwt_field(cookievar: dict, fieldname: str) -> Optional[str]: return cookievar.get(JWT_COOKIE_NAME, {}).get(fieldname, None) @app.route('/') @with_trace(context={}) -def root(): +def root() -> Response: template_vars = {'title': 'Welcome'} user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) if user_profile is not None: @@ -648,7 +648,7 @@ def root(): @app.route('/logout') @with_trace(context={}) -def logout(): +def logout() -> Response: user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) template_vars = {'title': 'Logged Out', 'URS_URL': get_urs_url(app.current_request.context)} @@ -667,7 +667,7 @@ def logout(): @app.route('/login') @with_trace(context={}) -def login(): +def login() -> Response: try: headers = {} aux_headers = get_aux_request_headers() @@ -694,7 +694,7 @@ def login(): @app.route('/version') @with_trace(context={}) -def version(): +def version() -> str: log.info("Got a version request!") version_return = {'version_id': ''} @@ -707,7 +707,7 @@ def version(): @app.route('/locate') @with_trace(context={}) -def locate(): +def locate() -> Response: query_params = app.current_request.query_params if query_params is None or query_params.get('bucket_name') is None: return Response(body='Required "bucket_name" query paramater not specified', @@ -727,7 +727,7 @@ def locate(): @with_trace() -def collapse_bucket_configuration(bucket_map): +def collapse_bucket_configuration(bucket_map) -> dict: for k, v in bucket_map.items(): if isinstance(v, dict): if 'bucket' in v: @@ -738,7 +738,7 @@ def collapse_bucket_configuration(bucket_map): @with_trace() -def get_range_header_val(): +def get_range_header_val() -> Optional[str]: if 'Range' in app.current_request.headers: return app.current_request.headers['Range'] if 'range' in app.current_request.headers: @@ -1026,7 +1026,7 @@ def s3credentials(): @with_trace() -def get_role_session_name(user_id: str, app_name: str): +def get_role_session_name(user_id: str, app_name: str) -> str: # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html#API_AssumeRole_RequestParameters if not re.match(r"[\w+,.=@-]*", app_name): app_name = "" @@ -1035,7 +1035,7 @@ def get_role_session_name(user_id: str, app_name: str): @with_trace() -def get_s3_credentials(user_id: str, role_session_name: str, policy: dict): +def get_s3_credentials(user_id: str, role_session_name: str, policy: dict) -> dict: client = boto3.client("sts") arn = os.getenv("EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN") response = client.assume_role( @@ -1053,20 +1053,20 @@ def get_s3_credentials(user_id: str, role_session_name: str, policy: dict): @app.route('/s3credentialsREADME', methods=['GET']) @with_trace(context={}) -def s3credentials_readme(): +def s3credentials_readme() -> Response: return make_html_response({}, {}, 200, "s3credentials_readme.html") @app.route('/profile') @with_trace(context={}) -def profile(): +def profile() -> Response: return Response(body='Profile not available.', status_code=200, headers={}) @app.route('/pubkey', methods=['GET']) @with_trace(context={}) -def pubkey(): +def pubkey() -> Response: thebody = json.dumps({ 'rsa_pub_key': JWT_MANAGER.public_key, 'algorithm': JWT_MANAGER.algorithm From 365315062a74932c9ac05c7d328be22117ecc5b1 Mon Sep 17 00:00:00 2001 From: YuviPanda Date: Fri, 6 Jan 2023 16:09:11 -0800 Subject: [PATCH 2/4] Add docstrings to most functions --- lambda/app.py | 116 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 2 deletions(-) diff --git a/lambda/app.py b/lambda/app.py index e8879da1..9df263db 100644 --- a/lambda/app.py +++ b/lambda/app.py @@ -24,6 +24,8 @@ from opentelemetry import trace from opentelemetry.propagate import inject except ImportError: + # If opentelemetry is not installed, we make dummy objects / functions here + # so we do not need to add conditionls throughout our codebase trace = None def inject(obj): @@ -97,6 +99,12 @@ def wrapper(*args, **kwargs): @ttl_cache(maxsize=2, ttl=10 * 60, timer=time.time) def get_black_list() -> dict: + """ + Return blacklist if configured + + Looks at the BLACKLIST_ENDPOINT environment variable for a URL to + contact for fetching the blacklist from. + """ endpoint = os.getenv('BLACKLIST_ENDPOINT', '') if endpoint: response = urllib.request.urlopen(endpoint).read().decode('utf-8') @@ -108,7 +116,12 @@ def get_black_list() -> dict: @app.middleware('http') -def initialize(event, get_response) -> Response: +def initialize(event: chalice.app.Request, get_response) -> Response: + """ + Initialize various properties needed for each request. + + This function is called on *every* request before any of the handlers are called + """ JWT_MANAGER.black_list = get_black_list() jwt_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME')) JWT_MANAGER.public_key = base64.b64decode(jwt_keys.get('rsa_pub_key', '')).decode() @@ -119,6 +132,11 @@ def initialize(event, get_response) -> Response: @app.middleware('http') def set_log_context(event: chalice.app.Request, get_response) -> Response: + """ + Set context about current request for all log statements + + This function is called for each request before the request handlers are called. + """ origin_request_id = event.headers.get('x-origin-request-id') log_context( @@ -135,11 +153,19 @@ def set_log_context(event: chalice.app.Request, get_response) -> Response: try: return get_response(event) finally: + # Reset log context after our request is completed log_context(user_id=None, route=None, request_id=None) @app.middleware('http') def forward_origin_request_id(event: chalice.app.Request, get_response) -> Response: + """ + Normalize x-request-id header to point to aws_request_id for all requests + + The original request_id, if present, is put in an x-origin-request-id + + This function is called for each request before the request handlers are called. + """ response = get_response(event) origin_request_id = event.headers.get('x-origin-request-id') @@ -157,17 +183,29 @@ class TeaException(Exception): class EulaException(TeaException): + """ + Exception indicating that the authorized user has not accepted the EULA for accessing the dataset + """ def __init__(self, payload: dict): self.payload = payload class RequestAuthorizer: + """ + Handle authorization of incoming requests. + + Supports handling traditional OAuth2 with appropriate redirects, as well as using + bearer tokens. + """ def __init__(self): self._response = None self._headers = {} @with_trace() def get_profile(self) -> Optional[UserProfile]: + """ + Return user profile if the user is authenticated + """ user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) if user_profile is not None: return user_profile @@ -255,6 +293,9 @@ def get_success_response_headers(self) -> dict: @with_trace() def get_request_id() -> str: + """ + Return AWS Lambda Request ID + """ assert app.lambda_context is not None return app.lambda_context.aws_request_id @@ -262,6 +303,9 @@ def get_request_id() -> str: @with_trace() def get_origin_request_id() -> Optional[str]: + """ + Return the *original* AWS Lambda Request ID + """ assert app.current_request is not None return app.current_request.headers.get("x-origin-request-id") @@ -269,6 +313,12 @@ def get_origin_request_id() -> Optional[str]: @with_trace() def get_aux_request_headers() -> dict: + """ + Return common HTTP headers used when making requests to EarthData login servers. + + These are headers we send when making requests as a *client*, not when making + responses as a server. + """ req_headers = {"x-request-id": get_request_id()} origin_request_id = get_origin_request_id() @@ -283,6 +333,9 @@ def get_aux_request_headers() -> dict: @with_trace() def check_for_browser(hdrs) -> bool: + """ + Return True if request is being sent by a browser + """ return 'user-agent' in hdrs and hdrs['user-agent'].lower().startswith('mozilla') @@ -367,12 +420,18 @@ def get_user_from_token(token) -> Optional[str]: @with_trace() def cumulus_log_message(outcome: str, code: int, http_method: str, k_v: dict): + """ + Emit log message to stdout in a format that cumulus understands + """ k_v.update({'code': code, 'http_method': http_method, 'status': outcome, 'requestid': get_request_id()}) print(json.dumps(k_v)) @with_trace() def restore_bucket_vars(): + """ + Update bucket config by re-fetching it from configured S3 object + """ global b_map # pylint: disable=global-statement log.debug('conf bucket: %s, bucket_map_file: %s', conf_bucket, bucket_map_file) @@ -424,6 +483,9 @@ def do_auth_and_return(ctxt) -> Response: @with_trace() def add_cors_headers(headers): + """ + Add CORS headers to allow requests from all configured domains + """ assert app.current_request is not None # send CORS headers if we're configured to use them @@ -439,6 +501,9 @@ def add_cors_headers(headers): @with_trace() def make_redirect(to_url, headers=None, status_code=301) -> Response: + """ + Return a HTTP Response redirecting users with appropriate headers + """ if headers is None: headers = {} headers['Location'] = to_url @@ -451,6 +516,9 @@ def make_redirect(to_url, headers=None, status_code=301) -> Response: @with_trace() def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html') -> Response: + """ + Return a HTTP response with rendered HTML from the given template + """ template_vars = { 'STAGE': STAGE if not os.getenv('DOMAIN_NAME') else None, 'status_code': status_code, @@ -489,6 +557,9 @@ def get_bcconfig(user_id: str) -> dict: key=lambda _, bucketname: hashkey(bucketname) ) def get_bucket_region(session, bucketname) -> str: + """ + Get the region of the given bucket + """ try: _time = time.time() bucket_region = session.client('s3').get_bucket_location(Bucket=bucketname)['LocationConstraint'] or 'us-east-1' @@ -508,6 +579,9 @@ def get_bucket_region(session, bucketname) -> str: @with_trace() def get_user_ip() -> str: + """ + Return IP of the user making the request + """ assert app.current_request is not None x_forwarded_for = app.current_request.headers.get('x-forwarded-for') @@ -523,6 +597,12 @@ def get_user_ip() -> str: @with_trace() def try_download_from_bucket(bucket, filename, user_profile, headers: dict) -> Response: + """ + Attempt to redirect to given file from given bucket. + + Returns a redirect response with presigned S3 URL if successful, + or an appropriate error response if unsuccessful. + """ timer = Timer() timer.mark() user_id = None @@ -634,6 +714,9 @@ def get_jwt_field(cookievar: dict, fieldname: str) -> Optional[str]: @app.route('/') @with_trace(context={}) def root() -> Response: + """ + Render human readable root page + """ template_vars = {'title': 'Welcome'} user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) if user_profile is not None: @@ -739,6 +822,9 @@ def collapse_bucket_configuration(bucket_map) -> dict: @with_trace() def get_range_header_val() -> Optional[str]: + """ + Return value of range header if present + """ if 'Range' in app.current_request.headers: return app.current_request.headers['Range'] if 'range' in app.current_request.headers: @@ -775,7 +861,13 @@ def get_data_dl_s3_client(): @with_trace() -def try_download_head(bucket, filename): +def try_download_head(bucket, filename) -> Response: + """ + Try to handle a HEAD request for given filename in given bucket + + Return a redirect response if given filename exists in the bucket, provide + an error message otherwise. + """ timer = Timer() timer.mark("get_data_dl_s3_client()") @@ -839,6 +931,13 @@ def try_download_head(bucket, filename): @app.route('/{proxy+}', methods=['HEAD']) @with_trace(context={}) def dynamic_url_head(): + """ + Handle HEAD requests for arbitrary files in arbitrary buckets + + The name of the bucket and filename is parsed out of the URL. If the + file is found in the bucket and the request is authenticated properly, + a signed s3 URL is returned. If not, an error response is returned. + """ timer = Timer() timer.mark("restore_bucket_vars()") log.debug('attempting to HEAD a thing') @@ -872,6 +971,13 @@ def dynamic_url_head(): @app.route('/{proxy+}', methods=['GET']) @with_trace(context={}) def dynamic_url(): + """ + Handle GET requests for arbitrary files in arbitrary buckets + + The name of the bucket and filename is parsed out of the URL. If the + file is found in the bucket and the request is authenticated properly, + a signed s3 URL is returned. If not, an error response is returned. + """ timer = Timer() timer.mark("restore_bucket_vars()") @@ -963,6 +1069,9 @@ def dynamic_url(): @app.route('/s3credentials', methods=['GET']) @with_trace(context={}) def s3credentials(): + """ + Return temporary AWS credentials to calling user, with ability to access S3 + """ timer = Timer() timer.mark("restore_bucket_vars()") @@ -1054,6 +1163,9 @@ def get_s3_credentials(user_id: str, role_session_name: str, policy: dict) -> di @app.route('/s3credentialsREADME', methods=['GET']) @with_trace(context={}) def s3credentials_readme() -> Response: + """ + Return a human readable README for how to use /s3credentials + """ return make_html_response({}, {}, 200, "s3credentials_readme.html") From 146c503a28acf19bb2dba9e6894b8724e60a4ebe Mon Sep 17 00:00:00 2001 From: YuviPanda Date: Fri, 6 Jan 2023 16:15:19 -0800 Subject: [PATCH 3/4] Add some more type hints --- lambda/app.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lambda/app.py b/lambda/app.py index 9df263db..67608cd4 100644 --- a/lambda/app.py +++ b/lambda/app.py @@ -239,7 +239,7 @@ def get_profile(self) -> Optional[UserProfile]: return None @with_trace() - def _handle_auth_bearer_header(self, token) -> Optional[UserProfile]: + def _handle_auth_bearer_header(self, token: str) -> Optional[UserProfile]: """ Will handle the output from get_user_from_token in context of a chalice function. If user_id is determined, returns it. If user_id is not determined returns data to be returned @@ -332,7 +332,7 @@ def get_aux_request_headers() -> dict: @with_trace() -def check_for_browser(hdrs) -> bool: +def check_for_browser(hdrs: dict) -> bool: """ Return True if request is being sent by a browser """ @@ -340,7 +340,7 @@ def check_for_browser(hdrs) -> bool: @with_trace() -def get_user_from_token(token) -> Optional[str]: +def get_user_from_token(token: str) -> Optional[str]: """ This may be moved to rain-api-core.urs_util.py once things stabilize. Will query URS for user ID of requesting user based on token sent with request @@ -482,7 +482,7 @@ def do_auth_and_return(ctxt) -> Response: @with_trace() -def add_cors_headers(headers): +def add_cors_headers(headers: dict): """ Add CORS headers to allow requests from all configured domains """ @@ -500,7 +500,7 @@ def add_cors_headers(headers): @with_trace() -def make_redirect(to_url, headers=None, status_code=301) -> Response: +def make_redirect(to_url: str, headers: Optional[dict] = None, status_code: str = 301) -> Response: """ Return a HTTP Response redirecting users with appropriate headers """ @@ -556,7 +556,7 @@ def get_bcconfig(user_id: str) -> dict: # Cache by bucketname only key=lambda _, bucketname: hashkey(bucketname) ) -def get_bucket_region(session, bucketname) -> str: +def get_bucket_region(session, bucketname: str) -> str: """ Get the region of the given bucket """ @@ -596,7 +596,7 @@ def get_user_ip() -> str: @with_trace() -def try_download_from_bucket(bucket, filename, user_profile, headers: dict) -> Response: +def try_download_from_bucket(bucket: str, filename: str, user_profile: UserProfile, headers: dict) -> Response: """ Attempt to redirect to given file from given bucket. @@ -861,7 +861,7 @@ def get_data_dl_s3_client(): @with_trace() -def try_download_head(bucket, filename) -> Response: +def try_download_head(bucket: str, filename: str) -> Response: """ Try to handle a HEAD request for given filename in given bucket From 917c096bd48e03a239482f08e282467c03830ac3 Mon Sep 17 00:00:00 2001 From: YuviPanda Date: Fri, 20 Jan 2023 15:27:06 -0800 Subject: [PATCH 4/4] Split up line that was too long --- lambda/app.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lambda/app.py b/lambda/app.py index 67608cd4..a259fd22 100644 --- a/lambda/app.py +++ b/lambda/app.py @@ -515,7 +515,12 @@ def make_redirect(to_url: str, headers: Optional[dict] = None, status_code: str @with_trace() -def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html') -> Response: +def make_html_response( + t_vars: dict, + headers: dict, + status_code: int = 200, + template_file: str = 'root.html' +) -> Response: """ Return a HTTP response with rendered HTML from the given template """