diff --git a/docs/source/features/toolkits/grasp_generator.rst b/docs/source/features/toolkits/grasp_generator.rst index 7eea272a..7fde03b9 100644 --- a/docs/source/features/toolkits/grasp_generator.rst +++ b/docs/source/features/toolkits/grasp_generator.rst @@ -109,6 +109,12 @@ Configuring GraspGeneratorCfg * - ``max_deviation_angle`` - ``π / 12`` - Maximum allowed angle (in radians) between the specified approach direction and the axis connecting an antipodal point pair. Pairs that deviate more than this threshold are discarded. + * - ``is_partial_annotate`` + - ``True`` + - When ``True``, the annotator allows selecting a partial region of the mesh for grasp sampling. If ``False``, the entire mesh is used. + * - ``is_filter_ground_collision`` + - ``True`` + - Whether to filter out grasp poses that would cause the gripper to collide. The ``antipodal_sampler_cfg`` field accepts an :class:`~embodichain.toolkits.graspkit.pg_grasp.AntipodalSamplerCfg` instance, which controls how antipodal point pairs are sampled on the mesh surface. diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py index 9ec009bc..6b620ff7 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py @@ -79,6 +79,14 @@ class GraspGeneratorCfg: deviate more than this threshold from perpendicular to the approach are discarded during grasp pose computation.""" + is_partial_annotate: bool = True + """When ``True``, the annotator allows selecting a partial region of the + mesh for grasp sampling. If ``False``, the entire mesh is used.""" + + is_filter_ground_collision: bool = True + """Whether to filter out grasp poses that would cause the gripper to + collide.""" + class GraspGenerator: """Antipodal grasp-pose generator for parallel-jaw grippers. @@ -236,7 +244,12 @@ def annotate(self) -> torch.Tensor: torch.Tensor: A tensor of shape (N, 2, 3) representing N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. """ - + if self.cfg.is_partial_annotate == False: + hit_point_pairs = self._generate_hit_point_pairs( + self.vertices, self.triangles + ) + self._cache_hit_point_pairs(hit_point_pairs) + return self._hit_point_pairs logger.log_info( f"[Viser] *****Annotate grasp region in http://localhost:{self.cfg.viser_port}" ) @@ -343,7 +356,7 @@ def _(event: viser.ScenePointerEvent) -> None: f"[Selection] Selected {sel_vertex_indices.size} vertices and {sel_face_indices.size} faces." ) - hit_point_pairs = self._antipodal_sampler.sample( + hit_point_pairs = self._generate_hit_point_pairs( torch.tensor(sel_vertices, device=self.device), torch.tensor(sel_faces, device=self.device), ) @@ -378,13 +391,24 @@ def _(_evt: viser.GuiEvent) -> None: while True: if return_flag: if hit_point_pairs is not None: - self._hit_point_pairs = hit_point_pairs - cache_path = self._get_cache_dir(self.vertices, self.triangles) - self._save_cache(cache_path, hit_point_pairs) + self._cache_hit_point_pairs(hit_point_pairs) break time.sleep(0.5) return self._hit_point_pairs + def _generate_hit_point_pairs( + self, vertices: torch.Tensor, triangles: torch.Tensor + ) -> torch.Tensor: + return self._antipodal_sampler.sample( + vertices=vertices, + faces=triangles, + ) + + def _cache_hit_point_pairs(self, hit_point_pairs: torch.Tensor): + self._hit_point_pairs = hit_point_pairs + cache_path = self._get_cache_dir(self.vertices, self.triangles) + self._save_cache(cache_path, hit_point_pairs) + def _get_cache_dir(self, vertices: torch.Tensor, triangles: torch.Tensor): vert_bytes = vertices.to("cpu").numpy().tobytes() face_bytes = triangles.to("cpu").numpy().tobytes() @@ -652,6 +676,7 @@ def get_grasp_poses( object_pose, valid_grasp_poses, valid_open_lengths, + is_filter_ground_collision=self.cfg.is_filter_ground_collision, is_visual=visualize_collision, collision_threshold=0.0, ) diff --git a/embodichain/toolkits/graspkit/pg_grasp/collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/collision_checker.py index fcbfb850..f3b09014 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/collision_checker.py @@ -192,7 +192,10 @@ def query_batch_points( collision_threshold: Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision. is_visual: Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing. Returns: - is_pose_collide: [B, ] boolean tensor indicating whether each point cloud in the + is_point_collide: [B, n_point] boolean tensor indicating whether a point cloud is collided. + point_signed_distance: [B, n_point] of float. Signed distance from the point cloud to the object surface. + Negative means the point cloud is penetrating into the object, + positive means the point cloud is outside the object. """ n_batch = batch_points.shape[0] point_signed_distance, is_point_collide = ( @@ -204,31 +207,7 @@ def query_batch_points( collision_threshold=collision_threshold, ) ) - is_pose_collide = is_point_collide.any(dim=-1) # [B] - pose_surface_distance = point_signed_distance.min(dim=-1).values # [B] - if is_visual: - # visualize result - frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) - for i in range(n_batch): - query_points_o3d = o3d.geometry.PointCloud() - query_points_np = batch_points[i].cpu().numpy() - query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) - query_points_color = np.zeros_like(query_points_np) - query_points_color[is_point_collide[i].cpu().numpy()] = [ - 1.0, - 0, - 0, - ] # red for colliding points - query_points_color[~is_point_collide[i].cpu().numpy()] = [ - 0, - 1.0, - 0, - ] # green for non-colliding points - query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) - o3d.visualization.draw_geometries( - [self.mesh, query_points_o3d, frame], mesh_show_back_face=True - ) - return is_pose_collide, pose_surface_distance + return is_point_collide, point_signed_distance def query( self, diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py index 5f02176c..b4d77c43 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -17,7 +17,8 @@ from __future__ import annotations import torch - +import open3d as o3d +import numpy as np from typing import Sequence from embodichain.utils import configclass @@ -93,6 +94,7 @@ def __init__( base_mesh_faces=object_mesh_faces, max_decomposition_hulls=cfg.max_decomposition_hulls, ) + self.obj_mesh_verts = object_mesh_verts self.device = object_mesh_verts.device self.cfg = cfg self._init_pc_template() @@ -152,24 +154,89 @@ def _get_gripper_pc( gripper_pc = torch.cat([root_pc, left_pc, right_pc], dim=1) return gripper_pc + def get_ground_height(self, obj_pose: torch.Tensor) -> float: + obj_r = obj_pose[:3, :3] + obj_t = obj_pose[:3, 3] + # obj_verts_world = (obj_r @ self.obj_mesh_verts.T).T + obj_t + obj_verts_world = self.obj_mesh_verts @ obj_r.T + obj_t + min_z = obj_verts_world[:, 2].min().item() + return min_z + def query( self, obj_pose: torch.Tensor, grasp_poses: torch.Tensor, open_lengths: torch.Tensor, collision_threshold: float = 0.0, + is_filter_ground_collision: bool = True, is_visual: bool = False, ) -> torch.Tensor: + """query the collision status of the gripper with the object. + The gripper is represented as a point cloud generated from the grasp poses and + open lengths, and the collision status is determined by checking the distance + between the gripper points and the object mesh. + + Args: + obj_pose (torch.Tensor): [4, 4] of float. The homogeneous transformation matrix of the object pose in the world frame. + grasp_poses (torch.Tensor): [B, 4, 4] of float. The homogeneous transformation matrices of the gripper root frame for B grasp poses. + open_lengths (torch.Tensor): [B, ] of float. The opening lengths of the gripper fingers for B grasp poses. + collision_threshold (float, optional): Collision distance threshold. Defaults to 0.0. + is_visual (bool, optional): whether to visualize collision result. Defaults to False. + + Returns: + torch.Tensor: [B, ] boolean tensor indicating whether a grasp pose is collided. + """ inv_obj_pose = obj_pose.clone() inv_obj_pose[:3, :3] = obj_pose[:3, :3].T inv_obj_pose[:3, 3] = -obj_pose[:3, 3] @ obj_pose[:3, :3] inv_obj_poses = inv_obj_pose[None, :, :].repeat(grasp_poses.shape[0], 1, 1) grasp_relative_pose = torch.bmm(inv_obj_poses, grasp_poses) - gripper_pc = self._get_gripper_pc(grasp_relative_pose, open_lengths) - return self._checker.query_batch_points( - gripper_pc, collision_threshold=collision_threshold, is_visual=is_visual + gripper_pc_obj = self._get_gripper_pc(grasp_relative_pose, open_lengths) + is_obj_gripper_collided, obj_gripper_dis = self._checker.query_batch_points( + gripper_pc_obj, collision_threshold=collision_threshold, is_visual=is_visual ) + if is_filter_ground_collision: + gripper_pc_world = self._get_gripper_pc(grasp_poses, open_lengths) + ground_height = self.get_ground_height(obj_pose) + gripper_ground_dis = gripper_pc_world[:, :, 2] - ground_height + is_gripper_ground_collided = gripper_ground_dis < collision_threshold + + is_gripper_collided = torch.logical_or( + is_obj_gripper_collided, is_gripper_ground_collided + ) + gripper_dis = torch.min(obj_gripper_dis, gripper_ground_dis) + else: + is_gripper_collided = is_obj_gripper_collided + gripper_dis = obj_gripper_dis + + if is_visual: + n_batch = grasp_poses.shape[0] + # visualize all collision result + frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + for i in range(n_batch): + query_points_o3d = o3d.geometry.PointCloud() + query_points_np = gripper_pc_obj[i].cpu().numpy() + query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) + query_points_color = np.zeros_like(query_points_np) + query_points_color[is_gripper_collided[i].cpu().numpy()] = [ + 1.0, + 0, + 0, + ] # red for colliding points + query_points_color[~is_gripper_collided[i].cpu().numpy()] = [ + 0, + 1.0, + 0, + ] # green for non-colliding points + query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) + o3d.visualization.draw_geometries( + [self._checker.mesh, query_points_o3d, frame], + mesh_show_back_face=True, + ) + + return is_obj_gripper_collided.any(dim=1), obj_gripper_dis.min(dim=1).values + def box_surface_grid( size: Sequence[float] | torch.Tensor, diff --git a/embodichain/utils/math.py b/embodichain/utils/math.py index caaa39d2..fbbe75f6 100644 --- a/embodichain/utils/math.py +++ b/embodichain/utils/math.py @@ -1219,9 +1219,9 @@ def transform_points_mat( Returns: transformed: [B, P, 3] transformed point cloud for each pose. """ - R = poses[:, :3, :3] # [B, 3, 3] + r = poses[:, :3, :3] # [B, 3, 3] t = poses[:, :3, 3] # [B, 3] - transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) + transformed = torch.einsum("bij, pj -> bpi", r, points) + t.unsqueeze(1) return transformed diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py index db4a79ac..1bfdeda6 100644 --- a/scripts/tutorials/grasp/grasp_generator.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -227,6 +227,8 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso antipodal_sampler_cfg=AntipodalSamplerCfg( n_sample=20000, max_length=0.088, min_length=0.003 ), + is_partial_annotate=True, + is_filter_ground_collision=True, ) sim.open_window() @@ -266,7 +268,10 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso )[0] for i, obj_pose in enumerate(obj_poses): is_success, grasp_pose, open_length = grasp_generator.get_grasp_poses( - obj_pose, approach_direction, visualize_pose=False + obj_pose, + approach_direction, + visualize_collision=False, + visualize_pose=False, ) if is_success: grasp_xpos_list.append(grasp_pose.unsqueeze(0)) diff --git a/tests/toolkits/test_batch_convex_collision.py b/tests/toolkits/test_batch_convex_collision.py index 4bf852c8..291e15e1 100644 --- a/tests/toolkits/test_batch_convex_collision.py +++ b/tests/toolkits/test_batch_convex_collision.py @@ -60,9 +60,11 @@ def batch_convex_collision_query(device=torch.device("cuda")): obj_faces = torch.tensor(obj_mesh.faces, dtype=torch.int32, device=device) test_pc = transform_points_mat(obj_verts, poses) - is_pose_collide, pose_surface_distance = collision_checker.query_batch_points( + is_point_collide, point_surface_distance = collision_checker.query_batch_points( test_pc, collision_threshold=0.003, is_visual=False ) + is_pose_collide = is_point_collide.any(dim=1) + pose_surface_distance = point_surface_distance.min(dim=1).values assert is_pose_collide.sum().item() == 1 assert abs(pose_surface_distance.max().item() - 0.8492) < 1e-2