Блог компании
⚡ HyperTwistSGD: оптимизатор с автоматической устойчивостью

Аннотация
Мы предлагаем новый метод обучения нейросетей — HyperTwistSGD,
основанный на строгой теореме устойчивости HyperTwist.
Метод вводит метрику F′ и использует её для автоматической подстройки learning rate,
устраняя необходимость в ручных расписаниях и обеспечивая устойчивость обучения.
1. Проблема
Современные оптимизаторы (SGD
, Adam
, OneCycleLR
и др.):
- требуют ручного подбора learning rate,
- зависят от эмпирических расписаний,
- не имеют строгого критерия устойчивости.
Это делает процесс обучения «искусством тюнинга», а не строгой наукой.
2. Метод HyperTwistSGD
Идея:
вместо эвристик использовать физико-математический критерий устойчивости.
На каждом шаге вычисляется:
- r=∥∇θ∥ — норма градиента,
- V=∥Δθ∥ — шаг параметров,
- F=rV2 ,
- F′=drdF .
Условие устойчивости: F′≥0 .
Если F′ выходит за пределы коридора [Fmin,Fmax] , learning rate корректируется PID-контуром.
3. Реализация (PyTorch)
# twistcoder.py
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
# ---------- Метрика HyperTwist ----------
class HyperTwistMeter:
def __init__(self, writer: SummaryWriter=None, tag="ht"):
self.w, self.tag = writer, tag
self.prev = None
self.theta_before = None
self.step = 0
@torch.no_grad()
def _flat(self, model):
return torch.cat([p.data.view(-1) for p in model.parameters() if p.requires_grad])
@torch.no_grad()
def _grad_norm(self, model):
s = 0.0
for p in model.parameters():
if p.grad is not None:
s += p.grad.detach().pow(2).sum().item()
return s**0.5
@torch.no_grad()
def begin(self, model):
self.theta_before = self._flat(model)
@torch.no_grad()
def end(self, model):
if self.theta_before is None:
return None
theta_after = self._flat(model)
r = self._grad_norm(model)
V = (theta_after - self.theta_before).norm().item()
F = r * (V**2)
Fp = float("nan")
if self.prev is not None:
r0, F0 = self.prev
dr = r - r0
if abs(dr) > 1e-12:
Fp = (F - F0) / dr
if self.w is not None:
self.w.add_scalar(f"{self.tag}/r_grad", r, self.step)
self.w.add_scalar(f"{self.tag}/V_step", V, self.step)
self.w.add_scalar(f"{self.tag}/F", F, self.step)
if np.isfinite(Fp):
self.w.add_scalar(f"{self.tag}/Fprime", Fp, self.step)
self.w.flush()
self.prev = (r, F)
self.theta_before = None
self.step += 1
return {"r": r, "V": V, "F": F, "Fprime": Fp, "step": self.step-1}
# ---------- Оптимизатор с контролем по F' ----------
class HyperTwistSGD(torch.optim.SGD):
def __init__(self, params, lr=1e-1, momentum=0.0,
F_min=1e-3, F_max=4e-3, alpha=0.9,
k_p=0.5, k_i=0.05, k_d=0.0,
low=1e-5, high=3e-1):
super().__init__(params, lr=lr, momentum=momentum)
self.F_min, self.F_max, self.alpha = F_min, F_max, alpha
self.F_target = 0.5*(F_min+F_max)
self.k_p, self.k_i, self.k_d = k_p, k_i, k_d
self.e_prev, self.e_int = 0.0, 0.0
self.e_int_cap = 5.0
self.low, self.high = low, high
self.Fp_sm = None
@torch.no_grad()
def adapt(self, Fprime: float) -> float:
if not np.isfinite(Fprime):
return self.param_groups[0]["lr"]
self.Fp_sm = Fprime if self.Fp_sm is None else self.alpha*self.Fp_sm + (1-self.alpha)*Fprime
if self.F_min <= self.Fp_sm <= self.F_max:
return self.param_groups[0]["lr"]
e = self.Fp_sm - self.F_target
self.e_int = np.clip(0.99*self.e_int + e, -self.e_int_cap, self.e_int_cap)
de = e - self.e_prev
gain = (self.k_p*e + self.k_i*self.e_int + self.k_d*de)
scale = float(np.exp(-gain))
for g in self.param_groups:
g["lr"] = float(np.clip(g["lr"]*scale, self.low, self.high))
self.e_prev = e
return self.param_groups[0]["lr"]
# ---------- Демонстрация ----------
def main():
torch.manual_seed(42)
X = torch.randn(1000, 10)
y = (X.sum(dim=1, keepdim=True) > 0).float()
loader = DataLoader(TensorDataset(X, y), batch_size=32, shuffle=True)
model = nn.Sequential(nn.Linear(10, 1), nn.Sigmoid())
loss_fn = nn.BCELoss()
opt = HyperTwistSGD(model.parameters(), lr=1e-1)
writer = SummaryWriter()
ht = HyperTwistMeter(writer, tag="demo")
epochs = 5
for epoch in range(epochs):
for xb, yb in loader:
opt.zero_grad()
loss = loss_fn(model(xb), yb)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
ht.begin(model)
opt.step()
stats = ht.end(model)
if stats and np.isfinite(stats["Fprime"]):
lr_now = opt.adapt(stats["Fprime"])
print(f"step={stats['step']:4d} loss={loss.item():.4f} "
f"F'={stats['Fprime']:+.3e} F'sm={opt.Fp_sm:+.3e} lr={lr_now:.3e}")
print(f"epoch {epoch+1}, last_loss={loss.item():.4f}")
if __name__ == "__main__":
main()