Skip to content

第8章 组合模式

学习目标

完成本章学习后,读者将能够:

  • 理解组合模式的核心概念和数学形式化定义
  • 掌握树形结构的对象组织方式与递归组合原理
  • 实现透明式与安全式两种组合模式变体
  • 运用Protocol、Generic等现代Python特性实现类型安全的组合结构
  • 识别组合模式的适用场景与反模式
  • 设计企业级组合结构解决方案

8.1 模式定义

8.1.1 核心定义

组合模式(Composite Pattern) 将对象组合成树形结构以表示"部分-整体"(Part-Whole)的层次结构。组合模式使得用户对单个对象(Leaf)和组合对象(Composite)的使用具有一致性。

8.1.2 形式化定义

设 $\mathcal{C}$ 为组件集合,$\mathcal{L}$ 为叶子集合,$\mathcal{K}$ 为组合集合,则组合模式可形式化定义为:

$$\text{Composite}: \mathcal{C} = \mathcal{L} \cup \mathcal{K}$$

其中组合集合满足递归定义:

$$\mathcal{K} = {k \mid k = (id, {c_1, c_2, ..., c_n}), c_i \in \mathcal{C}}$$

树形结构约束:组合模式形成的结构必须满足树的无环性约束:

$$\forall c \in \mathcal{C}: \text{ancestors}(c) \cap \text{descendants}(c) = \emptyset$$

一致性原则:对于任意操作 $op$,叶子与组合的行为满足:

$$op(l) = f_l(l), \quad l \in \mathcal{L}$$ $$op(k) = \bigoplus_{c \in children(k)} op(c), \quad k \in \mathcal{K}$$

其中 $\bigoplus$ 为聚合算子(如求和、连接、逻辑与等)。

复杂度分析

  • 添加操作:$O(1)$(列表追加)
  • 删除操作:$O(n)$(列表查找删除)
  • 遍历操作:$O(n)$($n$ 为节点总数)
  • 空间复杂度:$O(n)$

8.1.3 历史背景与学术脉络

组合模式起源于图形编辑器的设计实践。其学术发展历程如下:

年份里程碑贡献者
1987首次在ET++框架中应用Weinand, Gamma
1991正式提出组合模式概念Gamma, Helm, Johnson, Vlissides
1994GoF《设计模式》收录Gang of Four
1995透明式vs安全式分类讨论Gamma et al.
2002与访问者模式结合研究Gamma, Beck
2010函数式组合模式研究Martin et al.

学术意义:组合模式是递归数据结构在面向对象设计中的经典应用,体现了"组合优于继承"的设计原则,为树形结构的统一处理提供了范式解决方案。


8.2 模式结构与参与者

8.2.1 UML类图

                    ┌─────────────────────────────┐
                    │       <<interface>>         │
                    │        Component            │
                    ├─────────────────────────────┤
                    │ + operation(): Result       │
                    │ + add(c: Component): void   │
                    │ + remove(c: Component): void│
                    │ + get_child(i: int): Component│
                    │ + is_composite(): bool      │
                    └───────────────┬─────────────┘

                    ┌───────────────┴───────────────┐
                    │                               │
        ┌───────────┴───────────┐     ┌────────────┴────────────┐
        │         Leaf          │     │       Composite         │
        ├───────────────────────┤     ├─────────────────────────┤
        │ - name: str           │     │ - children: List[Component]│
        │ - value: Any          │     │ - name: str             │
        ├───────────────────────┤     ├─────────────────────────┤
        │ + operation(): Result │     │ + operation(): Result   │
        │ + is_composite(): bool│     │ + add(c: Component)     │
        │                       │     │ + remove(c: Component)  │
        │                       │     │ + get_child(i): Component│
        │                       │     │ + is_composite(): bool  │
        └───────────────────────┘     └─────────────────────────┘

8.2.2 参与者职责

参与者职责关键方法
Component声明统一接口,定义默认行为operation(), add(), remove()
Leaf表示叶子节点,无子节点实现 operation()
Composite表示组合节点,存储子组件管理子组件,递归调用
Client通过Component接口操作对象统一处理所有组件

8.2.3 协作关系

Client ──────> Component <──────────┐
                 │                  │
                 ▼                  │
           ┌──────────┐             │
           │operation()│            │
           └────┬─────┘             │
                │                   │
        ┌───────┴───────┐           │
        ▼               ▼           │
    ┌───────┐     ┌──────────┐      │
    │ Leaf  │     │Composite │──────┘
    └───────┘     └────┬─────┘

              ┌────────┼────────┐
              ▼        ▼        ▼
           Child1   Child2   Child3

8.3 Python实现

8.3.1 标准实现(ABC抽象基类)

python
from abc import ABC, abstractmethod
from typing import List, Optional, Iterator
from dataclasses import dataclass

class Component(ABC):
    """组件抽象基类 - 定义统一接口"""
    
    @abstractmethod
    def operation(self) -> str:
        """核心操作方法"""
        pass
    
    def add(self, component: 'Component') -> None:
        """添加子组件 - 默认抛出异常(安全式)"""
        raise NotImplementedError("Leaf cannot add children")
    
    def remove(self, component: 'Component') -> None:
        """移除子组件 - 默认抛出异常(安全式)"""
        raise NotImplementedError("Leaf cannot remove children")
    
    def get_child(self, index: int) -> 'Component':
        """获取子组件 - 默认抛出异常"""
        raise NotImplementedError("Leaf has no children")
    
    def is_composite(self) -> bool:
        """判断是否为组合节点"""
        return False
    
    def __iter__(self) -> Iterator['Component']:
        """迭代器支持"""
        return iter([])


class Leaf(Component):
    """叶子节点 - 无子节点的终端节点"""
    
    def __init__(self, name: str, value: any = None):
        self._name = name
        self._value = value
    
    @property
    def name(self) -> str:
        return self._name
    
    @property
    def value(self) -> any:
        return self._value
    
    def operation(self) -> str:
        return f"Leaf({self._name}: {self._value})"


class Composite(Component):
    """组合节点 - 可包含子节点的容器"""
    
    def __init__(self, name: str):
        self._name = name
        self._children: List[Component] = []
    
    @property
    def name(self) -> str:
        return self._name
    
    def operation(self) -> str:
        results = [child.operation() for child in self._children]
        return f"Composite({self._name}): [{', '.join(results)}]"
    
    def add(self, component: Component) -> None:
        if component not in self._children:
            self._children.append(component)
    
    def remove(self, component: Component) -> None:
        if component in self._children:
            self._children.remove(component)
    
    def get_child(self, index: int) -> Component:
        return self._children[index]
    
    def is_composite(self) -> bool:
        return True
    
    def __iter__(self) -> Iterator[Component]:
        return iter(self._children)
    
    def __len__(self) -> int:
        return len(self._children)
    
    def clear(self) -> None:
        self._children.clear()


def client_code(component: Component) -> None:
    """客户端代码 - 统一处理所有组件"""
    print(f"RESULT: {component.operation()}")
    if component.is_composite():
        print(f"  Children count: {len(list(component))}")


if __name__ == "__main__":
    leaf = Leaf("Simple", 42)
    client_code(leaf)
    
    root = Composite("Root")
    branch1 = Composite("Branch1")
    branch1.add(Leaf("Leaf1", 1))
    branch1.add(Leaf("Leaf2", 2))
    
    branch2 = Composite("Branch2")
    branch2.add(Leaf("Leaf3", 3))
    
    root.add(branch1)
    root.add(branch2)
    root.add(Leaf("Leaf4", 4))
    
    client_code(root)

8.3.2 Protocol实现(结构化类型)

python
from typing import Protocol, List, Iterator, runtime_checkable
from dataclasses import dataclass

@runtime_checkable
class ComponentProtocol(Protocol):
    """组件协议 - 结构化类型定义"""
    
    def operation(self) -> str: ...
    def is_composite(self) -> bool: ...


@dataclass
class LeafNode:
    """叶子节点 - 数据类实现"""
    name: str
    value: int = 0
    
    def operation(self) -> str:
        return f"Leaf({self.name}: {self.value})"
    
    def is_composite(self) -> bool:
        return False


class CompositeNode:
    """组合节点 - Protocol兼容实现"""
    
    def __init__(self, name: str):
        self.name = name
        self._children: List[ComponentProtocol] = []
    
    def operation(self) -> str:
        results = [child.operation() for child in self._children]
        return f"Composite({self.name}): [{', '.join(results)}]"
    
    def is_composite(self) -> bool:
        return True
    
    def add(self, component: ComponentProtocol) -> None:
        self._children.append(component)
    
    def remove(self, component: ComponentProtocol) -> None:
        self._children.remove(component)
    
    def __iter__(self) -> Iterator[ComponentProtocol]:
        return iter(self._children)


def process_component(comp: ComponentProtocol) -> None:
    """使用Protocol进行类型检查"""
    print(f"Processing: {comp.operation()}")
    print(f"Is Composite: {comp.is_composite()}")


if __name__ == "__main__":
    leaf = LeafNode("Data", 100)
    composite = CompositeNode("Container")
    composite.add(LeafNode("Item1", 1))
    composite.add(LeafNode("Item2", 2))
    
    process_component(leaf)
    process_component(composite)
    
    print(f"\nType checking:")
    print(f"leaf is ComponentProtocol: {isinstance(leaf, ComponentProtocol)}")
    print(f"composite is ComponentProtocol: {isinstance(composite, ComponentProtocol)}")

8.3.3 泛型实现(类型安全)

python
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, List, Callable, Optional
from dataclasses import dataclass

T = TypeVar('T')
R = TypeVar('R')


class GenericComponent(ABC, Generic[T, R]):
    """泛型组件接口"""
    
    @abstractmethod
    def operation(self) -> R:
        pass
    
    @abstractmethod
    def get_data(self) -> T:
        pass


@dataclass
class GenericLeaf(GenericComponent[T, R]):
    """泛型叶子节点"""
    name: str
    data: T
    transformer: Callable[[T], R]
    
    def operation(self) -> R:
        return self.transformer(self.data)
    
    def get_data(self) -> T:
        return self.data


class GenericComposite(GenericComponent[T, R]):
    """泛型组合节点"""
    
    def __init__(
        self, 
        name: str,
        aggregator: Callable[[List[R]], R]
    ):
        self._name = name
        self._children: List[GenericComponent[T, R]] = []
        self._aggregator = aggregator
    
    def operation(self) -> R:
        results = [child.operation() for child in self._children]
        return self._aggregator(results)
    
    def get_data(self) -> T:
        raise NotImplementedError("Composite has no single data")
    
    def add(self, component: GenericComponent[T, R]) -> None:
        self._children.append(component)
    
    def remove(self, component: GenericComponent[T, R]) -> None:
        self._children.remove(component)


if __name__ == "__main__":
    def sum_aggregator(results: List[int]) -> int:
        return sum(results)
    
    def str_aggregator(results: List[str]) -> str:
        return " | ".join(results)
    
    int_leaf1 = GenericLeaf("A", 10, lambda x: x * 2)
    int_leaf2 = GenericLeaf("B", 20, lambda x: x * 2)
    int_composite = GenericComposite("Numbers", sum_aggregator)
    int_composite.add(int_leaf1)
    int_composite.add(int_leaf2)
    
    print(f"Integer composite: {int_composite.operation()}")
    
    str_leaf1 = GenericLeaf("X", "hello", lambda x: x.upper())
    str_leaf2 = GenericLeaf("Y", "world", lambda x: x.upper())
    str_composite = GenericComposite("Strings", str_aggregator)
    str_composite.add(str_leaf1)
    str_composite.add(str_leaf2)
    
    print(f"String composite: {str_composite.operation()}")

8.3.4 透明式实现

python
from abc import ABC, abstractmethod
from typing import List, Optional

class TransparentComponent(ABC):
    """透明式组件 - 所有方法在基类中声明"""
    
    @abstractmethod
    def operation(self) -> str:
        pass
    
    def add(self, component: 'TransparentComponent') -> None:
        pass
    
    def remove(self, component: 'TransparentComponent') -> None:
        pass
    
    def get_child(self, index: int) -> Optional['TransparentComponent']:
        return None
    
    def is_composite(self) -> bool:
        return False


class TransparentLeaf(TransparentComponent):
    """透明式叶子 - 空实现add/remove"""
    
    def __init__(self, name: str):
        self._name = name
    
    def operation(self) -> str:
        return f"Leaf: {self._name}"


class TransparentComposite(TransparentComponent):
    """透明式组合"""
    
    def __init__(self, name: str):
        self._name = name
        self._children: List[TransparentComponent] = []
    
    def operation(self) -> str:
        results = [child.operation() for child in self._children]
        return f"Branch({self._name}): [{', '.join(results)}]"
    
    def add(self, component: TransparentComponent) -> None:
        self._children.append(component)
    
    def remove(self, component: TransparentComponent) -> None:
        self._children.remove(component)
    
    def get_child(self, index: int) -> Optional[TransparentComponent]:
        if 0 <= index < len(self._children):
            return self._children[index]
        return None
    
    def is_composite(self) -> bool:
        return True


def transparent_client(component: TransparentComponent) -> None:
    """客户端无需区分叶子/组合"""
    print(component.operation())
    
    component.add(TransparentLeaf("New"))
    if component.is_composite():
        print("  Added child successfully")

8.3.5 安全式实现

python
from abc import ABC, abstractmethod
from typing import List, Protocol, runtime_checkable

class SafeComponent(ABC):
    """安全式组件 - 仅声明公共操作"""
    
    @abstractmethod
    def operation(self) -> str:
        pass


class SafeLeaf(SafeComponent):
    """安全式叶子 - 无add/remove方法"""
    
    def __init__(self, name: str):
        self._name = name
    
    def operation(self) -> str:
        return f"Leaf: {self._name}"


class SafeComposite(SafeComponent):
    """安全式组合 - 独有的管理方法"""
    
    def __init__(self, name: str):
        self._name = name
        self._children: List[SafeComponent] = []
    
    def operation(self) -> str:
        results = [child.operation() for child in self._children]
        return f"Branch({self._name}): [{', '.join(results)}]"
    
    def add(self, component: SafeComponent) -> None:
        self._children.append(component)
    
    def remove(self, component: SafeComponent) -> None:
        self._children.remove(component)
    
    def get_children(self) -> List[SafeComponent]:
        return self._children.copy()


def safe_client(component: SafeComponent) -> None:
    """客户端仅能调用公共方法"""
    print(component.operation())


def safe_composite_client(composite: SafeComposite) -> None:
    """组合专用客户端 - 可调用管理方法"""
    composite.add(SafeLeaf("New"))
    print(composite.operation())

8.4 透明式 vs 安全式对比

8.4.1 对比分析

特性透明式安全式
接口设计所有方法在基类声明仅公共方法在基类
类型安全运行时可能出错编译时类型检查
客户端复杂度低,统一处理高,需区分类型
扩展性高,易于添加新组件中,需修改接口
错误处理运行时异常编译时检查
适用场景结构稳定,客户端简单类型安全要求高

8.4.2 选择决策树

                    需要类型安全?
                    /          \
                  是            否
                  /              \
           安全式          客户端需要统一处理?
                              /          \
                            是            否
                            /              \
                       透明式           安全式

8.5 实际应用示例

8.5.1 文件系统实现

python
from abc import ABC, abstractmethod
from typing import List, Optional, Iterator
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
import os

class FileType(Enum):
    FILE = "file"
    DIRECTORY = "directory"


@dataclass
class FileMetadata:
    name: str
    size: int
    created_at: datetime
    modified_at: datetime
    permissions: str = "rw-r--r--"


class FileSystemNode(ABC):
    """文件系统节点抽象"""
    
    @abstractmethod
    def get_name(self) -> str:
        pass
    
    @abstractmethod
    def get_size(self) -> int:
        pass
    
    @abstractmethod
    def get_type(self) -> FileType:
        pass
    
    @abstractmethod
    def list_contents(self, indent: int = 0) -> str:
        pass
    
    def get_path(self) -> str:
        return self.get_name()


class File(FileSystemNode):
    """文件节点"""
    
    def __init__(
        self, 
        name: str, 
        content: str = "",
        permissions: str = "rw-r--r--"
    ):
        self._metadata = FileMetadata(
            name=name,
            size=len(content),
            created_at=datetime.now(),
            modified_at=datetime.now(),
            permissions=permissions
        )
        self._content = content
    
    def get_name(self) -> str:
        return self._metadata.name
    
    def get_size(self) -> int:
        return self._metadata.size
    
    def get_type(self) -> FileType:
        return FileType.FILE
    
    def get_metadata(self) -> FileMetadata:
        return self._metadata
    
    def read(self) -> str:
        return self._content
    
    def write(self, content: str) -> None:
        self._content = content
        self._metadata.size = len(content)
        self._metadata.modified_at = datetime.now()
    
    def append(self, content: str) -> None:
        self._content += content
        self._metadata.size = len(self._content)
        self._metadata.modified_at = datetime.now()
    
    def list_contents(self, indent: int = 0) -> str:
        prefix = "  " * indent
        size_kb = self._metadata.size / 1024
        return f"{prefix}📄 {self._metadata.name} ({size_kb:.1f} KB)"


class Directory(FileSystemNode):
    """目录节点"""
    
    def __init__(self, name: str, permissions: str = "rwxr-xr-x"):
        self._name = name
        self._permissions = permissions
        self._children: List[FileSystemNode] = []
        self._created_at = datetime.now()
        self._modified_at = datetime.now()
    
    def get_name(self) -> str:
        return self._name
    
    def get_size(self) -> int:
        return sum(child.get_size() for child in self._children)
    
    def get_type(self) -> FileType:
        return FileType.DIRECTORY
    
    def add(self, node: FileSystemNode) -> None:
        if node not in self._children:
            self._children.append(node)
            self._modified_at = datetime.now()
    
    def remove(self, node: FileSystemNode) -> None:
        if node in self._children:
            self._children.remove(node)
            self._modified_at = datetime.now()
    
    def get_children(self) -> List[FileSystemNode]:
        return self._children.copy()
    
    def find(self, name: str) -> Optional[FileSystemNode]:
        for child in self._children:
            if child.get_name() == name:
                return child
            if child.get_type() == FileType.DIRECTORY:
                result = child.find(name)
                if result:
                    return result
        return None
    
    def list_contents(self, indent: int = 0) -> str:
        prefix = "  " * indent
        lines = [f"{prefix}📁 {self._name}/"]
        for child in sorted(self._children, key=lambda x: x.get_name()):
            lines.append(child.list_contents(indent + 1))
        return "\n".join(lines)
    
    def tree(self, max_depth: int = -1, current_depth: int = 0) -> str:
        """生成树形结构视图"""
        if max_depth >= 0 and current_depth > max_depth:
            return ""
        
        prefix = "│   " * current_depth
        lines = []
        
        for i, child in enumerate(sorted(self._children, key=lambda x: x.get_name())):
            is_last = i == len(self._children) - 1
            connector = "└── " if is_last else "├── "
            
            if child.get_type() == FileType.DIRECTORY:
                lines.append(f"{prefix}{connector}{child.get_name()}/")
                lines.append(child.tree(max_depth, current_depth + 1))
            else:
                lines.append(f"{prefix}{connector}{child.get_name()}")
        
        return "\n".join(filter(None, lines))
    
    def get_file_count(self) -> int:
        count = 0
        for child in self._children:
            if child.get_type() == FileType.FILE:
                count += 1
            else:
                count += child.get_file_count()
        return count
    
    def get_directory_count(self) -> int:
        count = 0
        for child in self._children:
            if child.get_type() == FileType.DIRECTORY:
                count += 1 + child.get_directory_count()
        return count


def create_sample_filesystem() -> Directory:
    root = Directory("project")
    
    src = Directory("src")
    src.add(File("main.py", "print('Hello')"))
    src.add(File("utils.py", "def helper(): pass"))
    
    models = Directory("models")
    models.add(File("user.py", "class User: pass"))
    models.add(File("product.py", "class Product: pass"))
    src.add(models)
    
    tests = Directory("tests")
    tests.add(File("test_main.py", "def test_main(): pass"))
    tests.add(File("test_utils.py", "def test_utils(): pass"))
    
    docs = Directory("docs")
    docs.add(File("README.md", "# Project Documentation"))
    docs.add(File("API.md", "## API Reference"))
    
    root.add(src)
    root.add(tests)
    root.add(docs)
    root.add(File("requirements.txt", "pytest\nrequests"))
    root.add(File("setup.py", "# setup configuration"))
    
    return root


if __name__ == "__main__":
    fs = create_sample_filesystem()
    
    print("=== 文件系统结构 ===")
    print(fs.list_contents())
    
    print(f"\n=== 树形视图 ===")
    print(f"project/")
    print(fs.tree())
    
    print(f"\n=== 统计信息 ===")
    print(f"总大小: {fs.get_size() / 1024:.2f} KB")
    print(f"文件数: {fs.get_file_count()}")
    print(f"目录数: {fs.get_directory_count()}")
    
    print(f"\n=== 查找文件 ===")
    found = fs.find("main.py")
    if found:
        print(f"找到: {found.get_name()}, 大小: {found.get_size()} bytes")

8.5.2 组织架构系统

python
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any
from dataclasses import dataclass, field
from datetime import date
from enum import Enum

class DepartmentType(Enum):
    HEADQUARTERS = "总部"
    DIVISION = "事业部"
    DEPARTMENT = "部门"
    TEAM = "小组"


@dataclass
class EmployeeInfo:
    """员工信息"""
    id: str
    name: str
    position: str
    salary: float
    hire_date: date
    skills: List[str] = field(default_factory=list)
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "id": self.id,
            "name": self.name,
            "position": self.position,
            "salary": self.salary,
            "hire_date": self.hire_date.isoformat(),
            "skills": self.skills
        }


class OrganizationComponent(ABC):
    """组织架构组件"""
    
    @abstractmethod
    def get_name(self) -> str:
        pass
    
    @abstractmethod
    def get_budget(self) -> float:
        pass
    
    @abstractmethod
    def get_employee_count(self) -> int:
        pass
    
    @abstractmethod
    def get_structure(self, indent: int = 0) -> str:
        pass
    
    @abstractmethod
    def to_dict(self) -> Dict[str, Any]:
        pass
    
    def get_average_salary(self) -> float:
        count = self.get_employee_count()
        return self.get_budget() / count if count > 0 else 0


class Employee(OrganizationComponent):
    """员工节点"""
    
    def __init__(self, info: EmployeeInfo):
        self._info = info
    
    @property
    def info(self) -> EmployeeInfo:
        return self._info
    
    def get_name(self) -> str:
        return f"{self._info.name} ({self._info.position})"
    
    def get_budget(self) -> float:
        return self._info.salary
    
    def get_employee_count(self) -> int:
        return 1
    
    def get_structure(self, indent: int = 0) -> str:
        prefix = "  " * indent
        return f"{prefix}👤 {self._info.name} - {self._info.position}{self._info.salary:,.0f}/月)"
    
    def to_dict(self) -> Dict[str, Any]:
        return self._info.to_dict()


class Department(OrganizationComponent):
    """部门节点"""
    
    def __init__(
        self, 
        name: str, 
        dept_type: DepartmentType,
        manager: Optional[str] = None,
        budget_allocation: float = 0
    ):
        self._name = name
        self._dept_type = dept_type
        self._manager = manager
        self._budget_allocation = budget_allocation
        self._components: List[OrganizationComponent] = []
    
    @property
    def dept_type(self) -> DepartmentType:
        return self._dept_type
    
    @property
    def manager(self) -> Optional[str]:
        return self._manager
    
    def get_name(self) -> str:
        return self._name
    
    def get_budget(self) -> float:
        return self._budget_allocation + sum(
            comp.get_budget() for comp in self._components
        )
    
    def get_employee_count(self) -> int:
        return sum(comp.get_employee_count() for comp in self._components)
    
    def add(self, component: OrganizationComponent) -> None:
        self._components.append(component)
    
    def remove(self, component: OrganizationComponent) -> None:
        self._components.remove(component)
    
    def get_components(self) -> List[OrganizationComponent]:
        return self._components.copy()
    
    def get_structure(self, indent: int = 0) -> str:
        prefix = "  " * indent
        lines = [f"{prefix}🏢 {self._name} [{self._dept_type.value}]"]
        if self._manager:
            lines[-1] += f" (负责人: {self._manager})"
        
        for comp in self._components:
            lines.append(comp.get_structure(indent + 1))
        
        return "\n".join(lines)
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "name": self._name,
            "type": self._dept_type.value,
            "manager": self._manager,
            "budget": self.get_budget(),
            "employee_count": self.get_employee_count(),
            "children": [comp.to_dict() for comp in self._components]
        }
    
    def find_employee(self, name: str) -> Optional[Employee]:
        """按姓名查找员工"""
        for comp in self._components:
            if isinstance(comp, Employee):
                if comp.info.name == name:
                    return comp
            elif isinstance(comp, Department):
                result = comp.find_employee(name)
                if result:
                    return result
        return None
    
    def get_all_employees(self) -> List[Employee]:
        """获取所有员工"""
        employees = []
        for comp in self._components:
            if isinstance(comp, Employee):
                employees.append(comp)
            elif isinstance(comp, Department):
                employees.extend(comp.get_all_employees())
        return employees
    
    def get_salary_statistics(self) -> Dict[str, float]:
        """薪资统计"""
        employees = self.get_all_employees()
        if not employees:
            return {"min": 0, "max": 0, "avg": 0, "total": 0}
        
        salaries = [e.info.salary for e in employees]
        return {
            "min": min(salaries),
            "max": max(salaries),
            "avg": sum(salaries) / len(salaries),
            "total": sum(salaries)
        }
    
    def get_skill_matrix(self) -> Dict[str, List[str]]:
        """技能矩阵"""
        matrix: Dict[str, List[str]] = {}
        for emp in self.get_all_employees():
            for skill in emp.info.skills:
                if skill not in matrix:
                    matrix[skill] = []
                matrix[skill].append(emp.info.name)
        return matrix


def create_organization() -> Department:
    company = Department("科技创新公司", DepartmentType.HEADQUARTERS, "张总", 500000)
    
    tech_division = Department("技术事业部", DepartmentType.DIVISION, "李副总", 200000)
    
    backend_dept = Department("后端开发部", DepartmentType.DEPARTMENT, "王经理", 50000)
    backend_dept.add(Employee(EmployeeInfo(
        "E001", "赵工", "高级架构师", 45000,
        date(2020, 3, 15), ["Python", "Go", "Kubernetes"]
    )))
    backend_dept.add(Employee(EmployeeInfo(
        "E002", "钱工", "高级工程师", 30000,
        date(2021, 6, 20), ["Python", "Django", "PostgreSQL"]
    )))
    backend_dept.add(Employee(EmployeeInfo(
        "E003", "孙工", "工程师", 22000,
        date(2022, 1, 10), ["Python", "FastAPI"]
    )))
    
    frontend_dept = Department("前端开发部", DepartmentType.DEPARTMENT, "周经理", 40000)
    frontend_dept.add(Employee(EmployeeInfo(
        "E004", "吴工", "高级前端工程师", 28000,
        date(2021, 4, 5), ["React", "TypeScript", "Node.js"]
    )))
    frontend_dept.add(Employee(EmployeeInfo(
        "E005", "郑工", "前端工程师", 20000,
        date(2022, 8, 15), ["Vue", "JavaScript"]
    )))
    
    devops_team = Department("DevOps组", DepartmentType.TEAM, "冯组长", 30000)
    devops_team.add(Employee(EmployeeInfo(
        "E006", "陈工", "DevOps工程师", 25000,
        date(2021, 9, 1), ["Docker", "Kubernetes", "CI/CD"]
    )))
    
    tech_division.add(backend_dept)
    tech_division.add(frontend_dept)
    tech_division.add(devops_team)
    
    hr_division = Department("人力资源部", DepartmentType.DIVISION, "林经理", 80000)
    hr_division.add(Employee(EmployeeInfo(
        "E007", "黄专员", "HRBP", 15000,
        date(2022, 3, 20), ["招聘", "培训", "绩效管理"]
    )))
    hr_division.add(Employee(EmployeeInfo(
        "E008", "许专员", "招聘专员", 12000,
        date(2023, 1, 15), ["招聘", "面试"]
    )))
    
    finance_division = Department("财务部", DepartmentType.DIVISION, "何经理", 60000)
    finance_division.add(Employee(EmployeeInfo(
        "E009", "沈会计", "高级会计", 18000,
        date(2021, 5, 10), ["财务分析", "税务筹划"]
    )))
    finance_division.add(Employee(EmployeeInfo(
        "E010", "韩出纳", "出纳", 10000,
        date(2023, 2, 1), ["现金管理", "银行结算"]
    )))
    
    company.add(tech_division)
    company.add(hr_division)
    company.add(finance_division)
    
    return company


if __name__ == "__main__":
    org = create_organization()
    
    print("=== 组织架构 ===")
    print(org.get_structure())
    
    print(f"\n=== 统计信息 ===")
    print(f"总人数: {org.get_employee_count()}")
    print(f"总预算: ¥{org.get_budget():,.0f}")
    print(f"平均薪资: ¥{org.get_average_salary():,.0f}")
    
    print(f"\n=== 薪资统计 ===")
    stats = org.get_salary_statistics()
    print(f"最低薪资: ¥{stats['min']:,.0f}")
    print(f"最高薪资: ¥{stats['max']:,.0f}")
    print(f"平均薪资: ¥{stats['avg']:,.0f}")
    
    print(f"\n=== 技能矩阵 ===")
    skills = org.get_skill_matrix()
    for skill, employees in sorted(skills.items()):
        print(f"{skill}: {', '.join(employees)}")
    
    print(f"\n=== 查找员工 ===")
    emp = org.find_employee("赵工")
    if emp:
        print(f"找到: {emp.info.name}, 职位: {emp.info.position}")

8.5.3 图形编辑器

python
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, Dict, Any
from dataclasses import dataclass, field
from enum import Enum
import json

class ShapeType(Enum):
    CIRCLE = "circle"
    RECTANGLE = "rectangle"
    LINE = "line"
    GROUP = "group"


@dataclass
class Point:
    x: float
    y: float
    
    def translate(self, dx: float, dy: float) -> 'Point':
        return Point(self.x + dx, self.y + dy)
    
    def distance_to(self, other: 'Point') -> float:
        return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5


@dataclass
class Color:
    r: int
    g: int
    b: int
    a: float = 1.0
    
    def to_hex(self) -> str:
        return f"#{self.r:02x}{self.g:02x}{self.b:02x}"
    
    def to_rgba(self) -> str:
        return f"rgba({self.r}, {self.g}, {self.b}, {self.a})"


@dataclass
class Style:
    fill_color: Optional[Color] = None
    stroke_color: Optional[Color] = None
    stroke_width: float = 1.0
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "fill": self.fill_color.to_hex() if self.fill_color else None,
            "stroke": self.stroke_color.to_hex() if self.stroke_color else None,
            "stroke_width": self.stroke_width
        }


class Graphic(ABC):
    """图形组件抽象"""
    
    @abstractmethod
    def draw(self) -> str:
        pass
    
    @abstractmethod
    def get_bounds(self) -> Tuple[Point, Point]:
        pass
    
    @abstractmethod
    def move(self, dx: float, dy: float) -> None:
        pass
    
    @abstractmethod
    def contains_point(self, point: Point) -> bool:
        pass
    
    @abstractmethod
    def get_type(self) -> ShapeType:
        pass
    
    @abstractmethod
    def to_dict(self) -> Dict[str, Any]:
        pass
    
    def get_center(self) -> Point:
        min_pt, max_pt = self.get_bounds()
        return Point(
            (min_pt.x + max_pt.x) / 2,
            (min_pt.y + max_pt.y) / 2
        )
    
    def get_width(self) -> float:
        min_pt, max_pt = self.get_bounds()
        return max_pt.x - min_pt.x
    
    def get_height(self) -> float:
        min_pt, max_pt = self.get_bounds()
        return max_pt.y - min_pt.y


class Circle(Graphic):
    """圆形"""
    
    def __init__(self, center: Point, radius: float, style: Style = None):
        self._center = center
        self._radius = radius
        self._style = style or Style()
    
    @property
    def center(self) -> Point:
        return self._center
    
    @property
    def radius(self) -> float:
        return self._radius
    
    def draw(self) -> str:
        return f"Circle(center=({self._center.x}, {self._center.y}), r={self._radius})"
    
    def get_bounds(self) -> Tuple[Point, Point]:
        return (
            Point(self._center.x - self._radius, self._center.y - self._radius),
            Point(self._center.x + self._radius, self._center.y + self._radius)
        )
    
    def move(self, dx: float, dy: float) -> None:
        self._center = self._center.translate(dx, dy)
    
    def contains_point(self, point: Point) -> bool:
        return self._center.distance_to(point) <= self._radius
    
    def get_type(self) -> ShapeType:
        return ShapeType.CIRCLE
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": "circle",
            "center": {"x": self._center.x, "y": self._center.y},
            "radius": self._radius,
            "style": self._style.to_dict()
        }
    
    def scale(self, factor: float) -> None:
        self._radius *= factor


class Rectangle(Graphic):
    """矩形"""
    
    def __init__(
        self, 
        top_left: Point, 
        width: float, 
        height: float,
        style: Style = None
    ):
        self._top_left = top_left
        self._width = width
        self._height = height
        self._style = style or Style()
    
    @property
    def top_left(self) -> Point:
        return self._top_left
    
    @property
    def width(self) -> float:
        return self._width
    
    @property
    def height(self) -> float:
        return self._height
    
    def draw(self) -> str:
        return f"Rectangle(pos=({self._top_left.x}, {self._top_left.y}), size={self._width}x{self._height})"
    
    def get_bounds(self) -> Tuple[Point, Point]:
        return (
            self._top_left,
            Point(self._top_left.x + self._width, self._top_left.y + self._height)
        )
    
    def move(self, dx: float, dy: float) -> None:
        self._top_left = self._top_left.translate(dx, dy)
    
    def contains_point(self, point: Point) -> bool:
        min_pt, max_pt = self.get_bounds()
        return (min_pt.x <= point.x <= max_pt.x and 
                min_pt.y <= point.y <= max_pt.y)
    
    def get_type(self) -> ShapeType:
        return ShapeType.RECTANGLE
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": "rectangle",
            "top_left": {"x": self._top_left.x, "y": self._top_left.y},
            "width": self._width,
            "height": self._height,
            "style": self._style.to_dict()
        }


class Line(Graphic):
    """线段"""
    
    def __init__(self, start: Point, end: Point, style: Style = None):
        self._start = start
        self._end = end
        self._style = style or Style(stroke_color=Color(0, 0, 0))
    
    def draw(self) -> str:
        return f"Line(({self._start.x}, {self._start.y}) -> ({self._end.x}, {self._end.y}))"
    
    def get_bounds(self) -> Tuple[Point, Point]:
        return (
            Point(min(self._start.x, self._end.x), min(self._start.y, self._end.y)),
            Point(max(self._start.x, self._end.x), max(self._start.y, self._end.y))
        )
    
    def move(self, dx: float, dy: float) -> None:
        self._start = self._start.translate(dx, dy)
        self._end = self._end.translate(dx, dy)
    
    def contains_point(self, point: Point, threshold: float = 5.0) -> bool:
        dist = abs(
            (self._end.y - self._start.y) * point.x -
            (self._end.x - self._start.x) * point.y +
            self._end.x * self._start.y -
            self._end.y * self._start.x
        ) / ((self._end.y - self._start.y) ** 2 + (self._end.x - self._start.x) ** 2) ** 0.5
        return dist <= threshold
    
    def get_type(self) -> ShapeType:
        return ShapeType.LINE
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": "line",
            "start": {"x": self._start.x, "y": self._start.y},
            "end": {"x": self._end.x, "y": self._end.y},
            "style": self._style.to_dict()
        }
    
    @property
    def length(self) -> float:
        return self._start.distance_to(self._end)


class GraphicGroup(Graphic):
    """图形组合"""
    
    def __init__(self, name: str = "Group"):
        self._name = name
        self._graphics: List[Graphic] = []
        self._style = Style()
    
    @property
    def name(self) -> str:
        return self._name
    
    def add(self, graphic: Graphic) -> None:
        self._graphics.append(graphic)
    
    def remove(self, graphic: Graphic) -> None:
        self._graphics.remove(graphic)
    
    def get_graphics(self) -> List[Graphic]:
        return self._graphics.copy()
    
    def draw(self) -> str:
        results = [f"Group '{self._name}':"]
        for g in self._graphics:
            for line in g.draw().split('\n'):
                results.append(f"  {line}")
        return '\n'.join(results)
    
    def get_bounds(self) -> Tuple[Point, Point]:
        if not self._graphics:
            return (Point(0, 0), Point(0, 0))
        
        all_bounds = [g.get_bounds() for g in self._graphics]
        min_x = min(b[0].x for b in all_bounds)
        min_y = min(b[0].y for b in all_bounds)
        max_x = max(b[1].x for b in all_bounds)
        max_y = max(b[1].y for b in all_bounds)
        
        return (Point(min_x, min_y), Point(max_x, max_y))
    
    def move(self, dx: float, dy: float) -> None:
        for graphic in self._graphics:
            graphic.move(dx, dy)
    
    def contains_point(self, point: Point) -> bool:
        return any(g.contains_point(point) for g in self._graphics)
    
    def get_type(self) -> ShapeType:
        return ShapeType.GROUP
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": "group",
            "name": self._name,
            "children": [g.to_dict() for g in self._graphics]
        }
    
    def flatten(self) -> List[Graphic]:
        """展平组合,返回所有基本图形"""
        result = []
        for g in self._graphics:
            if g.get_type() == ShapeType.GROUP:
                result.extend(g.flatten())
            else:
                result.append(g)
        return result
    
    def find_at_point(self, point: Point) -> Optional[Graphic]:
        """查找指定位置的图形"""
        for g in reversed(self._graphics):
            if g.contains_point(point):
                if g.get_type() == ShapeType.GROUP:
                    found = g.find_at_point(point)
                    if found:
                        return found
                return g
        return None
    
    def scale(self, factor: float, center: Point = None) -> None:
        """缩放组合"""
        if center is None:
            center = self.get_center()
        
        for g in self._graphics:
            if g.get_type() == ShapeType.GROUP:
                g.scale(factor, center)
            elif g.get_type() == ShapeType.CIRCLE:
                g.scale(factor)
            elif hasattr(g, '_top_left'):
                g._top_left = Point(
                    center.x + (g._top_left.x - center.x) * factor,
                    center.y + (g._top_left.y - center.y) * factor
                )
                g._width *= factor
                g._height *= factor


def create_drawing() -> GraphicGroup:
    canvas = GraphicGroup("Canvas")
    
    red = Color(255, 0, 0)
    blue = Color(0, 0, 255)
    green = Color(0, 255, 0)
    
    face = GraphicGroup("Face")
    face.add(Circle(Point(200, 200), 100, Style(fill_color=Color(255, 220, 180))))
    face.add(Circle(Point(165, 175), 10, Style(fill_color=blue)))
    face.add(Circle(Point(235, 175), 10, Style(fill_color=blue)))
    face.add(Rectangle(Point(175, 220), 50, 10, Style(fill_color=red)))
    
    house = GraphicGroup("House")
    house.add(Rectangle(Point(400, 250), 120, 100, Style(fill_color=Color(200, 150, 100))))
    house.add(Rectangle(Point(440, 280), 40, 70, Style(fill_color=Color(139, 90, 43))))
    house.add(Rectangle(Point(410, 270), 25, 25, Style(fill_color=Color(135, 206, 235))))
    house.add(Rectangle(Point(485, 270), 25, 25, Style(fill_color=Color(135, 206, 235))))
    
    canvas.add(face)
    canvas.add(house)
    canvas.add(Line(Point(50, 350), Point(550, 350), Style(stroke_color=green, stroke_width=2)))
    
    return canvas


if __name__ == "__main__":
    drawing = create_drawing()
    
    print("=== 图形结构 ===")
    print(drawing.draw())
    
    print(f"\n=== 边界信息 ===")
    bounds = drawing.get_bounds()
    print(f"左上角: ({bounds[0].x}, {bounds[0].y})")
    print(f"右下角: ({bounds[1].x}, {bounds[1].y})")
    
    print(f"\n=== 图形统计 ===")
    shapes = drawing.flatten()
    print(f"基本图形数量: {len(shapes)}")
    for shape_type in ShapeType:
        if shape_type != ShapeType.GROUP:
            count = sum(1 for s in shapes if s.get_type() == shape_type)
            print(f"  {shape_type.value}: {count}")
    
    print(f"\n=== JSON导出 ===")
    print(json.dumps(drawing.to_dict(), indent=2, ensure_ascii=False))
    
    print(f"\n=== 点击测试 ===")
    test_point = Point(200, 200)
    found = drawing.find_at_point(test_point)
    if found:
        print(f"在 ({test_point.x}, {test_point.y}) 找到: {found.draw()}")

8.6 企业级应用示例

8.6.1 配置管理系统

python
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Set
from dataclasses import dataclass, field
from copy import deepcopy
import json

class ConfigNode(ABC):
    """配置节点抽象"""
    
    @abstractmethod
    def get(self, key: str, default: Any = None) -> Any:
        pass
    
    @abstractmethod
    def set(self, key: str, value: Any) -> None:
        pass
    
    @abstractmethod
    def to_dict(self) -> Dict[str, Any]:
        pass
    
    @abstractmethod
    def keys(self) -> Set[str]:
        pass
    
    def merge(self, other: 'ConfigNode') -> 'ConfigNode':
        raise NotImplementedError


@dataclass
class ConfigValue(ConfigNode):
    """配置值节点"""
    _value: Any
    _metadata: Dict[str, Any] = field(default_factory=dict)
    
    def get(self, key: str, default: Any = None) -> Any:
        if key == "value":
            return self._value
        return self._metadata.get(key, default)
    
    def set(self, key: str, value: Any) -> None:
        if key == "value":
            self._value = value
        else:
            self._metadata[key] = value
    
    def to_dict(self) -> Dict[str, Any]:
        result = {"value": self._value}
        if self._metadata:
            result["metadata"] = self._metadata
        return result
    
    def keys(self) -> Set[str]:
        return {"value"} | set(self._metadata.keys())
    
    @property
    def value(self) -> Any:
        return self._value


class ConfigSection(ConfigNode):
    """配置节节点"""
    
    def __init__(self, name: str, parent: Optional['ConfigSection'] = None):
        self._name = name
        self._parent = parent
        self._children: Dict[str, ConfigNode] = {}
        self._inherit_from: List[str] = []
    
    @property
    def name(self) -> str:
        return self._name
    
    @property
    def path(self) -> str:
        if self._parent:
            return f"{self._parent.path}.{self._name}"
        return self._name
    
    def get(self, key: str, default: Any = None) -> Any:
        if '.' in key:
            first, rest = key.split('.', 1)
            child = self._children.get(first)
            if child:
                return child.get(rest, default)
            return default
        
        if key in self._children:
            child = self._children[key]
            if isinstance(child, ConfigValue):
                return child.value
            return child
        
        for inherit_path in self._inherit_from:
            inherited = self._resolve_inheritance(inherit_path)
            if inherited:
                result = inherited.get(key, None)
                if result is not None:
                    return result
        
        return default
    
    def set(self, key: str, value: Any) -> None:
        if '.' in key:
            first, rest = key.split('.', 1)
            if first not in self._children:
                self._children[first] = ConfigSection(first, self)
            self._children[first].set(rest, value)
        else:
            if isinstance(value, ConfigNode):
                self._children[key] = value
            else:
                self._children[key] = ConfigValue(value)
    
    def add_section(self, name: str) -> 'ConfigSection':
        section = ConfigSection(name, self)
        self._children[name] = section
        return section
    
    def add_inheritance(self, path: str) -> None:
        self._inherit_from.append(path)
    
    def _resolve_inheritance(self, path: str) -> Optional['ConfigSection']:
        parts = path.split('.')
        current: Optional[ConfigSection] = self._get_root()
        
        for part in parts:
            if current is None:
                return None
            child = current._children.get(part)
            if isinstance(child, ConfigSection):
                current = child
            else:
                return None
        
        return current
    
    def _get_root(self) -> 'ConfigSection':
        current = self
        while current._parent:
            current = current._parent
        return current
    
    def to_dict(self) -> Dict[str, Any]:
        result = {}
        for key, child in self._children.items():
            result[key] = child.to_dict()
        if self._inherit_from:
            result["_inherit"] = self._inherit_from
        return result
    
    def keys(self) -> Set[str]:
        result = set(self._children.keys())
        for inherit_path in self._inherit_from:
            inherited = self._resolve_inheritance(inherit_path)
            if inherited:
                result.update(inherited.keys())
        return result
    
    def merge(self, other: ConfigNode) -> 'ConfigSection':
        if isinstance(other, ConfigSection):
            for key, child in other._children.items():
                if key in self._children:
                    self_child = self._children[key]
                    if isinstance(self_child, ConfigSection) and isinstance(child, ConfigSection):
                        self_child.merge(child)
                    else:
                        self._children[key] = deepcopy(child)
                else:
                    self._children[key] = deepcopy(child)
        return self
    
    def flatten(self) -> Dict[str, Any]:
        """展平配置为点分隔键值对"""
        result = {}
        for key, child in self._children.items():
            if isinstance(child, ConfigSection):
                for sub_key, value in child.flatten().items():
                    result[f"{key}.{sub_key}"] = value
            elif isinstance(child, ConfigValue):
                result[key] = child.value
        return result
    
    def validate(self, schema: Dict[str, Any]) -> List[str]:
        """验证配置"""
        errors = []
        for key, rules in schema.items():
            value = self.get(key)
            if rules.get("required", False) and value is None:
                errors.append(f"Missing required key: {key}")
            if value is not None and "type" in rules:
                if not isinstance(value, rules["type"]):
                    errors.append(f"Invalid type for {key}: expected {rules['type']}, got {type(value)}")
        return errors


def create_app_config() -> ConfigSection:
    root = ConfigSection("app")
    
    root.set("name", "MyApplication")
    root.set("version", "1.0.0")
    root.set("debug", False)
    
    database = root.add_section("database")
    database.set("host", "localhost")
    database.set("port", 5432)
    database.set("name", "myapp_db")
    database.set("pool_size", 10)
    
    cache = root.add_section("cache")
    cache.set("backend", "redis")
    cache.set("host", "localhost")
    cache.set("port", 6379)
    cache.set("ttl", 3600)
    
    logging_config = root.add_section("logging")
    logging_config.set("level", "INFO")
    logging_config.set("format", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    
    handlers = logging_config.add_section("handlers")
    handlers.set("console", {"enabled": True, "level": "DEBUG"})
    handlers.set("file", {"enabled": True, "level": "INFO", "path": "/var/log/app.log"})
    
    features = root.add_section("features")
    features.set("new_ui", True)
    features.set("beta_api", False)
    features.set("analytics", True)
    
    return root


if __name__ == "__main__":
    config = create_app_config()
    
    print("=== 配置结构 ===")
    print(json.dumps(config.to_dict(), indent=2))
    
    print(f"\n=== 获取配置值 ===")
    print(f"应用名称: {config.get('name')}")
    print(f"数据库主机: {config.get('database.host')}")
    print(f"缓存TTL: {config.get('cache.ttl')}")
    print(f"日志级别: {config.get('logging.level')}")
    print(f"不存在的键: {config.get('nonexistent.key', '默认值')}")
    
    print(f"\n=== 展平配置 ===")
    flat = config.flatten()
    for key, value in sorted(flat.items()):
        print(f"{key}: {value}")
    
    print(f"\n=== 配置验证 ===")
    schema = {
        "name": {"required": True, "type": str},
        "version": {"required": True, "type": str},
        "database.host": {"required": True, "type": str},
        "database.port": {"required": True, "type": int},
    }
    errors = config.validate(schema)
    if errors:
        for error in errors:
            print(f"错误: {error}")
    else:
        print("配置验证通过")

8.6.2 表达式解析与求值

python
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
import re

class Expression(ABC):
    """表达式抽象"""
    
    @abstractmethod
    def evaluate(self, context: Dict[str, Any]) -> Any:
        pass
    
    @abstractmethod
    def to_string(self) -> str:
        pass
    
    @abstractmethod
    def get_variables(self) -> set:
        pass
    
    def __repr__(self) -> str:
        return self.to_string()


@dataclass
class Number(Expression):
    """数字字面量"""
    value: float
    
    def evaluate(self, context: Dict[str, Any]) -> float:
        return self.value
    
    def to_string(self) -> str:
        return str(self.value)
    
    def get_variables(self) -> set:
        return set()


@dataclass
class Variable(Expression):
    """变量"""
    name: str
    
    def evaluate(self, context: Dict[str, Any]) -> Any:
        if self.name not in context:
            raise NameError(f"Variable '{self.name}' not defined")
        return context[self.name]
    
    def to_string(self) -> str:
        return self.name
    
    def get_variables(self) -> set:
        return {self.name}


@dataclass
class BinaryOp(Expression):
    """二元运算"""
    left: Expression
    operator: str
    right: Expression
    
    _operators: Dict[str, Callable] = {
        '+': lambda a, b: a + b,
        '-': lambda a, b: a - b,
        '*': lambda a, b: a * b,
        '/': lambda a, b: a / b if b != 0 else float('inf'),
        '%': lambda a, b: a % b,
        '**': lambda a, b: a ** b,
        '//': lambda a, b: a // b,
    }
    
    _precedence: Dict[str, int] = {
        '+': 1, '-': 1,
        '*': 2, '/': 2, '%': 2, '//': 2,
        '**': 3,
    }
    
    def evaluate(self, context: Dict[str, Any]) -> Any:
        left_val = self.left.evaluate(context)
        right_val = self.right.evaluate(context)
        
        if self.operator not in self._operators:
            raise ValueError(f"Unknown operator: {self.operator}")
        
        return self._operators[self.operator](left_val, right_val)
    
    def to_string(self) -> str:
        left_str = self.left.to_string()
        right_str = self.right.to_string()
        
        if isinstance(self.left, BinaryOp) and self._precedence.get(self.left.operator, 0) < self._precedence.get(self.operator, 0):
            left_str = f"({left_str})"
        if isinstance(self.right, BinaryOp) and self._precedence.get(self.right.operator, 0) < self._precedence.get(self.operator, 0):
            right_str = f"({right_str})"
        
        return f"{left_str} {self.operator} {right_str}"
    
    def get_variables(self) -> set:
        return self.left.get_variables() | self.right.get_variables()


@dataclass
class UnaryOp(Expression):
    """一元运算"""
    operator: str
    operand: Expression
    
    def evaluate(self, context: Dict[str, Any]) -> Any:
        val = self.operand.evaluate(context)
        if self.operator == '-':
            return -val
        elif self.operator == '+':
            return +val
        raise ValueError(f"Unknown unary operator: {self.operator}")
    
    def to_string(self) -> str:
        return f"{self.operator}{self.operand.to_string()}"
    
    def get_variables(self) -> set:
        return self.operand.get_variables()


@dataclass
class FunctionCall(Expression):
    """函数调用"""
    name: str
    args: List[Expression]
    
    _functions: Dict[str, Callable] = {
        'abs': abs,
        'round': round,
        'min': min,
        'max': max,
        'sum': sum,
        'sqrt': lambda x: x ** 0.5,
        'sin': __import__('math').sin,
        'cos': __import__('math').cos,
    }
    
    def evaluate(self, context: Dict[str, Any]) -> Any:
        if self.name not in self._functions:
            raise NameError(f"Unknown function: {self.name}")
        
        arg_values = [arg.evaluate(context) for arg in self.args]
        return self._functions[self.name](*arg_values)
    
    def to_string(self) -> str:
        args_str = ", ".join(arg.to_string() for arg in self.args)
        return f"{self.name}({args_str})"
    
    def get_variables(self) -> set:
        result = set()
        for arg in self.args:
            result.update(arg.get_variables())
        return result


class ExpressionParser:
    """表达式解析器"""
    
    def __init__(self):
        self._tokens: List[str] = []
        self._pos: int = 0
    
    def parse(self, expr: str) -> Expression:
        self._tokens = self._tokenize(expr)
        self._pos = 0
        result = self._parse_expression()
        if self._pos < len(self._tokens):
            raise SyntaxError(f"Unexpected token: {self._tokens[self._pos]}")
        return result
    
    def _tokenize(self, expr: str) -> List[str]:
        pattern = r'(\d+\.?\d*|\w+|[+\-*/%()]|\*\*|//)'
        tokens = re.findall(pattern, expr.replace(' ', ''))
        return tokens
    
    def _parse_expression(self, min_prec: int = 0) -> Expression:
        left = self._parse_primary()
        
        while self._pos < len(self._tokens):
            op = self._tokens[self._pos]
            prec = BinaryOp._precedence.get(op, -1)
            
            if prec < min_prec:
                break
            
            self._pos += 1
            right = self._parse_expression(prec + 1)
            left = BinaryOp(left, op, right)
        
        return left
    
    def _parse_primary(self) -> Expression:
        if self._pos >= len(self._tokens):
            raise SyntaxError("Unexpected end of expression")
        
        token = self._tokens[self._pos]
        
        if token == '(':
            self._pos += 1
            expr = self._parse_expression()
            if self._pos >= len(self._tokens) or self._tokens[self._pos] != ')':
                raise SyntaxError("Missing closing parenthesis")
            self._pos += 1
            return expr
        
        if token in ('+', '-'):
            self._pos += 1
            return UnaryOp(token, self._parse_primary())
        
        if re.match(r'^\d+\.?\d*$', token):
            self._pos += 1
            return Number(float(token))
        
        if re.match(r'^\w+$', token):
            self._pos += 1
            if self._pos < len(self._tokens) and self._tokens[self._pos] == '(':
                return self._parse_function_call(token)
            return Variable(token)
        
        raise SyntaxError(f"Unexpected token: {token}")
    
    def _parse_function_call(self, name: str) -> FunctionCall:
        self._pos += 1
        args = []
        
        if self._pos < len(self._tokens) and self._tokens[self._pos] != ')':
            args.append(self._parse_expression())
            while self._pos < len(self._tokens) and self._tokens[self._pos] == ',':
                self._pos += 1
                args.append(self._parse_expression())
        
        if self._pos >= len(self._tokens) or self._tokens[self._pos] != ')':
            raise SyntaxError("Missing closing parenthesis in function call")
        self._pos += 1
        
        return FunctionCall(name, args)


if __name__ == "__main__":
    parser = ExpressionParser()
    
    expressions = [
        "1 + 2 * 3",
        "(1 + 2) * 3",
        "x + y * z",
        "2 ** 3 + 4",
        "abs(-5) + max(1, 2, 3)",
        "sqrt(16) + sin(0)",
        "a * b + c * d",
    ]
    
    context = {
        'x': 10, 'y': 20, 'z': 5,
        'a': 2, 'b': 3, 'c': 4, 'd': 5,
    }
    
    print("=== 表达式求值 ===")
    for expr_str in expressions:
        expr = parser.parse(expr_str)
        result = expr.evaluate(context)
        variables = expr.get_variables()
        print(f"{expr_str:30} => {expr.to_string():30} = {result}")
        if variables:
            print(f"  变量: {variables}")

8.7 模式变体与扩展

8.7.1 缓存组合(Memoization)

python
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
from functools import lru_cache
import hashlib
import json

class CacheableComponent(ABC):
    """可缓存组件"""
    
    @abstractmethod
    def compute(self) -> Any:
        pass
    
    @abstractmethod
    def get_cache_key(self) -> str:
        pass
    
    def compute_cached(self, cache: Dict[str, Any]) -> Any:
        key = self.get_cache_key()
        if key not in cache:
            cache[key] = self.compute()
        return cache[key]


class CacheableLeaf(CacheableComponent):
    """可缓存叶子"""
    
    def __init__(self, name: str, value: Any, expensive_func: callable):
        self._name = name
        self._value = value
        self._func = expensive_func
    
    def compute(self) -> Any:
        return self._func(self._value)
    
    def get_cache_key(self) -> str:
        data = json.dumps({"name": self._name, "value": str(self._value)})
        return hashlib.md5(data.encode()).hexdigest()


class CacheableComposite(CacheableComponent):
    """可缓存组合"""
    
    def __init__(self, name: str, aggregator: callable):
        self._name = name
        self._aggregator = aggregator
        self._children: List[CacheableComponent] = []
    
    def add(self, component: CacheableComponent) -> None:
        self._children.append(component)
    
    def compute(self) -> Any:
        results = [child.compute() for child in self._children]
        return self._aggregator(results)
    
    def compute_cached(self, cache: Dict[str, Any]) -> Any:
        results = [child.compute_cached(cache) for child in self._children]
        return self._aggregator(results)
    
    def get_cache_key(self) -> str:
        child_keys = [child.get_cache_key() for child in self._children]
        data = json.dumps({"name": self._name, "children": child_keys})
        return hashlib.md5(data.encode()).hexdigest()
    
    def invalidate_cache(self, cache: Dict[str, Any]) -> None:
        key = self.get_cache_key()
        if key in cache:
            del cache[key]
        for child in self._children:
            if isinstance(child, CacheableComposite):
                child.invalidate_cache(cache)

8.7.2 观察者组合(事件传播)

python
from abc import ABC, abstractmethod
from typing import List, Callable, Dict, Any
from dataclasses import dataclass
from enum import Enum

class EventType(Enum):
    ADD = "add"
    REMOVE = "remove"
    UPDATE = "update"
    MOVE = "move"


@dataclass
class Event:
    type: EventType
    source: Any
    data: Dict[str, Any]


class ObservableComponent(ABC):
    """可观察组件"""
    
    def __init__(self):
        self._listeners: List[Callable[[Event], None]] = []
    
    def add_listener(self, listener: Callable[[Event], None]) -> None:
        self._listeners.append(listener)
    
    def remove_listener(self, listener: Callable[[Event], None]) -> None:
        self._listeners.remove(listener)
    
    def _emit(self, event: Event) -> None:
        for listener in self._listeners:
            listener(event)
    
    @abstractmethod
    def operation(self) -> Any:
        pass


class ObservableLeaf(ObservableComponent):
    """可观察叶子"""
    
    def __init__(self, name: str, value: Any):
        super().__init__()
        self._name = name
        self._value = value
    
    @property
    def value(self) -> Any:
        return self._value
    
    @value.setter
    def value(self, new_value: Any) -> None:
        old_value = self._value
        self._value = new_value
        self._emit(Event(
            type=EventType.UPDATE,
            source=self,
            data={"old": old_value, "new": new_value}
        ))
    
    def operation(self) -> Any:
        return self._value


class ObservableComposite(ObservableComponent):
    """可观察组合"""
    
    def __init__(self, name: str):
        super().__init__()
        self._name = name
        self._children: List[ObservableComponent] = []
    
    def add(self, component: ObservableComponent) -> None:
        self._children.append(component)
        self._emit(Event(
            type=EventType.ADD,
            source=self,
            data={"component": component}
        ))
    
    def remove(self, component: ObservableComponent) -> None:
        self._children.remove(component)
        self._emit(Event(
            type=EventType.REMOVE,
            source=self,
            data={"component": component}
        ))
    
    def operation(self) -> Any:
        return [child.operation() for child in self._children]

8.7.3 持久化组合(序列化)

python
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Type
import json

class SerializableComponent(ABC):
    """可序列化组件"""
    
    @abstractmethod
    def to_dict(self) -> Dict[str, Any]:
        pass
    
    @classmethod
    @abstractmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'SerializableComponent':
        pass
    
    def to_json(self, indent: int = 2) -> str:
        return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
    
    @classmethod
    def from_json(cls, json_str: str) -> 'SerializableComponent':
        data = json.loads(json_str)
        return cls.from_dict(data)


class SerializableLeaf(SerializableComponent):
    """可序列化叶子"""
    
    _registry: Dict[str, Type['SerializableLeaf']] = {}
    
    def __init__(self, name: str, value: Any):
        self._name = name
        self._value = value
    
    @classmethod
    def register(cls, type_name: str):
        def decorator(subclass):
            cls._registry[type_name] = subclass
            return subclass
        return decorator
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": self.__class__.__name__,
            "name": self._name,
            "value": self._value
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'SerializableLeaf':
        type_name = data.get("type", "SerializableLeaf")
        if type_name in cls._registry:
            return cls._registry[type_name].from_dict(data)
        return cls(data["name"], data["value"])


class SerializableComposite(SerializableComponent):
    """可序列化组合"""
    
    def __init__(self, name: str):
        self._name = name
        self._children: List[SerializableComponent] = []
    
    def add(self, component: SerializableComponent) -> None:
        self._children.append(component)
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": "Composite",
            "name": self._name,
            "children": [child.to_dict() for child in self._children]
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'SerializableComposite':
        composite = cls(data["name"])
        for child_data in data.get("children", []):
            if child_data.get("type") == "Composite":
                child = cls.from_dict(child_data)
            else:
                child = SerializableLeaf.from_dict(child_data)
            composite.add(child)
        return composite

8.8 反模式与最佳实践

8.8.1 常见反模式

反模式1:过度嵌套

python
class BadNestedComposite:
    """过度嵌套的反模式"""
    
    def __init__(self):
        self.level1 = {
            "level2": {
                "level3": {
                    "level4": {
                        "level5": {
                            "value": "too deep"
                        }
                    }
                }
            }
        }
    
    def get_value(self):
        return self.level1["level2"]["level3"]["level4"]["level5"]["value"]


class GoodFlattenedComposite:
    """扁平化的正确模式"""
    
    def __init__(self):
        self._nodes: Dict[str, Any] = {}
    
    def add(self, path: str, value: Any) -> None:
        self._nodes[path] = value
    
    def get(self, path: str, default: Any = None) -> Any:
        return self._nodes.get(path, default)

反模式2:循环引用

python
class CircularComposite:
    """循环引用的反模式"""
    
    def __init__(self, name: str):
        self.name = name
        self.children = []
    
    def add_circular(self, other: 'CircularComposite') -> None:
        self.children.append(other)
        other.children.append(self)


def detect_cycle(node: 'CircularComposite', visited: set = None) -> bool:
    """检测循环引用"""
    if visited is None:
        visited = set()
    
    if id(node) in visited:
        return True
    
    visited.add(id(node))
    for child in node.children:
        if detect_cycle(child, visited):
            return True
    visited.remove(id(node))
    return False

反模式3:上帝组件

python
class GodComponent:
    """上帝组件反模式 - 承担过多职责"""
    
    def __init__(self):
        self.children = []
        self.database_connection = None
        self.cache = {}
        self.logger = None
        self.config = {}
        self.event_handlers = []
    
    def operation(self): pass
    def save_to_db(self): pass
    def load_from_cache(self): pass
    def log_activity(self): pass
    def validate_config(self): pass
    def handle_event(self): pass


class FocusedComponent:
    """职责单一的正确模式"""
    
    def __init__(self, name: str):
        self._name = name
        self._children: List['FocusedComponent'] = []
    
    def operation(self) -> str:
        results = [child.operation() for child in self._children]
        return f"{self._name}: {results}"

8.8.2 最佳实践

实践描述示例
接口最小化Component接口只包含必要方法仅声明operation()
类型检查使用is_composite()isinstance()运行时类型判断
防御性复制返回子组件的副本get_children()返回copy
迭代器支持实现__iter__支持遍历for child in composite
上下文管理实现__enter__/__exit__资源管理
不可变设计叶子节点设计为不可变使用@dataclass(frozen=True)

8.9 模式比较

8.9.1 与相关模式对比

模式意图结构关键区别
组合部分-整体层次结构树形递归统一处理单个/组合对象
装饰器动态添加职责线性包装不改变对象本质
责任链请求传递处理链式传递关注处理流程
迭代器遍历聚合对象线性遍历关注访问方式
访问者分离操作与结构双分派关注操作扩展

8.9.2 组合模式与其他模式的协作

┌─────────────────────────────────────────────────────────────┐
│                     组合模式协作关系                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐   │
│  │   组合模式   │────>│   访问者    │────>│  操作分离   │   │
│  └─────────────┘     └─────────────┘     └─────────────┘   │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐   │
│  │   迭代器    │────>│   遍历支持   │────>│  统一访问   │   │
│  └─────────────┘     └─────────────┘     └─────────────┘   │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐   │
│  │   装饰器    │────>│   职责增强   │────>│  动态扩展   │   │
│  └─────────────┘     └─────────────┘     └─────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

8.10 决策指南

8.10.1 适用场景检查清单

  • [ ] 需要表示"部分-整体"层次结构
  • [ ] 希望统一处理单个对象和组合对象
  • [ ] 结构具有递归特性
  • [ ] 客户端代码需要忽略叶子与组合的差异
  • [ ] 需要动态组合对象

8.10.2 实现选择决策树

                    需要类型安全?
                    /          \
                  是            否
                  /              \
           安全式          需要运行时添加/移除?
                              /          \
                            是            否
                            /              \
                       透明式           安全式

8.10.3 快速参考卡

┌─────────────────────────────────────────────────────────────┐
│                    组合模式快速参考                          │
├─────────────────────────────────────────────────────────────┤
│ 定义: 将对象组合成树形结构,统一处理单个/组合对象            │
├─────────────────────────────────────────────────────────────┤
│ 参与者:                                                     │
│   • Component   - 统一接口                                  │
│   • Leaf        - 叶子节点                                  │
│   • Composite   - 组合节点                                  │
│   • Client      - 客户端                                    │
├─────────────────────────────────────────────────────────────┤
│ 形式化: C = L ∪ K, K = {(id, {c₁, c₂, ..., cₙ})}          │
├─────────────────────────────────────────────────────────────┤
│ 变体:                                                       │
│   • 透明式 - 所有方法在基类声明                             │
│   • 安全式 - 管理方法仅在Composite                          │
├─────────────────────────────────────────────────────────────┤
│ 典型应用:                                                   │
│   • 文件系统                                                │
│   • 组织架构                                                │
│   • 图形编辑器                                              │
│   • 菜单系统                                                │
│   • 配置管理                                                │
├─────────────────────────────────────────────────────────────┤
│ Python特性:                                                 │
│   • ABC抽象基类                                             │
│   • Protocol结构化类型                                      │
│   • Generic泛型支持                                         │
│   • dataclass数据类                                         │
├─────────────────────────────────────────────────────────────┤
│ 最佳实践:                                                   │
│   • 接口最小化                                              │
│   • 防御性复制                                              │
│   • 避免循环引用                                            │
│   • 实现迭代器支持                                          │
└─────────────────────────────────────────────────────────────┘

8.11 小结

组合模式是处理树形结构的理想选择。通过统一单个对象和组合对象的接口,简化了客户端代码。其核心价值在于:

  1. 统一抽象:客户端无需区分叶子与组合
  2. 递归结构:天然支持层次化数据
  3. 开闭原则:易于添加新的组件类型
  4. 简化客户端:减少条件判断逻辑

在Python中,可以利用抽象基类、Protocol、Generic等现代特性实现类型安全的组合结构,同时结合迭代器、上下文管理器等协议提供更丰富的功能支持。


思考与练习

  1. 基础练习:实现一个支持撤销操作的组合模式,记录所有修改历史。

  2. 进阶练习:设计一个支持并行计算的组合模式,利用多线程/多进程加速组合节点的计算。

  3. 挑战练习:实现一个支持增量更新的组合模式,当叶子节点变化时,仅更新受影响的组合节点。

  4. 设计思考:组合模式与访问者模式如何协作?请设计一个支持多种访问操作的组合结构。

Python技术丛书 - 江苏省宿城中等专业学校