My personal blog

Machine learning, computer vision, languages

Riemannian SGD in PyTorch

23 Jul 2020

A lot of recent papers use different spaces than the regular Euclidean space. This trend is sometimes called geometric deep learning. There is a growing interest particularly in the domain of word embeddings and graphs.

Since geometric neural networks perform optimization in a different space, it is not possible to simply apply stochastic gradient descent.

The following two equations show what changes are necessary:

\[\begin{aligned} \text{SGD: } \theta_{t+1} &\gets \theta_t - \lambda \nabla \mathcal{L}\\ \text{RSGD: } \theta_{t+1} &\gets \exp_{\theta_t}\left(- \lambda \nabla_R \mathcal{L}\right) \end{aligned}\]

where \(\exp_{\theta_t} : \mathcal{T}_{\theta_t}M \to M\) is the exponential map. It maps a small change by the vector \(v \in \mathcal{T}_{\theta_t}M\) on a point of the manifold \(M\). \(\lambda\) is the learning rate.

\(\nabla_R\) is the Riemannian gradient, given by \(g_{\theta_t}^{-1} \nabla \mathcal{L}\) where \(g_{\theta_t}\) is the metric tensor. This gradient is also called the natural gradient. A derivation can be found in [1].

In the Euclidean space there is one model:

In the hyperbolic space there are multiple models:

For elliptic geometry see the paper [2].

Implementation

Poincaré unit ball

The following code contains also the exact exponential map. I commented the relevant lines out, because empirically the approximation produces slightly better results. Refer to [4], the authors did some more tests.

@torch.jit.script
def lambda_x(x: torch.Tensor):
    return 2 / (1 - torch.sum(x ** 2, dim=-1, keepdim=True))

@torch.jit.script
def mobius_add(x: torch.Tensor, y: torch.Tensor):
    x2 = torch.sum(x ** 2, dim=-1, keepdim=True)
    y2 = torch.sum(y ** 2, dim=-1, keepdim=True)
    xy = torch.sum(x * y, dim=-1, keepdim=True)

    num = (1 + 2 * xy + y2) * x + (1 - x2) * y
    denom = 1 + 2 * xy + x2 * y2

    return num / denom.clamp_min(1e-15)

@torch.jit.script
def expm(p: torch.Tensor, u: torch.Tensor):
    return p + u
    # for exact exponential mapping
    #norm = torch.sqrt(torch.sum(u ** 2, dim=-1, keepdim=True))
    #return mobius_add(p, torch.tanh(0.5 * lambda_x(p) * norm) * u / norm.clamp_min(1e-15))

@torch.jit.script
def grad(p: torch.Tensor):
    p_sqnorm = torch.sum(p.data ** 2, dim=-1, keepdim=True)
    return p.grad.data * ((1 - p_sqnorm) ** 2 / 4).expand_as(p.grad.data)

class RiemannianSGD(Optimizer):
    def __init__(self, params):
        super(RiemannianSGD, self).__init__(params, {})

    def step(self, lr=0.3):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                d_p = grad(p)
                d_p.mul_(-lr)

                p.data.copy_(expm(p.data, d_p))

Hyperboloid

According to [5], the hyperboloid / Lorentz model is more stable during training. However, my tests showed similar results. Sometimes the Poincaré unit ball was actually more stable.

def expm(p : torch.Tensor, u : torch.Tensor):
    ldv = lorentzian_inner_product(u, u, keepdim=True).clamp_(min=1e-15).sqrt_()
    return torch.cosh(ldv) * p + torch.sinh(ldv) * u / ldv

def lorentzian_inner_product(u : torch.Tensor, v : torch.Tensor, keepdim=False):
    uv = u * v
    uv.narrow(-1, 0, 1).mul_(-1)
    return torch.sum(uv, dim=-1, keepdim=keepdim)

@torch.jit.script
def grad(p : torch.Tensor):
    d_p = p.grad
    d_p.narrow(-1, 0, 1).mul_(-1)
    return d_p

def proj(p : torch.Tensor, d_p : torch.Tensor):
    return d_p + lorentzian_inner_product(p.data, d_p, keepdim=True) * p.data

class RiemannianSGD(Optimizer):
    def __init__(self, params):
        super(RiemannianSGD, self).__init__(params, {})

    def step(self, lr=0.3):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                d_p = grad(p)
                d_p = proj(p, d_p)
                d_p.mul_(-lr)

                p.data.copy_(expm(p.data, d_p))

My preliminary tests with word embeddings showed the following disadvantages of the hyperboloid:

  1. It requires one more dimension due to the constraint \(x_0 = \sqrt{1 + \sum_{i=1}^{n+1} x_i^2}\). After training one can get rid of the additional dimension by mapping to the Poincaré unit ball.
  2. It is worse for visualizations, but it is possible to map again to the Poincaré unit ball.
  3. The network weights have to satisfy the equality constraint \(x_0 = \sqrt{1 + \sum_{i=1}^{n+1} x_i^2}\) (see d_p = proj(p, d_p)). This is more difficult than an inequality constraint, because the equality constraint is always active. In comparison, the Poincaré unit ball is defined by \(\lVert x\rVert < 1\). As long as the learning rate is reasonable, points will never fall of the unit ball.

A random initializiation of the weights will violate the equality constraint. Hence, before training one should set all weights \(w\) to torch.sqrt(1 + torch.sum((w.narrow(-1, 1, w.size(-1) - 1)) ** 2, dim=-1, keepdim=True)).

References

[1] Shun-ichi Amari, Natural Gradient Works Efficiently in Learning, 1998.

[2] Y. Meng, J. Huang, G. Wang, C. Zhang, H. Zhuang, L. Kaplan and J. Han, Spherical Text Embedding, 2019.

[3] O. Ganea, G. Bécigneul and T. Hofmann, Hyperbolic Neural Networks, 2018.

[4] G. Bécigneul and O.-E. Ganea, Riemannian Adaptive Optimization Methods, 2018.

[5] M. Nickel and D. Kiela, Learning Continuous Hierarchies in the Lorentz Model of Hyperbolic Geometry, 2018.