diff --git a/light/data_model/db/base.py b/light/data_model/db/base.py index 5668ec39d..3b401b2ca 100644 --- a/light/data_model/db/base.py +++ b/light/data_model/db/base.py @@ -11,13 +11,22 @@ from enum import Enum from typing import Optional, Union, Dict, Any from dataclasses import dataclass +from tempfile import mkdtemp +import shutil +import os +import json from hydra.core.config_store import ConfigStore +DEFAULT_LOG_PATH = "".join( + [os.path.abspath(os.path.dirname(__file__)), "/../../../logs"] +) + @dataclass class LightDBConfig: backend: str = "test" + file_root: Optional[str] = DEFAULT_LOG_PATH cs = ConfigStore.instance() @@ -61,8 +70,14 @@ def __init__(self, config: "DictConfig"): files and instances. """ # TODO replace with a swappable engine that persists the data + self.backend = config.backend if config.backend == "test": self.engine = create_engine("sqlite+pysqlite:///:memory:", future=True) + self.made_temp_dir = config.file_root is None + if self.made_temp_dir: + self.file_root = mkdtemp() + else: + self.file_root = config.file_root else: raise NotImplementedError() self._complete_init(config) @@ -91,18 +106,38 @@ def _enforce_get_first(self, session, stmt, error_text) -> Any: def write_data_to_file( self, data: Union[str, Dict[str, Any]], filename: str, json_encode: bool = False - ): + ) -> None: """ Write the given data to the provided filename in the correct storage location (local or remote) """ - # Expects data to be a string, unless json_encode is True + if self.backend in ["test", "local"]: + full_path = os.path.join(self.file_root, filename) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w+") as target_file: + if json_encode: + json.dump(data, target_file) + else: + target_file.write(data) + else: + raise NotImplementedError() - def read_data_from_file(self, filename: str, json_encoded: bool = False): + def read_data_from_file( + self, filename: str, json_encoded: bool = False + ) -> Union[str, Dict[str, Any]]: """ Read the data from the given filename from wherever it is currently stored (local or remote) """ + if self.backend in ["test", "local"]: + full_path = os.path.join(self.file_root, filename) + with open(full_path, "r") as target_file: + if json_encoded: + return json.load(target_file) + else: + return target_file.read() + else: + raise NotImplementedError() def open_file(self): try: @@ -110,3 +145,8 @@ def open_file(self): yield file finally: file.close() + + def shutdown(self): + if self.backend == "test": + if self.made_temp_dir: + shutil.rmtree(self.file_root) diff --git a/light/data_model/db/episodes.py b/light/data_model/db/episodes.py index 09f408187..412a31bdb 100644 --- a/light/data_model/db/episodes.py +++ b/light/data_model/db/episodes.py @@ -5,14 +5,22 @@ # LICENSE file in the root directory of this source tree. from light.data_model.db.base import BaseDB, DBStatus, DBSplitType -from light.graph.structured_graph import OOGraph from omegaconf import MISSING, DictConfig -from typing import Optional, Union, Dict, Any -from dataclasses import dataclass -from enum import Enum +from typing import Optional, List, Tuple, Union, Dict, Any, Set, TYPE_CHECKING +from sqlalchemy import insert, select, Enum, Column, Integer, String, Float, ForeignKey +from sqlalchemy.orm import declarative_base, relationship, Session +from light.graph.events.base import GraphEvent +import time +import enum +import os +if TYPE_CHECKING: + from light.graph.structured_graph import OOGraph -class DBGroupName(Enum): +SQLBase = declarative_base() + + +class DBGroupName(enum.Enum): """Edges in the LIGHT Environment DB""" ORIG = "orig" @@ -22,26 +30,130 @@ class DBGroupName(Enum): RELEASE = "full_release" -@dataclass -class DBEpisode: +class EpisodeLogType(enum.Enum): + """Types of episodes in LIGHT""" + + ROOM = "room" + AGENT = "agent" + FULL = "full" + + +class DBEpisode(SQLBase): """Class containing the expected elements for an episode as stored in the db""" - group: DBGroupName - split: DBSplitType - status: DBStatus - actors: List[str] - dump_file: str - after_graph: str - timestamp: float + __tablename__ = "episodes" + + id = Column(Integer, primary_key=True) + group = Column(Enum(DBGroupName), nullable=False, index=True) + split = Column(Enum(DBSplitType), nullable=False, index=True) + status = Column(Enum(DBStatus), nullable=False, index=True) + actors = Column( + String + ) # Comma separated list of actor IDs. Cleared on release data + dump_file_path = Column(String(90), nullable=False) # Path to data + turn_count = Column(Integer, nullable=False) + human_count = Column(Integer, nullable=False) + action_count = Column(Integer, nullable=False) + timestamp = Column(Float, nullable=False) + log_type = Column(Enum(EpisodeLogType), nullable=False) + first_graph_id = Column(ForeignKey("graphs.id")) + final_graph_id = Column(ForeignKey("graphs.id")) + + _cached_map = None + + def get_actors(self) -> List[str]: + """Return the actors in this episode""" + if len(self.actors.strip()) == 0: + return [] + return self.actors.split(",") + + def get_parsed_events( + self, db: "EpisodeDB" + ) -> List[Tuple[str, List["GraphEvent"]]]: + """ + Return all of the actions and turns from this episode, + split by the graph key ID relevant to those actions + """ + # Import deferred as World imports loggers which import the EpisodeDB + from light.world.world import World + + events = db.read_data_from_file(self.dump_file_path, json_encoded=True)[ + "events" + ] + graph_grouped_events: List[Tuple[str, List["GraphEvent"]]] = [] + current_graph_events = None + curr_graph_key = None + curr_graph = None + tmp_world = None + # Extract events to the correct related graphs, initializing the graphs + # as necessary + for event_turn in events: + # See if we've moved onto an event in a new graph + if event_turn["graph_key"] != curr_graph_key: + if current_graph_events is not None: + # There was old state, so lets push it to the list + graph_grouped_events.append((curr_graph_key, current_graph_events)) + # We're on a new graph, have to reset the current graph state + curr_graph_key = event_turn["graph_key"] + current_graph_events: List["GraphEvent"] = [] + curr_graph = self.get_graph(curr_graph_key, db) + tmp_world = World({}, None) + tmp_world.oo_graph = curr_graph + # The current turn is part of the current graph's events, add + current_graph_events.append( + GraphEvent.from_json(event_turn["event_json"], tmp_world) + ) + if current_graph_events is not None: + # Push the last graph's events, which weren't yet added + graph_grouped_events.append((curr_graph_key, current_graph_events)) + return graph_grouped_events def get_before_graph(self, db: "EpisodeDB") -> "OOGraph": """Return the state of the graph before this episode""" + return self.get_graph(self.first_graph_id, db) - def get_parsed_episode(self, db: "EpisodeDB") -> List[Any]: - """Return all of the actions and turns from this episode""" + def get_graph(self, id_or_key: str, db: "EpisodeDB") -> "OOGraph": + """Return a specific graph by id or key""" + return self.get_graph_map()[id_or_key].get_graph(db) def get_after_graph(self, db: "EpisodeDB") -> "OOGraph": """Return the state of the graph after this episode""" + return self.get_graph(self.final_graph_id, db) + + def get_graph_map(self): + """Return a mapping from both graph keys and graph ids to their graph""" + if self._cached_map is None: + key_map = {graph.graph_key_id: graph for graph in self.graphs} + id_map = {graph.id: graph for graph in self.graphs} + key_map.update(id_map) + self._cached_map = key_map + return self._cached_map + + def __repr__(self): + return f"DBEpisode(ids:[{self.id!r}] group/split:[{self.group.value!r}/{self.split.value!r}] File:[{self.dump_file_path!r}])" + + +class DBGraph(SQLBase): + """Class containing expected elements for a stored graph""" + + __tablename__ = "graphs" + + id = Column(Integer, primary_key=True) + episode_id = Column(Integer, ForeignKey("episodes.id"), nullable=False, index=True) + full_path = Column(String(80), nullable=False) + graph_key_id = Column(String(60), nullable=False, index=True) + episode = relationship("DBEpisode", backref="graphs", foreign_keys=[episode_id]) + + def get_graph(self, db: "EpisodeDB") -> "OOGraph": + """Return the initialized graph based on this file""" + from light.graph.structured_graph import OOGraph + + graph_json = db.read_data_from_file(self.full_path) + graph = OOGraph.from_json(graph_json) + return graph + + def __repr__(self): + return f"DBGraph(ids:[{self.id!r},{self.graph_key_id!r}], episode:{self.episode_id!r})" class EpisodeDB(BaseDB): @@ -58,28 +170,99 @@ def _complete_init(self, config: "DictConfig"): Initialize any specific episode-related paths. Populate the list of available splits and datasets. """ - raise NotImplementedError() + SQLBase.metadata.create_all(self.engine) def _validate_init(self): """ Ensure that the episode directory is properly loaded """ - raise NotImplementedError() + # TODO Check the table for any possible consistency issues + # and ensure that the episode directories for listed splits exist - def write_episode(self, args) -> str: + def write_episode( + self, + graphs: List[Dict[str, str]], + events: Tuple[str, List[Dict[str, str]]], + log_type: EpisodeLogType, + action_count: int, + players: Set[str], + group: DBGroupName, + ) -> str: """ Create an entry given the current argument data, store it to file on the database """ + actor_string = ",".join(list(players)) + event_filename = events[0] + event_list = events[1] + # Trim the filename from the left if too long + event_filename = event_filename[-70:] + assert len(event_filename) <= 70 + dump_file_path = os.path.join(group.value, log_type.value, event_filename) + graph_dump_root = os.path.join( + group.value, + log_type.value, + "graphs", + ) + + # File writes + self.write_data_to_file( + {"events": event_list}, dump_file_path, json_encode=True + ) + for graph_info in graphs: + graph_full_path = os.path.join(graph_dump_root, graph_info["filename"]) + self.write_data_to_file(graph_info["graph_json"], graph_full_path) + + # DB Writes + with Session(self.engine) as session: + episode = DBEpisode( + group=group, + split=DBSplitType.UNSET, + status=DBStatus.REVIEW, + actors=actor_string, + dump_file_path=dump_file_path, + turn_count=len(event_list), + human_count=len(players), + action_count=action_count, + timestamp=time.time(), + log_type=log_type, + ) + first_graph = None + for idx, graph_info in enumerate(graphs): + graph_full_path = os.path.join(graph_dump_root, graph_info["filename"]) + db_graph = DBGraph( + graph_key_id=graph_info["key"], + full_path=graph_full_path, + ) + if idx == 0: + first_graph = db_graph + episode.graphs.append(db_graph) + session.add(episode) + session.flush() + episode.first_graph_id = first_graph.id + episode.final_graph_id = db_graph.id + + episode_id = episode.id + session.commit() + + return episode_id def get_episode(self, episode_id: str) -> "DBEpisode": """ Return a specific episode by id, raising an issue if it doesnt exist """ + stmt = select(DBEpisode).where(DBEpisode.id == episode_id) + with Session(self.engine) as session: + episode = self._enforce_get_first(session, stmt, "Episode did not exist") + for graph in episode.graphs: + # Load all the graph keys + assert graph.id is not None + session.expunge_all() + return episode def get_episodes( self, - group: Optional[str] = None, + group: Optional[DBGroupName] = None, split: Optional[DBSplitType] = None, min_turns: Optional[int] = None, min_humans: Optional[int] = None, @@ -88,8 +271,34 @@ def get_episodes( user_id: Optional[str] = None, min_creation_time: Optional[float] = None, max_creation_time: Optional[float] = None, + log_type: Optional[EpisodeLogType] = None, # ... other args ) -> List["DBEpisode"]: """ Return all matching episodes """ + stmt = select(DBEpisode) + if group is not None: + stmt = stmt.where(DBEpisode.group == group) + if split is not None: + stmt = stmt.where(DBEpisode.split == split) + if min_turns is not None: + stmt = stmt.where(DBEpisode.turn_count >= min_turns) + if min_humans is not None: + stmt = stmt.where(DBEpisode.human_count >= min_humans) + if min_actions is not None: + stmt = stmt.where(DBEpisode.action_count >= min_actions) + if status is not None: + stmt = stmt.where(DBEpisode.status == status) + if user_id is not None: + stmt = stmt.where(DBEpisode.actors.contains(user_id)) + if log_type is not None: + stmt = stmt.where(DBEpisode.log_type == log_type) + if min_creation_time is not None: + stmt = stmt.where(DBEpisode.timestamp >= min_creation_time) + if max_creation_time is not None: + stmt = stmt.where(DBEpisode.timestamp <= max_creation_time) + with Session(self.engine) as session: + episodes = session.scalars(stmt).all() + session.expunge_all() + return episodes diff --git a/light/data_model/db/users.py b/light/data_model/db/users.py index cc23dc978..45fc889bf 100644 --- a/light/data_model/db/users.py +++ b/light/data_model/db/users.py @@ -45,7 +45,7 @@ class DBPlayer(SQLBase): scores = relationship("DBScoreEntry") def __repr__(self): - return f"DBPlayer(ids:[{self.id!r},{self.extern_id!r}], preauth:{self.is_preauth!r}, status:{self.account_status!r})" + return f"DBPlayer(ids:[{self.id!r},{self.extern_id!r}], preauth:{self.is_preauth!r}, status:{self.account_status.value!r})" class DBScoreEntry(SQLBase): diff --git a/light/data_model/tests/test_episode_db.py b/light/data_model/tests/test_episode_db.py new file mode 100644 index 000000000..204e683a4 --- /dev/null +++ b/light/data_model/tests/test_episode_db.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree.abs + +import unittest +import shutil, tempfile +from omegaconf import OmegaConf +import os +import json +import time + +from light.graph.elements.graph_nodes import GraphAgent +from light.graph.structured_graph import OOGraph +from light.world.world import World +from light.graph.events.graph_events import ArriveEvent, LeaveEvent, GoEvent, LookEvent +from light.world.content_loggers import AgentInteractionLogger, RoomInteractionLogger +from light.world.utils.json_utils import read_event_logs +from light.data_model.db.episodes import EpisodeDB, EpisodeLogType +from light.data_model.db.base import LightDBConfig + + +class TestEpisodesDB(unittest.TestCase): + """Unit tests for the EpisodeDB. Leverages Interaction Loggers to generate episodes""" + + def setUp(self): + self.data_dir = tempfile.mkdtemp() + self.config = LightDBConfig(backend="test", file_root=self.data_dir) + + def tearDown(self): + shutil.rmtree(self.data_dir) + + def setUp_single_room_graph(self): + # Set up the graph + opt = {"is_logging": True, "log_path": self.data_dir} + test_graph = OOGraph(opt) + agent_node = test_graph.add_agent("My test agent", {}) + room_node = test_graph.add_room("test room", {}) + agent_node.force_move_to(room_node) + test_world = World({}, None, True) + test_world.oo_graph = test_graph + return (test_graph, test_world, agent_node, room_node) + + def test_initialize_user_db(self): + """Ensure it's possible to initialize the db""" + db = EpisodeDB(self.config) + + def test_simple_room_logger_saves_and_loads_init_graph(self): + """ + Test that the room logger properly saves and reloads the initial + graph + """ + # Set up the graph + pre_time = time.time() + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph() + test_graph, test_world, agent_node, room_node = initial + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + room_logger.db = episode_db + + # Push a json episode out to the db + test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) + room_logger._begin_meta_episode() + room_logger._end_meta_episode() + + # Mark the end time to test queries later + episode_id = room_logger._last_episode_logged + post_time = time.time() + + # Ensure an episode was created properly + self.assertIsNotNone(episode_id) + episode = episode_db.get_episode(episode_id) + graph_map = episode.get_graph_map() + self.assertEqual(len(episode.graphs), 2, "Expected an init and final graph") + self.assertIsNotNone(episode.group) + self.assertIsNotNone(episode.split) + self.assertIsNotNone(episode.status) + self.assertEqual( + len(episode.actors), 0, f"No actors expected, found {episode.actors}" + ) + self.assertEqual( + len(episode.get_actors()), + 0, + f"No actors expected, found {episode.get_actors()}", + ) + self.assertEqual( + episode.turn_count, 0, f"No turns excpected, found {episode.turn_count}" + ) + self.assertEqual( + episode.human_count, 0, f"No humans expected, found {episode.human_count}" + ) + self.assertEqual( + episode.action_count, + 0, + f"No actions expected, found {episode.action_count}", + ) + self.assertIn( + episode.first_graph_id, graph_map, f"First graph not present in map" + ) + self.assertIn( + episode.final_graph_id, graph_map, f"Final graph not present in map" + ) + + # Test repr + episode.__repr__() + + # Check graph equivalence + before_graph = episode.get_before_graph(episode_db) + before_graph_json = before_graph.to_json_rv(room_node.node_id) + self.assertEqual(test_init_json, before_graph_json) + + after_graph = episode.get_after_graph(episode_db) + after_graph_json = after_graph.to_json_rv(room_node.node_id) + self.assertEqual(test_init_json, after_graph_json) + + # Check the parsed episode + events = episode.get_parsed_events(episode_db) + self.assertEqual(len(events), 0, f"Expected no events, found {events}") + + # Do some episode queries + episodes = episode_db.get_episodes() + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes( + min_creation_time=pre_time, max_creation_time=post_time + ) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(max_creation_time=pre_time) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + episodes = episode_db.get_episodes(min_creation_time=post_time) + self.assertEqual(len(episodes), 0, f"Expected 0 episode, found {episodes}") + episodes = episode_db.get_episodes(min_turns=0) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(min_turns=1) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + episodes = episode_db.get_episodes(min_humans=0) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(min_humans=1) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + episodes = episode_db.get_episodes(min_actions=0) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(min_actions=1) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + episodes = episode_db.get_episodes(log_type=EpisodeLogType.ROOM) + self.assertEqual(len(episodes), 1, f"Expected 1 episodes, found {episodes}") + episodes = episode_db.get_episodes(log_type=EpisodeLogType.AGENT) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + + def test_simple_room_logger_saves_and_loads_event(self): + """ + Test that the room logger properly saves and reloads an event + """ + # Set up the graph + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph() + test_graph, test_world, agent_node, room_node = initial + agent_node.is_player = True + agent_node.user_id = "test" + room2_node = test_graph.add_room("test room2", {}) + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + room_logger.db = episode_db + + # Check an event json was done correctly + test_event = ArriveEvent(agent_node, text_content="") + test_init_json = test_world.oo_graph.to_json_rv(agent_node.get_room().node_id) + room_logger.observe_event(test_event) + test_event2 = LookEvent(agent_node) + room_logger.observe_event(test_event2) + room_logger._end_meta_episode() + + ref_json = test_event2.to_json() + episode_id = room_logger._last_episode_logged + + # Ensure an episode was created properly + self.assertIsNotNone(episode_id) + episode = episode_db.get_episode(episode_id) + events = episode.get_parsed_events(episode_db) + + event_graph = events[0][0] + event_list = events[0][1] + loaded_event = event_list[0] + + # Assert the loaded event is the same as the executed one + self.assertEqual(loaded_event.to_json(), ref_json) + + # Assert that episode queries with users + self.assertEqual(episode.human_count, 1, "Expected one human") + self.assertEqual(episode.get_actors(), ["test"], "Expected one actor") + episodes = episode_db.get_episodes(min_humans=1) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(user_id="test") + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(user_id="nonexist") + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + + def test_simple_agent_logger_saves_and_loads_init_graph(self): + """ + Test that the agent logger properly saves and reloads the initial + graph + """ + # Set up the graph + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph() + test_graph, test_world, agent_node, room_node = initial + + # Check the graph json was done correctly from agent's room + test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) + agent_logger = AgentInteractionLogger(test_graph, agent_node) + agent_logger.db = episode_db + agent_logger._begin_meta_episode() + agent_logger._end_meta_episode() + + # Mark the end time to test queries later + episode_id = agent_logger._last_episode_logged + post_time = time.time() + + # Ensure an episode was created properly + self.assertIsNotNone(episode_id) + episode = episode_db.get_episode(episode_id) + graph_map = episode.get_graph_map() + self.assertEqual(len(episode.graphs), 2, "Expected an init and final graph") + self.assertIsNotNone(episode.group) + self.assertIsNotNone(episode.split) + self.assertIsNotNone(episode.status) + self.assertEqual( + len(episode.actors), 0, f"No actors expected, found {episode.actors}" + ) + self.assertEqual( + len(episode.get_actors()), + 0, + f"No actors expected, found {episode.get_actors()}", + ) + self.assertEqual( + episode.turn_count, 0, f"No turns excpected, found {episode.turn_count}" + ) + self.assertEqual( + episode.human_count, 0, f"No humans expected, found {episode.human_count}" + ) + self.assertEqual( + episode.action_count, + 0, + f"No actions expected, found {episode.action_count}", + ) + self.assertIn( + episode.first_graph_id, graph_map, f"First graph not present in map" + ) + self.assertIn( + episode.final_graph_id, graph_map, f"Final graph not present in map" + ) + + # Test repr + episode.__repr__() + + # Check graph equivalence + before_graph = episode.get_before_graph(episode_db) + before_graph_json = before_graph.to_json_rv(room_node.node_id) + self.assertEqual(test_init_json, before_graph_json) + + after_graph = episode.get_after_graph(episode_db) + after_graph_json = after_graph.to_json_rv(room_node.node_id) + self.assertEqual(test_init_json, after_graph_json) + + # Check the parsed episode + events = episode.get_parsed_events(episode_db) + self.assertEqual(len(events), 0, f"Expected no events, found {events}") + + # Check some episode queries + episodes = episode_db.get_episodes(log_type=EpisodeLogType.AGENT) + self.assertEqual(len(episodes), 1, f"Expected 1 episodes, found {episodes}") + episodes = episode_db.get_episodes(log_type=EpisodeLogType.ROOM) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + + def test_simple_agent_logger_saves_and_loads_event(self): + """ + Test that the agent logger properly saves and reloads an event + """ + # Set up the graph + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph() + test_graph, test_world, agent_node, room_node = initial + agent_node.is_player = True + agent_node.user_id = "test" + room2_node = test_graph.add_room("test room2", {}) + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + room_logger.db = episode_db + + # Check an event json was done correctly + test_event = ArriveEvent(agent_node, text_content="") + test_init_json = test_world.oo_graph.to_json_rv(agent_node.get_room().node_id) + agent_logger = AgentInteractionLogger(test_graph, agent_node) + agent_logger.db = episode_db + agent_logger._begin_meta_episode() + agent_logger.observe_event(test_event) + test_event2 = LookEvent(agent_node) + agent_logger.observe_event(test_event2) + agent_logger._end_meta_episode() + ref_json = test_event2.to_json() + + episode_id = agent_logger._last_episode_logged + + # Ensure an episode was created properly + self.assertIsNotNone(episode_id) + episode = episode_db.get_episode(episode_id) + events = episode.get_parsed_events(episode_db) + self.assertEqual(len(events), 1, f"Expected 1 graph type, found {events}") + + event_graph = events[0][0] + event_list = events[0][1] + self.assertEqual( + len(event_list), 2, f"Expected 2 logged events, found {event_list}" + ) + loaded_event = event_list[1] + + # Assert the loaded event is the same as the executed one + self.assertEqual(loaded_event.to_json(), ref_json) + + # Assert that episode queries with users + self.assertEqual(episode.human_count, 1, "Expected one human") + self.assertEqual(episode.get_actors(), ["test"], "Expected one actor") + episodes = episode_db.get_episodes(min_humans=1) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(user_id="test") + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(user_id="nonexist") + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + + def test_simple_room_logger_e2e(self): + """ + Test that the room logger properly saves and reloads the graph and events + """ + # Set up the graph + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph() + test_graph, test_world, agent_node, room_node = initial + agent_node.is_player = True + agent_node.user_id = "test" + room_node2 = test_graph.add_room("test room2", {}) + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + room_logger.db = episode_db + test_graph.add_paths_between( + room_node, room_node2, "a path to the north", "a path to the south" + ) + test_graph.room_id_to_loggers[room_node.node_id]._add_player() + + # Check the room and event json was done correctly for room_node + event_room_node_observed = LeaveEvent( + agent_node, target_nodes=[room_node2] + ).to_json() + test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) + + GoEvent(agent_node, target_nodes=[room_node2]).execute(test_world) + + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + + episode_id = room_logger._last_episode_logged + episode = episode_db.get_episode(episode_id) + events = episode.get_parsed_events(episode_db) + self.assertEqual(len(events), 1, f"Expected 1 graph type, found {events}") + + event_graph = events[0][0] + + event_list = events[0][1] + self.assertEqual( + len(event_list), 2, f"Expected 2 logged events, found {event_list}" + ) + loaded_event = event_list[1] + + ref_json = json.loads(event_room_node_observed) + event_ref = json.loads(loaded_event.to_json()) + for k in ref_json: + if k == "event_id": + continue + elif k == "target_nodes": + self.assertEqual(ref_json[k][0]["names"], event_ref[k][0]["names"]) + else: + self.assertEqual( + ref_json[k], + event_ref[k], + f"Event Json should match for LeaveEvent, misses on {k}", + ) diff --git a/light/data_model/tests/test_user_db.py b/light/data_model/tests/test_user_db.py index 4c3bb91be..4d1372e86 100644 --- a/light/data_model/tests/test_user_db.py +++ b/light/data_model/tests/test_user_db.py @@ -10,7 +10,7 @@ from light.data_model.db.base import LightDBConfig from light.data_model.db.users import UserDB, PlayerStatus -config = LightDBConfig(backend="test") +config = LightDBConfig(backend="test", file_root="unused") class TestUserDB(unittest.TestCase): diff --git a/light/graph/elements/graph_nodes.py b/light/graph/elements/graph_nodes.py index 66ddf8360..9262302b3 100644 --- a/light/graph/elements/graph_nodes.py +++ b/light/graph/elements/graph_nodes.py @@ -589,6 +589,7 @@ def __init__(self, node_id, name, props=None, db_id=None): # Flag to resolve when a death event is in the stack, but possibly not processed self._dying = False self.is_player = self._props.get("is_player", self._props.get("_human", False)) + self.user_id = self._props.get("user_id", None) self.usually_npc = self._props.get("usually_npc", False) self.pacifist = self._props.get("pacifist", False) self.tags = self._props.get("tags", self.DEFAULT_TAGS) diff --git a/light/graph/events/base.py b/light/graph/events/base.py index 041ddebf2..dd9aa8d3a 100644 --- a/light/graph/events/base.py +++ b/light/graph/events/base.py @@ -88,7 +88,7 @@ def __init__( """ if event_id is None: event_id = str(uuid4()) - self.executed: bool = False # type: ignore + self.executed: bool = False self.actor = actor self.room = actor.get_room() self.target_nodes = [] if target_nodes is None else target_nodes diff --git a/light/graph/events/graph_events.py b/light/graph/events/graph_events.py index 3fb8e3e1a..e20038e8e 100644 --- a/light/graph/events/graph_events.py +++ b/light/graph/events/graph_events.py @@ -663,9 +663,9 @@ def execute(self, world: "World") -> List[GraphEvent]: health = self.actor.health eps = self.actor.movement_energy_cost if health > eps: - health_text = world.health(self.actor.node_id) + health_text = world.view.get_health_text_for(self.actor.node_id) self.actor.health = max(0, health - eps) - new_health_text = world.health(self.actor.node_id) + new_health_text = world.view.get_health_text_for(self.actor.node_id) if health_text != new_health_text: HealthEvent(self.actor, text_content="HealthOnMoveEvent").execute(world) @@ -2436,7 +2436,7 @@ def execute(self, world: "World") -> List[GraphEvent]: ), f"Can only equip GraphObjects, not {equip_target}" # The current children of EquipObjectEvent have ONLY one name. # Joining for any future possibility that may have more than one. - equip_target.equipped = ','.join(self.NAMES) + equip_target.equipped = ",".join(self.NAMES) for n, s in equip_target.get_prop("stats", {"defense": 1}).items(): self.actor.set_prop(n, self.actor.get_prop(n) + s) if equip_target.wieldable: @@ -2724,9 +2724,9 @@ def execute(self, world: "World") -> List[GraphEvent]: world.broadcast_to_room(self) - health_text = world.health(self.actor.node_id) + health_text = world.view.get_health_text_for(self.actor.node_id) self.actor.health = max(self.actor.health + fe, 0) - new_health_text = world.health(self.actor.node_id) + new_health_text = world.view.get_health_text_for(self.actor.node_id) if self.actor.health <= 0: DeathEvent(self.actor).execute(world) elif health_text != new_health_text: @@ -3599,7 +3599,7 @@ def execute(self, world: "World") -> List[GraphEvent]: """ assert not self.executed self.__actor_name = self.actor.get_prefix_view() - self.__health_text = world.health(self.actor.node_id) + self.__health_text = world.view.get_health_text_for(self.actor.node_id) to_agents = [self.actor] for t in self.target_nodes: to_agents.append(t) diff --git a/light/world/content_loggers.py b/light/world/content_loggers.py index f0021bcd5..83bf4b797 100644 --- a/light/world/content_loggers.py +++ b/light/world/content_loggers.py @@ -9,6 +9,7 @@ import os import time import uuid +from light.data_model.db.episodes import DBGroupName, EpisodeLogType # TODO: Investigate changing the format from 3 line to csv or some other standard from light.graph.events.graph_events import ( @@ -16,9 +17,18 @@ DeathEvent, LeaveEvent, SoulSpawnEvent, + SayEvent, + TellEvent, + ShoutEvent, + WhisperEvent, ) -DEFAULT_LOG_PATH = "".join([os.path.abspath(os.path.dirname(__file__)), "/../../logs"]) +from typing import Optional, List, Set, Dict, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from light.data_model.db.episodes import EpisodeDB + from light.graph.structured_graph import OOGraph + from light.graph.elements.graph_nodes import GraphAgent class InteractionLogger(abc.ABC): @@ -27,25 +37,28 @@ class InteractionLogger(abc.ABC): location to write data, as well as defines some methods for interfacing """ - def __init__(self, graph, data_path): - self.data_path = data_path + def __init__(self, graph: "OOGraph", db: Optional["EpisodeDB"]): + self.db = db self.graph = graph + self.players: Set[str] = set() + self.actions: int = 0 + self._last_episode_logged: Optional[int] = None # All loggers should have graph state history and a buffer for events # State history is just the json of the graph the event executed on - self.state_history = [] - # Event buffer is (state_history_idx, event_hash, timestamp, event_json) + self.state_history: List[str] = [] + # Event buffer is (state_history_idx, event_hash, event_json, timestamp) # where state_history_idx is the index of the graph the event executed on - self.event_buffer = [] + self.event_buffer: List[Tuple[int, str, str, float]] = [] - def _begin_meta_episode(self): + def _begin_meta_episode(self) -> None: """ Handles any preprocessing associated with beginning a meta episode such as clearing buffers and recording initial state """ raise NotImplementedError - def _end_meta_episode(self): + def _end_meta_episode(self) -> None: """ Handles any postprocessing associated with the end of a meta episode such as flushing buffers by writing to data location, and updating variables @@ -53,71 +66,66 @@ def _end_meta_episode(self): self._log_interactions() raise NotImplementedError - def _log_interactions(self): - """ - Writes out the buffers to the location specified by data location, - handling any data specific formatting - """ - raise NotImplementedError - - def observe_event(self, event): + def observe_event(self, event) -> None: """ Examine event passed in, deciding how to save it to the logs """ raise NotImplementedError - def _dump_graphs(self): + def _prep_graphs(self) -> List[Dict[str, str]]: """ - This method is responsible for dumping the graphs of the event logger - to file, recording the identifiers used for the graphs + This method is responsible for preparing the graphs for this event logger """ - # First, check graph path, then write the graph dump - if not os.path.exists(self.data_path): - os.mkdir(self.data_path) - graph_path = os.path.join(self.data_path, "light_graph_dumps") - if not os.path.exists(graph_path): - os.mkdir(graph_path) - states = [] - for state in self.state_history: - unique_graph_name = str(uuid.uuid4()) - states.append(unique_graph_name) + for idx, state in enumerate(self.state_history): + rand_id = str(uuid.uuid4())[:8] + unique_graph_name = f"{time.time():.0f}-{idx}-{rand_id}" graph_file_name = f"{unique_graph_name}.json" - file_path = os.path.join(graph_path, graph_file_name) - with open(file_path, "w") as dump_file: - dump_file.write(state) + states.append( + { + "key": unique_graph_name, + "filename": graph_file_name, + "graph_json": state, + } + ) return states - def _dump_events(self, graph_states, pov, id_): + def _prep_events( + self, + graph_states: List[Dict[str, str]], + target_id: str, + ) -> Tuple[str, List[Dict[str, str]]]: """ This method is responsible for dumping the event logs, referencing the - graph files recorded in graph_states. An event log consist of events, where - an event consist of 3 lines: - serialized_graph_filename event_hash - timestamp - event_json - Event logs are named: {id}_{unique_identifier}.log - and are stored in the `pov/` directory - + graph files recorded in graph_states. """ - # Now, do the same for events, dumping in the light_event_dumps/rooms - events_path = os.path.join(self.data_path, "light_event_dumps") - if not os.path.exists(events_path): - os.mkdir(events_path) - events_path_dir = os.path.join(events_path, pov) - if not os.path.exists(events_path_dir): - os.mkdir(events_path_dir) - - unique_event_name = str(uuid.uuid4()) - id_name = f"{id_}".replace(" ", "_") - event_file_name = f"{id_name}_{unique_event_name}_events.log" - events_file_path = os.path.join(events_path_dir, event_file_name) - with open(events_file_path, "w") as dump_file: - for (idx, hashed, event, time_) in self.event_buffer: - dump_file.write("".join([graph_states[idx], " ", str(hashed), "\n"])) - dump_file.write("".join([time_, "\n"])) - dump_file.write("".join([event, "\n"])) - return events_file_path + unique_event_name = str(uuid.uuid4())[:8] + id_name = f"{target_id}".replace(" ", "_")[:20] + event_file_name = f"{id_name}_{time.time():.0f}_{unique_event_name}_events.json" + events = [] + for (graph_idx, hashed, event, timestamp) in self.event_buffer: + events.append( + { + "graph_key": graph_states[graph_idx]["key"], + "hash": hashed, + "event_json": event, + } + ) + return (event_file_name, events) + + def _log_interactions(self, episode_type: "EpisodeLogType", target_id: str) -> None: + if self.db is None: + return # not actually logging + graphs = self._prep_graphs() + events = self._prep_events(graphs, target_id) + self._last_episode_logged = self.db.write_episode( + graphs=graphs, + events=events, + log_type=episode_type, + action_count=self.actions, + players=self.players, + group=DBGroupName.PRE_LAUNCH, # TODO make configurable? + ) class AgentInteractionLogger(InteractionLogger): @@ -125,56 +133,41 @@ class AgentInteractionLogger(InteractionLogger): This interaction logger attaches to human agents in the graph, logging all events the human observes. This logger also requires serializing more rooms, since agent encounters many rooms along its traversal These events go into - the conversation buffer, which is then sent to `.log` files - at the specified path - - context_buffers serve an important role in this class to avoid bloating the - event logs. Context_buffers will log a fixed number of the most recent events - when: - - 1. The player goes afk. This has the potential to avoid logging lots of noise - in the room that does not provide any signal on human player interactions. - When the player comes back to the game, our loggers send some context of - the most recent events to the log + the conversation buffer, which is then stored in the provided EpisodeDB """ def __init__( self, - graph, - agent, - data_path=DEFAULT_LOG_PATH, - is_active=False, - max_context_history=5, - afk_turn_tolerance=25, + graph: "OOGraph", + agent: "GraphAgent", + db: Optional["EpisodeDB"] = None, + is_active: bool = False, + afk_turn_tolerance: int = 30, ): - super().__init__(graph, data_path) + super().__init__(graph, db) self.agent = agent - self.max_context_history = max_context_history self.afk_turn_tolerance = afk_turn_tolerance if graph._opt is None: self.is_active = is_active else: - self.data_path = graph._opt.get("log_path", DEFAULT_LOG_PATH) self.is_active = graph._opt.get("is_logging", False) - self.turns_wo_player_action = ( - 0 # Player is acting by virtue of this initialized! - ) - self.context_buffer = collections.deque(maxlen=max_context_history) + self.turns_wo_player_action = 0 self._logging_intialized = False - def _begin_meta_episode(self): + def _begin_meta_episode(self) -> None: self._clear_buffers() self._add_current_graph_state() self.turns_wo_player_action = 0 + self.actions = 0 self._logging_intialized = True - def _clear_buffers(self): + def _clear_buffers(self) -> None: """Clear the buffers storage for this logger, dumping context""" self.state_history.clear() self.event_buffer.clear() - def _add_current_graph_state(self): + def _add_current_graph_state(self) -> None: """Make a copy of the graph state so we can replay events on top of it""" try: self.state_history.append( @@ -187,63 +180,58 @@ def _add_current_graph_state(self): traceback.print_exc() raise - def _is_player_afk(self): + def _is_player_afk(self) -> bool: return self.turns_wo_player_action >= self.afk_turn_tolerance - def _end_meta_episode(self): + def _end_meta_episode(self) -> None: self._logging_intialized = False - self._log_interactions() - - def _log_interactions(self): - - graph_states = self._dump_graphs() - self._last_graphs = graph_states - events_file_path = self._dump_events(graph_states, "agent", self.agent.node_id) - # Used for testing - self._last_event_log = events_file_path + self._add_current_graph_state() + self._log_interactions(EpisodeLogType.AGENT, self.agent.node_id) - def observe_event(self, event): + def observe_event(self, event) -> None: if not self.is_active: return event_t = type(event) if event_t is SoulSpawnEvent and not self._logging_intialized: self._begin_meta_episode() + elif self._is_player_afk(): + if event.actor is self.agent and not self._logging_intialized: + self._begin_meta_episode() + return # Did not have prior graph state, can't log this event + else: + return # skip events while AFK - # Get new room state + # Get new room state when moving if event_t is ArriveEvent and event.actor is self.agent: # NOTE: If this is before executing event, not reliable! self._add_current_graph_state() + elif event_t not in [TellEvent, SayEvent, ShoutEvent, WhisperEvent]: + self.actions += 1 - # Store context from bots, or store current events - if self._is_player_afk() and event.actor is not self.agent: - self.context_buffer.append( - ( - len(self.state_history) - 1, - event.__hash__(), - event.to_json(), - time.ctime(), - ) - ) + # Keep track of presence + if event.actor is self.agent: + self.turns_wo_player_action = 0 else: - if event.actor is self.agent: - if self._is_player_afk(): - self.event_buffer.extend(self.context_buffer) - self.context_buffer.clear() - self.turns_wo_player_action = 0 - else: - self.turns_wo_player_action += 1 - self.event_buffer.append( - ( - len(self.state_history) - 1, - event.__hash__(), - event.to_json(), - time.ctime(), - ) + self.turns_wo_player_action += 1 + + if event.actor.is_player: + user_id = event.actor.user_id + if user_id is not None and user_id not in self.players: + self.players.add(event.actor.user_id) + + # Append the particular event + self.event_buffer.append( + ( + len(self.state_history) - 1, + event.__hash__(), + event.to_json(), + time.time(), ) + ) - if ( - event_t is DeathEvent and event.actor is self.agent - ): # If agent is exiting or dieing or something, end meta episode + if (event_t is DeathEvent and event.actor is self.agent) or ( + self._is_player_afk() + ): # If agent is exiting or dying or afk, end meta episode self._end_meta_episode() @@ -252,45 +240,27 @@ class RoomInteractionLogger(InteractionLogger): This interaction logger attaches to a room level node in the graph, logging all events which take place with human agents in the room as long as a player is still in the room. These events go into the conversation buffer, which is - then sent to `.log` files at the specified path - - - context_buffers serve an important role in this class to avoid bloating the - event logs. context_buffers will log a fixed number of the most recent events - when: - - 1. There are no players in the room. This is a potential use case when an agent - enters a conversation between 2 or more models, and we want some context for - training purposes - - 2. All players go afk. This has the potential to avoid logging lots of noise - in the room that does not provide any signal on human player interactions. - When players come back to the game, our loggers send context of the most - recent events to the log + then logged in the provided EpisodeDB """ def __init__( self, - graph, - room_id, - data_path=DEFAULT_LOG_PATH, - is_active=False, - max_context_history=5, - afk_turn_tolerance=10, + graph: "OOGraph", + room_id: str, + db: Optional["EpisodeDB"] = None, + is_active: bool = False, + afk_turn_tolerance: int = 30, ): - super().__init__(graph, data_path) - self.room_id = room_id - self.max_context_history = max_context_history + super().__init__(graph, db) + self.room_id: str = room_id self.afk_turn_tolerance = afk_turn_tolerance if graph._opt is None: self.is_active = is_active else: - self.data_path = graph._opt.get("log_path", DEFAULT_LOG_PATH) self.is_active = graph._opt.get("is_logging", False) self.num_players_present = 0 self.turns_wo_players = float("inf") # Technically, we have never had players - self.context_buffer = collections.deque(maxlen=max_context_history) # Initialize player count here (bc sometimes players are force moved) for node_id in self.graph.all_nodes[self.room_id].contained_nodes: @@ -299,19 +269,18 @@ def __init__( ): self._add_player() - def _begin_meta_episode(self): + def _begin_meta_episode(self) -> None: self._clear_buffers() self._add_current_graph_state() self.turns_wo_players = 0 + self.actions = 0 - def _clear_buffers(self): - """Clear the buffers storage for this logger, dumping context""" + def _clear_buffers(self) -> None: + """Clear the buffers storage for this logger""" self.state_history.clear() self.event_buffer.clear() - self.event_buffer.extend(self.context_buffer) - self.context_buffer.clear() - def _add_current_graph_state(self): + def _add_current_graph_state(self) -> None: """Make a copy of the graph state so we can replay events on top of it""" try: self.state_history.append(self.graph.to_json_rv(self.room_id)) @@ -322,24 +291,17 @@ def _add_current_graph_state(self): traceback.print_exc() raise - def _is_logging(self): + def _is_logging(self) -> bool: return self.num_players_present > 0 - def _is_players_afk(self): + def _is_players_afk(self) -> bool: return self.turns_wo_players >= self.afk_turn_tolerance - def _end_meta_episode(self): - self._log_interactions() - self.context_buffer.clear() - - def _log_interactions(self): - graph_states = self._dump_graphs() - self._last_graphs = graph_states - events_file_path = self._dump_events(graph_states, "room", self.room_id) - # Used for testing - self._last_event_log = events_file_path + def _end_meta_episode(self) -> None: + self._add_current_graph_state() + self._log_interactions(EpisodeLogType.ROOM, self.room_id) - def _add_player(self): + def _add_player(self) -> None: """ Record that a player entered the room, updating variables as needed""" if not self.is_active: return @@ -347,7 +309,7 @@ def _add_player(self): self._begin_meta_episode() self.num_players_present += 1 - def _remove_player(self): + def _remove_player(self) -> None: """ Record that a player left the room, updating variables as needed""" if not self.is_active: return @@ -356,7 +318,7 @@ def _remove_player(self): if not self._is_logging(): self._end_meta_episode() - def observe_event(self, event): + def observe_event(self, event) -> None: if not self.is_active: return @@ -365,45 +327,46 @@ def observe_event(self, event): if ( event_t is ArriveEvent or event_t is SoulSpawnEvent ) and self.human_controlled(event): + if not self._is_logging(): + self._add_player() + return # Add and return to start logging self._add_player() - # Store context from bots, or store current events - if not self._is_logging() or ( - self._is_players_afk() and not self.human_controlled(event) - ): - self.context_buffer.append( - ( - len(self.state_history) - 1, - event.__hash__(), - event.to_json(), - time.ctime(), - ) - ) - else: - if self.human_controlled(event): - # Players are back from AFK, dump context - if self._is_players_afk(): - # TODO: Need to handle something related to graph state here(?) - self.event_buffer.extend(self.context_buffer) - self.context_buffer.clear() - self.turns_wo_players = 0 + if self._is_players_afk() or not self._is_logging(): + if not self.human_controlled(event): + return # Skip these events else: - self.turns_wo_players += 1 - self.event_buffer.append( - ( - len(self.state_history) - 1, - event.__hash__(), - event.to_json(), - time.ctime(), - ) + self._begin_meta_episode() + return # Don't have previous context, will start on the next one + + if event_t not in [TellEvent, SayEvent, ShoutEvent, WhisperEvent]: + self.actions += 1 + + # Keep track of human events + if self.human_controlled(event): + user_id = event.actor.user_id + if user_id is not None and user_id not in self.players: + self.players.add(event.actor.user_id) + self.turns_wo_players = 0 + else: + self.turns_wo_players += 1 + + # Add to buffer + self.event_buffer.append( + ( + len(self.state_history) - 1, + event.__hash__(), + event.to_json(), + time.time(), ) + ) - if (event_t is LeaveEvent or event_t is DeathEvent) and self.human_controlled( - event - ): + if (event_t in [LeaveEvent, DeathEvent]) and self.human_controlled(event): self._remove_player() + if self._is_players_afk(): + self._end_meta_episode() - def human_controlled(self, event): + def human_controlled(self, event) -> bool: """ Determines if an event is controlled by a human or not """ diff --git a/light/world/tests/test_loggers.py b/light/world/tests/test_loggers.py index a8eab00d3..17afef539 100644 --- a/light/world/tests/test_loggers.py +++ b/light/world/tests/test_loggers.py @@ -62,7 +62,6 @@ def test_init_room_logger(self): test_graph, test_world, agent_node, room_node = initial logger = test_graph.room_id_to_loggers[room_node.node_id] - self.assertEqual(logger.data_path, self.data_dir) self.assertEqual(logger.graph, test_graph) self.assertEqual(logger.state_history, []) self.assertEqual(logger.event_buffer, []) @@ -70,7 +69,6 @@ def test_init_room_logger(self): self.assertFalse(logger._is_logging()) self.assertTrue(logger._is_players_afk()) self.assertTrue(logger.is_active) - self.assertEqual(len(logger.context_buffer), 0) def test_init_agent_logger(self): """ @@ -80,7 +78,6 @@ def test_init_agent_logger(self): test_graph, test_world, agent_node, room_node = initial logger = AgentInteractionLogger(test_graph, agent_node) - self.assertEqual(logger.data_path, self.data_dir) self.assertEqual(logger.graph, test_graph) self.assertEqual(logger.state_history, []) self.assertEqual(logger.event_buffer, []) @@ -88,12 +85,11 @@ def test_init_agent_logger(self): self.assertFalse(logger._logging_intialized) self.assertFalse(logger._is_player_afk()) self.assertTrue(logger.is_active) - self.assertEqual(len(logger.context_buffer), 0) def test_begin_meta_episode_room_logger(self): """ Test calling begin_meta_episode: - - Clears all the buffers (context into event if nonempty) + - Clears all the buffers - adds the graph state from the room POV - counts as a turn of player action - initializes logger @@ -104,13 +100,11 @@ def test_begin_meta_episode_room_logger(self): logger = test_graph.room_id_to_loggers[room_node.node_id] logger.event_buffer.append("Testing NOT!") logger.state_history.append("Testing") - logger.context_buffer.append("Testing") logger._begin_meta_episode() self.assertFalse(logger._is_players_afk()) - self.assertEqual(len(logger.context_buffer), 0) - self.assertEqual(logger.event_buffer, ["Testing"]) - self.assertEqual(len(logger.state_history), 1) + self.assertEqual(len(logger.state_history), 1, "Had extra in buffer") + self.assertEqual(len(logger.event_buffer), 0, "Had extra in buffer") self.assertEqual( logger.state_history[-1], test_graph.to_json_rv(logger.room_id) ) @@ -133,7 +127,6 @@ def test_begin_meta_episode_agent_logger(self): self.assertFalse(logger._is_player_afk()) self.assertTrue(logger._logging_intialized) - self.assertEqual(len(logger.context_buffer), 0) self.assertEqual(len(logger.event_buffer), 0) self.assertEqual(len(logger.state_history), 1) self.assertEqual( @@ -144,18 +137,14 @@ def test_begin_meta_episode_agent_logger(self): def test_end_meta_episode_room_logger(self): """ Test calling end_meta_episode: - - Clears the context buffer ** Note, future test check that things are written properly """ initial = self.setUp_single_room_graph() test_graph, test_world, agent_node, room_node = initial logger = test_graph.room_id_to_loggers[room_node.node_id] - logger.context_buffer.append("Testing") logger._end_meta_episode() - self.assertEqual(len(logger.context_buffer), 0) - def test_end_meta_episode_agent_logger(self): """ Test calling end_meta_episode: @@ -183,13 +172,11 @@ def test_add_player_room_logger(self): logger.event_buffer.append("Testing NOT!") logger.state_history.append("Testing") - logger.context_buffer.append("Testing") logger._add_player() self.assertTrue(logger._is_logging()) self.assertFalse(logger._is_players_afk()) - self.assertEqual(len(logger.context_buffer), 0) - self.assertEqual(logger.event_buffer, ["Testing"]) + self.assertEqual(len(logger.event_buffer), 0) self.assertEqual(len(logger.state_history), 1) self.assertEqual( logger.state_history[-1], test_graph.to_json_rv(logger.room_id) @@ -198,8 +185,7 @@ def test_add_player_room_logger(self): # Another player just ups the count logger._add_player() - self.assertEqual(len(logger.context_buffer), 0) - self.assertEqual(logger.event_buffer, ["Testing"]) + self.assertEqual(len(logger.event_buffer), 0) self.assertEqual(len(logger.state_history), 1) self.assertEqual( logger.state_history[-1], test_graph.to_json_rv(logger.room_id) @@ -226,7 +212,6 @@ def test_remove_player_room_logger(self): # Another player is 0, end episode logger._remove_player() self.assertFalse(logger._is_logging()) - self.assertEqual(len(logger.context_buffer), 0) self.assertEqual(logger.num_players_present, 0) def test_observer_event_goes_context_room_logger(self): @@ -245,12 +230,10 @@ def test_observer_event_goes_context_room_logger(self): test_event5 = ArriveEvent(agent_node, text_content="hello5") test_event6 = ArriveEvent(agent_node, text_content="hello6") - # No player in room, so this should go to context + # No player in room, so this should be skipped logger.observe_event(test_event1) self.assertFalse(logger._is_logging()) self.assertEqual(len(logger.event_buffer), 0) - self.assertEqual(len(logger.context_buffer), 1) - logger.observe_event(test_event2) logger.observe_event(test_event3) logger.observe_event(test_event4) @@ -258,15 +241,6 @@ def test_observer_event_goes_context_room_logger(self): logger.observe_event(test_event6) self.assertFalse(logger._is_logging()) self.assertEqual(len(logger.event_buffer), 0) - self.assertEqual(len(logger.context_buffer), 5) - events = [json for _, _, json, _ in logger.context_buffer] - self.assertFalse(test_event1.to_json() in events) - - # player added, should be in event buffer - logger._add_player() - self.assertTrue(logger._is_logging()) - self.assertEqual(len(logger.event_buffer), 5) - self.assertEqual(len(logger.context_buffer), 0) def test_observe_event_room_logger(self): """ @@ -284,9 +258,6 @@ def test_observe_event_room_logger(self): self.assertTrue(logger._is_logging()) self.assertEqual(len(logger.event_buffer), 1) - self.assertEqual(len(logger.context_buffer), 0) - events = [json for _, _, json, _ in logger.event_buffer] - self.assertTrue(test_event1.to_json() in events) def test_observe_event_agent_logger(self): """ @@ -302,13 +273,10 @@ def test_observe_event_agent_logger(self): logger.observe_event(test_event1) self.assertEqual(len(logger.event_buffer), 1) - self.assertEqual(len(logger.context_buffer), 0) - events = [json for _, _, json, _ in logger.event_buffer] - self.assertTrue(test_event1.to_json() in events) def test_afk_observe_event_room_logger(self): """ - Test that after 10 turns with no player, fill buffer, then dumps into main! + Test that after 30 turns with no player, clear when returns! """ initial = self.setUp_single_room_graph() test_graph, test_world, agent_node, room_node = initial @@ -316,24 +284,22 @@ def test_afk_observe_event_room_logger(self): logger._add_player() test_event1 = ArriveEvent(agent_node, text_content="hello1") - for i in range(20): + for i in range(30): logger.observe_event(test_event1) # Only up to 5 in buffer, that is the limit self.assertTrue(logger._is_players_afk()) - self.assertEqual(len(logger.event_buffer), 10) - self.assertEqual(len(logger.context_buffer), 5) + self.assertEqual(len(logger.event_buffer), 30) - # Now, player event - dump to buffer! + # Now, player event - clear buffer! agent_node.is_player = True logger.observe_event(test_event1) self.assertFalse(logger._is_players_afk()) - self.assertEqual(len(logger.event_buffer), 16) - self.assertEqual(len(logger.context_buffer), 0) + self.assertEqual(len(logger.event_buffer), 0) def test_afk_observe_event_agent_logger(self): """ - Test that after 25 turns with no player, fill buffer, then dumps into main! + Test that after 30 turns with no player, clear, then start new episode! """ initial = self.setUp_single_room_graph() test_graph, test_world, agent_node, room_node = initial @@ -347,165 +313,13 @@ def test_afk_observe_event_agent_logger(self): logger.observe_event(test_event1) self.assertTrue(logger._is_player_afk()) - self.assertEqual(len(logger.event_buffer), 25) - self.assertEqual(len(logger.context_buffer), 5) + self.assertEqual(len(logger.event_buffer), 30) test_event2 = ArriveEvent(agent_node, text_content="hello2") logger.observe_event(test_event2) + logger.observe_event(test_event2) self.assertFalse(logger._is_player_afk()) - self.assertEqual(len(logger.event_buffer), 31) - self.assertEqual(len(logger.context_buffer), 0) - - def test_simple_room_logger_saves_and_loads_init_graph(self): - """ - Test that the room logger properly saves and reloads the initial - graph - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - - # Check the room json was done correctly - test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) - room_logger._begin_meta_episode() - room_logger._end_meta_episode() - graph_file = os.path.join( - self.data_dir, "light_graph_dumps", f"{room_logger._last_graphs[-1]}.json" - ) - with open(graph_file, "r") as graph_json_file: - written_init_json = graph_json_file.read() - self.assertEqual(test_init_json, written_init_json) - - def test_simple_room_logger_saves_and_loads_event(self): - """ - Test that the room logger properly saves and reloads an event - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - agent_node.is_player = True - room2_node = test_graph.add_room("test room2", {}) - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - - # Check an event json was done correctly - test_event = ArriveEvent(agent_node, text_content="") - test_init_json = test_world.oo_graph.to_json_rv(agent_node.get_room().node_id) - room_logger.observe_event(test_event) - room_logger._end_meta_episode() - - ref_json = test_event.to_json() - event_file = room_logger._last_event_log - self.assertNotEqual(os.stat(event_file).st_size, 0) - buff = read_event_logs(event_file) - assert len(buff) == 1 - - world_name, hash_, timestamp, written_event = buff[0] - self.assertEqual(world_name, room_logger._last_graphs[-1]) - self.assertEqual(hash_, str(test_event.__hash__())) - ref_json = json.loads(ref_json) - event_ref = json.loads(written_event) - self.assertEqual(event_ref, ref_json) - - def test_simple_agent_logger_saves_and_loads_init_graph(self): - """ - Test that the agent logger properly saves and reloads the initial - graph - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - - # Check the graph json was done correctly from agent's room - test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) - agent_logger = AgentInteractionLogger(test_graph, agent_node) - agent_logger._begin_meta_episode() - agent_logger._end_meta_episode() - graph_file = os.path.join( - self.data_dir, "light_graph_dumps", f"{agent_logger._last_graphs[-1]}.json" - ) - with open(graph_file, "r") as graph_json_file: - written_init_json = graph_json_file.read() - self.assertEqual(test_init_json, written_init_json) - - def test_simple_agent_logger_saves_and_loads_event(self): - """ - Test that the agent logger properly saves and reloads an event - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - agent_node.is_player = True - room2_node = test_graph.add_room("test room2", {}) - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - - # Check an event json was done correctly - test_event = ArriveEvent(agent_node, text_content="") - test_init_json = test_world.oo_graph.to_json_rv(agent_node.get_room().node_id) - agent_logger = AgentInteractionLogger(test_graph, agent_node) - agent_logger._begin_meta_episode() - agent_logger.observe_event(test_event) - agent_logger._end_meta_episode() - ref_json = test_event.to_json() - event_file = agent_logger._last_event_log - self.assertNotEqual(os.stat(event_file).st_size, 0) - buff = read_event_logs(event_file) - assert len(buff) == 1 - - world_name, hash_, timestamp, written_event = buff[0] - self.assertEqual(world_name, agent_logger._last_graphs[-1]) - self.assertEqual(hash_, str(test_event.__hash__())) - ref_json = json.loads(ref_json) - event_ref = json.loads(written_event) - self.assertEqual(event_ref, ref_json) - - def test_simple_room_logger_e2e(self): - """ - Test that the room logger properly saves and reloads the graph and events - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - agent_node.is_player = True - room_node2 = test_graph.add_room("test room2", {}) - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - test_graph.add_paths_between( - room_node, room_node2, "a path to the north", "a path to the south" - ) - test_graph.room_id_to_loggers[room_node.node_id]._add_player() - - # Check the room and event json was done correctly for room_node - event_room_node_observed = LeaveEvent( - agent_node, target_nodes=[room_node2] - ).to_json() - test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) - - GoEvent(agent_node, target_nodes=[room_node2]).execute(test_world) - - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - graph_file = os.path.join( - self.data_dir, "light_graph_dumps", f"{room_logger._last_graphs[-1]}.json" - ) - self.assertNotEqual(os.stat(graph_file).st_size, 0) - with open(graph_file, "r") as graph_json_file: - written_init_json = graph_json_file.read() - self.assertEqual(test_init_json, written_init_json) - event_file = room_logger._last_event_log - self.assertNotEqual(os.stat(event_file).st_size, 0) - buff = read_event_logs(event_file) - # Go event triggers a leave event as well! - assert len(buff) == 2 - - world_name, hash_, timestamp, written_event = buff[1] - self.assertEqual(world_name, room_logger._last_graphs[-1]) - ref_json = json.loads(event_room_node_observed) - event_ref = json.loads(written_event) - for k in ref_json: - if k == "event_id": - continue - self.assertEqual( - ref_json[k], event_ref[k], f"Event Json should match for LeaveEvent" - ) + self.assertEqual(len(logger.event_buffer), 1) if __name__ == "__main__": diff --git a/light/world/views.py b/light/world/views.py index 26bbebe78..dd7c9ec84 100644 --- a/light/world/views.py +++ b/light/world/views.py @@ -57,7 +57,7 @@ def get_inventory_text_for(self, id, give_empty=True): def get_health_text_for(self, id): """Return the text description of someone's numeric health""" # TODO get the correct values - health = self.world.get_prop(id, "health") + health = self.world.oo_graph.get_node(id).get_prop("health") if health < 0: health = 0 if health is None or health is False: @@ -217,7 +217,7 @@ def name_prefix(self, node, txt, use_the): def name_prefix_id(self, id, txt, use_the): """Get the prefix to prepend an object with in text form""" # Get the preferred prefix type. - pre = self.world.get_prop(id, "name_prefix") + pre = self.world.oo_graph.get_node(id).get_prop("name_prefix") if pre == "": return pre