Skip to content

Support for Continuous action #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion entity_gym/env/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ def __len__(self) -> int:
return len(self.index_to_label)


@dataclass
class ContinuousActionSpace:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this have a min/max value?

"""
Defines one continuous action that can be taken by multiple entities.
"""

index_to_label: List[str]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is there a discrete set of labels?

"""list of human-readable labels for each action"""

def __len__(self) -> int:
return len(self.index_to_label)


@dataclass
class GlobalCategoricalActionSpace:
"""
Expand Down Expand Up @@ -47,7 +60,10 @@ class SelectEntityActionSpace:


ActionSpace = Union[
CategoricalActionSpace, SelectEntityActionSpace, GlobalCategoricalActionSpace
CategoricalActionSpace,
ContinuousActionSpace,
SelectEntityActionSpace,
GlobalCategoricalActionSpace,
]


Expand Down Expand Up @@ -174,6 +190,27 @@ def labels(self) -> List[str]:
return [self.index_to_label[i] for i in self.indices]


@dataclass
class ContinuousAction:
"""
Outcome of a contious action.
"""

actors: Sequence[EntityID]
"""the ids of the entities that chose the actions"""

values: npt.NDArray[np.int64]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably have the action converted to a float here

"""the indices of the actions that were chosen"""

index_to_label: List[str]
"""mapping from action indices to human readable labels"""

@property
def labels(self) -> List[str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably doesn't make sense for continuous action unless i'm missing something?

"""the human readable labels of the actions that were performed"""
return [self.index_to_label[i] for i in self.values]


@dataclass
class SelectEntityAction:
"""
Expand Down
9 changes: 9 additions & 0 deletions entity_gym/env/env_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
CategoricalAction,
CategoricalActionMask,
CategoricalActionSpace,
ContinuousAction,
ContinuousActionSpace,
EntityID,
Environment,
GlobalCategoricalAction,
Expand Down Expand Up @@ -118,6 +120,13 @@ def action_index_to_actions(
index_to_label=aspace.index_to_label,
probs=probs[atype] if probs is not None else None,
)
if isinstance(action_spaces[atype], ContinuousActionSpace):
values = action[index].as_array().reshape(-1) / np.iinfo(np.int64).max
_actions[atype] = ContinuousAction(
actors=actors,
values=values,
index_to_label=aspace.index_to_label,
)
elif isinstance(action_spaces[atype], SelectEntityActionSpace):
assert isinstance(mask, SelectEntityActionMask)
if mask.actee_types is not None:
Expand Down
16 changes: 14 additions & 2 deletions entity_gym/env/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ActionName,
ActionSpace,
CategoricalActionSpace,
ContinuousActionSpace,
EntityName,
GlobalCategoricalActionSpace,
Observation,
Expand Down Expand Up @@ -293,6 +294,11 @@ def batch_obs(
action_masks[action_name] = VecSelectEntityActionMask(
RaggedBufferI64(1), RaggedBufferI64(1)
)
elif isinstance(space, ContinuousActionSpace):
action_masks[action_name] = VecCategoricalActionMask(
RaggedBufferI64(1),
None,
)
else:
raise NotImplementedError(f"Action space {space} not supported")
if global_entity:
Expand Down Expand Up @@ -377,10 +383,16 @@ def batch_obs(
action_masks[atype] = VecSelectEntityActionMask(
empty_ragged_i64(1, i), empty_ragged_i64(1, i)
)
elif isinstance(space, ContinuousActionSpace()):
action_masks[atype] = VecCategoricalActionMask(
empty_ragged_i64(1, i), None
)
else:
raise ValueError(f"Unknown action space type: {space}")
if isinstance(space, CategoricalActionSpace) or isinstance(
space, GlobalCategoricalActionSpace
if (
isinstance(space, CategoricalActionSpace)
or isinstance(space, GlobalCategoricalActionSpace)
or isinstance(space, ContinuousActionSpace)
):
vec_action = action_masks[atype]
assert isinstance(vec_action, VecCategoricalActionMask)
Expand Down
7 changes: 6 additions & 1 deletion entity_gym/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from entity_gym.env import Environment
from entity_gym.examples.cherry_pick import CherryPick
from entity_gym.examples.contious_slider import ContinuousSlider
from entity_gym.examples.count import Count
from entity_gym.examples.floor_is_lava import FloorIsLava
from entity_gym.examples.minefield import Minefield
Expand All @@ -12,6 +13,7 @@
from entity_gym.examples.not_hotdog import NotHotdog
from entity_gym.examples.pick_matching_balls import PickMatchingBalls
from entity_gym.examples.rock_paper_scissors import RockPaperScissors
from entity_gym.examples.slider import Slider
from entity_gym.examples.tutorial import TreasureHunt
from entity_gym.examples.xor import Xor

Expand All @@ -23,12 +25,14 @@
"MultiSnake": MultiSnake,
"MultiArmedBandit": MultiArmedBandit,
"NotHotdog": NotHotdog,
"Xor": Xor,
"Count": Count,
"FloorIsLava": FloorIsLava,
"MineSweeper": MineSweeper,
"RockPaperScissors": RockPaperScissors,
"TreasureHunt": TreasureHunt,
"Xor": Xor,
"Slider": Slider,
"ContinuousSlider": ContinuousSlider,
}

__all__ = [
Expand All @@ -45,4 +49,5 @@
"MineSweeper",
"RockPaperScissors",
"TreasureHunt",
"Slider",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing "ContinuousSlider"?

]
83 changes: 83 additions & 0 deletions entity_gym/examples/contious_slider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import random
from dataclasses import dataclass
from typing import Dict, Mapping

import numpy as np

from entity_gym.env import (
Action,
ActionSpace,
CategoricalActionMask,
ContinuousAction,
ContinuousActionSpace,
Entity,
Environment,
Observation,
ObsSpace,
)


@dataclass
class ContinuousSlider(Environment):
"""
On each timestep, there is either a generic "Object" entity with a `is_hotdog` property, or a "Hotdog" object.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update comment

The "Player" entity is always present, and has an action to classify the other entity as hotdog or not hotdog.
"""

def obs_space(self) -> ObsSpace:
return ObsSpace(
entities={
"Player": Entity(["step"]),
"Slider": Entity(["value"]),
}
)

def action_space(self) -> Dict[str, ActionSpace]:
return {
"set": ContinuousActionSpace(["slider_value"]),
}

def reset(self) -> Observation:
self.step = 0
self.slider = random.random()
return self.observe()

def act(self, actions: Mapping[str, Action]) -> Observation:
self.step += 1

a = actions["set"]
assert isinstance(a, ContinuousAction), f"{a} is not a CategoricalAction"
# divide index by max int64 to get a float between 0 and 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

misleading comment? this should also be done automatically by the framework

val = a.values[0]
reward = 1 - abs(val - self.slider)
done = True

return self.observe(done, reward)

def observe(self, done: bool = False, reward: float = 0) -> Observation:
return Observation(
features={
"Player": np.array(
[
[
self.step,
]
],
dtype=np.float32,
),
"Slider": np.array(
[
[
self.slider,
]
],
dtype=np.float32,
),
},
actions={
"set": CategoricalActionMask(actor_ids=[0]),
},
ids={"Player": [0]},
reward=reward,
done=done,
)
82 changes: 82 additions & 0 deletions entity_gym/examples/slider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import random
from dataclasses import dataclass
from typing import Dict, Mapping

import numpy as np

from entity_gym.env import (
Action,
ActionSpace,
CategoricalAction,
CategoricalActionMask,
CategoricalActionSpace,
Entity,
Environment,
Observation,
ObsSpace,
)


@dataclass
class Slider(Environment):
"""
On each timestep, there is either a generic "Object" entity with a `is_hotdog` property, or a "Hotdog" object.
The "Player" entity is always present, and has an action to classify the other entity as hotdog or not hotdog.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update comment

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could probably unify the slider and continuous slider environments by passing in a bool that determines whether we use continuous or discrete action


def obs_space(self) -> ObsSpace:
return ObsSpace(
entities={
"Player": Entity(["step"]),
"Slider": Entity(["value"]),
}
)

def action_space(self) -> Dict[str, ActionSpace]:
return {
"classify": CategoricalActionSpace([str(i / 1000) for i in range(1000)]),
}

def reset(self) -> Observation:
self.step = 0
self.slider = random.random()
return self.observe()

def act(self, actions: Mapping[str, Action]) -> Observation:
self.step += 1

a = actions["classify"]
assert isinstance(a, CategoricalAction), f"{a} is not a CategoricalAction"
val = float(a.index_to_label[a.indices[0]])
reward = 1 - abs(val - self.slider)
done = True

return self.observe(done, reward)

def observe(self, done: bool = False, reward: float = 0) -> Observation:
return Observation(
features={
"Player": np.array(
[
[
self.step,
]
],
dtype=np.float32,
),
"Slider": np.array(
[
[
self.slider,
]
],
dtype=np.float32,
),
},
actions={
"classify": CategoricalActionMask(actor_ids=[0]),
},
ids={"Player": [0]},
reward=reward,
done=done,
)