Skip to content

Commit

Permalink
Fix total_entries count on the event logs endpoint (#38625)
Browse files Browse the repository at this point in the history
The `total_entries` count should reflect how many log entries match the
filters provided, not simply how many rows are in the table total.
  • Loading branch information
jedcunningham authored Mar 30, 2024
1 parent 50a4c95 commit 90e7b3f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
10 changes: 6 additions & 4 deletions airflow/api_connexion/endpoints/event_log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from sqlalchemy import func, select
from sqlalchemy import select

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound
Expand All @@ -31,6 +31,7 @@
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models import Log
from airflow.utils import timezone
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,7 +71,7 @@ def get_event_logs(
) -> APIResponse:
"""Get all log entries from event log."""
to_replace = {"event_log_id": "id", "when": "dttm"}
allowed_filter_attrs = [
allowed_sort_attrs = [
"event_log_id",
"when",
"dag_id",
Expand All @@ -81,7 +82,6 @@ def get_event_logs(
"owner",
"extra",
]
total_entries = session.scalars(func.count(Log.id)).one()
query = select(Log)

if dag_id:
Expand All @@ -105,7 +105,9 @@ def get_event_logs(
if after:
query = query.where(Log.dttm > timezone.parse(after))

query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
total_entries = get_query_count(query, session=session)

query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)
event_logs = session.scalars(query.offset(offset).limit(limit)).all()
return event_log_collection_schema.dump(
EventLogCollection(event_logs=event_logs, total_entries=total_entries)
Expand Down
18 changes: 13 additions & 5 deletions tests/api_connexion/endpoints/test_event_log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,23 +281,27 @@ def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, s
f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"}
)
assert response.status_code == 200
assert {eventlog[attr] for eventlog in response.json["event_logs"]} == {attr_value}
assert response.json["total_entries"] == 1
assert len(response.json["event_logs"]) == 1
assert response.json["event_logs"][0][attr] == attr_value

def test_should_filter_eventlogs_by_when(self, create_log_model, session):
eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time)
eventlog2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2)
session.add_all([eventlog1, eventlog2])
session.commit()
for when_attr, expected_eventlogs in {
"before": {"TEST_EVENT_1"},
"after": {"TEST_EVENT_2"},
for when_attr, expected_eventlog_event in {
"before": "TEST_EVENT_1",
"after": "TEST_EVENT_2",
}.items():
response = self.client.get(
f"/api/v1/eventLogs?{when_attr}=2020-06-10T20%3A00%3A01%2B00%3A00", # self.default_time + 1s
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
assert {eventlog["event"] for eventlog in response.json["event_logs"]} == expected_eventlogs
assert response.json["total_entries"] == 1
assert len(response.json["event_logs"]) == 1
assert response.json["event_logs"][0]["event"] == expected_eventlog_event

def test_should_filter_eventlogs_by_run_id(self, create_log_model, session):
eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time, run_id="run_1")
Expand All @@ -314,6 +318,8 @@ def test_should_filter_eventlogs_by_run_id(self, create_log_model, session):
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
assert response.json["total_entries"] == len(expected_eventlogs)
assert len(response.json["event_logs"]) == len(expected_eventlogs)
assert {eventlog["event"] for eventlog in response.json["event_logs"]} == expected_eventlogs
assert all({eventlog["run_id"] == run_id for eventlog in response.json["event_logs"]})

Expand All @@ -327,6 +333,7 @@ def test_should_filter_eventlogs_by_included_events(self, create_log_model):
assert response.status_code == 200
response_data = response.json
assert len(response_data["event_logs"]) == 2
assert response_data["total_entries"] == 2
assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]}

def test_should_filter_eventlogs_by_excluded_events(self, create_log_model):
Expand All @@ -339,6 +346,7 @@ def test_should_filter_eventlogs_by_excluded_events(self, create_log_model):
assert response.status_code == 200
response_data = response.json
assert len(response_data["event_logs"]) == 1
assert response_data["total_entries"] == 1
assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]}


Expand Down

0 comments on commit 90e7b3f

Please sign in to comment.