Блог компании

⚡ HyperTwistSGD: оптимизатор с автоматической устойчивостью

⚡ HyperTwistSGD: оптимизатор с автоматической устойчивостью

Аннотация

Мы предлагаем новый метод обучения нейросетей — HyperTwistSGD,
основанный на строгой теореме устойчивости HyperTwist.
Метод вводит метрику F и использует её для автоматической подстройки learning rate,
устраняя необходимость в ручных расписаниях и обеспечивая устойчивость обучения.


1. Проблема

Современные оптимизаторы (SGD, Adam, OneCycleLR и др.):

  • требуют ручного подбора learning rate,
  • зависят от эмпирических расписаний,
  • не имеют строгого критерия устойчивости.

Это делает процесс обучения «искусством тюнинга», а не строгой наукой.


2. Метод HyperTwistSGD

Идея:
вместо эвристик использовать физико-математический критерий устойчивости.

  • На каждом шаге вычисляется:

    • r=∥∇θ — норма градиента,
    • V=∥Δθ — шаг параметров,
    • F=rV2 ,
    • F=drdF .
  • Условие устойчивости: F0 .

  • Если F выходит за пределы коридора [Fmin,Fmax] , learning rate корректируется PID-контуром.


3. Реализация (PyTorch)

# twistcoder.py
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

# ---------- Метрика HyperTwist ----------
class HyperTwistMeter:
    def __init__(self, writer: SummaryWriter=None, tag="ht"):
        self.w, self.tag = writer, tag
        self.prev = None
        self.theta_before = None
        self.step = 0

    @torch.no_grad()
    def _flat(self, model):
        return torch.cat([p.data.view(-1) for p in model.parameters() if p.requires_grad])

    @torch.no_grad()
    def _grad_norm(self, model):
        s = 0.0
        for p in model.parameters():
            if p.grad is not None:
                s += p.grad.detach().pow(2).sum().item()
        return s**0.5

    @torch.no_grad()
    def begin(self, model):
        self.theta_before = self._flat(model)

    @torch.no_grad()
    def end(self, model):
        if self.theta_before is None:
            return None
        theta_after = self._flat(model)
        r = self._grad_norm(model)
        V = (theta_after - self.theta_before).norm().item()
        F = r * (V**2)

        Fp = float("nan")
        if self.prev is not None:
            r0, F0 = self.prev
            dr = r - r0
            if abs(dr) > 1e-12:
                Fp = (F - F0) / dr

        if self.w is not None:
            self.w.add_scalar(f"{self.tag}/r_grad", r, self.step)
            self.w.add_scalar(f"{self.tag}/V_step", V, self.step)
            self.w.add_scalar(f"{self.tag}/F", F, self.step)
            if np.isfinite(Fp):
                self.w.add_scalar(f"{self.tag}/Fprime", Fp, self.step)
            self.w.flush()

        self.prev = (r, F)
        self.theta_before = None
        self.step += 1
        return {"r": r, "V": V, "F": F, "Fprime": Fp, "step": self.step-1}

# ---------- Оптимизатор с контролем по F' ----------
class HyperTwistSGD(torch.optim.SGD):
    def __init__(self, params, lr=1e-1, momentum=0.0,
                 F_min=1e-3, F_max=4e-3, alpha=0.9,
                 k_p=0.5, k_i=0.05, k_d=0.0,
                 low=1e-5, high=3e-1):
        super().__init__(params, lr=lr, momentum=momentum)
        self.F_min, self.F_max, self.alpha = F_min, F_max, alpha
        self.F_target = 0.5*(F_min+F_max)
        self.k_p, self.k_i, self.k_d = k_p, k_i, k_d
        self.e_prev, self.e_int = 0.0, 0.0
        self.e_int_cap = 5.0
        self.low, self.high = low, high
        self.Fp_sm = None

    @torch.no_grad()
    def adapt(self, Fprime: float) -> float:
        if not np.isfinite(Fprime):
            return self.param_groups[0]["lr"]
        self.Fp_sm = Fprime if self.Fp_sm is None else self.alpha*self.Fp_sm + (1-self.alpha)*Fprime
        if self.F_min <= self.Fp_sm <= self.F_max:
            return self.param_groups[0]["lr"]

        e = self.Fp_sm - self.F_target
        self.e_int = np.clip(0.99*self.e_int + e, -self.e_int_cap, self.e_int_cap)
        de = e - self.e_prev
        gain = (self.k_p*e + self.k_i*self.e_int + self.k_d*de)
        scale = float(np.exp(-gain))
        for g in self.param_groups:
            g["lr"] = float(np.clip(g["lr"]*scale, self.low, self.high))
        self.e_prev = e
        return self.param_groups[0]["lr"]

# ---------- Демонстрация ----------
def main():
    torch.manual_seed(42)
    X = torch.randn(1000, 10)
    y = (X.sum(dim=1, keepdim=True) > 0).float()
    loader = DataLoader(TensorDataset(X, y), batch_size=32, shuffle=True)

    model = nn.Sequential(nn.Linear(10, 1), nn.Sigmoid())
    loss_fn = nn.BCELoss()

    opt = HyperTwistSGD(model.parameters(), lr=1e-1)
    writer = SummaryWriter()
    ht = HyperTwistMeter(writer, tag="demo")

    epochs = 5
    for epoch in range(epochs):
        for xb, yb in loader:
            opt.zero_grad()
            loss = loss_fn(model(xb), yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            ht.begin(model)
            opt.step()
            stats = ht.end(model)

            if stats and np.isfinite(stats["Fprime"]):
                lr_now = opt.adapt(stats["Fprime"])
                print(f"step={stats['step']:4d}  loss={loss.item():.4f}  "
                      f"F'={stats['Fprime']:+.3e}  F'sm={opt.Fp_sm:+.3e}  lr={lr_now:.3e}")
        print(f"epoch {epoch+1}, last_loss={loss.item():.4f}")

if __name__ == "__main__":
    main()

Наука