Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions nle/tests/test_nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@
class TestNetHack:
@pytest.fixture
def game(self): # Make sure we close even on test failure.
g = nethack.Nethack(observation_keys=("chars", "blstats"))
g = nethack.Nethack(ttyrec=None, observation_keys=("chars", "blstats"))
try:
yield g
finally:
g.close()

def test_close_and_restart(self):
game = nethack.Nethack()
game = nethack.Nethack(ttyrec=None)
game.reset()
game.close()

game = nethack.Nethack()
game = nethack.Nethack(ttyrec=None)
game.reset()
game.close()

Expand Down Expand Up @@ -103,7 +103,7 @@ def test_run_n_episodes(self, tmpdir, game, episodes=3):

def test_several_nethacks(self, game):
game.reset()
game1 = nethack.Nethack()
game1 = nethack.Nethack(ttyrec=None)
game1.reset()

try:
Expand All @@ -119,7 +119,7 @@ def test_several_nethacks(self, game):
game1.close()

def test_set_initial_seeds(self):
game = nethack.Nethack(copy=True)
game = nethack.Nethack(ttyrec=None, copy=True)
game.set_initial_seeds(core=42, disp=666)
obs0 = game.reset()
try:
Expand All @@ -141,9 +141,11 @@ def test_set_seed_after_reset(self, game):


class TestNetHackFurther:
def test_run(self):
def test_run(self, tmpdir):
ttyrec_path = os.path.join(tmpdir, "nle.ttyrec%i.bz2" % nethack.TTYREC_VERSION)
game = nethack.Nethack(
observation_keys=("glyphs", "chars", "colors", "blstats", "program_state")
ttyrec=ttyrec_path,
observation_keys=("glyphs", "chars", "colors", "blstats", "program_state"),
)
_, _, _, _, program_state = game.reset()
actions = [
Expand Down Expand Up @@ -186,35 +188,34 @@ def test_run(self):
assert class_sym.explain == "human or elf"

game.close()
assert os.path.isfile(
os.path.join(os.getcwd(), "nle.ttyrec%i.bz2" % nethack.TTYREC_VERSION)
)
assert os.path.isfile(ttyrec_path)

def test_illegal_filename(self):
with pytest.raises(IOError):
nethack.Nethack(ttyrec="")
game = nethack.Nethack()
game = nethack.Nethack(ttyrec=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test shouldn't have the ttyrec override as it's testing an incorrect filename.

with pytest.raises(IOError):
game.reset("")

def test_set_buffers_after_reset(self):
game = nethack.Nethack()
game = nethack.Nethack(ttyrec=None)
game.reset()
with pytest.raises(RuntimeError, match=r"set_buffers called after reset()"):
game._pynethack.set_buffers()

def test_nethack_random_character(self):
game = nethack.Nethack(playername="Hugo-@")
game = nethack.Nethack(ttyrec=None, playername="Hugo-@")
assert "race:random" in game.options
assert "gender:random" in game.options
assert "align:random" in game.options

game = nethack.Nethack(playername="Jurgen-wiz-gno-cha-mal")
game = nethack.Nethack(ttyrec=None, playername="Jurgen-wiz-gno-cha-mal")
assert "race:random" not in game.options
assert "gender:random" not in game.options
assert "align:random" not in game.options

game = nethack.Nethack(
ttyrec=None,
playername="Albert-@",
options=list(nethack.NETHACKOPTIONS) + ["align:lawful"],
)
Expand All @@ -224,6 +225,7 @@ def test_nethack_random_character(self):
assert "align:lawful" in game.options

game = nethack.Nethack(
ttyrec=None,
playername="Rachel",
options=list(nethack.NETHACKOPTIONS) + ["gender:female"],
)
Expand All @@ -236,7 +238,9 @@ def test_nethack_random_character(self):
class TestNethackSomeObs:
@pytest.fixture
def game(self): # Make sure we close even on test failure.
g = nethack.Nethack(observation_keys=("program_state", "message", "internal"))
g = nethack.Nethack(
ttyrec=None, observation_keys=("program_state", "message", "internal")
)
try:
yield g
finally:
Expand Down Expand Up @@ -577,6 +581,7 @@ class TestNethackGlanceObservation:
@pytest.fixture
def game(self): # Make sure we close even on test failure.
g = nethack.Nethack(
ttyrec=None,
playername="MonkBot-mon-hum-neu-mal",
observation_keys=("screen_descriptions", "glyphs", "chars"),
)
Expand Down Expand Up @@ -628,6 +633,7 @@ class TestNethackTerminalObservation:
@pytest.fixture
def game(self): # Make sure we close even on test failure.
g = nethack.Nethack(
ttyrec=None,
playername="MonkBot-mon-hum-neu-mal",
observation_keys=(
"tty_chars",
Expand Down Expand Up @@ -681,7 +687,9 @@ class TestNethackMiscObservation:
@pytest.fixture
def game(self): # Make sure we close even on test failure.
g = nethack.Nethack(
playername="MonkBot-mon-hum-neu-mal", observation_keys=("misc", "internal")
ttyrec=None,
playername="MonkBot-mon-hum-neu-mal",
observation_keys=("misc", "internal"),
)
try:
yield g
Expand Down
Loading