Skip to content

Commit 8957e93

Browse files
committed
add documentations and tests
1 parent 986d36e commit 8957e93

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

source/isaaclab/isaaclab/envs/mdp/curriculums.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,33 @@ class modify_env_param(ManagerTermBase):
4646
`cfg.params["address"]`) the first time it is called, then on each invocation
4747
reads the current value, applies a user-provided `modify_fn`, and writes back
4848
the result.
49+
50+
usage example, modify physics material bucket:
51+
52+
def resample_bucket_range(
53+
env, env_id, data, static_friction_range, dynamic_friction_range, restitution_range, num_steps
54+
):
55+
if env.common_step_counter > num_steps:
56+
range_list = [static_friction_range, dynamic_friction_range, restitution_range]
57+
ranges = torch.tensor(range_list, device="cpu")
58+
new_buckets = math_utils.sample_uniform(ranges[:, 0], ranges[:, 1], (len(data), 3), device="cpu")
59+
return new_buckets
60+
61+
object_physics_material_curriculum = CurrTerm(
62+
func=mdp.modify_env_param,
63+
params={
64+
"address": "event_manager.cfg.object_physics_material.func.material_buckets",
65+
"modify_fn": resample_bucket_range,
66+
"modify_params": {
67+
"static_friction_range": [.5, 1.],
68+
"dynamic_friction_range": [.3, 1.],
69+
"restitution_range": [0.0, 0.5],
70+
"num_step": 120000
71+
}
72+
}
73+
)
4974
"""
75+
NO_CHANGE = object()
5076

5177
def __init__(self, cfg, env):
5278
"""
@@ -91,7 +117,8 @@ def __call__(
91117

92118
data = self.get_fn()
93119
new_val = modify_fn(self._env, env_ids, data, **modify_params)
94-
self.set_fn(new_val)
120+
if new_val is not self.NO_CHANGE:
121+
self.set_fn(new_val)
95122

96123
def _compile_accessors(self, root, path: str):
97124
"""
@@ -163,6 +190,21 @@ class modify_term_cfg(modify_env_param):
163190
with "_manager.cfg.", and then behaves identically to ModifyEnvParam.
164191
165192
for example: command_manager.cfg.object_pose.ranges.xpos -> commands.object_pose.ranges.xpos
193+
194+
usage example:
195+
196+
def update_value_after_step(env, env_ids, data, value, num_steps):
197+
if env.common_step_counter > num_steps:
198+
return value
199+
200+
command_object_pose_xrange_adr = CurrTerm(
201+
func=mdp.modify_term_cfg,
202+
params={
203+
"address": "commands.object_pose.ranges.pos_x", # note that `_manager.cfg` is omitted
204+
"modify_fn": mdp.update_value_after_step,
205+
"modify_params": {"value": (-.75, -.25), "num_steps": 12000}
206+
}
207+
)
166208
"""
167209

168210
def __init__(self, cfg, env):
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
"""
7+
This script tests the functionality of texture randomization applied to the cartpole scene.
8+
"""
9+
10+
"""Launch Isaac Sim Simulator first."""
11+
12+
from isaaclab.app import AppLauncher
13+
14+
# launch omniverse app
15+
app_launcher = AppLauncher(headless=True, enable_cameras=True)
16+
simulation_app = app_launcher.app
17+
18+
"""Rest everything follows."""
19+
20+
import math
21+
import torch
22+
import unittest
23+
from unittest.mock import patch
24+
25+
import omni.usd
26+
27+
import isaaclab.envs.mdp as mdp
28+
import isaaclab.utils.math as math_utils
29+
from isaaclab.envs import ManagerBasedRLEnv
30+
from isaaclab_tasks.manager_based.classic.cartpole.cartpole_env_cfg import CartpoleEnvCfg
31+
from isaaclab.managers import CurriculumTermCfg as CurrTerm
32+
from isaaclab.utils import configclass
33+
from isaaclab.utils.assets import NVIDIA_NUCLEUS_DIR
34+
35+
36+
def resample_bucket_range(
37+
env, env_id, data, static_friction_range, dynamic_friction_range, restitution_range, num_steps
38+
):
39+
if env.common_step_counter > num_steps:
40+
range_list = [static_friction_range, dynamic_friction_range, restitution_range]
41+
ranges = torch.tensor(range_list, device="cpu")
42+
new_buckets = math_utils.sample_uniform(ranges[:, 0], ranges[:, 1], (len(data), 3), device="cpu")
43+
return new_buckets
44+
45+
46+
def replace_term(
47+
env, env_id, data, term, num_steps
48+
):
49+
if env.common_step_counter > num_steps and data != term:
50+
return term
51+
return mdp.modify_env_param.NO_CHANGE
52+
53+
54+
@configclass
55+
class CurriculumsCfg:
56+
modify_observation_joint_pos = CurrTerm(
57+
func=mdp.modify_term_cfg,
58+
params={
59+
"address": "observations.policy.joint_pos_rel.func",
60+
"modify_fn": replace_term,
61+
"modify_params": {
62+
"term": mdp.joint_pos,
63+
"num_steps": 1
64+
}
65+
}
66+
)
67+
68+
class TestCurriculumModifyEnvParam(unittest.TestCase):
69+
"""Test for texture randomization"""
70+
71+
"""
72+
Tests
73+
"""
74+
75+
def test_curriculum_modify_env_param(self):
76+
"""Test texture randomization fallback mechanism when /visuals pattern doesn't match."""
77+
78+
for device in ["cpu", "cuda"]:
79+
with self.subTest(device=device):
80+
# create a new stage
81+
omni.usd.get_context().new_stage()
82+
83+
# set the arguments - use fallback config
84+
env_cfg = CartpoleEnvCfg()
85+
env_cfg.scene.num_envs = 16
86+
env_cfg.scene.replicate_physics = False
87+
env_cfg.curriculum = CurriculumsCfg()
88+
env_cfg.sim.device = device
89+
90+
# This should trigger the fallback mechanism and log the fallback message
91+
env = ManagerBasedRLEnv(cfg=env_cfg)
92+
93+
# simulate physics
94+
with torch.inference_mode():
95+
for count in range(2): # shorter test for fallback
96+
# reset every few steps to check nothing breaks
97+
env.reset()
98+
# sample random actions
99+
joint_efforts = torch.randn_like(env.action_manager.action)
100+
# step the environment
101+
env.step(joint_efforts)
102+
103+
if count == 0:
104+
assert env.observation_manager.cfg.policy.joint_pos_rel.func == mdp.joint_pos_rel
105+
106+
if count == 1:
107+
assert env.observation_manager.cfg.policy.joint_pos_rel.func == mdp.joint_pos
108+
109+
env.close()

0 commit comments

Comments
 (0)