diff --git a/entity_gym/env/__init__.py b/entity_gym/env/__init__.py index d6b72e9..1d91de1 100644 --- a/entity_gym/env/__init__.py +++ b/entity_gym/env/__init__.py @@ -30,6 +30,7 @@ from .environment import * from .parallel_env_list import * from .validator import ValidatingEnv +from .multi_validator import MultiValidatingEnv from .vec_env import * __all__ = [ @@ -64,5 +65,6 @@ "VecSelectEntityActionMask", # Wrappers "ValidatingEnv", + "MultiValidatingEnv", "AddMetricsWrapper", ] diff --git a/entity_gym/env/add_metrics_wrapper.py b/entity_gym/env/add_metrics_wrapper.py index a361fa7..cb5abb6 100644 --- a/entity_gym/env/add_metrics_wrapper.py +++ b/entity_gym/env/add_metrics_wrapper.py @@ -4,7 +4,7 @@ import numpy.typing as npt from ragged_buffer import RaggedBufferI64 -from entity_gym.env.environment import ActionName, ActionSpace, ObsSpace +from entity_gym.env.environment import ActionName, ActionSpace, ObsSpace, MultiAgentEnvironment from entity_gym.env.vec_env import Metric, VecEnv, VecObs @@ -25,17 +25,38 @@ def __init__( self.entity_types = list(env.obs_space().entities.keys()) if env.has_global_entity(): self.entity_types.append("__global__") - self.total_reward = np.zeros(len(env), dtype=np.float32) - self.total_steps = np.zeros(len(env), dtype=np.int64) - self.filter = np.ones(len(env), dtype=np.bool8) if filter is None else filter + if issubclass(type(env.envs[0]), MultiAgentEnvironment): # TO DO: how do I check if env is a multiagent env? + self.is_multi_player = True + self.total_reward = np.zeros((env.envs[0].get_num_players(), len(env)), dtype=np.float32) + self.total_steps = np.zeros(len(env), dtype=np.int64) + self.filter = np.ones(len(env), dtype=np.bool8) if filter is None else filter + + else: + self.is_multi_player = False + self.total_reward = np.zeros(len(env), dtype=np.float32) + self.total_steps = np.zeros(len(env), dtype=np.int64) + self.filter = np.ones(len(env), dtype=np.bool8) if filter is None else filter def reset(self, obs_config: ObsSpace) -> VecObs: - return self.track_metrics(self.env.reset(obs_config)) + if self.is_multi_player == True: + obs_list = [] + for i, obs in enumerate(self.env.reset(obs_config)): + obs_i = self.multi_track_metrics(i, obs) + obs_list.append(obs_i) + return obs_list + else: + return self.track_metrics(self.env.reset(obs_config)) def act( self, actions: Mapping[ActionName, RaggedBufferI64], obs_filter: ObsSpace ) -> VecObs: - return self.track_metrics(self.env.act(actions, obs_filter)) + if self.is_multi_player == True: + obs_list = [] + for i, obs in enumerate(self.env.act(actions, obs_filter)): + obs_list.append(self.multi_track_metrics(i, obs)) + return obs_list + else: + return self.track_metrics(self.env.act(actions, obs_filter)) def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: return self.env.render(**kwargs) @@ -46,6 +67,67 @@ def __len__(self) -> int: def close(self) -> None: self.env.close() + def multi_track_metrics(self, id: int, obs: VecObs) -> VecObs: + """ + :param id: player id + :param obs: player VecObs + """ + # Track reward and total steps + self.total_reward[id] += obs.reward + if id == 0: # increment total steps only for the first player + self.total_steps += 1 + episodic_reward = Metric() + episodic_length = Metric() + count = len(self.total_steps) + obs.metrics["step"] = Metric( + sum=self.total_steps.sum(), + count=count, + min=self.total_steps.min(), + max=self.total_steps.max(), + ) + + for entity in self.entity_types: + if entity in obs.features: + _sum = obs.features[entity].items() + counts = obs.features[entity].size1() + _min = counts.min() + _max = counts.max() + else: + _sum = 0 + _min = 0 + _max = 0 + obs.metrics[f"entity_count/{entity}"] = Metric( + sum=_sum, count=count, min=_min, max=_max + ) + if len(obs.features) > 0: + combined_counts: Any = sum( + features.size1() for features in obs.features.values() + ) + else: + combined_counts = np.zeros(count, dtype=np.int64) + obs.metrics["entity_count"] = Metric( + sum=combined_counts.sum(), + count=count, + min=combined_counts.min(), + max=combined_counts.max(), + ) + + # Track episodic reward and length for each player + for i_env in np.arange(len(self))[obs.done & self.filter]: + episodic_reward.push(self.total_reward[id, i_env]) + episodic_length.push(self.total_steps[i_env]) + self.total_reward[id, i_env] = 0.0 + self.total_steps[i_env] = 0 + obs.metrics["episodic_reward"] = episodic_reward + obs.metrics["episode_length"] = episodic_length + obs.metrics["reward"] = Metric( + sum=obs.reward.sum(), + count=obs.reward.size, + min=obs.reward.min(), + max=obs.reward.max(), + ) + return obs + def track_metrics(self, obs: VecObs) -> VecObs: self.total_reward += obs.reward self.total_steps += 1 diff --git a/entity_gym/env/env_list.py b/entity_gym/env/env_list.py index 6ecccd4..d40be0a 100644 --- a/entity_gym/env/env_list.py +++ b/entity_gym/env/env_list.py @@ -21,7 +21,7 @@ SelectEntityActionMask, SelectEntityActionSpace, ) -from entity_gym.env.vec_env import VecEnv, VecObs, batch_obs +from entity_gym.env.vec_env import VecEnv, VecObs, batch_obs, multi_batch_os class EnvList(VecEnv): @@ -77,6 +77,60 @@ def action_space(self) -> Dict[ActionName, ActionSpace]: return self._action_space +class MultiEnvList(VecEnv): + def __init__(self, create_env: Callable[[], Environment], num_envs: int): + self.envs = [create_env() for _ in range(num_envs)] + self.last_obs: List[Observation] = [] + env = self.envs[0] if num_envs > 0 else create_env() + self._obs_space = env.obs_space() + self._action_space = env.action_space() + def reset(self, obs_space: ObsSpace) -> VecObs: + batch = self._batch_obs([e.reset_filter(obs_space) for e in self.envs]) + return batch + def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: + return np.stack([e.render(**kwargs) for e in self.envs]) + def close(self) -> None: + for env in self.envs: + env.close() + def act( + self, actions: Mapping[int, Mapping[str, RaggedBufferI64]], obs_space: ObsSpace + ) -> VecObs: + """ + :param actions: Dict of player_id to Dict of action_name to RaggedBufferI64 + :param obs_space: ObsSpace + """ + observations = [] + action_spaces = self.action_space() + for i_env, env in enumerate(self.envs): + # only need to step environment using actions of current player + current_player = env.get_current_player() # environment must implement get_current_player() + _actions = action_index_to_actions( + self._obs_space, action_spaces, actions[current_player], self.last_obs[i_env][current_player], i_env + ) + obs = env.act_filter(_actions, obs_space) + + if any([o.done for o in obs]): + new_obs = env.reset_filter(obs_space) + for player_id, no in enumerate(new_obs): + no.done = True + no.reward = obs[player_id].reward + no.metrics = obs[player_id].metrics + observations.append(new_obs) + else: + observations.append(obs) + return self._batch_obs(observations) + + def _batch_obs(self, obs: List[List[Observation]]) -> VecObs: + self.last_obs = obs + vecobs = multi_batch_os(obs, self.obs_space(), self.action_space()) + return vecobs + def __len__(self) -> int: + return len(self.envs) + def obs_space(self) -> ObsSpace: + return self._obs_space + def action_space(self) -> Dict[ActionName, ActionSpace]: + return self._action_space + def action_index_to_actions( obs_space: ObsSpace, action_spaces: Dict[ActionName, ActionSpace], diff --git a/entity_gym/env/environment.py b/entity_gym/env/environment.py index 5a36a36..b6d6a1a 100644 --- a/entity_gym/env/environment.py +++ b/entity_gym/env/environment.py @@ -294,3 +294,112 @@ def _compile_feature_filter(self, obs_space: ObsSpace) -> Dict[str, np.ndarray]: dtype=np.int32, ) return feature_selection + + +class MultiAgentEnvironment(ABC): + """ + Abstract base class for all environments. + """ + + @abstractmethod + def obs_space(self) -> ObsSpace: + """ + Defines the shape of observations returned by the environment. + """ + raise NotImplementedError + + @abstractmethod + def action_space(self) -> Dict[str, ActionSpace]: + """ + Defines the types of actions that can be taken in the environment. + """ + raise NotImplementedError + + @abstractmethod + def reset(self) -> Observation: + """ + Resets the environment and returns the initial observation. + """ + raise NotImplementedError + + @abstractmethod + # def act(self, actions: Mapping[ActionName, Action]) -> Observation: + def act(self, actions: Union[Mapping[ActionName, Action], List[Mapping[ActionName, Action]]]) -> Union[Observation, List[Observation]]: + + """ + Performs the given action and returns the resulting observation. + + :param actions: Maps the name of each action type to the action to perform. + """ + raise NotImplementedError + + def reset_filter(self, obs_filter: ObsSpace) -> Observation: + """ + Resets the environment and returns the initial observation. + Any entities or features that are not present in the filter are removed from the observation. + """ + return self._filter_obs(self.reset(), obs_filter) + + @abstractmethod + def get_num_players(self): + raise NotImplementedError + + def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: + """ + Renders the environment. + + :param kwargs: a dictionary of arguments to send to the rendering process + """ + raise NotImplementedError + + def act_filter( + self, actions: Mapping[ActionName, Action], obs_filter: ObsSpace + ) -> Observation: + """ + Performs the given action and returns the resulting observation. + Any entities or features that are not present in the filter are removed from the observation. + """ + return self._filter_obs(self.act(actions), obs_filter) + + def close(self) -> None: + """Closes the environment.""" + + def _filter_obs(self, obs_list: List[Observation], obs_filter: ObsSpace) -> List[Observation]: + filtered_obs_list = [] + for _, obs in enumerate(obs_list): + selectors = self._compile_feature_filter(obs_filter) + features: Dict[ + EntityName, Union[npt.NDArray[np.float32], Sequence[Sequence[float]]] + ] = {} + for etype, feats in obs.features.items(): + selector = selectors[etype] + if isinstance(feats, np.ndarray): + features[etype] = feats[:, selector].reshape( + feats.shape[0], len(selector) + ) + else: + features[etype] = [[entity[i] for i in selector] for entity in feats] + filtered_obs_list.append(Observation( + global_features=obs.global_features, + features=features, + ids=obs.ids, + actions=obs.actions, + done=obs.done, + reward=obs.reward, + metrics=obs.metrics, + visible=obs.visible, + )) + return filtered_obs_list + + def _compile_feature_filter(self, obs_space: ObsSpace) -> Dict[str, np.ndarray]: + obs_space = self.obs_space() + feature_selection = {} + for entity_name, entity in obs_space.entities.items(): + feature_selection[entity_name] = np.array( + [entity.features.index(f) for f in entity.features], dtype=np.int32 + ) + feature_selection["__global__"] = np.array( + [obs_space.global_features.index(f) for f in obs_space.global_features], + dtype=np.int32, + ) + return feature_selection \ No newline at end of file diff --git a/entity_gym/env/multi_validator.py b/entity_gym/env/multi_validator.py new file mode 100644 index 0000000..01140c7 --- /dev/null +++ b/entity_gym/env/multi_validator.py @@ -0,0 +1,180 @@ +from typing import Any, Dict, Mapping, Union, List + +import numpy as np +import numpy.typing as npt + +from entity_gym.env.environment import ( + Action, + ActionName, + ActionSpace, + CategoricalActionMask, + CategoricalActionSpace, + Environment, + MultiAgentEnvironment, + Observation, + ObsSpace, + SelectEntityActionMask, + SelectEntityActionSpace, +) + + +class MultiValidatingEnv(MultiAgentEnvironment): + def __init__(self, env: Environment) -> None: + self.env = env + self._obs_space = env.obs_space() + self._action_space = env.action_space() + + def act(self, actions: Mapping[ActionName, Action]) -> Observation: + obs = self.env.act(actions) + try: + self._validate(obs) + except AssertionError as e: + print(f"Invalid observation:\n{e}") + raise e + return obs + + def reset(self) -> Observation: + obs = self.env.reset() + try: + self._validate(obs) + except AssertionError as e: + print(f"Invalid observation:\n{e}") + raise e + return obs + + def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: + return self.env.render(**kwargs) + + def obs_space(self) -> ObsSpace: + return self._obs_space + + def action_space(self) -> Dict[str, ActionSpace]: + return self._action_space + + def get_num_players(self): + return self.env.get_num_players() + + def get_current_player(self): + return self.env.get_current_player() + + def _validate_features(self, obs: Observation) -> None: + for entity_type, entity_features in obs.features.items(): + assert ( + entity_type in self._obs_space.entities + ), f"Features contain entity of type '{entity_type}' which is not in observation space: {list(self._obs_space.entities.keys())}" + if isinstance(entity_features, np.ndarray): + assert ( + entity_features.dtype == np.float32 + ), f"Features of entity of type '{entity_type}' have invalid dtype: {entity_features.dtype}. Expected: {np.float32}" + shape = entity_features.shape + assert len(shape) == 2 and shape[1] == len( + self._obs_space.entities[entity_type].features + ), f"Features of entity of type '{entity_type}' have invalid shape: {shape}. Expected: (n, {len(self._obs_space.entities[entity_type].features)})" + else: + for i, entity in enumerate(entity_features): + assert len(entity) == len( + self._obs_space.entities[entity_type].features + ), f"Features of {i}-th entity of type '{entity_type}' have invalid length: {len(entity)}. Expected: {len(self._obs_space.entities[entity_type].features)}" + + if entity_type in obs.ids: + assert len(obs.ids[entity_type]) == len( + entity_features + ), f"Length of ids of entity of type '{entity_type}' does not match length of features: {len(obs.ids[entity_type])} != {len(entity_features)}" + + def _validate_global_features(self, obs: Observation) -> None: + if len(obs.global_features) != len(self._obs_space.global_features): + raise AssertionError( + f"Length of global features does not match length of global features in observation space: {len(obs.global_features)} != {len(self._obs_space.global_features)}" + ) + + def _validate_ids(self, obs: Observation) -> None: + # Validate ids + previous_ids = set() + for entity_type, entity_ids in obs.ids.items(): + assert ( + entity_type in self._obs_space.entities + ), f"IDs contain entity of type '{entity_type}' which is not in observation space: {list(self._obs_space.entities.keys())}" + for id in entity_ids: + assert id not in previous_ids, f"Observation has duplicate id '{id}'" + previous_ids.add(id) + + def _validate_action_masks(self, obs: Observation) -> None: + # Validate actions + ids = obs.id_to_index(self._obs_space) + + for action_type, action_mask in obs.actions.items(): + + assert ( + action_type in self._action_space + ), f"Actions contain action of type '{action_type}' which is not in action space: {list(self._action_space.keys())}" + space = self._action_space[action_type] + if isinstance(space, CategoricalActionSpace): + assert isinstance( + action_mask, CategoricalActionMask + ), f"Action of type '{action_type}' has invalid type: {type(action_mask)}. Expected: CategoricalActionMask" + if action_mask.actor_ids is not None: + for id in action_mask.actor_ids: + assert ( + id in ids + ), f"Action of type '{action_type}' contains invalid actor id {id} which is not in ids: {obs.ids}" + if action_mask.actor_types is not None: + for actor_type in action_mask.actor_types: + assert ( + actor_type in obs.ids + ), f"Action of type '{action_type}' contains invalid actor type {actor_type} which is not in ids: {obs.ids.keys()}" + mask = action_mask.mask + actor_indices = obs._actor_indices(action_type, self._obs_space) + if isinstance(mask, np.ndarray): + assert ( + mask.dtype == np.bool_ + ), f"Action of type '{action_type}' has invalid dtype: {mask.dtype}. Expected: {np.bool_}" + shape = mask.shape + if shape[0] != 0: + assert shape == ( + len(actor_indices), + len(space.index_to_label), + ), f"Action of type '{action_type}' has invalid shape: {shape}. Expected: ({len(actor_indices), len(space.index_to_label)})" + unmasked_count = mask.sum(axis=1) + for i in range(len(unmasked_count)): + assert ( + unmasked_count[i] > 0 + ), f"Action of type '{action_type}' contains invalid mask for {i}-th actor: {mask[i]}. Expected at least one possible action" + elif mask is not None: + assert len(mask) == len( + actor_indices + ), f"Action of type '{action_type}' has invalid length: {len(mask)}. Expected: {len(actor_indices)}" + for i in range(len(mask)): + assert len(mask[i]) == len( + space.index_to_label + ), f"Action of type '{action_type}' has invalid length of mask for {i}-th actor: {len(mask[i])}. Expected: {len(space.index_to_label)}" + assert any( + mask[i] + ), f"Action of type '{action_type}' contains invalid mask for {i}-th actor: {mask[i]}. Expected at least one possible action" + + elif isinstance(self._action_space[action_type], SelectEntityActionSpace): + assert isinstance( + action_mask, SelectEntityActionMask + ), f"Action of type '{action_type}' has invalid type: {type(action_mask)}. Expected: SelectEntityActionMask" + + + def _validate(self, obs: Union[Observation, List[Observation]]) -> None: + assert isinstance(obs, Observation) or isinstance(obs, list), f"Observation has invalid type: {type(obs)}" + if isinstance(obs, list): + for o in obs: + self._validate_features(o) # Validate features + self._validate_global_features(o) # Validate global features + self._validate_ids(o) # Validate ids + self._validate_action_masks(o) # Validate action masks + else: + self._validate_features(obs) # Validate features + self._validate_global_features(obs) # Validate global features + self._validate_ids(obs) # Validate ids + self._validate_action_masks(obs) # Validate action masks + + + + + + + + \ No newline at end of file diff --git a/entity_gym/env/vec_env.py b/entity_gym/env/vec_env.py index b03d0ca..f5e174f 100644 --- a/entity_gym/env/vec_env.py +++ b/entity_gym/env/vec_env.py @@ -258,6 +258,11 @@ def has_global_entity(self) -> bool: for space in self.action_space().values() ) +def multi_batch_os(obs_list: List[List[Observation]], obs_space: ObsSpace, action_space: Dict[str, ActionSpace]): + vecobs_list = [] + for obs in zip(*obs_list): + vecobs_list.append(batch_obs(obs, obs_space, action_space)) + return vecobs_list def batch_obs( obs: List[Observation], obs_space: ObsSpace, action_space: Dict[str, ActionSpace]