Source code for digeo.optim.utils

import torch
from abc import ABC, abstractmethod
from typing import Any, Tuple
from torch import Tensor

from digeo import Mesh, MeshPointBatch
from digeo.ops import trace_geodesics
from digeo.ops.geodesic import GeodesicInfo


[docs] class MeshLossFunc(ABC): def __call__(self, mesh: Mesh, points: MeshPointBatch) -> Tuple[float, Tensor]: return self.compute(mesh, points)
[docs] @abstractmethod def compute(self, mesh: Mesh, points: MeshPointBatch) -> Tuple[float, Tensor]: """ Loss function used in the mesh optimisers. Parameters ---------- mesh : Mesh The mesh used points: MeshPointBatch The points on the mesh that are being optimized Returns ------- loss: float The loss of the function gradient: Tensor The gradients in 3d space for each point """ ...
[docs] def get_logs(self) -> dict[str, Any]: """ Optionnal method to return any additional information that should be logged during optimization, such as the value of certain terms in the loss function, or any other relevant information. Returns ------- logs: dict[str, Any] A dictionary containing the information to log, where the keys are the names of the metrics and the values are the corresponding values to log. """ return {}
@torch.no_grad() def line_search( mesh: Mesh, x: MeshPointBatch, loss_func: MeshLossFunc, direction: Tensor, curr_loss: float, curr_grad: Tensor, use_wolfe=True, c1=1e-4, c2=0.9, alpha_init=1.0, tau=0.1, max_iter=3, ) -> Tuple[MeshPointBatch, Tensor, float, float, int, GeodesicInfo]: """ Wolfe line search satisfying Armijo and curvature conditions. Parameters ---------- mesh: Mesh The mesh used x: MeshPointBatch The current points on the mesh loss_func: MeshLossFunc The loss function direction: Tensor The direction used for the line search to optimize x curr_loss: float The current loss for x curr_grad: Tensor The current gradients for x for loss_func use_wolfe: bool If the Wolfe conditions should be used (default: True) c1: float First Wolfe condition constant (default: 1e-4) c2: float Second Wolfe condition constant (default: 0.9) alpha_init: float The initial step size (defailt: 1.0) tau: float The backtracking reduction factor (default: 0.1) max_iter: int The maximum number of iterations (default: 3) Returns ------- x': MeshPointBatch The new meshpoints gradient': Tensor The gradients for x' loss: float The loss of the loss function for x' alpha: float step size chosen satisfying Wolfe conditions function_calls: int The number of times the loss function was called info: GeodesicInfo The geodesic info of the trace from x to x' """ alpha = alpha_init f_x = curr_loss grad = curr_grad directional_derivative = torch.sum(-grad * direction).item() directional_derivative = min(directional_derivative, 0) function_calls = 0 for k in range(max_iter): x_new, info = trace_geodesics( mesh, x, alpha * direction, gradient="none", save_parallel_transport=True, max_steps=10000, print_warnings=False, ) loss, new_grad = loss_func(mesh, x_new) function_calls += 1 with torch.no_grad(): # Armijo condition if loss > f_x + c1 * alpha * directional_derivative: alpha *= tau continue if not use_wolfe: break transported_dir = torch.bmm(direction.unsqueeze(1), info.rotation).squeeze( 1 ) new_derivative = torch.sum(-new_grad * transported_dir) # Curvature condition if -new_derivative <= -c2 * directional_derivative: break else: alpha *= tau return x_new, new_grad, loss, alpha, function_calls, info