diff --git a/custom_components/gtfs2/gtfs_rt_helper.py b/custom_components/gtfs2/gtfs_rt_helper.py index 04caaac..b551dde 100644 --- a/custom_components/gtfs2/gtfs_rt_helper.py +++ b/custom_components/gtfs2/gtfs_rt_helper.py @@ -86,10 +86,11 @@ def get_gtfs_feed_entities(url: str, headers, label: str): except ValueError as e: if label == "vehicle_positions": feed = convert_gtfs_realtime_positions_to_json(response.content) - elif label == "alerts": - feed = convert_gtfs_realtime_alerts_to_json(response.content) - else: + elif label == "trip_data": feed = convert_gtfs_realtime_to_json(response.content) + else: # not yet converted to json + feed.ParseFromString(response.content) + return feed.entity return feed.get('entity') @@ -166,7 +167,7 @@ def get_rt_route_trip_statuses(self): vehicle_positions = get_rt_vehicle_positions(self) feed_entities = get_gtfs_feed_entities( - url=self._trip_update_url, headers=self._headers, label="trip data" + url=self._trip_update_url, headers=self._headers, label="trip_data" ) self._feed_entities = feed_entities _LOGGER.debug("Search departure times for route: %s, trip: %s, type: %s, direction: %s", self._route_id, self._trip_id, self._rt_group, self._direction) @@ -303,6 +304,38 @@ def get_rt_vehicle_positions(self): return geojson_body def get_rt_alerts(self): + rt_alerts = {} + if (self._alerts_url)[:4] == "http": + feed_entities = get_gtfs_feed_entities( + url=self._alerts_url, + headers=self._headers, + label="alerts", + ) + for entity in feed_entities: + if entity.HasField("alert"): + for x in entity.alert.informed_entity: + if x.HasField("stop_id"): + stop_id = x.stop_id + else: + stop_id = "unknown" + if x.HasField("stop_id"): + route_id = x.route_id + else: + route_id = "unknown" + if stop_id == self._stop_id and (route_id == "unknown" or route_id == self._route_id): + _LOGGER.debug("RT Alert for route: %s, stop: %s, alert: %s", route_id, stop_id, entity.alert.header_text) + rt_alerts["origin_stop_alert"] = (str(entity.alert.header_text).split('text: "')[1]).split('"',1)[0].replace(':','').replace('\n','') + if stop_id == self._destination_id and (route_id == "unknown" or route_id == self._route_id): + _LOGGER.debug("RT Alert for route: %s, stop: %s, alert: %s", route_id, stop_id, entity.alert.header_text) + rt_alerts["destination_stop_alert"] = (str(entity.alert.header_text).split('text: "')[1]).split('"',1)[0].replace(':','').replace('\n','') + if stop_id == "unknown" and route_id == self._route_id: + _LOGGER.debug("RT Alert for route: %s, stop: %s, alert: %s", route_id, stop_id, entity.alert.header_text) + rt_alerts["origin_stop_alert"] = (str(entity.alert.header_text).split('text: "')[1]).split('"',1)[0].replace(':','').replace('\n','') + rt_alerts["destination_stop_alert"] = (str(entity.alert.header_text).split('text: "')[1]).split('"',1)[0].replace(':','').replace('\n','') + + return rt_alerts + +def get_rt_alerts_json(self): rt_alerts = {} if (self._alerts_url)[:4] == "http": feed_entities = get_gtfs_feed_entities( @@ -507,6 +540,6 @@ def convert_gtfs_realtime_alerts_to_json(gtfs_realtime_data): "description_text": entity.alert.description_text } } - json_data["entity"].append(entity_dict) + json_data["entity"].append(entity_dict) _LOGGER.debug("Alert entity JSON: %s", json_data["entity"]) return json_data \ No newline at end of file