Jekyll2021-09-05T13:55:03+00:00https://lars76.github.io/feed.xmllars76.github.ioSigmoid activation is not optimal with binary segmentation2021-09-05T00:00:00+00:002021-09-05T00:00:00+00:00https://lars76.github.io/2021/09/05/activations-segmentation<p>The standard activation function for binary outputs is the sigmoid function. However, <a href="http://arxiv.org/abs/2109.00903">in a recent paper</a>, I show empirically on several medical segmentation datasets that other functions can be better.</p>
<!--more-->
<p>Two important results of this work are:</p>
<ul>
<li>Dice loss gives better results with the arctangent function than with the sigmoid function.</li>
<li>Binary cross entropy together with the normal CDF can lead to better results than the sigmoid function.</li>
</ul>
<p>In this blog post, I will implement the two results in PyTorch.</p>
<h2 id="arctangent-and-dice-loss">Arctangent and Dice loss</h2>
<p>Dice loss is a common loss function in segmentation. It is defined as follows:</p>
\[\text{DL} = 1 - \frac{2\sum_{i}f(x_{i})y_{i}}{\sum_{i}f(x_{i}) + y_{i}}\,,\]
<p>where \(x_i\) are the inputs and \(y_i \in \{0, 1\}\) is the ground truth. \(f(x)\) defines the activation function, usually \(f(x)\) is the sigmoid activation function.</p>
<p>The following code implements this loss function:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">class</span> <span class="nc">DiceLoss</span><span class="p">():</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">pass</span>
<span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">):</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">activation</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="p">(</span><span class="n">y_pred</span> <span class="o">*</span> <span class="n">y_true</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">y_pred</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span> <span class="o">+</span> <span class="n">y_true</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">-</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">numerator</span><span class="p">)</span> <span class="o">/</span> <span class="n">denominator</span>
<span class="k">def</span> <span class="nf">activation</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">DiceLoss</span><span class="p">()</span>
<span class="p">...</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span></code></pre></figure>
<p>On four different datasets, the sigmoid activation achieved an average dice coefficient of \(0.726575\). By replacing the sigmoid activation by the following arctangent function, there was an increase of about 2%.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">activation</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="mf">1e-7</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="mf">1e-7</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="n">arctan</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">/</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="p">))</span></code></pre></figure>
<p>The reason why arctangent is better than sigmoid is that sigmoid is too fast. Arctangent has more freedom of action than sigmoid. In the paper, this is made clear by comparing the cross entropy error and the rate of change. Since dice loss computes all predictions at the same time, we need a slower function. Binary cross entropy, on the other hand, considers each pixel individually.</p>
<h2 id="normal-cdf-and-cross-entropy">Normal CDF and Cross Entropy</h2>
<p>Binary cross entropy can be defined mathematically as follows:</p>
\[\text{BCE} = -\frac{1}{n}\sum_{i}y_{i} \log f(x_{i}) + \left(1-y_{i}\right) \log\left(1-f(x_{i})\right)\,.\]
<p>where \(x_i\) are the inputs and \(y_i \in \{0, 1\}\) is the ground truth. Again \(f(x)\) is the sigmoid function. Usually in PyTorch we use the more numerical stable functions <code class="language-plaintext highlighter-rouge">F.binary_cross_entropy_with_logits(y_hat, y_true)</code> or <code class="language-plaintext highlighter-rouge">BCEWithLogitsLoss()</code>. These two functions combine the sigmoid function with cross entropy.</p>
<p>In the paper, I propose the normal CDF for \(f(x)\) instead. On average, the normal CDF is about 0.1% better than sigmoid. For some datasets, it can be up to 1% better than sigmoid.</p>
<p>The following code implements the normal CDF together with BCE:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">BCELoss</span>
<span class="k">def</span> <span class="nf">activation</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">-</span> <span class="mf">1e-7</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="p">.</span><span class="n">erf</span><span class="p">(</span><span class="n">x</span><span class="o">/</span><span class="n">torch</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="mi">2</span><span class="p">)))</span> <span class="o">+</span> <span class="mf">0.5</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">BCELoss</span><span class="p">()</span>
<span class="p">...</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">activation</span><span class="p">(</span><span class="n">predictions</span><span class="p">),</span> <span class="n">targets</span><span class="p">)</span></code></pre></figure>
<p>The normal CDF is a function that reaches the probabilities 0% and 100% faster than the sigmoid function. Using the normal CDF reduces the freedom of action and forces the network to make faster decisions. This leads to less uncertainty and a better dice coefficient.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The output activation function has hardly been analyzed for neural networks so far. In the paper, the rate of change of the activation function was related to the resulting segmentation errors (dice coefficient). It was shown that the sigmoid function is not always the best output function. Since I was limited by the available GPU resources, the tests were only performed on medical segmentation datasets. It would be interesting to see some results in other domains as well.</p>
<h2 id="references">References</h2>
<p>[1] Lars Nieradzik, Gerik Scheuermann, Dorothee Saur, and Christina Gillmann. (2021). Effect of the output activation function on the probabilities and errors in medical image segmentation.</p>The standard activation function for binary outputs is the sigmoid function. However, in a recent paper, I show empirically on several medical segmentation datasets that other functions can be better.Operations on contextual word embeddings2021-06-16T00:00:00+00:002021-06-16T00:00:00+00:00https://lars76.github.io/2021/06/16/operations-contextual-embeddings<p>A variety of operations can be performed on (contextual) word vectors. In this blog post, I will implement some common operations using PyTorch and Python.</p>
<!--more-->
<p>First, we need to create the word embeddings. This is done in the section “introduction” using the <a href="https://huggingface.co/">transformers</a> library. The following section “Operations” then deals with the topic of this blog post.</p>
<h2 id="introduction">Introduction</h2>
<p>Word embeddings project high dimensional word vectors onto a low dimensional space. More mathematically, we could describe a word embedding as an injective function \(f : \{0, 1\}^n \to \mathbb{R}^k\) where \(n \gg k\). For example, <a href="https://nlp.stanford.edu/projects/glove/">GloVe</a> (with Wikipedia 2014 + Gigaword 5) uses \(n = 4 \cdot 10^{5}\) and \(k = 300\).</p>
<p>There are mainly two types of word embeddings:</p>
<ol>
<li>
<p>Regular word embeddings like SGNS (word2vec) or GloVe look at each word individually without considering the context. It was shown by <a href="#6">[1]</a> that this particular case can be viewed as factorizing a word-context matrix.</p>
</li>
<li>
<p>Sentence embeddings like BERT require the whole text as input. The same two words \(v, w \in \{0, 1\}^n\) at different positions are mapped to different \(f(v), f(w) \in \mathbb{R}^k\).</p>
</li>
</ol>
<p>Since there are already good introductions to regular embeddings, we will only consider the case of contextual word embeddings here.</p>
<p>There are many transformer models but most of them use \(k = 768\) for the output space. This is too big for doing quick tests. For this reason, we choose a more compact model called <code class="language-plaintext highlighter-rouge">google/bert_uncased_L-12_H-128_A-2</code> with only \(k = 128\) <a href="#6">[2]</a>.</p>
<p>In addition to a transformer model, embeddings also require a large number of sentences to extract individual words. As a place holder, I came up with a few sentences and put them in a list <code class="language-plaintext highlighter-rouge">sentences</code>. This variable should be filled with real sentences.</p>
<p>Finally, we can look at some code. We want to do three things: encode each sentence, remove the padding “[PAD]” and assign a sentence ID to each token.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoTokenizer</span><span class="p">,</span> <span class="n">AutoModel</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="n">sentences</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="s">"This is cold."</span><span class="p">,</span> <span class="s">"This is warm."</span><span class="p">,</span> <span class="s">"This is a test."</span><span class="p">])</span>
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"google/bert_uncased_L-12_H-128_A-2"</span><span class="p">)</span>
<span class="n">tokens_encoded</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">sentences</span><span class="p">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">padding</span><span class="o">=</span><span class="s">'longest'</span><span class="p">,</span> <span class="n">truncation</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s">'pt'</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">AutoModel</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"google/bert_uncased_L-12_H-128_A-2"</span><span class="p">,</span>
<span class="n">return_dict</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">output_hidden_states</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">tokens_encoded</span><span class="p">).</span><span class="n">last_hidden_state</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">vstack</span><span class="p">([</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">tokens_encoded</span><span class="p">.</span><span class="n">attention_mask</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">].</span><span class="n">numpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">sentences</span><span class="p">))])</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">convert_ids_to_tokens</span><span class="p">(</span><span class="n">token</span><span class="p">)</span> <span class="k">for</span> <span class="n">token</span> <span class="ow">in</span> <span class="n">tokens_encoded</span><span class="p">[</span><span class="s">"input_ids"</span><span class="p">]])</span>
<span class="n">sent_ids</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">hstack</span><span class="p">([[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">token</span><span class="p">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">token</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">tokens_encoded</span><span class="p">[</span><span class="s">"attention_mask"</span><span class="p">])])</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">token</span> <span class="k">for</span> <span class="n">token</span> <span class="ow">in</span> <span class="n">tokens</span><span class="p">.</span><span class="n">flatten</span><span class="p">()</span> <span class="k">if</span> <span class="n">token</span> <span class="o">!=</span> <span class="s">"[PAD]"</span><span class="p">])</span></code></pre></figure>
<p>The output of the code is an \(m \times 128\) embedding matrix \(X\) that contains all \(m\) words from each sentence. Each sentence in \(X\) starts with the word [CLS] and ends with [SEP]. It is also possible to remove these two placeholder characters to save memory. Note that the variable <code class="language-plaintext highlighter-rouge">sentences</code> contains only three debug sentences, I am using a corpus of about 3000 sentences to get good results. Then \(m \gg 128\) (more words than embedding dimensions).</p>
<p>Having created the embeddings \(X\), we can look at some operations.</p>
<h2 id="operations">Operations</h2>
<p>The outputs <code class="language-plaintext highlighter-rouge">tokens</code>, <code class="language-plaintext highlighter-rouge">sent_ids</code>, <code class="language-plaintext highlighter-rouge">X</code> of the previous section are needed for this section. I am going to present the following operations:</p>
<ul>
<li>Dot product: compute the similarity between two words</li>
<li>Linear combination: decompose a word into its subcomponents</li>
<li>Vector addition/subtraction: reason about the relations between words (e.g. analogy)</li>
<li>Isotropy: postprocess the embeddings to make the words more “spread out” in space</li>
</ul>
<h3 id="dot-product">Dot product</h3>
<p>The first operation is defined as follows:</p>
\[a^Tb = ||a|| ||b|| \cos\theta \iff \frac{a^Tb}{||a|| ||b||} = \cos\theta\]
<p>This operation is also called cosine similarity. It measures the similarity between two words e.g. a = “cold” and b = “warm”.</p>
<p>The following code measures the similarity between the word “cold” (in a particular sentence) and all other words in the embedding matrix \(X\) (all sentences).</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">X</span> <span class="o">/=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">word</span> <span class="o">=</span> <span class="s">"cold"</span>
<span class="n">word_mask</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">flatnonzero</span><span class="p">(</span><span class="n">tokens</span> <span class="o">==</span> <span class="n">word</span><span class="p">)</span>
<span class="n">similarity</span> <span class="o">=</span> <span class="n">X</span> <span class="o">@</span> <span class="n">X</span><span class="p">[</span><span class="n">word_mask</span><span class="p">].</span><span class="n">T</span>
<span class="n">top_k</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">similarity</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="o">-</span><span class="mi">5</span><span class="p">:][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># top-5 results
</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Top 5 words for the '</span><span class="si">{</span><span class="n">word</span><span class="si">}</span><span class="s">'.</span><span class="se">\n</span><span class="s">"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">similarity</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Source sentence: </span><span class="si">{</span><span class="n">sentences</span><span class="p">[</span><span class="n">sent_ids</span><span class="p">[</span><span class="n">word_mask</span><span class="p">[</span><span class="n">i</span><span class="p">]]]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Closest tokens: </span><span class="si">{</span><span class="n">tokens</span><span class="p">[</span><span class="n">top_k</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Closest similarities: </span><span class="si">{</span><span class="n">similarity</span><span class="p">[:,</span> <span class="n">i</span><span class="p">][</span><span class="n">top_k</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Closest sentences: </span><span class="si">{</span><span class="n">sentences</span><span class="p">[</span><span class="n">sent_ids</span><span class="p">[</span><span class="n">top_k</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]]]</span><span class="si">}</span><span class="se">\n</span><span class="s">"</span><span class="p">)</span></code></pre></figure>
<p>An output for a text dataset with around 700000 tokens is as follows:</p>
<ul>
<li>Source sentence: “It was alone in the <strong>cold</strong> world of frost and snow.”</li>
<li>Closest token to “cold”: “cold” (similarity 0.92475)</li>
<li>Found sentence: “It is a very <strong>cold</strong> country, and very rocky; and there are a great many small islands all around it.”</li>
</ul>
<p>Since words tend to occur more than once, the highest similarity is the word “cold” of another sentence.</p>
<p>For the next test, I removed all occurrences of “cold” and used another sentence.</p>
<ul>
<li>Source sentence: “Examples of the former case are giant molecular clouds, the <strong>cold</strong>est, densest phase of interstellar gas, which can form by the cooling and condensation of more diffuse gas.”</li>
<li>Closest token to “cold”: “warm” (similarity 0.8814)</li>
<li>Found sentence: “Paleontologists used to believe that dinosaurs lived only in the <strong>warm</strong>est parts of the world.”</li>
</ul>
<h3 id="linear-combination">Linear combination</h3>
<p>Writing a vector in terms of other vectors is called a linear combination or a superposition. We want to decompose a word into its subcomponents. More formally, we can write</p>
\[w = a_{1}v_{1} + a_2v_{2}+a_{3}v_{3}+\cdots +a_mv_{m}\,,\]
<p>where \(w, v_1, \dots, v_m \in \mathbb{R}^{128}\) are words and \(a_1, \dots, a_m \in \mathbb{R}\). For example: \(\text{winter} = 0.7\cdot\text{autumn} + 0.3\cdot\text{summer}\). Note that the factors \(a_1, \dots, a_m\) do not have to sum to \(1\) (except if one enforces it).</p>
<p>The first step for finding the components is to apply the transpose operation on the embedding matrix \(X\). Then the original \(m \times 128\) matrix turns into a \(128 \times m\) matrix. The objective is to multiply this \(X^T\) matrix with an \(m \times 1\) vector \(x\) where most entries are zero (sparse):</p>
\[X^Tx = y\]
<p>The equation expresses the above \(0.7\cdot\text{autumn} + 0.3\cdot\text{summer} = \text{winter}\) relationship in matrix form.</p>
<p>We are dealing here with an underdetermined system of equations because we have fewer equations than unknowns. Such a system tends to have an infinite number of solutions, that is, there is more than one way to decompose a word. For the same reason, linear regression will not work here. Instead the usual approach to find \(x\) (or in other words \(a_1, \dots, a_m\)) is “sparse dictionary learning” <a href="#6">[3, 4]</a>. The paper <a href="#6">[4]</a> uses the FISTA
algorithm for this purpose. This basically amounts to using regression with \(L_1\) regularization which is also known as LASSO.</p>
<p>We simplify matters by only decomposing a single word into its subcomponents and apply regular LASSO regression.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">Lasso</span>
<span class="n">word_mask</span> <span class="o">=</span> <span class="n">tokens</span> <span class="o">==</span> <span class="s">"cold"</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Lasso</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">fit_intercept</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">positive</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">max_iter</span><span class="o">=</span><span class="mf">1e4</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="o">~</span><span class="n">word_mask</span><span class="p">].</span><span class="n">T</span><span class="p">,</span> <span class="n">X</span><span class="p">[</span><span class="n">word_mask</span><span class="p">].</span><span class="n">T</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="n">coef_</span> <span class="o">/=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">coef_</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="c1"># normalize to 1
</span>
<span class="n">top_k</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">coef_</span><span class="p">.</span><span class="n">T</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="o">-</span><span class="mi">5</span><span class="p">:][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># top-5 results
</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">top_k</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
<span class="n">string</span> <span class="o">=</span> <span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">word</span><span class="si">}</span><span class="s"> = "</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span>
<span class="n">coef</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">coef_</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">top_k</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">i</span><span class="p">]]</span>
<span class="k">if</span> <span class="n">coef</span> <span class="o">!=</span> <span class="mf">0.0</span><span class="p">:</span>
<span class="n">string</span> <span class="o">+=</span> <span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">coef</span><span class="si">}</span><span class="s"> * </span><span class="si">{</span><span class="n">tokens</span><span class="p">[</span><span class="o">~</span><span class="n">word_mask</span><span class="p">][</span><span class="n">top_k</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">i</span><span class="p">]]</span><span class="si">}</span><span class="s"> + "</span>
<span class="k">print</span><span class="p">(</span><span class="n">string</span><span class="p">[:</span><span class="o">-</span><span class="mi">3</span><span class="p">])</span></code></pre></figure>
<p>The results are as follows:</p>
\[\begin{aligned}
\text{cold} &= 0.66\cdot\text{warm} + 0.23\cdot\text{heavy} + 0.11\cdot\text{hot}\\
\text{cold} &= 0.60\cdot\text{hot} + 0.30\cdot\text{cool} + 0.1\cdot\text{heavy}\\
\text{cold} &= 0.84\cdot\text{dry} + 0.16\cdot\text{warm}\\
\cdots
\end{aligned}\]
<p>Depending on the sentence, the same word “cold” has different components. In this example, “cold” consists of 2 or 3 components.</p>
<h3 id="vector-additionsubtraction">Vector addition/subtraction</h3>
<p>Vector addition and subtraction can correspond linguistically to word analogies. The typical example for an analogy is \(\text{king} - \text{man} + \text{woman} = \text{queen}\) (see word2vec).</p>
<p>The implementation is quite simple. The following code finds the relevant words in the embedding matrix \(X\) and then computes the dot product to get the word with the highest similarity:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">word_mask</span> <span class="o">=</span> <span class="n">tokens</span> <span class="o">==</span> <span class="s">"king"</span>
<span class="n">word_mask2</span> <span class="o">=</span> <span class="n">tokens</span> <span class="o">==</span> <span class="s">"man"</span>
<span class="n">word_mask3</span> <span class="o">=</span> <span class="n">tokens</span> <span class="o">==</span> <span class="s">"woman"</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">word_mask</span><span class="p">][:</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">X</span><span class="p">[</span><span class="n">word_mask2</span><span class="p">][:</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">X</span><span class="p">[</span><span class="n">word_mask3</span><span class="p">][:</span><span class="mi">1</span><span class="p">]</span>
<span class="n">z</span> <span class="o">/=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">X</span> <span class="o">/=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">similarity</span> <span class="o">=</span> <span class="n">X</span> <span class="o">@</span> <span class="n">z</span><span class="p">.</span><span class="n">T</span>
<span class="n">top_k</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">similarity</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="o">-</span><span class="mi">5</span><span class="p">:][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Closest tokens: </span><span class="si">{</span><span class="n">tokens</span><span class="p">[</span><span class="n">top_k</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span></code></pre></figure>
<p>The top-5 results to \(\text{king} - \text{man} + \text{woman}\) are: queen, king, king, king, king. Another way of writing the above code as equation is:</p>
\[\arg\max_{v_i} \left(v_{\text{king}} - v_{\text{man}} + v_{\text{woman}}\right)^Tv_i = v_\text{queen}\]
<p>We found that \(i = \text{queen}\) for the top-1 result. However, it depends also on the sentence because “king” is geometrically not far from “queen”.</p>
<h3 id="isotropy">Isotropy</h3>
<p>While the last sections were focused on describing the relationship between words in the vector space, this section is about operations on the space itself.</p>
<p>When words are clustered together, it is harder to distinguish them one from another. Intuitively, words should be spread out across the whole space <a href="#6">[5, 6, 7, 8, 9]</a>. The desired property is called “isotropy”. A common theme in recent papers is to postprocess the word embeddings so that they are more isotropic.</p>
<p>The postprocessing methods only require changing the embedding matrix itself. Some important papers are the following:</p>
<ul>
<li><a href="#6">[5]</a> applied principal component analysis (PCA) on the word embeddings and showed that some components have a stronger influence than others. Removing the leading components gave better results on common datasets. An implementation can be found <a href="https://gist.github.com/lgalke/febaaa1313d9c11f3bc8240defed8390">here</a>.</li>
<li><a href="#6">[8]</a> improved the previous results by softly filtering out components of the word vectors (no complete removal). An implementation can be found <a href="https://github.com/liutianlin0121/Conceptor-Negation-WV/blob/master/CN_demo.ipynb">here</a>.</li>
<li><a href="#6">[9]</a> proposed removing the variance from the embeddings (whitening). An implementation can be found <a href="https://github.com/bojone/BERT-whitening">here</a>.</li>
</ul>
<h2 id="conclusion">Conclusion</h2>
<p>The operations I introduced in this post mainly focus on things you do with vector spaces: calculate the angle (dot product), combine vectors (linear combination), add vectors (vector addition), or change the basis (e.g. PCA). However, there are also more general operations like clustering words or visualizing the dimensions. I might come back to this post in the future and add more general operations as well.</p>
<h2 id="references">References</h2>
<p>[1] O. Levy and Y. Goldberg, “Neural Word Embedding as Implicit Matrix Factorization”, 2014.</p>
<p>[2] I. Turc, M.-W. Chang, K. Lee, K. Toutanova, “Well-Read Students Learn Better: On the Importance of Pre-training Compact Models”, 2019.</p>
<p>[3] J. Zhang, Y. Chen, B. Cheung, B. A. Olshausen, “Word Embedding Visualization Via Dictionary Learning”, 2021.</p>
<p>[4] Z. Yun, Y. Chen, B. A. Olshausen, Y. LeCun, “Transformer visualization via dictionary learning: contextualized embedding as a linear superposition of transformer factors”, 2021.</p>
<p>[5] J. Mu, P. Viswanath, “All-but-the-top: Simple and effective postprocessing for word representations”, 2018.</p>
<p>[6] X. Cai, J. Huang, Y. Bian, K. Church, “Isotropy in the contextual embedding space: Clusters and manifolds”, 2021.</p>
<p>[7] S. Arora, Y. Li, Y. Liang, T. Ma, A. Risteski, “A Latent Variable Model Approach to PMI-based Word Embeddings”, 2016.</p>
<p>[8] T. Liu, L. Ungar, J. Sedoc, “Unsupervised Post-processing of Word Vectors via Conceptor Negation”, 2018.</p>
<p>[9] J. Su, J. Cao, W. L, Y. Ou, “Whitening Sentence Representations for Better Semantics and Faster Retrieval”, 2021.</p>A variety of operations can be performed on (contextual) word vectors. In this blog post, I will implement some common operations using PyTorch and Python.Uncertainty estimation in neural networks2020-08-14T00:00:00+00:002020-08-14T00:00:00+00:00https://lars76.github.io/2020/08/14/uncertainty-estimation-in-neural-networks<p>In this blog post, I will implement some common methods for uncertainty estimation. My main focus lies on classification and segmentation. Therefore, regression-specific methods such as Pinball loss are not covered here.</p>
<!--more-->
<p>Recently, there has been a lot of development in Gaussian processes. Google has published a library for <a href="https://ai.googleblog.com/2020/03/fast-and-easy-infinitely-wide-networks.html">infinitely wide networks</a> (Neural network Gaussian process). There are also <a href="https://arxiv.org/abs/1902.05888">Deep Convolutional Gaussian Processes</a>, which seem to handle uncertainty really well. However, these methods require large amounts of <a href="https://github.com/google/neural-tangents/issues/18">memory</a>, do not scale up to big datasets or do not work with other architectures such as LSTMs.</p>
<p>The methods here are more practical and can be adapted to most architectures.</p>
<p><strong>01.09.2020 update: changed vector/matrix scaling to use a vector-valued bias term</strong></p>
<p><strong>16.08.2020 update: added implementation for temperature/vector/matrix scaling</strong></p>
<h2 id="resampling-methods">Resampling methods</h2>
<p>The methods in this section sample the input to the neural network or to certain layers:</p>
<ul>
<li>Monte Carlo dropout <a href="#8">[1]</a>: sampling of dropout masks</li>
<li>heteroscedastic uncertainty <a href="#8">[1]</a>: sampling of Gaussian noise</li>
<li>ensemble <a href="#8">[2]</a>: sampling of input</li>
<li>TTA <a href="#8">[3]</a>: sampling of image transformations (augmentations)</li>
</ul>
<p>Note that this is only a rough categorization from a practical point of view. I am not considering Bayesian interpretations here.</p>
<h3 id="monte-carlo-dropout">Monte Carlo dropout</h3>
<p>In practice, Monte Carlo dropout (MC dropout) consists of running an image multiple times through a neural network with dropout and calculating the mean of the results. Dropout is not deactivated during prediction as it is normally the case. MC dropout models “epistemic uncertainty”, that is, uncertainty in the parameters. The prediction \(\hat{p}\) is given by:</p>
\[\hat{p} = \frac{1}{M} \sum_{i=1}^M \text{softmax}(f^{W_i}(x))\]
<p>where \(W_1, \dots, W_M\) are sampled dropout masks. The number of samples \(M\) is usually \(40\) or \(50\).</p>
<p>MC dropout is often interpreted as a kind of approximate inference algorithm for Bayesian neural networks. But another (perhaps more plausible) interpretation is that dropout produces different combinations of models. A neuron can either be off or on, so there are \(2^n\) possible networks.</p>
<p>Since dropout is normally disabled at test time, we have to use <code class="language-plaintext highlighter-rouge">torch.nn.functional.dropout</code> instead of <code class="language-plaintext highlighter-rouge">torch.nn.Dropout</code>. This function allows us to specify if we are in training or test mode.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">p</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span></code></pre></figure>
<p>Place <code class="language-plaintext highlighter-rouge">F.dropout</code> with <code class="language-plaintext highlighter-rouge">training=True</code> after some layers and then during prediction (not training) run the following code.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">samples</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="n">classes</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">M</span><span class="p">):</span>
<span class="n">samples</span> <span class="o">+=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">samples</span> <span class="o">/=</span> <span class="n">M</span></code></pre></figure>
<h3 id="heteroscedastic-uncertainty">Heteroscedastic uncertainty</h3>
<p>Aleatoric uncertainty captures noise inherent in the observations. It is similar to MC dropout, but instead of deactivating neurons, noise is added to the final pre-activation output.</p>
\[\hat{p} = \frac{1}{M} \sum_{i=1}^M \text{softmax}(x + z\epsilon_i)\]
<p>where \(\epsilon_i \sim \mathcal{N}(0, 1)\) and \(z\) is an additional output.</p>
<p>The loss function changes to some degree, because a sum inside the logarithm function cannot be simplified.</p>
\[\begin{aligned}
H(p, \hat{p}) &= -p_j\log\left(\frac{1}{M} \sum_{i=1}^M \text{softmax}\left(x + z\epsilon_i\right)_j\right)\\
&= M - \log\left(\sum_{i=1}^M \text{softmax}\left(x + z\epsilon_i\right)_j\right)\\
&= M - \log\left(\sum_{i=1}^M \frac{\exp\left((x + z\epsilon_i)_j\right)}{\sum_{k=1}^N \exp\left((x + z\epsilon_i)_k\right)}\right)\\
&= M - \log\left(\sum_{i=1}^M \exp\left((x + z\epsilon_i)_j - \log\left(\sum_{k=1}^N \exp\left((x + z\epsilon_i)_k\right)\right)\right)\right)
\end{aligned}\]
<p>where \(H(\cdot, \cdot)\) denotes cross entropy and \(p_j = 1\).</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">M</span><span class="p">):</span>
<span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">sigma</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">feat</span><span class="p">,</span> <span class="n">classes</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">M</span> <span class="o">=</span> <span class="n">M</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="p">...</span>
<span class="n">sigma</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">sigma</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">M</span><span class="p">):</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="n">sigma</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">k</span> <span class="o">-</span> <span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">k</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)))</span>
<span class="k">return</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">M</span><span class="p">).</span><span class="nb">float</span><span class="p">()</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">y</span><span class="p">),</span> <span class="n">y</span> <span class="o">/</span> <span class="bp">self</span><span class="p">.</span><span class="n">M</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s">"mean"</span><span class="p">)</span></code></pre></figure>
<p>The first output of <code class="language-plaintext highlighter-rouge">forward</code> is the loss function, the second output is the prediction. Since <code class="language-plaintext highlighter-rouge">nll_loss</code> computes the negative log likelihood, I changed the sign of the first output.</p>
<h3 id="ensemble">Ensemble</h3>
<p>Ensembles are multiple neural networks that are trained on subsets of the original dataset. Predictions of each model are combined via averaging. There are different ways to build ensembles. However, cross validation tends to produce good results. For example, bagging would result in a loss of \((1 - \frac{1}{n})^n = \frac{1}{e} \approx 0.37\) of the original data as \(n \to \infty\).</p>
<p>The following code shows how to create the data loaders for cross validation.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">StratifiedKFold</span>
<span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">datasets</span><span class="p">,</span> <span class="n">transforms</span>
<span class="n">train_loaders</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">val_loaders</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">skf</span> <span class="o">=</span> <span class="n">StratifiedKFold</span><span class="p">(</span><span class="n">n_splits</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="n">base_dataset</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">base_folder</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">.</span><span class="n">Compose</span><span class="p">(</span><span class="n">arr</span><span class="p">))</span>
<span class="k">for</span> <span class="n">train_index</span><span class="p">,</span> <span class="n">test_index</span> <span class="ow">in</span> <span class="n">skf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">base_dataset</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="n">base_dataset</span><span class="p">.</span><span class="n">targets</span><span class="p">):</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">base_dataset</span><span class="p">)</span>
<span class="n">dataset</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">train_index</span><span class="p">]</span>
<span class="n">dataset</span><span class="p">.</span><span class="n">targets</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">.</span><span class="n">targets</span><span class="p">[</span><span class="n">train_index</span><span class="p">]</span>
<span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span>
<span class="n">pin_memory</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">(),</span>
<span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">base_dataset</span><span class="p">)</span>
<span class="n">dataset</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">test_index</span><span class="p">]</span>
<span class="n">dataset</span><span class="p">.</span><span class="n">targets</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">.</span><span class="n">targets</span><span class="p">[</span><span class="n">test_index</span><span class="p">]</span>
<span class="n">val_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span>
<span class="n">pin_memory</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">(),</span>
<span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
<span class="n">train_loaders</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">train_loader</span><span class="p">)</span>
<span class="n">val_loaders</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">val_loader</span><span class="p">)</span></code></pre></figure>
<p>Next, we train the models.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">for</span> <span class="n">fold</span><span class="p">,</span> <span class="n">train_loader</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loaders</span><span class="p">):</span>
<span class="n">train</span> <span class="n">model</span>
<span class="p">...</span>
<span class="k">if</span> <span class="n">loss</span> <span class="o"><</span> <span class="n">best_loss</span><span class="p">:</span>
<span class="n">best_loss</span> <span class="o">=</span> <span class="n">loss</span>
<span class="n">torch</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="sa">f</span><span class="s">"model_</span><span class="si">{</span><span class="n">fold</span><span class="si">}</span><span class="s">.pt"</span><span class="p">)</span></code></pre></figure>
<p>Finally, during prediction all outputs are averaged.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">dataset_size</span><span class="p">,</span> <span class="n">classes</span><span class="p">)</span>
<span class="k">for</span> <span class="n">fold</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="sa">f</span><span class="s">"model_</span><span class="si">{</span><span class="n">fold</span><span class="si">}</span><span class="s">.pt"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">val_loader</span><span class="p">:</span>
<span class="n">out</span> <span class="o">+=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">out</span> <span class="o">/=</span> <span class="mi">5</span></code></pre></figure>
<h3 id="test-time-augmentation">Test time augmentation</h3>
<p>Test time augmentation (TTA) consists of applying different transformations to an image and then run each input through a neural network. The results are usually averaged. The transformations are data augmentations e.g. rotations or reflections.</p>
<p>Instead of reinventing the wheel, one can use the library <a href="https://github.com/qubvel/ttach">TTAch</a>. It already contains some common pipelines and transformations.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">transforms</span> <span class="o">=</span> <span class="n">tta</span><span class="p">.</span><span class="n">Compose</span><span class="p">(</span>
<span class="p">[</span>
<span class="n">tta</span><span class="p">.</span><span class="n">HorizontalFlip</span><span class="p">(),</span>
<span class="n">tta</span><span class="p">.</span><span class="n">Rotate90</span><span class="p">(</span><span class="n">angles</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">180</span><span class="p">]),</span>
<span class="n">tta</span><span class="p">.</span><span class="n">Scale</span><span class="p">(</span><span class="n">scales</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">]),</span>
<span class="n">tta</span><span class="p">.</span><span class="n">Multiply</span><span class="p">(</span><span class="n">factors</span><span class="o">=</span><span class="p">[</span><span class="mf">0.9</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">]),</span>
<span class="p">]</span>
<span class="p">)</span>
<span class="n">tta_model</span> <span class="o">=</span> <span class="n">tta</span><span class="p">.</span><span class="n">SegmentationTTAWrapper</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">transforms</span><span class="p">,</span> <span class="n">merge_mode</span><span class="o">=</span><span class="s">'mean'</span><span class="p">)</span></code></pre></figure>
<p>TTA is a common trick in ML competitions. There are also a couple of papers which analyze it with respect to uncertainty estimation. Recently, a paper on <a href="https://github.com/bayesgroup/gps-augment">learnable test-time augmentation</a> was also published.</p>
<h2 id="loss-functions">Loss functions</h2>
<h3 id="learned-confidence-estimates">Learned confidence estimates</h3>
<p>Learned confidence estimates (LCE) adds an extra output \(0 \leq c \leq 1\) to the neural network and modifies the loss function as follows:</p>
\[\mathcal{L} = H(p, c\hat{p} + (1 - c)p) + \lambda H(1, c)\]
<p>where \(H(\cdot, \cdot)\) computes cross entropy and \(\lambda\) is the amount of confidence penalty. As training progresses, one can adjust \(\lambda\).</p>
<p>The basic idea of LCE is to give the neural network hints during training. For example, when the network is very confident, then \(c\hat{p} + (1 - c)p = 1\hat{p} + (1 - 1)p = \hat{p}\) (no hints necessary). If the network is not sure, then \(c\hat{p} + (1 - c)p = 0\hat{p} + (1 - 0)p = p\) (all hints are necessary).</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">comb</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">feat</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">sigmoid</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sigmoid</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="p">...</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="bp">self</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">comb</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="n">output</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">new_p</span> <span class="o">=</span> <span class="n">c</span> <span class="o">*</span> <span class="n">output</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">c</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">classes</span><span class="p">)[</span><span class="n">target</span><span class="p">]</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">new_p</span><span class="p">),</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s">"mean"</span><span class="p">)</span> <span class="o">-</span> <span class="k">lambda</span> <span class="o">*</span> <span class="n">torch</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">c</span><span class="p">))</span></code></pre></figure>
<p>In the original paper <a href="#8">[4]</a>, the authors used \(c\) also for out-of-distribution detection.</p>
<h2 id="calibration">Calibration</h2>
<p>The three most common methods for calibration are:</p>
<ul>
<li>temperature scaling: scale by a scalar \(\hat{p} = \text{softmax}\left(\frac{z}{T}\right)\)</li>
<li>vector scaling: scale by a vector \(\hat{p} = \text{softmax}\left(w \circ z + b\right)\) where \(\circ\) denotes element-wise multiplication.</li>
<li>matrix scaling: scale by a matrix \(\hat{p} = \text{softmax}\left(Wz + b\right)\)</li>
</ul>
<p>\(z\) is the final pre-activation layer (i.e. logits) and \(T\), \(w\), \(W\), \(b\) are learnable parameters. The parameters are tuned on the test set. The authors in <a href="#8">[5]</a> proposed using repeated 2-fold cross validation on the test set to reduce the variance of the result.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">torch</span>
<span class="k">def</span> <span class="nf">temp_scaling</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">train_index</span><span class="p">,</span> <span class="n">test_index</span><span class="p">):</span>
<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">y_true</span><span class="p">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.5</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">LBFGS</span><span class="p">([</span><span class="n">t</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">max_iter</span><span class="o">=</span><span class="mi">5000</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">closure</span><span class="p">():</span>
<span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">y_pred</span><span class="p">[</span><span class="n">train_index</span><span class="p">]</span> <span class="o">/</span> <span class="n">t</span><span class="p">,</span> <span class="n">y_true</span><span class="p">[</span><span class="n">train_index</span><span class="p">])</span>
<span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
<span class="k">return</span> <span class="n">loss</span>
<span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">(</span><span class="n">closure</span><span class="p">)</span>
<span class="k">return</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">y_pred</span><span class="p">[</span><span class="n">test_index</span><span class="p">]</span> <span class="o">/</span> <span class="n">t</span><span class="p">,</span> <span class="n">y_true</span><span class="p">[</span><span class="n">test_index</span><span class="p">]).</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="p">.</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">nll_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">5</span><span class="p">):</span>
<span class="n">kf</span> <span class="o">=</span> <span class="n">KFold</span><span class="p">(</span><span class="n">n_splits</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">train_index</span><span class="p">,</span> <span class="n">test_index</span> <span class="ow">in</span> <span class="n">kf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">y_pred_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">):</span>
<span class="n">nll_loss</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">temp_scaling</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_pred_test</span><span class="p">,</span> <span class="n">train_index</span><span class="p">,</span> <span class="n">test_index</span><span class="p">)</span>
<span class="n">nll_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nll_loss</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"NLL: </span><span class="si">{</span><span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">nll_losses</span><span class="p">)</span><span class="si">}</span><span class="s"> (std: </span><span class="si">{</span><span class="n">np</span><span class="p">.</span><span class="n">std</span><span class="p">(</span><span class="n">nll_losses</span><span class="p">)</span><span class="si">}</span><span class="s">)"</span><span class="p">)</span></code></pre></figure>
<p><code class="language-plaintext highlighter-rouge">y_pred_test</code> are all logits of some test dataset and <code class="language-plaintext highlighter-rouge">y_test</code> is the ground truth.</p>
<p>For vector scaling add parameters <code class="language-plaintext highlighter-rouge">w = nn.Parameter(torch.randn(y_pred.shape[-1],) * 1e-3)</code> and <code class="language-plaintext highlighter-rouge">b = nn.Parameter(torch.zeros(y_pred.shape[-1],))</code>. Then change where necessary the code to <code class="language-plaintext highlighter-rouge">y_pred[train_index] * w + b</code>.</p>
<p>For matrix scaling add parameters <code class="language-plaintext highlighter-rouge">W = nn.Parameter(torch.randn(y_pred.shape[-1], y_pred.shape[-1]) * 1e-3)</code> and <code class="language-plaintext highlighter-rouge">b = nn.Parameter(torch.zeros(y_pred.shape[-1],))</code>. Then change where necessary the code to <code class="language-plaintext highlighter-rouge">y_pred[train_index] @ W + b</code>.</p>
<p>Another implementation for temperature scaling can be found <a href="https://github.com/gpleiss/temperature_scaling">here</a>. <code class="language-plaintext highlighter-rouge">temperature_scaling.py</code> contains the relevant code.</p>
<p>Calibration methods are applied after the regular training of the neural network. The scaling parameters are optimized based on the cross entropy loss and the metric is often the expected calibration error (ECE) or NLL. Some other metrics can be found in my <a href="/2020/08/07/metrics-for-uncertainty-estimation.html">last blog post</a>.</p>
<h2 id="references">References</h2>
<p>[1] A. Kendall and Y. Gal, <em>What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?</em>, 2017.</p>
<p>[2] B. Lakshminarayanan, A. Pritzel and C. Blundell, <em>Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles</em>, 2017.</p>
<p>[3] Guotai Wang, Wenqi Li et al., <em>Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional neural networks</em>, 2019.</p>
<p>[4] T. DeVries and G. W. Taylor, <em>Learning Confidence for Out-of-Distribution Detection in Neural Networks</em>, 2018.</p>
<p>[5] A. Ashuskha, A. Lyzhov, D. Molchanov et al., <em>Pitfalls of in-domain uncertainty estimation and ensembling in deep learning</em>, 2020.</p>In this blog post, I will implement some common methods for uncertainty estimation. My main focus lies on classification and segmentation. Therefore, regression-specific methods such as Pinball loss are not covered here.Metrics for uncertainty estimation2020-08-07T00:00:00+00:002020-08-07T00:00:00+00:00https://lars76.github.io/2020/08/07/metrics-for-uncertainty-estimation<p>Predictions are not just about accuracy, but also about probability. In lots of applications it is important to know how sure a neural network is of a prediction. However, the softmax probabilities in neural networks are not always calibrated and don’t necessarily measure uncertainty.</p>
<p>In this blog post, I will implement the most common metrics to evaluate the output probabilities of neural networks.</p>
<!--more-->
<p>There are in general two types of metrics:</p>
<ol>
<li>Proper scoring rules estimate the deviation from the true probability distribution. A high value indicates that the predicted probability \(0 \leq \hat{p} \leq 1\) is far away from the true probability \(p \in \{0, 1\}\). Whether \(\hat{p}\) equals \(0.6\) or \(\hat{p}\) equals \(0.8\) is not as important as the distance from \(1\) or \(0\).</li>
<li>Calibration metrics measure the difference between “true confidence” and “predicted confidence”. If \(\hat{p}\) equals \(0.6\), then it should mean that the neural network is 60% sure. A model is calibrated if \(\mathbf{P}\left(\hat{Y} = y \mid \hat{P} = p\right) = p\). Then the difference is \(\left\lvert \mathbf{P}\left(\hat{Y} = y \mid \hat{P} = p\right) - p\right\rvert\). The predicted confidence is the output probability of the neural network, while the true confidence is estimated by the corresponding accuracy. Calibration metrics are computed on the whole dataset in order to group different probabilities (e.g. 0% - 10%, 10% - 20%, …). In contrast, proper scoring rules compare individual probabilities.</li>
</ol>
<p><strong>14/08/20 update: added recommendations, static calibration error and thresholding</strong></p>
<p><strong>30/08/21 update: the traditional reliability diagram, as it is known from weather forecasts, has the relative frequency and not the accuracy on the y-axis. However, papers like “On Calibration of Modern Neural Networks” have the accuracy on the y-axis. I use the latter convention here but would recommend the traditional definition as it gives better results.</strong></p>
<h2 id="proper-scoring-rules">Proper scoring rules</h2>
<h3 id="negative-log-likelihood">Negative log likelihood</h3>
<p>Negative log likelihood (NLL) is the usual method for optimizing neural networks for classification tasks. However, this loss function can also be used as a uncertainty metric. For example, the <a href="https://www.kaggle.com/c/deepfake-detection-challenge/overview/evaluation">Deepfake Detection Challenge</a> scored submissions on NLL.</p>
\[H(p, \hat{p}) = -\mathbf{E}_{p}[\log \hat{p}] = -\sum_{i=1}^n p_i\log\left(\hat{p}_i\right) = -\log\left(\hat{p}_j\right)\]
<p>where \(p_j = 1\) is the ground truth and \(\hat{p}_j = \text{softmax}_j\left(x\right)\). PyTorch’s <code class="language-plaintext highlighter-rouge">CrossEntropyLoss</code> applies the softmax function and computes \(H(p, \hat{p})\).</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s">"mean"</span><span class="p">)</span>
<span class="n">nll</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></code></pre></figure>
<p>We can also rewrite the code above using <code class="language-plaintext highlighter-rouge">nll_loss</code>. This shows more of what happens internally.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">logits</span> <span class="o">-</span> <span class="n">torch</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
<span class="n">nll</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s">"mean"</span><span class="p">)</span></code></pre></figure>
<p>To ensure numerical stability \(\max(x)\) was subtracted from \(\log\left(\text{softmax}_j\left(x\right)\right)\).</p>
<h3 id="brier-score">Brier score</h3>
<p>The Brier score is the mean squared error of a forecast. For a single output it is defined as follows:</p>
\[BS(p, \hat{p}) = \sum_{i=1}^{c}(\hat{p}_{i}-p_{i})^2 = 1 - 2\hat{p}_{j} + \sum_{i=1}^{c} \hat{p}_{i}^2\]
<p>For multiple values it is possible to sum over all outputs. The code is then</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">brier_score</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">+</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">y_pred</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">y_pred</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">y_pred</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">y_true</span><span class="p">]))</span> <span class="o">/</span> <span class="n">y_true</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code></pre></figure>
<p><code class="language-plaintext highlighter-rouge">y_true</code> should be a one dimensional array, while <code class="language-plaintext highlighter-rouge">y_pred</code> should be a two dimensional array. When predicting multiple classes, sometimes each class is considered individually (one-vs.-rest / one-against-all strategy).</p>
<h2 id="calibration-metrics">Calibration metrics</h2>
<h3 id="expected-calibration-error">Expected calibration error</h3>
\[\begin{aligned}
ECE &= \mathbf{E}_{\hat{P}}\left[\left\lvert \mathbf{P}(\hat{Y} = y \mid \hat{P} = p) - p\right\rvert\right]\\
&= \sum_{p} \mathbf{P}(\hat{P} = p) \left\lvert \mathbf{P}(\hat{Y} = y \mid \hat{P} = p) - p\right\rvert
\end{aligned}\]
<p>We approximate the probability distribution by a histogram with \(B\) bins. Then \(\mathbf{P}(\hat{P} = p) = \frac{n_b}{N}\) where \(n_b\) is the number of probabilities in bin \(b\) and \(N\) is the size of the dataset. Since we put \(n_b\) probabilities into one bin, \(p\) is not a single value. Therefore, a representative value \(p = \sum_{\hat{p_i} \in b} \frac{\hat{p_i}}{n_b} = \text{conf}(b)\) is necessary. Similarly, we can set \(\mathbf{P}(\hat{Y} = y \mid \hat{P} = p) = \sum_{\hat{y}_i \in b} \frac{\mathbf{1}\left(y_i = \hat{y_i}\right)}{n_b} = \text{acc}(b)\) where \(\hat{y_i}\) is obtained from the highest probability (arg max). \(\hat{p_i}\) is also the highest probability (max).</p>
<p>ECE is then defined as follows:</p>
\[\begin{aligned}\text{ECE}(B) &= \sum_{b=1}^{B} \frac{n_b}{N}\lvert\text{acc}(b) - \text{conf}(b)\rvert\\
&= \frac{1}{N}\sum_{b \in B}\left\lvert\sum_{(\hat{p_i}, \hat{y_i}) \in b} \mathbf{1}\left(y_i = \hat{y_i}\right) - \hat{p_i}\right\rvert\end{aligned}\]
<p>The accuracy \(\text{acc}(b)\) is also called “observed relative frequency”, while the confidence \(\text{conf}(b)\) is a synonym for “average predicted frequency”.</p>
<p>The implementation is:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">expected_calibration_error</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">num_bins</span><span class="o">=</span><span class="mi">15</span><span class="p">):</span>
<span class="n">pred_y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">correct</span> <span class="o">=</span> <span class="p">(</span><span class="n">pred_y</span> <span class="o">==</span> <span class="n">y_true</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">prob_y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">num</span><span class="o">=</span><span class="n">num_bins</span><span class="p">)</span>
<span class="n">bins</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">digitize</span><span class="p">(</span><span class="n">prob_y</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">b</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">o</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_bins</span><span class="p">):</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">bins</span> <span class="o">==</span> <span class="n">b</span>
<span class="k">if</span> <span class="n">np</span><span class="p">.</span><span class="nb">any</span><span class="p">(</span><span class="n">mask</span><span class="p">):</span>
<span class="n">o</span> <span class="o">+=</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">correct</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">-</span> <span class="n">prob_y</span><span class="p">[</span><span class="n">mask</span><span class="p">]))</span>
<span class="k">return</span> <span class="n">o</span> <span class="o">/</span> <span class="n">y_pred</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code></pre></figure>
<p><code class="language-plaintext highlighter-rouge">y_true</code> should be a one dimensional array like <code class="language-plaintext highlighter-rouge">np.array([0,1,0,1,0,0])</code>, while <code class="language-plaintext highlighter-rouge">y_pred</code> requires a two dimensional array e.g. <code class="language-plaintext highlighter-rouge">np.array([[0.9, 0.1],[0.1, 0.9],[0.4, 0.6],[0.6, 0.4]], dtype=np.float32)</code>. Since most papers use between 10 and 20 bins <a href="#6">[1]</a>, I set <code class="language-plaintext highlighter-rouge">num_bins=15</code>. More bins reduce the bias, but increase the variance (<a href="https://en.wikipedia.org/wiki/Bias%E2%80%93variance_tradeoff">bias-variance tradeoff</a>).</p>
<p>If you have TensorFlow Probability installed, you can also use the following function (which produces the same results):</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">tensorflow_probability</span> <span class="k">as</span> <span class="n">tfp</span>
<span class="n">tfp</span><span class="p">.</span><span class="n">stats</span><span class="p">.</span><span class="n">expected_calibration_error</span><span class="p">(</span><span class="n">num_bins</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">labels_true</span><span class="o">=</span><span class="n">gt</span><span class="p">,</span> <span class="n">logits</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">pred</span><span class="p">))</span></code></pre></figure>
<p>Note if <code class="language-plaintext highlighter-rouge">pred</code> are logits, then <code class="language-plaintext highlighter-rouge">np.log</code> is not necessary.</p>
<p>There are a few problems with the standard ECE. <code class="language-plaintext highlighter-rouge">np.linspace</code> will create evenly spaced bins, which are likely to be empty. In statistics, bins are often chosen so that each bin contains an equal number of probability outcomes <a href="#6">[2]</a>. This is called <em>Adaptive Calibration Error</em> (ACE) in <a href="#6">[1]</a>.</p>
<p>One can change the variable <code class="language-plaintext highlighter-rouge">b</code> in <code class="language-plaintext highlighter-rouge">expected_calibration_error</code> to obtain ACE.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">num</span><span class="o">=</span><span class="n">num_bins</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">quantile</span><span class="p">(</span><span class="n">prob_y</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">unique</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">num_bins</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">b</span><span class="p">)</span></code></pre></figure>
<p>However, the adaptivity can also cause the number of bins to decrease. At the start of a neural network I trained, there were \(15\) bins. After 10 epochs the number of bins reduced to \(11\). The sigmoid function tends to over-emphasize probabilities near \(1\) or \(0\). For example, one training run produced the bins \(\{0.4786461, 0.99776319, 0.99977307, \dots, 0.99995485, 0.99999988, 1., 1.\}\).</p>
<p>It is also important to note that only the highest probability is considered for ECE/ACE i.e. <code class="language-plaintext highlighter-rouge">pred_y = np.argmax(y_pred, axis=-1)</code>. <a href="#6">[2]</a> proposes <em>Static Calibration Error</em> (SCE) which bins the predictions separately for each class probability. This should be considered, when all probabilities in a multi-class setting are equally important.</p>
\[\begin{aligned}\text{SCE}(B) &= \frac{1}{NC}\sum_{c=0}^{C-1}\sum_{b \in B}\left\lvert\sum_{\hat{p_i} \in b} \mathbf{1}\left(y_i = c\right) - \hat{p_i}\right\rvert\end{aligned}\]
<p>The implementation is:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">static_calibration_error</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">num_bins</span><span class="o">=</span><span class="mi">15</span><span class="p">):</span>
<span class="n">classes</span> <span class="o">=</span> <span class="n">y_pred</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">o</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">cur_class</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">classes</span><span class="p">):</span>
<span class="n">correct</span> <span class="o">=</span> <span class="p">(</span><span class="n">cur_class</span> <span class="o">==</span> <span class="n">y_true</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">prob_y</span> <span class="o">=</span> <span class="n">y_pred</span><span class="p">[...,</span> <span class="n">cur_class</span><span class="p">]</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">num</span><span class="o">=</span><span class="n">num_bins</span><span class="p">)</span>
<span class="n">bins</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">digitize</span><span class="p">(</span><span class="n">prob_y</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">b</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_bins</span><span class="p">):</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">bins</span> <span class="o">==</span> <span class="n">b</span>
<span class="k">if</span> <span class="n">np</span><span class="p">.</span><span class="nb">any</span><span class="p">(</span><span class="n">mask</span><span class="p">):</span>
<span class="n">o</span> <span class="o">+=</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">correct</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">-</span> <span class="n">prob_y</span><span class="p">[</span><span class="n">mask</span><span class="p">]))</span>
<span class="k">return</span> <span class="n">o</span> <span class="o">/</span> <span class="p">(</span><span class="n">y_pred</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">classes</span><span class="p">)</span></code></pre></figure>
<p>If there are a lot of classes, adaptive SCE will assign too many bins to predictions close to 0% (e.g. 999 classes \(\approx 0.01\), 1 class \(\approx 0.99\)). ECE does not have the same problem, because it only evaluates the class with the highest probability. <a href="#6">[1]</a> suggests thresholding the predictions in this case (e.g. \(10^{-3}\)). Change the code as follows:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="p">...</span>
<span class="n">prob_y</span> <span class="o">=</span> <span class="n">y_pred</span><span class="p">[...,</span> <span class="n">cur_class</span><span class="p">]</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">prob_y</span> <span class="o">></span> <span class="n">threshold</span>
<span class="n">correct</span> <span class="o">=</span> <span class="n">correct</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span>
<span class="n">prob_y</span> <span class="o">=</span> <span class="n">prob_y</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span>
<span class="p">...</span>
<span class="n">o</span> <span class="o">+=</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">correct</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">-</span> <span class="n">prob_y</span><span class="p">[</span><span class="n">mask</span><span class="p">]))</span> <span class="o">/</span> <span class="n">prob_y</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="p">...</span>
<span class="k">return</span> <span class="n">o</span> <span class="o">/</span> <span class="n">classes</span></code></pre></figure>
<p>Some other things to keep in mind are:</p>
<ul>
<li>optimizing ECE: using <code class="language-plaintext highlighter-rouge">scipy.optimize</code> it is possible to directly optimize this non-differentiable metric. However, according to <a href="#6">[1]</a> “ECE is very strongly influenced by measures of calibration error that adhere to its own properties, rather than capturing a more general concept of the calibration error.”</li>
<li>norm: most paper use the \(L_1\) norm, but \(L_2\) is also an option.</li>
</ul>
<h3 id="reliability-diagram">Reliability diagram</h3>
<p>The x-axis is <code class="language-plaintext highlighter-rouge">np.sum(prob_y[mask]) / count</code> (confidence or avg predicted frequency) and the y-axis is <code class="language-plaintext highlighter-rouge">np.sum(correct[mask]) / count)</code> (accuracy). It is important to note that the traditional reliability diagram has on the y-axis the “observed relative frequency” <code class="language-plaintext highlighter-rouge">np.sum(y_true[mask]) / count)</code> and NOT the accuracy. I would also recommend using the “observed relative frequency” as this is the standard approach.</p>
<p>First, we change the function <code class="language-plaintext highlighter-rouge">expected_calibration_error</code> to return both values. Then the following function will produce a reliability diagram:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">seaborn</span> <span class="k">as</span> <span class="n">sns</span>
<span class="n">sns</span><span class="p">.</span><span class="nb">set</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">reliability_diagram</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">expected_calibration_error</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="s">"k:"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Perfectly calibrated"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="s">"s-"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"CNN"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Confidence"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Accuracy"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="s">"lower right"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Reliability diagram / calibration curve"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span></code></pre></figure>
<p>The reliability diagram itself looks like this:</p>
<figure class="figure text-center" style="width: 75%;">
<img src="/assets/images/reliability_diagram.jpg" class="figure-img img-fluid rounded" alt="..." />
<figcaption class="figure-caption">Reliability diagram of some CNN</figcaption>
</figure>
<h2 id="recommendations">Recommendations</h2>
<p>Using an unsuitable metric can lead to wrong conclusions. According to <a href="#6">[3]</a>, calibration metrics should not be used to compare different models. Expected calibration error is sensitive to the number of bins and the thresholding. Furthermore, it does not provide a consistent ranking of different models.</p>
<p>Instead, a better metric would be BS and log likelihood provided temperature scaling was applied to the logit layer. ECE is more useful for measuring the calibration of a specific model.</p>
<h2 id="references">References</h2>
<p>[1] J. Nixon, M. Dusenberry et al., <em>Measuring Calibration in Deep Learning</em>, 2020.</p>
<p>[2] Hyukjun Gweon and Hao Yu, <em>How reliable is your reliability diagram?</em>, 2019.</p>
<p>[3] A. Ashuskha, A. Lyzhov, D. Molchanov et al. <em>Pitfalls of in-domain uncertainty estimation and ensembling in deep learning</em>, 2020.</p>Predictions are not just about accuracy, but also about probability. In lots of applications it is important to know how sure a neural network is of a prediction. However, the softmax probabilities in neural networks are not always calibrated and don’t necessarily measure uncertainty.Implementing Poincaré Embeddings in PyTorch2020-07-24T00:00:00+00:002020-07-24T00:00:00+00:00https://lars76.github.io/2020/07/24/implementing-poincare-embedding<p>After having introduced Riemannian SGD in the last blog post, here I will give a concrete application for this optimization method. Poincaré embeddings <a href="#5">[1]</a><a href="#5">[2]</a> are hierarchical word embeddings which map integer-encoded words to the hyperbolic space.</p>
<!--more-->
<p>Even though the original paper used the Poincaré unit ball, any reasonable manifold can work. For example, when the data is not really a tree, the Euclidean space can produce better embeddings.</p>
<p>Poincaré Embeddings consist of the following components:</p>
<ul>
<li>dataset: \(\{(w_1, v_2), \dots, (w_n, v_n)\}\) where \(w_i\) is the parent and \(v_i\) is the child (e.g. hypernym and hyponym)</li>
<li>distance function: the Poincaré ball needs \(d(x, y) = \text{arcosh}\left(1 + 2\frac{\lvert\lvert x - y\rvert\rvert^2}{(1 - \lvert\lvert x\rvert\rvert^2)(1 - \lvert\lvert y\rvert\rvert^2)}\right)\), while the Euclidean space uses \(d(x, y) = \lVert x - y\rVert_2^2\)</li>
<li>loss function: \(\log\left(\text{softmax}\left(-d(x, y)\right)\right)\) (“categorical crossentropy”)</li>
<li>weights: a single \(m \times n\) matrix where \(m\) is the input vocabulary and \(n\) is the output vocabulary (projection)</li>
<li>metric: words that are neighbors should be closer together in a ranking than words that have no connection (“mean rank”).</li>
</ul>
<h2 id="implementation">Implementation</h2>
<h3 id="model">Model</h3>
<p>First, we create the weights using the function <code class="language-plaintext highlighter-rouge">Embedding</code>. Then they are initialized close to \(0\). Since the Poincaré ball requires \(\lvert\lvert x\rvert\rvert < 1\), this won’t cause any trouble.</p>
<p>During forward propagation the input is split into two parts: parent (0 to 1) and children (1 to n). Next, we compute the distance between all nodes.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">init_weights</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">,</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-7</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">Embedding</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">sparse</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">embedding</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="n">init_weights</span><span class="p">,</span> <span class="n">init_weights</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">epsilon</span>
<span class="k">def</span> <span class="nf">dist</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">u</span><span class="p">,</span> <span class="n">v</span><span class="p">):</span>
<span class="n">sqdist</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">((</span><span class="n">u</span> <span class="o">-</span> <span class="n">v</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">squnorm</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">u</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">sqvnorm</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">v</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">sqdist</span> <span class="o">/</span> <span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">squnorm</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">sqvnorm</span><span class="p">))</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">epsilon</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">z</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span>
<span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">e</span><span class="p">.</span><span class="n">narrow</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">start</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="n">e</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">s</span> <span class="o">=</span> <span class="n">e</span><span class="p">.</span><span class="n">narrow</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="mi">1</span><span class="p">).</span><span class="n">expand_as</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">dist</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">o</span><span class="p">)</span></code></pre></figure>
<p>The line <code class="language-plaintext highlighter-rouge">x = 1 + 2 * sqdist / ((1 - squnorm) * (1 - sqvnorm))</code> causes numerical instability. When using double-precision floating-point, <code class="language-plaintext highlighter-rouge">epsilon</code> can often be set to \(0\).</p>
<h3 id="training">Training</h3>
<p>First, set <code class="language-plaintext highlighter-rouge">torch.set_default_dtype(torch.float64)</code>. This is not strictly necessary, but gives slightly better results and makes the network more stable.</p>
<p>Next, we need two distributions:</p>
<ul>
<li>categorical distribution: during the first \(20\) epochs, words/relations that occur more often are drawn with a higher frequency.</li>
<li>uniform distribution: every word/relation has the same chance of being drawn.</li>
</ul>
<p>It is not mentioned in the original paper, but the offical code follows this approach.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">cat_dist</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">weights</span><span class="p">))</span>
<span class="n">unif_dist</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">),)</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">))</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">DIMENSIONS</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">))</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">RiemannianSGD</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="n">batch_X</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">NEG_SAMPLES</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">long</span><span class="p">)</span>
<span class="n">batch_y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">long</span><span class="p">)</span>
<span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
<span class="k">if</span> <span class="n">epoch</span> <span class="o"><</span> <span class="mi">20</span><span class="p">:</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">0.003</span>
<span class="n">sampler</span> <span class="o">=</span> <span class="n">cat_dist</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">0.3</span>
<span class="n">sampler</span> <span class="o">=</span> <span class="n">unif_dist</span>
<span class="n">perm</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randperm</span><span class="p">(</span><span class="n">dataset</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">dataset_rnd</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="n">perm</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dataset</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="n">dataset</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">%</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)):</span>
<span class="n">batch_X</span><span class="p">[:,:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset_rnd</span><span class="p">[</span><span class="n">i</span> <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">10</span><span class="p">]</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
<span class="n">a</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">sampler</span><span class="p">.</span><span class="n">sample</span><span class="p">([</span><span class="mi">2</span> <span class="o">*</span> <span class="n">NEG_SAMPLES</span><span class="p">]).</span><span class="n">numpy</span><span class="p">())</span>
<span class="n">negatives</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">a</span> <span class="o">-</span> <span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">neighbors</span><span class="p">[</span><span class="n">batch_X</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="mi">0</span><span class="p">]])</span> <span class="o">|</span> <span class="nb">set</span><span class="p">(</span><span class="n">neighbors</span><span class="p">[</span><span class="n">batch_X</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="mi">1</span><span class="p">]])))</span>
<span class="n">batch_X</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="mi">2</span> <span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">negatives</span><span class="p">)</span><span class="o">+</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">negatives</span><span class="p">[:</span><span class="n">NEG_SAMPLES</span><span class="p">])</span>
<span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">preds</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">batch_X</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">preds</span><span class="p">.</span><span class="n">neg</span><span class="p">(),</span> <span class="n">batch_y</span><span class="p">)</span>
<span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span></code></pre></figure>
<p>Note only the first two words in <code class="language-plaintext highlighter-rouge">batch_X</code> are real relations. All other words are randomly drawn from either the uniform or categorical distribution. The ground truth for <code class="language-plaintext highlighter-rouge">CrossEntropyLoss</code> is always the first element <code class="language-plaintext highlighter-rouge">batch_y = torch.zeros(10, dtype=torch.long)</code>. All other children are negatives.</p>
<h3 id="visualizations">Visualizations</h3>
<p>After having trained the neural network, we can visualize our embeddings.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"poincare_model_dim_2.pt"</span><span class="p">)</span>
<span class="n">coordinates</span> <span class="o">=</span> <span class="n">model</span><span class="p">[</span><span class="s">"state_dict"</span><span class="p">][</span><span class="s">"embedding.weight"</span><span class="p">].</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlim</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axis</span><span class="p">(</span><span class="s">'off'</span><span class="p">)</span>
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">coordinates</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
<span class="n">plt</span><span class="p">.</span><span class="n">annotate</span><span class="p">(</span><span class="n">model</span><span class="p">[</span><span class="s">"names"</span><span class="p">][</span><span class="n">x</span><span class="p">],</span> <span class="p">(</span><span class="n">coordinates</span><span class="p">[</span><span class="n">x</span><span class="p">,</span><span class="mi">0</span><span class="p">],</span> <span class="n">coordinates</span><span class="p">[</span><span class="n">x</span><span class="p">,</span><span class="mi">1</span><span class="p">]),</span>
<span class="n">bbox</span><span class="o">=</span><span class="p">{</span><span class="s">"fc"</span><span class="p">:</span><span class="s">"white"</span><span class="p">,</span> <span class="s">"alpha"</span><span class="p">:</span><span class="mf">0.9</span><span class="p">})</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span></code></pre></figure>
<p>If you made no mistakes, the result for WordNet mammals should like this:</p>
<figure class="figure text-center">
<img src="/assets/images/plot3.png" class="figure-img img-fluid rounded" alt="..." />
<figcaption class="figure-caption">WordNet mammals</figcaption>
</figure>
<p>WordNet mammals has a low Gromov hyperbolicity.</p>
<p>I also tried the dataset OpenThesaurus which has a pretty high Gromov hyperbolicity. It is a German dataset, but even if you don’t understand this language, you should see the difference.</p>
<figure class="figure text-center">
<img src="/assets/images/hyperbolic.png" class="figure-img img-fluid rounded" alt="..." />
<figcaption class="figure-caption">OpenThesaurus</figcaption>
</figure>
<h2 id="applications">Applications</h2>
<p>Besides visualizations, various uses can be found for these embeddings:</p>
<ul>
<li>clustering with <a href="https://github.com/drewwilimitis/hyperbolic-learning">hyperbolic k-means</a></li>
<li>finding wrong relations in graphs <a href="#5">[3]</a></li>
<li>hypernym-hyponym detection <a href="#5">[4]</a></li>
<li>inside another network for other applications</li>
<li>…</li>
</ul>
<p>However, if you need word similarity or analogy there are better word embeddings. For example, SGNS (skip-gram with negative-sampling) produces quite good results for these tasks.</p>
<h2 id="references">References</h2>
<p>[1] M. Nickel and D. Kiela, <em>Poincaré Embeddings for Learning Hierarchical Representations</em>, 2017.</p>
<p>[2] M. Nickel and D. Kiela, <em>Learning Continuous Hierarchies in the Lorentz Model of Hyperbolic Geometry</em>, 2018.</p>
<p>[3] S. Roller, D. Kiela and M. Nickel, <em>Hearst Patterns Revisited: Automatic Hypernym Detection from Large Text Corpora</em>, 2018.</p>
<p>[4] M. Le, S. Roller, L. Papaxanthos, D. Kiela and M. Nickel, <em>Inferring Concept Hierarchies from Text Corpora via Hyperbolic Embeddings</em>, 2019.</p>After having introduced Riemannian SGD in the last blog post, here I will give a concrete application for this optimization method. Poincaré embeddings [1][2] are hierarchical word embeddings which map integer-encoded words to the hyperbolic space.Riemannian SGD in PyTorch2020-07-23T00:00:00+00:002020-07-23T00:00:00+00:00https://lars76.github.io/2020/07/23/rsgd-in-pytorch<p>A lot of recent papers use different spaces than the regular Euclidean space. This trend is sometimes called geometric deep learning. There is a growing interest particularly in the domain of word embeddings and graphs.</p>
<p>Since geometric neural networks perform optimization in a different space, it is not possible to simply apply stochastic gradient descent.</p>
<!--more-->
<p>The following two equations show what changes are necessary:</p>
\[\begin{aligned}
\text{SGD: } \theta_{t+1} &\gets \theta_t - \lambda \nabla \mathcal{L}\\
\text{RSGD: } \theta_{t+1} &\gets \exp_{\theta_t}\left(- \lambda \nabla_R \mathcal{L}\right)
\end{aligned}\]
<p>where \(\exp_{\theta_t} : \mathcal{T}_{\theta_t}M \to M\) is the exponential map. It maps a small change by the vector \(v \in \mathcal{T}_{\theta_t}M\) on a point of the manifold \(M\). \(\lambda\) is the learning rate.</p>
<p>\(\nabla_R\) is the Riemannian gradient, given by \(g_{\theta_t}^{-1} \nabla \mathcal{L}\) where \(g_{\theta_t}\) is the metric tensor. This gradient is also called the natural gradient. A derivation can be found in <a href="#3">[1]</a>.</p>
<p>In the Euclidean space there is one model:</p>
<ul>
<li>The exponential map is \(\exp_{\theta_t}\left(v\right) = \theta_t + v\) and the metric tensor \(g_{\theta_t}^{-1}\) is the identity matrix.</li>
</ul>
<p>In the hyperbolic space there are multiple models:</p>
<ul>
<li>
<p><strong>Poincaré unit ball:</strong> \(\exp_{\theta_t}(v) = \theta_t \oplus \tanh(\frac{\lambda_{\theta_t} \lvert\lvert v\rvert\rvert}{2})\frac{v}{\lvert\lvert v\rvert\rvert}\) where \(\oplus\) is the Möbius addition and \(\lambda_{\theta_t} = \frac{2}{1 - \lvert\lvert \theta_t\rvert\rvert^2}\) is the conformal factor. However, the approximation \(\exp_{\theta_t}\left(v\right) = \theta_t + v\) tends to work better and is faster. The metric tensor is \(g_{\theta_t} = \lambda_{\theta_t}^2 I_n = \left(\frac{2}{1 - \lvert\lvert \theta_t\rvert\rvert^2}\right)^2 I_n\) where \(I_n\) is the identity matrix.</p>
</li>
<li>
<p><strong>Hyperboloid:</strong> the metric tensor is \(g_{\theta_t} = \begin{bmatrix}-1 & 0 & \cdots & 0\\0 & 1 & \cdots & 0\\\vdots\\0 & 0 & \cdots & 1\end{bmatrix}\). The exponential map is \(\exp_{\theta_t}(v) = \text{cosh}(\lVert v\rVert_{\mathcal{L}})x + \text{sinh}(\lVert v\rVert_{\mathcal{L}})\frac{v}{\lVert v\rVert_{\mathcal{L}}}\) where \(\lVert v\rVert_{\mathcal{L}} = \sqrt{-v_0^2 + \sum_{i=1}^n v_i^2}\).</p>
</li>
<li>
<p>There are other models, but they are not as common. The models are mathematically equivalent, but one has to also consider the numerical stability. Furthermore, it is possible to scale the models to change the Gaussian curvature. For example, the conformal factor of the Poincaré ball would change to \(\lambda_{\theta_t} = \frac{2}{1 - \left(\frac{\lvert\lvert \theta_t\rvert\rvert}{r^2}\right)^2} = \frac{2r^2}{r^2 - \lvert\lvert \theta_t\rvert\rvert^2}\) (see <a href="https://math.stackexchange.com/questions/2882024/conformal-factor-between-euclidean-metric-and-metric-on-poincar%C3%A9-ball-of-arbitra">this</a>). The paper <a href="#3">[3]</a> gives a good introduction to hyperbolic geometry with neural networks.</p>
</li>
</ul>
<p>For elliptic geometry see the paper <a href="#3">[2]</a>.</p>
<h2 id="implementation">Implementation</h2>
<h3 id="poincaré-unit-ball">Poincaré unit ball</h3>
<p>The following code contains also the exact exponential map. I commented the relevant lines out, because empirically the approximation produces slightly better results. Refer to <a href="#3">[4]</a>, the authors did some more tests.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">lambda_x</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">2</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
<span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">mobius_add</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="n">x2</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">y2</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">y</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">xy</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">num</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">xy</span> <span class="o">+</span> <span class="n">y2</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">x2</span><span class="p">)</span> <span class="o">*</span> <span class="n">y</span>
<span class="n">denom</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">xy</span> <span class="o">+</span> <span class="n">x2</span> <span class="o">*</span> <span class="n">y2</span>
<span class="k">return</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span><span class="p">.</span><span class="n">clamp_min</span><span class="p">(</span><span class="mf">1e-15</span><span class="p">)</span>
<span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">expm</span><span class="p">(</span><span class="n">p</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">u</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">return</span> <span class="n">p</span> <span class="o">+</span> <span class="n">u</span>
<span class="c1"># for exact exponential mapping
</span> <span class="c1">#norm = torch.sqrt(torch.sum(u ** 2, dim=-1, keepdim=True))
</span> <span class="c1">#return mobius_add(p, torch.tanh(0.5 * lambda_x(p) * norm) * u / norm.clamp_min(1e-15))
</span>
<span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">grad</span><span class="p">(</span><span class="n">p</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="n">p_sqnorm</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">data</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="k">return</span> <span class="n">p</span><span class="p">.</span><span class="n">grad</span><span class="p">.</span><span class="n">data</span> <span class="o">*</span> <span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p_sqnorm</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">/</span> <span class="mi">4</span><span class="p">).</span><span class="n">expand_as</span><span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">grad</span><span class="p">.</span><span class="n">data</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">RiemannianSGD</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">RiemannianSGD</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="p">{})</span>
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.3</span><span class="p">):</span>
<span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">param_groups</span><span class="p">:</span>
<span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s">'params'</span><span class="p">]:</span>
<span class="k">if</span> <span class="n">p</span><span class="p">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">d_p</span> <span class="o">=</span> <span class="n">grad</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
<span class="n">d_p</span><span class="p">.</span><span class="n">mul_</span><span class="p">(</span><span class="o">-</span><span class="n">lr</span><span class="p">)</span>
<span class="n">p</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">expm</span><span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="n">d_p</span><span class="p">))</span></code></pre></figure>
<h3 id="hyperboloid">Hyperboloid</h3>
<p>According to <a href="#3">[5]</a>, the hyperboloid / Lorentz model is more stable during training. However, my tests showed similar results. Sometimes the Poincaré unit ball was actually more stable.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">expm</span><span class="p">(</span><span class="n">p</span> <span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">u</span> <span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="n">ldv</span> <span class="o">=</span> <span class="n">lorentzian_inner_product</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">u</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">clamp_</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1e-15</span><span class="p">).</span><span class="n">sqrt_</span><span class="p">()</span>
<span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">cosh</span><span class="p">(</span><span class="n">ldv</span><span class="p">)</span> <span class="o">*</span> <span class="n">p</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="n">sinh</span><span class="p">(</span><span class="n">ldv</span><span class="p">)</span> <span class="o">*</span> <span class="n">u</span> <span class="o">/</span> <span class="n">ldv</span>
<span class="k">def</span> <span class="nf">lorentzian_inner_product</span><span class="p">(</span><span class="n">u</span> <span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span> <span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="n">uv</span> <span class="o">=</span> <span class="n">u</span> <span class="o">*</span> <span class="n">v</span>
<span class="n">uv</span><span class="p">.</span><span class="n">narrow</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">).</span><span class="n">mul_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">uv</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
<span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">grad</span><span class="p">(</span><span class="n">p</span> <span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="n">d_p</span> <span class="o">=</span> <span class="n">p</span><span class="p">.</span><span class="n">grad</span>
<span class="n">d_p</span><span class="p">.</span><span class="n">narrow</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">).</span><span class="n">mul_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">d_p</span>
<span class="k">def</span> <span class="nf">proj</span><span class="p">(</span><span class="n">p</span> <span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">d_p</span> <span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">return</span> <span class="n">d_p</span> <span class="o">+</span> <span class="n">lorentzian_inner_product</span><span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="n">d_p</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">*</span> <span class="n">p</span><span class="p">.</span><span class="n">data</span>
<span class="k">class</span> <span class="nc">RiemannianSGD</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">RiemannianSGD</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="p">{})</span>
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.3</span><span class="p">):</span>
<span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">param_groups</span><span class="p">:</span>
<span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s">'params'</span><span class="p">]:</span>
<span class="k">if</span> <span class="n">p</span><span class="p">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">d_p</span> <span class="o">=</span> <span class="n">grad</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
<span class="n">d_p</span> <span class="o">=</span> <span class="n">proj</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">d_p</span><span class="p">)</span>
<span class="n">d_p</span><span class="p">.</span><span class="n">mul_</span><span class="p">(</span><span class="o">-</span><span class="n">lr</span><span class="p">)</span>
<span class="n">p</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">expm</span><span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="n">d_p</span><span class="p">))</span></code></pre></figure>
<p>My preliminary tests with word embeddings showed the following disadvantages of the hyperboloid:</p>
<ol>
<li>It requires one more dimension due to the constraint \(x_0 = \sqrt{1 + \sum_{i=1}^{n+1} x_i^2}\). After training one can get rid of the additional dimension by mapping to the Poincaré unit ball.</li>
<li>It is worse for visualizations, but it is possible to map again to the Poincaré unit ball.</li>
<li>The network weights have to satisfy the equality constraint \(x_0 = \sqrt{1 + \sum_{i=1}^{n+1} x_i^2}\) (see <code class="language-plaintext highlighter-rouge">d_p = proj(p, d_p)</code>). This is more difficult than an inequality constraint, because the equality constraint is always active. In comparison, the Poincaré unit ball is defined by \(\lVert x\rVert < 1\). As long as the learning rate is reasonable, points will never fall of the unit ball.</li>
</ol>
<p>A random initializiation of the weights will violate the equality constraint. Hence, before training one should set all weights \(w\) to <code class="language-plaintext highlighter-rouge">torch.sqrt(1 + torch.sum((w.narrow(-1, 1, w.size(-1) - 1)) ** 2, dim=-1, keepdim=True))</code>.</p>
<h2 id="references">References</h2>
<p>[1] Shun-ichi Amari, <em>Natural Gradient Works Efficiently in Learning</em>, 1998.</p>
<p>[2] Y. Meng, J. Huang, G. Wang, C. Zhang, H. Zhuang, L. Kaplan and J. Han, <em>Spherical Text Embedding</em>, 2019.</p>
<p>[3] O. Ganea, G. Bécigneul and T. Hofmann, <em>Hyperbolic Neural Networks</em>, 2018.</p>
<p>[4] G. Bécigneul and O.-E. Ganea, <em>Riemannian Adaptive Optimization Methods</em>, 2018.</p>
<p>[5] M. Nickel and D. Kiela, <em>Learning Continuous Hierarchies in the Lorentz Model of Hyperbolic Geometry</em>, 2018.</p>A lot of recent papers use different spaces than the regular Euclidean space. This trend is sometimes called geometric deep learning. There is a growing interest particularly in the domain of word embeddings and graphs.Computing Gromov Hyperbolicity2020-07-22T00:00:00+00:002020-07-22T00:00:00+00:00https://lars76.github.io/2020/07/22/computing-gromov-hyperbolicity<p>Gromov Hyperbolicity measures the “tree-likeness” of a dataset. This metric is an indicator of how well hierarchical embeddings such as Poincaré embeddings <a href="#2">[1]</a> would work on a dataset. Some papers which use this metric are <a href="#2">[2]</a> and <a href="#2">[3]</a>. A Gromov Hyperbolicity of approximately zero means a high tree-likeness.</p>
<!--more-->
<p>Here are the results of some NLP datasets:</p>
<ul>
<li>
<p>WordNet mammals (6540 relations): \(\delta_{avg} \approx 0.00178\) (tree)</p>
</li>
<li>
<p>WordNet nouns (743086 relations): \(\delta_{avg} \approx 4 \cdot 10^{-4}\) (tree)</p>
</li>
<li>
<p>OpenThesaurus nouns (321877 relations): \(\delta_{avg} \approx 0.307\) (not a tree)</p>
</li>
<li>
<p>Hyperlex (2228 relations): \(\delta_{avg} \approx 4 \cdot 10^{-4}\) (tree)</p>
</li>
</ul>
<p>There are different ways to define Gromov hyperbolicity. I will present here two variants: sampling \(\delta_{avg}\) and exact \(\delta_{worst}\). Most papers seem to use sampling, because it is faster and produces better results.</p>
\[\delta_{worst}(G) = \max_{x,y,u,v \in V} \frac{S_1 - S_2}{2}\]
<p>where \(S_i = \{d(x, y) + d(u, v), d(x, u) + d(y, v), d(x, v) + d(y, u)\}\) and \(S_i\) is the \(i\)th largest element.</p>
<h2 id="exact-gromov-hyperbolicity">Exact Gromov hyperbolicity</h2>
<p><a href="https://doc.sagemath.org/html/en/reference/graphs/sage/graphs/hyperbolicity.html">SageMath</a> includes already a fast algorithm for computing exact Gromov hyperbolicity. The input is an adjacency matrix. It is a good idea to use a sparse matrix if the dataset is too big.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">sage.all</span>
<span class="kn">from</span> <span class="nn">sage.graphs.hyperbolicity</span> <span class="kn">import</span> <span class="n">hyperbolicity</span>
<span class="kn">from</span> <span class="nn">sage.graphs.graph</span> <span class="kn">import</span> <span class="n">Graph</span>
<span class="kn">from</span> <span class="nn">sage.matrix.constructor</span> <span class="kn">import</span> <span class="n">matrix</span>
<span class="k">def</span> <span class="nf">exact_hyperbolicity</span><span class="p">(</span><span class="n">adjacency_matrix</span><span class="p">):</span>
<span class="n">G</span> <span class="o">=</span> <span class="n">Graph</span><span class="p">(</span><span class="n">matrix</span><span class="p">(</span><span class="n">adjacency_matrix</span><span class="p">),</span> <span class="nb">format</span><span class="o">=</span><span class="s">'adjacency_matrix'</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">G</span><span class="p">.</span><span class="n">is_connected</span><span class="p">()</span>
<span class="n">h</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">hyperbolicity</span><span class="p">(</span><span class="n">G</span><span class="p">,</span> <span class="n">algorithm</span><span class="o">=</span><span class="s">'BCCM'</span><span class="p">)</span>
<span class="k">return</span> <span class="n">h</span></code></pre></figure>
<h2 id="sample-gromov-hyperbolicity">Sample Gromov hyperbolicity</h2>
<p>It is not hard to implement the formula from above. This time we will use the library <a href="https://networkx.github.io/">NetworkX</a>. One can also change the distance function \(d(u, v)\) if the dataset is already embedded in a space (see <a href="#2">[3]</a>).</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">networkx</span> <span class="k">as</span> <span class="n">nx</span>
<span class="k">def</span> <span class="nf">sample_hyperbolicity</span><span class="p">(</span><span class="n">adjacency_matrix</span><span class="p">,</span> <span class="n">num_samples</span><span class="o">=</span><span class="mi">50000</span><span class="p">):</span>
<span class="n">G</span> <span class="o">=</span> <span class="n">nx</span><span class="p">.</span><span class="n">from_scipy_sparse_matrix</span><span class="p">(</span><span class="n">adjacency_matrix</span><span class="p">)</span>
<span class="n">hyps</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_samples</span><span class="p">):</span>
<span class="n">node_tuple</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">choice</span><span class="p">(</span><span class="n">G</span><span class="p">.</span><span class="n">nodes</span><span class="p">(),</span> <span class="mi">4</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">d01</span> <span class="o">=</span> <span class="n">nx</span><span class="p">.</span><span class="n">shortest_path_length</span><span class="p">(</span><span class="n">G</span><span class="p">,</span> <span class="n">source</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">weight</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">d23</span> <span class="o">=</span> <span class="n">nx</span><span class="p">.</span><span class="n">shortest_path_length</span><span class="p">(</span><span class="n">G</span><span class="p">,</span> <span class="n">source</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">weight</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">d02</span> <span class="o">=</span> <span class="n">nx</span><span class="p">.</span><span class="n">shortest_path_length</span><span class="p">(</span><span class="n">G</span><span class="p">,</span> <span class="n">source</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">d13</span> <span class="o">=</span> <span class="n">nx</span><span class="p">.</span><span class="n">shortest_path_length</span><span class="p">(</span><span class="n">G</span><span class="p">,</span> <span class="n">source</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">weight</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">d03</span> <span class="o">=</span> <span class="n">nx</span><span class="p">.</span><span class="n">shortest_path_length</span><span class="p">(</span><span class="n">G</span><span class="p">,</span> <span class="n">source</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">weight</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">d12</span> <span class="o">=</span> <span class="n">nx</span><span class="p">.</span><span class="n">shortest_path_length</span><span class="p">(</span><span class="n">G</span><span class="p">,</span> <span class="n">source</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">node_tuple</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">s</span> <span class="o">=</span> <span class="p">[</span><span class="n">d01</span> <span class="o">+</span> <span class="n">d23</span><span class="p">,</span> <span class="n">d02</span> <span class="o">+</span> <span class="n">d13</span><span class="p">,</span> <span class="n">d03</span> <span class="o">+</span> <span class="n">d12</span><span class="p">]</span>
<span class="n">s</span><span class="p">.</span><span class="n">sort</span><span class="p">()</span>
<span class="n">hyps</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">s</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">s</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">])</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">except</span> <span class="nb">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="k">continue</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">hyps</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">hyps</span><span class="p">)</span></code></pre></figure>
<h2 id="references">References</h2>
<p>[1] M. Nickel and D. Kiela, <em>Poincaré Embeddings for Learning Hierarchical Representations</em>, 2017.</p>
<p>[2] I. Chami, Z. Ying, C. Ré and J. Leskovec, <em>Hyperbolic Graph Convolutional Neural Networks</em>, 2019.</p>
<p>[3] A. Tifrea, G. Bécigneul and O.-E. Ganea, Poincaré GloVe: Hyperbolic Word Embeddings, 2018.</p>Gromov Hyperbolicity measures the “tree-likeness” of a dataset. This metric is an indicator of how well hierarchical embeddings such as Poincaré embeddings [1] would work on a dataset. Some papers which use this metric are [2] and [3]. A Gromov Hyperbolicity of approximately zero means a high tree-likeness.New Blog2020-07-21T00:00:00+00:002020-07-21T00:00:00+00:00https://lars76.github.io/2020/07/21/new-blog<p>I decided to update my blog and replace <a href="https://github.com/mmistakes/minimal-mistakes">minimal-mistakes</a> by my own Jekyll theme. My goal was to increase the space for content and reduce the amount of personal information the reader sees. I took inspiration from the Bootstrap theme <a href="https://github.com/StartBootstrap/startbootstrap-clean-blog">Clean Blog</a>. Some of my older posts need updating, so I removed them for now.</p>
<!--more-->
<p>An advantage of not depending on external themes is that it is easier to add new features. For example, when a text is too long, it is good to have a table of contents. I created the table of contents in pure javascript (a bit “hackish” but it works). It should appear on any post where there are <code class="language-plaintext highlighter-rouge"><h2></code> and <code class="language-plaintext highlighter-rouge"><h3></code> tags (not on this one).</p>
<p>The new website was created using the following CSS/JS libraries (and of course <a href="https://jekyllrb.com/">Jekyll</a>):</p>
<ul>
<li>Bootstrap (general design)</li>
<li>medium-zoom (images)</li>
<li>KaTeX (math typesetting)</li>
<li>Font Awesome (icons)</li>
<li>jQuery (because some libraries still require it)</li>
</ul>
<p>Then I added my own CSS and JavaScript code. Building this website was surprisingly easy, it took me a few days at most.</p>I decided to update my blog and replace minimal-mistakes by my own Jekyll theme. My goal was to increase the space for content and reduce the amount of personal information the reader sees. I took inspiration from the Bootstrap theme Clean Blog. Some of my older posts need updating, so I removed them for now.Loss Functions For Segmentation2018-09-27T00:00:00+00:002018-09-27T00:00:00+00:00https://lars76.github.io/2018/09/27/loss-functions-for-segmentation<p>In this post, I will implement some of the most common loss functions for image segmentation in Keras/TensorFlow. I will only consider the case of two classes (i.e. binary).</p>
<!--more-->
<p><strong>01.09.2020</strong>: rewrote lots of parts, fixed mistakes, updated to TensorFlow 2.3</p>
<p><strong>16.08.2019</strong>: improved overlap measures, added CE+DL loss</p>
<h2 id="cross-entropy">Cross Entropy</h2>
<p>We have two probability distributions:</p>
<ol>
<li>The prediction can either be \(\mathbf{P}(\hat{Y} = 0) = \hat{p}\) or \(\mathbf{P}(\hat{Y} = 1) = 1 - \hat{p}\).</li>
<li>The ground truth can either be \(\mathbf{P}(Y = 0) = p\) or \(\mathbf{P}(Y = 1) = 1 - p\).</li>
</ol>
<p>The predictions are given by the logistic/sigmoid function \(\hat{p} = \frac{1}{1 + e^{-x}}\) and the ground truth is \(p \in \{0,1\}\).</p>
<p>Then cross entropy (CE) can be defined as follows:</p>
\[\text{CE}\left(p, \hat{p}\right) = -\left(p \log\left(\hat{p}\right) + (1-p) \log\left(1 - \hat{p}\right)\right)\]
<p>In Keras, the loss function is <code class="language-plaintext highlighter-rouge">BinaryCrossentropy</code> and in TensorFlow, it is <code class="language-plaintext highlighter-rouge">sigmoid_cross_entropy_with_logits</code>. For multiple classes, it is <code class="language-plaintext highlighter-rouge">softmax_cross_entropy_with_logits_v2</code> and <code class="language-plaintext highlighter-rouge">CategoricalCrossentropy</code>/<code class="language-plaintext highlighter-rouge">SparseCategoricalCrossentropy</code>. Due to numerical stability, it is always better to use <code class="language-plaintext highlighter-rouge">BinaryCrossentropy</code> with <code class="language-plaintext highlighter-rouge">from_logits=True</code>. However, then the model should not contain the layer <code class="language-plaintext highlighter-rouge">tf.keras.layers.Sigmoid()</code> or <code class="language-plaintext highlighter-rouge">tf.keras.layers.Softmax()</code>.</p>
<p>You can see in the <a href="https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/python/keras/backend.py#L4797">original code</a> that TensorFlow sometimes tries to compute cross entropy from probabilities (when <code class="language-plaintext highlighter-rouge">from_logits=False</code>). Due to numerical instabilities <code class="language-plaintext highlighter-rouge">clip_by_value</code> becomes then necessary.</p>
<p>In this post, I will always assume that <code class="language-plaintext highlighter-rouge">tf.keras.layers.Sigmoid()</code> is not applied (or only during prediction).</p>
<h3 id="weighted-cross-entropy">Weighted cross entropy</h3>
<p>Weighted cross entropy (WCE) is a variant of CE where all positive examples get weighted by some coefficient. It is used in the case of class imbalance. In segmentation, it is often not necessary. However, it can be beneficial when the training of the neural network is unstable. In classification, it is mostly used for multiple classes. This is why TensorFlow has no function <code class="language-plaintext highlighter-rouge">tf.nn.weighted_binary_entropy_with_logits</code>. There is only <code class="language-plaintext highlighter-rouge">tf.nn.weighted_cross_entropy_with_logits</code>.</p>
<p>WCE can be defined as follows:</p>
\[\text{WCE}\left(p, \hat{p}\right) = -\left(\beta p \log\left(\hat{p}\right) + (1-p) \log\left(1 - \hat{p}\right)\right)\]
<p>To decrease the number of false negatives, set \(\beta > 1\). To decrease the number of false positives, set \(\beta < 1\).</p>
<p>The implementation looks as follows</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">weighted_cross_entropy</span><span class="p">(</span><span class="n">beta</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="n">weight_a</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">weight_b</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">o</span> <span class="o">=</span> <span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log1p</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">tf</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)))</span> <span class="o">+</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="o">-</span><span class="n">y_pred</span><span class="p">))</span> <span class="o">*</span> <span class="p">(</span><span class="n">weight_a</span> <span class="o">+</span> <span class="n">weight_b</span><span class="p">)</span> <span class="o">+</span> <span class="n">y_pred</span> <span class="o">*</span> <span class="n">weight_b</span>
<span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="k">return</span> <span class="n">loss</span></code></pre></figure>
<p>Loss functions can be set when compiling the model (Keras):</p>
<p><code class="language-plaintext highlighter-rouge">model.compile(loss=weighted_cross_entropy(beta=beta), optimizer=optimizer, metrics=metrics)</code></p>
<p>If you are wondering why there is a ReLU function, this follows from simplifications. I derive the formula in the section on focal loss.</p>
<p>The result of a loss function is always a scalar. Some deep learning libraries will automatically apply <code class="language-plaintext highlighter-rouge">reduce_mean</code> or <code class="language-plaintext highlighter-rouge">reduce_sum</code> if you don’t do it. When combining different loss functions, sometimes the <code class="language-plaintext highlighter-rouge">axis</code> argument of <code class="language-plaintext highlighter-rouge">reduce_mean</code> can become important. Since TensorFlow 2.0, the class <code class="language-plaintext highlighter-rouge">BinaryCrossentropy</code> has the argument <code class="language-plaintext highlighter-rouge">reduction=losses_utils.ReductionV2.AUTO</code>.</p>
<h3 id="balanced-cross-entropy">Balanced cross entropy</h3>
<p>Balanced cross entropy (BCE) is similar to WCE. The only difference is that we weight also the negative examples.</p>
<p>BCE can be defined as follows:</p>
\[\text{BCE}\left(p, \hat{p}\right) = -\left(\beta p \log\left(\hat{p}\right) + (1 - \beta)(1-p) \log\left(1 - \hat{p}\right)\right)\]
<p>It can be implemented as follows:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">balanced_cross_entropy</span><span class="p">(</span><span class="n">beta</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="n">weight_a</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">weight_b</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta</span><span class="p">)</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">o</span> <span class="o">=</span> <span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log1p</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">tf</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)))</span> <span class="o">+</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="o">-</span><span class="n">y_pred</span><span class="p">))</span> <span class="o">*</span> <span class="p">(</span><span class="n">weight_a</span> <span class="o">+</span> <span class="n">weight_b</span><span class="p">)</span> <span class="o">+</span> <span class="n">y_pred</span> <span class="o">*</span> <span class="n">weight_b</span>
<span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="k">return</span> <span class="n">loss</span></code></pre></figure>
<p>Instead of using a fixed value like <code class="language-plaintext highlighter-rouge">beta = 0.3</code>, it is also possible to dynamically adjust the value of <code class="language-plaintext highlighter-rouge">beta</code>. For example, the paper <a href="#10">[1]</a> uses: <code class="language-plaintext highlighter-rouge">beta = tf.reduce_mean(1 - y_true)</code></p>
<h3 id="focal-loss">Focal loss</h3>
<p>Focal loss (FL) <a href="#10">[2]</a> tries to down-weight the contribution of easy examples so that the CNN focuses more on hard examples.</p>
<p>FL can be defined as follows:</p>
\[\text{FL}\left(p, \hat{p}\right) = -\left(\alpha (1 - \hat{p})^{\gamma} p \log\left(\hat{p}\right) + (1 - \alpha) \hat{p}^{\gamma} (1-p) \log\left(1 - \hat{p}\right)\right)\]
<p>When \(\gamma = 0\), we obtain BCE.</p>
<p>There are a lot of simplifications possible when implementing FL. TensorFlow uses the same simplifications for <code class="language-plaintext highlighter-rouge">sigmoid_cross_entropy_with_logits</code> (see the <a href="https://github.com/tensorflow/tensorflow/blob/926c08624849abda617b5e0330b33d94365c08dc/tensorflow/python/ops/nn_impl.py#L115">original code</a>)</p>
\[\begin{aligned}
\text{FL}\left(p, \hat{p}\right) &= \alpha(1 - \hat{p})^{\gamma} p \log\left(1 + e^{-x}\right) - \left(1 - \alpha\right)\hat{p}^{\gamma}(1-p) \log\left(\frac{e^{-x}}{1 + e^{-x}}\right)\\
&= \alpha(1 - \hat{p})^{\gamma}p \log\left(1 + e^{-x}\right) - \left(1 - \alpha\right)\hat{p}^{\gamma}\left(1-p\right)\left(-x - \log\left(1 + e^{-x}\right)\right)\\
&= \alpha(1 - \hat{p})^{\gamma}p \log\left(1 + e^{-x}\right) + \left(1 - \alpha\right)\hat{p}^{\gamma}\left(1-p\right)\left(x + \log\left(1 + e^{-x}\right)\right)\\
&= \log\left(1 + e^{-x}\right)\left(\alpha (1 - \hat{p})^{\gamma} p + (1-\alpha)\hat{p}^{\gamma}(1-p)\right) + x(1 - \alpha)\hat{p}^{\gamma}(1 - p)\\
&= \log\left(e^{-x}(1 + e^{x})\right)\left(\alpha (1 - \hat{p})^{\gamma} p + (1-\alpha)\hat{p}^{\gamma}(1-p)\right) + x(1 - \alpha)\hat{p}^{\gamma}(1 - p)\\
&= \left(\log\left(1 + e^{x}\right) - x\right)\left(\alpha (1 - \hat{p})^{\gamma} p + (1-\alpha)\hat{p}^{\gamma}(1-p)\right) + x(1 - \alpha)\hat{p}^{\gamma}(1 - p)\\
&= \left(\log\left(1 + e^{-|x|}\right) + \max(-x, 0)\right)\left(\alpha (1 - \hat{p})^{\gamma} p + (1-\alpha)\hat{p}^{\gamma}(1-p)\right) + x(1 - \alpha)\hat{p}^{\gamma}(1 - p)\\
\end{aligned}\]
<p>And the implementation is then:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">focal_loss</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="mf">0.25</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="mi">2</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">focal_loss_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">gamma</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="n">targets</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">weight_a</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span> <span class="o">**</span> <span class="n">gamma</span> <span class="o">*</span> <span class="n">targets</span>
<span class="n">weight_b</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">*</span> <span class="n">y_pred</span> <span class="o">**</span> <span class="n">gamma</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">targets</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log1p</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">tf</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">logits</span><span class="p">)))</span> <span class="o">+</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="o">-</span><span class="n">logits</span><span class="p">))</span> <span class="o">*</span> <span class="p">(</span><span class="n">weight_a</span> <span class="o">+</span> <span class="n">weight_b</span><span class="p">)</span> <span class="o">+</span> <span class="n">logits</span> <span class="o">*</span> <span class="n">weight_b</span>
<span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">logits</span><span class="p">):</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">focal_loss_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="n">y_true</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="n">gamma</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="k">return</span> <span class="n">loss</span></code></pre></figure>
<h3 id="distance-to-the-nearest-cell">Distance to the nearest cell</h3>
<p>The paper <a href="#10">[3]</a> adds to cross entropy a distance function to force the CNN to learn the separation border between touching objects. In other words, this is BCE with an additional distance term:</p>
\[\text{DNC}\left(p, \hat{p}\right) = -\left(w(p) p \log\left(\hat{p}\right) + w(p)(1-p) \log\left(1 - \hat{p}\right)\right)\]
<p>where</p>
\[w(p) = w_c(p) + w_0\cdot\exp\left(-\frac{(d_1(p) + d_2(p))^2}{2\sigma^2}\right)\]
<p>\(d_1(x)\) and \(d_2(x)\) are two functions that calculate the distance to the nearest and second nearest cell and \(w_c(p) = \beta\) or \(w_c(p) = 1 - \beta\). If we had multiple classes, then \(w_c(p)\) would return a different \(\beta_i\) depending on the class \(i\). The values \(w_0\), \(\sigma\), \(\beta\) are all parameters of the loss function (some constants).</p>
<p>Calculating the exponential term inside the loss function would slow down the training considerably. Hence, it is better to precompute the distance map and pass it to the neural network together with the image input.</p>
<p>The following code is a variation that calculates the distance only to one object.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">scipy.spatial</span> <span class="kn">import</span> <span class="n">distance_matrix</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="p">...</span>
<span class="n">not_zeros</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argwhere</span><span class="p">(</span><span class="n">img</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">zeros</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argwhere</span><span class="p">(</span><span class="n">img</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">dist_matrix</span> <span class="o">=</span> <span class="n">distance_matrix</span><span class="p">(</span><span class="n">zeros</span><span class="p">,</span> <span class="n">not_zeros</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">HEIGHT</span><span class="p">,</span> <span class="n">WIDTH</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="n">i</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">dist</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">dist_matrix</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">HEIGHT</span><span class="p">):</span>
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">WIDTH</span><span class="p">):</span>
<span class="k">if</span> <span class="n">img</span><span class="p">[</span><span class="n">y</span><span class="p">,</span><span class="n">x</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">output</span><span class="p">[</span><span class="n">y</span><span class="p">,</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">dist</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="p">...</span></code></pre></figure>
<p>For example, on the left is a mask and on the right is the corresponding weight map.</p>
<figure class="figure text-center">
<img src="/assets/images/mask.png" class="figure-img img-fluid rounded" alt="..." />
<figcaption class="figure-caption">Mask and weight map in comparison</figcaption>
</figure>
<p>The blacker the pixel, the higher is the weight of the exponential term. To pass the weight matrix as input, one could use:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="k">def</span> <span class="nf">loss_function</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">weights</span><span class="p">):</span>
<span class="p">...</span>
<span class="n">weight_input</span> <span class="o">=</span> <span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">HEIGHT</span><span class="p">,</span> <span class="n">WIDTH</span><span class="p">))</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">loss_function</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weight_input</span><span class="p">)</span></code></pre></figure>
<h2 id="overlap-measures">Overlap measures</h2>
<h3 id="dice-loss--f1-score">Dice Loss / F1 score</h3>
<p>The Dice coefficient is similar to the Jaccard Index (Intersection over Union, IoU):</p>
\[\text{DC} = \frac{2 TP}{2 TP + FP + FN} = \frac{2|X \cap Y|}{|X| + |Y|}\]
\[\text{IoU} = \frac{TP}{TP + FP + FN} = \frac{|X \cap Y|}{|X| + |Y| - |X \cap Y|}\]
<p>where TP are the true positives, FP false positives and FN false negatives. We can see that \(\text{DC} \geq \text{IoU}\).</p>
<p>The dice coefficient can also be defined as a loss function:</p>
\[\text{DL}\left(p, \hat{p}\right) = 1 - \frac{2\sum p_{h,w}\hat{p}_{h,w}}{\sum p_{h,w} + \sum \hat{p}_{h,w}}\]
<p>where \(p_{h,w} \in \{0,1\}\) and \(0 \leq \hat{p}_{h,w} \leq 1\).</p>
<p>The code is then</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">dice_loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="n">y_true</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">y_true</span> <span class="o">*</span> <span class="n">y_pred</span><span class="p">)</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">y_true</span> <span class="o">+</span> <span class="n">y_pred</span><span class="p">)</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span></code></pre></figure>
<p>In general, dice loss works better when it is applied on images than on single pixels. This means \(1 - \frac{2p\hat{p}}{p + \hat{p}}\) is never used for segmentation.</p>
<h3 id="tversky-loss">Tversky loss</h3>
<p>Tversky index (TI) is a generalization of the Dice coefficient. TI adds a weight to FP (false positives) and FN (false negatives).</p>
\[\text{TI}\left(p, \hat{p}\right) = 1 - \frac{p\hat{p}}{p\hat{p} + \beta(1 - p)\hat{p} + (1 - \beta)p(1 - \hat{p})}\]
<p>Let \(\beta = \frac{1}{2}\). Then</p>
\[\begin{aligned}
&= 1 - \frac{2 p\hat{p}}{2p\hat{p} + (1 - p)\hat{p} + p (1 - \hat{p})}\\
&= 1 - \frac{2 p\hat{p}}{\hat{p} + p}
\end{aligned}\]
<p>which is just the regular Dice coefficient. Since we are interested in sets of pixels, the following function computes the sum of pixels <a href="#10">[5]</a>:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">tversky_loss</span><span class="p">(</span><span class="n">beta</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="n">y_true</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="n">y_true</span> <span class="o">*</span> <span class="n">y_pred</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">y_true</span> <span class="o">*</span> <span class="n">y_pred</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">y_true</span><span class="p">)</span> <span class="o">*</span> <span class="n">y_pred</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta</span><span class="p">)</span> <span class="o">*</span> <span class="n">y_true</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">)</span> <span class="o">/</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">denominator</span><span class="p">)</span>
<span class="k">return</span> <span class="n">loss</span></code></pre></figure>
<h3 id="lovász-softmax">Lovász-Softmax</h3>
<p>DL and TL simply relax the hard constraint \(p \in \{0,1\}\) in order to have a function on the domain \([0, 1]\). The paper <a href="#10">[6]</a> derives instead a surrogate loss function.</p>
<p>An implementation of Lovász-Softmax can be found on <a href="https://github.com/bermanmaxim/LovaszSoftmax/blob/master/tensorflow/lovasz_losses_tf.py">github</a>. Note that this loss does not rely on the sigmoid function (“hinge loss”). A negative value means class A and a positive value means class B.</p>
<p>In Keras the loss function can be used as follows:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">lovasz_softmax</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="k">return</span> <span class="n">lovasz_hinge</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="n">y_true</span><span class="p">,</span> <span class="n">logits</span><span class="o">=</span><span class="n">y_pred</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="n">lovasz_softmax</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">pixel_iou</span><span class="p">])</span></code></pre></figure>
<h2 id="combinations">Combinations</h2>
<p>It is also possible to combine multiple loss functions. The following function is quite popular in data competitions:</p>
\[\text{CE}\left(p, \hat{p}\right) + \text{DL}\left(p, \hat{p}\right)\]
<p>Note that \(\text{CE}\) returns a tensor, while \(\text{DL}\) returns a scalar for each image in the batch. This way we combine local (\(\text{CE}\)) with global information (\(\text{DL}\)).</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">dice_loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">y_true</span> <span class="o">*</span> <span class="n">y_pred</span><span class="p">)</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">y_true</span> <span class="o">+</span> <span class="n">y_pred</span><span class="p">)</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span>
<span class="n">y_true</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">sigmoid_cross_entropy_with_logits</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="o">+</span> <span class="n">dice_loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">o</span><span class="p">)</span></code></pre></figure>
<p>Some people additionally apply the logarithm function to <code class="language-plaintext highlighter-rouge">dice_loss</code>.</p>
<p><strong>Example:</strong> Let \(\mathbf{P}\) be our real image, \(\mathbf{\hat{P}}\) the prediction and \(\mathbf{L}\) the result of the loss function.</p>
\[\mathbf{P} = \begin{bmatrix}1 & 1\\0 & 0\end{bmatrix}\]
\[\mathbf{\hat{P}} = \begin{bmatrix}0.5 & 0.6\\0.2 & 0.1\end{bmatrix}\]
<p>Then \(\mathbf{L} = \begin{bmatrix}-1\log(0.5) + l_2 & -1\log(0.6) + l_2\\-(1 - 0)\log(1 - 0.2) + l_2 & -(1 - 0)\log(1 - 0.1) + l_2\end{bmatrix}\), where</p>
\[l_2 = 1 - \frac{2(1 \cdot 0.5 + 1 \cdot 0.6 + 0 \cdot 0.2 + 0 \cdot 0.1)}{(1 + 1 + 0 + 0) + (0.5 + 0.6 + 0.2 + 0.1)} \approx 0.3529\]
<p>The result is:</p>
\[\mathbf{L} \approx \begin{bmatrix}0.6931 + 0.3529 & 0.5108 + 0.3529\\0.2231 + 0.3529 & 0.1054 + 0.3529\end{bmatrix} = \begin{bmatrix}1.046 & 0.8637\\0.576 & 0.4583\end{bmatrix}\]
<p>Next, we compute the mean via <code class="language-plaintext highlighter-rouge">tf.reduce_mean</code> which results in \(\frac{1}{4}(1.046 + 0.8637 + 0.576 + 0.4583) = 0.736\)</p>
<p>Let’s check the result:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">c</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">constant</span><span class="p">([[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">]])</span>
<span class="n">d</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">constant</span><span class="p">([[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]])</span>
<span class="k">print</span><span class="p">(</span><span class="n">loss</span><span class="p">(</span><span class="n">c</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">d</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">d</span><span class="p">))))</span>
<span class="c1"># tf.Tensor(0.7360604, shape=(), dtype=float32)</span></code></pre></figure>
<h2 id="references">References</h2>
<p>[1] S. Xie and Z. Tu. <em>Holistically-Nested Edge Detection</em>, 2015.</p>
<p>[2] T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar. <em>Focal Loss for Dense Object Detection</em>, 2017.</p>
<p>[3] O. Ronneberger, P. Fischer, and T. Brox. <em>U-Net: Convolutional Networks for Biomedical Image Segmentation</em>, 2015.</p>
<p>[4] F. Milletari, N. Navab, and S.-A. Ahmadi. <em>V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation</em>, 2016.</p>
<p>[5] S. S. M. Salehi, D. Erdogmus, and A. Gholipour. <em>Tversky loss function for image segmentation using 3D fully convolutional deep networks</em>, 2017.</p>
<p>[6] M. Berman, A. R. Triki, M. B. Blaschko. <em>The Lovász-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks</em>, 2018.</p>In this post, I will implement some of the most common loss functions for image segmentation in Keras/TensorFlow. I will only consider the case of two classes (i.e. binary).Portuguese Lemmatizers (2020 update)2018-05-08T00:00:00+00:002018-05-08T00:00:00+00:00https://lars76.github.io/2018/05/08/portuguese-lemmatizers<p>In this post, I will compare some lemmatizers for Portuguese. In order to do the comparison, I downloaded subtitles from various television programs. The sentences are written in European Portuguese (EP).</p>
<!--more-->
<p><strong>01.09.2020: I have migrated the post from my old blog and updated it to reflect the current state of lemmatization.</strong></p>
<h2 id="rule-based">Rule-based</h2>
<h3 id="hunspell">Hunspell</h3>
<p>There exists a Python binding for Hunspell called “CyHunspell”. This library contains a function <em>stem</em> which can be used to get the root of a word. In contrast to Spacy, it is not possible to consider the context in which the word occurs. A dictionary can be downloaded <a href="http://natura.di.uminho.pt/download/TGZ/Dictionaries/hunspell/LATEST/">here</a>.</p>
<p>It is also necessary to use beforehand a tokenizer. If we don’t consider special cases like mesoclisis, it’s easy to write our own.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">hunspell</span>
<span class="kn">import</span> <span class="nn">re</span>
<span class="k">def</span> <span class="nf">tokenize</span><span class="p">(</span><span class="n">sentence</span><span class="p">):</span>
<span class="n">tokens_regex</span> <span class="o">=</span> <span class="n">re</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="sa">r</span><span class="s">"([., :;\n()\"#!?1234567890/&%+])"</span><span class="p">,</span> <span class="n">flags</span><span class="o">=</span><span class="n">re</span><span class="p">.</span><span class="n">IGNORECASE</span><span class="p">)</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">re</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">tokens_regex</span><span class="p">,</span> <span class="n">sentence</span><span class="p">)</span>
<span class="n">postprocess</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">postprocess_regex</span> <span class="o">=</span> <span class="n">re</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="sa">r</span><span class="s">"\b(\w+)-(me|te|se|nos|vos|o|os|a|as|lo|los|la|las|lhe|lhes|lha|lhas|lho|lhos|no|na|nas|mo|ma|mos|mas|to|ta|tos|tas)\b"</span><span class="p">,</span> <span class="n">flags</span><span class="o">=</span><span class="n">re</span><span class="p">.</span><span class="n">IGNORECASE</span><span class="p">)</span>
<span class="k">for</span> <span class="n">token</span> <span class="ow">in</span> <span class="n">tokens</span><span class="p">:</span>
<span class="k">for</span> <span class="n">token2</span> <span class="ow">in</span> <span class="n">re</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">postprocess_regex</span><span class="p">,</span> <span class="n">token</span><span class="p">):</span>
<span class="k">if</span> <span class="n">token2</span><span class="p">.</span><span class="n">strip</span><span class="p">():</span>
<span class="n">postprocess</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">token2</span><span class="p">)</span>
<span class="k">return</span> <span class="n">postprocess</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">tokenize</span><span class="p">(</span><span class="s">"Estás bem ?"</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="n">hunspell</span><span class="p">.</span><span class="n">Hunspell</span><span class="p">(</span><span class="s">"pt_PT"</span><span class="p">,</span> <span class="n">hunspell_data_dir</span><span class="o">=</span><span class="s">"/usr/share/hunspell/"</span><span class="p">)</span>
<span class="n">text</span> <span class="o">=</span> <span class="s">""</span>
<span class="n">lemmas</span> <span class="o">=</span> <span class="s">""</span>
<span class="k">for</span> <span class="n">token</span> <span class="ow">in</span> <span class="n">tokens</span><span class="p">:</span>
<span class="n">text</span> <span class="o">+=</span> <span class="n">token</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span>
<span class="n">lemma</span> <span class="o">=</span> <span class="n">h</span><span class="p">.</span><span class="n">stem</span><span class="p">(</span><span class="n">token</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">lemma</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">lemmas</span> <span class="o">+=</span> <span class="n">lemma</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">lemmas</span> <span class="o">+=</span> <span class="n">token</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span></code></pre></figure>
<p>The results are:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">Estás</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">estás</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">Está</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">está</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">Não</span> <span class="p">,</span> <span class="n">minha</span> <span class="n">miúda</span> <span class="n">no</span> <span class="n">sentido</span> <span class="n">que</span> <span class="n">és</span> <span class="n">como</span> <span class="n">uma</span> <span class="n">irmã</span> <span class="n">para</span> <span class="n">mim</span> <span class="p">.</span>
<span class="n">não</span> <span class="p">,</span> <span class="n">minha</span> <span class="n">miúdo</span> <span class="n">no</span> <span class="n">sentido</span> <span class="n">que</span> <span class="n">és</span> <span class="n">como</span> <span class="n">um</span> <span class="n">irmão</span> <span class="n">para</span> <span class="n">mim</span> <span class="p">.</span></code></pre></figure>
<p>Not every word gets assigned a lemma, because some tokens don’t seem to have entries in the dictionary.</p>
<p>Another problem is the context. The dictionary has for example two different stems for the word “sentido”: “sentir” and “sentido”. In the first case, it could be a verb conjugated in <em>pretérito perfeito composto</em> (tenho sentido etc.). In the second case, the word is a noun. Hence, we need a Part-of-Speech (POS) Tagger to decide which case is the right one.</p>
<h3 id="lemport">LemPORT</h3>
<p>This library is written in Java and requires an external tokenizer and POS Tagger.</p>
<figure class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">lemma.Lemmatizer</span><span class="o">;</span>
<span class="kd">public</span> <span class="kd">class</span> <span class="nc">Main</span> <span class="o">{</span>
<span class="kd">public</span> <span class="kd">static</span> <span class="kt">void</span> <span class="nf">main</span><span class="o">(</span><span class="kd">final</span> <span class="nc">String</span><span class="o">[]</span> <span class="n">args</span><span class="o">)</span> <span class="o">{</span>
<span class="kd">final</span> <span class="nc">String</span><span class="o">[]</span> <span class="n">tokens</span> <span class="o">=</span> <span class="o">{</span><span class="s">"Estás"</span><span class="o">,</span> <span class="s">"bem"</span><span class="o">,</span> <span class="s">"?"</span><span class="o">};</span>
<span class="kd">final</span> <span class="nc">String</span><span class="o">[]</span> <span class="n">tags</span> <span class="o">=</span> <span class="o">{</span><span class="s">"v-fin"</span><span class="o">,</span> <span class="s">"adv"</span><span class="o">,</span> <span class="s">"punc"</span><span class="o">};</span>
<span class="kd">final</span> <span class="nc">Lemmatizer</span> <span class="n">lemmatizer</span><span class="o">;</span>
<span class="kd">final</span> <span class="nc">String</span><span class="o">[]</span> <span class="n">lemmas</span><span class="o">;</span>
<span class="k">try</span> <span class="o">{</span>
<span class="n">lemmatizer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Lemmatizer</span><span class="o">();</span>
<span class="n">lemmas</span> <span class="o">=</span> <span class="n">lemmatizer</span><span class="o">.</span><span class="na">lemmatize</span><span class="o">(</span><span class="n">tokens</span><span class="o">,</span> <span class="n">tags</span><span class="o">);</span>
<span class="o">}</span> <span class="k">catch</span> <span class="o">(</span><span class="nc">Exception</span> <span class="n">e</span><span class="o">)</span> <span class="o">{</span>
<span class="n">e</span><span class="o">.</span><span class="na">printStackTrace</span><span class="o">();</span>
<span class="k">return</span><span class="o">;</span>
<span class="o">}</span>
<span class="kd">final</span> <span class="nc">StringBuilder</span> <span class="n">token</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">StringBuilder</span><span class="o">();</span>
<span class="kd">final</span> <span class="nc">StringBuilder</span> <span class="n">lemma</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">StringBuilder</span><span class="o">();</span>
<span class="k">for</span> <span class="o">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> <span class="n">i</span> <span class="o"><</span> <span class="n">tokens</span><span class="o">.</span><span class="na">length</span><span class="o">;</span> <span class="n">i</span><span class="o">++)</span> <span class="o">{</span>
<span class="n">token</span><span class="o">.</span><span class="na">append</span><span class="o">(</span><span class="n">tokens</span><span class="o">[</span><span class="n">i</span><span class="o">]).</span><span class="na">append</span><span class="o">(</span><span class="s">"\t"</span><span class="o">);</span>
<span class="n">lemma</span><span class="o">.</span><span class="na">append</span><span class="o">(</span><span class="n">lemmas</span><span class="o">[</span><span class="n">i</span><span class="o">]).</span><span class="na">append</span><span class="o">(</span><span class="s">"\t"</span><span class="o">);</span>
<span class="o">}</span>
<span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">token</span><span class="o">);</span>
<span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">lemma</span><span class="o">);</span>
<span class="o">}</span>
<span class="o">}</span></code></pre></figure>
<p>When I used the right annotations, the lemmas were generated correctly. However, there is an issue with the size of the dictionary. Using the full dictionary “resources/acdc/lemas.total.txt”, will result in a “java.lang.OutOfMemoryError: GC overhead limit exceeded” exception. One can give either Java more memory or use a smaller dictionary to fix this.</p>
<h3 id="nltk">NLTK</h3>
<p>NLTK is one of the most popular libraries for NLP-related tasks. However, it does not contain a lemmatizer for Portuguese. There are only two stemmers: RSLPStemmer and snowball.</p>
<h2 id="neural-network-based">Neural network based</h2>
<h3 id="spacy">Spacy</h3>
<p>Spacy is a relatively new NLP library for Python. A language model for Portuguese can be downloaded <a href="https://github.com/explosion/spacy-models/releases">here</a>. This model was trained with a CNN on the Universal Dependencies and WikiNER corpus.</p>
<p>Let us try some sentences.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">spacy</span>
<span class="n">nlp</span> <span class="o">=</span> <span class="n">spacy</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"pt_core_news_lg"</span><span class="p">)</span>
<span class="n">text</span> <span class="o">=</span> <span class="s">""</span>
<span class="n">pos</span> <span class="o">=</span> <span class="s">""</span>
<span class="n">lemma</span> <span class="o">=</span> <span class="s">""</span>
<span class="k">for</span> <span class="n">token</span> <span class="ow">in</span> <span class="n">nlp</span><span class="p">(</span><span class="s">"Estás bem ?"</span><span class="p">):</span>
<span class="n">text</span> <span class="o">+=</span> <span class="n">token</span><span class="p">.</span><span class="n">text</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span>
<span class="n">pos</span> <span class="o">+=</span> <span class="n">token</span><span class="p">.</span><span class="n">pos_</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span>
<span class="n">lemma</span> <span class="o">+=</span> <span class="n">token</span><span class="p">.</span><span class="n">lemma_</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span></code></pre></figure>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">Estás</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">AUX</span> <span class="n">ADV</span> <span class="n">PUNCT</span>
<span class="n">Estás</span> <span class="n">bem</span> <span class="err">?</span> </code></pre></figure>
<p>There is a mistake with the word “Estás”. The lemma should be “estar”. Most Portuguese-speaking countries don’t use the second-person singular. Thus, the problem could be that the corpus does not contain enough texts written in EP.</p>
<p>To verify this, let us consider the sentences “Está bem ?” and “Você está bem ?”.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">Está</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">VERB</span> <span class="n">ADV</span> <span class="n">PUNCT</span>
<span class="n">Está</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">Você</span> <span class="n">está</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">PRON</span> <span class="n">VERB</span> <span class="n">ADV</span> <span class="n">PUNCT</span>
<span class="n">Você</span> <span class="n">estar</span> <span class="n">bem</span> <span class="err">?</span> </code></pre></figure>
<p>The library still doesn’t find the correct lemma. Only by explicitly adding a pronoun, we can get the right result.</p>
<p>Maybe we have more success with longer sentences.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">Não</span> <span class="p">,</span> <span class="n">minha</span> <span class="n">miúda</span> <span class="n">no</span> <span class="n">sentido</span> <span class="n">que</span> <span class="n">és</span>
<span class="n">ADV</span> <span class="n">PUNCT</span> <span class="n">DET</span> <span class="n">NOUN</span> <span class="n">DET</span> <span class="n">NOUN</span> <span class="n">PRON</span> <span class="n">AUX</span>
<span class="n">Não</span> <span class="p">,</span> <span class="n">meu</span> <span class="n">miúdo</span> <span class="n">o</span> <span class="n">sentir</span> <span class="n">que</span> <span class="n">ser</span>
<span class="n">como</span> <span class="n">uma</span> <span class="n">irmã</span> <span class="n">para</span> <span class="n">mim</span> <span class="p">.</span>
<span class="n">ADP</span> <span class="n">DET</span> <span class="n">NOUN</span> <span class="n">ADP</span> <span class="n">PRON</span> <span class="n">PUNCT</span>
<span class="n">comer</span> <span class="n">umar</span> <span class="n">irmão</span> <span class="n">parir</span> <span class="n">mim</span> <span class="p">.</span> </code></pre></figure>
<p>The lemmas are a bit strange:</p>
<ul>
<li>“no” is a contraction of “em + o”</li>
<li>“como” the lemma should not be “comer”</li>
<li>“uma” should not be “umar”</li>
<li>“para” should not be a verb</li>
</ul>
<p>Assuming the lemmas were intended to be written in this way, then they should be at least consistent. But Spacy assigns sometimes “para” as lemma and not “parir” (for example in the sentence “Para mim estão boas !”).</p>
<h3 id="stanza">Stanza</h3>
<p>Stanza came just out this year (2020). It is like Spacy quite easy to use and also provides pretrained Portuguese models. If there are any problems installing the library, try github directly <code class="language-plaintext highlighter-rouge">pip install git+https://github.com/stanfordnlp/stanza.git</code>.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">stanza</span>
<span class="n">stanza</span><span class="p">.</span><span class="n">download</span><span class="p">(</span><span class="s">'pt'</span><span class="p">)</span>
<span class="n">nlp</span> <span class="o">=</span> <span class="n">stanza</span><span class="p">.</span><span class="n">Pipeline</span><span class="p">(</span><span class="s">'pt'</span><span class="p">)</span>
<span class="n">text</span> <span class="o">=</span> <span class="s">""</span>
<span class="n">pos</span> <span class="o">=</span> <span class="s">""</span>
<span class="n">lemma</span> <span class="o">=</span> <span class="s">""</span>
<span class="k">for</span> <span class="n">sent</span> <span class="ow">in</span> <span class="n">nlp</span><span class="p">(</span><span class="s">"Não, minha miúda no sentido que és como uma irmã para mim."</span><span class="p">).</span><span class="n">sentences</span><span class="p">:</span>
<span class="k">for</span> <span class="n">word</span> <span class="ow">in</span> <span class="n">sent</span><span class="p">.</span><span class="n">words</span><span class="p">:</span>
<span class="n">text</span> <span class="o">+=</span> <span class="n">word</span><span class="p">.</span><span class="n">text</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span>
<span class="n">pos</span> <span class="o">+=</span> <span class="n">word</span><span class="p">.</span><span class="n">upos</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span>
<span class="n">lemma</span> <span class="o">+=</span> <span class="n">word</span><span class="p">.</span><span class="n">lemma</span> <span class="o">+</span> <span class="s">"</span><span class="se">\t</span><span class="s">"</span>
<span class="k">print</span><span class="p">(</span><span class="n">text</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">pos</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">lemma</span><span class="p">)</span></code></pre></figure>
<p>The results are</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">Estás</span> <span class="n">bem</span> <span class="err">?</span>
<span class="n">AUX</span> <span class="n">ADV</span> <span class="n">PUNCT</span>
<span class="n">estar</span> <span class="n">bem</span> <span class="err">?</span></code></pre></figure>
<p>and</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">Não</span> <span class="p">,</span> <span class="n">minha</span> <span class="n">miúda</span> <span class="n">em</span> <span class="n">o</span> <span class="n">sentido</span> <span class="n">que</span>
<span class="n">ADV</span> <span class="n">PUNCT</span> <span class="n">DET</span> <span class="n">NOUN</span> <span class="n">ADP</span> <span class="n">DET</span> <span class="n">NOUN</span> <span class="n">PRON</span>
<span class="n">não</span> <span class="p">,</span> <span class="n">meu</span> <span class="n">miúda</span> <span class="n">em</span> <span class="n">o</span> <span class="n">sentido</span> <span class="n">que</span>
<span class="n">és</span> <span class="n">como</span> <span class="n">uma</span> <span class="n">irmã</span> <span class="n">para</span> <span class="n">mim</span> <span class="p">.</span>
<span class="n">AUX</span> <span class="n">ADP</span> <span class="n">DET</span> <span class="n">NOUN</span> <span class="n">ADP</span> <span class="n">PRON</span> <span class="n">PUNCT</span>
<span class="n">ser</span> <span class="n">como</span> <span class="n">um</span> <span class="n">irmã</span> <span class="n">para</span> <span class="n">eu</span> <span class="p">.</span> </code></pre></figure>
<p>The results are good, but still not perfect:</p>
<ul>
<li>somehow <code class="language-plaintext highlighter-rouge">no</code> was replaced by <code class="language-plaintext highlighter-rouge">o</code> (see <code class="language-plaintext highlighter-rouge">word.text</code>)</li>
<li>miúda: should be miúdo</li>
<li>irmã: should be irmão</li>
</ul>
<h3 id="universal-lemmatizer">Universal Lemmatizer</h3>
<p>This library is a little more complex to install than stanza and spacy.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">git</span> <span class="n">clone</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="p">.</span><span class="n">com</span><span class="o">/</span><span class="n">TurkuNLP</span><span class="o">/</span><span class="n">Turku</span><span class="o">-</span><span class="n">neural</span><span class="o">-</span><span class="n">parser</span><span class="o">-</span><span class="n">pipeline</span><span class="p">.</span><span class="n">git</span>
<span class="n">cd</span> <span class="n">Turku</span><span class="o">-</span><span class="n">neural</span><span class="o">-</span><span class="n">parser</span><span class="o">-</span><span class="n">pipeline</span></code></pre></figure>
<p>Then start docker <code class="language-plaintext highlighter-rouge">sudo systemctl start docker</code> and run</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">docker</span> <span class="n">build</span> <span class="o">-</span><span class="n">t</span> <span class="s">"my_portuguese_parser"</span> <span class="o">--</span><span class="n">build</span><span class="o">-</span><span class="n">arg</span> <span class="n">models</span><span class="o">=</span><span class="n">pt_bosque</span> <span class="o">--</span><span class="n">build</span><span class="o">-</span><span class="n">arg</span> <span class="n">hardware</span><span class="o">=</span><span class="n">cpu</span> <span class="o">-</span><span class="n">f</span> <span class="n">Dockerfile</span><span class="o">-</span><span class="n">lang</span> <span class="p">.</span></code></pre></figure>
<p>Or alternatively instead of <code class="language-plaintext highlighter-rouge">pt_bosque</code>, there is also <code class="language-plaintext highlighter-rouge">pt_gsd</code> and <code class="language-plaintext highlighter-rouge">pt_pud</code>. I used <code class="language-plaintext highlighter-rouge">pt_bosque</code>, because it contains both European (CETEMPúblico) and Brazilian (CETENFolha) variants.</p>
<p>Then we can feed texts to the docker image</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">echo</span> <span class="s">"Não, minha miúda no sentido que és como uma irmã para mim."</span> <span class="o">|</span> <span class="n">docker</span> <span class="n">run</span> <span class="o">-</span><span class="n">i</span> <span class="n">my_portuguese_parser</span> <span class="n">stream</span> <span class="n">pt_bosque</span> <span class="n">parse_plaintext</span></code></pre></figure>
<p>The result is in the CoNLL-U format. The library <code class="language-plaintext highlighter-rouge">pyconll</code> can be used for parsing the following output:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="c1"># newdoc
# newpar
# sent_id = 1
# text = Não, minha miúda no sentido que és como uma irmã para mim.
</span><span class="mi">1</span> <span class="n">Não</span> <span class="n">não</span> <span class="n">INTJ</span> <span class="n">_</span> <span class="n">_</span> <span class="mi">4</span> <span class="n">advmod</span> <span class="n">_</span> <span class="n">SpaceAfter</span><span class="o">=</span><span class="n">No</span>
<span class="mi">2</span> <span class="p">,</span> <span class="p">,</span> <span class="n">PUNCT</span> <span class="n">_</span> <span class="n">_</span> <span class="mi">1</span> <span class="n">punct</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">3</span> <span class="n">minha</span> <span class="n">meu</span> <span class="n">DET</span> <span class="n">_</span> <span class="n">Gender</span><span class="o">=</span><span class="n">Fem</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span><span class="o">|</span><span class="n">PronType</span><span class="o">=</span><span class="n">Prs</span> <span class="mi">4</span><span class="n">det</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">4</span> <span class="n">miúda</span> <span class="n">miúda</span> <span class="n">NOUN</span> <span class="n">_</span> <span class="n">Gender</span><span class="o">=</span><span class="n">Fem</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span> <span class="mi">0</span> <span class="n">root</span> <span class="n">__</span>
<span class="mi">5</span><span class="o">-</span><span class="mi">6</span> <span class="n">no</span> <span class="n">_</span> <span class="n">_</span> <span class="n">_</span> <span class="n">_</span> <span class="n">_</span> <span class="n">_</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">5</span> <span class="n">em</span> <span class="n">em</span> <span class="n">ADP</span> <span class="n">_</span> <span class="n">_</span> <span class="mi">7</span> <span class="n">case</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">6</span> <span class="n">o</span> <span class="n">o</span> <span class="n">DET</span> <span class="n">_</span> <span class="n">Definite</span><span class="o">=</span><span class="n">Def</span><span class="o">|</span><span class="n">Gender</span><span class="o">=</span><span class="n">Masc</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span><span class="o">|</span><span class="n">PronType</span><span class="o">=</span><span class="n">Art</span> <span class="mi">7</span> <span class="n">det</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">7</span> <span class="n">sentido</span> <span class="n">sentido</span> <span class="n">NOUN</span> <span class="n">_</span> <span class="n">Gender</span><span class="o">=</span><span class="n">Masc</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span> <span class="mi">4</span> <span class="n">nmod</span> <span class="n">__</span>
<span class="mi">8</span> <span class="n">que</span> <span class="n">que</span> <span class="n">PRON</span> <span class="n">_</span> <span class="n">Gender</span><span class="o">=</span><span class="n">Masc</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span><span class="o">|</span><span class="n">PronType</span><span class="o">=</span><span class="n">Rel</span> <span class="mi">9</span><span class="n">nsubj</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">9</span> <span class="n">és</span> <span class="n">ser</span> <span class="n">VERB</span> <span class="n">_</span> <span class="n">Mood</span><span class="o">=</span><span class="n">Ind</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span><span class="o">|</span><span class="n">Person</span><span class="o">=</span><span class="mi">1</span><span class="o">|</span><span class="n">Tense</span><span class="o">=</span><span class="n">Pres</span><span class="o">|</span><span class="n">VerbForm</span><span class="o">=</span><span class="n">Fin</span> <span class="mi">7</span> <span class="n">acl</span><span class="p">:</span><span class="n">relcl</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">10</span> <span class="n">como</span> <span class="n">como</span> <span class="n">ADP</span> <span class="n">_</span> <span class="n">_</span> <span class="mi">12</span> <span class="n">case</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">11</span> <span class="n">uma</span> <span class="n">um</span> <span class="n">DET</span> <span class="n">_</span> <span class="n">Definite</span><span class="o">=</span><span class="n">Ind</span><span class="o">|</span><span class="n">Gender</span><span class="o">=</span><span class="n">Fem</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span><span class="o">|</span><span class="n">PronType</span><span class="o">=</span><span class="n">Art</span> <span class="mi">12</span> <span class="n">det</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">12</span> <span class="n">irmã</span> <span class="n">irmã</span> <span class="n">NOUN</span> <span class="n">_</span> <span class="n">Gender</span><span class="o">=</span><span class="n">Fem</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span> <span class="mi">9</span> <span class="n">obl</span> <span class="n">__</span>
<span class="mi">13</span> <span class="n">para</span> <span class="n">para</span> <span class="n">ADP</span> <span class="n">_</span> <span class="n">_</span> <span class="mi">14</span> <span class="n">case</span> <span class="n">_</span> <span class="n">_</span>
<span class="mi">14</span> <span class="n">mim</span> <span class="n">eu</span> <span class="n">PRON</span> <span class="n">_</span> <span class="n">Gender</span><span class="o">=</span><span class="n">Unsp</span><span class="o">|</span><span class="n">Number</span><span class="o">=</span><span class="n">Sing</span><span class="o">|</span><span class="n">Person</span><span class="o">=</span><span class="mi">1</span><span class="o">|</span><span class="n">PronType</span><span class="o">=</span><span class="n">Prs</span> <span class="mi">12</span> <span class="n">nmod</span> <span class="n">_</span> <span class="n">SpaceAfter</span><span class="o">=</span><span class="n">No</span>
<span class="mi">15</span> <span class="p">.</span> <span class="p">.</span> <span class="n">PUNCT</span> <span class="n">_</span> <span class="n">_</span> <span class="mi">4</span> <span class="n">punct</span> <span class="n">_</span> <span class="n">SpacesAfter</span><span class="o">=</span>\<span class="n">n</span></code></pre></figure>
<p>But since stanza uses also <code class="language-plaintext highlighter-rouge">pt_bosque</code>, the results are approximately the same. Still the <a href="https://arxiv.org/pdf/1902.00972.pdf">original paper</a> shows slight improvements on almost all treebanks.</p>
<p>It is also possible to lemmatize entire texts by sending POST requests to the docker image. The following bash script splits a file <code class="language-plaintext highlighter-rouge">../feed.txt</code> in lines of 10000 and appends the output to a split/parsed.conllu file.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">mkdir</span> <span class="n">split</span>
<span class="n">cd</span> <span class="n">split</span>
<span class="n">split</span> <span class="o">-</span><span class="n">l</span> <span class="mi">10000</span> <span class="p">..</span><span class="o">/</span><span class="n">feed</span><span class="p">.</span><span class="n">txt</span>
<span class="n">cd</span> <span class="p">..</span>
<span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">split</span><span class="o">/*</span><span class="p">;</span> <span class="n">do</span>
<span class="n">echo</span> <span class="err">$</span><span class="n">filename</span>
<span class="k">if</span> <span class="p">[[</span> <span class="err">$</span><span class="n">filename</span> <span class="o">==</span> <span class="n">split</span><span class="o">/</span><span class="n">x</span><span class="o">*</span> <span class="p">]]</span>
<span class="n">then</span>
<span class="n">curl</span> <span class="o">--</span><span class="n">request</span> <span class="n">POST</span> <span class="o">--</span><span class="n">header</span> <span class="s">'Content-Type: text/plain; charset=utf-8'</span> <span class="o">--</span><span class="n">data</span><span class="o">-</span><span class="n">binary</span> <span class="o">@</span><span class="s">"$filename"</span> <span class="n">http</span><span class="p">:</span><span class="o">//</span><span class="n">localhost</span><span class="p">:</span><span class="mi">15000</span> <span class="o">>></span> <span class="s">"split/parsed.conllu"</span>
<span class="n">fi</span>
<span class="n">done</span></code></pre></figure>
<p>The splitting is necessary, when millions of sentences need to be lemmatized. This could be a bug or I might simply not have enough RAM.</p>
<p>Instead of using the library <code class="language-plaintext highlighter-rouge">pyconll</code>, manual parsing can be performed as follows:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">train</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s">"split/parsed.conllu"</span><span class="p">,</span> <span class="s">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">k</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">sent</span> <span class="o">=</span> <span class="s">""</span>
<span class="n">pattern</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">f</span><span class="p">.</span><span class="n">readlines</span><span class="p">()):</span>
<span class="k">if</span> <span class="s">"# text"</span> <span class="ow">in</span> <span class="n">line</span><span class="p">:</span>
<span class="n">sent</span> <span class="o">=</span> <span class="n">line</span>
<span class="n">train</span><span class="p">.</span><span class="n">append</span><span class="p">([</span><span class="n">sent</span><span class="p">.</span><span class="n">replace</span><span class="p">(</span><span class="s">"# text ="</span><span class="p">,</span> <span class="s">""</span><span class="p">).</span><span class="n">strip</span><span class="p">(),</span> <span class="p">[]])</span>
<span class="k">if</span> <span class="s">"#"</span> <span class="ow">in</span> <span class="n">line</span> <span class="ow">or</span> <span class="s">"</span><span class="se">\n</span><span class="s">"</span> <span class="o">==</span> <span class="n">line</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">s</span> <span class="o">=</span> <span class="n">line</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="s">"</span><span class="se">\t</span><span class="s">"</span><span class="p">)</span>
<span class="k">if</span> <span class="s">"1"</span> <span class="o">==</span> <span class="n">s</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">train</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
<span class="n">train</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">][</span><span class="mi">1</span><span class="p">].</span><span class="n">extend</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
<span class="n">k</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">k</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">s</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">s</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">s</span><span class="p">[</span><span class="mi">3</span><span class="p">]))</span>
<span class="n">out</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">sent</span><span class="p">,</span> <span class="n">sentence</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">train</span><span class="p">):</span>
<span class="k">for</span> <span class="n">token</span><span class="p">,</span> <span class="n">lemma</span><span class="p">,</span> <span class="n">pos</span> <span class="ow">in</span> <span class="n">sentence</span><span class="p">:</span>
<span class="p">...</span></code></pre></figure>
<h2 id="summary">Summary</h2>
<p>The neural network based lemmatizers have gotten much better. Personally, I often use “Universal Lemmatizer” because it also works well in other languages such as German. The main alternative is stanza. This library also offers other tools such as NER (Named Entity Recognition).</p>
<p>However, no lemmatizer is perfect. It is easy to find sentences where there are obvious mistakes.</p>In this post, I will compare some lemmatizers for Portuguese. In order to do the comparison, I downloaded subtitles from various television programs. The sentences are written in European Portuguese (EP).