diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index f3bdad4..78655dd 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -49,10 +49,10 @@ jobs: #---------------------------------------------- - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry run pip install torch + run: pip install torch==1.12.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry run pip install torch-scatter + run: pip install torch-scatter -f https://data.pyg.org/whl/torch-1.12.0+cpu.html #---------------------------------------------- # install your root project, if required #---------------------------------------------- @@ -109,10 +109,10 @@ jobs: #---------------------------------------------- - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry run pip install torch + run: pip install torch==1.12.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry run pip install torch-scatter + run: pip install torch-scatter -f https://data.pyg.org/whl/torch-1.12.0+cpu.html #---------------------------------------------- # install your root project, if required #---------------------------------------------- diff --git a/entity_gym/env/vec_env.py b/entity_gym/env/vec_env.py index b03d0ca..eb9c528 100644 --- a/entity_gym/env/vec_env.py +++ b/entity_gym/env/vec_env.py @@ -347,7 +347,9 @@ def batch_obs( for atype, space in action_space.items(): if atype not in o.actions: if atype in action_masks: - if isinstance(space, CategoricalActionSpace): + if isinstance(space, CategoricalActionSpace) or isinstance( + space, GlobalCategoricalActionSpace + ): vec_action = action_masks[atype] assert isinstance(vec_action, VecCategoricalActionMask) vec_action.actors.push(np.zeros((0, 1), dtype=np.int64))