Skip to content

Multi Agent #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions entity_gym/env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .environment import *
from .parallel_env_list import *
from .validator import ValidatingEnv
from .multi_validator import MultiValidatingEnv
from .vec_env import *

__all__ = [
Expand Down Expand Up @@ -64,5 +65,6 @@
"VecSelectEntityActionMask",
# Wrappers
"ValidatingEnv",
"MultiValidatingEnv",
"AddMetricsWrapper",
]
94 changes: 88 additions & 6 deletions entity_gym/env/add_metrics_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy.typing as npt
from ragged_buffer import RaggedBufferI64

from entity_gym.env.environment import ActionName, ActionSpace, ObsSpace
from entity_gym.env.environment import ActionName, ActionSpace, ObsSpace, MultiAgentEnvironment
from entity_gym.env.vec_env import Metric, VecEnv, VecObs


Expand All @@ -25,17 +25,38 @@ def __init__(
self.entity_types = list(env.obs_space().entities.keys())
if env.has_global_entity():
self.entity_types.append("__global__")
self.total_reward = np.zeros(len(env), dtype=np.float32)
self.total_steps = np.zeros(len(env), dtype=np.int64)
self.filter = np.ones(len(env), dtype=np.bool8) if filter is None else filter
if issubclass(type(env.envs[0]), MultiAgentEnvironment): # TO DO: how do I check if env is a multiagent env?
self.is_multi_player = True
self.total_reward = np.zeros((env.envs[0].get_num_players(), len(env)), dtype=np.float32)
self.total_steps = np.zeros(len(env), dtype=np.int64)
self.filter = np.ones(len(env), dtype=np.bool8) if filter is None else filter

else:
self.is_multi_player = False
self.total_reward = np.zeros(len(env), dtype=np.float32)
self.total_steps = np.zeros(len(env), dtype=np.int64)
self.filter = np.ones(len(env), dtype=np.bool8) if filter is None else filter

def reset(self, obs_config: ObsSpace) -> VecObs:
return self.track_metrics(self.env.reset(obs_config))
if self.is_multi_player == True:
obs_list = []
for i, obs in enumerate(self.env.reset(obs_config)):
obs_i = self.multi_track_metrics(i, obs)
obs_list.append(obs_i)
return obs_list
else:
return self.track_metrics(self.env.reset(obs_config))

def act(
self, actions: Mapping[ActionName, RaggedBufferI64], obs_filter: ObsSpace
) -> VecObs:
return self.track_metrics(self.env.act(actions, obs_filter))
if self.is_multi_player == True:
obs_list = []
for i, obs in enumerate(self.env.act(actions, obs_filter)):
obs_list.append(self.multi_track_metrics(i, obs))
return obs_list
else:
return self.track_metrics(self.env.act(actions, obs_filter))

def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]:
return self.env.render(**kwargs)
Expand All @@ -46,6 +67,67 @@ def __len__(self) -> int:
def close(self) -> None:
self.env.close()

def multi_track_metrics(self, id: int, obs: VecObs) -> VecObs:
"""
:param id: player id
:param obs: player VecObs
"""
# Track reward and total steps
self.total_reward[id] += obs.reward
if id == 0: # increment total steps only for the first player
self.total_steps += 1
episodic_reward = Metric()
episodic_length = Metric()
count = len(self.total_steps)
obs.metrics["step"] = Metric(
sum=self.total_steps.sum(),
count=count,
min=self.total_steps.min(),
max=self.total_steps.max(),
)

for entity in self.entity_types:
if entity in obs.features:
_sum = obs.features[entity].items()
counts = obs.features[entity].size1()
_min = counts.min()
_max = counts.max()
else:
_sum = 0
_min = 0
_max = 0
obs.metrics[f"entity_count/{entity}"] = Metric(
sum=_sum, count=count, min=_min, max=_max
)
if len(obs.features) > 0:
combined_counts: Any = sum(
features.size1() for features in obs.features.values()
)
else:
combined_counts = np.zeros(count, dtype=np.int64)
obs.metrics["entity_count"] = Metric(
sum=combined_counts.sum(),
count=count,
min=combined_counts.min(),
max=combined_counts.max(),
)

# Track episodic reward and length for each player
for i_env in np.arange(len(self))[obs.done & self.filter]:
episodic_reward.push(self.total_reward[id, i_env])
episodic_length.push(self.total_steps[i_env])
self.total_reward[id, i_env] = 0.0
self.total_steps[i_env] = 0
obs.metrics["episodic_reward"] = episodic_reward
obs.metrics["episode_length"] = episodic_length
obs.metrics["reward"] = Metric(
sum=obs.reward.sum(),
count=obs.reward.size,
min=obs.reward.min(),
max=obs.reward.max(),
)
return obs

def track_metrics(self, obs: VecObs) -> VecObs:
self.total_reward += obs.reward
self.total_steps += 1
Expand Down
56 changes: 55 additions & 1 deletion entity_gym/env/env_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
SelectEntityActionMask,
SelectEntityActionSpace,
)
from entity_gym.env.vec_env import VecEnv, VecObs, batch_obs
from entity_gym.env.vec_env import VecEnv, VecObs, batch_obs, multi_batch_os


class EnvList(VecEnv):
Expand Down Expand Up @@ -77,6 +77,60 @@ def action_space(self) -> Dict[ActionName, ActionSpace]:
return self._action_space


class MultiEnvList(VecEnv):
def __init__(self, create_env: Callable[[], Environment], num_envs: int):
self.envs = [create_env() for _ in range(num_envs)]
self.last_obs: List[Observation] = []
env = self.envs[0] if num_envs > 0 else create_env()
self._obs_space = env.obs_space()
self._action_space = env.action_space()
def reset(self, obs_space: ObsSpace) -> VecObs:
batch = self._batch_obs([e.reset_filter(obs_space) for e in self.envs])
return batch
def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]:
return np.stack([e.render(**kwargs) for e in self.envs])
def close(self) -> None:
for env in self.envs:
env.close()
def act(
self, actions: Mapping[int, Mapping[str, RaggedBufferI64]], obs_space: ObsSpace
) -> VecObs:
"""
:param actions: Dict of player_id to Dict of action_name to RaggedBufferI64
:param obs_space: ObsSpace
"""
observations = []
action_spaces = self.action_space()
for i_env, env in enumerate(self.envs):
# only need to step environment using actions of current player
current_player = env.get_current_player() # environment must implement get_current_player()
_actions = action_index_to_actions(
self._obs_space, action_spaces, actions[current_player], self.last_obs[i_env][current_player], i_env
)
obs = env.act_filter(_actions, obs_space)

if any([o.done for o in obs]):
new_obs = env.reset_filter(obs_space)
for player_id, no in enumerate(new_obs):
no.done = True
no.reward = obs[player_id].reward
no.metrics = obs[player_id].metrics
observations.append(new_obs)
else:
observations.append(obs)
return self._batch_obs(observations)

def _batch_obs(self, obs: List[List[Observation]]) -> VecObs:
self.last_obs = obs
vecobs = multi_batch_os(obs, self.obs_space(), self.action_space())
return vecobs
def __len__(self) -> int:
return len(self.envs)
def obs_space(self) -> ObsSpace:
return self._obs_space
def action_space(self) -> Dict[ActionName, ActionSpace]:
return self._action_space

def action_index_to_actions(
obs_space: ObsSpace,
action_spaces: Dict[ActionName, ActionSpace],
Expand Down
109 changes: 109 additions & 0 deletions entity_gym/env/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,112 @@ def _compile_feature_filter(self, obs_space: ObsSpace) -> Dict[str, np.ndarray]:
dtype=np.int32,
)
return feature_selection


class MultiAgentEnvironment(ABC):
"""
Abstract base class for all environments.
"""

@abstractmethod
def obs_space(self) -> ObsSpace:
"""
Defines the shape of observations returned by the environment.
"""
raise NotImplementedError

@abstractmethod
def action_space(self) -> Dict[str, ActionSpace]:
"""
Defines the types of actions that can be taken in the environment.
"""
raise NotImplementedError

@abstractmethod
def reset(self) -> Observation:
"""
Resets the environment and returns the initial observation.
"""
raise NotImplementedError

@abstractmethod
# def act(self, actions: Mapping[ActionName, Action]) -> Observation:
def act(self, actions: Union[Mapping[ActionName, Action], List[Mapping[ActionName, Action]]]) -> Union[Observation, List[Observation]]:

"""
Performs the given action and returns the resulting observation.

:param actions: Maps the name of each action type to the action to perform.
"""
raise NotImplementedError

def reset_filter(self, obs_filter: ObsSpace) -> Observation:
"""
Resets the environment and returns the initial observation.
Any entities or features that are not present in the filter are removed from the observation.
"""
return self._filter_obs(self.reset(), obs_filter)

@abstractmethod
def get_num_players(self):
raise NotImplementedError

def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]:
"""
Renders the environment.

:param kwargs: a dictionary of arguments to send to the rendering process
"""
raise NotImplementedError

def act_filter(
self, actions: Mapping[ActionName, Action], obs_filter: ObsSpace
) -> Observation:
"""
Performs the given action and returns the resulting observation.
Any entities or features that are not present in the filter are removed from the observation.
"""
return self._filter_obs(self.act(actions), obs_filter)

def close(self) -> None:
"""Closes the environment."""

def _filter_obs(self, obs_list: List[Observation], obs_filter: ObsSpace) -> List[Observation]:
filtered_obs_list = []
for _, obs in enumerate(obs_list):
selectors = self._compile_feature_filter(obs_filter)
features: Dict[
EntityName, Union[npt.NDArray[np.float32], Sequence[Sequence[float]]]
] = {}
for etype, feats in obs.features.items():
selector = selectors[etype]
if isinstance(feats, np.ndarray):
features[etype] = feats[:, selector].reshape(
feats.shape[0], len(selector)
)
else:
features[etype] = [[entity[i] for i in selector] for entity in feats]
filtered_obs_list.append(Observation(
global_features=obs.global_features,
features=features,
ids=obs.ids,
actions=obs.actions,
done=obs.done,
reward=obs.reward,
metrics=obs.metrics,
visible=obs.visible,
))
return filtered_obs_list

def _compile_feature_filter(self, obs_space: ObsSpace) -> Dict[str, np.ndarray]:
obs_space = self.obs_space()
feature_selection = {}
for entity_name, entity in obs_space.entities.items():
feature_selection[entity_name] = np.array(
[entity.features.index(f) for f in entity.features], dtype=np.int32
)
feature_selection["__global__"] = np.array(
[obs_space.global_features.index(f) for f in obs_space.global_features],
dtype=np.int32,
)
return feature_selection
Loading