Skip to content

第23章 访问者模式

学习目标

完成本章学习后,读者将能够:

  • 理解访问者模式的核心概念与形式化定义
  • 掌握双分派机制的原理与实现
  • 实现操作与数据结构的彻底分离
  • 识别访问者模式的适用场景与局限性
  • 应用访问者模式解决复杂对象结构的操作问题

23.1 模式定义

23.1.1 核心定义

访问者模式(Visitor Pattern) 是一种行为型设计模式,它表示一个作用于某对象结构中各元素的操作。它使你可以在不改变各元素的类的前提下定义作用于这些元素的新操作。

23.1.2 形式化定义

从形式化角度,访问者模式可以定义为一个六元组:

$$\mathcal{V} = \langle E, V, O, \alpha, \beta, \gamma \rangle$$

其中:

  • $E = {e_1, e_2, \ldots, e_n}$ 是元素类型集合(Element Types)
  • $V = {v_1, v_2, \ldots, v_m}$ 是访问者集合(Visitors)
  • $O$ 是操作结果类型集合(Output Types)
  • $\alpha: E \times V \rightarrow O$ 是访问函数(Visit Function)
  • $\beta: E \rightarrow \text{accept}$ 是接受方法映射(Accept Method)
  • $\gamma: V \rightarrow {\text{visit}_e \mid e \in E}$ 是访问方法映射(Visit Methods)

双分派定义

访问者模式实现了双分派(Double Dispatch)机制:

$$\text{dispatch}(e, v) = v.\text{visit}_{\text{type}(e)}(e) = e.\text{accept}(v)$$

类型安全约束

对于每个元素类型 $e_i \in E$ 和访问者 $v_j \in V$:

$$\exists \text{visit}{e_i} \in v_j : \text{visit} : e_i \rightarrow O$$

扩展性定理

  1. 添加新操作:新增访问者 $v_{new}$,无需修改元素类
  2. 添加新元素:新增元素类型 $e_{new}$,需修改所有访问者

稳定性条件

访问者模式适用于元素类型集合 $E$ 稳定的场景:

$$\frac{d|E|}{dt} \ll \frac{d|V|}{dt}$$

即元素类型变化率远小于操作变化率。

23.1.3 双分派机制

单分派 vs 双分派

分派类型决定因素示例
单分派运行时对象类型obj.method(arg)
双分派运行时两个对象类型obj.accept(visitor)
多分派运行时多个对象类型一般语言不直接支持

双分派实现原理

┌─────────────────────────────────────────────────────────────────┐
│                      双分派执行流程                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Client                                                         │
│    │                                                            │
│    │ element.accept(visitor)     ← 第一次分派:根据element类型  │
│    ▼                                                            │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ ConcreteElementA                                          │   │
│  │                                                           │   │
│  │  def accept(self, visitor):                              │   │
│  │      visitor.visit_concrete_element_a(self)              │   │
│  │           │                                               │   │
│  │           │  ← 第二次分派:根据visitor类型                 │   │
│  │           ▼                                               │   │
│  │  ┌────────────────────────────────────────────────────┐   │   │
│  │  │ ConcreteVisitor1                                    │   │   │
│  │  │                                                     │   │   │
│  │  │  def visit_concrete_element_a(self, element):      │   │   │
│  │  │      # 处理逻辑                                      │   │   │
│  │  └────────────────────────────────────────────────────┘   │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

23.1.4 设计原则体现

设计原则体现方式
开闭原则新增操作只需添加新访问者,无需修改元素类
单一职责原则操作逻辑集中在访问者中,元素类只负责数据
依赖倒置原则元素依赖抽象的访问者接口

23.2 历史背景与演进

23.2.1 历史发展

年代里程碑描述
1980sSmalltalkSmalltalk的double dispatch机制
1994GoF经典《设计模式》将访问者列为经典模式
1996Cyclomatic Visitor循环访问者模式变体提出
1998Acyclic Visitor非循环访问者模式变体提出
2000s编译器应用编译器、IDE大量使用访问者模式
2010s函数式变体函数式编程带来新的实现方式
2020s模式匹配现代语言的模式匹配简化访问者实现

23.2.2 理论基础

访问者模式的理论基础来源于:

  1. 多分派理论:多态方法的分派机制
  2. 编译器设计:抽象语法树(AST)遍历
  3. 类型理论:代数数据类型的模式匹配

23.3 UML结构图

23.3.1 标准结构

┌─────────────────────────────────────────────────────────────────┐
│                         Visitor <<interface>>                   │
├─────────────────────────────────────────────────────────────────┤
│ + visit_concrete_element_a(element: ConcreteElementA): void    │
│ + visit_concrete_element_b(element: ConcreteElementB): void    │
└───────────────────────────────┬─────────────────────────────────┘

                ┌───────────────┴───────────────┐
                │                               │
┌───────────────┴───────────┐   ┌───────────────┴───────────┐
│    ConcreteVisitor1       │   │    ConcreteVisitor2       │
├───────────────────────────┤   ├───────────────────────────┤
│ + visit_concrete_         │   │ + visit_concrete_         │
│   element_a(element)      │   │   element_a(element)      │
│ + visit_concrete_         │   │ + visit_concrete_         │
│   element_b(element)      │   │   element_b(element)      │
└───────────────────────────┘   └───────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                      Element <<interface>>                      │
├─────────────────────────────────────────────────────────────────┤
│ + accept(visitor: Visitor): void                                │
└───────────────────────────────┬─────────────────────────────────┘

                ┌───────────────┴───────────────┐
                │                               │
┌───────────────┴───────────┐   ┌───────────────┴───────────┐
│    ConcreteElementA       │   │    ConcreteElementB       │
├───────────────────────────┤   ├───────────────────────────┤
│ - attribute_a             │   │ - attribute_b             │
├───────────────────────────┤   ├───────────────────────────┤
│ + accept(visitor)         │   │ + accept(visitor)         │
│ + operation_a()           │   │ + operation_b()           │
└───────────────────────────┘   └───────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                      ObjectStructure                            │
├─────────────────────────────────────────────────────────────────┤
│ - elements: List[Element]                                       │
├─────────────────────────────────────────────────────────────────┤
│ + attach(element: Element)                                      │
│ + detach(element: Element)                                      │
│ + accept(visitor: Visitor)                                      │
└─────────────────────────────────────────────────────────────────┘

23.3.2 协作关系图

┌─────────────────────────────────────────────────────────────────┐
│                        协作关系                                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Client                                                         │
│    │                                                            │
│    │ 1. 创建访问者                                              │
│    │    visitor = ConcreteVisitor()                            │
│    │                                                            │
│    │ 2. 遍历对象结构                                            │
│    │    for element in object_structure:                        │
│    │        element.accept(visitor)                             │
│    │                                                            │
│    ▼                                                            │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ ObjectStructure                                           │   │
│  │                                                           │   │
│  │  elements = [ElementA, ElementB, ElementA]               │   │
│  │                                                           │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                 │
│  执行序列:                                                     │
│  ElementA.accept(visitor) → visitor.visit_element_a(a)         │
│  ElementB.accept(visitor) → visitor.visit_element_b(b)         │
│  ElementA.accept(visitor) → visitor.visit_element_a(a)         │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

23.4 Python实现

23.4.1 经典实现

python
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, List, Any
from dataclasses import dataclass

T = TypeVar('T')


class Visitor(ABC):
    """访问者抽象基类"""
    
    @abstractmethod
    def visit_concrete_element_a(self, element: 'ConcreteElementA') -> Any:
        pass
    
    @abstractmethod
    def visit_concrete_element_b(self, element: 'ConcreteElementB') -> Any:
        pass


class Element(ABC):
    """元素抽象基类"""
    
    @abstractmethod
    def accept(self, visitor: Visitor) -> Any:
        """接受访问者"""
        pass


@dataclass
class ConcreteElementA(Element):
    """具体元素A"""
    
    name: str
    value: int
    
    def accept(self, visitor: Visitor) -> Any:
        return visitor.visit_concrete_element_a(self)
    
    def exclusive_method_a(self) -> str:
        """A特有方法"""
        return f"ElementA[{self.name}] = {self.value}"


@dataclass
class ConcreteElementB(Element):
    """具体元素B"""
    
    title: str
    data: List[str]
    
    def accept(self, visitor: Visitor) -> Any:
        return visitor.visit_concrete_element_b(self)
    
    def special_method_b(self) -> str:
        """B特有方法"""
        return f"ElementB[{self.title}]: {', '.join(self.data)}"


class ConcreteVisitor1(Visitor):
    """具体访问者1:收集信息"""
    
    def __init__(self):
        self.results: List[str] = []
    
    def visit_concrete_element_a(self, element: ConcreteElementA) -> Any:
        result = f"Visitor1处理A: {element.name} = {element.value}"
        self.results.append(result)
        return result
    
    def visit_concrete_element_b(self, element: ConcreteElementB) -> Any:
        result = f"Visitor1处理B: {element.title} ({len(element.data)} items)"
        self.results.append(result)
        return result


class ConcreteVisitor2(Visitor):
    """具体访问者2:计算统计"""
    
    def __init__(self):
        self.total_value = 0
        self.total_items = 0
    
    def visit_concrete_element_a(self, element: ConcreteElementA) -> Any:
        self.total_value += element.value
        return element.value
    
    def visit_concrete_element_b(self, element: ConcreteElementB) -> Any:
        self.total_items += len(element.data)
        return len(element.data)


class ObjectStructure:
    """对象结构:管理元素集合"""
    
    def __init__(self):
        self._elements: List[Element] = []
    
    def attach(self, element: Element) -> None:
        self._elements.append(element)
    
    def detach(self, element: Element) -> None:
        self._elements.remove(element)
    
    def accept(self, visitor: Visitor) -> List[Any]:
        """接受访问者遍历所有元素"""
        return [element.accept(visitor) for element in self._elements]


structure = ObjectStructure()
structure.attach(ConcreteElementA("alpha", 10))
structure.attach(ConcreteElementB("beta", ["x", "y", "z"]))
structure.attach(ConcreteElementA("gamma", 20))

visitor1 = ConcreteVisitor1()
structure.accept(visitor1)
print("Visitor1结果:", visitor1.results)

visitor2 = ConcreteVisitor2()
structure.accept(visitor2)
print(f"Visitor2统计: 总值={visitor2.total_value}, 总项={visitor2.total_items}")

23.4.2 使用singledispatch简化

python
from functools import singledispatch
from dataclasses import dataclass
from typing import List, Any
import math


@dataclass
class Rectangle:
    width: float
    height: float


@dataclass
class Circle:
    radius: float


@dataclass
class Triangle:
    base: float
    height: float


Shape = Rectangle | Circle | Triangle


@singledispatch
def area(shape) -> float:
    """计算面积 - 单分派函数"""
    raise NotImplementedError(f"未知形状: {type(shape)}")


@area.register
def _(shape: Rectangle) -> float:
    return shape.width * shape.height


@area.register
def _(shape: Circle) -> float:
    return math.pi * shape.radius ** 2


@area.register
def _(shape: Triangle) -> float:
    return 0.5 * shape.base * shape.height


@singledispatch
def perimeter(shape) -> float:
    """计算周长 - 另一个操作"""
    raise NotImplementedError(f"未知形状: {type(shape)}")


@perimeter.register
def _(shape: Rectangle) -> float:
    return 2 * (shape.width + shape.height)


@perimeter.register
def _(shape: Circle) -> float:
    return 2 * math.pi * shape.radius


@perimeter.register
def _(shape: Triangle) -> float:
    import math
    side = math.sqrt((shape.base / 2) ** 2 + shape.height ** 2)
    return shape.base + 2 * side


@singledispatch
def to_svg(shape) -> str:
    """导出SVG - 新操作"""
    raise NotImplementedError(f"未知形状: {type(shape)}")


@to_svg.register
def _(shape: Rectangle) -> str:
    return f'<rect width="{shape.width}" height="{shape.height}" />'


@to_svg.register
def _(shape: Circle) -> str:
    return f'<circle r="{shape.radius}" />'


@to_svg.register
def _(shape: Triangle) -> str:
    half_base = shape.base / 2
    points = f"0,{shape.height} {shape.base},{shape.height} {half_base},0"
    return f'<polygon points="{points}" />'


shapes: List[Shape] = [
    Rectangle(4, 5),
    Circle(3),
    Triangle(6, 4),
]

print("=== 面积计算 ===")
for shape in shapes:
    print(f"{type(shape).__name__}: {area(shape):.2f}")

print("\n=== 周长计算 ===")
for shape in shapes:
    print(f"{type(shape).__name__}: {perimeter(shape):.2f}")

print("\n=== SVG导出 ===")
for shape in shapes:
    print(f"{type(shape).__name__}: {to_svg(shape)}")

23.4.3 类型安全的访问者

python
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Protocol, runtime_checkable
from dataclasses import dataclass
from enum import Enum, auto


class ElementType(Enum):
    """元素类型枚举"""
    DOCUMENT = auto()
    PARAGRAPH = auto()
    IMAGE = auto()
    TABLE = auto()


@runtime_checkable
class Visitable(Protocol):
    """可访问协议"""
    
    def accept(self, visitor: 'Visitor') -> Any:
        ...


TResult = TypeVar('TResult')


class Visitor(ABC, Generic[TResult]):
    """泛型访问者基类"""
    
    @abstractmethod
    def visit_document(self, element: 'Document') -> TResult:
        pass
    
    @abstractmethod
    def visit_paragraph(self, element: 'Paragraph') -> TResult:
        pass
    
    @abstractmethod
    def visit_image(self, element: 'Image') -> TResult:
        pass
    
    @abstractmethod
    def visit_table(self, element: 'Table') -> TResult:
        pass


@dataclass
class Document:
    """文档元素"""
    title: str
    children: list[Visitable]
    
    def accept(self, visitor: Visitor) -> Any:
        return visitor.visit_document(self)


@dataclass
class Paragraph:
    """段落元素"""
    text: str
    
    def accept(self, visitor: Visitor) -> Any:
        return visitor.visit_paragraph(self)


@dataclass
class Image:
    """图片元素"""
    url: str
    alt: str
    
    def accept(self, visitor: Visitor) -> Any:
        return visitor.visit_image(self)


@dataclass
class Table:
    """表格元素"""
    headers: list[str]
    rows: list[list[str]]
    
    def accept(self, visitor: Visitor) -> Any:
        return visitor.visit_table(self)


class CountVisitor(Visitor[int]):
    """计数访问者"""
    
    def __init__(self):
        self.counts = {
            ElementType.DOCUMENT: 0,
            ElementType.PARAGRAPH: 0,
            ElementType.IMAGE: 0,
            ElementType.TABLE: 0,
        }
    
    def visit_document(self, element: Document) -> int:
        self.counts[ElementType.DOCUMENT] += 1
        for child in element.children:
            child.accept(self)
        return sum(self.counts.values())
    
    def visit_paragraph(self, element: Paragraph) -> int:
        self.counts[ElementType.PARAGRAPH] += 1
        return 1
    
    def visit_image(self, element: Image) -> int:
        self.counts[ElementType.IMAGE] += 1
        return 1
    
    def visit_table(self, element: Table) -> int:
        self.counts[ElementType.TABLE] += 1
        return 1


class HTMLExportVisitor(Visitor[str]):
    """HTML导出访问者"""
    
    def visit_document(self, element: Document) -> str:
        children_html = '\n'.join(
            child.accept(self) for child in element.children
        )
        return f"""<!DOCTYPE html>
<html>
<head><title>{element.title}</title></head>
<body>
<h1>{element.title}</h1>
{children_html}
</body>
</html>"""
    
    def visit_paragraph(self, element: Paragraph) -> str:
        return f"<p>{element.text}</p>"
    
    def visit_image(self, element: Image) -> str:
        return f'<img src="{element.url}" alt="{element.alt}" />'
    
    def visit_table(self, element: Table) -> str:
        headers = ''.join(f'<th>{h}</th>' for h in element.headers)
        rows = ''.join(
            '<tr>' + ''.join(f'<td>{cell}</td>' for cell in row) + '</tr>'
            for row in element.rows
        )
        return f'<table><thead><tr>{headers}</tr></thead><tbody>{rows}</tbody></table>'


doc = Document(
    title="示例文档",
    children=[
        Paragraph("这是第一段文字。"),
        Image("image.png", "示例图片"),
        Table(
            headers=["姓名", "年龄"],
            rows=[["张三", "25"], ["李四", "30"]]
        ),
        Paragraph("这是第二段文字。"),
    ]
)

counter = CountVisitor()
total = doc.accept(counter)
print(f"元素统计: {counter.counts}, 总计: {total}")

print("\n=== HTML导出 ===")
html_exporter = HTMLExportVisitor()
print(doc.accept(html_exporter))

23.4.4 反射式访问者

python
from typing import Any, Callable, Dict, Type, get_type_hints
from dataclasses import dataclass
from functools import wraps
import inspect


class ReflectiveVisitor:
    """
    反射式访问者:通过方法名约定自动分派
    """
    
    def visit(self, element: Any) -> Any:
        """通用访问方法"""
        method_name = f"visit_{self._get_type_name(element)}"
        visitor_method = getattr(self, method_name, self.default_visit)
        return visitor_method(element)
    
    def _get_type_name(self, element: Any) -> str:
        """获取元素类型名称"""
        return type(element).__name__.lower()
    
    def default_visit(self, element: Any) -> Any:
        """默认访问方法"""
        raise NotImplementedError(
            f"未实现 {type(element).__name__} 的访问方法"
        )


@dataclass
class Product:
    id: str
    name: str
    price: float


@dataclass
class Customer:
    id: str
    name: str
    email: str


@dataclass
class Order:
    id: str
    customer: Customer
    products: list[Product]


class ReportVisitor(ReflectiveVisitor):
    """报表访问者"""
    
    def __init__(self):
        self.report_lines: list[str] = []
    
    def visit_product(self, element: Product) -> Any:
        self.report_lines.append(
            f"  产品: {element.name}{element.price:.2f})"
        )
    
    def visit_customer(self, element: Customer) -> Any:
        self.report_lines.append(
            f"客户: {element.name} <{element.email}>"
        )
    
    def visit_order(self, element: Order) -> Any:
        self.report_lines.append(f"订单 #{element.id}")
        element.customer.visit(self)
        self.report_lines.append("产品列表:")
        for product in element.products:
            product.visit(self)
        total = sum(p.price for p in element.products)
        self.report_lines.append(f"总计: ¥{total:.2f}")
    
    def get_report(self) -> str:
        return '\n'.join(self.report_lines)


class JSONVisitor(ReflectiveVisitor):
    """JSON序列化访问者"""
    
    def visit_product(self, element: Product) -> dict:
        return {
            'type': 'product',
            'id': element.id,
            'name': element.name,
            'price': element.price
        }
    
    def visit_customer(self, element: Customer) -> dict:
        return {
            'type': 'customer',
            'id': element.id,
            'name': element.name,
            'email': element.email
        }
    
    def visit_order(self, element: Order) -> dict:
        return {
            'type': 'order',
            'id': element.id,
            'customer': element.customer.visit(self),
            'products': [p.visit(self) for p in element.products]
        }


def make_visitable(cls):
    """装饰器:为类添加visit方法"""
    def visit(self, visitor):
        return visitor.visit(self)
    cls.visit = visit
    return cls


@make_visitable
@dataclass
class Book:
    title: str
    author: str
    isbn: str


class SummaryVisitor(ReflectiveVisitor):
    def visit_book(self, element: Book) -> str:
        return f"《{element.title}》- {element.author}"


order = Order(
    id="ORD001",
    customer=Customer("C001", "张三", "zhang@example.com"),
    products=[
        Product("P001", "Python编程", 89.0),
        Product("P002", "设计模式", 99.0),
    ]
)

report_visitor = ReportVisitor()
order.visit(report_visitor)
print(report_visitor.get_report())

print("\n=== JSON序列化 ===")
json_visitor = JSONVisitor()
import json
print(json.dumps(order.visit(json_visitor), ensure_ascii=False, indent=2))

book = Book("设计模式", "GoF", "978-0201633610")
summary = SummaryVisitor()
print(f"\n书籍摘要: {book.visit(summary)}")

23.5 企业级应用示例

23.5.1 编译器AST遍历

python
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from dataclasses import dataclass, field


class ASTNode(ABC):
    """抽象语法树节点基类"""
    
    @abstractmethod
    def accept(self, visitor: 'ASTVisitor') -> Any:
        pass


@dataclass
class Program(ASTNode):
    """程序节点"""
    statements: List[ASTNode]
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_program(self)


@dataclass
class FunctionDef(ASTNode):
    """函数定义"""
    name: str
    params: List[str]
    body: List[ASTNode]
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_function_def(self)


@dataclass
class Return(ASTNode):
    """返回语句"""
    value: Optional[ASTNode]
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_return(self)


@dataclass
class BinaryOp(ASTNode):
    """二元运算"""
    left: ASTNode
    op: str
    right: ASTNode
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_binary_op(self)


@dataclass
class Number(ASTNode):
    """数字字面量"""
    value: float
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_number(self)


@dataclass
class Variable(ASTNode):
    """变量引用"""
    name: str
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_variable(self)


@dataclass
class FunctionCall(ASTNode):
    """函数调用"""
    name: str
    args: List[ASTNode]
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_function_call(self)


class ASTVisitor(ABC):
    """AST访问者基类"""
    
    @abstractmethod
    def visit_program(self, node: Program) -> Any:
        pass
    
    @abstractmethod
    def visit_function_def(self, node: FunctionDef) -> Any:
        pass
    
    @abstractmethod
    def visit_return(self, node: Return) -> Any:
        pass
    
    @abstractmethod
    def visit_binary_op(self, node: BinaryOp) -> Any:
        pass
    
    @abstractmethod
    def visit_number(self, node: Number) -> Any:
        pass
    
    @abstractmethod
    def visit_variable(self, node: Variable) -> Any:
        pass
    
    @abstractmethod
    def visit_function_call(self, node: FunctionCall) -> Any:
        pass


class Interpreter(ASTVisitor):
    """解释器访问者"""
    
    def __init__(self):
        self.globals: dict[str, Any] = {}
        self.functions: dict[str, FunctionDef] = {}
        self.call_stack: list[dict[str, Any]] = [{}]
    
    @property
    def env(self) -> dict[str, Any]:
        return self.call_stack[-1]
    
    def visit_program(self, node: Program) -> Any:
        result = None
        for stmt in node.statements:
            result = stmt.accept(self)
        return result
    
    def visit_function_def(self, node: FunctionDef) -> Any:
        self.functions[node.name] = node
        return None
    
    def visit_return(self, node: Return) -> Any:
        if node.value:
            return node.value.accept(self)
        return None
    
    def visit_binary_op(self, node: BinaryOp) -> Any:
        left = node.left.accept(self)
        right = node.right.accept(self)
        
        ops = {
            '+': lambda a, b: a + b,
            '-': lambda a, b: a - b,
            '*': lambda a, b: a * b,
            '/': lambda a, b: a / b,
            '==': lambda a, b: a == b,
            '<': lambda a, b: a < b,
            '>': lambda a, b: a > b,
        }
        return ops[node.op](left, right)
    
    def visit_number(self, node: Number) -> Any:
        return node.value
    
    def visit_variable(self, node: Variable) -> Any:
        if node.name in self.env:
            return self.env[node.name]
        if node.name in self.globals:
            return self.globals[node.name]
        raise NameError(f"未定义的变量: {node.name}")
    
    def visit_function_call(self, node: FunctionCall) -> Any:
        if node.name not in self.functions:
            raise NameError(f"未定义的函数: {node.name}")
        
        func = self.functions[node.name]
        
        if len(node.args) != len(func.params):
            raise TypeError(
                f"函数 {node.name} 期望 {len(func.params)} 个参数,"
                f"得到 {len(node.args)} 个"
            )
        
        new_env = {}
        for param, arg in zip(func.params, node.args):
            new_env[param] = arg.accept(self)
        
        self.call_stack.append(new_env)
        
        result = None
        for stmt in func.body:
            result = stmt.accept(self)
        
        self.call_stack.pop()
        
        return result


class PrettyPrinter(ASTVisitor):
    """代码美化打印访问者"""
    
    def __init__(self, indent_size: int = 4):
        self.indent_size = indent_size
        self.indent_level = 0
        self.lines: List[str] = []
    
    def _indent(self) -> str:
        return ' ' * (self.indent_level * self.indent_size)
    
    def visit_program(self, node: Program) -> str:
        self.lines = []
        for stmt in node.statements:
            stmt.accept(self)
        return '\n'.join(self.lines)
    
    def visit_function_def(self, node: FunctionDef) -> str:
        params = ', '.join(node.params)
        self.lines.append(f"{self._indent()}def {node.name}({params}):")
        self.indent_level += 1
        for stmt in node.body:
            stmt.accept(self)
        self.indent_level -= 1
        return ''
    
    def visit_return(self, node: Return) -> str:
        if node.value:
            value = node.value.accept(self)
            self.lines.append(f"{self._indent()}return {value}")
        else:
            self.lines.append(f"{self._indent()}return")
        return ''
    
    def visit_binary_op(self, node: BinaryOp) -> str:
        left = node.left.accept(self)
        right = node.right.accept(self)
        return f"({left} {node.op} {right})"
    
    def visit_number(self, node: Number) -> str:
        return str(node.value)
    
    def visit_variable(self, node: Variable) -> str:
        return node.name
    
    def visit_function_call(self, node: FunctionCall) -> str:
        args = ', '.join(arg.accept(self) for arg in node.args)
        return f"{node.name}({args})"


ast = Program([
    FunctionDef('add', ['a', 'b'], [
        Return(BinaryOp(Variable('a'), '+', Variable('b')))
    ]),
    FunctionDef('square', ['x'], [
        Return(BinaryOp(Variable('x'), '*', Variable('x')))
    ]),
    FunctionDef('main', [], [
        Return(FunctionCall('add', [
            FunctionCall('square', [Number(3)]),
            Number(4)
        ]))
    ])
])

printer = PrettyPrinter()
print("=== 源代码 ===")
print(printer.visit(ast))

interpreter = Interpreter()
ast.accept(interpreter)
result = interpreter.visit_function_call(FunctionCall('main', []))
print(f"\n=== 执行结果 ===")
print(f"main() = {result}")

23.5.2 数据库查询生成器

python
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from dataclasses import dataclass
from enum import Enum, auto


class JoinType(Enum):
    INNER = auto()
    LEFT = auto()
    RIGHT = auto()


@dataclass
class SQLNode(ABC):
    """SQL AST节点基类"""
    
    @abstractmethod
    def accept(self, visitor: 'SQLVisitor') -> str:
        pass


@dataclass
class Select(SQLNode):
    """SELECT语句"""
    columns: List[SQLNode]
    from_clause: SQLNode
    where: Optional[SQLNode] = None
    order_by: Optional[SQLNode] = None
    limit: Optional[int] = None
    
    def accept(self, visitor: 'SQLVisitor') -> str:
        return visitor.visit_select(self)


@dataclass
class Column(SQLNode):
    """列引用"""
    name: str
    table: Optional[str] = None
    alias: Optional[str] = None
    
    def accept(self, visitor: 'SQLVisitor') -> str:
        return visitor.visit_column(self)


@dataclass
class Table(SQLNode):
    """表引用"""
    name: str
    alias: Optional[str] = None
    
    def accept(self, visitor: 'SQLVisitor') -> str:
        return visitor.visit_table(self)


@dataclass
class Join(SQLNode):
    """JOIN子句"""
    left: SQLNode
    right: SQLNode
    condition: SQLNode
    join_type: JoinType = JoinType.INNER
    
    def accept(self, visitor: 'SQLVisitor') -> str:
        return visitor.visit_join(self)


@dataclass
class BinaryExpr(SQLNode):
    """二元表达式"""
    left: SQLNode
    op: str
    right: SQLNode
    
    def accept(self, visitor: 'SQLVisitor') -> str:
        return visitor.visit_binary_expr(self)


@dataclass
class Literal(SQLNode):
    """字面量"""
    value: Any
    
    def accept(self, visitor: 'SQLVisitor') -> str:
        return visitor.visit_literal(self)


@dataclass
class FunctionCall(SQLNode):
    """函数调用"""
    name: str
    args: List[SQLNode]
    
    def accept(self, visitor: 'SQLVisitor') -> str:
        return visitor.visit_function_call(self)


@dataclass
class OrderBy(SQLNode):
    """ORDER BY子句"""
    columns: List[SQLNode]
    descending: bool = False
    
    def accept(self, visitor: 'SQLVisitor') -> str:
        return visitor.visit_order_by(self)


class SQLVisitor(ABC):
    """SQL访问者基类"""
    
    @abstractmethod
    def visit_select(self, node: Select) -> str:
        pass
    
    @abstractmethod
    def visit_column(self, node: Column) -> str:
        pass
    
    @abstractmethod
    def visit_table(self, node: Table) -> str:
        pass
    
    @abstractmethod
    def visit_join(self, node: Join) -> str:
        pass
    
    @abstractmethod
    def visit_binary_expr(self, node: BinaryExpr) -> str:
        pass
    
    @abstractmethod
    def visit_literal(self, node: Literal) -> str:
        pass
    
    @abstractmethod
    def visit_function_call(self, node: FunctionCall) -> str:
        pass
    
    @abstractmethod
    def visit_order_by(self, node: OrderBy) -> str:
        pass


class MySQLVisitor(SQLVisitor):
    """MySQL方言访问者"""
    
    def visit_select(self, node: Select) -> str:
        columns = ', '.join(col.accept(self) for col in node.columns)
        from_clause = node.from_clause.accept(self)
        
        sql = f"SELECT {columns}\nFROM {from_clause}"
        
        if node.where:
            sql += f"\nWHERE {node.where.accept(self)}"
        
        if node.order_by:
            sql += f"\nORDER BY {node.order_by.accept(self)}"
        
        if node.limit:
            sql += f"\nLIMIT {node.limit}"
        
        return sql
    
    def visit_column(self, node: Column) -> str:
        if node.table:
            col = f"{node.table}.{node.name}"
        else:
            col = node.name
        
        if node.alias:
            col += f" AS {node.alias}"
        
        return col
    
    def visit_table(self, node: Table) -> str:
        if node.alias:
            return f"{node.name} {node.alias}"
        return node.name
    
    def visit_join(self, node: Join) -> str:
        join_types = {
            JoinType.INNER: "INNER JOIN",
            JoinType.LEFT: "LEFT JOIN",
            JoinType.RIGHT: "RIGHT JOIN",
        }
        
        left = node.left.accept(self)
        right = node.right.accept(self)
        condition = node.condition.accept(self)
        
        return f"{left}\n{join_types[node.join_type]} {right} ON {condition}"
    
    def visit_binary_expr(self, node: BinaryExpr) -> str:
        left = node.left.accept(self)
        right = node.right.accept(self)
        return f"{left} {node.op} {right}"
    
    def visit_literal(self, node: Literal) -> str:
        if isinstance(node.value, str):
            return f"'{node.value}'"
        elif node.value is None:
            return "NULL"
        return str(node.value)
    
    def visit_function_call(self, node: FunctionCall) -> str:
        args = ', '.join(arg.accept(self) for arg in node.args)
        return f"{node.name}({args})"
    
    def visit_order_by(self, node: OrderBy) -> str:
        cols = ', '.join(col.accept(self) for col in node.columns)
        if node.descending:
            cols += " DESC"
        return cols


class PostgreSQLVisitor(SQLVisitor):
    """PostgreSQL方言访问者"""
    
    def visit_select(self, node: Select) -> str:
        columns = ', '.join(col.accept(self) for col in node.columns)
        from_clause = node.from_clause.accept(self)
        
        sql = f"SELECT {columns}\nFROM {from_clause}"
        
        if node.where:
            sql += f"\nWHERE {node.where.accept(self)}"
        
        if node.order_by:
            sql += f"\nORDER BY {node.order_by.accept(self)}"
        
        if node.limit:
            sql += f"\nLIMIT {node.limit}"
        
        return sql
    
    def visit_column(self, node: Column) -> str:
        if node.table:
            col = f'"{node.table}"."{node.name}"'
        else:
            col = f'"{node.name}"'
        
        if node.alias:
            col += f' AS "{node.alias}"'
        
        return col
    
    def visit_table(self, node: Table) -> str:
        if node.alias:
            return f'"{node.name}" AS "{node.alias}"'
        return f'"{node.name}"'
    
    def visit_join(self, node: Join) -> str:
        join_types = {
            JoinType.INNER: "INNER JOIN",
            JoinType.LEFT: "LEFT JOIN",
            JoinType.RIGHT: "RIGHT JOIN",
        }
        
        left = node.left.accept(self)
        right = node.right.accept(self)
        condition = node.condition.accept(self)
        
        return f"{left}\n{join_types[node.join_type]} {right} ON {condition}"
    
    def visit_binary_expr(self, node: BinaryExpr) -> str:
        left = node.left.accept(self)
        right = node.right.accept(self)
        return f"{left} {node.op} {right}"
    
    def visit_literal(self, node: Literal) -> str:
        if isinstance(node.value, str):
            return f"'{node.value}'"
        elif node.value is None:
            return "NULL"
        return str(node.value)
    
    def visit_function_call(self, node: FunctionCall) -> str:
        args = ', '.join(arg.accept(self) for arg in node.args)
        return f"{node.name}({args})"
    
    def visit_order_by(self, node: OrderBy) -> str:
        cols = ', '.join(col.accept(self) for col in node.columns)
        if node.descending:
            cols += " DESC"
        return cols


query = Select(
    columns=[
        Column('id'),
        Column('name', 'u', 'user_name'),
        FunctionCall('COUNT', [Column('id', 'o')]),
    ],
    from_clause=Join(
        left=Table('users', 'u'),
        right=Table('orders', 'o'),
        condition=BinaryExpr(
            Column('id', 'u'), '=', Column('user_id', 'o')
        ),
        join_type=JoinType.LEFT
    ),
    where=BinaryExpr(
        Column('status', 'o'), '=', Literal('active')
    ),
    order_by=OrderBy([Column('name', 'u')]),
    limit=10
)

print("=== MySQL ===")
print(query.accept(MySQLVisitor()))

print("\n=== PostgreSQL ===")
print(query.accept(PostgreSQLVisitor()))

23.5.3 文件系统遍历

python
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path


@dataclass
class FileSystemNode(ABC):
    """文件系统节点基类"""
    name: str
    path: str
    modified_time: datetime
    
    @abstractmethod
    def accept(self, visitor: 'FileSystemVisitor') -> Any:
        pass


@dataclass
class File(FileSystemNode):
    """文件节点"""
    size: int
    extension: str
    
    def accept(self, visitor: 'FileSystemVisitor') -> Any:
        return visitor.visit_file(self)


@dataclass
class Directory(FileSystemNode):
    """目录节点"""
    children: List[FileSystemNode] = field(default_factory=list)
    
    def accept(self, visitor: 'FileSystemVisitor') -> Any:
        return visitor.visit_directory(self)
    
    def add_child(self, child: FileSystemNode) -> None:
        self.children.append(child)


@dataclass
class Symlink(FileSystemNode):
    """符号链接节点"""
    target: str
    
    def accept(self, visitor: 'FileSystemVisitor') -> Any:
        return visitor.visit_symlink(self)


class FileSystemVisitor(ABC):
    """文件系统访问者基类"""
    
    @abstractmethod
    def visit_file(self, node: File) -> Any:
        pass
    
    @abstractmethod
    def visit_directory(self, node: Directory) -> Any:
        pass
    
    @abstractmethod
    def visit_symlink(self, node: Symlink) -> Any:
        pass


class SizeCalculator(FileSystemVisitor):
    """大小计算访问者"""
    
    def __init__(self):
        self.total_size = 0
        self.file_count = 0
        self.dir_count = 0
    
    def visit_file(self, node: File) -> int:
        self.total_size += node.size
        self.file_count += 1
        return node.size
    
    def visit_directory(self, node: Directory) -> int:
        self.dir_count += 1
        dir_size = 0
        for child in node.children:
            dir_size += child.accept(self)
        return dir_size
    
    def visit_symlink(self, node: Symlink) -> int:
        return 0
    
    def get_summary(self) -> dict:
        return {
            'total_size': self.total_size,
            'file_count': self.file_count,
            'dir_count': self.dir_count,
            'size_formatted': self._format_size(self.total_size)
        }
    
    def _format_size(self, size: int) -> str:
        for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
            if size < 1024:
                return f"{size:.2f} {unit}"
            size /= 1024
        return f"{size:.2f} PB"


class SearchVisitor(FileSystemVisitor):
    """搜索访问者"""
    
    def __init__(
        self,
        name_pattern: Optional[str] = None,
        extension: Optional[str] = None,
        min_size: Optional[int] = None,
        max_size: Optional[int] = None
    ):
        self.name_pattern = name_pattern
        self.extension = extension
        self.min_size = min_size
        self.max_size = max_size
        self.results: List[FileSystemNode] = []
    
    def _matches(self, node: FileSystemNode) -> bool:
        if self.name_pattern and self.name_pattern.lower() not in node.name.lower():
            return False
        
        if isinstance(node, File):
            if self.extension and node.extension.lower() != self.extension.lower():
                return False
            if self.min_size and node.size < self.min_size:
                return False
            if self.max_size and node.size > self.max_size:
                return False
        
        return True
    
    def visit_file(self, node: File) -> List[FileSystemNode]:
        if self._matches(node):
            self.results.append(node)
        return self.results
    
    def visit_directory(self, node: Directory) -> List[FileSystemNode]:
        if self._matches(node):
            self.results.append(node)
        for child in node.children:
            child.accept(self)
        return self.results
    
    def visit_symlink(self, node: Symlink) -> List[FileSystemNode]:
        if self._matches(node):
            self.results.append(node)
        return self.results


class TreePrinter(FileSystemVisitor):
    """树形打印访问者"""
    
    def __init__(self, show_size: bool = True):
        self.show_size = show_size
        self.lines: List[str] = []
        self.indent = ""
        self.is_last = True
        self.stack: List[bool] = []
    
    def visit_file(self, node: File) -> str:
        prefix = self._get_prefix()
        size_str = f" ({self._format_size(node.size)})" if self.show_size else ""
        line = f"{prefix}{node.name}{size_str}"
        self.lines.append(line)
        return line
    
    def visit_directory(self, node: Directory) -> str:
        prefix = self._get_prefix()
        line = f"{prefix}{node.name}/"
        self.lines.append(line)
        
        self.stack.append(self.is_last)
        
        for i, child in enumerate(node.children):
            self.is_last = (i == len(node.children) - 1)
            self.indent = "    " * len(self.stack)
            child.accept(self)
        
        self.stack.pop()
        return line
    
    def visit_symlink(self, node: Symlink) -> str:
        prefix = self._get_prefix()
        line = f"{prefix}{node.name} -> {node.target}"
        self.lines.append(line)
        return line
    
    def _get_prefix(self) -> str:
        if not self.stack:
            return ""
        connector = "└── " if self.is_last else "├── "
        return self.indent + connector
    
    def _format_size(self, size: int) -> str:
        if size < 1024:
            return f"{size}B"
        elif size < 1024 * 1024:
            return f"{size/1024:.1f}KB"
        else:
            return f"{size/1024/1024:.1f}MB"
    
    def get_tree(self) -> str:
        return '\n'.join(self.lines)


root = Directory("project", "/project", datetime.now())
src = Directory("src", "/project/src", datetime.now())
tests = Directory("tests", "/project/tests", datetime.now())

src.add_child(File("main.py", "/project/src/main.py", datetime.now(), 2048, "py"))
src.add_child(File("utils.py", "/project/src/utils.py", datetime.now(), 1024, "py"))
src.add_child(File("config.json", "/project/src/config.json", datetime.now(), 512, "json"))

tests.add_child(File("test_main.py", "/project/tests/test_main.py", datetime.now(), 1536, "py"))
tests.add_child(File("test_utils.py", "/project/tests/test_utils.py", datetime.now(), 768, "py"))

root.add_child(src)
root.add_child(tests)
root.add_child(File("README.md", "/project/README.md", datetime.now(), 256, "md"))
root.add_child(Symlink("link_to_src", "/project/link_to_src", datetime.now(), "/project/src"))

print("=== 目录树 ===")
printer = TreePrinter()
root.accept(printer)
print(printer.get_tree())

print("\n=== 大小统计 ===")
size_calc = SizeCalculator()
root.accept(size_calc)
print(size_calc.get_summary())

print("\n=== 搜索 .py 文件 ===")
searcher = SearchVisitor(extension="py")
root.accept(searcher)
for result in searcher.results:
    print(f"  {result.path}")

23.6 模式变体与扩展

23.6.1 非循环访问者

python
from abc import ABC, abstractmethod
from typing import Any, Optional, Type, Set
from dataclasses import dataclass


class Element(ABC):
    """元素基类"""
    
    @abstractmethod
    def accept(self, visitor: 'VisitorBase') -> Any:
        pass


class VisitorBase(ABC):
    """访问者基类(非循环)"""
    
    def visit(self, element: Element) -> Any:
        method_name = f"visit_{type(element).__name__}"
        method = getattr(self, method_name, self.default_visit)
        return method(element)
    
    def default_visit(self, element: Element) -> Any:
        raise NotImplementedError(
            f"{type(self).__name__} 不支持 {type(element).__name__}"
        )


@dataclass
class PDF(Element):
    filename: str
    pages: int
    
    def accept(self, visitor: VisitorBase) -> Any:
        return visitor.visit(self)


@dataclass
class Image(Element):
    filename: str
    width: int
    height: int
    
    def accept(self, visitor: VisitorBase) -> Any:
        return visitor.visit(self)


@dataclass
class Video(Element):
    filename: str
    duration: float
    
    def accept(self, visitor: VisitorBase) -> Any:
        return visitor.visit(self)


class MetadataVisitor(VisitorBase):
    """元数据访问者 - 只处理PDF和Image"""
    
    def visit_PDF(self, element: PDF) -> dict:
        return {
            'type': 'PDF',
            'filename': element.filename,
            'pages': element.pages
        }
    
    def visit_Image(self, element: Image) -> dict:
        return {
            'type': 'Image',
            'filename': element.filename,
            'resolution': f"{element.width}x{element.height}"
        }


class DurationVisitor(VisitorBase):
    """时长访问者 - 只处理Video"""
    
    def visit_Video(self, element: Video) -> str:
        minutes = int(element.duration // 60)
        seconds = int(element.duration % 60)
        return f"{minutes:02d}:{seconds:02d}"


elements: list[Element] = [
    PDF("document.pdf", 25),
    Image("photo.jpg", 1920, 1080),
    Video("clip.mp4", 125.5),
]

print("=== 元数据提取 ===")
metadata_visitor = MetadataVisitor()
for elem in elements:
    try:
        result = elem.accept(metadata_visitor)
        print(result)
    except NotImplementedError as e:
        print(f"跳过: {type(elem).__name__}")

print("\n=== 时长计算 ===")
duration_visitor = DurationVisitor()
for elem in elements:
    try:
        result = elem.accept(duration_visitor)
        print(f"{elem.filename}: {result}")
    except NotImplementedError as e:
        print(f"跳过: {type(elem).__name__}")

23.6.2 访问者组合器

python
from typing import Any, List, Callable, TypeVar, Generic
from dataclasses import dataclass, field

T = TypeVar('T')


@dataclass
class CompositeVisitor:
    """组合访问者:将多个访问者组合"""
    
    visitors: List[Callable[[Any], Any]] = field(default_factory=list)
    
    def add(self, visitor: Callable[[Any], Any]) -> 'CompositeVisitor':
        self.visitors.append(visitor)
        return self
    
    def visit(self, element: Any) -> List[Any]:
        return [visitor(element) for visitor in self.visitors]


@dataclass
class Product:
    id: str
    name: str
    price: float
    category: str


def extract_basic_info(product: Product) -> dict:
    return {'id': product.id, 'name': product.name}


def calculate_tax(product: Product, tax_rate: float = 0.1) -> dict:
    tax = product.price * tax_rate
    return {'price': product.price, 'tax': tax}


def categorize(product: Product) -> dict:
    return {'category': product.category}


composite = CompositeVisitor()
composite.add(extract_basic_info)
composite.add(lambda p: calculate_tax(p, 0.08))
composite.add(categorize)

product = Product("P001", "Python编程", 89.0, "Books")
results = composite.visit(product)

print("=== 组合访问者结果 ===")
for i, result in enumerate(results):
    print(f"访问者{i+1}: {result}")

23.6.3 带状态的访问者

python
from typing import Any, Dict, Optional
from dataclasses import dataclass, field
from enum import Enum, auto


class NodeType(Enum):
    ROOT = auto()
    BRANCH = auto()
    LEAF = auto()


@dataclass
class TreeNode:
    """树节点"""
    name: str
    node_type: NodeType
    value: Any = None
    children: list['TreeNode'] = field(default_factory=list)
    
    def accept(self, visitor: 'StatefulVisitor') -> Any:
        return visitor.visit(self)


class StatefulVisitor:
    """带状态的访问者"""
    
    def __init__(self):
        self.state: Dict[str, Any] = {
            'depth': 0,
            'path': [],
            'visited': 0,
            'results': []
        }
        self._context_stack: list[Dict[str, Any]] = []
    
    def visit(self, node: TreeNode) -> Any:
        self._push_context()
        
        self.state['depth'] += 1
        self.state['path'].append(node.name)
        self.state['visited'] += 1
        
        result = self._process_node(node)
        self.state['results'].append(result)
        
        for child in node.children:
            child.accept(self)
        
        self.state['path'].pop()
        self.state['depth'] -= 1
        
        self._pop_context()
        
        return result
    
    def _push_context(self) -> None:
        self._context_stack.append(self.state.copy())
    
    def _pop_context(self) -> None:
        if self._context_stack:
            saved = self._context_stack.pop()
            saved['results'] = self.state['results']
            self.state = saved
    
    def _process_node(self, node: TreeNode) -> dict:
        indent = '  ' * self.state['depth']
        path = '/'.join(self.state['path'])
        
        return {
            'name': node.name,
            'type': node.node_type.name,
            'depth': self.state['depth'],
            'path': path,
            'value': node.value,
            'display': f"{indent}{node.name} ({node.node_type.name})"
        }
    
    def get_summary(self) -> dict:
        return {
            'total_visited': self.state['visited'],
            'results': self.state['results']
        }


root = TreeNode("root", NodeType.ROOT, children=[
    TreeNode("branch1", NodeType.BRANCH, value=100, children=[
        TreeNode("leaf1", NodeType.LEAF, value=10),
        TreeNode("leaf2", NodeType.LEAF, value=20),
    ]),
    TreeNode("branch2", NodeType.BRANCH, children=[
        TreeNode("leaf3", NodeType.LEAF, value=30),
        TreeNode("branch3", NodeType.BRANCH, children=[
            TreeNode("leaf4", NodeType.LEAF, value=40),
        ]),
    ]),
])

visitor = StatefulVisitor()
root.accept(visitor)

print("=== 遍历结果 ===")
for result in visitor.state['results']:
    print(result['display'])

print(f"\n总计访问: {visitor.state['visited']} 个节点")

23.7 反模式与最佳实践

23.7.1 常见反模式

反模式1:频繁添加新元素类型

python
# ❌ 错误示例:元素类型频繁变化时使用访问者
class ElementA: pass
class ElementB: pass
class ElementC: pass  # 每次新增元素都要修改所有访问者

class Visitor:
    def visit_a(self, a): pass
    def visit_b(self, b): pass
    def visit_c(self, c): pass  # 必须修改


# ✅ 正确示例:元素稳定时使用访问者,否则考虑其他模式
# 如果元素类型频繁变化,考虑使用策略模式或命令模式

反模式2:访问者过度访问内部状态

python
# ❌ 错误示例:访问者过度依赖元素内部实现
class BadVisitor:
    def visit_order(self, order):
        # 直接访问私有属性,破坏封装
        return order._internal_data['secret_field']


# ✅ 正确示例:通过公共接口访问
class GoodVisitor:
    def visit_order(self, order):
        return order.get_summary()  # 使用公共方法

反模式3:循环引用

python
# ❌ 错误示例:访问者和元素相互依赖
class ElementA:
    def accept(self, visitor):
        return visitor.visit_a(self)
    
    def process_with_visitor(self, visitor):
        # 元素调用访问者,访问者又调用元素
        return visitor.complex_operation(self)


# ✅ 正确示例:单向依赖
class ElementA:
    def accept(self, visitor):
        return visitor.visit_a(self)
    # 访问者单向调用元素方法

23.7.2 最佳实践

python
from abc import ABC, abstractmethod
from typing import Any, List, TypeVar, Generic
from dataclasses import dataclass

TResult = TypeVar('TResult')


class BestPracticeVisitor(ABC, Generic[TResult]):
    """访问者最佳实践基类"""
    
    @abstractmethod
    def get_result(self) -> TResult:
        """获取访问结果"""
        pass


class Element(ABC):
    """元素最佳实践基类"""
    
    @abstractmethod
    def accept(self, visitor: BestPracticeVisitor) -> Any:
        pass
    
    def get_visitable_data(self) -> dict:
        """提供访问者需要的数据(封装友好)"""
        return {}


@dataclass
class Invoice(Element):
    """发票元素"""
    invoice_id: str
    amount: float
    items: List[str]
    
    def accept(self, visitor: BestPracticeVisitor) -> Any:
        return visitor.visit_invoice(self)
    
    def get_visitable_data(self) -> dict:
        return {
            'id': self.invoice_id,
            'amount': self.amount,
            'item_count': len(self.items)
        }


@dataclass
class Receipt(Element):
    """收据元素"""
    receipt_id: str
    total: float
    
    def accept(self, visitor: BestPracticeVisitor) -> Any:
        return visitor.visit_receipt(self)
    
    def get_visitable_data(self) -> dict:
        return {
            'id': self.receipt_id,
            'total': self.total
        }


class SummaryVisitor(BestPracticeVisitor[dict]):
    """汇总访问者"""
    
    def __init__(self):
        self._total_amount = 0.0
        self._document_count = 0
        self._details: List[dict] = []
    
    def visit_invoice(self, invoice: Invoice) -> None:
        data = invoice.get_visitable_data()
        self._total_amount += data['amount']
        self._document_count += 1
        self._details.append({
            'type': 'Invoice',
            'id': data['id'],
            'amount': data['amount']
        })
    
    def visit_receipt(self, receipt: Receipt) -> None:
        data = receipt.get_visitable_data()
        self._total_amount += data['total']
        self._document_count += 1
        self._details.append({
            'type': 'Receipt',
            'id': data['id'],
            'amount': data['total']
        })
    
    def get_result(self) -> dict:
        return {
            'total_amount': self._total_amount,
            'document_count': self._document_count,
            'details': self._details
        }


documents: List[Element] = [
    Invoice("INV001", 1000.0, ["Item1", "Item2"]),
    Receipt("REC001", 500.0),
    Invoice("INV002", 750.0, ["Item3"]),
]

visitor = SummaryVisitor()
for doc in documents:
    doc.accept(visitor)

result = visitor.get_result()
print(f"文档数量: {result['document_count']}")
print(f"总金额: ¥{result['total_amount']:.2f}")
print("详情:")
for detail in result['details']:
    print(f"  {detail['type']} {detail['id']}: ¥{detail['amount']:.2f}")

23.8 决策指南

23.8.1 是否使用访问者模式

┌─────────────────────────────────────────────────────────────────┐
│                    是否使用访问者模式?                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  对象结构是否稳定?                                             │
│         │                                                       │
│         ├── 否 ──→ 不适合使用访问者模式                         │
│         │                                                       │
│         └── 是 ──┐                                              │
│                    │                                            │
│                    ▼                                            │
│         是否需要多种操作?                                      │
│         │                                                       │
│         ├── 否 ──→ 考虑直接在类中实现方法                       │
│         │                                                       │
│         └── 是 ──┐                                              │
│                    │                                            │
│                    ▼                                            │
│         操作是否经常变化?                                      │
│         │                                                       │
│         ├── 否 ──→ 考虑策略模式                                 │
│         │                                                       │
│         └── 是 ──→ ✓ 使用访问者模式                             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

23.8.2 模式选择决策树

┌─────────────────────────────────────────────────────────────────┐
│                    行为型模式选择指南                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  需要分离什么?                                                 │
│         │                                                       │
│         ├── 操作与结构 ──→ 访问者模式 ✓                         │
│         │                                                       │
│         ├── 算法整体 ──→ 策略模式                               │
│         │                                                       │
│         ├── 算法骨架 ──→ 模板方法模式                           │
│         │                                                       │
│         ├── 请求处理 ──→ 责任链模式                             │
│         │                                                       │
│         └── 对象行为 ──→ 状态模式                               │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

23.8.3 与相关模式的对比

特性访问者策略命令迭代器
意图分离操作与结构封装算法封装请求遍历集合
变化点操作算法请求执行遍历方式
结构依赖依赖元素结构独立独立依赖集合结构
新增操作容易容易容易不适用
新增元素困难不适用不适用不适用

23.9 快速参考卡

23.9.1 核心概念速查

┌─────────────────────────────────────────────────────────────────┐
│                      访问者模式速查卡                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  定义:将操作从对象结构中分离出来,支持新增操作而不修改结构     │
│                                                                 │
│  核心角色:                                                     │
│  ├── Visitor:访问者接口,定义visit方法                         │
│  ├── ConcreteVisitor:具体访问者,实现具体操作                  │
│  ├── Element:元素接口,定义accept方法                          │
│  ├── ConcreteElement:具体元素,实现accept                      │
│  └── ObjectStructure:对象结构,管理元素集合                    │
│                                                                 │
│  双分派机制:                                                   │
│  element.accept(visitor) → visitor.visit_element(element)      │
│                                                                 │
│  适用场景:                                                     │
│  ├── 对象结构稳定,操作频繁变化                                 │
│  ├── 需要对结构进行多种不同操作                                 │
│  └── 避免操作污染元素类                                         │
│                                                                 │
│  优点:开闭原则、单一职责、操作集中                             │
│  缺点:新增元素困难、破坏封装、复杂度增加                       │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

23.9.2 Python实现要点

python
from abc import ABC, abstractmethod
from functools import singledispatch

# 经典实现
class Visitor(ABC):
    @abstractmethod
    def visit_element_a(self, element): pass
    @abstractmethod
    def visit_element_b(self, element): pass

class Element(ABC):
    @abstractmethod
    def accept(self, visitor): pass

# singledispatch简化实现
@singledispatch
def operation(element):
    raise NotImplementedError()

@operation.register
def _(element: ElementA):
    return element.process_a()

23.9.3 常见应用场景

场景示例
编译器AST遍历、代码生成
文档处理格式转换、导出
数据库SQL生成、查询优化
文件系统搜索、统计、备份
图形编辑渲染、导出、变换

23.10 小结

访问者模式通过双分派机制实现了操作与数据结构的彻底分离,是处理复杂对象结构操作的利器。在Python中,可以利用functools.singledispatch简化访问者模式的实现,使代码更加简洁。

关键要点

  1. 双分派机制:通过两次方法调用实现运行时类型匹配
  2. 稳定性前提:适用于元素类型稳定、操作频繁变化的场景
  3. 开闭原则:新增操作只需添加新访问者
  4. 封装权衡:访问者需要访问元素内部状态,可能破坏封装

实践建议

  1. 确保元素类型稳定后再使用访问者模式
  2. 通过公共接口暴露必要数据,保护封装性
  3. 考虑使用singledispatch简化实现
  4. 对于简单场景,优先考虑其他模式

设计模式系列总结

至此,我们已经完成了GoF设计模式的全部讲解:

  • 创建型模式(5种):单例、工厂方法、抽象工厂、建造者、原型
  • 结构型模式(7种):适配器、桥接、组合、装饰器、外观、享元、代理
  • 行为型模式(11种):责任链、命令、解释器、迭代器、中介者、备忘录、观察者、状态、策略、模板方法、访问者

这些模式是软件设计的基石,理解并灵活运用它们将大大提升代码质量和可维护性。

Python技术丛书 - 江苏省宿城中等专业学校