diff --git a/embodichain/lab/sim/objects/soft_object.py b/embodichain/lab/sim/objects/soft_object.py index 116925e7..d5247222 100644 --- a/embodichain/lab/sim/objects/soft_object.py +++ b/embodichain/lab/sim/objects/soft_object.py @@ -64,18 +64,18 @@ def __init__( self.ps = ps self.num_instances = len(entities) - softbodies = [ + self.softbodies = [ self.entities[i].get_physical_body() for i in range(self.num_instances) ] - self.n_collision_vertices = softbodies[0].get_num_vertices() - self.n_sim_vertices = softbodies[0].get_num_sim_vertices() + self.n_collision_vertices = self.softbodies[0].get_num_vertices() + self.n_sim_vertices = self.softbodies[0].get_num_sim_vertices() self._rest_position_buffer = torch.empty( (self.num_instances, self.n_collision_vertices, 4), device=self.device, dtype=torch.float32, ) - for i, softbody in enumerate(softbodies): + for i, softbody in enumerate(self.softbodies): self._rest_position_buffer[i] = softbody.get_position_inv_mass_buffer() self._rest_sim_position_buffer = torch.empty( @@ -84,23 +84,23 @@ def __init__( dtype=torch.float32, ) - for i, softbody in enumerate(softbodies): + for i, softbody in enumerate(self.softbodies): self._rest_sim_position_buffer[i] = ( softbody.get_sim_position_inv_mass_buffer() ) - self._collision_position_buffer = torch.zeros( - (self.num_instances, self.n_collision_vertices, 4), + self._collision_position = torch.zeros( + (self.num_instances, self.n_collision_vertices, 3), device=self.device, dtype=torch.float32, ) - self._sim_vertex_velocity_buffer = torch.zeros( - (self.num_instances, self.n_sim_vertices, 4), + self._sim_vertex_velocity = torch.zeros( + (self.num_instances, self.n_sim_vertices, 3), device=self.device, dtype=torch.float32, ) - self._sim_vertex_position_buffer = torch.zeros( - (self.num_instances, self.n_sim_vertices, 4), + self._sim_vertex_position = torch.zeros( + (self.num_instances, self.n_sim_vertices, 3), device=self.device, dtype=torch.float32, ) @@ -129,29 +129,29 @@ def rest_sim_vertices(self): return self._rest_sim_position_buffer[:, :, :3].clone() @property - def collision_position_buffer(self): + def collision_position(self): """Get the current vertex position buffer of the soft bodies.""" for i, softbody in enumerate(self.soft_bodies): - self._collision_position_buffer[i] = softbody.get_position_inv_mass_buffer() - return self._collision_position_buffer.clone() + self._collision_position[i] = softbody.get_position_inv_mass_buffer()[:, :3] + return self._collision_position.clone() @property - def sim_vertex_position_buffer(self): + def sim_vertex_position(self): """Get the current sim vertex position buffer of the soft bodies.""" for i, softbody in enumerate(self.soft_bodies): - self._sim_vertex_position_buffer[i] = ( - softbody.get_sim_position_inv_mass_buffer() - ) - return self._sim_vertex_position_buffer.clone() + self._sim_vertex_position[i] = softbody.get_sim_position_inv_mass_buffer()[ + :, :3 + ] + return self._sim_vertex_position.clone() @property - def sim_vertex_velocity_buffer(self): + def sim_vertex_velocity(self): """Get the current vertex velocity buffer of the soft bodies.""" for i, softbody in enumerate(self.soft_bodies): - self._sim_vertex_velocity_buffer[i] = ( - softbody.get_sim_position_inv_mass_buffer() - ) - return self._sim_vertex_velocity_buffer.clone() + self._sim_vertex_velocity[i] = softbody.get_sim_position_inv_mass_buffer()[ + :, :3 + ] + return self._sim_vertex_velocity.clone() class SoftObject(BatchEntity):