Skip to content

Commit

Permalink
Fix for long add_gtfs_headways calls (#55)
Browse files Browse the repository at this point in the history
* wip

* working but still not a huge jump in performance

* ready to go?

* commented out code

* reset state.json on GTFS change

* Update src/gobble.py

Co-authored-by: Devin Matte <[email protected]>

---------

Co-authored-by: Devin Matte <[email protected]>
  • Loading branch information
idreyn and devinmatte authored Jan 19, 2024
1 parent eeb3a2a commit b75f9a1
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 64 deletions.
25 changes: 12 additions & 13 deletions src/event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from datetime import date, datetime
from datetime import datetime
from typing import Tuple
import pandas as pd
from ddtrace import tracer
Expand Down Expand Up @@ -72,14 +72,7 @@ def reduce_update_event(update: dict) -> Tuple:


@tracer.wrap()
def process_event(
update: dict,
current_stop_state: dict,
gtfs_service_date: date,
scheduled_trips: pd.DataFrame,
scheduled_stop_times: pd.DataFrame,
stops: pd.DataFrame,
) -> None:
def process_event(update, current_stop_state: dict):
"""Process a single event from the MBTA's realtime API."""
(
current_status,
Expand Down Expand Up @@ -122,7 +115,8 @@ def process_event(
if is_departure_event:
stop_id = prev["stop_id"]

stop_name = get_stop_name(stops, stop_id)
gtfs_archive = gtfs.get_current_gtfs_archive()
stop_name = get_stop_name(gtfs_archive.stops, stop_id)
service_date = util.service_date(updated_at)

# store all commuter rail/subway stops, but only some bus stops
Expand Down Expand Up @@ -150,7 +144,7 @@ def process_event(
index=[0],
)

event = enrich_event(df, scheduled_trips, scheduled_stop_times)
event = enrich_event(df, gtfs_archive)
disk.write_event(event)

current_stop_state[trip_id] = {
Expand All @@ -164,14 +158,19 @@ def process_event(
disk.write_state(current_stop_state)


def enrich_event(df: pd.DataFrame, scheduled_trips: pd.DataFrame, scheduled_stop_times: pd.DataFrame) -> pd.DataFrame:
def enrich_event(df: pd.DataFrame, gtfs_archive: gtfs.GtfsArchive):
"""
Given a dataframe with a single event, enrich it with headway information and return a single event dict
"""
# ensure timestamp is always in local time to match the rest of the data
df["event_time"] = df["event_time"].dt.tz_convert(util.EASTERN_TIME)

headway_adjusted_df = gtfs.add_gtfs_headways(df, scheduled_trips, scheduled_stop_times)
# get trips and stop times for this route specifically (slow to scan them all)
route_id = df["route_id"].iloc[0]
scheduled_trips_for_route = gtfs_archive.trips_by_route_id(route_id)
scheduled_stop_times_for_route = gtfs_archive.stop_times_by_route_id(route_id)

headway_adjusted_df = gtfs.add_gtfs_headways(df, scheduled_trips_for_route, scheduled_stop_times_for_route)
# future warning: returning a series is actually the correct future behavior of to_pydatetime(), can drop the
# context manager later
with warnings.catch_warnings():
Expand Down
52 changes: 9 additions & 43 deletions src/gobble.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from datetime import date, datetime
from ddtrace import tracer
import json
import logging
import pandas as pd
import requests
import sseclient
import threading
Expand All @@ -14,7 +12,6 @@
from logger import set_up_logging
import gtfs
import disk
import util

logging.basicConfig(level=logging.INFO, filename="gobble.log")
tracer.enabled = CONFIG["DATADOG_TRACE_ENABLED"]
Expand All @@ -24,19 +21,18 @@


def main():
# Download the gtfs bundle before we proceed so we don't have to wait
logger.info("Downloading GTFS bundle if necessary...")
gtfs_service_date = util.service_date(datetime.now(util.EASTERN_TIME))
# Start downloading GTFS bundles immediately
gtfs.start_watching_gtfs()

rapid_thread = threading.Thread(
target=client_thread,
args=(gtfs_service_date, ROUTES_RAPID),
args=(ROUTES_RAPID,),
name="rapid_routes",
)

cr_thread = threading.Thread(
target=client_thread,
args=(gtfs_service_date, ROUTES_CR),
args=(ROUTES_CR,),
name="cr_routes",
)

Expand All @@ -48,7 +44,7 @@ def main():
routes_bus_chunk = list(ROUTES_BUS)[i : i + 10]
bus_thread = threading.Thread(
target=client_thread,
args=(gtfs_service_date, set(routes_bus_chunk)),
args=(set(routes_bus_chunk),),
name=f"routes_bus_chunk{i}",
)
bus_threads.append(bus_thread)
Expand All @@ -60,51 +56,21 @@ def main():
bus_thread.join()


def client_thread(
gtfs_service_date: date,
routes_filter: Set[str],
):
def client_thread(routes_filter: Set[str]):
url = f'https://api-v3.mbta.com/vehicles?filter[route]={",".join(routes_filter)}'
logger.info(f"Connecting to {url}...")
client = sseclient.SSEClient(requests.get(url, headers=HEADERS, stream=True))

current_stop_state: Dict = disk.read_state()

scheduled_trips, scheduled_stop_times, stops = gtfs.read_gtfs(gtfs_service_date, routes_filter=routes_filter)
process_events(
client, current_stop_state, gtfs_service_date, scheduled_trips, scheduled_stop_times, stops, routes_filter
)
process_events(client, current_stop_state)


def process_events(
client: sseclient.SSEClient,
current_stop_state: dict,
gtfs_service_date: date,
scheduled_trips: pd.DataFrame,
scheduled_stop_times: pd.DataFrame,
stops: pd.DataFrame,
routes_filter: Set[str],
):
def process_events(client: sseclient.SSEClient, current_stop_state: dict):
for event in client.events():
try:
if event.event != "update":
continue

update = json.loads(event.data)

# check for new day
updated_at = datetime.fromisoformat(update["attributes"]["updated_at"])
service_date = util.service_date(updated_at)

if gtfs_service_date != service_date:
logger.info(
f"New day! Refreshing GTFS bundle from {gtfs_service_date} to {service_date} and clearing state..."
)
disk.write_state({})
scheduled_trips, scheduled_stop_times, stops = gtfs.read_gtfs(gtfs_service_date, routes_filter)
gtfs_service_date = service_date

process_event(update, current_stop_state, gtfs_service_date, scheduled_trips, scheduled_stop_times, stops)
process_event(update, current_stop_state)
except Exception:
logger.exception("Encountered an exception when processing an event", stack_info=True, exc_info=True)
continue
Expand Down
97 changes: 89 additions & 8 deletions src/gtfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
import pathlib
import shutil
import urllib.request
import time
from urllib.parse import urljoin
from dataclasses import dataclass
from ddtrace import tracer
from typing import List, Tuple, Optional, Set
from threading import Lock, Thread
from typing import List, Dict, Optional, Set

from config import CONFIG
from constants import ALL_ROUTES
from logger import set_up_logging
from util import EASTERN_TIME

import util
import disk

logger = set_up_logging(__name__)
tracer.enabled = CONFIG["DATADOG_TRACE_ENABLED"]
Expand All @@ -29,6 +35,43 @@
STOP_TIMES_COLS = ["stop_id", "trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"]


def _group_df_by_column(df: pd.DataFrame, column_name: str) -> Dict[str, pd.DataFrame]:
return {key: df_group for key, df_group in df.groupby(column_name)}


def _get_empty_df_with_same_columns(df: pd.DataFrame) -> pd.DataFrame:
empty_df = df.copy(deep=False)
empty_df.drop(empty_df.index, inplace=True)
return empty_df


@dataclass
class GtfsArchive:
# All trips on all routes
trips: pd.DataFrame
# All stop times on all trips
stop_times: pd.DataFrame
# All stops
stops: pd.DataFrame
# The current service date
service_date: datetime.date

def __post_init__(self):
self._trips_empty = _get_empty_df_with_same_columns(self.trips)
self._stop_times_empty = _get_empty_df_with_same_columns(self.stop_times)
self._trips_by_route_id = _group_df_by_column(self.trips, "route_id")
self._stop_times_by_route_id = {}
for route_id in self._trips_by_route_id.keys():
trip_ids_for_route = self._trips_by_route_id[route_id].trip_id
self._stop_times_by_route_id[route_id] = self.stop_times[self.stop_times.trip_id.isin(trip_ids_for_route)]

def stop_times_by_route_id(self, route_id: str):
return self._stop_times_by_route_id.get(route_id, self._stop_times_empty)

def trips_by_route_id(self, route_id: str):
return self._trips_by_route_id.get(route_id, self._trips_empty)


@tracer.wrap()
def _download_gtfs_archives_list() -> pd.DataFrame:
"""Downloads list of GTFS archive urls. This file will get overwritten."""
Expand Down Expand Up @@ -101,9 +144,7 @@ def get_services(date: datetime.date, archive_dir: pathlib.Path) -> List[str]:


@tracer.wrap()
def read_gtfs(
date: datetime.date, routes_filter: Optional[Set[str]] = None
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
def read_gtfs(date: datetime.date, routes_filter: Optional[Set[str]] = None) -> GtfsArchive:
"""
Given a date, this function will:
- Find the appropriate gtfs archive (downloading if necessary)
Expand All @@ -112,6 +153,7 @@ def read_gtfs(
If a route filter is applied, only return trips and stop information relevent to supplied routes. Otherwise, return all services.
"""
dateint = to_dateint(date)
logger.info(f"Reading GTFS archive for {date}")

archive_dir = get_gtfs_archive(dateint)
services = get_services(date, archive_dir)
Expand All @@ -132,22 +174,22 @@ def read_gtfs(
stop_times.arrival_time = pd.to_timedelta(stop_times.arrival_time)
stop_times.departure_time = pd.to_timedelta(stop_times.departure_time)

return trips, stop_times, stops
return GtfsArchive(trips=trips, stop_times=stop_times, stops=stops, service_date=date)


@tracer.wrap()
def batch_add_gtfs_headways(events_df: pd.DataFrame, all_trips: pd.DataFrame, all_stops: pd.DataFrame) -> pd.DataFrame:
def batch_add_gtfs_headways(events_df: pd.DataFrame, trips: pd.DataFrame, stop_times: pd.DataFrame) -> pd.DataFrame:
"""A batch implementation of add_gtfs_headways--this will probably never be used, but we include it just in case."""
results = []

# we have to do this day-by-day because gtfs changes so often
for service_date, days_events in events_df.groupby("service_date"):
# filter out the trips of interest
relevant_trips = all_trips[all_trips.route_id.isin(days_events.route_id)]
relevant_trips = trips[trips.route_id.isin(days_events.route_id)]

# take only the stops from those trips (adding route and dir info)
trip_info = relevant_trips[["trip_id", "route_id", "direction_id"]]
gtfs_stops = all_stops.merge(trip_info, on="trip_id", how="right")
gtfs_stops = stop_times.merge(trip_info, on="trip_id", how="right")

# calculate gtfs headways
gtfs_stops = gtfs_stops.sort_values(by="arrival_time")
Expand Down Expand Up @@ -277,3 +319,42 @@ def add_gtfs_headways(event_df: pd.DataFrame, all_trips: pd.DataFrame, all_stops
)

return augmented_event


current_gtfs_archive = None
write_gtfs_archive_lock = Lock()


def update_current_gtfs_archive_if_necessary():
global current_gtfs_archive
global write_gtfs_archive_lock
with write_gtfs_archive_lock:
gtfs_service_date = util.service_date(datetime.datetime.now(util.EASTERN_TIME))
needs_update = current_gtfs_archive is None or current_gtfs_archive.service_date != gtfs_service_date
if needs_update:
if current_gtfs_archive is None:
logger.info(f"Downloading GTFS archive for {gtfs_service_date}")
else:
logger.info(f"Updating GTFS archive from {current_gtfs_archive.service_date} to {gtfs_service_date}")
current_gtfs_archive = read_gtfs(gtfs_service_date, routes_filter=ALL_ROUTES)
# TODO(ian): This will become a per-trip concern in a future change
# See https://transitmatters.slack.com/archives/GSJ6F35DW/p1705680401311829?thread_ts=1705677890.833879&cid=GSJ6F35DW
disk.write_state({})


def get_current_gtfs_archive():
global current_gtfs_archive
if current_gtfs_archive is None:
update_current_gtfs_archive_if_necessary()
return current_gtfs_archive


def update_gtfs_thread():
while True:
update_current_gtfs_archive_if_necessary()
time.sleep(60)


def start_watching_gtfs():
gtfs_thread = Thread(target=update_gtfs_thread, name="update_gtfs")
gtfs_thread.start()
30 changes: 30 additions & 0 deletions src/timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from functools import wraps
from time import time
from random import random
import numpy as np


def measure_time(report_frequency: float = 1.0, trail_length=1000):
def decorator(fn):
exec_times = []

@wraps(fn)
def wrap(*args, **kw):
nonlocal exec_times
ts = time()
result = fn(*args, **kw)
te = time()
exec_times.append(te - ts)
if random() < report_frequency:
last = exec_times[-1]
exec_times = exec_times[-trail_length:]
avg = np.mean(exec_times)
std = np.std(exec_times)
min = np.min(exec_times)
max = np.max(exec_times)
print(f"func {fn.__name__}: last={last:.3f}s min={min:.3f} max={max:.3f} avg={avg:.3f}s std={std:.3f}s")
return result

return wrap

return decorator

0 comments on commit b75f9a1

Please sign in to comment.