Skip to content

第23章 高级搜索技术

学习目标

  • 掌握状态空间搜索的形式化定义
  • 理解回溯算法的正确性与剪枝优化
  • 掌握分支限界算法的理论基础
  • 理解A*算法的可采纳性与最优性证明
  • 掌握启发式函数的设计原则

23.1 状态空间搜索理论

23.1.1 形式化定义

定义 23.1(状态空间) 状态空间是一个四元组 $(S, A, T, G)$:

  • $S$:状态集合
  • $A$:动作集合
  • $T: S \times A \rightarrow S$:转移函数
  • $G \subseteq S$:目标状态集合

定义 23.2(搜索问题) 给定初始状态 $s_0 \in S$,找到从 $s_0$ 到某个 $g \in G$ 的路径。

定义 23.3(搜索树) 搜索树是从初始状态出发,通过扩展节点生成的树结构。

23.1.2 搜索策略分类

定义 23.4(搜索策略)

  • 无信息搜索:不使用启发信息,如BFS、DFS
  • 有信息搜索:使用启发函数,如A*、贪心最佳优先
  • 系统搜索:保证找到解(若存在),如回溯
  • 随机搜索:引入随机性,如蒙特卡洛树搜索

23.2 回溯算法

23.2.1 基本原理

定义 23.5(回溯) 回溯是一种系统搜索方法,通过深度优先遍历搜索树,遇到死路时回退。

算法 23.1(回溯框架)

Backtrack(state):
    if IsComplete(state):
        RecordSolution(state)
        return
    
    for each candidate in GetCandidates(state):
        if IsValid(state, candidate):
            Apply(state, candidate)
            Backtrack(state)
            Undo(state, candidate)

23.2.2 正确性证明

定理 23.1(回溯完备性) 若解存在,回溯算法必能找到。

证明:回溯算法遍历搜索树的所有节点。若解存在,对应搜索树中某条路径,必被遍历到。 ∎

定理 23.2(回溯复杂度) 设搜索树深度为 $d$,分支因子为 $b$,回溯的时间复杂度为 $O(b^d)$。

23.2.3 剪枝优化

定义 23.6(剪枝) 剪枝是在搜索过程中提前排除不可能产生解的子树。

定理 23.3(剪枝正确性) 正确的剪枝不丢失任何解。

证明:剪枝条件基于问题约束。若某状态违反约束,其所有后继状态也违反约束(单调性),故可安全剪除。 ∎

23.2.4 Python实现

python
from typing import List, TypeVar, Callable, Optional, Generator

T = TypeVar('T')

class Backtracking:
    def __init__(self):
        self.solutions = []
    
    def permute(self, nums: List[T]) -> List[List[T]]:
        result = []
        
        def backtrack(path: List[T], remaining: List[T]) -> None:
            if not remaining:
                result.append(path[:])
                return
            
            for i, num in enumerate(remaining):
                path.append(num)
                backtrack(path, remaining[:i] + remaining[i+1:])
                path.pop()
        
        backtrack([], nums)
        return result
    
    def combine(self, n: int, k: int) -> List[List[int]]:
        result = []
        
        def backtrack(start: int, path: List[int]) -> None:
            if len(path) == k:
                result.append(path[:])
                return
            
            for i in range(start, n + 1):
                path.append(i)
                backtrack(i + 1, path)
                path.pop()
        
        backtrack(1, [])
        return result
    
    def subsets(self, nums: List[T]) -> List[List[T]]:
        result = []
        
        def backtrack(start: int, path: List[T]) -> None:
            result.append(path[:])
            
            for i in range(start, len(nums)):
                path.append(nums[i])
                backtrack(i + 1, path)
                path.pop()
        
        backtrack(0, [])
        return result

def n_queens(n: int) -> List[List[int]]:
    result = []
    
    def is_valid(board: List[int], row: int, col: int) -> bool:
        for i in range(row):
            if board[i] == col:
                return False
            if abs(board[i] - col) == row - i:
                return False
        return True
    
    def backtrack(row: int, board: List[int]) -> None:
        if row == n:
            result.append(board[:])
            return
        
        for col in range(n):
            if is_valid(board, row, col):
                board.append(col)
                backtrack(row + 1, board)
                board.pop()
    
    backtrack(0, [])
    return result

def n_queens_count(n: int) -> int:
    count = 0
    cols = [False] * n
    diag1 = [False] * (2 * n - 1)
    diag2 = [False] * (2 * n - 1)
    
    def backtrack(row: int) -> None:
        nonlocal count
        if row == n:
            count += 1
            return
        
        for col in range(n):
            d1, d2 = row + col, row - col + n - 1
            if not cols[col] and not diag1[d1] and not diag2[d2]:
                cols[col] = diag1[d1] = diag2[d2] = True
                backtrack(row + 1)
                cols[col] = diag1[d1] = diag2[d2] = False
    
    backtrack(0)
    return count

def solve_sudoku(board: List[List[int]]) -> bool:
    def is_valid(row: int, col: int, num: int) -> bool:
        for i in range(9):
            if board[row][i] == num or board[i][col] == num:
                return False
        
        box_row, box_col = 3 * (row // 3), 3 * (col // 3)
        for i in range(box_row, box_row + 3):
            for j in range(box_col, box_col + 3):
                if board[i][j] == num:
                    return False
        return True
    
    def solve() -> bool:
        for i in range(9):
            for j in range(9):
                if board[i][j] == 0:
                    for num in range(1, 10):
                        if is_valid(i, j, num):
                            board[i][j] = num
                            if solve():
                                return True
                            board[i][j] = 0
                    return False
        return True
    
    return solve()

def generate_parentheses(n: int) -> List[str]:
    result = []
    
    def backtrack(current: str, open_count: int, close_count: int) -> None:
        if len(current) == 2 * n:
            result.append(current)
            return
        
        if open_count < n:
            backtrack(current + '(', open_count + 1, close_count)
        
        if close_count < open_count:
            backtrack(current + ')', open_count, close_count + 1)
    
    backtrack('', 0, 0)
    return result

23.3 分支限界算法

23.3.1 基本原理

定义 23.7(分支限界) 分支限界是一种系统搜索方法,使用界限函数剪除不可能产生最优解的子树。

定义 23.8(下界函数) 下界函数 $LB(s)$ 给出从状态 $s$ 出发能得到的解的最小代价。

定义 23.9(上界函数) 上界函数 $UB(s)$ 给出从状态 $s$ 出发能得到的解的最大代价。

23.3.2 正确性证明

定理 23.4(分支限界正确性) 若下界函数正确,分支限界算法返回最优解。

证明:设当前最优解代价为 $c^$。对于状态 $s$,若 $LB(s) > c^$,则从 $s$ 出发的任何解代价都大于 $c^*$,不可能更优,可安全剪除。 ∎

23.3.3 FIFO vs LC分支限界

定义 23.10(FIFO分支限界) 使用队列存储活节点,按广度优先顺序扩展。

定义 23.11(LC分支限界) 使用优先队列存储活节点,按代价估计值排序扩展。

23.3.4 Python实现

python
from typing import List, Tuple, Optional, Generic, TypeVar
from dataclasses import dataclass, field
from heapq import heappush, heappop
import math

T = TypeVar('T')

@dataclass(order=True)
class TSPNode:
    lower_bound: float
    path: List[int] = field(compare=False)
    visited: set = field(compare=False)
    current: int = field(compare=False)
    cost: float = field(compare=False)

def tsp_branch_and_bound(distances: List[List[float]]) -> Tuple[float, List[int]]:
    n = len(distances)
    if n == 0:
        return 0, []
    if n == 1:
        return 0, [0]
    
    def compute_lower_bound(path: List[int], visited: set) -> float:
        lb = 0
        
        for i in range(len(path) - 1):
            lb += distances[path[i]][path[i + 1]]
        
        if len(path) > 0:
            last = path[-1]
            min_out = min((distances[last][j] for j in range(n) if j not in visited), default=0)
            lb += min_out
        
        for i in range(n):
            if i not in visited:
                min_edges = sorted((distances[i][j] for j in range(n) if i != j))[:2]
                if len(min_edges) >= 2:
                    lb += (min_edges[0] + min_edges[1]) / 2
                elif len(min_edges) == 1:
                    lb += min_edges[0]
        
        return lb
    
    initial_lb = compute_lower_bound([0], {0})
    initial_node = TSPNode(initial_lb, [0], {0}, 0, 0)
    
    pq = [initial_node]
    best_cost = math.inf
    best_path = []
    
    while pq:
        node = heappop(pq)
        
        if node.lower_bound >= best_cost:
            continue
        
        if len(node.path) == n:
            total_cost = node.cost + distances[node.current][0]
            if total_cost < best_cost:
                best_cost = total_cost
                best_path = node.path + [0]
            continue
        
        for next_city in range(n):
            if next_city not in node.visited:
                new_path = node.path + [next_city]
                new_visited = node.visited | {next_city}
                new_cost = node.cost + distances[node.current][next_city]
                new_lb = compute_lower_bound(new_path, new_visited)
                
                if new_lb < best_cost:
                    new_node = TSPNode(new_lb, new_path, new_visited, next_city, new_cost)
                    heappush(pq, new_node)
    
    return best_cost, best_path

def knapsack_branch_and_bound(values: List[int], weights: List[int], capacity: int) -> Tuple[int, List[int]]:
    n = len(values)
    
    @dataclass(order=True)
    class Node:
        upper_bound: float
        level: int = field(compare=False)
        value: int = field(compare=False)
        weight: int = field(compare=False)
        selected: List[int] = field(compare=False)
    
    def compute_upper_bound(level: int, current_value: int, current_weight: int) -> float:
        if current_weight > capacity:
            return -1
        
        ub = current_value
        remaining = capacity - current_weight
        i = level
        
        while i < n and weights[i] <= remaining:
            ub += values[i]
            remaining -= weights[i]
            i += 1
        
        if i < n:
            ub += (remaining / weights[i]) * values[i]
        
        return ub
    
    items = sorted(zip(values, weights, range(n)), key=lambda x: x[0]/x[1] if x[1] > 0 else 0, reverse=True)
    sorted_values = [v for v, w, i in items]
    sorted_weights = [w for v, w, i in items]
    indices = [i for v, w, i in items]
    
    initial_ub = compute_upper_bound(0, 0, 0)
    initial_node = Node(initial_ub, 0, 0, 0, [])
    
    pq = [initial_node]
    best_value = 0
    best_selection = []
    
    while pq:
        node = heappop(pq)
        
        if node.upper_bound <= best_value:
            continue
        
        if node.level == n:
            if node.value > best_value:
                best_value = node.value
                best_selection = node.selected
            continue
        
        idx = node.level
        
        new_weight = node.weight + sorted_weights[idx]
        if new_weight <= capacity:
            new_value = node.value + sorted_values[idx]
            new_ub = compute_upper_bound(idx + 1, new_value, new_weight)
            if new_ub > best_value:
                new_node = Node(-new_ub, idx + 1, new_value, new_weight, node.selected + [indices[idx]])
                heappush(pq, new_node)
        
        new_ub = compute_upper_bound(idx + 1, node.value, node.weight)
        if new_ub > best_value:
            new_node = Node(-new_ub, idx + 1, node.value, node.weight, node.selected[:])
            heappush(pq, new_node)
    
    return best_value, sorted(best_selection)

23.4 A*算法

23.4.1 基本原理

定义 23.12(A*算法) A*算法是一种最佳优先搜索,使用评估函数:

$$f(n) = g(n) + h(n)$$

其中 $g(n)$ 是从起点到 $n$ 的实际代价,$h(n)$ 是从 $n$ 到目标的启发式估计。

23.4.2 可采纳性与一致性

定义 23.13(可采纳性) 启发函数 $h$ 是可采纳的,若对所有状态 $n$:

$$h(n) \leq h^*(n)$$

其中 $h^*(n)$ 是从 $n$ 到目标的真实最小代价。

定义 23.14(一致性/单调性) 启发函数 $h$ 是一致的,若对所有节点 $n$ 和其后继 $n'$:

$$h(n) \leq c(n, n') + h(n')$$

定理 23.5 一致性蕴含可采纳性。

证明:设 $n$ 为任意节点,$n^*$ 为目标节点。沿最优路径展开一致性不等式:

$$h(n) \leq c(n, n_1) + h(n_1) \leq c(n, n_1) + c(n_1, n_2) + h(n_2) \leq \ldots \leq h^*(n)$$

23.4.3 最优性证明

定理 23.6(A*最优性) 若启发函数可采纳,A*算法找到最优解。

证明:设 $G$ 为A找到的目标节点,代价为 $g(G)$。设 $G^$ 为最优目标,代价为 $g(G^*)$。

假设 $g(G) > g(G^)$。当A选择 $G$ 扩展时,$G^*$ 的某个祖先 $n$ 必在开放表中。

由可采纳性,$h(n) \leq h^*(n)$,故:

$$f(n) = g(n) + h(n) \leq g(n) + h^(n) = g(G^)$$

而 $f(G) = g(G) + h(G) = g(G) > g(G^*) \geq f(n)$。

这与A选择 $f$ 值最小的节点矛盾。故 $g(G) \leq g(G^)$,A*找到最优解。 ∎

定理 23.7(A*最优效率) 若启发函数一致,A是最优高效的:不会扩展 $f(n) > g(G^)$ 的节点。

23.4.4 Python实现

python
from typing import List, Tuple, Dict, Set, Callable, Generic, TypeVar, Optional
from heapq import heappush, heappop
from dataclasses import dataclass, field
import math

T = TypeVar('T')

@dataclass(order=True)
class AStarNode:
    f_score: float
    g_score: float = field(compare=False)
    state: T = field(compare=False)
    parent: Optional['AStarNode'] = field(compare=False, default=None)

def astar_search(
    start: T,
    is_goal: Callable[[T], bool],
    get_neighbors: Callable[[T], List[Tuple[T, float]]],
    heuristic: Callable[[T], float]
) -> Optional[Tuple[List[T], float]]:
    open_set: List[AStarNode] = []
    closed_set: Set[T] = set()
    g_scores: Dict[T, float] = {start: 0}
    
    start_node = AStarNode(heuristic(start), 0, start)
    heappush(open_set, start_node)
    
    while open_set:
        current = heappop(open_set)
        
        if is_goal(current.state):
            path = []
            node = current
            while node:
                path.append(node.state)
                node = node.parent
            return path[::-1], current.g_score
        
        if current.state in closed_set:
            continue
        closed_set.add(current.state)
        
        for neighbor, cost in get_neighbors(current.state):
            if neighbor in closed_set:
                continue
            
            tentative_g = current.g_score + cost
            
            if neighbor not in g_scores or tentative_g < g_scores[neighbor]:
                g_scores[neighbor] = tentative_g
                f_score = tentative_g + heuristic(neighbor)
                new_node = AStarNode(f_score, tentative_g, neighbor, current)
                heappush(open_set, new_node)
    
    return None

class GridAStar:
    def __init__(self, grid: List[List[int]]):
        self.grid = grid
        self.rows = len(grid)
        self.cols = len(grid[0]) if grid else 0
    
    def is_valid(self, row: int, col: int) -> bool:
        return 0 <= row < self.rows and 0 <= col < self.cols and self.grid[row][col] == 0
    
    def heuristic(self, a: Tuple[int, int], b: Tuple[int, int]) -> float:
        return abs(a[0] - b[0]) + abs(a[1] - b[1])
    
    def find_path(self, start: Tuple[int, int], goal: Tuple[int, int]) -> Optional[List[Tuple[int, int]]]:
        def is_goal(state):
            return state == goal
        
        def get_neighbors(state):
            row, col = state
            neighbors = []
            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                nr, nc = row + dr, col + dc
                if self.is_valid(nr, nc):
                    neighbors.append(((nr, nc), 1))
            return neighbors
        
        def h(state):
            return self.heuristic(state, goal)
        
        result = astar_search(start, is_goal, get_neighbors, h)
        return result[0] if result else None

def ida_star(
    start: T,
    is_goal: Callable[[T], bool],
    get_neighbors: Callable[[T], List[Tuple[T, float]]],
    heuristic: Callable[[T], float]
) -> Optional[Tuple[List[T], float]]:
    def search(path: List[T], g: float, bound: float) -> Tuple[float, Optional[List[T]]]:
        node = path[-1]
        f = g + heuristic(node)
        
        if f > bound:
            return f, None
        
        if is_goal(node):
            return f, path[:]
        
        min_bound = math.inf
        
        for neighbor, cost in get_neighbors(node):
            if neighbor not in path:
                path.append(neighbor)
                t, result = search(path, g + cost, bound)
                
                if result is not None:
                    return t, result
                
                if t < min_bound:
                    min_bound = t
                
                path.pop()
        
        return min_bound, None
    
    bound = heuristic(start)
    path = [start]
    
    while True:
        t, result = search(path, 0, bound)
        
        if result is not None:
            return result, t
        
        if t == math.inf:
            return None
        
        bound = t

23.5 启发式函数设计

23.5.1 设计原则

定理 23.8(启发函数比较) 若 $h_1$ 和 $h_2$ 都可采纳且 $h_1(n) \geq h_2(n)$ 对所有 $n$ 成立,则使用 $h_1$ 的A扩展节点数不多于使用 $h_2$ 的A

证明:设 $n$ 被使用 $h_1$ 的A扩展。则 $f_1(n) = g(n) + h_1(n) \leq g(G^)$。由 $h_1 \geq h_2$:

$$f_2(n) = g(n) + h_2(n) \leq g(n) + h_1(n) = f_1(n) \leq g(G^*)$$

故 $n$ 也被使用 $h_2$ 的A*扩展。 ∎

23.5.2 常见启发函数

曼哈顿距离

$$h(n) = |x_n - x_g| + |y_n - y_g|$$

欧几里得距离

$$h(n) = \sqrt{(x_n - x_g)^2 + (y_n - y_g)^2}$$

切比雪夫距离

$$h(n) = \max(|x_n - x_g|, |y_n - y_g|)$$

23.5.3 松弛问题

定义 23.15(松弛问题) 松弛问题放宽原问题的约束,其最优解代价是原问题的可采纳启发值。

定理 23.9 松弛问题的最优解代价是可采纳的启发函数。

证明:松弛问题允许更多解,故其最优代价不超过原问题最优代价。 ∎


23.6 算法比较

23.6.1 复杂度对比

算法时间复杂度空间复杂度完备性最优性
回溯$O(b^d)$$O(d)$
分支限界$O(b^d)$$O(b^d)$
A*$O(b^d)$$O(b^d)$是(可采纳h)
IDA*$O(b^d)$$O(d)$是(可采纳h)

23.6.2 选择策略

场景推荐算法原因
约束满足问题回溯系统搜索解空间
组合优化分支限界保证最优解
路径规划A*高效找到最优路径
内存受限IDA*线性空间复杂度

23.7 本章小结

本章介绍了高级搜索技术:

  1. 状态空间搜索:问题的形式化表示
  2. 回溯算法:系统搜索,剪枝优化
  3. 分支限界:界限函数剪枝,保证最优
  4. A*算法:可采纳启发函数保证最优性
  5. 启发式设计:松弛问题导出可采纳启发

参考文献

  1. Russell, S., & Norvig, P. (2020). Artificial Intelligence: A Modern Approach, 4th ed. Pearson.
  2. Hart, P. E., Nilsson, N. J., & Raphael, B. (1968). A formal basis for the heuristic determination of minimum cost paths. IEEE Transactions on Systems Science and Cybernetics, 4(2), 100-107.
  3. Korf, R. E. (1985). Depth-first iterative-deepening: An optimal admissible tree search. Artificial Intelligence, 27(1), 97-109.
  4. Cormen, T. H., et al. (2009). Introduction to Algorithms, 3rd ed. MIT Press.

下一章:第24章 高级数据结构

Python技术丛书 - 江苏省宿城中等专业学校