Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
Log in-game score in ttyrecs and bump ttyrec version to 3.
Browse files Browse the repository at this point in the history
We will now start storing the in-game score in ttyrecs, to provide
greater information for offline learning methods. This is done by
utilising a different 'channel' for score. As it stands the channels
are:
 - 0 - terminal information
 - 1 - action input
 - 2 - in-game score (based on u.urexp)

 Note the in-game score is recorded once, just _after_ step has been
 called in the Python (but before the step is run in C). This is done to
 avoid having to store the in-game score everytime the game flushes to
 screen, and to keep the final ttyrec size low.
  • Loading branch information
cdmatters committed May 9, 2022
1 parent 3f633c7 commit 7319368
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
2 changes: 1 addition & 1 deletion nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)

HACKDIR = pkg_resources.resource_filename("nle", "nethackdir")
TTYREC_VERSION = 2
TTYREC_VERSION = 3


def _new_dl_linux(vardir):
Expand Down
32 changes: 32 additions & 0 deletions src/nle.c
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,14 @@ nle_start(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seed_init,
nle_seeds_init =
NULL; /* Don't set to *these* seeds on subsequent reseeds, if any. */

if (obs->blstats) {
/* See comment in `nle_step`. We record the score in line with
* the state to ensure s,r -> a -> s', r'. These lines ensure
* we don't skip the first reward. */
write_ttyrec_header(4, 2);
write_ttyrec_data(&obs->blstats[9], 4);
}

return nle;
}

Expand All @@ -452,6 +460,30 @@ nle_step(nle_ctx_t *nle, nle_obs *obs)
nle->done = (t.data == NULL);
obs->done = nle->done;

if (nle->ttyrec) {
/* NLE ttyrec version 3 stores the action and in-game score in
* different channels of the ttyrec. These channels are:
* - 0: the terminal instructions (classic ttyrec)
* - 1: the keypress/action (1 byte)
* - 2: the in-game score (4 bytes)
*
* We could either the note the in-game score every time we flush the
* terminal instructions to screen, (eg writing [ 0 2 0 2 <step> 1 0 2
* <step> 1 ]) or we can note it _just_ before resuming the game,
* assuming no chicanery has happened to the score after it is written
* to the array `blstats`, (eg writing [ 0 2 <step> 1 0 2 <step> 1 0 2
* <step> ]). We chose the latter for compression & simplicity
* reasons.
*
* Note: blstats[9] == botl_score which is used for score/reward fns.
* see winrl.cc
*/
if (obs->blstats) {
write_ttyrec_header(4, 2);
write_ttyrec_data(&obs->blstats[9], 4);
}
}

return nle;
}

Expand Down
17 changes: 13 additions & 4 deletions third_party/converter/converter.c
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,17 @@ int conversion_convert_frames(Conversion *c) {

if (c->version > 1){
/* NLE-based ttyrecs have a channel which codifies what type of
* information we are encoding.
* information we are encoding.
*
* V2: If Output Channel (0) -> update terminal;
* Else Input Channel (1) -> write (state + action) to buffers.
* NB. Will only end up writing last frame before input is given. */
* V1: Order: 0 - No "channel", write only to update terminal
* V2: [0 1 0 1 ...]
* Channel 0 -> update terminal/state
* Channel 1 -> we have an action: write state + action to buffers
* V3: [0 2 1 0 2 1 ...]
* Channel 0 -> update terminal/state
* Channel 2 -> we have an reward: write reward only
* Channel 1 -> we have an action: write state + action to buffers
* NB. Will only end up writing when an action is given. */
if (c->header.channel == 0) {
tmt_write(c->vt, c->buf, c->header.len);
} else {
Expand All @@ -287,6 +293,9 @@ int conversion_convert_frames(Conversion *c) {
}

void write_to_buffers(Conversion *conv) {
if (conv->version == 3 && conv->header.channel == 2)
return;

const TMTSCREEN *scr = tmt_screen(conv->vt);
for (size_t r = 0; r < conv->rows; ++r) {
for (size_t c = 0; c < conv->cols; ++c) {
Expand Down

0 comments on commit 7319368

Please sign in to comment.