diff --git a/nle/dataset/populate_db.py b/nle/dataset/populate_db.py index 7797d7832..b08090d89 100644 --- a/nle/dataset/populate_db.py +++ b/nle/dataset/populate_db.py @@ -251,35 +251,37 @@ def add_nledata_directory(path, name, filename=db.DB): db.create_dataset(name, root, conn=c, commit=False) # 2. For each xlogfile, read the games and take only the games that - # correspond to the ttyrecs in the enclosing directory. + # correspond to the ttyrecs that exist in the enclosing directory. for xlogfile in sorted(glob.iglob(path + "/*/*.xlogfile")): stem = xlogfile.replace(".xlogfile", ".*.ttyrec*.bz2") - files = list(glob.iglob(stem)) - resets = set() - versions = set() - for f in files: - nameparts = f.split(".") - resets.add(nameparts[-3]) - ttyrec_version = nameparts[-2].replace("ttyrec", "") - # NLE v0.8.1 or older generated v2 ttyrecs with the "ttyrec.bz2" suffix - v = int(ttyrec_version) if ttyrec_version else 2 - versions.add(v) - resets = set(int(i.split(".")[-3]) for i in files) - + files = set(glob.iglob(stem)) + ttyrecnames = set(f.split("/")[-1] for f in files) + versions = set(f.split("ttyrec")[-1].replace(".bz2", "") for f in files) assert len(versions) == 1, "Cannot add ttyrecs with different versions" version = versions.pop() + + if version == "": + raise AssertionError( + "Ttyrec version (* in ttyrec*.bz2) must be > 1 for NLE data." + ) + c.execute( "UPDATE roots SET ttyrec_version = ? WHERE dataset_name = ?", - (version, name), + (int(version), name), ) - version = str(version) + + ttyrecs = [] + ttydir = str(os.path.dirname(xlogfile)) def filter(gen): # The `xlogfile` may have more rows than files in directory # due to 'save_ttyrec_every' option in env.py, so filter these out. - for line_no, line in enumerate(gen): - if line_no in resets: + # If we do find a file, we will save it to be added later. + for line in gen: + ttyrecname = line.decode("ascii").split("ttyrecname=")[-1].strip() + if ttyrecname in ttyrecnames: + ttyrecs.append(ttydir + "/" + ttyrecname) yield line # 3. Add games to `games` and `datasets` table. @@ -293,12 +295,8 @@ def filter(gen): db.add_games(name, *gameids, conn=conn, commit=False) # 4. Add ttyrecs to `ttyrecs` table. - valid_resets = list(resets)[: len(gameids)] - ttyrecs = [ - stem.replace("*", str(r), 1).replace("*", version, 1) - for r in sorted(valid_resets, reverse=True) - ] - ttyrec_gen = ttyrec_data_generator(ttyrecs, gameids, root) + # Note gameids are "most recently added" so must be reversed. + ttyrec_gen = ttyrec_data_generator(ttyrecs, reversed(gameids), root) c.executemany("INSERT INTO ttyrecs VALUES (?,?,?,?,?)", ttyrec_gen) mtime = time.time()