Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions molmo_spaces/configs/task_sampler_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class BaseMujocoTaskSamplerConfig(Config):
robot_placement_exclusion_threshold: float = 0.15

robot_placement_rotation_range_rad: float = 0.25 # +/- approx 15 degrees
render_device: int = 0 # Device ID for rendering

# Scene configuration
enable_texture_randomization: bool = False
Expand Down
4 changes: 2 additions & 2 deletions molmo_spaces/env/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ def _initialize_with_model(self, mj_model: MjModel, mj_base_scene_path: str) ->
width, height = self.config.camera_config.img_resolution
else:
width, height = (640, 480) # Default resolution
self._renderer = MjOpenGLRenderer(model=self.mj_model, width=width, height=height)

device_id = getattr(self.config.task_sampler_config, "render_device", 0)
self._renderer = MjOpenGLRenderer(model=self.mj_model, width=width, height=height, device_id=device_id)
if self._parallelize and self._n_batch > 1:
self._executor = ThreadPoolExecutor(max_workers=self._n_batch)
else:
Expand Down
10 changes: 3 additions & 7 deletions molmo_spaces/renderer/opengl_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,9 @@ def __init__(
`height` exceed the dimensions of MuJoCo's offscreen framebuffer.
"""
if device_id is None:
try:
import torch

if torch.cuda.is_available():
device_id = 0
except ImportError:
pass
import os
if os.path.exists("/dev/nvidia0"):
device_id = 0

super().__init__(**prepare_locals_for_super(locals()))

Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies = [
"boto3>=1.35.0",
"coacd>=1.0.7",
"compress-json>=1.1.1",
"decord; sys_platform == 'linux'", # For Linux
"decord; sys_platform == 'linux' and platform_machine == 'x86_64'", # decord wheels are not available on linux aarch64
"eva-decord; sys_platform == 'darwin'", # For macOS
"einops>=0.8.1",
"ffmpeg-python>=0.2.0",
Expand Down Expand Up @@ -54,8 +54,8 @@ dependencies = [
"teledex",
"termcolor>=2.0.0",
"toppra>=0.6.3",
"torch~=2.7.0", # torch 2.8.x is brand new and has some regressions vs 2.7.x - prefer to hold back for now - this installs nvidia_cuda_nvrtc_cu12-12.6.77 by default, not 12.8
"torchvision>=0.22.0,<0.23.0",
"torch>=2.7.0,<2.9", # Allow torch 2.8 for compatibility with jax[cuda12] dependency stack.
"torchvision>=0.22.0,<0.24.0",
"tqdm>=4.65.0",
"trimesh>=4.7.3",
"wandb>=0.18.10",
Expand Down