Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Populate database using the files in xlogfile #321

Merged
merged 1 commit into from
May 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions nle/dataset/populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,35 +251,37 @@ def add_nledata_directory(path, name, filename=db.DB):
db.create_dataset(name, root, conn=c, commit=False)

# 2. For each xlogfile, read the games and take only the games that
# correspond to the ttyrecs in the enclosing directory.
# correspond to the ttyrecs that exist in the enclosing directory.
for xlogfile in sorted(glob.iglob(path + "/*/*.xlogfile")):
stem = xlogfile.replace(".xlogfile", ".*.ttyrec*.bz2")

files = list(glob.iglob(stem))
resets = set()
versions = set()
for f in files:
nameparts = f.split(".")
resets.add(nameparts[-3])
ttyrec_version = nameparts[-2].replace("ttyrec", "")
# NLE v0.8.1 or older generated v2 ttyrecs with the "ttyrec.bz2" suffix
v = int(ttyrec_version) if ttyrec_version else 2
versions.add(v)
resets = set(int(i.split(".")[-3]) for i in files)

files = set(glob.iglob(stem))
ttyrecnames = set(f.split("/")[-1] for f in files)
versions = set(f.split("ttyrec")[-1].replace(".bz2", "") for f in files)
assert len(versions) == 1, "Cannot add ttyrecs with different versions"
version = versions.pop()

if version == "":
raise AssertionError(
"Ttyrec version (* in ttyrec*.bz2) must be > 1 for NLE data."
)

c.execute(
"UPDATE roots SET ttyrec_version = ? WHERE dataset_name = ?",
(version, name),
(int(version), name),
)
version = str(version)

ttyrecs = []
ttydir = str(os.path.dirname(xlogfile))

def filter(gen):
# The `xlogfile` may have more rows than files in directory
# due to 'save_ttyrec_every' option in env.py, so filter these out.
for line_no, line in enumerate(gen):
if line_no in resets:
# If we do find a file, we will save it to be added later.
for line in gen:
ttyrecname = line.decode("ascii").split("ttyrecname=")[-1].strip()
if ttyrecname in ttyrecnames:
ttyrecs.append(ttydir + "/" + ttyrecname)
yield line

# 3. Add games to `games` and `datasets` table.
Expand All @@ -293,12 +295,8 @@ def filter(gen):
db.add_games(name, *gameids, conn=conn, commit=False)

# 4. Add ttyrecs to `ttyrecs` table.
valid_resets = list(resets)[: len(gameids)]
ttyrecs = [
stem.replace("*", str(r), 1).replace("*", version, 1)
for r in sorted(valid_resets, reverse=True)
]
ttyrec_gen = ttyrec_data_generator(ttyrecs, gameids, root)
# Note gameids are "most recently added" so must be reversed.
ttyrec_gen = ttyrec_data_generator(ttyrecs, reversed(gameids), root)
c.executemany("INSERT INTO ttyrecs VALUES (?,?,?,?,?)", ttyrec_gen)

mtime = time.time()
Expand Down