Skip to content

Commit

Permalink
Add typing for airflow/configuration.py (#23716)
Browse files Browse the repository at this point in the history
* Add typing for airflow/configuration.py

The configuraiton.py did not have typing information and it made
it rather difficult to reason about it-especially that it went
a few changes in the past that made it rather complex to
understand.

This PR adds typing information all over the configuration file
  • Loading branch information
potiuk authored May 16, 2022
1 parent 741f802 commit 71e4deb
Show file tree
Hide file tree
Showing 24 changed files with 313 additions and 183 deletions.
2 changes: 1 addition & 1 deletion airflow/api/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

def get_current_api_client() -> Client:
"""Return current API Client based on current Airflow configuration"""
api_module = import_module(conf.get('cli', 'api_client')) # type: Any
api_module = import_module(conf.get_mandatory_value('cli', 'api_client')) # type: Any
auth_backends = api.load_auth()
session = None
for backend in auth_backends:
Expand Down
46 changes: 24 additions & 22 deletions airflow/config_templates/airflow_local_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import os
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union
from urllib.parse import urlparse

from airflow.configuration import conf
Expand All @@ -29,30 +29,32 @@
# in this file instead of from airflow.cfg. Currently
# there are other log format and level configurations in
# settings.py and cli.py. Please see AIRFLOW-1455.
LOG_LEVEL: str = conf.get('logging', 'LOGGING_LEVEL').upper()
LOG_LEVEL: str = conf.get_mandatory_value('logging', 'LOGGING_LEVEL').upper()


# Flask appbuilder's info level log is very verbose,
# so it's set to 'WARN' by default.
FAB_LOG_LEVEL: str = conf.get('logging', 'FAB_LOGGING_LEVEL').upper()
FAB_LOG_LEVEL: str = conf.get_mandatory_value('logging', 'FAB_LOGGING_LEVEL').upper()

LOG_FORMAT: str = conf.get('logging', 'LOG_FORMAT')
LOG_FORMAT: str = conf.get_mandatory_value('logging', 'LOG_FORMAT')

COLORED_LOG_FORMAT: str = conf.get('logging', 'COLORED_LOG_FORMAT')
COLORED_LOG_FORMAT: str = conf.get_mandatory_value('logging', 'COLORED_LOG_FORMAT')

COLORED_LOG: bool = conf.getboolean('logging', 'COLORED_CONSOLE_LOG')

COLORED_FORMATTER_CLASS: str = conf.get('logging', 'COLORED_FORMATTER_CLASS')
COLORED_FORMATTER_CLASS: str = conf.get_mandatory_value('logging', 'COLORED_FORMATTER_CLASS')

BASE_LOG_FOLDER: str = conf.get('logging', 'BASE_LOG_FOLDER')
BASE_LOG_FOLDER: str = conf.get_mandatory_value('logging', 'BASE_LOG_FOLDER')

PROCESSOR_LOG_FOLDER: str = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY')
PROCESSOR_LOG_FOLDER: str = conf.get_mandatory_value('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY')

DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get('logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION')
DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get_mandatory_value(
'logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION'
)

FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_FILENAME_TEMPLATE')
FILENAME_TEMPLATE: str = conf.get_mandatory_value('logging', 'LOG_FILENAME_TEMPLATE')

PROCESSOR_FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE')
PROCESSOR_FILENAME_TEMPLATE: str = conf.get_mandatory_value('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE')

DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
'version': 1,
Expand Down Expand Up @@ -116,7 +118,7 @@
},
}

EXTRA_LOGGER_NAMES: str = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None)
EXTRA_LOGGER_NAMES: Optional[str] = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None)
if EXTRA_LOGGER_NAMES:
new_loggers = {
logger_name.strip(): {
Expand Down Expand Up @@ -171,15 +173,15 @@

if REMOTE_LOGGING:

ELASTICSEARCH_HOST: str = conf.get('elasticsearch', 'HOST')
ELASTICSEARCH_HOST: Optional[str] = conf.get('elasticsearch', 'HOST')

# Storage bucket URL for remote logging
# S3 buckets should start with "s3://"
# Cloudwatch log groups should start with "cloudwatch://"
# GCS buckets should start with "gs://"
# WASB buckets should start with "wasb"
# just to help Airflow select correct handler
REMOTE_BASE_LOG_FOLDER: str = conf.get('logging', 'REMOTE_BASE_LOG_FOLDER')
REMOTE_BASE_LOG_FOLDER: str = conf.get_mandatory_value('logging', 'REMOTE_BASE_LOG_FOLDER')

if REMOTE_BASE_LOG_FOLDER.startswith('s3://'):
S3_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
Expand Down Expand Up @@ -207,7 +209,7 @@

DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'):
key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None)
key_path = conf.get_mandatory_value('logging', 'GOOGLE_KEY_PATH', fallback=None)
GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
'task': {
'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler',
Expand Down Expand Up @@ -235,7 +237,7 @@

DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'):
key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None)
key_path = conf.get_mandatory_value('logging', 'GOOGLE_KEY_PATH', fallback=None)
# stackdriver:///airflow-tasks => airflow-tasks
log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:]
STACKDRIVER_REMOTE_HANDLERS = {
Expand All @@ -260,14 +262,14 @@
}
DEFAULT_LOGGING_CONFIG['handlers'].update(OSS_REMOTE_HANDLERS)
elif ELASTICSEARCH_HOST:
ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get('elasticsearch', 'LOG_ID_TEMPLATE')
ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get('elasticsearch', 'END_OF_LOG_MARK')
ELASTICSEARCH_FRONTEND: str = conf.get('elasticsearch', 'frontend')
ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get_mandatory_value('elasticsearch', 'LOG_ID_TEMPLATE')
ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get_mandatory_value('elasticsearch', 'END_OF_LOG_MARK')
ELASTICSEARCH_FRONTEND: str = conf.get_mandatory_value('elasticsearch', 'frontend')
ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean('elasticsearch', 'WRITE_STDOUT')
ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean('elasticsearch', 'JSON_FORMAT')
ELASTICSEARCH_JSON_FIELDS: str = conf.get('elasticsearch', 'JSON_FIELDS')
ELASTICSEARCH_HOST_FIELD: str = conf.get('elasticsearch', 'HOST_FIELD')
ELASTICSEARCH_OFFSET_FIELD: str = conf.get('elasticsearch', 'OFFSET_FIELD')
ELASTICSEARCH_JSON_FIELDS: str = conf.get_mandatory_value('elasticsearch', 'JSON_FIELDS')
ELASTICSEARCH_HOST_FIELD: str = conf.get_mandatory_value('elasticsearch', 'HOST_FIELD')
ELASTICSEARCH_OFFSET_FIELD: str = conf.get_mandatory_value('elasticsearch', 'OFFSET_FIELD')

ELASTIC_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = {
'task': {
Expand Down
6 changes: 3 additions & 3 deletions airflow/config_templates/default_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def _broker_supports_visibility_timeout(url):

try:
if celery_ssl_active:
if 'amqp://' in broker_url:
if broker_url and 'amqp://' in broker_url:
broker_use_ssl = {
'keyfile': conf.get('celery', 'SSL_KEY'),
'certfile': conf.get('celery', 'SSL_CERT'),
'ca_certs': conf.get('celery', 'SSL_CACERT'),
'cert_reqs': ssl.CERT_REQUIRED,
}
elif 'redis://' in broker_url:
elif broker_url and 'redis://' in broker_url:
broker_use_ssl = {
'ssl_keyfile': conf.get('celery', 'SSL_KEY'),
'ssl_certfile': conf.get('celery', 'SSL_CERT'),
Expand All @@ -92,7 +92,7 @@ def _broker_supports_visibility_timeout(url):
f'all necessary certs and key ({e}).'
)

result_backend = DEFAULT_CELERY_CONFIG['result_backend']
result_backend = str(DEFAULT_CELERY_CONFIG['result_backend'])
if 'amqp://' in result_backend or 'redis://' in result_backend or 'rpc://' in result_backend:
log.warning(
"You have configured a result_backend of %s, it is highly recommended "
Expand Down
Loading

0 comments on commit 71e4deb

Please sign in to comment.