import numpy as np
import torch
from torch import Tensor
from typing import Tuple, Union, Optional, Iterable, List, overload
from numpy.typing import NDArray
import trimesh
def get_device_dtype(*args, **kwargs) -> Tuple[Optional[str], Optional[torch.dtype]]:
device: Optional[Union[str, torch.device]] = None
dtype: Optional[torch.dtype] = None
# Handle different overloads
if len(args) == 1:
arg = args[0]
if isinstance(arg, (torch.device, str)):
device = str(arg)
elif isinstance(arg, torch.dtype):
dtype = arg
elif hasattr(arg, "device") and hasattr(arg, "dtype"): # tensor-like
device = str(arg.device)
dtype = arg.dtype
else:
raise TypeError(f"Unsupported argument type: {type(arg)}")
elif len(args) == 2:
device, dtype = args
if isinstance(device, torch.device):
device = str(device)
if not isinstance(dtype, torch.dtype):
raise TypeError(
f"Expected torch.dtype for second argument, got {type(dtype)}"
)
# Keyword arguments override args
device = kwargs.get("device", device)
dtype = kwargs.get("dtype", dtype)
if device is None and dtype is None:
raise ValueError("At least one of device or dtype must be specified.")
if isinstance(device, torch.device):
device = str(device)
return device, dtype
[docs]
class Mesh:
def __init__(
self,
vertices: Union[NDArray, Tensor],
faces: Union[NDArray, Tensor],
adjacencies: Optional[Union[NDArray, Tensor]] = None,
triangle_normals: Optional[Union[NDArray, Tensor]] = None,
v2t: Optional[Union[NDArray, Tensor]] = None,
vertex_normals: Optional[Union[NDArray, Tensor]] = None,
device: str | torch.device = "cpu",
dtype: torch.dtype = torch.float32,
):
"""
Initialize a Mesh object with vertices, faces, and optional precomputed data.
Parameters
----------
vertices: NDArray | Tensor
(V, 3) array of vertex positions.
faces: NDArray | Tensor
(F, 3) array of triangle vertex indices.
adjacencies: Optional[NDArray | Tensor]
(F, 3) stores the adjacent triangle indices for each triangle edge.
triangle_normals: Optional[NDArray | Tensor]
(F, 3) array of triangle normal vectors.
v2t: Optional[NDArray | Tensor]
(V, max_tris_per_vertex + 1) array mapping vertices to incident triangles.
The first column stores the number of incident triangles, followed by the
triangle indices.
vertex_normals: Optional[NDArray | Tensor]
(V, 3) array of vertex normal vectors.
device: str | torch.device
The device to store the mesh on.
dtype: torch.dtype
The data type for the vertex positions and normals.
"""
self.dtype: torch.dtype = dtype
self.vertices = torch.as_tensor(
vertices, dtype=dtype, device=device
).requires_grad_(False)
self.faces = torch.as_tensor(
faces, dtype=torch.int32, device=device
).requires_grad_(False)
if vertex_normals is None:
mesh_trimesh = trimesh.Trimesh(
vertices=self.vertices.cpu().numpy(),
faces=self.faces.cpu().numpy(),
)
vertex_normals = mesh_trimesh.vertex_normals
if triangle_normals is None:
triangle_normals = mesh_trimesh.face_normals
self.adjacencies = torch.as_tensor(
adjacencies if adjacencies is not None else self._compute_adjacencies(),
dtype=torch.int32,
device=device,
).requires_grad_(False)
self.triangle_normals = torch.as_tensor(
triangle_normals, dtype=dtype, device=device
).requires_grad_(False)
self.v2t = torch.as_tensor(
v2t if v2t is not None else self._compute_vertex_to_triangle_map(),
dtype=torch.int32,
device=device,
).requires_grad_(False)
self.vertex_normals = torch.as_tensor(
vertex_normals, dtype=dtype, device=device
).requires_grad_(False)
self.device: torch.device = self.vertices.device
@overload
def to(self, device: Union[str, torch.device]) -> "Mesh": ...
@overload
def to(self, dtype: torch.dtype) -> "Mesh": ...
@overload
def to(self, device: Union[str, torch.device], dtype: torch.dtype) -> "Mesh": ...
@overload
def to(self, tensor: "Tensor") -> "Mesh": ...
@overload
def to(
self,
*,
device: Optional[Union[str, torch.device]] = ...,
dtype: Optional[torch.dtype] = ...,
) -> "Mesh": ...
[docs]
def to(self, *args, **kwargs) -> "Mesh":
"""
Move the mesh to a different device and/or dtype.
Supports PyTorch-like overloads:
- to(device)
- to(dtype)
- to(device, dtype)
- to(device=..., dtype=...)
"""
device, dtype = get_device_dtype(*args, **kwargs)
if device is None:
device = self.device
if dtype is None:
dtype = self.dtype
self.vertices = self.vertices.to(device=device, dtype=dtype)
self.faces = self.faces.to(device=device)
self.adjacencies = self.adjacencies.to(device=device)
self.triangle_normals = self.triangle_normals.to(device=device, dtype=dtype)
self.v2t = self.v2t.to(device=device)
self.vertex_normals = self.vertex_normals.to(device=device, dtype=dtype)
self.device = self.vertices.device
self.dtype = dtype
return self
def _compute_adjacencies(self) -> NDArray[np.int32]:
triangles = self.faces.cpu().numpy()
num_triangles = len(triangles)
adjacencies = -np.ones((num_triangles, 3), dtype=np.int32)
edge_to_triangle = {}
for tri_idx, tri in enumerate(triangles):
for local_edge_idx, (i, j) in enumerate([(0, 1), (1, 2), (2, 0)]):
edge = tuple(sorted((tri[i], tri[j])))
if edge in edge_to_triangle:
other_tri_idx, other_local_edge_idx = edge_to_triangle[edge]
adjacencies[tri_idx, local_edge_idx] = other_tri_idx
adjacencies[other_tri_idx, other_local_edge_idx] = tri_idx
else:
edge_to_triangle[edge] = (tri_idx, local_edge_idx)
return adjacencies
def _compute_vertex_to_triangle_map(self) -> NDArray[np.int32]:
triangles = self.faces.cpu().numpy()
num_vertices = self.vertices.shape[0]
v2t = [[] for _ in range(num_vertices)]
max_len = 0
for tri_idx, tri in enumerate(triangles):
for v in tri:
v2t[v].append(tri_idx)
max_len = max(max_len, len(v2t[v]))
v2t_array = -np.ones((num_vertices, max_len + 1), dtype=np.int32)
for i, tris in enumerate(v2t):
v2t_array[i, 0] = len(tris)
v2t_array[i, 1 : len(tris) + 1] = tris
return v2t_array
[docs]
class MeshBatch(Mesh):
def __init__(self, meshes: List[Mesh]):
"""
Allows batched operations of meshes.
MeshPointsBatch must also be batched using the MeshBatch with batch_points
and unbatch_points.
Paramseters
-----------
meshes: List[Mesh]
A list of Mesh objects to batch together.
>>> mesh1 = load_mesh_from_file(path_to_mesh1, device=device)
>>> mesh2 = load_mesh_from_file(path_to_mesh2, device=device)
>>> start_points1 = uniform_sampling(mesh1, N).to(device)
>>> start_points2 = uniform_sampling(mesh2, N).to(device)
>>> mesh_batch = MeshBatch([mesh1, mesh2])
>>> start_batched_points = mesh_batch.batch_points(
... [start_points1, start_points2]
... )
>>> start_dirs = torch.randn(
... (2*N, 3), dtype=torch.float32
... ).to(device)
>>> end_batched_points, geodesic_info = trace_geodesics(
... mesh_batch, start_batched_points, start_dirs
... )
>>> [end_points1, end_points2] = mesh_batch.unbatch_points(
... end_batched_points
... )
"""
if not meshes:
raise ValueError("MeshBatch must contain at least one mesh.")
device = meshes[0].device
self.vertex_idx = torch.cumsum(
torch.tensor(
[0] + [mesh.vertices.shape[0] for mesh in meshes],
dtype=torch.int32,
device=device,
),
dim=0,
).requires_grad_(False)
self.triangle_idx = torch.cumsum(
torch.tensor(
[0] + [mesh.faces.shape[0] for mesh in meshes],
dtype=torch.int32,
device=device,
),
dim=0,
).requires_grad_(False)
vertices = torch.cat([mesh.vertices for mesh in meshes], dim=0)
faces = torch.cat(
[mesh.faces + self.vertex_idx[i] for i, mesh in enumerate(meshes)],
dim=0,
)
adjacencies = torch.cat(
[mesh.adjacencies + self.triangle_idx[i] for i, mesh in enumerate(meshes)],
dim=0,
)
triangle_normals = torch.cat([mesh.triangle_normals for mesh in meshes], dim=0)
max_v2t_length = max(mesh.v2t.shape[1] for mesh in meshes)
padded_v2t = [
torch.nn.functional.pad(
torch.cat(
[
mesh.v2t[:, :1], # Keep first column unchanged
mesh.v2t[:, 1:] + self.triangle_idx[i], # Offset other columns
],
dim=1,
),
(0, max_v2t_length - mesh.v2t.shape[1]),
)
for i, mesh in enumerate(meshes)
]
v2t = torch.cat(padded_v2t, dim=0)
vertex_normals = torch.cat([mesh.vertex_normals for mesh in meshes], dim=0)
super().__init__(
vertices=vertices,
faces=faces,
adjacencies=adjacencies,
triangle_normals=triangle_normals,
v2t=v2t,
vertex_normals=vertex_normals,
device=device,
dtype=meshes[0].dtype,
)
[docs]
def unbatch(self) -> List[Mesh]:
"""
Unbatch the mesh batch into a list of Mesh objects.
Returns
-------
List[Mesh]
A list of Mesh objects corresponding to the original meshes used to
create the batch.
"""
meshes = []
for i in range(len(self)):
start_vertex = self.vertex_idx[i]
start_triangle = self.triangle_idx[i]
end_vertex = self.vertex_idx[i + 1]
end_triangle = self.triangle_idx[i + 1]
v2t = self.v2t[start_vertex:end_vertex]
v2t = torch.cat(
[
v2t[:, :1], # Keep first column unchanged
v2t[:, 1:] - start_triangle, # Subtract only from other columns
],
dim=1,
)
meshes.append(
Mesh(
vertices=self.vertices[start_vertex:end_vertex],
faces=self.faces[start_triangle:end_triangle]
- start_vertex,
adjacencies=self.adjacencies[start_triangle:end_triangle]
- start_triangle,
triangle_normals=self.triangle_normals[start_triangle:end_triangle],
v2t=v2t,
vertex_normals=self.vertex_normals[start_vertex:end_vertex],
device=self.device,
)
)
return meshes
[docs]
def batch_points(self, meshpoints: "List[MeshPointBatch]") -> "MeshPointBatch":
"""
Batch a list of MeshPointBatch objects into a single MeshPointBatch.
The output MeshPointBatch will be detached from the original MeshPointBatch
objects.
Parameters
----------
meshpoints: List[MeshPointBatch]
A list of MeshPointBatch objects to batch together. Where meshpoints[i]
corresponds to the points on mesh i in the MeshBatch.
Returns
-------
MeshPointBatch
A single MeshPointBatch containing all the points from the input batches,
with face indices adjusted to match the concatenated faces of the MeshBatch.
"""
if not meshpoints:
raise ValueError("meshpoints must contain at least one MeshPointBatch.")
if len(meshpoints) != len(self):
raise ValueError(
"Number of MeshPointBatch objects must match the number of meshes in " \
"the MeshBatch."
)
faces = torch.cat(
[mp.faces + self.triangle_idx[i] for i, mp in enumerate(meshpoints)], dim=0
)
uvs = torch.cat([mp.uvs.detach() for mp in meshpoints], dim=0)
return MeshPointBatch(faces=faces, uvs=uvs).to(self.device)
[docs]
def unbatch_points(self, meshpoints: "MeshPointBatch") -> "List[MeshPointBatch]":
"""
Unbatch a MeshPointBatch into a list of MeshPointBatch objects.
Parameters
----------
meshpoints: MeshPointBatch
A MeshPointBatch containing points on the MeshBatch. The face indices in
meshpoints should correspond to the concatenated faces of the MeshBatch.
Returns
-------
List[MeshPointBatch]
A list of MeshPointBatch objects, where the i-th batch contains the points
corresponding to the i-th mesh in the MeshBatch. The face indices in each
batch will be adjusted to correspond to the original faces of the individual
meshes.
"""
meshpoint_batches = []
for i in range(len(self)):
start_triangle = self.triangle_idx[i]
end_triangle = self.triangle_idx[i + 1]
mask = (meshpoints.faces >= start_triangle) & (
meshpoints.faces < end_triangle
)
meshpoint_batches.append(
MeshPointBatch(
faces=meshpoints.faces[mask] - start_triangle,
uvs=meshpoints.uvs[mask],
).to(self.device)
)
return meshpoint_batches
@overload
def to(self, device: Union[str, torch.device]) -> "MeshBatch": ...
@overload
def to(self, dtype: torch.dtype) -> "MeshBatch": ...
@overload
def to(
self, device: Union[str, torch.device], dtype: torch.dtype
) -> "MeshBatch": ...
@overload
def to(self, tensor: "Tensor") -> "MeshBatch": ...
@overload
def to(
self,
*,
device: Optional[Union[str, torch.device]] = ...,
dtype: Optional[torch.dtype] = ...,
) -> "MeshBatch": ...
[docs]
def to(self, *args, **kwargs) -> "MeshBatch":
"""
Move the mesh to a different device and/or dtype.
Supports PyTorch-like overloads:
- to(device)
- to(dtype)
- to(device, dtype)
- to(device=..., dtype=...)
"""
device, dtype = get_device_dtype(*args, **kwargs)
if device is None:
device = self.device
if dtype is None:
dtype = self.dtype
self.vertices = self.vertices.to(device=device, dtype=dtype)
self.faces = self.faces.to(device=device)
self.adjacencies = self.adjacencies.to(device=device)
self.triangle_normals = self.triangle_normals.to(device=device, dtype=dtype)
self.v2t = self.v2t.to(device=device)
self.vertex_normals = self.vertex_normals.to(device=device, dtype=dtype)
self.vertex_idx = self.vertex_idx.to(device=device)
self.triangle_idx = self.triangle_idx.to(device=device)
self.device = device
self.dtype = dtype
return self
def __len__(self) -> int:
"""Return the number of meshes in the batch."""
return len(self.vertex_idx) - 1
def __getitem__(self, idx: int) -> Mesh:
if isinstance(idx, int):
if idx < 0 or idx >= len(self.vertex_idx) - 1:
raise IndexError("Index out of bounds for MeshBatch.")
start_vertex = self.vertex_idx[idx]
start_triangle = self.triangle_idx[idx]
end_vertex = self.vertex_idx[idx + 1]
end_triangle = self.triangle_idx[idx + 1]
return Mesh(
vertices=self.vertices[start_vertex:end_vertex],
faces=self.faces[start_triangle:end_triangle] - start_vertex,
adjacencies=self.adjacencies[start_triangle:end_triangle],
triangle_normals=self.triangle_normals[start_triangle:end_triangle],
v2t=self.v2t[start_vertex:end_vertex],
vertex_normals=self.vertex_normals[start_vertex:end_vertex],
device=self.device,
)
else:
raise TypeError("Index must be an integer.")
[docs]
class MeshPoint:
def __init__(self, face: int, uv: Tensor):
"""
Initialize a MeshPoint with a face index and UV coordinates.
This should generally not be used directly, instead use MeshPointBatch for
batching.
"""
self.face = face
self.uv = uv
def __str__(self) -> str:
return f"MeshPoint(face={self.face}, uv={self.uv.tolist()})"
[docs]
def interpolate(self, mesh: Mesh) -> Tensor:
face = self.face
uv = self.uv
p0 = mesh.vertices[mesh.faces[face, 0]]
p1 = mesh.vertices[mesh.faces[face, 1]]
p2 = mesh.vertices[mesh.faces[face, 2]]
pos = (1 - uv[0] - uv[1]) * p0 + uv[0] * p1 + uv[1] * p2
return pos
[docs]
def detach(self) -> "MeshPoint":
return MeshPoint(self.face, self.uv.detach())
[docs]
def get_barycentric_coords(self) -> Tensor:
return torch.tensor(
[1.0 - self.uv[0] - self.uv[1], self.uv[0], self.uv[1]], dtype=self.uv.dtype
)
[docs]
class MeshPointBatch:
def __init__(self, faces: Tensor, uvs: Tensor):
"""
Initialize a batch of MeshPoints.
"""
if faces.dim() != 1 or faces.dtype != torch.int32:
raise ValueError("faces must be a 1D tensor of dtype torch.int32.")
if uvs.dim() != 2 or uvs.size(1) != 2:
raise ValueError("uvs must be a 2D tensor with shape (N, 2).")
if faces.size(0) != uvs.size(0):
raise ValueError(
f"faces and uvs must have the same length. Found {faces.size(0)} and " \
f"{uvs.size(0)}."
)
self.faces: Tensor = faces
self.uvs: Tensor = uvs
self.uvs.requires_grad_()
[docs]
@classmethod
def from_points(cls, points: Iterable[MeshPoint]) -> "MeshPointBatch":
"""
Create a MeshPointBatch from an iterable of MeshPoint objects.
"""
faces = torch.tensor([p.face for p in points], dtype=torch.int32)
uvs = torch.stack([p.uv for p in points], dim=0)
return cls(faces, uvs)
@overload
def to(self, device: Union[str, torch.device]) -> "MeshPointBatch": ...
@overload
def to(self, dtype: torch.dtype) -> "MeshPointBatch": ...
@overload
def to(
self, device: Union[str, torch.device], dtype: torch.dtype
) -> "MeshPointBatch": ...
@overload
def to(self, tensor: "Tensor") -> "MeshPointBatch": ...
@overload
def to(
self,
*,
device: Optional[Union[str, torch.device]] = ...,
dtype: Optional[torch.dtype] = ...,
) -> "MeshPointBatch": ...
[docs]
def to(self, *args, **kwargs) -> "MeshPointBatch":
"""
Move the mesh to a different device and/or dtype.
Supports PyTorch-like overloads:
- to(device)
- to(dtype)
- to(device, dtype)
- to(device=..., dtype=...)
"""
device, dtype = get_device_dtype(*args, **kwargs)
if device is None:
device = str(self.uvs.device)
if dtype is None:
dtype = self.uvs.dtype
return MeshPointBatch(
faces=self.faces.to(device=device),
uvs=self.uvs.to(device=device, dtype=dtype),
)
[docs]
def detach(self) -> "MeshPointBatch":
return MeshPointBatch(faces=self.faces.detach(), uvs=self.uvs.detach())
[docs]
def clone(self) -> "MeshPointBatch":
return MeshPointBatch(faces=self.faces.clone(), uvs=self.uvs.clone())
[docs]
def interpolate(self, mesh: Mesh, return_batch: bool = False) -> Tensor:
"""
Interpolate the positions of the mesh points in the batch.
Parameters
----------
mesh: Mesh
The mesh on which the points lie. The face indices in self.faces should
correspond to the faces of this mesh.
return_batch: bool (default=False)
If True, return the interpolated positions as a batch of shape (B, N, 3),
where B is the number of meshes in the batch (if mesh is a MeshBatch)
and N is the number of points in the batch. If False, return a tensor of
shape (N, 3) containing the interpolated positions of the points.
Returns
-------
Tensor
The interpolated positions of the mesh points. Shape is (N, 3) if
return_batch is False, or (B, N, 3) if return_batch is True and
mesh is a MeshBatch.
"""
p0 = mesh.vertices[mesh.faces[self.faces, 0]]
p1 = mesh.vertices[mesh.faces[self.faces, 1]]
p2 = mesh.vertices[mesh.faces[self.faces, 2]]
pos = (
(1 - self.uvs[:, 0:1] - self.uvs[:, 1:2]) * p0
+ self.uvs[:, 0:1] * p1
+ self.uvs[:, 1:2] * p2
)
if return_batch:
if not isinstance(mesh, MeshBatch):
raise ValueError(
"Mesh must be a MeshBatch to return a batch of positions."
)
batch_size = mesh.vertex_idx.shape[0] - 1
return pos.view(batch_size, -1, 3)
else:
return pos
[docs]
def get_barycentric_coords(self) -> Tensor:
"""
Get the barycentric coordinates of the mesh points in the batch.
Returns
-------
Tensor
A tensor of shape (N, 3) containing the barycentric coordinates of each
point in the batch.
"""
return torch.cat(
[
1.0 - self.uvs[:, 0:1] - self.uvs[:, 1:2],
self.uvs[:, 0:1],
self.uvs[:, 1:2],
],
dim=1,
)
[docs]
def to_list(self) -> list[MeshPoint]:
return [
MeshPoint(int(face.item()), uv) for face, uv in zip(self.faces, self.uvs)
]
@overload
def __getitem__(self, idx: int) -> MeshPoint: ...
@overload
def __getitem__(self, idx: slice | Tensor) -> "MeshPointBatch": ...
def __getitem__(
self, idx: int | slice | Tensor
) -> "Union[MeshPoint,MeshPointBatch]":
if isinstance(idx, int):
return MeshPoint(int(self.faces[idx].item()), self.uvs[idx])
elif isinstance(idx, slice | Tensor):
return MeshPointBatch(faces=self.faces[idx], uvs=self.uvs[idx])
else:
raise TypeError(f"Invalid index type: {type(idx)}")
def __len__(self) -> int:
return len(self.faces)
def __iter__(self):
for i in range(len(self)):
yield self[i]