Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
Fix up documentation. Add a docstring to TtyrecDataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
cdmatters committed May 16, 2022
1 parent f6896d1 commit 19954e4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
42 changes: 41 additions & 1 deletion nle/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def convert_frames(
:param timestamps: Array of timestamps - np.array(np.int64) [ SEQ ]
:param actions: Array of actions at t in response to output at t
- np.array(np.uint8) [ SEQ ]
:param scores: Array of in-game scores - np.array(np.int32) [ SEQ ]
:param resets: Array of resets - np.array(np.uint8) [ SEQ ]
:param gameids: Array of the gameid of each frame - np.array(np.int32) [ SEQ ]
:param load_fn: A callback that loads the next file into a converter:
sig: load_fn(converter) -> bool is_success
Note actions will only be non-null if the ttyrec read_actions flag is set.
"""

resets[0] = 0
Expand Down Expand Up @@ -161,6 +161,46 @@ def __init__(
subselect_sql=None,
subselect_sql_args=None,
):
"""
An iterable dataset to load minibatches of NetHack games from compressed
ttyrec*.bz2 files into numpy arrays. (shape: [batch_size, seq_length, ...])
This class makes use of a sqlite3 database at `dbfilename` to find the
metadata and the location of files in a dataset. It then uses these to
create generators which convert the ttyrecs on the fly. Note that the
dataset generators always reuse their numpy arrays, writing into the
arrays instead of generating new ones. Methods to create and populate the
db from an NLE directry can be found in `populate_db.py`.
Example
-------
```
import nle.dataset as nld
if not os.path.exists(nld.db.DB):
nld.db.create()
nld.populate_db.add_nledata_directory('path/to/nle_data', "data1")
dataset = nld.TtyrecDataset("data1"):
for mb in dataset:
# NB: dataset reuses np arrays, for performance reasons
print(mb)
```
:param batch_size: Number of parallel games to load.
:param seq_length: Number of frames to load per game.
:param rows: Row size of the terminal screen.
:param cols: Column size of the terminal screen.
:param dbfilename: Path to the database file
:param gameids: Use a subselection of games (gameids) only.
:param shuffle: Shuffle the order of gameids before iterating through them.
:param loop_forever: If true, cycle through gameids forever,
insted of padding empty batch dims with 0's.
:param subselect_sql: SQL Query to subselect games (gameids) using metadata
:param subselect_sql_args: SQL Query Args to subselect games (gameids)
using metadata.
"""
self.batch_size = batch_size
self.seq_length = seq_length
self.rows = rows
Expand Down
1 change: 0 additions & 1 deletion nle/scripts/read_tty.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def main():
arrow = "->"
elif channel == 2:
score, *_ = struct.unpack("<i", data)
# data = chr(score).encode("ascii", "backslashreplace")
data = f" {score} "
arrow = "->"

Expand Down

0 comments on commit 19954e4

Please sign in to comment.