Skip to content
This repository has been archived by the owner on Oct 11, 2023. It is now read-only.

DBEpisode import export #345

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions deploy/web/server/tests/test_tornado_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,11 +1456,11 @@ def test_view_interaction(self, mocked_auth):
def all():
suiteList = []
# TODO: Break out into seperate files, arrange suite elsewhere when automated testing done
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestRegistryApp))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestGameApp))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestWorldSaving))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestBuilderApp))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestLandingApp))
# suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestRegistryApp))
# suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestGameApp))
# suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestWorldSaving))
# suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestBuilderApp))
# suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestLandingApp))
return unittest.TestSuite(suiteList)


Expand Down
89 changes: 74 additions & 15 deletions light/data_model/db/episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def write_episode(
if idx == 0:
first_id = curr_id
episode.graphs.append(db_graph)

session.add(episode)
assert first_id is not None and curr_id is not None
episode.first_graph_id = first_id
Expand Down Expand Up @@ -498,29 +499,41 @@ def replace_all_actors(in_data: str) -> str:
session.commit()
return True

def export(self, config: "LightDBConfig") -> "EpisodeDB":
def export(
self, config: "LightDBConfig", group: Optional[DBGroupName] = None
) -> "EpisodeDB":
"""
Create a scrubbed version of this database for use in releases
"""
assert config.file_root != self.file_root, "Cannot copy DB to same location!"
new_db = EpisodeDB(config)

# Copy all the basic content
for table_name, table_obj in SQLBase.metadata.tables.items():
with self.engine.connect() as orig_conn:
with new_db.engine.connect() as new_conn:
keys = table_obj.c.keys()
all_data = [
dict(zip(keys, row))
for row in orig_conn.execute(select(table_obj.c))
]
if len(all_data) == 0:
continue
new_conn.execute(table_obj.insert().values(all_data))
new_conn.commit()

with self.engine.connect() as orig_conn:
with new_db.engine.connect() as new_conn:
# Write the episodes
episode_table = SQLBase.metadata.tables["episodes"]
stmt = select(DBEpisode)
if group is not None:
stmt = stmt.where(DBEpisode.group == group)
episodes = orig_conn.execute(stmt)
episode_data = [dict(r) for r in episodes]
episode_ids = [d["id"] for d in episode_data]
new_conn.execute(episode_table.insert().values(episode_data))
for tbl_name in ["graphs", "quest_completions", "wild_metadata"]:
tbl_obj = SQLBase.metadata.tables[tbl_name]
stmt = select(tbl_obj.c).where(
tbl_obj.c.episode_id.in_(episode_ids)
)
table_data = [dict(r) for r in orig_conn.execute(stmt)]
new_conn.execute(tbl_obj.insert().values(table_data))
new_conn.commit() # type: ignore

# Copy graphs over
with Session(self.engine) as session:
stmt = select(DBEpisode)
if group is not None:
stmt = stmt.where(DBEpisode.group == group)
episodes = session.scalars(stmt).all()
for episode in episodes:
graphs = episode.graphs
Expand All @@ -532,7 +545,53 @@ def export(self, config: "LightDBConfig") -> "EpisodeDB":
event_data = self.read_data_from_file(episode.dump_file_path)
new_db.write_data_to_file(event_data, episode.dump_file_path)

for group in DBGroupName:
if group is not None:
new_db.anonymize_group(group=group)
else:
for group in DBGroupName:
new_db.anonymize_group(group=group)

return new_db

def download_and_import_group(self, group: DBGroupName) -> None:
"""Download the group from a hosted source, then """
stmt = select(DBEpisode).where(DBEpisode.group == group).limit(1)
with Session(self.engine) as session:
episodes = session.scalars(stmt).all()
if len(episodes) > 0:
print(f"You already have {group} installed")
return
assert NotImplementedError("Still need to implement the download step")
# data_url = light.data_utils.get_episode_group(group) or something like it
# download data to temporary directory
# unzip db and data in temporary directory
# create LightDBConfig pointing to this dir
# self.import_db(other_db)
# delete temporary db directory

def import_db(self, other_db: "EpisodeDB") -> None:
"""Attempt to import the contents of the specified db to this one"""
# Copy all the contents from the source
for _, table_obj in SQLBase.metadata.tables.items():
with self.engine.connect() as dest_conn:
with other_db.engine.connect() as source_conn:
all_data = [
dict(row) for row in source_conn.execute(select(table_obj.c))
]
if len(all_data) == 0:
continue
dest_conn.execute(table_obj.insert().values(all_data))
dest_conn.commit() # type: ignore

with Session(other_db.engine) as session:
stmt = select(DBEpisode)
episodes = session.scalars(stmt).all()
for episode in episodes:
graphs = episode.graphs
for graph in graphs:
# Copy the graphs to the new DB
graph_data = other_db.read_data_from_file(graph.full_path)
self.write_data_to_file(graph_data, graph.full_path)
# Copy the events to the new DB
event_data = other_db.read_data_from_file(episode.dump_file_path)
self.write_data_to_file(event_data, episode.dump_file_path)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ yappi>=1.2.5
torch>=1.5.0
tornado>=6.0.4
pandas>=1.0.3
parlai>=1.7.0
parlai>=1.7.1
psycopg2-binary>=2.8.5
pytest>=5.0.0
pyzmq>=19.0.1
tqdm>=4.48.0
hydra-core>=1.2.0
mephisto>=1.0.3
mephisto>=1.1.0
SQLAlchemy>=2.0.7