Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some type hints and docstrings for app.py #677

Open
wants to merge 4 commits into
base: devel
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions lambda/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def wrapper(*args, **kwargs):


@ttl_cache(maxsize=2, ttl=10 * 60, timer=time.time)
def get_black_list():
def get_black_list() -> dict:
endpoint = os.getenv('BLACKLIST_ENDPOINT', '')
if endpoint:
response = urllib.request.urlopen(endpoint).read().decode('utf-8')
Expand All @@ -108,7 +108,7 @@ def get_black_list():


@app.middleware('http')
def initialize(event, get_response):
def initialize(event, get_response) -> Response:
JWT_MANAGER.black_list = get_black_list()
jwt_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME'))
JWT_MANAGER.public_key = base64.b64decode(jwt_keys.get('rsa_pub_key', '')).decode()
Expand All @@ -118,7 +118,7 @@ def initialize(event, get_response):


@app.middleware('http')
def set_log_context(event: chalice.app.Request, get_response):
def set_log_context(event: chalice.app.Request, get_response) -> Response:
origin_request_id = event.headers.get('x-origin-request-id')

log_context(
Expand All @@ -139,7 +139,7 @@ def set_log_context(event: chalice.app.Request, get_response):


@app.middleware('http')
def forward_origin_request_id(event: chalice.app.Request, get_response):
def forward_origin_request_id(event: chalice.app.Request, get_response) -> Response:
response = get_response(event)

origin_request_id = event.headers.get('x-origin-request-id')
Expand Down Expand Up @@ -268,7 +268,7 @@ def get_origin_request_id() -> Optional[str]:


@with_trace()
def get_aux_request_headers():
def get_aux_request_headers() -> dict:
req_headers = {"x-request-id": get_request_id()}
origin_request_id = get_origin_request_id()

Expand All @@ -282,12 +282,12 @@ def get_aux_request_headers():


@with_trace()
def check_for_browser(hdrs):
def check_for_browser(hdrs) -> bool:
return 'user-agent' in hdrs and hdrs['user-agent'].lower().startswith('mozilla')


@with_trace()
def get_user_from_token(token):
def get_user_from_token(token) -> Optional[str]:
"""
This may be moved to rain-api-core.urs_util.py once things stabilize.
Will query URS for user ID of requesting user based on token sent with request
Expand Down Expand Up @@ -405,7 +405,7 @@ def restore_bucket_vars():


@with_trace()
def do_auth_and_return(ctxt):
def do_auth_and_return(ctxt) -> Response:
log.debug('context: {}'.format(ctxt))
here = ctxt['path']
if os.getenv('DOMAIN_NAME'):
Expand Down Expand Up @@ -438,7 +438,7 @@ def add_cors_headers(headers):


@with_trace()
def make_redirect(to_url, headers=None, status_code=301):
def make_redirect(to_url, headers=None, status_code=301) -> Response:
if headers is None:
headers = {}
headers['Location'] = to_url
Expand All @@ -450,7 +450,7 @@ def make_redirect(to_url, headers=None, status_code=301):


@with_trace()
def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html'):
def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html') -> Response:
yuvipanda marked this conversation as resolved.
Show resolved Hide resolved
template_vars = {
'STAGE': STAGE if not os.getenv('DOMAIN_NAME') else None,
'status_code': status_code,
Expand Down Expand Up @@ -507,7 +507,7 @@ def get_bucket_region(session, bucketname) -> str:


@with_trace()
def get_user_ip():
def get_user_ip() -> str:
assert app.current_request is not None

x_forwarded_for = app.current_request.headers.get('x-forwarded-for')
Expand All @@ -522,7 +522,7 @@ def get_user_ip():


@with_trace()
def try_download_from_bucket(bucket, filename, user_profile, headers: dict):
def try_download_from_bucket(bucket, filename, user_profile, headers: dict) -> Response:
timer = Timer()
timer.mark()
user_id = None
Expand Down Expand Up @@ -627,13 +627,13 @@ def try_download_from_bucket(bucket, filename, user_profile, headers: dict):


@with_trace()
def get_jwt_field(cookievar: dict, fieldname: str):
def get_jwt_field(cookievar: dict, fieldname: str) -> Optional[str]:
return cookievar.get(JWT_COOKIE_NAME, {}).get(fieldname, None)


@app.route('/')
@with_trace(context={})
def root():
def root() -> Response:
template_vars = {'title': 'Welcome'}
user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers)
if user_profile is not None:
Expand All @@ -648,7 +648,7 @@ def root():

@app.route('/logout')
@with_trace(context={})
def logout():
def logout() -> Response:
user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers)
template_vars = {'title': 'Logged Out', 'URS_URL': get_urs_url(app.current_request.context)}

Expand All @@ -667,7 +667,7 @@ def logout():

@app.route('/login')
@with_trace(context={})
def login():
def login() -> Response:
try:
headers = {}
aux_headers = get_aux_request_headers()
Expand All @@ -694,7 +694,7 @@ def login():

@app.route('/version')
@with_trace(context={})
def version():
def version() -> str:
log.info("Got a version request!")
version_return = {'version_id': '<BUILD_ID>'}

Expand All @@ -707,7 +707,7 @@ def version():

@app.route('/locate')
@with_trace(context={})
def locate():
def locate() -> Response:
query_params = app.current_request.query_params
if query_params is None or query_params.get('bucket_name') is None:
return Response(body='Required "bucket_name" query paramater not specified',
Expand All @@ -727,7 +727,7 @@ def locate():


@with_trace()
def collapse_bucket_configuration(bucket_map):
def collapse_bucket_configuration(bucket_map) -> dict:
for k, v in bucket_map.items():
if isinstance(v, dict):
if 'bucket' in v:
Expand All @@ -738,7 +738,7 @@ def collapse_bucket_configuration(bucket_map):


@with_trace()
def get_range_header_val():
def get_range_header_val() -> Optional[str]:
if 'Range' in app.current_request.headers:
return app.current_request.headers['Range']
if 'range' in app.current_request.headers:
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def s3credentials():


@with_trace()
def get_role_session_name(user_id: str, app_name: str):
def get_role_session_name(user_id: str, app_name: str) -> str:
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html#API_AssumeRole_RequestParameters
if not re.match(r"[\w+,.=@-]*", app_name):
app_name = ""
Expand All @@ -1035,7 +1035,7 @@ def get_role_session_name(user_id: str, app_name: str):


@with_trace()
def get_s3_credentials(user_id: str, role_session_name: str, policy: dict):
def get_s3_credentials(user_id: str, role_session_name: str, policy: dict) -> dict:
client = boto3.client("sts")
arn = os.getenv("EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN")
response = client.assume_role(
Expand All @@ -1053,20 +1053,20 @@ def get_s3_credentials(user_id: str, role_session_name: str, policy: dict):

@app.route('/s3credentialsREADME', methods=['GET'])
@with_trace(context={})
def s3credentials_readme():
def s3credentials_readme() -> Response:
return make_html_response({}, {}, 200, "s3credentials_readme.html")


@app.route('/profile')
@with_trace(context={})
def profile():
def profile() -> Response:
return Response(body='Profile not available.',
status_code=200, headers={})


@app.route('/pubkey', methods=['GET'])
@with_trace(context={})
def pubkey():
def pubkey() -> Response:
thebody = json.dumps({
'rsa_pub_key': JWT_MANAGER.public_key,
'algorithm': JWT_MANAGER.algorithm
Expand Down