diff --git a/entity_gym/env/action.py b/entity_gym/env/action.py index c6d8fa3..9313666 100644 --- a/entity_gym/env/action.py +++ b/entity_gym/env/action.py @@ -20,6 +20,19 @@ def __len__(self) -> int: return len(self.index_to_label) +@dataclass +class ContinuousActionSpace: + """ + Defines one continuous action that can be taken by multiple entities. + """ + + index_to_label: List[str] + """list of human-readable labels for each action""" + + def __len__(self) -> int: + return len(self.index_to_label) + + @dataclass class GlobalCategoricalActionSpace: """ @@ -47,7 +60,10 @@ class SelectEntityActionSpace: ActionSpace = Union[ - CategoricalActionSpace, SelectEntityActionSpace, GlobalCategoricalActionSpace + CategoricalActionSpace, + ContinuousActionSpace, + SelectEntityActionSpace, + GlobalCategoricalActionSpace, ] @@ -174,6 +190,27 @@ def labels(self) -> List[str]: return [self.index_to_label[i] for i in self.indices] +@dataclass +class ContinuousAction: + """ + Outcome of a contious action. + """ + + actors: Sequence[EntityID] + """the ids of the entities that chose the actions""" + + values: npt.NDArray[np.int64] + """the indices of the actions that were chosen""" + + index_to_label: List[str] + """mapping from action indices to human readable labels""" + + @property + def labels(self) -> List[str]: + """the human readable labels of the actions that were performed""" + return [self.index_to_label[i] for i in self.values] + + @dataclass class SelectEntityAction: """ diff --git a/entity_gym/env/env_list.py b/entity_gym/env/env_list.py index 6ecccd4..d98ca4b 100644 --- a/entity_gym/env/env_list.py +++ b/entity_gym/env/env_list.py @@ -11,6 +11,8 @@ CategoricalAction, CategoricalActionMask, CategoricalActionSpace, + ContinuousAction, + ContinuousActionSpace, EntityID, Environment, GlobalCategoricalAction, @@ -118,6 +120,13 @@ def action_index_to_actions( index_to_label=aspace.index_to_label, probs=probs[atype] if probs is not None else None, ) + if isinstance(action_spaces[atype], ContinuousActionSpace): + values = action[index].as_array().reshape(-1) / np.iinfo(np.int64).max + _actions[atype] = ContinuousAction( + actors=actors, + values=values, + index_to_label=aspace.index_to_label, + ) elif isinstance(action_spaces[atype], SelectEntityActionSpace): assert isinstance(mask, SelectEntityActionMask) if mask.actee_types is not None: diff --git a/entity_gym/env/vec_env.py b/entity_gym/env/vec_env.py index b03d0ca..1ba5d09 100644 --- a/entity_gym/env/vec_env.py +++ b/entity_gym/env/vec_env.py @@ -11,6 +11,7 @@ ActionName, ActionSpace, CategoricalActionSpace, + ContinuousActionSpace, EntityName, GlobalCategoricalActionSpace, Observation, @@ -293,6 +294,11 @@ def batch_obs( action_masks[action_name] = VecSelectEntityActionMask( RaggedBufferI64(1), RaggedBufferI64(1) ) + elif isinstance(space, ContinuousActionSpace): + action_masks[action_name] = VecCategoricalActionMask( + RaggedBufferI64(1), + None, + ) else: raise NotImplementedError(f"Action space {space} not supported") if global_entity: @@ -377,10 +383,16 @@ def batch_obs( action_masks[atype] = VecSelectEntityActionMask( empty_ragged_i64(1, i), empty_ragged_i64(1, i) ) + elif isinstance(space, ContinuousActionSpace()): + action_masks[atype] = VecCategoricalActionMask( + empty_ragged_i64(1, i), None + ) else: raise ValueError(f"Unknown action space type: {space}") - if isinstance(space, CategoricalActionSpace) or isinstance( - space, GlobalCategoricalActionSpace + if ( + isinstance(space, CategoricalActionSpace) + or isinstance(space, GlobalCategoricalActionSpace) + or isinstance(space, ContinuousActionSpace) ): vec_action = action_masks[atype] assert isinstance(vec_action, VecCategoricalActionMask) diff --git a/entity_gym/examples/__init__.py b/entity_gym/examples/__init__.py index 6989ea4..edc97eb 100644 --- a/entity_gym/examples/__init__.py +++ b/entity_gym/examples/__init__.py @@ -2,6 +2,7 @@ from entity_gym.env import Environment from entity_gym.examples.cherry_pick import CherryPick +from entity_gym.examples.contious_slider import ContinuousSlider from entity_gym.examples.count import Count from entity_gym.examples.floor_is_lava import FloorIsLava from entity_gym.examples.minefield import Minefield @@ -12,6 +13,7 @@ from entity_gym.examples.not_hotdog import NotHotdog from entity_gym.examples.pick_matching_balls import PickMatchingBalls from entity_gym.examples.rock_paper_scissors import RockPaperScissors +from entity_gym.examples.slider import Slider from entity_gym.examples.tutorial import TreasureHunt from entity_gym.examples.xor import Xor @@ -23,12 +25,14 @@ "MultiSnake": MultiSnake, "MultiArmedBandit": MultiArmedBandit, "NotHotdog": NotHotdog, - "Xor": Xor, "Count": Count, "FloorIsLava": FloorIsLava, "MineSweeper": MineSweeper, "RockPaperScissors": RockPaperScissors, "TreasureHunt": TreasureHunt, + "Xor": Xor, + "Slider": Slider, + "ContinuousSlider": ContinuousSlider, } __all__ = [ @@ -45,4 +49,5 @@ "MineSweeper", "RockPaperScissors", "TreasureHunt", + "Slider", ] diff --git a/entity_gym/examples/contious_slider.py b/entity_gym/examples/contious_slider.py new file mode 100644 index 0000000..cf76377 --- /dev/null +++ b/entity_gym/examples/contious_slider.py @@ -0,0 +1,83 @@ +import random +from dataclasses import dataclass +from typing import Dict, Mapping + +import numpy as np + +from entity_gym.env import ( + Action, + ActionSpace, + CategoricalActionMask, + ContinuousAction, + ContinuousActionSpace, + Entity, + Environment, + Observation, + ObsSpace, +) + + +@dataclass +class ContinuousSlider(Environment): + """ + On each timestep, there is either a generic "Object" entity with a `is_hotdog` property, or a "Hotdog" object. + The "Player" entity is always present, and has an action to classify the other entity as hotdog or not hotdog. + """ + + def obs_space(self) -> ObsSpace: + return ObsSpace( + entities={ + "Player": Entity(["step"]), + "Slider": Entity(["value"]), + } + ) + + def action_space(self) -> Dict[str, ActionSpace]: + return { + "set": ContinuousActionSpace(["slider_value"]), + } + + def reset(self) -> Observation: + self.step = 0 + self.slider = random.random() + return self.observe() + + def act(self, actions: Mapping[str, Action]) -> Observation: + self.step += 1 + + a = actions["set"] + assert isinstance(a, ContinuousAction), f"{a} is not a CategoricalAction" + # divide index by max int64 to get a float between 0 and 1 + val = a.values[0] + reward = 1 - abs(val - self.slider) + done = True + + return self.observe(done, reward) + + def observe(self, done: bool = False, reward: float = 0) -> Observation: + return Observation( + features={ + "Player": np.array( + [ + [ + self.step, + ] + ], + dtype=np.float32, + ), + "Slider": np.array( + [ + [ + self.slider, + ] + ], + dtype=np.float32, + ), + }, + actions={ + "set": CategoricalActionMask(actor_ids=[0]), + }, + ids={"Player": [0]}, + reward=reward, + done=done, + ) diff --git a/entity_gym/examples/slider.py b/entity_gym/examples/slider.py new file mode 100644 index 0000000..9868fa7 --- /dev/null +++ b/entity_gym/examples/slider.py @@ -0,0 +1,82 @@ +import random +from dataclasses import dataclass +from typing import Dict, Mapping + +import numpy as np + +from entity_gym.env import ( + Action, + ActionSpace, + CategoricalAction, + CategoricalActionMask, + CategoricalActionSpace, + Entity, + Environment, + Observation, + ObsSpace, +) + + +@dataclass +class Slider(Environment): + """ + On each timestep, there is either a generic "Object" entity with a `is_hotdog` property, or a "Hotdog" object. + The "Player" entity is always present, and has an action to classify the other entity as hotdog or not hotdog. + """ + + def obs_space(self) -> ObsSpace: + return ObsSpace( + entities={ + "Player": Entity(["step"]), + "Slider": Entity(["value"]), + } + ) + + def action_space(self) -> Dict[str, ActionSpace]: + return { + "classify": CategoricalActionSpace([str(i / 1000) for i in range(1000)]), + } + + def reset(self) -> Observation: + self.step = 0 + self.slider = random.random() + return self.observe() + + def act(self, actions: Mapping[str, Action]) -> Observation: + self.step += 1 + + a = actions["classify"] + assert isinstance(a, CategoricalAction), f"{a} is not a CategoricalAction" + val = float(a.index_to_label[a.indices[0]]) + reward = 1 - abs(val - self.slider) + done = True + + return self.observe(done, reward) + + def observe(self, done: bool = False, reward: float = 0) -> Observation: + return Observation( + features={ + "Player": np.array( + [ + [ + self.step, + ] + ], + dtype=np.float32, + ), + "Slider": np.array( + [ + [ + self.slider, + ] + ], + dtype=np.float32, + ), + }, + actions={ + "classify": CategoricalActionMask(actor_ids=[0]), + }, + ids={"Player": [0]}, + reward=reward, + done=done, + )