Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/features/toolkits/grasp_generator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ Configuring GraspGeneratorCfg
* - ``max_deviation_angle``
- ``π / 12``
- Maximum allowed angle (in radians) between the specified approach direction and the axis connecting an antipodal point pair. Pairs that deviate more than this threshold are discarded.
* - ``is_partial_annotate``
- ``True``
- When ``True``, the annotator allows selecting a partial region of the mesh for grasp sampling. If ``False``, the entire mesh is used.
* - ``is_filter_ground_collision``
- ``True``
- Whether to filter out grasp poses that would cause the gripper to collide.

The ``antipodal_sampler_cfg`` field accepts an :class:`~embodichain.toolkits.graspkit.pg_grasp.AntipodalSamplerCfg` instance, which controls how antipodal point pairs are sampled on the mesh surface.

Expand Down
35 changes: 30 additions & 5 deletions embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ class GraspGeneratorCfg:
deviate more than this threshold from perpendicular to the approach are
discarded during grasp pose computation."""

is_partial_annotate: bool = True
"""When ``True``, the annotator allows selecting a partial region of the
mesh for grasp sampling. If ``False``, the entire mesh is used."""

is_filter_ground_collision: bool = True
"""Whether to filter out grasp poses that would cause the gripper to
collide."""


class GraspGenerator:
"""Antipodal grasp-pose generator for parallel-jaw grippers.
Expand Down Expand Up @@ -236,7 +244,12 @@ def annotate(self) -> torch.Tensor:
torch.Tensor: A tensor of shape (N, 2, 3) representing N antipodal point pairs.
Each pair consists of a hit point and its corresponding surface point.
"""

if self.cfg.is_partial_annotate == False:
hit_point_pairs = self._generate_hit_point_pairs(
self.vertices, self.triangles
)
self._cache_hit_point_pairs(hit_point_pairs)
return self._hit_point_pairs
logger.log_info(
f"[Viser] *****Annotate grasp region in http://localhost:{self.cfg.viser_port}"
)
Expand Down Expand Up @@ -343,7 +356,7 @@ def _(event: viser.ScenePointerEvent) -> None:
f"[Selection] Selected {sel_vertex_indices.size} vertices and {sel_face_indices.size} faces."
)

hit_point_pairs = self._antipodal_sampler.sample(
hit_point_pairs = self._generate_hit_point_pairs(
torch.tensor(sel_vertices, device=self.device),
torch.tensor(sel_faces, device=self.device),
)
Expand Down Expand Up @@ -378,13 +391,24 @@ def _(_evt: viser.GuiEvent) -> None:
while True:
if return_flag:
if hit_point_pairs is not None:
self._hit_point_pairs = hit_point_pairs
cache_path = self._get_cache_dir(self.vertices, self.triangles)
self._save_cache(cache_path, hit_point_pairs)
self._cache_hit_point_pairs(hit_point_pairs)
break
time.sleep(0.5)
return self._hit_point_pairs

def _generate_hit_point_pairs(
self, vertices: torch.Tensor, triangles: torch.Tensor
) -> torch.Tensor:
return self._antipodal_sampler.sample(
vertices=vertices,
faces=triangles,
)

def _cache_hit_point_pairs(self, hit_point_pairs: torch.Tensor):
self._hit_point_pairs = hit_point_pairs
cache_path = self._get_cache_dir(self.vertices, self.triangles)
self._save_cache(cache_path, hit_point_pairs)

def _get_cache_dir(self, vertices: torch.Tensor, triangles: torch.Tensor):
vert_bytes = vertices.to("cpu").numpy().tobytes()
face_bytes = triangles.to("cpu").numpy().tobytes()
Expand Down Expand Up @@ -652,6 +676,7 @@ def get_grasp_poses(
object_pose,
valid_grasp_poses,
valid_open_lengths,
is_filter_ground_collision=self.cfg.is_filter_ground_collision,
is_visual=visualize_collision,
collision_threshold=0.0,
)
Expand Down
31 changes: 5 additions & 26 deletions embodichain/toolkits/graspkit/pg_grasp/collision_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ def query_batch_points(
collision_threshold: Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision.
is_visual: Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing.
Returns:
is_pose_collide: [B, ] boolean tensor indicating whether each point cloud in the
is_point_collide: [B, n_point] boolean tensor indicating whether a point cloud is collided.
point_signed_distance: [B, n_point] of float. Signed distance from the point cloud to the object surface.
Negative means the point cloud is penetrating into the object,
positive means the point cloud is outside the object.
"""
n_batch = batch_points.shape[0]
point_signed_distance, is_point_collide = (
Expand All @@ -204,31 +207,7 @@ def query_batch_points(
collision_threshold=collision_threshold,
)
)
is_pose_collide = is_point_collide.any(dim=-1) # [B]
pose_surface_distance = point_signed_distance.min(dim=-1).values # [B]
if is_visual:
# visualize result
frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
for i in range(n_batch):
query_points_o3d = o3d.geometry.PointCloud()
query_points_np = batch_points[i].cpu().numpy()
query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np)
query_points_color = np.zeros_like(query_points_np)
query_points_color[is_point_collide[i].cpu().numpy()] = [
1.0,
0,
0,
] # red for colliding points
query_points_color[~is_point_collide[i].cpu().numpy()] = [
0,
1.0,
0,
] # green for non-colliding points
query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color)
o3d.visualization.draw_geometries(
[self.mesh, query_points_o3d, frame], mesh_show_back_face=True
)
return is_pose_collide, pose_surface_distance
return is_point_collide, point_signed_distance

def query(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from __future__ import annotations

import torch

import open3d as o3d
import numpy as np
from typing import Sequence

from embodichain.utils import configclass
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
base_mesh_faces=object_mesh_faces,
max_decomposition_hulls=cfg.max_decomposition_hulls,
)
self.obj_mesh_verts = object_mesh_verts
self.device = object_mesh_verts.device
self.cfg = cfg
self._init_pc_template()
Expand Down Expand Up @@ -152,24 +154,89 @@ def _get_gripper_pc(
gripper_pc = torch.cat([root_pc, left_pc, right_pc], dim=1)
return gripper_pc

def get_ground_height(self, obj_pose: torch.Tensor) -> float:
obj_r = obj_pose[:3, :3]
obj_t = obj_pose[:3, 3]
# obj_verts_world = (obj_r @ self.obj_mesh_verts.T).T + obj_t
obj_verts_world = self.obj_mesh_verts @ obj_r.T + obj_t
min_z = obj_verts_world[:, 2].min().item()
return min_z

def query(
self,
obj_pose: torch.Tensor,
grasp_poses: torch.Tensor,
open_lengths: torch.Tensor,
collision_threshold: float = 0.0,
is_filter_ground_collision: bool = True,
is_visual: bool = False,
) -> torch.Tensor:
"""query the collision status of the gripper with the object.
The gripper is represented as a point cloud generated from the grasp poses and
open lengths, and the collision status is determined by checking the distance
between the gripper points and the object mesh.

Args:
obj_pose (torch.Tensor): [4, 4] of float. The homogeneous transformation matrix of the object pose in the world frame.
grasp_poses (torch.Tensor): [B, 4, 4] of float. The homogeneous transformation matrices of the gripper root frame for B grasp poses.
open_lengths (torch.Tensor): [B, ] of float. The opening lengths of the gripper fingers for B grasp poses.
collision_threshold (float, optional): Collision distance threshold. Defaults to 0.0.
is_visual (bool, optional): whether to visualize collision result. Defaults to False.

Returns:
torch.Tensor: [B, ] boolean tensor indicating whether a grasp pose is collided.
"""
inv_obj_pose = obj_pose.clone()
inv_obj_pose[:3, :3] = obj_pose[:3, :3].T
inv_obj_pose[:3, 3] = -obj_pose[:3, 3] @ obj_pose[:3, :3]
inv_obj_poses = inv_obj_pose[None, :, :].repeat(grasp_poses.shape[0], 1, 1)
grasp_relative_pose = torch.bmm(inv_obj_poses, grasp_poses)
gripper_pc = self._get_gripper_pc(grasp_relative_pose, open_lengths)
return self._checker.query_batch_points(
gripper_pc, collision_threshold=collision_threshold, is_visual=is_visual
gripper_pc_obj = self._get_gripper_pc(grasp_relative_pose, open_lengths)
is_obj_gripper_collided, obj_gripper_dis = self._checker.query_batch_points(
gripper_pc_obj, collision_threshold=collision_threshold, is_visual=is_visual
)

if is_filter_ground_collision:
gripper_pc_world = self._get_gripper_pc(grasp_poses, open_lengths)
ground_height = self.get_ground_height(obj_pose)
gripper_ground_dis = gripper_pc_world[:, :, 2] - ground_height
is_gripper_ground_collided = gripper_ground_dis < collision_threshold

is_gripper_collided = torch.logical_or(
is_obj_gripper_collided, is_gripper_ground_collided
)
gripper_dis = torch.min(obj_gripper_dis, gripper_ground_dis)
else:
is_gripper_collided = is_obj_gripper_collided
gripper_dis = obj_gripper_dis

if is_visual:
n_batch = grasp_poses.shape[0]
# visualize all collision result
frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
for i in range(n_batch):
query_points_o3d = o3d.geometry.PointCloud()
query_points_np = gripper_pc_obj[i].cpu().numpy()
query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np)
query_points_color = np.zeros_like(query_points_np)
query_points_color[is_gripper_collided[i].cpu().numpy()] = [
1.0,
0,
0,
] # red for colliding points
query_points_color[~is_gripper_collided[i].cpu().numpy()] = [
0,
1.0,
0,
] # green for non-colliding points
query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color)
o3d.visualization.draw_geometries(
[self._checker.mesh, query_points_o3d, frame],
mesh_show_back_face=True,
)

return is_obj_gripper_collided.any(dim=1), obj_gripper_dis.min(dim=1).values


def box_surface_grid(
size: Sequence[float] | torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions embodichain/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,9 +1219,9 @@ def transform_points_mat(
Returns:
transformed: [B, P, 3] transformed point cloud for each pose.
"""
R = poses[:, :3, :3] # [B, 3, 3]
r = poses[:, :3, :3] # [B, 3, 3]
t = poses[:, :3, 3] # [B, 3]
transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1)
transformed = torch.einsum("bij, pj -> bpi", r, points) + t.unsqueeze(1)
return transformed


Expand Down
7 changes: 6 additions & 1 deletion scripts/tutorials/grasp/grasp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso
antipodal_sampler_cfg=AntipodalSamplerCfg(
n_sample=20000, max_length=0.088, min_length=0.003
),
is_partial_annotate=True,
is_filter_ground_collision=True,
)
sim.open_window()

Expand Down Expand Up @@ -266,7 +268,10 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso
)[0]
for i, obj_pose in enumerate(obj_poses):
is_success, grasp_pose, open_length = grasp_generator.get_grasp_poses(
obj_pose, approach_direction, visualize_pose=False
obj_pose,
approach_direction,
visualize_collision=False,
visualize_pose=False,
)
if is_success:
grasp_xpos_list.append(grasp_pose.unsqueeze(0))
Expand Down
4 changes: 3 additions & 1 deletion tests/toolkits/test_batch_convex_collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ def batch_convex_collision_query(device=torch.device("cuda")):
obj_faces = torch.tensor(obj_mesh.faces, dtype=torch.int32, device=device)
test_pc = transform_points_mat(obj_verts, poses)

is_pose_collide, pose_surface_distance = collision_checker.query_batch_points(
is_point_collide, point_surface_distance = collision_checker.query_batch_points(
test_pc, collision_threshold=0.003, is_visual=False
)
is_pose_collide = is_point_collide.any(dim=1)
pose_surface_distance = point_surface_distance.min(dim=1).values
assert is_pose_collide.sum().item() == 1
assert abs(pose_surface_distance.max().item() - 0.8492) < 1e-2

Expand Down
Loading