-
Notifications
You must be signed in to change notification settings - Fork 5
Support for Continuous action #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,19 @@ def __len__(self) -> int: | |
return len(self.index_to_label) | ||
|
||
|
||
@dataclass | ||
class ContinuousActionSpace: | ||
""" | ||
Defines one continuous action that can be taken by multiple entities. | ||
""" | ||
|
||
index_to_label: List[str] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is there a discrete set of labels? |
||
"""list of human-readable labels for each action""" | ||
|
||
def __len__(self) -> int: | ||
return len(self.index_to_label) | ||
|
||
|
||
@dataclass | ||
class GlobalCategoricalActionSpace: | ||
""" | ||
|
@@ -47,7 +60,10 @@ class SelectEntityActionSpace: | |
|
||
|
||
ActionSpace = Union[ | ||
CategoricalActionSpace, SelectEntityActionSpace, GlobalCategoricalActionSpace | ||
CategoricalActionSpace, | ||
ContinuousActionSpace, | ||
SelectEntityActionSpace, | ||
GlobalCategoricalActionSpace, | ||
] | ||
|
||
|
||
|
@@ -174,6 +190,27 @@ def labels(self) -> List[str]: | |
return [self.index_to_label[i] for i in self.indices] | ||
|
||
|
||
@dataclass | ||
class ContinuousAction: | ||
""" | ||
Outcome of a contious action. | ||
""" | ||
|
||
actors: Sequence[EntityID] | ||
"""the ids of the entities that chose the actions""" | ||
|
||
values: npt.NDArray[np.int64] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should probably have the action converted to a float here |
||
"""the indices of the actions that were chosen""" | ||
|
||
index_to_label: List[str] | ||
"""mapping from action indices to human readable labels""" | ||
|
||
@property | ||
def labels(self) -> List[str]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably doesn't make sense for continuous action unless i'm missing something? |
||
"""the human readable labels of the actions that were performed""" | ||
return [self.index_to_label[i] for i in self.values] | ||
|
||
|
||
@dataclass | ||
class SelectEntityAction: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
from entity_gym.env import Environment | ||
from entity_gym.examples.cherry_pick import CherryPick | ||
from entity_gym.examples.contious_slider import ContinuousSlider | ||
from entity_gym.examples.count import Count | ||
from entity_gym.examples.floor_is_lava import FloorIsLava | ||
from entity_gym.examples.minefield import Minefield | ||
|
@@ -12,6 +13,7 @@ | |
from entity_gym.examples.not_hotdog import NotHotdog | ||
from entity_gym.examples.pick_matching_balls import PickMatchingBalls | ||
from entity_gym.examples.rock_paper_scissors import RockPaperScissors | ||
from entity_gym.examples.slider import Slider | ||
from entity_gym.examples.tutorial import TreasureHunt | ||
from entity_gym.examples.xor import Xor | ||
|
||
|
@@ -23,12 +25,14 @@ | |
"MultiSnake": MultiSnake, | ||
"MultiArmedBandit": MultiArmedBandit, | ||
"NotHotdog": NotHotdog, | ||
"Xor": Xor, | ||
"Count": Count, | ||
"FloorIsLava": FloorIsLava, | ||
"MineSweeper": MineSweeper, | ||
"RockPaperScissors": RockPaperScissors, | ||
"TreasureHunt": TreasureHunt, | ||
"Xor": Xor, | ||
"Slider": Slider, | ||
"ContinuousSlider": ContinuousSlider, | ||
} | ||
|
||
__all__ = [ | ||
|
@@ -45,4 +49,5 @@ | |
"MineSweeper", | ||
"RockPaperScissors", | ||
"TreasureHunt", | ||
"Slider", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing "ContinuousSlider"? |
||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import random | ||
from dataclasses import dataclass | ||
from typing import Dict, Mapping | ||
|
||
import numpy as np | ||
|
||
from entity_gym.env import ( | ||
Action, | ||
ActionSpace, | ||
CategoricalActionMask, | ||
ContinuousAction, | ||
ContinuousActionSpace, | ||
Entity, | ||
Environment, | ||
Observation, | ||
ObsSpace, | ||
) | ||
|
||
|
||
@dataclass | ||
class ContinuousSlider(Environment): | ||
""" | ||
On each timestep, there is either a generic "Object" entity with a `is_hotdog` property, or a "Hotdog" object. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update comment |
||
The "Player" entity is always present, and has an action to classify the other entity as hotdog or not hotdog. | ||
""" | ||
|
||
def obs_space(self) -> ObsSpace: | ||
return ObsSpace( | ||
entities={ | ||
"Player": Entity(["step"]), | ||
"Slider": Entity(["value"]), | ||
} | ||
) | ||
|
||
def action_space(self) -> Dict[str, ActionSpace]: | ||
return { | ||
"set": ContinuousActionSpace(["slider_value"]), | ||
} | ||
|
||
def reset(self) -> Observation: | ||
self.step = 0 | ||
self.slider = random.random() | ||
return self.observe() | ||
|
||
def act(self, actions: Mapping[str, Action]) -> Observation: | ||
self.step += 1 | ||
|
||
a = actions["set"] | ||
assert isinstance(a, ContinuousAction), f"{a} is not a CategoricalAction" | ||
# divide index by max int64 to get a float between 0 and 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. misleading comment? this should also be done automatically by the framework |
||
val = a.values[0] | ||
reward = 1 - abs(val - self.slider) | ||
done = True | ||
|
||
return self.observe(done, reward) | ||
|
||
def observe(self, done: bool = False, reward: float = 0) -> Observation: | ||
return Observation( | ||
features={ | ||
"Player": np.array( | ||
[ | ||
[ | ||
self.step, | ||
] | ||
], | ||
dtype=np.float32, | ||
), | ||
"Slider": np.array( | ||
[ | ||
[ | ||
self.slider, | ||
] | ||
], | ||
dtype=np.float32, | ||
), | ||
}, | ||
actions={ | ||
"set": CategoricalActionMask(actor_ids=[0]), | ||
}, | ||
ids={"Player": [0]}, | ||
reward=reward, | ||
done=done, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import random | ||
from dataclasses import dataclass | ||
from typing import Dict, Mapping | ||
|
||
import numpy as np | ||
|
||
from entity_gym.env import ( | ||
Action, | ||
ActionSpace, | ||
CategoricalAction, | ||
CategoricalActionMask, | ||
CategoricalActionSpace, | ||
Entity, | ||
Environment, | ||
Observation, | ||
ObsSpace, | ||
) | ||
|
||
|
||
@dataclass | ||
class Slider(Environment): | ||
""" | ||
On each timestep, there is either a generic "Object" entity with a `is_hotdog` property, or a "Hotdog" object. | ||
The "Player" entity is always present, and has an action to classify the other entity as hotdog or not hotdog. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update comment |
||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could probably unify the slider and continuous slider environments by passing in a bool that determines whether we use continuous or discrete action |
||
|
||
def obs_space(self) -> ObsSpace: | ||
return ObsSpace( | ||
entities={ | ||
"Player": Entity(["step"]), | ||
"Slider": Entity(["value"]), | ||
} | ||
) | ||
|
||
def action_space(self) -> Dict[str, ActionSpace]: | ||
return { | ||
"classify": CategoricalActionSpace([str(i / 1000) for i in range(1000)]), | ||
} | ||
|
||
def reset(self) -> Observation: | ||
self.step = 0 | ||
self.slider = random.random() | ||
return self.observe() | ||
|
||
def act(self, actions: Mapping[str, Action]) -> Observation: | ||
self.step += 1 | ||
|
||
a = actions["classify"] | ||
assert isinstance(a, CategoricalAction), f"{a} is not a CategoricalAction" | ||
val = float(a.index_to_label[a.indices[0]]) | ||
reward = 1 - abs(val - self.slider) | ||
done = True | ||
|
||
return self.observe(done, reward) | ||
|
||
def observe(self, done: bool = False, reward: float = 0) -> Observation: | ||
return Observation( | ||
features={ | ||
"Player": np.array( | ||
[ | ||
[ | ||
self.step, | ||
] | ||
], | ||
dtype=np.float32, | ||
), | ||
"Slider": np.array( | ||
[ | ||
[ | ||
self.slider, | ||
] | ||
], | ||
dtype=np.float32, | ||
), | ||
}, | ||
actions={ | ||
"classify": CategoricalActionMask(actor_ids=[0]), | ||
}, | ||
ids={"Player": [0]}, | ||
reward=reward, | ||
done=done, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this have a min/max value?