Skip to content
This repository has been archived by the owner on Oct 11, 2023. It is now read-only.

Commit

Permalink
New Episode logging (#293)
Browse files Browse the repository at this point in the history
* Filling out UserDB

* Abstract enforce get first

* Clearer argument - num_turns

* Using enums in DB

* Initial episode data model

* Update content loggers to use episode formatting

* Updating tables to work with testing

* Fixing some test changes

* Fixing small warnings that were noise during tests

* Moving default log path

* Test fix

* Correcting math thanks to Kurt

* Addressing comments, clarifying code
  • Loading branch information
JackUrb authored Aug 22, 2022
1 parent 183449b commit 84843af
Show file tree
Hide file tree
Showing 11 changed files with 847 additions and 440 deletions.
46 changes: 43 additions & 3 deletions light/data_model/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@
from enum import Enum
from typing import Optional, Union, Dict, Any
from dataclasses import dataclass
from tempfile import mkdtemp
import shutil
import os
import json

from hydra.core.config_store import ConfigStore

DEFAULT_LOG_PATH = "".join(
[os.path.abspath(os.path.dirname(__file__)), "/../../../logs"]
)


@dataclass
class LightDBConfig:
backend: str = "test"
file_root: Optional[str] = DEFAULT_LOG_PATH


cs = ConfigStore.instance()
Expand Down Expand Up @@ -61,8 +70,14 @@ def __init__(self, config: "DictConfig"):
files and instances.
"""
# TODO replace with a swappable engine that persists the data
self.backend = config.backend
if config.backend == "test":
self.engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
self.made_temp_dir = config.file_root is None
if self.made_temp_dir:
self.file_root = mkdtemp()
else:
self.file_root = config.file_root
else:
raise NotImplementedError()
self._complete_init(config)
Expand Down Expand Up @@ -91,22 +106,47 @@ def _enforce_get_first(self, session, stmt, error_text) -> Any:

def write_data_to_file(
self, data: Union[str, Dict[str, Any]], filename: str, json_encode: bool = False
):
) -> None:
"""
Write the given data to the provided filename
in the correct storage location (local or remote)
"""
# Expects data to be a string, unless json_encode is True
if self.backend in ["test", "local"]:
full_path = os.path.join(self.file_root, filename)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, "w+") as target_file:
if json_encode:
json.dump(data, target_file)
else:
target_file.write(data)
else:
raise NotImplementedError()

def read_data_from_file(self, filename: str, json_encoded: bool = False):
def read_data_from_file(
self, filename: str, json_encoded: bool = False
) -> Union[str, Dict[str, Any]]:
"""
Read the data from the given filename from wherever it
is currently stored (local or remote)
"""
if self.backend in ["test", "local"]:
full_path = os.path.join(self.file_root, filename)
with open(full_path, "r") as target_file:
if json_encoded:
return json.load(target_file)
else:
return target_file.read()
else:
raise NotImplementedError()

def open_file(self):
try:
file = open(self.file_name, "w")
yield file
finally:
file.close()

def shutdown(self):
if self.backend == "test":
if self.made_temp_dir:
shutil.rmtree(self.file_root)
249 changes: 229 additions & 20 deletions light/data_model/db/episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@
# LICENSE file in the root directory of this source tree.

from light.data_model.db.base import BaseDB, DBStatus, DBSplitType
from light.graph.structured_graph import OOGraph
from omegaconf import MISSING, DictConfig
from typing import Optional, Union, Dict, Any
from dataclasses import dataclass
from enum import Enum
from typing import Optional, List, Tuple, Union, Dict, Any, Set, TYPE_CHECKING
from sqlalchemy import insert, select, Enum, Column, Integer, String, Float, ForeignKey
from sqlalchemy.orm import declarative_base, relationship, Session
from light.graph.events.base import GraphEvent
import time
import enum
import os

if TYPE_CHECKING:
from light.graph.structured_graph import OOGraph

class DBGroupName(Enum):
SQLBase = declarative_base()


class DBGroupName(enum.Enum):
"""Edges in the LIGHT Environment DB"""

ORIG = "orig"
Expand All @@ -22,26 +30,130 @@ class DBGroupName(Enum):
RELEASE = "full_release"


@dataclass
class DBEpisode:
class EpisodeLogType(enum.Enum):
"""Types of episodes in LIGHT"""

ROOM = "room"
AGENT = "agent"
FULL = "full"


class DBEpisode(SQLBase):
"""Class containing the expected elements for an episode as stored in the db"""

group: DBGroupName
split: DBSplitType
status: DBStatus
actors: List[str]
dump_file: str
after_graph: str
timestamp: float
__tablename__ = "episodes"

id = Column(Integer, primary_key=True)
group = Column(Enum(DBGroupName), nullable=False, index=True)
split = Column(Enum(DBSplitType), nullable=False, index=True)
status = Column(Enum(DBStatus), nullable=False, index=True)
actors = Column(
String
) # Comma separated list of actor IDs. Cleared on release data
dump_file_path = Column(String(90), nullable=False) # Path to data
turn_count = Column(Integer, nullable=False)
human_count = Column(Integer, nullable=False)
action_count = Column(Integer, nullable=False)
timestamp = Column(Float, nullable=False)
log_type = Column(Enum(EpisodeLogType), nullable=False)
first_graph_id = Column(ForeignKey("graphs.id"))
final_graph_id = Column(ForeignKey("graphs.id"))

_cached_map = None

def get_actors(self) -> List[str]:
"""Return the actors in this episode"""
if len(self.actors.strip()) == 0:
return []
return self.actors.split(",")

def get_parsed_events(
self, db: "EpisodeDB"
) -> List[Tuple[str, List["GraphEvent"]]]:
"""
Return all of the actions and turns from this episode,
split by the graph key ID relevant to those actions
"""
# Import deferred as World imports loggers which import the EpisodeDB
from light.world.world import World

events = db.read_data_from_file(self.dump_file_path, json_encoded=True)[
"events"
]
graph_grouped_events: List[Tuple[str, List["GraphEvent"]]] = []
current_graph_events = None
curr_graph_key = None
curr_graph = None
tmp_world = None
# Extract events to the correct related graphs, initializing the graphs
# as necessary
for event_turn in events:
# See if we've moved onto an event in a new graph
if event_turn["graph_key"] != curr_graph_key:
if current_graph_events is not None:
# There was old state, so lets push it to the list
graph_grouped_events.append((curr_graph_key, current_graph_events))
# We're on a new graph, have to reset the current graph state
curr_graph_key = event_turn["graph_key"]
current_graph_events: List["GraphEvent"] = []
curr_graph = self.get_graph(curr_graph_key, db)
tmp_world = World({}, None)
tmp_world.oo_graph = curr_graph
# The current turn is part of the current graph's events, add
current_graph_events.append(
GraphEvent.from_json(event_turn["event_json"], tmp_world)
)
if current_graph_events is not None:
# Push the last graph's events, which weren't yet added
graph_grouped_events.append((curr_graph_key, current_graph_events))
return graph_grouped_events

def get_before_graph(self, db: "EpisodeDB") -> "OOGraph":
"""Return the state of the graph before this episode"""
return self.get_graph(self.first_graph_id, db)

def get_parsed_episode(self, db: "EpisodeDB") -> List[Any]:
"""Return all of the actions and turns from this episode"""
def get_graph(self, id_or_key: str, db: "EpisodeDB") -> "OOGraph":
"""Return a specific graph by id or key"""
return self.get_graph_map()[id_or_key].get_graph(db)

def get_after_graph(self, db: "EpisodeDB") -> "OOGraph":
"""Return the state of the graph after this episode"""
return self.get_graph(self.final_graph_id, db)

def get_graph_map(self):
"""Return a mapping from both graph keys and graph ids to their graph"""
if self._cached_map is None:
key_map = {graph.graph_key_id: graph for graph in self.graphs}
id_map = {graph.id: graph for graph in self.graphs}
key_map.update(id_map)
self._cached_map = key_map
return self._cached_map

def __repr__(self):
return f"DBEpisode(ids:[{self.id!r}] group/split:[{self.group.value!r}/{self.split.value!r}] File:[{self.dump_file_path!r}])"


class DBGraph(SQLBase):
"""Class containing expected elements for a stored graph"""

__tablename__ = "graphs"

id = Column(Integer, primary_key=True)
episode_id = Column(Integer, ForeignKey("episodes.id"), nullable=False, index=True)
full_path = Column(String(80), nullable=False)
graph_key_id = Column(String(60), nullable=False, index=True)
episode = relationship("DBEpisode", backref="graphs", foreign_keys=[episode_id])

def get_graph(self, db: "EpisodeDB") -> "OOGraph":
"""Return the initialized graph based on this file"""
from light.graph.structured_graph import OOGraph

graph_json = db.read_data_from_file(self.full_path)
graph = OOGraph.from_json(graph_json)
return graph

def __repr__(self):
return f"DBGraph(ids:[{self.id!r},{self.graph_key_id!r}], episode:{self.episode_id!r})"


class EpisodeDB(BaseDB):
Expand All @@ -58,28 +170,99 @@ def _complete_init(self, config: "DictConfig"):
Initialize any specific episode-related paths. Populate
the list of available splits and datasets.
"""
raise NotImplementedError()
SQLBase.metadata.create_all(self.engine)

def _validate_init(self):
"""
Ensure that the episode directory is properly loaded
"""
raise NotImplementedError()
# TODO Check the table for any possible consistency issues
# and ensure that the episode directories for listed splits exist

def write_episode(self, args) -> str:
def write_episode(
self,
graphs: List[Dict[str, str]],
events: Tuple[str, List[Dict[str, str]]],
log_type: EpisodeLogType,
action_count: int,
players: Set[str],
group: DBGroupName,
) -> str:
"""
Create an entry given the current argument data, store it
to file on the database
"""
actor_string = ",".join(list(players))
event_filename = events[0]
event_list = events[1]
# Trim the filename from the left if too long
event_filename = event_filename[-70:]
assert len(event_filename) <= 70
dump_file_path = os.path.join(group.value, log_type.value, event_filename)
graph_dump_root = os.path.join(
group.value,
log_type.value,
"graphs",
)

# File writes
self.write_data_to_file(
{"events": event_list}, dump_file_path, json_encode=True
)
for graph_info in graphs:
graph_full_path = os.path.join(graph_dump_root, graph_info["filename"])
self.write_data_to_file(graph_info["graph_json"], graph_full_path)

# DB Writes
with Session(self.engine) as session:
episode = DBEpisode(
group=group,
split=DBSplitType.UNSET,
status=DBStatus.REVIEW,
actors=actor_string,
dump_file_path=dump_file_path,
turn_count=len(event_list),
human_count=len(players),
action_count=action_count,
timestamp=time.time(),
log_type=log_type,
)
first_graph = None
for idx, graph_info in enumerate(graphs):
graph_full_path = os.path.join(graph_dump_root, graph_info["filename"])
db_graph = DBGraph(
graph_key_id=graph_info["key"],
full_path=graph_full_path,
)
if idx == 0:
first_graph = db_graph
episode.graphs.append(db_graph)
session.add(episode)
session.flush()
episode.first_graph_id = first_graph.id
episode.final_graph_id = db_graph.id

episode_id = episode.id
session.commit()

return episode_id

def get_episode(self, episode_id: str) -> "DBEpisode":
"""
Return a specific episode by id, raising an issue if it doesnt exist
"""
stmt = select(DBEpisode).where(DBEpisode.id == episode_id)
with Session(self.engine) as session:
episode = self._enforce_get_first(session, stmt, "Episode did not exist")
for graph in episode.graphs:
# Load all the graph keys
assert graph.id is not None
session.expunge_all()
return episode

def get_episodes(
self,
group: Optional[str] = None,
group: Optional[DBGroupName] = None,
split: Optional[DBSplitType] = None,
min_turns: Optional[int] = None,
min_humans: Optional[int] = None,
Expand All @@ -88,8 +271,34 @@ def get_episodes(
user_id: Optional[str] = None,
min_creation_time: Optional[float] = None,
max_creation_time: Optional[float] = None,
log_type: Optional[EpisodeLogType] = None,
# ... other args
) -> List["DBEpisode"]:
"""
Return all matching episodes
"""
stmt = select(DBEpisode)
if group is not None:
stmt = stmt.where(DBEpisode.group == group)
if split is not None:
stmt = stmt.where(DBEpisode.split == split)
if min_turns is not None:
stmt = stmt.where(DBEpisode.turn_count >= min_turns)
if min_humans is not None:
stmt = stmt.where(DBEpisode.human_count >= min_humans)
if min_actions is not None:
stmt = stmt.where(DBEpisode.action_count >= min_actions)
if status is not None:
stmt = stmt.where(DBEpisode.status == status)
if user_id is not None:
stmt = stmt.where(DBEpisode.actors.contains(user_id))
if log_type is not None:
stmt = stmt.where(DBEpisode.log_type == log_type)
if min_creation_time is not None:
stmt = stmt.where(DBEpisode.timestamp >= min_creation_time)
if max_creation_time is not None:
stmt = stmt.where(DBEpisode.timestamp <= max_creation_time)
with Session(self.engine) as session:
episodes = session.scalars(stmt).all()
session.expunge_all()
return episodes
Loading

0 comments on commit 84843af

Please sign in to comment.