Source code for digeo.optim.mesh_gd

import time
from tqdm import tqdm
from typing import Tuple

from digeo import Mesh, MeshPointBatch
from digeo.ops import trace_geodesics
from digeo.optim.utils import MeshLossFunc, line_search


[docs] def mesh_gd( mesh: Mesh, x: MeshPointBatch, loss_func: MeshLossFunc, use_line_search=True, max_iter=100, lr=1.0, tol=1e-6, min_rel_improvement=1e-5, patience=3, ) -> Tuple[MeshPointBatch, dict]: """ Gradient descent optimiser with line search and early stopping. Parameters ---------- mesh: Mesh The mesh used for the optimisation. x: MeshPointBatch The initial points on the mesh to optimize. loss_func: MeshLossFunc The loss function to optimize. use_line_search: bool Whether to use line search to find the optimal step size at each iteration (default: True). max_iter: int The maximum number of iterations (default: 100). lr: float The initial learning rate for the line search (default: 1.0). tol: float The tolerance for the stopping criterion based on the gradient norm (default: 1e-6). min_rel_improvement: float The minimum relative improvement in the loss for early stopping (default: 1e-5). patience: int The number of iterations to wait for an improvement before stopping (default: 3). Returns ------- x': MeshPointBatch The optimized points on the mesh. logs: dict[str, Any] A dictionary containing the information to log, where the keys are the names of the metrics and the values are the corresponding values to log. By default, the logs contain: - ``"loss"`` (List[float]): The loss value at each iteration. - ``"time"`` (float): The total time taken for the optimization. - ``"function_calls"`` (List[int]): The total number of function calls at each iteration. - ``"step_size"`` (List[float]): The step size at each iteration. - ``"mean_dir"`` (List[float]): The mean direction norm at each iteration. """ loss, grad = loss_func(mesh, x) total_function_calls = 1 function_calls = [] progress_bar = tqdm(range(max_iter)) total_loss = [] total_step_size = [] total_mean_dir = [] t0 = time.process_time() no_improve_counter = 0 for _ in progress_bar: if use_line_search: x, grad, loss_new, alpha, c, _ = line_search( mesh=mesh, x=x, loss_func=loss_func, direction=grad, curr_loss=loss, curr_grad=grad, alpha_init=lr, use_wolfe=False, ) total_function_calls += c else: alpha = lr x, _ = trace_geodesics(mesh, x, alpha * grad, gradient="none") loss_new, grad = loss_func(mesh, x) total_function_calls += 1 total_step_size.append(alpha) total_mean_dir.append((alpha * grad).norm(dim=1).mean().item()) total_loss.append(loss_new) function_calls.append(total_function_calls) progress_bar.set_description(f"Loss: {loss_new:.6f}") # Early stopping based on relative loss improvement rel_improvement = abs(loss_new - loss) / (abs(loss_new) + 1e-12) if rel_improvement < min_rel_improvement: no_improve_counter += 1 if no_improve_counter >= patience: print( f"[Early Stopping] No improvement for {patience} steps " f"(rel_impr < {min_rel_improvement})" ) break else: no_improve_counter = 0 loss = loss_new if alpha * grad.norm(dim=-1).mean() < tol: print( f"[Early Stopping] Gradient norm {alpha * grad.norm():.2e} < tol={tol}" ) break logs = { **loss_func.get_logs(), "loss": total_loss, "function_calls": function_calls, "time": time.process_time() - t0, "step_size": total_step_size, "mean_dir": total_mean_dir, } return x, logs