Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion example/small_rooms_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def step(self, action):
reward += 10

# Return (next_state, reward, terminal, info).
return self.current_state, reward, self.is_state_terminal(self.current_state), {}
return (
self.current_state,
reward,
self.is_state_terminal(self.current_state),
{},
)

def get_action_space(self):
# The agent has four actions (up, down, left, right).
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="simpleoptions",
version="0.10.0",
version="0.11.0",
author="Joshua Evans",
author_email="jbe25@bath.ac.uk",
description="A simple and flexible framework for working with Options in Reinforcement Learning.",
Expand Down
11 changes: 11 additions & 0 deletions simpleoptions/environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import random

import numpy as np
import networkx as nx

from typing import List, Set
Expand Down Expand Up @@ -75,6 +76,16 @@ def render(self, mode: str = "human") -> None:
"""
pass

def seed(self, random_seed: int) -> None:
"""
Seed the environment's random number generator(s).

Args:
random_seed (int): The random seed to use for random number generation.
"""
random.seed(random_seed)
np.random.seed(random_seed)

@abstractmethod
def close(self):
"""
Expand Down
5 changes: 4 additions & 1 deletion simpleoptions/function_approximation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sys

from simpleoptions.function_approximation.environment import ApproxBaseEnvironment, GymWrapper
from simpleoptions.function_approximation.environment import (
ApproxBaseEnvironment,
GymWrapper,
)
from simpleoptions.function_approximation.primitive_option import PrimitiveOption

__all__ = ["ApproxBaseEnvironment", "GymWrapper", "PrimitiveOption"]
Expand Down
15 changes: 15 additions & 0 deletions simpleoptions/function_approximation/environment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import random

import numpy as np
import gymnasium as gym

Expand Down Expand Up @@ -59,6 +61,16 @@ def render(self) -> None:
"""
pass

def seed(self, random_seed: int) -> None:
"""
Seed the environment's random number generator(s).

Args:
random_seed (int): The random seed to use for random number generation.
"""
random.seed(random_seed)
np.random.seed(random_seed)

@abstractmethod
def close(self) -> None:
"""
Expand Down Expand Up @@ -218,6 +230,9 @@ def step(self, action: Hashable, state: Hashable = None) -> Tuple[Hashable, floa
def render(self) -> None:
return self.env.render()

def seed(self, random_seed: int) -> None:
return self.env.seed(random_seed)

def close(self) -> None:
return self.env.close()

Expand Down
19 changes: 15 additions & 4 deletions simpleoptions/implementations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# Generic Option Generators.
from simpleoptions.implementations.generic_option_generator import GenericOptionGenerator
from simpleoptions.implementations.subgoal_option_generator import SubgoalOptionGenerator, SubgoalOption
from simpleoptions.implementations.generic_option_generator import (
GenericOptionGenerator,
)
from simpleoptions.implementations.subgoal_option_generator import (
SubgoalOptionGenerator,
SubgoalOption,
)

# Skill Discovery Algorithm Implementations.
from simpleoptions.implementations.eigenoptions import EigenoptionGenerator, Eigenoption
from simpleoptions.implementations.diffusion_options import DiffusionOptionGenerator, DiffusionOption
from simpleoptions.implementations.betweenness import BetweennessOptionGenerator, BetweennessOption
from simpleoptions.implementations.diffusion_options import (
DiffusionOptionGenerator,
DiffusionOption,
)
from simpleoptions.implementations.betweenness import (
BetweennessOptionGenerator,
BetweennessOption,
)


__all__ = [
Expand Down
10 changes: 8 additions & 2 deletions simpleoptions/implementations/betweenness.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,16 @@ def generate_options(
# Define options for reaching each subgoal.
options = [None for _ in range(len(subgoals))]
for i, subgoal in tqdm(enumerate(subgoals), desc="Training Betweeness Options..."):
initiation_set = sorted(list(nx.single_target_shortest_path_length(stg, subgoal)), key=lambda x: x[1])
initiation_set = sorted(
list(nx.single_target_shortest_path_length(stg, subgoal)),
key=lambda x: x[1],
)
initiation_set = list(list(zip(*initiation_set))[0])[1 : self.initiation_set_size + 1]
options[i] = BetweennessOption(
env=env, subgoal=subgoal, initiation_set=set(initiation_set), betweenness=centralities[subgoal]
env=env,
subgoal=subgoal,
initiation_set=set(initiation_set),
betweenness=centralities[subgoal],
)
self.train_option(options[i])

Expand Down
8 changes: 7 additions & 1 deletion simpleoptions/implementations/diffusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ def _is_local_maxima(self, node: Hashable, stg: nx.Graph, centralities: Dict):


class DiffusionOption(SubgoalOption):
def __init__(self, env: BaseEnvironment, subgoal: Hashable, initiation_set: Set[Hashable], q_table: Dict = None):
def __init__(
self,
env: BaseEnvironment,
subgoal: Hashable,
initiation_set: Set[Hashable],
q_table: Dict = None,
):
super().__init__(env, subgoal, initiation_set, q_table)

def __str__(self):
Expand Down
8 changes: 7 additions & 1 deletion simpleoptions/implementations/subgoal_option_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def _select_action(self, state: Hashable, option: "SubgoalOption", q_table: Dict


class SubgoalOption(BaseOption):
def __init__(self, env: BaseEnvironment, subgoal: Hashable, initiation_set: Set[Hashable], q_table: Dict = None):
def __init__(
self,
env: BaseEnvironment,
subgoal: Hashable,
initiation_set: Set[Hashable],
q_table: Dict = None,
):
self.env = copy.copy(env)
self.subgoal = subgoal
self.initiation_set = initiation_set
Expand Down
14 changes: 9 additions & 5 deletions simpleoptions/options_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,11 @@ def run_agent(
if not epoch_eval and episodic_eval:
return self.training_log, self.episodic_evaluation_log
if epoch_eval and episodic_eval:
return self.training_log, self.epoch_evaluation_log, self.episodic_evaluation_log
return (
self.training_log,
self.epoch_evaluation_log,
self.episodic_evaluation_log,
)

else:
training_epoch_rewards = [
Expand Down Expand Up @@ -488,11 +492,11 @@ def test_policy(
}
for key, value in transition.items():
if not episodic_eval:
self.epoch_evaluation_log[f"evaluation_{eval_number}"][f"run_{test_run+1}"][key].append(
value
)
self.epoch_evaluation_log[f"evaluation_{eval_number}"][f"run_{test_run + 1}"][
key
].append(value)
else:
self.episodic_evaluation_log[f"evaluation_{eval_number}"][f"run_{test_run+1}"][
self.episodic_evaluation_log[f"evaluation_{eval_number}"][f"run_{test_run + 1}"][
key
].append(value)
# Reset environment and continue evaluation run.
Expand Down