Skip to content

Commit

Permalink
Fix black.
Browse files Browse the repository at this point in the history
  • Loading branch information
heiner committed May 3, 2024
1 parent 6d80bb5 commit ab061d3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 16 deletions.
6 changes: 3 additions & 3 deletions nle/scripts/collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,9 @@ def get_env_info():
cuda_available_str = torch.cuda.is_available()
cuda_version_str = torch.version.cuda
else:
torch_version_str = (
torch_debug_mode_str
) = cuda_available_str = cuda_version_str = "N/A"
torch_version_str = torch_debug_mode_str = cuda_available_str = (
cuda_version_str
) = "N/A"

return SystemEnv(
nle_version=nle_version,
Expand Down
48 changes: 36 additions & 12 deletions nle/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ def getfilename(filename):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), filename)


def load_and_convert(converter, ttyrec, chars, colors, cursors, timestamps, actions, scores):
def load_and_convert(
converter, ttyrec, chars, colors, cursors, timestamps, actions, scores
):
converter.load_ttyrec(ttyrec)
remaining = converter.convert(chars, colors, cursors, timestamps, actions, scores)
while remaining == 0:
remaining = converter.convert(chars, colors, cursors, timestamps, actions, scores)
remaining = converter.convert(
chars, colors, cursors, timestamps, actions, scores
)


class TestConverter:
Expand Down Expand Up @@ -125,7 +129,9 @@ def test_ttyrec_with_extra_data(self, seq_length=500):
scores = np.zeros((seq_length), dtype=np.int32)

converter.load_ttyrec(getfilename(TTYREC_2018))
remaining = converter.convert(chars, colors, cursors, timestamps, actions, scores)
remaining = converter.convert(
chars, colors, cursors, timestamps, actions, scores
)
assert remaining == 165

def test_data(self):
Expand All @@ -149,7 +155,9 @@ def test_data(self):

num_frames = 0
while True:
remaining = converter.convert(chars, colors, cursors, timestamps, actions, scores)
remaining = converter.convert(
chars, colors, cursors, timestamps, actions, scores
)
for (row, col), ts in zip(
cursors[: SEQ_LENGTH - remaining], timestamps[: SEQ_LENGTH - remaining]
):
Expand Down Expand Up @@ -203,7 +211,9 @@ def test_illegal_buffers(self):
cursors = np.zeros((10, 2), dtype=np.int16)
with pytest.raises(
ValueError,
match=re.escape("Array has wrong shape (expected [ 10 25 80 ], got [ 10 25 79 ])"),
match=re.escape(
"Array has wrong shape (expected [ 10 25 80 ], got [ 10 25 79 ])"
),
):
converter.convert(chars, colors, cursors, timestamps, actions, scores)

Expand All @@ -212,7 +222,9 @@ def test_illegal_buffers(self):
cursors = np.zeros((10, 2), dtype=np.int16)
with pytest.raises(
ValueError,
match=re.escape("Array has wrong shape (expected [ 10 25 80 ], got [ 10 24 80 ])"),
match=re.escape(
"Array has wrong shape (expected [ 10 25 80 ], got [ 10 24 80 ])"
),
):
converter.convert(chars, colors, cursors, timestamps, actions, scores)

Expand Down Expand Up @@ -288,7 +300,9 @@ def test_ibm_graphics(self):
scores = np.zeros((seq_length), dtype=np.int32)

converter.load_ttyrec(getfilename(TTYREC_IBMGRAPHICS))
assert converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
assert (
converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
)

with open(getfilename(TTYREC_IBMGRAPHICS_FRAME_10)) as f:
for row, line in enumerate(f):
Expand All @@ -308,7 +322,9 @@ def test_dec_graphics(self):
scores = np.zeros((seq_length), dtype=np.int32)

converter.load_ttyrec(getfilename(TTYREC_DECGRAPHICS))
assert converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
assert (
converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
)

with open(getfilename(TTYREC_DECGRAPHICS_FRAME_10)) as f:
for row, line in enumerate(f):
Expand All @@ -328,7 +344,9 @@ def test_unknown_control_sequence_graphics(self):
scores = np.zeros((seq_length), dtype=np.int32)

converter.load_ttyrec(getfilename(TTYREC_UNKGRAPHICS))
assert converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
assert (
converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
)

with open(getfilename(TTYREC_UNKGRAPHICS_FRAME_10)) as f:
for row, line in enumerate(f):
Expand All @@ -348,7 +366,9 @@ def test_shiftin_shiftout_graphics(self):
scores = np.zeros((seq_length), dtype=np.int32)

converter.load_ttyrec(getfilename(TTYREC_SHIFTIN))
assert converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
assert (
converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
)

with open(getfilename(TTYREC_SHIFTIN_FRAME_10)) as f:
for row, line in enumerate(f):
Expand All @@ -368,7 +388,9 @@ def test_nle_v2_conversion(self):
scores = np.zeros((seq_length), dtype=np.int32)

converter.load_ttyrec(getfilename(TTYREC_NLE_V2))
assert converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
assert (
converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0
)

with open(getfilename(TTYREC_NLE_V2_FRAME_150)) as f:
for row, line in enumerate(f):
Expand All @@ -391,7 +413,9 @@ def test_nle_v3_conversion(self):
scores = np.zeros((seq_length), dtype=np.int32)

converter.load_ttyrec(getfilename(TTYREC_NLE_V3))
assert converter.convert(chars, colors, cursors, timestamps, actions, scores) == 1
assert (
converter.convert(chars, colors, cursors, timestamps, actions, scores) == 1
)

with open(getfilename(TTYREC_NLE_V3_FRAME_44)) as f:
for row, line in enumerate(f):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.black]
line-length = 88
target-version = ['py37']
target-version = ['py38']
include = '\.pyi?$'
exclude = '''
/(
Expand Down

0 comments on commit ab061d3

Please sign in to comment.