Skip to content

第24章 高级数据结构

学习目标

  • 掌握并查集的摊还分析与最优实现
  • 理解线段树的正确性与懒惰传播
  • 掌握字典树的形式化定义与复杂度分析
  • 理解树状数组的二进制索引原理
  • 掌握高级数据结构的选择策略

24.1 并查集深入分析

24.1.1 形式化定义

定义 24.1(并查集) 并查集(Disjoint Set Union, DSU)是一种维护动态不相交集合的数据结构,支持:

  • $\text{Make-Set}(x)$:创建单元素集合 ${x}$
  • $\text{Find}(x)$:返回包含 $x$ 的集合代表
  • $\text{Union}(x, y)$:合并包含 $x$ 和 $y$ 的集合

定义 24.2(不相交集合森林) 使用森林表示,每棵树对应一个集合,树根为代表元素。

24.1.2 摊还分析

定义 24.3(秩) 节点 $x$ 的秩 $rank(x)$ 是以 $x$ 为根的子树高度的上界。

定义 24.4(阿克曼函数)

$$A(i, j) = \begin{cases} 2^j & i = 0 \ A(i-1, 1) & i > 0, j = 0 \ A(i-1, A(i, j-1)) & i > 0, j > 0 \end{cases}$$

定义 24.5(反阿克曼函数)

$$\alpha(n) = \min{i : A(i, 1) \geq n}$$

定理 24.1(Tarjan定理) 使用路径压缩和按秩合并,$m$ 次操作的总时间为 $O(m \cdot \alpha(n))$。

证明概要

  1. 定义势能函数 $\Phi$ 基于节点的"等级"和"组号"
  2. 每次Find操作的摊还代价为 $O(\alpha(n))$
  3. 路径压缩使得后续操作更快

详细证明参见Tarjan的原始论文。 ∎

推论 24.1 对于所有实际应用,$\alpha(n) \leq 4$,故单次操作摊还时间为 $O(1)$。

24.1.3 Python实现

python
from typing import Dict, List, Set, TypeVar, Generic, Optional, Tuple

T = TypeVar('T')

class UnionFind(Generic[T]):
    def __init__(self):
        self._parent: Dict[T, T] = {}
        self._rank: Dict[T, int] = {}
        self._size: Dict[T, int] = {}
        self._count: int = 0
    
    def make_set(self, x: T) -> None:
        if x not in self._parent:
            self._parent[x] = x
            self._rank[x] = 0
            self._size[x] = 1
            self._count += 1
    
    def find(self, x: T) -> T:
        if x not in self._parent:
            self.make_set(x)
        
        if self._parent[x] != x:
            self._parent[x] = self.find(self._parent[x])
        return self._parent[x]
    
    def find_iterative(self, x: T) -> T:
        if x not in self._parent:
            self.make_set(x)
            return x
        
        root = x
        while self._parent[root] != root:
            root = self._parent[root]
        
        while self._parent[x] != root:
            next_node = self._parent[x]
            self._parent[x] = root
            x = next_node
        
        return root
    
    def union(self, x: T, y: T) -> bool:
        px, py = self.find(x), self.find(y)
        
        if px == py:
            return False
        
        if self._rank[px] < self._rank[py]:
            px, py = py, px
        
        self._parent[py] = px
        self._size[px] += self._size[py]
        
        if self._rank[px] == self._rank[py]:
            self._rank[px] += 1
        
        self._count -= 1
        return True
    
    def connected(self, x: T, y: T) -> bool:
        return self.find(x) == self.find(y)
    
    def component_size(self, x: T) -> int:
        return self._size[self.find(x)]
    
    def count(self) -> int:
        return self._count
    
    def components(self) -> Dict[T, Set[T]]:
        result: Dict[T, Set[T]] = {}
        for x in self._parent:
            root = self.find(x)
            if root not in result:
                result[root] = set()
            result[root].add(x)
        return result

class UnionFindArray:
    def __init__(self, n: int):
        self._parent = list(range(n))
        self._rank = [0] * n
        self._size = [1] * n
        self._count = n
    
    def find(self, x: int) -> int:
        if self._parent[x] != x:
            self._parent[x] = self.find(self._parent[x])
        return self._parent[x]
    
    def union(self, x: int, y: int) -> bool:
        px, py = self.find(x), self.find(y)
        
        if px == py:
            return False
        
        if self._rank[px] < self._rank[py]:
            px, py = py, px
        
        self._parent[py] = px
        self._size[px] += self._size[py]
        
        if self._rank[px] == self._rank[py]:
            self._rank[px] += 1
        
        self._count -= 1
        return True
    
    def connected(self, x: int, y: int) -> bool:
        return self.find(x) == self.find(y)
    
    def component_size(self, x: int) -> int:
        return self._size[self.find(x)]
    
    def count(self) -> int:
        return self._count

24.2 线段树

24.2.1 形式化定义

定义 24.6(线段树) 线段树是一种二叉树结构,用于维护区间信息。每个节点表示一个区间 $[l, r]$,叶子节点表示单点区间。

定义 24.7(线段树性质)

  • 根节点表示整个区间 $[1, n]$
  • 对于节点 $u$ 表示 $[l, r]$:
    • 若 $l = r$,为叶子节点
    • 否则,左孩子表示 $[l, mid]$,右孩子表示 $[mid+1, r]$,其中 $mid = \lfloor(l+r)/2\rfloor$

定理 24.2(线段树空间) 线段树需要 $O(n)$ 空间,具体为 $2n$ 到 $4n$ 个节点。

证明:线段树高度为 $\lceil \log_2 n \rceil$。节点总数:

$$\sum_{i=0}^{\lceil \log_2 n \rceil} 2^i \leq 2^{\lceil \log_2 n \rceil + 1} \leq 4n$$

实际使用 $2n$ 个节点(对于完全二叉树)到 $4n$(考虑数组实现)。 ∎

24.2.2 操作复杂度

定理 24.3 线段树支持:

  • 单点更新:$O(\log n)$
  • 区间查询:$O(\log n)$
  • 区间更新(懒惰传播):$O(\log n)$

证明:树高度为 $O(\log n)$,每次操作沿树的一条路径进行。 ∎

24.2.3 懒惰传播

定义 24.8(懒惰传播) 懒惰传播延迟区间更新操作,仅在需要时下推标记。

定理 24.4(懒惰传播正确性) 懒惰传播保证查询结果正确。

证明:懒惰标记记录未下推的更新。查询时,沿路径下推所有相关标记,确保访问的节点信息正确。 ∎

24.2.4 Python实现

python
from typing import List, TypeVar, Generic, Callable, Optional

T = TypeVar('T')

class SegmentTree(Generic[T]):
    def __init__(self, data: List[T], merge: Callable[[T, T], T], identity: T):
        self.n = len(data)
        self.merge = merge
        self.identity = identity
        self.size = 1
        while self.size < self.n:
            self.size *= 2
        self.tree = [identity] * (2 * self.size)
        
        for i in range(self.n):
            self.tree[self.size + i] = data[i]
        
        for i in range(self.size - 1, 0, -1):
            self.tree[i] = merge(self.tree[2 * i], self.tree[2 * i + 1])
    
    def update(self, index: int, value: T) -> None:
        if index < 0 or index >= self.n:
            return
        
        pos = self.size + index
        self.tree[pos] = value
        
        while pos > 1:
            pos //= 2
            self.tree[pos] = self.merge(self.tree[2 * pos], self.tree[2 * pos + 1])
    
    def query(self, left: int, right: int) -> T:
        if left < 0:
            left = 0
        if right > self.n:
            right = self.n
        if left >= right:
            return self.identity
        
        left += self.size
        right += self.size
        
        result_left = self.identity
        result_right = self.identity
        
        while left < right:
            if left % 2 == 1:
                result_left = self.merge(result_left, self.tree[left])
                left += 1
            if right % 2 == 1:
                right -= 1
                result_right = self.merge(self.tree[right], result_right)
            left //= 2
            right //= 2
        
        return self.merge(result_left, result_right)
    
    def get(self, index: int) -> T:
        if 0 <= index < self.n:
            return self.tree[self.size + index]
        return self.identity

class LazySegmentTree:
    def __init__(self, data: List[int]):
        self.n = len(data)
        self.size = 1
        while self.size < self.n:
            self.size *= 2
        self.tree = [0] * (2 * self.size)
        self.lazy = [0] * (2 * self.size)
        
        for i in range(self.n):
            self.tree[self.size + i] = data[i]
        
        for i in range(self.size - 1, 0, -1):
            self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
    
    def _apply(self, node: int, start: int, end: int, value: int) -> None:
        self.tree[node] += value * (end - start)
        self.lazy[node] += value
    
    def _push(self, node: int, start: int, end: int) -> None:
        if self.lazy[node] != 0:
            mid = (start + end) // 2
            self._apply(2 * node, start, mid, self.lazy[node])
            self._apply(2 * node + 1, mid, end, self.lazy[node])
            self.lazy[node] = 0
    
    def range_add(self, left: int, right: int, value: int) -> None:
        self._range_add(1, 0, self.size, left, right, value)
    
    def _range_add(self, node: int, start: int, end: int, left: int, right: int, value: int) -> None:
        if left >= end or right <= start:
            return
        
        if left <= start and end <= right:
            self._apply(node, start, end, value)
            return
        
        self._push(node, start, end)
        
        mid = (start + end) // 2
        self._range_add(2 * node, start, mid, left, right, value)
        self._range_add(2 * node + 1, mid, end, left, right, value)
        
        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
    
    def range_sum(self, left: int, right: int) -> int:
        return self._range_sum(1, 0, self.size, left, right)
    
    def _range_sum(self, node: int, start: int, end: int, left: int, right: int) -> int:
        if left >= end or right <= start:
            return 0
        
        if left <= start and end <= right:
            return self.tree[node]
        
        self._push(node, start, end)
        
        mid = (start + end) // 2
        left_sum = self._range_sum(2 * node, start, mid, left, right)
        right_sum = self._range_sum(2 * node + 1, mid, end, left, right)
        
        return left_sum + right_sum

24.3 字典树(Trie)

24.3.1 形式化定义

定义 24.9(字典树) 字典树(Trie)是一种树形数据结构,用于存储字符串集合。每条边标记一个字符,从根到某节点的路径表示一个字符串。

定义 24.10(Trie节点) 每个节点包含:

  • 子节点映射:$\delta: \Sigma \rightarrow \text{Node}$
  • 终止标记:指示是否有字符串在此结束

定理 24.5(Trie空间复杂度) 存储 $n$ 个总长度为 $L$ 的字符串,Trie空间复杂度为 $O(L \cdot |\Sigma|)$。

证明:每个字符对应一条边,共 $L$ 条边。每个节点最多 $|\Sigma|$ 个子节点指针。 ∎

24.3.2 操作复杂度

定理 24.6 Trie支持:

  • 插入:$O(m)$,其中 $m$ 为字符串长度
  • 查找:$O(m)$
  • 前缀查找:$O(m)$

证明:操作沿树的一条路径进行,路径长度为字符串长度 $m$。 ∎

24.3.3 Python实现

python
from typing import Dict, List, Set, Optional, Tuple
from collections import defaultdict

class TrieNode:
    def __init__(self):
        self.children: Dict[str, 'TrieNode'] = {}
        self.is_end: bool = False
        self.count: int = 0

class Trie:
    def __init__(self):
        self.root = TrieNode()
        self._size = 0
    
    def insert(self, word: str) -> None:
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
            node.count += 1
        
        if not node.is_end:
            node.is_end = True
            self._size += 1
    
    def search(self, word: str) -> bool:
        node = self._find_node(word)
        return node is not None and node.is_end
    
    def starts_with(self, prefix: str) -> bool:
        return self._find_node(prefix) is not None
    
    def count_prefix(self, prefix: str) -> int:
        node = self._find_node(prefix)
        return node.count if node else 0
    
    def _find_node(self, prefix: str) -> Optional[TrieNode]:
        node = self.root
        for char in prefix:
            if char not in node.children:
                return None
            node = node.children[char]
        return node
    
    def delete(self, word: str) -> bool:
        if not self.search(word):
            return False
        
        node = self.root
        path = [node]
        
        for char in word:
            node = node.children[char]
            node.count -= 1
            path.append(node)
        
        node.is_end = False
        self._size -= 1
        
        for i in range(len(path) - 1, 0, -1):
            current = path[i]
            if current.count == 0 and not current.is_end:
                parent = path[i - 1]
                char = word[i - 1]
                del parent.children[char]
        
        return True
    
    def words_with_prefix(self, prefix: str) -> List[str]:
        node = self._find_node(prefix)
        if node is None:
            return []
        
        result = []
        self._collect_words(node, prefix, result)
        return result
    
    def _collect_words(self, node: TrieNode, prefix: str, result: List[str]) -> None:
        if node.is_end:
            result.append(prefix)
        
        for char, child in node.children.items():
            self._collect_words(child, prefix + char, result)
    
    def longest_common_prefix(self) -> str:
        if self._size == 0:
            return ""
        
        prefix = []
        node = self.root
        
        while len(node.children) == 1 and not node.is_end:
            char = next(iter(node.children))
            prefix.append(char)
            node = node.children[char]
        
        return "".join(prefix)
    
    def __len__(self) -> int:
        return self._size

class CompressedTrie:
    def __init__(self):
        self.root: Dict[str, any] = {}
        self._size = 0
    
    def insert(self, word: str) -> None:
        node = self.root
        
        for i, char in enumerate(word):
            if char not in node:
                node[char] = {}
            
            if i == len(word) - 1:
                if '_end' not in node[char]:
                    node[char]['_end'] = True
                    self._size += 1
            
            node = node[char]
    
    def search(self, word: str) -> bool:
        node = self.root
        for char in word:
            if char not in node:
                return False
            node = node[char]
        return node.get('_end', False)

24.4 树状数组

24.4.1 形式化定义

定义 24.11(树状数组) 树状数组(Binary Indexed Tree, BIT)利用整数的二进制表示,实现高效的前缀和查询与单点更新。

定义 24.12(最低有效位) 对于整数 $i$,其最低有效位(LSB)定义为:

$$\text{LSB}(i) = i \land (-i)$$

其中 $\land$ 表示按位与运算。

定义 24.13(BIT数组定义) 对于原数组 $A[1..n]$,BIT数组 $B[1..n]$ 定义为:

$$B[i] = \sum_{j=i-\text{LSB}(i)+1}^{i} A[j]$$

定理 24.7(BIT正确性) BIT正确维护前缀和。

证明:对于前缀和 $S[k] = \sum_{i=1}^{k} A[i]$,通过累加 $B[k], B[k-\text{LSB}(k)], \ldots$ 直到 $0$,覆盖所有需要的 $A[i]$。 ∎

24.4.2 复杂度分析

定理 24.8 BIT支持:

  • 单点更新:$O(\log n)$
  • 前缀查询:$O(\log n)$
  • 空间:$O(n)$

证明:每次操作沿树的祖先/后代链移动,链长为 $O(\log n)$。 ∎

24.4.3 Python实现

python
from typing import List, Optional

class BinaryIndexedTree:
    def __init__(self, n: int):
        self.n = n
        self.tree = [0] * (n + 1)
    
    def update(self, index: int, delta: int) -> None:
        if index < 1 or index > self.n:
            return
        
        while index <= self.n:
            self.tree[index] += delta
            index += index & (-index)
    
    def query(self, index: int) -> int:
        if index < 0:
            return 0
        if index > self.n:
            index = self.n
        
        result = 0
        while index > 0:
            result += self.tree[index]
            index -= index & (-index)
        return result
    
    def range_query(self, left: int, right: int) -> int:
        if left > right:
            return 0
        return self.query(right) - self.query(left - 1)
    
    @staticmethod
    def from_array(arr: List[int]) -> 'BinaryIndexedTree':
        n = len(arr)
        bit = BinaryIndexedTree(n)
        for i, val in enumerate(arr, 1):
            bit.update(i, val)
        return bit

class BinaryIndexedTree2D:
    def __init__(self, rows: int, cols: int):
        self.rows = rows
        self.cols = cols
        self.tree = [[0] * (cols + 1) for _ in range(rows + 1)]
    
    def update(self, row: int, col: int, delta: int) -> None:
        if row < 1 or row > self.rows or col < 1 or col > self.cols:
            return
        
        i = row
        while i <= self.rows:
            j = col
            while j <= self.cols:
                self.tree[i][j] += delta
                j += j & (-j)
            i += i & (-i)
    
    def query(self, row: int, col: int) -> int:
        if row < 0 or col < 0:
            return 0
        if row > self.rows:
            row = self.rows
        if col > self.cols:
            col = self.cols
        
        result = 0
        i = row
        while i > 0:
            j = col
            while j > 0:
                result += self.tree[i][j]
                j -= j & (-j)
            i -= i & (-i)
        return result
    
    def range_query(self, r1: int, c1: int, r2: int, c2: int) -> int:
        return (self.query(r2, c2) - self.query(r1 - 1, c2) - 
                self.query(r2, c1 - 1) + self.query(r1 - 1, c1 - 1))

24.5 稀疏表

24.5.1 形式化定义

定义 24.14(稀疏表) 稀疏表是一种静态数据结构,支持 $O(1)$ 的区间最小/最大值查询(RMQ)。

定义 24.15(ST表定义) $st[i][j]$ 表示从位置 $i$ 开始长度为 $2^j$ 的区间的最值:

$$st[i][j] = \min_{k=i}^{i+2^j-1} A[k]$$

定理 24.9(ST表预处理) 预处理时间 $O(n \log n)$,空间 $O(n \log n)$。

证明:共有 $\lceil \log_2 n \rceil$ 层,每层 $O(n)$ 个元素。 ∎

24.5.2 Python实现

python
from typing import List, Callable, TypeVar, Generic
import math

T = TypeVar('T')

class SparseTable(Generic[T]):
    def __init__(self, data: List[T], merge: Callable[[T, T], T]):
        self.n = len(data)
        self.merge = merge
        self.log = [0] * (self.n + 1)
        
        for i in range(2, self.n + 1):
            self.log[i] = self.log[i // 2] + 1
        
        self.k = self.log[self.n] + 1
        self.st = [[data[i] for i in range(self.n)] for _ in range(self.k)]
        
        for j in range(1, self.k):
            for i in range(self.n - (1 << j) + 1):
                self.st[j][i] = merge(self.st[j-1][i], self.st[j-1][i + (1 << (j-1))])
    
    def query(self, left: int, right: int) -> T:
        if left < 0:
            left = 0
        if right >= self.n:
            right = self.n - 1
        if left > right:
            raise ValueError("Invalid range")
        
        j = self.log[right - left + 1]
        return self.merge(self.st[j][left], self.st[j][right - (1 << j) + 1])

class RMQ(SparseTable[int]):
    def __init__(self, data: List[int], query_min: bool = True):
        merge = min if query_min else max
        super().__init__(data, merge)

24.6 数据结构比较

24.6.1 复杂度对比

数据结构预处理查询更新空间
并查集$O(n)$$O(\alpha(n))$$O(\alpha(n))$$O(n)$
线段树$O(n)$$O(\log n)$$O(\log n)$$O(n)$
树状数组$O(n \log n)$$O(\log n)$$O(\log n)$$O(n)$
稀疏表$O(n \log n)$$O(1)$不支持$O(n \log n)$
字典树$O(L)$$O(m)$$O(m)$$O(L \cdot |\Sigma|)$

24.6.2 选择策略

场景推荐结构原因
连通性查询并查集最优摊还复杂度
区间求和树状数组实现简单,高效
区间最值线段树/稀疏表支持更新选线段树
字符串前缀字典树天然前缀结构

24.7 本章小结

本章介绍了高级数据结构:

  1. 并查集:路径压缩与按秩合并,摊还 $O(\alpha(n))$
  2. 线段树:区间查询与更新,懒惰传播
  3. 字典树:字符串前缀匹配,$O(m)$ 查询
  4. 树状数组:利用二进制索引,高效前缀和
  5. 稀疏表:静态RMQ,$O(1)$ 查询

参考文献

  1. Tarjan, R. E. (1975). Efficiency of a good but not linear set union algorithm. Journal of the ACM, 22(2), 215-225.
  2. Fredman, M. L., & Willard, D. E. (1993). Trans-dichotomous algorithms for minimum spanning trees and shortest paths. Journal of Computer and System Sciences, 48(3), 533-551.
  3. Bentley, J. L. (1980). Decomposable searching problems. Information Processing Letters, 8(5), 244-251.
  4. Cormen, T. H., et al. (2009). Introduction to Algorithms, 3rd ed. MIT Press.

下一章:第25章 算法设计范式

Python技术丛书 - 江苏省宿城中等专业学校