Wavelet Trees and full-text search indices

5 minute read

The wavelet tree is a useful data structure in many areas of computer science. One of its applications is the full-text search. See the articles [1] and [2] for more details.

In this blog post, I will implement a simple wavelet tree (WT) based on [3] and apply it to document retrieval. The WT from [3] does not use Huffman coding and replaces the bitmap with an integer array.

Construction

Let A be an array and $\Sigma$ the corresponding alphabet. Each node in a WT consists of an array $C_v[i] = \lvert\{x : A_v[x] \leq m_v \text{ for all } x \leq i\}\rvert$ where $m_v = \lfloor\frac{s_1 + s_2}{2}\rfloor$ is the middle of the array.

Example: Let $A = [0, 1, 2, 3, 4]$. Then $m_{\text{root}} = \frac{0 + 4}{2} = 2$ and $C_{\text{root}} = [1, 2, 3, 3, 3]$. The following code should make this clearer:

A = [0, 1, 2, 3, 4]
sigma = sorted(set(A))
m = (sigma[0] + sigma[-1]) // 2

C = []
s = 0
for num in A:
    if num <= m:
        s += 1
    C.append(s)
print(C)

The next step is to create the child nodes. Two new arrays $A_{\text{left}}$ and $A_{\text{right}}$ are needed. They consist of all values smaller or greater than $m_{\text{root}}$. The above code is then applied to both $A$ arrays to generate $C_{\text{left}}$ and $C_{\text{right}}$.

Finally, the full implementation only requires 20 lines of code.

class Node(object):
    def __init__(self):
        self.value = [0]
        self.left = None
        self.right = None

class WaveletTree(object):
    def __init__(self, arr, alphabet_size):
        self.root = Node()
        self._build(0, alphabet_size, arr, self.root)

    def _build(self, start, stop, arr, node):
        mid = (start + stop) // 2

        s = 0
        left = []
        right = []
        for num in arr:
            if num <= mid:
                s += 1
                left.append(num)
            else:
                right.append(num)
            node.value.append(s)

        if start == stop:
            return node

        node.left = self._build(start, mid, left, Node())
        node.right = self._build(mid+1, stop, right, Node())

        return node

The leaves of the WT contain the number of times an integer appears in an array $A$.

The following image from [3] gives a visual example of a WT. Note that the image only shows $A_v$ and not $C_v$. We actually don’t store $A_v$ in the nodes. The numbers $(4)$, $(2)$ etc. correspond to $m_v$.

Example

Rank

Before we can apply the WT to a real-world problem, the rank operation has to be explained. $\text{rank}_{q}(x)$ finds the number of times a term $q$ appears in an array $A$ up to position $x$.

The rank operation can be defined as follows:

where .

Example: In the image $\text{rank}_{3}(14) = 3$, because the number $3$ appears exactly three times up to position $14$. This can be calculated as follows (start each $C$ array at $0$):

  1. $m_\text{root} = \lfloor\frac{s_1 + s_2}{2}\rfloor = \lfloor\frac{0 + 9}{2}\rfloor = 4$, $C_{\text{root}}[15] = 8$. Since $q = 3 \leq m_\text{root}$, we choose $v_{1} = \text{left}$. Set $s_2 = m_\text{root}$.
  2. $m_\text{left} = \frac{0 + 4}{2} = 2$, $8 - C_{\text{left}}[8] = 5$. Since $q = 3 \not\leq m_\text{left}$, we choose $v_{2} = \text{right}$. Set $s_1 = m_\text{left} + 1$.
  3. $m_\text{right} = \lfloor\frac{3 + 4}{2}\rfloor = 3$, $C_{\text{right}}[5] = 3$. Since $q = 3 \leq m_\text{right}$, we choose $v_{3} = \text{left}$. Set $s_2 = m_\text{right}$.
  4. $m_\text{left} = \lfloor\frac{3 + 3}{2}\rfloor = 3$. Stop. The number $3$ appears three times.

Application

An interesting application of WTs are search engines. For example, [1] developed the following method for searching in documents:

  1. Concatenate all texts to a large string.
  2. Generate a suffix array (or compressed suffix tree) from this large text.
  3. Associate each suffix position with its corresponding document (a so-called document array).
  4. Construct a wavelet tree from the document array.

When a user enters a search term, the program will perform a binary search on the suffix array. The output is the interval which contains the search term. The next step is to do a depth-first traversal of the wavelet tree to get the corresponding documents. Finally, tf-idf is calculated on-the-fly and we return the results to the user.

I will implement here the depth-first traversal (DFT) and compare it with a naive approach based on counting. Let $\sigma$ be the length of the alphabet and $\lvert A\rvert$ the length of the array. In other words, we have $\sigma$ documents which have in total $\lvert A\rvert$ characters.

The DFT works like , except that each node gets visited and $m_v$ is ignored. Since an interval is given (and not a single number), we calculate .

from timeit import default_timer as timer
from random import randint
from collections import Counter

def dft(stop1, stop2, node, out, length):
    if not node or not node.left or not node.right:
        out[length] = stop2 - stop1
        return length+1

    length = dft(node.value[stop1], node.value[stop2], node.left, out, length)
    length = dft(stop1 - node.value[stop1], stop2 - node.value[stop2], node.right, out, length)

    return length

start_id = 0
stop_id = 300
sigma = stop_id - start_id
num_chars = 10 ** 6

arr = [randint(start_id, stop_id) for p in range(0, num_chars)]

interval = (0,len(arr)-1)

def test_wt():
    wt = WaveletTree(arr, sigma)
    t = timer()
    out = [0] * (sigma + 1)
    dft(interval[0], interval[1], wt.root, out, 0)
    print(timer() - t)

    return [(i, x) for i, x in enumerate(out)]

def test_naive():
    t = timer()
    ctn = Counter(arr[interval[0]:interval[1]])
    print(timer() - t)

    return sorted((x, y) for x, y in ctn.items())

out1 = test_wt()
out2 = test_naive()

assert len(out1) == len(out2) and out1 == out2
print(out1)
print()
print(out2)

The results show that WT is always faster than Counter, especially as the number of characters increases.

#Documents #Characters WT Counter
10 $10^5$ 1.96e-05 0.005
10 $10^6$ 3.10e-05 0.05
10 $10^7$ 3.91e-05 0.48
300 $10^5$ 0.0003 0.005
300 $10^6$ 0.0004 0.06

A word of caution, this implementation in Python requires lots of memory. There are $\lceil\log_2(\sigma)\rceil$ levels and the sum of all nodes in a level equals the number of entries in the original array $A$. For 300 documents, there are $\lceil\log_2(300)\rceil = 9$ levels. A (long) integer requires 24 bytes in Python (in C it is just 4 bytes). For 10 million characters, we have $9 \cdot 10^7 \cdot 24 = 2.16$ gigabytes plus some additional temporary variables and pointers.

Hence, it is better to use a C binding and some compression for long texts. A good C++ library which also uses Huffman coding can be found here. And here is a direct translation of my Python code in C.

References

[1] J. Culpepper, G. Navarro et al. “Top-k Ranked Document Search in General Text Databases”. http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.180.5473&rep=rep1&type=pdf

[2] S. Gog, T. Beller et al. “From Theory to Practice: Plug and Play with Succinct Data Structures”. https://arxiv.org/abs/1311.1249

[3] Robinson Castro, Nico Lehmann et al. “Wavelet Trees for Competitive Programming”. https://ioinformatics.org/journal/v10_2016_19_37.pdf

Categories:

Updated:

Comments