Source code for digeo.optim.mesh_lbfgs

import torch
import time
from tqdm import tqdm
from typing import Tuple, Any

from digeo import Mesh, MeshPointBatch
from digeo.optim.utils import MeshLossFunc, line_search


[docs] def mesh_lbfgs( mesh: Mesh, x: MeshPointBatch, loss_func: MeshLossFunc, max_iter: int = 100, list_size: int = 20, lr: float = 1.0, tol: float = 1e-6, min_rel_improvement: float = 1e-5, patience: int = 3, eps: float = 1e-6, ) -> Tuple[MeshPointBatch, dict[str, Any]]: """ Mesh-LBFGS optimiser with early stopping. Parameters ---------- mesh: Mesh The mesh used for the optimisation. x: MeshPointBatch The initial points on the mesh to optimize. loss_func: MeshLossFunc The loss function to optimize. max_iter: int The maximum number of iterations (default: 100). list_size: int The size of the history list for LBFGS (default: 20). lr: float The initial learning rate for the line search (default: 1.0). tol: float The tolerance for the stopping criterion based on the gradient norm (default: 1e-6). min_rel_improvement: float The minimum relative improvement in the loss for early stopping (default: 1e-5). patience: int The number of iterations to wait for an improvement before stopping (default: 3). eps: float A small constant to prevent division by zero in the LBFGS update (default: 1e-6). Returns ------- x': MeshPointBatch The optimized points on the mesh. logs: dict[str, Any] A dictionary containing the information to log, where the keys are the names of the metrics and the values are the corresponding values to log. By default, the logs contain: - ``"loss"`` (List[float]): The loss value at each iteration. - ``"time"`` (float): The total time taken for the optimization. - ``"function_calls"`` (List[int]): The total number of function calls at each iteration. - ``"step_size"`` (List[float]): The step size at each iteration. - ``"mean_dir"`` (List[float]): The mean direction norm at each iteration. """ loss, grad = loss_func(mesh, x) H_diag = 1.0 s_list, y_list, sy_list, rot_list = [], [], [], [] total_loss = [] step_size = [] mean_dir = [] function_calls = [] total_function_calls = 1 t0 = time.process_time() no_improve_counter = -1 progress_bar = tqdm(range(max_iter)) for _ in progress_bar: grad_norm = grad.norm().item() if grad_norm < tol: print(f"[Early Stopping] Gradient norm {grad_norm:.2e} < tol={tol}") break dir = desc( grad, len(s_list) - 1, s_list, y_list, sy_list, rot_list, H_diag, eps ) x_new, grad_new, loss_new, alpha, c, info = line_search( mesh=mesh, x=x, loss_func=loss_func, direction=dir, curr_loss=loss, curr_grad=grad, alpha_init=lr, use_wolfe=True, ) total_function_calls += c function_calls.append(total_function_calls) step_size.append(alpha) total_loss.append(loss_new) mean_dir.append((alpha * dir).norm(dim=1).mean().item()) if mean_dir[-1] < tol: print( f"[Early Stopping] Mean direction norm {mean_dir[-1]:.2e} < tol={tol}" ) break with torch.no_grad(): s = torch.bmm((alpha * dir).unsqueeze(1), info.rotation).squeeze(1) transported_grad = torch.bmm(grad.unsqueeze(1), info.rotation).squeeze(1) y = -(grad_new - transported_grad) s_list.append(s) y_list.append(y) sy_list.append(torch.sum(s * y)) rot_list.append(info.rotation) if len(s_list) > list_size: del s_list[0], y_list[0], sy_list[0], rot_list[0] H_diag = sy_list[-1] / (torch.sum(y * y) + eps) # Early stopping based on relative loss improvement rel_improvement = abs(loss_new - loss) / (abs(loss_new) + 1e-12) if rel_improvement < min_rel_improvement: no_improve_counter += 1 if no_improve_counter >= patience: print( f"[Early Stopping] No improvement for {patience} steps " f"(rel_impr < {min_rel_improvement})" ) break else: no_improve_counter = 0 x = x_new grad = grad_new loss = loss_new progress_bar.set_description(f"Loss: {loss:.6f}") logs = { **loss_func.get_logs(), "loss": total_loss, "time": time.process_time() - t0, "function_calls": function_calls, "step_size": step_size, "mean_dir": mean_dir, } return x, logs
@torch.no_grad() def desc(p, t, s_list, y_list, sy_list, rot_list, H_diag, eps=1e-6): if t < 0: return H_diag * p start_p = p.clone() p = p - (torch.sum(s_list[t] * p) / (sy_list[t] + eps)) * y_list[t] pt = torch.bmm(p.unsqueeze(1), rot_list[t].transpose(1, 2)).squeeze(1) p = desc(pt, t - 1, s_list, y_list, sy_list, rot_list, H_diag, eps) p = torch.bmm(p.unsqueeze(1), rot_list[t]).squeeze(1) return ( p - (torch.sum(y_list[t] * p) / (sy_list[t] + eps)) * s_list[t] + (torch.sum(s_list[t] * s_list[t]) / (sy_list[t] + eps)) * start_p )