Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ElasticsearchTaskHandler #21942

Merged
merged 2 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions airflow/providers/elasticsearch/log/es_json_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pendulum

from airflow.utils.log.json_formatter import JSONFormatter


class ElasticsearchJSONFormatter(JSONFormatter):
"""
ElasticsearchJSONFormatter instances are used to convert a log record
to json with ISO 8601 date and time format
"""

default_time_format = '%Y-%m-%dT%H:%M:%S'
default_msec_format = '%s.%03d'
default_tz_format = '%z'

def formatTime(self, record, datefmt=None):
"""
Returns the creation time of the specified LogRecord in ISO 8601 date and time format
in the local time zone.
"""
dt = pendulum.from_timestamp(record.created, tz=pendulum.local_timezone())
if datefmt:
s = dt.strftime(datefmt)
else:
s = dt.strftime(self.default_time_format)

if self.default_msec_format:
s = self.default_msec_format % (s, record.msecs)
if self.default_tz_format:
s += dt.strftime(self.default_tz_format)
return s
45 changes: 17 additions & 28 deletions airflow/providers/elasticsearch/log/es_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
from airflow.configuration import conf
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.elasticsearch.log.es_json_formatter import ElasticsearchJSONFormatter
from airflow.utils import timezone
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.json_formatter import JSONFormatter
from airflow.utils.log.logging_mixin import ExternalLoggingMixin, LoggingMixin
from airflow.utils.session import create_session

LOG_LINE_DEFAULTS = {'exc_text': '', 'stack_info': ''}
# Elasticsearch hosted log type
EsLogMsgType = List[Tuple[str, str]]

Expand Down Expand Up @@ -95,7 +96,7 @@ def __init__(
super().__init__(base_log_folder, filename_template)
self.closed = False

self.client = elasticsearch.Elasticsearch([host], **es_kwargs) # type: ignore[attr-defined]
self.client = elasticsearch.Elasticsearch(host.split(';'), **es_kwargs) # type: ignore[attr-defined]

if USE_PER_RUN_LOG_ID and log_id_template is not None:
warnings.warn(
Expand All @@ -106,7 +107,7 @@ def __init__(
self.log_id_template = log_id_template # Only used on Airflow < 2.3.2.
self.frontend = frontend
self.mark_end_on_close = True
self.end_of_log_mark = end_of_log_mark
self.end_of_log_mark = end_of_log_mark.strip()
self.write_stdout = write_stdout
self.json_format = json_format
self.json_fields = [label.strip() for label in json_fields.split(",")]
Expand Down Expand Up @@ -178,10 +179,7 @@ def _group_logs_by_host(self, logs):
key = getattr(log, self.host_field, 'default_host')
grouped_logs[key].append(log)

# return items sorted by timestamp.
result = sorted(grouped_logs.items(), key=lambda kv: getattr(kv[1][0], 'message', '_'))

return result
return grouped_logs

def _read_grouped_logs(self):
return True
Expand Down Expand Up @@ -218,10 +216,10 @@ def _read(

# end_of_log_mark may contain characters like '\n' which is needed to
# have the log uploaded but will not be stored in elasticsearch.
loading_hosts = [
item[0] for item in logs_by_host if item[-1][-1].message != self.end_of_log_mark.strip()
]
metadata['end_of_log'] = False if not logs else len(loading_hosts) == 0
metadata['end_of_log'] = False
for logs in logs_by_host.values():
if logs[-1].message == self.end_of_log_mark:
metadata['end_of_log'] = True

cur_ts = pendulum.now()
if 'last_log_timestamp' in metadata:
Expand Down Expand Up @@ -251,10 +249,10 @@ def _read(
# If we hit the end of the log, remove the actual end_of_log message
# to prevent it from showing in the UI.
def concat_logs(lines):
log_range = (len(lines) - 1) if lines[-1].message == self.end_of_log_mark.strip() else len(lines)
log_range = (len(lines) - 1) if lines[-1].message == self.end_of_log_mark else len(lines)
return '\n'.join(self._format_msg(lines[i]) for i in range(log_range))

message = [(host, concat_logs(hosted_log)) for host, hosted_log in logs_by_host]
message = [(host, concat_logs(hosted_log)) for host, hosted_log in logs_by_host.items()]

return message, metadata

Expand All @@ -264,8 +262,9 @@ def _format_msg(self, log_line):
# if we change the formatter style from '%' to '{' or '$', this will still work
if self.json_format:
try:

return self.formatter._style.format(_ESJsonLogFmt(self.json_fields, **log_line.to_dict()))
return self.formatter._style.format(
logging.makeLogRecord({**LOG_LINE_DEFAULTS, **log_line.to_dict()})
)
except Exception:
pass

Expand Down Expand Up @@ -309,7 +308,7 @@ def es_read(self, log_id: str, offset: str, metadata: dict) -> list:

def emit(self, record):
if self.handler:
record.offset = int(time() * (10**9))
setattr(record, self.offset_field, int(time() * (10**9)))
self.handler.emit(record)

def set_context(self, ti: TaskInstance) -> None:
Expand All @@ -321,7 +320,7 @@ def set_context(self, ti: TaskInstance) -> None:
self.mark_end_on_close = not ti.raw

if self.json_format:
self.formatter = JSONFormatter(
self.formatter = ElasticsearchJSONFormatter(
fmt=self.formatter._fmt,
json_fields=self.json_fields + [self.offset_field],
extras={
Expand Down Expand Up @@ -370,7 +369,7 @@ def close(self) -> None:

# Mark the end of file using end of log mark,
# so we know where to stop while auto-tailing.
self.handler.stream.write(self.end_of_log_mark)
self.emit(logging.makeLogRecord({'msg': self.end_of_log_mark}))

if self.write_stdout:
self.handler.close()
Expand Down Expand Up @@ -402,13 +401,3 @@ def get_external_log_url(self, task_instance: TaskInstance, try_number: int) ->
def supports_external_link(self) -> bool:
"""Whether we can support external links"""
return bool(self.frontend)


class _ESJsonLogFmt:
"""Helper class to read ES Logs and re-format it to match settings.LOG_FORMAT"""

# A separate class is needed because 'self.formatter._style.format' uses '.__dict__'
def __init__(self, json_fields: List, **kwargs):
for field in json_fields:
self.__setattr__(field, '')
self.__dict__.update(kwargs)
10 changes: 5 additions & 5 deletions tests/providers/elasticsearch/log/test_es_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def test_close(self, ti):
# have the log uploaded but will not be stored in elasticsearch.
# so apply the strip() to log_file.read()
log_line = log_file.read().strip()
assert self.end_of_log_mark.strip() == log_line
assert log_line.endswith(self.end_of_log_mark.strip())
assert self.es_task_handler.closed

def test_close_no_mark_end(self, ti):
Expand Down Expand Up @@ -518,7 +518,7 @@ def test_dynamic_offset(self, stdout_mock, ti):
ti._log = logger
handler.set_context(ti)

t1 = pendulum.naive(year=2017, month=1, day=1, hour=1, minute=1, second=15)
t1 = pendulum.local(year=2017, month=1, day=1, hour=1, minute=1, second=15)
t2, t3 = t1 + pendulum.duration(seconds=5), t1 + pendulum.duration(seconds=10)

# act
Expand All @@ -532,6 +532,6 @@ def test_dynamic_offset(self, stdout_mock, ti):
# assert
first_log, second_log, third_log = map(json.loads, stdout_mock.getvalue().strip().split("\n"))
assert first_log['offset'] < second_log['offset'] < third_log['offset']
assert first_log['asctime'] == t1.format("YYYY-MM-DD HH:mm:ss,SSS")
assert second_log['asctime'] == t2.format("YYYY-MM-DD HH:mm:ss,SSS")
assert third_log['asctime'] == t3.format("YYYY-MM-DD HH:mm:ss,SSS")
assert first_log['asctime'] == t1.format("YYYY-MM-DDTHH:mm:ss.SSSZZ")
assert second_log['asctime'] == t2.format("YYYY-MM-DDTHH:mm:ss.SSSZZ")
assert third_log['asctime'] == t3.format("YYYY-MM-DDTHH:mm:ss.SSSZZ")