Source code for digeo.ops.geodesic

import torch
from torch import Tensor
from typing import Tuple, List, Optional
from torch.autograd.function import once_differentiable

from digeo.ops.geodesic_utils import straightest_geodesic
from digeo.mesh import MeshPoint, Mesh, MeshPointBatch
from digeo.ops.cuda.straightest_geodesic import (
    straightest_geodesics as cuda_straightest_geodesics,
)


class GeodesicInfo:
    def __init__(
        self,
        debug_tensor: Optional[Tensor] = None,
        rotation: Optional[Tensor] = None,
        end_directions: Optional[Tensor] = None,
    ):
        """
        Represents the geodesic information including the rotation matrix and an
        optional debug tensor.
        """
        self.rotation_tensor = rotation
        self.path_len = None
        self.vertex_crossings = None
        self.path = None
        self.directions = None
        self.normals = None
        self.end_directions = end_directions
        if debug_tensor is not None:
            self.path_len = debug_tensor[:, 0, 0, 0]
            self.vertex_crossings = debug_tensor[:, 0, 0, 1]
            self.path = debug_tensor[:, 0, 1:]
            self.directions = debug_tensor[:, 1, 1:]
            self.normals = debug_tensor[:, 2, 1:]

    @property
    def rotation(self) -> Tensor:
        """
        Returns the rotation matrix used for parallel transport.
        """
        if self.rotation_tensor is None:
            raise RuntimeError(
                "Rotation tensor is not set. Set 'save_parallel_transport=True' in " \
                "trace_geodesics to save it."
            )
        return self.rotation_tensor

    def transport(self, v: Tensor) -> Tensor:
        """
        Parallel transports a direction vector v.
        """
        if self.rotation_tensor is None:
            raise RuntimeError(
                "Rotation tensor is not set. Set 'save_parallel_transport=True' in " \
                "trace_geodesics to save it."
            )
        if v.dim() != 2 or v.shape[1] != 3:
            raise RuntimeError(
                f"Input tensor should be of shape [batch_size, 3], got {v.shape}."
            )
        if v.shape[0] != self.rotation_tensor.shape[0]:
            raise RuntimeError(
                f"Input tensor batch size should match rotation tensor batch size, "
                f"got {v.shape[0]} and {self.rotation_tensor.shape[0]}."
            )
        return torch.bmm(v.unsqueeze(1), self.rotation_tensor).squeeze(1)

    def transport_inv(self, v: Tensor) -> Tensor:
        """
        Inverse parallel transports a direction vector v.
        """
        if self.rotation_tensor is None:
            raise RuntimeError(
                "Rotation tensor is not set. Set 'save_parallel_transport=True' in " \
                "trace_geodesics to save it."
            )
        if v.dim() != 2 or v.shape[1] != 3:
            raise RuntimeError(
                f"Input tensor should be of shape [batch_size, 3], got {v.shape}."
            )
        if v.shape[0] != self.rotation_tensor.shape[0]:
            raise RuntimeError(
                f"Input tensor batch size should match rotation tensor batch size, "
                f"got {v.shape[0]} and {self.rotation_tensor.shape[0]}."
            )
        return torch.bmm(v.unsqueeze(1), self.rotation_tensor.transpose(1, 2)).squeeze(
            1
        )

    def get_path(self, i: int) -> Tensor:
        """
        Returns the positions of the points of the geodesic starting from starts[i].
        Outputs a tensor of shape [path_length, 3]
        """
        if self.path is None or self.path_len is None:
            raise RuntimeError(
                "Path is not set. Set 'debug=True' in trace_geodesics to save it."
            )
        return self.path[i, : int(self.path_len[i].item())]

    def get_directions(self, i: int) -> Tensor:
        """
        Returns the directions of the geodesic starting from starts[i].
        Outputs a tensor of shape [path_length, 3]
        """
        if self.directions is None or self.path_len is None:
            raise RuntimeError(
                "Directions are not set. Set 'debug=True' in trace_geodesics to "
                "save it."
            )
        return self.directions[i, : int(self.path_len[i].item())]

    def get_normals(self, i: int) -> Tensor:
        """
        Returns the normals of the geodesic starting from starts[i].
        Outputs a tensor of shape [path_length, 3]
        """
        if self.normals is None or self.path_len is None:
            raise RuntimeError(
                "Normals are not set. Set 'debug=True' in trace_geodesics to save it."
            )
        return self.normals[i, : int(self.path_len[i].item())]


def trace_geodesics_gfd(
    mesh: Mesh, x: MeshPointBatch, dirs: Tensor, kwargs: dict
) -> MeshPointBatch:
    """
    Compute the straightest geodesics on the mesh starting from point x with
    direction v.
    This function is a wrapper for the StraighestGeodesicsGFD autograd function.
    """
    # Project on the tangent plane
    normals = mesh.triangle_normals[x.faces]
    dirs = dirs - torch.sum(dirs * normals, dim=1, keepdim=True) * normals

    x1_faces, x1_uvs = StraighestGeodesicsGFD.apply(mesh, x.faces, x.uvs, dirs, kwargs)
    x1 = MeshPointBatch(faces=x1_faces, uvs=x1_uvs)
    return x1


def trace_geodesics_abfd(
    mesh: Mesh, x: MeshPointBatch, dirs: Tensor, kwargs: dict
) -> MeshPointBatch:
    """
    Compute the straightest geodesics on the mesh starting from point x with
    direction v.
    This function is a wrapper for the StraighestGeodesicsABFD autograd function.
    """
    # Project on the tangent plane
    normals = mesh.triangle_normals[x.faces]
    dirs = dirs - torch.sum(dirs * normals, dim=1, keepdim=True) * normals

    x1_faces, x1_uvs = StraighestGeodesicsABFD.apply(mesh, x.faces, x.uvs, dirs, kwargs)
    x1 = MeshPointBatch(faces=x1_faces, uvs=x1_uvs)
    return x1


def inv_2x2(x: Tensor) -> Tensor:
    """
    Computes the inverse of a batch of 2x2 tensors.
    """
    scale = x.abs().amax(dim=(-2, -1), keepdim=True).clamp(min=1e-8)
    x_scaled = x / scale

    a = x_scaled[:, 0, 0]
    b = x_scaled[:, 0, 1]
    c = x_scaled[:, 1, 0]
    d = x_scaled[:, 1, 1]

    det = a * d - b * c
    inv_det = 1.0 / det.clamp(min=1e-8)

    inv_x = torch.stack((d, -b, -c, a), dim=-1).reshape(-1, 2, 2)
    inv_x = inv_det[:, None, None] * inv_x  # [B, 2, 2]
    inv_x = inv_x / scale
    return inv_x


class StraighestGeodesicsGFD(torch.autograd.Function):
    EPS = 1e-2

    @staticmethod
    def forward(
        ctx, mesh: Mesh, x0_faces: Tensor, x0_uvs: Tensor, v: Tensor, kwargs: dict
    ):
        x0 = MeshPointBatch(faces=x0_faces, uvs=x0_uvs)
        x1, info = trace_geodesics(
            mesh, x0, v, gradient="none", save_end_direction=True, **kwargs
        )

        ctx.save_for_backward(
            x0.faces, x0_uvs, v, info.end_directions, x1.faces, x1.uvs
        )
        ctx.mesh = mesh
        ctx.kwargs = kwargs

        return x1.faces, x1.uvs

    @staticmethod
    @once_differentiable
    def backward(ctx, gr_x1_faces, grad_x1_uv: Tensor):
        # ignore gr_x1_faces, should be None

        x0_faces, x0_uvs, v, v1, x1_faces, x1_uvs = ctx.saved_tensors
        x0 = MeshPointBatch(
            faces=x0_faces, uvs=x0_uvs
        )  # Reconstruct MeshPointBatch from saved tensors
        x1 = MeshPointBatch(
            faces=x1_faces, uvs=x1_uvs
        )  # Reconstruct MeshPointBatch from saved tensors
        mesh = ctx.mesh
        eps = StraighestGeodesicsGFD.EPS

        p0 = mesh.vertices[mesh.faces[x0.faces][:, 0]]
        p1 = mesh.vertices[mesh.faces[x0.faces][:, 1]]
        p2 = mesh.vertices[mesh.faces[x0.faces][:, 2]]
        x0_e1 = p1 - p0
        x0_e2 = p2 - p0
        x0_e1 = x0_e1 / torch.norm(x0_e1, dim=1, keepdim=True)  # (B, 3)
        x0_e2 = x0_e2 / torch.norm(x0_e2, dim=1, keepdim=True)  # (B, 3)

        p0 = mesh.vertices[mesh.faces[x1.faces][:, 0]]
        p1 = mesh.vertices[mesh.faces[x1.faces][:, 1]]
        p2 = mesh.vertices[mesh.faces[x1.faces][:, 2]]
        x1_e1 = p1 - p0
        x1_e2 = p2 - p0
        T1 = torch.stack((x1_e1, x1_e2), dim=-1)  # (B, 3, 2)
        TtT_inv = inv_2x2(torch.bmm(T1.mT, T1))  # [B, 2, 2]
        T1_inv = torch.bmm(TtT_inv, T1.mT)  # [B, 2, 3]

        x0_e1_eps, info_xe1 = trace_geodesics(
            mesh,
            x0,
            eps * x0_e1,
            gradient="none",
            save_parallel_transport=True,
            **ctx.kwargs,
        )  # (B, 3)
        x0_e2_eps, info_xe2 = trace_geodesics(
            mesh,
            x0,
            eps * x0_e2,
            gradient="none",
            save_parallel_transport=True,
            **ctx.kwargs,
        )  # (B, 3)

        v_xe1 = info_xe1.transport(v)
        v_xe2 = info_xe2.transport(v)

        y1, _ = trace_geodesics(
            mesh, x0_e1_eps, v_xe1, gradient="none", **ctx.kwargs
        )  # (B, 3)
        y2, _ = trace_geodesics(
            mesh, x0_e2_eps, v_xe2, gradient="none", **ctx.kwargs
        )  # (B, 3)

        # Jacobian
        J_x0 = torch.stack(
            (
                (y1.interpolate(mesh) - x1.interpolate(mesh)) / eps,
                (y2.interpolate(mesh) - x1.interpolate(mesh)) / eps,
            ),
            dim=-1,
        )  # (B, 3, 2)

        J_x0_local = torch.bmm(T1_inv, J_x0)  # (B, 2, 2)

        grad_x0_uvs = torch.bmm(grad_x1_uv.unsqueeze(1), J_x0_local).squeeze(
            1
        )  # (B, 2) @ (B, 2, 2) -> (B, 2)

        n0 = mesh.triangle_normals[x0.faces]  # (B, 3)
        d1 = v / torch.norm(v, dim=1, keepdim=True)  # (B, 3)
        d2 = torch.cross(n0, d1, dim=1)  # (B, 3)

        # Start directly from x1 since the pertubation is along v
        y1, _ = trace_geodesics(
            mesh, x1, eps * v1, gradient="none", **ctx.kwargs
        )  # (B, 3)
        # Here we need to retrace the geodesic from x0
        y2, _ = trace_geodesics(
            mesh, x0, v + eps * d2, gradient="none", **ctx.kwargs
        )  # (B, 3)

        # Jacobian
        J_v = torch.stack(
            (
                (y1.interpolate(mesh) - x1.interpolate(mesh)) / eps,
                (y2.interpolate(mesh) - x1.interpolate(mesh)) / eps,
            ),
            dim=-1,
        )

        grad_x1 = torch.bmm(grad_x1_uv.unsqueeze(1), T1_inv).squeeze(1)  # (B, 3)

        grad_v_local = torch.bmm(grad_x1.unsqueeze(1), J_v).squeeze(
            1
        )  # (B, 2) @ (B, 2, 2) -> (B, 2)

        grad_v = (
            grad_v_local[:, 0:1] * d1 + grad_v_local[:, 1:2] * d2
        )  # (B,) * (B, 3) + (B,) * (B, 3) -> (B, 3)

        return None, None, grad_x0_uvs, grad_v, None


class StraighestGeodesicsABFD(torch.autograd.Function):
    EPS = 1e-2

    @staticmethod
    def forward(
        ctx, mesh: Mesh, x0_faces: Tensor, x0_uvs: Tensor, v: Tensor, kwargs: dict
    ):
        x0 = MeshPointBatch(faces=x0_faces, uvs=x0_uvs)
        x1, info = trace_geodesics(
            mesh, x0, v, gradient="none", save_end_direction=True, **kwargs
        )

        ctx.save_for_backward(
            x0.faces, x0_uvs, v, info.end_directions, x1.faces, x1.uvs
        )
        ctx.mesh = mesh
        ctx.kwargs = kwargs

        return x1.faces, x1.uvs

    @staticmethod
    @once_differentiable
    def backward(ctx, gr_x1_faces, grad_x1_uv: Tensor):
        x0_faces, x0_uvs, v, v1, x1_faces, x1_uvs = ctx.saved_tensors
        mesh = ctx.mesh
        x0 = MeshPointBatch(
            faces=x0_faces, uvs=x0_uvs
        )  # Reconstruct MeshPointBatch from saved tensors
        x1 = MeshPointBatch(
            faces=x1_faces, uvs=x1_uvs
        )  # Reconstruct MeshPointBatch from saved tensors
        eps = StraighestGeodesicsABFD.EPS

        v_norm = v.norm(dim=-1, keepdim=True)
        v1 = v_norm * v1
        n = mesh.triangle_normals[x0.faces]
        v = v - torch.sum(v * n, dim=-1, keepdim=True) * n
        v = v_norm * (v / v.norm(dim=-1, keepdim=True))

        p0 = mesh.vertices[mesh.faces[x1.faces][:, 0]]
        p1 = mesh.vertices[mesh.faces[x1.faces][:, 1]]
        p2 = mesh.vertices[mesh.faces[x1.faces][:, 2]]
        x1_e1 = p1 - p0
        x1_e2 = p2 - p0
        T1 = torch.stack((x1_e1, x1_e2), dim=-1)  # (B, 3, 2)
        TtT_inv = inv_2x2(torch.bmm(T1.mT, T1))  # [B, 2, 2]
        T1_inv = torch.bmm(TtT_inv, T1.mT)  # [B, 2, 3]

        x1_e = torch.bmm(grad_x1_uv.unsqueeze(1), T1_inv).squeeze(
            1
        )  # (B, 2) @ (B, 2, 3) -> (B, 3)
        len_grad_x1 = torch.norm(x1_e, dim=1, keepdim=True)
        x1_e = x1_e / x1_e.norm(dim=-1, keepdim=True)  # (B, 3)

        x1_eps, info = trace_geodesics(
            mesh,
            x1,
            eps * x1_e,
            gradient="none",
            save_parallel_transport=True,
            **ctx.kwargs,
        )  # (B, 3)
        v1_e = info.transport(v1)
        x0_e, _ = trace_geodesics(
            mesh, x1_eps, -v1_e, gradient="none", **ctx.kwargs
        )  # (B, 3)

        grad_v = (x0_e.interpolate(mesh) - x0.interpolate(mesh)) / eps
        grad_v = len_grad_x1 * grad_v

        n = mesh.triangle_normals[x0.faces]
        grad_v = grad_v - torch.sum(grad_v * n, dim=-1, keepdim=True) * n

        p0 = mesh.vertices[mesh.faces[x0.faces][:, 0]]
        p1 = mesh.vertices[mesh.faces[x0.faces][:, 1]]
        p2 = mesh.vertices[mesh.faces[x0.faces][:, 2]]
        x0_e1 = p1 - p0
        x0_e2 = p2 - p0

        grad_x0_u = torch.sum(grad_v * x0_e1, dim=1, keepdim=True)  # (B, 1)
        grad_x0_v = torch.sum(grad_v * x0_e2, dim=1, keepdim=True)  # (B, 1)
        grad_x0_uvs = torch.cat([grad_x0_u, grad_x0_v], dim=1)  # (B, 2)

        return None, None, grad_x0_uvs, grad_v, None


def project_batch_to_plane(dirs: Tensor, normals: Tensor) -> Tensor:
    return dirs - normals * torch.sum(dirs * normals, dim=1, keepdim=True)


def get_tensors_from_path(
    mesh: Mesh,
    end_points: MeshPointBatch,
    end_dirs: Tensor,
    end_normals: Tensor,
    starts: MeshPointBatch,
    start_dirs: Tensor,
    debug=False,
) -> Tuple[Tensor, Optional[Tensor]]:
    """
    Convert an endpoint to a tensor representation.
    """
    start_normals = mesh.triangle_normals[starts.faces]
    start_dirs_unit = project_batch_to_plane(start_dirs, start_normals)
    start_dirs_unit = start_dirs_unit / torch.linalg.norm(
        start_dirs_unit, dim=1, keepdim=True
    )

    end_dirs_norm = torch.linalg.norm(end_dirs, dim=1, keepdim=True)
    indices = torch.where(end_dirs_norm < 10e-6)[0]

    end_dirs_unit = end_dirs / end_dirs_norm

    start_cross = torch.cross(start_normals, start_dirs_unit, dim=1)
    end_cross = torch.cross(end_normals, end_dirs_unit, dim=1)

    # create matrix from orthonormal basis
    start_M = torch.stack([start_dirs_unit, start_normals, start_cross], dim=1)
    end_M = torch.stack([end_dirs_unit, end_normals, end_cross], dim=1)

    # rotation matrix from basis to basis
    R = torch.bmm(start_M.transpose(1, 2), end_M).detach()
    R[indices] = torch.eye(3, device=R.device, dtype=R.dtype).unsqueeze(0)

    start_points = starts.interpolate(mesh)

    target_end_intermediate = starts.interpolate(
        mesh
    ).detach() + project_batch_to_plane(start_dirs.detach(), start_normals)
    target_end = torch.bmm(target_end_intermediate.unsqueeze(1), R).squeeze(1).detach()

    T = end_points.interpolate(mesh).detach() - target_end

    tensor_point_intermediate = start_points + project_batch_to_plane(
        start_dirs, start_normals
    )
    tensor_point = torch.bmm(tensor_point_intermediate.unsqueeze(1), R).squeeze(1) + T

    return tensor_point, indices


def tri_bary_coords(p0: Tensor, p1: Tensor, p2: Tensor, p: Tensor) -> Tensor:
    """Compute barycentric coordinates of p in the triangle (p0, p1, p2)."""
    v0 = p1 - p0
    v1 = p2 - p0
    v2 = p - p0

    d00 = torch.sum(v0 * v0, dim=1)
    d01 = torch.sum(v0 * v1, dim=1)
    d11 = torch.sum(v1 * v1, dim=1)
    d20 = torch.sum(v2 * v0, dim=1)
    d21 = torch.sum(v2 * v1, dim=1)

    denom = d00 * d11 - d01 * d01
    denom_mask = denom == 0
    denom[denom_mask] = 1.0

    v = (d11 * d20 - d01 * d21) / denom
    w = (d00 * d21 - d01 * d20) / denom
    u = 1.0 - v - w

    u[denom_mask] = 1.0
    v[denom_mask] = 0.0
    w[denom_mask] = 0.0

    return torch.stack([u, v, w], dim=1)


def get_meshpoints(mesh: Mesh, faces: Tensor, points: Tensor) -> MeshPointBatch:
    triangle = mesh.faces[faces]
    p0 = mesh.vertices[triangle[:, 0]]
    p1 = mesh.vertices[triangle[:, 1]]
    p2 = mesh.vertices[triangle[:, 2]]
    bary = tri_bary_coords(p0, p1, p2, points)
    return MeshPointBatch(faces=faces, uvs=bary[:, 1:])


[docs] def trace_geodesics( mesh: Mesh, starts: List[MeshPoint] | MeshPointBatch, dirs: torch.Tensor, gradient: str = "gfd", use_python: bool = False, max_steps: int = 2000, save_parallel_transport: bool = False, save_end_direction: bool = False, debug: bool = False, print_warnings: bool = True, eps: float = -1.0, avoid_holes: bool = False, ) -> Tuple[MeshPointBatch, GeodesicInfo]: """ Computes the straightest geodesic traces. For a detailed description, see :ref:`straightest_geodesic` in the user guide. Parameters ---------- mesh : Mesh The mesh to trace the geodesics starts: List[MeshPoint] | MeshPointBatch The starting points from which the geodesics start dirs: torch.Tensor The direction tensor (in 3d world coordinates) gradient : str, What method to compute the gradient, can be "none", "ep", "abfd", or "gfd" (default: "gfd"). use_python : bool, If the computation should be made with python, making it easier to debug (default: False). max_steps : int, The maximum steps for the traces (triangle hops) (default: 2000). save_parallel_transport : bool, If the parallel transport rotation should be saved in the GeodesicInfo (default: False). save_end_direction : bool, If the end direction should be saved in the GeodesicInfo (default: False). debug : bool, If the returned GeodesicInfo should contain additionnal debug info (default: False). print_warnings : bool, If warnings should be printed (default: True). eps : float, A small epsilon value to avoid numerical issues (defaults to 10e-4 for floats and 10e-7 for doubles) avoid_holes : bool, If the geodesic tracing should circumvent holes in the mesh (default: False) This is recommanded for meshes with small holes, but discrupts the geodesic tracing. Returns ------- MeshPointBatch, GeodesicInfo The endpoints of the geodesic traces and geodesic information. Examples -------- >>> mesh = load_mesh_from_file(path_to_mesh, device=device) >>> start_meshpoints = uniform_sampling(mesh, N).to(device) >>> start_directions = torch.randn((N, 3), dtype=torch.float32).to(device) >>> meshpoints, geodesic_info = trace_geodesics( ... mesh, start_meshpoints, start_directions ... ) """ if dirs.dim() != 2: raise RuntimeError("Dirs tensor should be of shape [batch_size, 3]") if len(starts) != dirs.shape[0]: raise RuntimeError( f"Starts and dirs should have the same batch size, got {len(starts)} " f"and {dirs.shape[0]}." ) if use_python and str(mesh.device) != "cpu": raise ValueError( "Python geodesic tracing is only supported on CPU. Tensors must be on CPU " "for Python tracing." ) if use_python and gradient == "gfd": raise ValueError( "GFD gradient is not supported with Python geodesic tracing. " "Use C++/CUDA for this feature." ) starts = ( MeshPointBatch.from_points(starts).to(device=mesh.device) if not isinstance(starts, MeshPointBatch) else starts ) # Check if all tensors are on the same device if mesh.device != starts.uvs.device: raise RuntimeError( f"Tensors must be on the same device. Found mesh on {mesh.device}, start " f"uvs on {starts.uvs.device}." ) if mesh.device != starts.faces.device: raise RuntimeError( f"Tensors must be on the same device. Found mesh on {mesh.device}, start " f"faces on {starts.faces.device}." ) if mesh.device != dirs.device: raise RuntimeError( f"Tensors must be on the same device. Found mesh on {mesh.device}, dirs " f"on {dirs.device}." ) debug_tensor = None with torch.no_grad(): if not ( gradient != "gfd" or gradient != "abfd" or debug or save_parallel_transport or save_end_direction ): pass elif use_python: meshpoints = [] directions = torch.empty(len(dirs), 3, dtype=dirs.dtype, device=mesh.device) normals = torch.empty(len(dirs), 3, dtype=dirs.dtype, device=mesh.device) if debug: debug_tensor = torch.empty( (len(dirs), 3, max_steps + 1, 3), dtype=dirs.dtype, device=mesh.device, ) for i, (start, dir) in enumerate(zip(starts, dirs)): result = straightest_geodesic( mesh, start.detach(), dir.detach(), max_steps=max_steps, debug_path=debug, print_warnings=print_warnings, eps=eps, avoid_holes=avoid_holes, ) meshpoints.append(result[0]) directions[i] = result[1] normals[i] = result[2] if debug: path, direction_path, normal_path, vertex_crossings = result[3] debug_tensor[i, 0, 0, 0] = len(path) debug_tensor[i, 0, 0, 1] = vertex_crossings debug_tensor[i, 0, 1 : len(path) + 1, :] = torch.tensor(path) debug_tensor[i, 1, 1 : len(direction_path) + 1, :] = torch.tensor( direction_path ) debug_tensor[i, 2, 1 : len(normal_path) + 1, :] = torch.tensor( normal_path ) meshpoints = MeshPointBatch.from_points(meshpoints).to(device=mesh.device) results = (meshpoints, directions, normals) else: results = cuda_straightest_geodesics( mesh, starts, dirs.detach(), max_steps=max_steps, debug=debug, print_warnings=print_warnings, eps=eps, avoid_holes=avoid_holes, ) if debug: debug_tensor = results[3] kwargs = dict( max_steps=max_steps, print_warnings=print_warnings, eps=eps, avoid_holes=avoid_holes, ) # Compute gradients if needed if gradient == "none": mesh_points = results[0] elif gradient == "ep": tensor_points, error_indices = get_tensors_from_path( mesh, results[0], results[1], results[2], starts, dirs, debug=debug ) mesh_points = get_meshpoints(mesh, results[0].faces, tensor_points) if debug and error_indices is not None: for i in error_indices: i = int(i.item()) print( f"Warning: null output for face {starts[i].face}, " f"uv {starts[i].uv.tolist()}, dir {dirs[i].tolist()}." ) elif gradient == "gfd": mesh_points = trace_geodesics_gfd(mesh, starts, dirs, kwargs) elif gradient == "abfd": mesh_points = trace_geodesics_abfd(mesh, starts, dirs, kwargs) else: raise ValueError( f"Unknown gradient method: {gradient}. Use 'none', 'ep', " f"'gfd', or 'abfd'." ) # Compute parallel transport rotation if needed rotation = None if save_parallel_transport: start_normals = mesh.triangle_normals[starts.faces] end_normals = results[2] start_dirs_unit = ( dirs - torch.sum(dirs * start_normals, dim=1, keepdim=True) * start_normals ) start_dirs_norm = start_dirs_unit.norm(dim=1, keepdim=True) idle_idx = start_dirs_norm[:, 0] < 1e-6 start_dirs_unit = start_dirs_unit / start_dirs_norm.clamp(min=1e-6) end_dirs_unit = results[1] / torch.linalg.norm( results[1], dim=1, keepdim=True ).clamp(min=1e-6) start_cross = torch.cross(start_normals, start_dirs_unit, dim=1) end_cross = torch.cross(end_normals, end_dirs_unit, dim=1) # create matrix from orthonormal basis start_M = torch.stack([start_dirs_unit, start_normals, start_cross], dim=1) end_M = torch.stack([end_dirs_unit, end_normals, end_cross], dim=1) # rotation matrix from basis to basis rotation = torch.bmm(start_M.transpose(1, 2), end_M).detach() rotation[idle_idx] = torch.eye(3, device=rotation.device).unsqueeze(0) end_dir = None if save_end_direction: end_dir = results[1] geodesic_info = GeodesicInfo( debug_tensor=debug_tensor, rotation=rotation, end_directions=end_dir ) return mesh_points, geodesic_info