Skip to content

Commit

Permalink
Add more return type hints for app.py
Browse files Browse the repository at this point in the history
I loved reading
https://github.com/asfadmin/thin-egress-app/blob/master/docs/vision.md!
Hopefully this PR is useful and not just noise.

I couldn't figure out how exactly to type a return for the S3 botocore
client, but otherwise I think this is accurate.
  • Loading branch information
yuvipanda committed Jan 2, 2023
1 parent 7b0f711 commit 8753952
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions lambda/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def wrapper(*args, **kwargs):


@ttl_cache(maxsize=2, ttl=10 * 60, timer=time.time)
def get_black_list():
def get_black_list() -> dict:
endpoint = os.getenv('BLACKLIST_ENDPOINT', '')
if endpoint:
response = urllib.request.urlopen(endpoint).read().decode('utf-8')
Expand All @@ -108,7 +108,7 @@ def get_black_list():


@app.middleware('http')
def initialize(event, get_response):
def initialize(event, get_response) -> Response:
JWT_MANAGER.black_list = get_black_list()
jwt_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME'))
JWT_MANAGER.public_key = base64.b64decode(jwt_keys.get('rsa_pub_key', '')).decode()
Expand All @@ -118,7 +118,7 @@ def initialize(event, get_response):


@app.middleware('http')
def set_log_context(event: chalice.app.Request, get_response):
def set_log_context(event: chalice.app.Request, get_response) -> Response:
origin_request_id = event.headers.get('x-origin-request-id')

log_context(
Expand All @@ -139,7 +139,7 @@ def set_log_context(event: chalice.app.Request, get_response):


@app.middleware('http')
def forward_origin_request_id(event: chalice.app.Request, get_response):
def forward_origin_request_id(event: chalice.app.Request, get_response) -> Response:
response = get_response(event)

origin_request_id = event.headers.get('x-origin-request-id')
Expand Down Expand Up @@ -268,7 +268,7 @@ def get_origin_request_id() -> Optional[str]:


@with_trace()
def get_aux_request_headers():
def get_aux_request_headers() -> dict:
req_headers = {"x-request-id": get_request_id()}
origin_request_id = get_origin_request_id()

Expand All @@ -282,12 +282,12 @@ def get_aux_request_headers():


@with_trace()
def check_for_browser(hdrs):
def check_for_browser(hdrs) -> bool:
return 'user-agent' in hdrs and hdrs['user-agent'].lower().startswith('mozilla')


@with_trace()
def get_user_from_token(token):
def get_user_from_token(token) -> Optional[str]:
"""
This may be moved to rain-api-core.urs_util.py once things stabilize.
Will query URS for user ID of requesting user based on token sent with request
Expand Down Expand Up @@ -405,7 +405,7 @@ def restore_bucket_vars():


@with_trace()
def do_auth_and_return(ctxt):
def do_auth_and_return(ctxt) -> Response:
log.debug('context: {}'.format(ctxt))
here = ctxt['path']
if os.getenv('DOMAIN_NAME'):
Expand Down Expand Up @@ -438,7 +438,7 @@ def add_cors_headers(headers):


@with_trace()
def make_redirect(to_url, headers=None, status_code=301):
def make_redirect(to_url, headers=None, status_code=301) -> Response:
if headers is None:
headers = {}
headers['Location'] = to_url
Expand All @@ -450,7 +450,7 @@ def make_redirect(to_url, headers=None, status_code=301):


@with_trace()
def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html'):
def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html') -> Response:
template_vars = {
'STAGE': STAGE if not os.getenv('DOMAIN_NAME') else None,
'status_code': status_code,
Expand Down Expand Up @@ -507,7 +507,7 @@ def get_bucket_region(session, bucketname) -> str:


@with_trace()
def get_user_ip():
def get_user_ip() -> str:
assert app.current_request is not None

x_forwarded_for = app.current_request.headers.get('x-forwarded-for')
Expand All @@ -522,7 +522,7 @@ def get_user_ip():


@with_trace()
def try_download_from_bucket(bucket, filename, user_profile, headers: dict):
def try_download_from_bucket(bucket, filename, user_profile, headers: dict) -> Response:
timer = Timer()
timer.mark()
user_id = None
Expand Down Expand Up @@ -627,13 +627,13 @@ def try_download_from_bucket(bucket, filename, user_profile, headers: dict):


@with_trace()
def get_jwt_field(cookievar: dict, fieldname: str):
def get_jwt_field(cookievar: dict, fieldname: str) -> Optional[str]:
return cookievar.get(JWT_COOKIE_NAME, {}).get(fieldname, None)


@app.route('/')
@with_trace(context={})
def root():
def root() -> Response:
template_vars = {'title': 'Welcome'}
user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers)
if user_profile is not None:
Expand All @@ -648,7 +648,7 @@ def root():

@app.route('/logout')
@with_trace(context={})
def logout():
def logout() -> Response:
user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers)
template_vars = {'title': 'Logged Out', 'URS_URL': get_urs_url(app.current_request.context)}

Expand All @@ -667,7 +667,7 @@ def logout():

@app.route('/login')
@with_trace(context={})
def login():
def login() -> Response:
try:
headers = {}
aux_headers = get_aux_request_headers()
Expand All @@ -694,7 +694,7 @@ def login():

@app.route('/version')
@with_trace(context={})
def version():
def version() -> str:
log.info("Got a version request!")
version_return = {'version_id': '<BUILD_ID>'}

Expand All @@ -707,7 +707,7 @@ def version():

@app.route('/locate')
@with_trace(context={})
def locate():
def locate() -> Response:
query_params = app.current_request.query_params
if query_params is None or query_params.get('bucket_name') is None:
return Response(body='Required "bucket_name" query paramater not specified',
Expand All @@ -727,7 +727,7 @@ def locate():


@with_trace()
def collapse_bucket_configuration(bucket_map):
def collapse_bucket_configuration(bucket_map) -> dict:
for k, v in bucket_map.items():
if isinstance(v, dict):
if 'bucket' in v:
Expand All @@ -738,7 +738,7 @@ def collapse_bucket_configuration(bucket_map):


@with_trace()
def get_range_header_val():
def get_range_header_val() -> Optional[str]:
if 'Range' in app.current_request.headers:
return app.current_request.headers['Range']
if 'range' in app.current_request.headers:
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def s3credentials():


@with_trace()
def get_role_session_name(user_id: str, app_name: str):
def get_role_session_name(user_id: str, app_name: str) -> str:
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html#API_AssumeRole_RequestParameters
if not re.match(r"[\w+,.=@-]*", app_name):
app_name = ""
Expand All @@ -1035,7 +1035,7 @@ def get_role_session_name(user_id: str, app_name: str):


@with_trace()
def get_s3_credentials(user_id: str, role_session_name: str, policy: dict):
def get_s3_credentials(user_id: str, role_session_name: str, policy: dict) -> dict:
client = boto3.client("sts")
arn = os.getenv("EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN")
response = client.assume_role(
Expand All @@ -1053,20 +1053,20 @@ def get_s3_credentials(user_id: str, role_session_name: str, policy: dict):

@app.route('/s3credentialsREADME', methods=['GET'])
@with_trace(context={})
def s3credentials_readme():
def s3credentials_readme() -> Response:
return make_html_response({}, {}, 200, "s3credentials_readme.html")


@app.route('/profile')
@with_trace(context={})
def profile():
def profile() -> Response:
return Response(body='Profile not available.',
status_code=200, headers={})


@app.route('/pubkey', methods=['GET'])
@with_trace(context={})
def pubkey():
def pubkey() -> Response:
thebody = json.dumps({
'rsa_pub_key': JWT_MANAGER.public_key,
'algorithm': JWT_MANAGER.algorithm
Expand Down

0 comments on commit 8753952

Please sign in to comment.