From 790be1f8277b7de808448cf3d6bb57e4c8728664 Mon Sep 17 00:00:00 2001 From: Michelle Tran Date: Wed, 30 Oct 2024 23:26:09 -0400 Subject: [PATCH] Refactored timeseries service to dedupe code There was a bunch of duplication and large methods. This refactoring reduces the duplication and hopefully makes things more readable and traceable. --- services/timeseries.py | 191 ++++++++++++++++++++--------------------- 1 file changed, 95 insertions(+), 96 deletions(-) diff --git a/services/timeseries.py b/services/timeseries.py index b7af8a963..4dd0ccddb 100644 --- a/services/timeseries.py +++ b/services/timeseries.py @@ -1,9 +1,11 @@ import logging from datetime import datetime -from typing import Iterable, Mapping, Optional +from typing import Any, Iterable, Mapping, Optional +from shared.components import Component from shared.reports.readonly import ReadOnlyReport from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session from database.models import Commit, Dataset, Measurement, MeasurementName from database.models.core import Repository @@ -39,35 +41,28 @@ def save_commit_measurements( db_session = commit.get_db_session() + maybe_upsert_coverage_measurement(commit, dataset_names, db_session, report) + maybe_upsert_components_measurements( + commit, current_yaml, dataset_names, db_session, report + ) + maybe_upsert_flag_measurements(commit, dataset_names, db_session, report) + + +def maybe_upsert_coverage_measurement(commit, dataset_names, db_session, report): if MeasurementName.coverage.value in dataset_names: if report.totals.coverage is not None: - command = insert(Measurement.__table__).values( - name=MeasurementName.coverage.value, - owner_id=commit.repository.ownerid, - repo_id=commit.repoid, - measurable_id=f"{commit.repoid}", - branch=commit.branch, - commit_sha=commit.commitid, - timestamp=commit.timestamp, - value=float(report.totals.coverage), - ) - command = command.on_conflict_do_update( - index_elements=[ - Measurement.name, - Measurement.owner_id, - Measurement.repo_id, - Measurement.measurable_id, - Measurement.commit_sha, - Measurement.timestamp, - ], - set_=dict( - branch=command.excluded.branch, - value=command.excluded.value, - ), - ) - db_session.execute(command) - db_session.flush() + measurements = [ + create_measurement_dict( + MeasurementName.coverage.value, + commit, + measurable_id=f"{commit.repoid}", + value=float(report.totals.coverage), + ) + ] + upsert_measurements(db_session, measurements) + +def maybe_upsert_flag_measurements(commit, dataset_names, db_session, report): if MeasurementName.flag_coverage.value in dataset_names: flag_ids = repository_flag_ids(commit.repository) measurements = [] @@ -89,14 +84,10 @@ def save_commit_measurements( flag_id = repo_flag.id measurements.append( - dict( - name=MeasurementName.flag_coverage.value, - owner_id=commit.repository.ownerid, - repo_id=commit.repoid, + create_measurement_dict( + MeasurementName.flag_coverage.value, + commit, measurable_id=f"{flag_id}", - branch=commit.branch, - commit_sha=commit.commitid, - timestamp=commit.timestamp, value=float(flag.totals.coverage), ) ) @@ -110,28 +101,16 @@ def save_commit_measurements( count=len(measurements), ), ) - command = insert(Measurement.__table__).values(measurements) - command = command.on_conflict_do_update( - index_elements=[ - Measurement.name, - Measurement.owner_id, - Measurement.repo_id, - Measurement.measurable_id, - Measurement.commit_sha, - Measurement.timestamp, - ], - set_=dict( - branch=command.excluded.branch, - value=command.excluded.value, - ), - ) - db_session.execute(command) - db_session.flush() + upsert_measurements(db_session, measurements) + +def maybe_upsert_components_measurements( + commit, current_yaml, dataset_names, db_session, report +): if MeasurementName.component_coverage.value in dataset_names: components = current_yaml.get_components() if components: - measurements = dict() + component_measurements = dict() for component in components: if component.paths or component.flag_regexes: @@ -141,18 +120,17 @@ def save_commit_measurements( filtered_report = report.filter( flags=report_and_component_matching_flags, paths=component.paths ) - if filtered_report.totals.coverage is not None: - measurement_key = ( - MeasurementName.component_coverage.value, - commit.repository.ownerid, - commit.repoid, - f"{component.component_id}", - commit.commitid, - commit.timestamp, + # This measurement key is being used to check for measurement existence and log the warning. + # TODO: see if we can remove this warning message as it's necessary to emit this warning. + # We're currently not doing anything with this information. + measurement_key = create_component_measurement_key( + commit, component ) if ( - existing_measurement := measurements.get(measurement_key) + existing_measurement := component_measurements.get( + measurement_key + ) ) is not None: log.warning( "Duplicate measurement keys being added to measurements", @@ -166,53 +144,74 @@ def save_commit_measurements( ), ) - measurements[ - ( + component_measurements[measurement_key] = ( + create_measurement_dict( MeasurementName.component_coverage.value, - commit.repository.ownerid, - commit.repoid, - f"{component.component_id}", - commit.commitid, - commit.timestamp, + commit, + measurable_id=f"{component.component_id}", + value=float(filtered_report.totals.coverage), ) - ] = dict( - name=MeasurementName.component_coverage.value, - owner_id=commit.repository.ownerid, - repo_id=commit.repoid, - branch=commit.branch, - commit_sha=commit.commitid, - timestamp=commit.timestamp, - measurable_id=f"{component.component_id}", - value=float(filtered_report.totals.coverage), ) - measurements = list(measurements.values()) + measurements = list(component_measurements.values()) if len(measurements) > 0: + upsert_measurements(db_session, measurements) log.info( - "Upserting component coverage measurements", + "Upserted component coverage measurements", extra=dict( repoid=commit.repoid, commit_id=commit.id_, count=len(measurements), ), ) - command = insert(Measurement.__table__).values(measurements) - command = command.on_conflict_do_update( - index_elements=[ - Measurement.name, - Measurement.owner_id, - Measurement.repo_id, - Measurement.measurable_id, - Measurement.commit_sha, - Measurement.timestamp, - ], - set_=dict( - branch=command.excluded.branch, - value=command.excluded.value, - ), - ) - db_session.execute(command) - db_session.flush() + + +def create_measurement_dict( + name: str, commit: Commit, measurable_id: str, value: float +) -> dict[str, Any]: + return dict( + name=name, + owner_id=commit.repository.ownerid, + repo_id=commit.repoid, + measurable_id=measurable_id, + branch=commit.branch, + commit_sha=commit.commitid, + timestamp=commit.timestamp, + value=value, + ) + + +def create_component_measurement_key(commit: Commit, component: Component) -> tuple: + return ( + MeasurementName.component_coverage.value, + commit.repository.ownerid, + commit.repoid, + f"{component.component_id}", + commit.commitid, + commit.timestamp, + ) + + +def upsert_measurements( + db_session: Session, measurements: list[dict[str, Any]] +) -> None: + command = insert(Measurement.__table__).values(measurements) + command = command.on_conflict_do_update( + index_elements=[ + Measurement.name, + Measurement.owner_id, + Measurement.repo_id, + Measurement.measurable_id, + Measurement.commit_sha, + Measurement.timestamp, + ], + set_=dict( + branch=command.excluded.branch, + value=command.excluded.value, + ), + ) + db_session.execute(command) + db_session.flush() def repository_commits_query(