第8章 组合模式
学习目标
完成本章学习后,读者将能够:
- 理解组合模式的核心概念和数学形式化定义
- 掌握树形结构的对象组织方式与递归组合原理
- 实现透明式与安全式两种组合模式变体
- 运用Protocol、Generic等现代Python特性实现类型安全的组合结构
- 识别组合模式的适用场景与反模式
- 设计企业级组合结构解决方案
8.1 模式定义
8.1.1 核心定义
组合模式(Composite Pattern) 将对象组合成树形结构以表示"部分-整体"(Part-Whole)的层次结构。组合模式使得用户对单个对象(Leaf)和组合对象(Composite)的使用具有一致性。
8.1.2 形式化定义
设 $\mathcal{C}$ 为组件集合,$\mathcal{L}$ 为叶子集合,$\mathcal{K}$ 为组合集合,则组合模式可形式化定义为:
$$\text{Composite}: \mathcal{C} = \mathcal{L} \cup \mathcal{K}$$
其中组合集合满足递归定义:
$$\mathcal{K} = {k \mid k = (id, {c_1, c_2, ..., c_n}), c_i \in \mathcal{C}}$$
树形结构约束:组合模式形成的结构必须满足树的无环性约束:
$$\forall c \in \mathcal{C}: \text{ancestors}(c) \cap \text{descendants}(c) = \emptyset$$
一致性原则:对于任意操作 $op$,叶子与组合的行为满足:
$$op(l) = f_l(l), \quad l \in \mathcal{L}$$ $$op(k) = \bigoplus_{c \in children(k)} op(c), \quad k \in \mathcal{K}$$
其中 $\bigoplus$ 为聚合算子(如求和、连接、逻辑与等)。
复杂度分析:
- 添加操作:$O(1)$(列表追加)
- 删除操作:$O(n)$(列表查找删除)
- 遍历操作:$O(n)$($n$ 为节点总数)
- 空间复杂度:$O(n)$
8.1.3 历史背景与学术脉络
组合模式起源于图形编辑器的设计实践。其学术发展历程如下:
| 年份 | 里程碑 | 贡献者 |
|---|---|---|
| 1987 | 首次在ET++框架中应用 | Weinand, Gamma |
| 1991 | 正式提出组合模式概念 | Gamma, Helm, Johnson, Vlissides |
| 1994 | GoF《设计模式》收录 | Gang of Four |
| 1995 | 透明式vs安全式分类讨论 | Gamma et al. |
| 2002 | 与访问者模式结合研究 | Gamma, Beck |
| 2010 | 函数式组合模式研究 | Martin et al. |
学术意义:组合模式是递归数据结构在面向对象设计中的经典应用,体现了"组合优于继承"的设计原则,为树形结构的统一处理提供了范式解决方案。
8.2 模式结构与参与者
8.2.1 UML类图
┌─────────────────────────────┐
│ <<interface>> │
│ Component │
├─────────────────────────────┤
│ + operation(): Result │
│ + add(c: Component): void │
│ + remove(c: Component): void│
│ + get_child(i: int): Component│
│ + is_composite(): bool │
└───────────────┬─────────────┘
│
┌───────────────┴───────────────┐
│ │
┌───────────┴───────────┐ ┌────────────┴────────────┐
│ Leaf │ │ Composite │
├───────────────────────┤ ├─────────────────────────┤
│ - name: str │ │ - children: List[Component]│
│ - value: Any │ │ - name: str │
├───────────────────────┤ ├─────────────────────────┤
│ + operation(): Result │ │ + operation(): Result │
│ + is_composite(): bool│ │ + add(c: Component) │
│ │ │ + remove(c: Component) │
│ │ │ + get_child(i): Component│
│ │ │ + is_composite(): bool │
└───────────────────────┘ └─────────────────────────┘8.2.2 参与者职责
| 参与者 | 职责 | 关键方法 |
|---|---|---|
| Component | 声明统一接口,定义默认行为 | operation(), add(), remove() |
| Leaf | 表示叶子节点,无子节点 | 实现 operation() |
| Composite | 表示组合节点,存储子组件 | 管理子组件,递归调用 |
| Client | 通过Component接口操作对象 | 统一处理所有组件 |
8.2.3 协作关系
Client ──────> Component <──────────┐
│ │
▼ │
┌──────────┐ │
│operation()│ │
└────┬─────┘ │
│ │
┌───────┴───────┐ │
▼ ▼ │
┌───────┐ ┌──────────┐ │
│ Leaf │ │Composite │──────┘
└───────┘ └────┬─────┘
│
┌────────┼────────┐
▼ ▼ ▼
Child1 Child2 Child38.3 Python实现
8.3.1 标准实现(ABC抽象基类)
from abc import ABC, abstractmethod
from typing import List, Optional, Iterator
from dataclasses import dataclass
class Component(ABC):
"""组件抽象基类 - 定义统一接口"""
@abstractmethod
def operation(self) -> str:
"""核心操作方法"""
pass
def add(self, component: 'Component') -> None:
"""添加子组件 - 默认抛出异常(安全式)"""
raise NotImplementedError("Leaf cannot add children")
def remove(self, component: 'Component') -> None:
"""移除子组件 - 默认抛出异常(安全式)"""
raise NotImplementedError("Leaf cannot remove children")
def get_child(self, index: int) -> 'Component':
"""获取子组件 - 默认抛出异常"""
raise NotImplementedError("Leaf has no children")
def is_composite(self) -> bool:
"""判断是否为组合节点"""
return False
def __iter__(self) -> Iterator['Component']:
"""迭代器支持"""
return iter([])
class Leaf(Component):
"""叶子节点 - 无子节点的终端节点"""
def __init__(self, name: str, value: any = None):
self._name = name
self._value = value
@property
def name(self) -> str:
return self._name
@property
def value(self) -> any:
return self._value
def operation(self) -> str:
return f"Leaf({self._name}: {self._value})"
class Composite(Component):
"""组合节点 - 可包含子节点的容器"""
def __init__(self, name: str):
self._name = name
self._children: List[Component] = []
@property
def name(self) -> str:
return self._name
def operation(self) -> str:
results = [child.operation() for child in self._children]
return f"Composite({self._name}): [{', '.join(results)}]"
def add(self, component: Component) -> None:
if component not in self._children:
self._children.append(component)
def remove(self, component: Component) -> None:
if component in self._children:
self._children.remove(component)
def get_child(self, index: int) -> Component:
return self._children[index]
def is_composite(self) -> bool:
return True
def __iter__(self) -> Iterator[Component]:
return iter(self._children)
def __len__(self) -> int:
return len(self._children)
def clear(self) -> None:
self._children.clear()
def client_code(component: Component) -> None:
"""客户端代码 - 统一处理所有组件"""
print(f"RESULT: {component.operation()}")
if component.is_composite():
print(f" Children count: {len(list(component))}")
if __name__ == "__main__":
leaf = Leaf("Simple", 42)
client_code(leaf)
root = Composite("Root")
branch1 = Composite("Branch1")
branch1.add(Leaf("Leaf1", 1))
branch1.add(Leaf("Leaf2", 2))
branch2 = Composite("Branch2")
branch2.add(Leaf("Leaf3", 3))
root.add(branch1)
root.add(branch2)
root.add(Leaf("Leaf4", 4))
client_code(root)8.3.2 Protocol实现(结构化类型)
from typing import Protocol, List, Iterator, runtime_checkable
from dataclasses import dataclass
@runtime_checkable
class ComponentProtocol(Protocol):
"""组件协议 - 结构化类型定义"""
def operation(self) -> str: ...
def is_composite(self) -> bool: ...
@dataclass
class LeafNode:
"""叶子节点 - 数据类实现"""
name: str
value: int = 0
def operation(self) -> str:
return f"Leaf({self.name}: {self.value})"
def is_composite(self) -> bool:
return False
class CompositeNode:
"""组合节点 - Protocol兼容实现"""
def __init__(self, name: str):
self.name = name
self._children: List[ComponentProtocol] = []
def operation(self) -> str:
results = [child.operation() for child in self._children]
return f"Composite({self.name}): [{', '.join(results)}]"
def is_composite(self) -> bool:
return True
def add(self, component: ComponentProtocol) -> None:
self._children.append(component)
def remove(self, component: ComponentProtocol) -> None:
self._children.remove(component)
def __iter__(self) -> Iterator[ComponentProtocol]:
return iter(self._children)
def process_component(comp: ComponentProtocol) -> None:
"""使用Protocol进行类型检查"""
print(f"Processing: {comp.operation()}")
print(f"Is Composite: {comp.is_composite()}")
if __name__ == "__main__":
leaf = LeafNode("Data", 100)
composite = CompositeNode("Container")
composite.add(LeafNode("Item1", 1))
composite.add(LeafNode("Item2", 2))
process_component(leaf)
process_component(composite)
print(f"\nType checking:")
print(f"leaf is ComponentProtocol: {isinstance(leaf, ComponentProtocol)}")
print(f"composite is ComponentProtocol: {isinstance(composite, ComponentProtocol)}")8.3.3 泛型实现(类型安全)
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, List, Callable, Optional
from dataclasses import dataclass
T = TypeVar('T')
R = TypeVar('R')
class GenericComponent(ABC, Generic[T, R]):
"""泛型组件接口"""
@abstractmethod
def operation(self) -> R:
pass
@abstractmethod
def get_data(self) -> T:
pass
@dataclass
class GenericLeaf(GenericComponent[T, R]):
"""泛型叶子节点"""
name: str
data: T
transformer: Callable[[T], R]
def operation(self) -> R:
return self.transformer(self.data)
def get_data(self) -> T:
return self.data
class GenericComposite(GenericComponent[T, R]):
"""泛型组合节点"""
def __init__(
self,
name: str,
aggregator: Callable[[List[R]], R]
):
self._name = name
self._children: List[GenericComponent[T, R]] = []
self._aggregator = aggregator
def operation(self) -> R:
results = [child.operation() for child in self._children]
return self._aggregator(results)
def get_data(self) -> T:
raise NotImplementedError("Composite has no single data")
def add(self, component: GenericComponent[T, R]) -> None:
self._children.append(component)
def remove(self, component: GenericComponent[T, R]) -> None:
self._children.remove(component)
if __name__ == "__main__":
def sum_aggregator(results: List[int]) -> int:
return sum(results)
def str_aggregator(results: List[str]) -> str:
return " | ".join(results)
int_leaf1 = GenericLeaf("A", 10, lambda x: x * 2)
int_leaf2 = GenericLeaf("B", 20, lambda x: x * 2)
int_composite = GenericComposite("Numbers", sum_aggregator)
int_composite.add(int_leaf1)
int_composite.add(int_leaf2)
print(f"Integer composite: {int_composite.operation()}")
str_leaf1 = GenericLeaf("X", "hello", lambda x: x.upper())
str_leaf2 = GenericLeaf("Y", "world", lambda x: x.upper())
str_composite = GenericComposite("Strings", str_aggregator)
str_composite.add(str_leaf1)
str_composite.add(str_leaf2)
print(f"String composite: {str_composite.operation()}")8.3.4 透明式实现
from abc import ABC, abstractmethod
from typing import List, Optional
class TransparentComponent(ABC):
"""透明式组件 - 所有方法在基类中声明"""
@abstractmethod
def operation(self) -> str:
pass
def add(self, component: 'TransparentComponent') -> None:
pass
def remove(self, component: 'TransparentComponent') -> None:
pass
def get_child(self, index: int) -> Optional['TransparentComponent']:
return None
def is_composite(self) -> bool:
return False
class TransparentLeaf(TransparentComponent):
"""透明式叶子 - 空实现add/remove"""
def __init__(self, name: str):
self._name = name
def operation(self) -> str:
return f"Leaf: {self._name}"
class TransparentComposite(TransparentComponent):
"""透明式组合"""
def __init__(self, name: str):
self._name = name
self._children: List[TransparentComponent] = []
def operation(self) -> str:
results = [child.operation() for child in self._children]
return f"Branch({self._name}): [{', '.join(results)}]"
def add(self, component: TransparentComponent) -> None:
self._children.append(component)
def remove(self, component: TransparentComponent) -> None:
self._children.remove(component)
def get_child(self, index: int) -> Optional[TransparentComponent]:
if 0 <= index < len(self._children):
return self._children[index]
return None
def is_composite(self) -> bool:
return True
def transparent_client(component: TransparentComponent) -> None:
"""客户端无需区分叶子/组合"""
print(component.operation())
component.add(TransparentLeaf("New"))
if component.is_composite():
print(" Added child successfully")8.3.5 安全式实现
from abc import ABC, abstractmethod
from typing import List, Protocol, runtime_checkable
class SafeComponent(ABC):
"""安全式组件 - 仅声明公共操作"""
@abstractmethod
def operation(self) -> str:
pass
class SafeLeaf(SafeComponent):
"""安全式叶子 - 无add/remove方法"""
def __init__(self, name: str):
self._name = name
def operation(self) -> str:
return f"Leaf: {self._name}"
class SafeComposite(SafeComponent):
"""安全式组合 - 独有的管理方法"""
def __init__(self, name: str):
self._name = name
self._children: List[SafeComponent] = []
def operation(self) -> str:
results = [child.operation() for child in self._children]
return f"Branch({self._name}): [{', '.join(results)}]"
def add(self, component: SafeComponent) -> None:
self._children.append(component)
def remove(self, component: SafeComponent) -> None:
self._children.remove(component)
def get_children(self) -> List[SafeComponent]:
return self._children.copy()
def safe_client(component: SafeComponent) -> None:
"""客户端仅能调用公共方法"""
print(component.operation())
def safe_composite_client(composite: SafeComposite) -> None:
"""组合专用客户端 - 可调用管理方法"""
composite.add(SafeLeaf("New"))
print(composite.operation())8.4 透明式 vs 安全式对比
8.4.1 对比分析
| 特性 | 透明式 | 安全式 |
|---|---|---|
| 接口设计 | 所有方法在基类声明 | 仅公共方法在基类 |
| 类型安全 | 运行时可能出错 | 编译时类型检查 |
| 客户端复杂度 | 低,统一处理 | 高,需区分类型 |
| 扩展性 | 高,易于添加新组件 | 中,需修改接口 |
| 错误处理 | 运行时异常 | 编译时检查 |
| 适用场景 | 结构稳定,客户端简单 | 类型安全要求高 |
8.4.2 选择决策树
需要类型安全?
/ \
是 否
/ \
安全式 客户端需要统一处理?
/ \
是 否
/ \
透明式 安全式8.5 实际应用示例
8.5.1 文件系统实现
from abc import ABC, abstractmethod
from typing import List, Optional, Iterator
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
import os
class FileType(Enum):
FILE = "file"
DIRECTORY = "directory"
@dataclass
class FileMetadata:
name: str
size: int
created_at: datetime
modified_at: datetime
permissions: str = "rw-r--r--"
class FileSystemNode(ABC):
"""文件系统节点抽象"""
@abstractmethod
def get_name(self) -> str:
pass
@abstractmethod
def get_size(self) -> int:
pass
@abstractmethod
def get_type(self) -> FileType:
pass
@abstractmethod
def list_contents(self, indent: int = 0) -> str:
pass
def get_path(self) -> str:
return self.get_name()
class File(FileSystemNode):
"""文件节点"""
def __init__(
self,
name: str,
content: str = "",
permissions: str = "rw-r--r--"
):
self._metadata = FileMetadata(
name=name,
size=len(content),
created_at=datetime.now(),
modified_at=datetime.now(),
permissions=permissions
)
self._content = content
def get_name(self) -> str:
return self._metadata.name
def get_size(self) -> int:
return self._metadata.size
def get_type(self) -> FileType:
return FileType.FILE
def get_metadata(self) -> FileMetadata:
return self._metadata
def read(self) -> str:
return self._content
def write(self, content: str) -> None:
self._content = content
self._metadata.size = len(content)
self._metadata.modified_at = datetime.now()
def append(self, content: str) -> None:
self._content += content
self._metadata.size = len(self._content)
self._metadata.modified_at = datetime.now()
def list_contents(self, indent: int = 0) -> str:
prefix = " " * indent
size_kb = self._metadata.size / 1024
return f"{prefix}📄 {self._metadata.name} ({size_kb:.1f} KB)"
class Directory(FileSystemNode):
"""目录节点"""
def __init__(self, name: str, permissions: str = "rwxr-xr-x"):
self._name = name
self._permissions = permissions
self._children: List[FileSystemNode] = []
self._created_at = datetime.now()
self._modified_at = datetime.now()
def get_name(self) -> str:
return self._name
def get_size(self) -> int:
return sum(child.get_size() for child in self._children)
def get_type(self) -> FileType:
return FileType.DIRECTORY
def add(self, node: FileSystemNode) -> None:
if node not in self._children:
self._children.append(node)
self._modified_at = datetime.now()
def remove(self, node: FileSystemNode) -> None:
if node in self._children:
self._children.remove(node)
self._modified_at = datetime.now()
def get_children(self) -> List[FileSystemNode]:
return self._children.copy()
def find(self, name: str) -> Optional[FileSystemNode]:
for child in self._children:
if child.get_name() == name:
return child
if child.get_type() == FileType.DIRECTORY:
result = child.find(name)
if result:
return result
return None
def list_contents(self, indent: int = 0) -> str:
prefix = " " * indent
lines = [f"{prefix}📁 {self._name}/"]
for child in sorted(self._children, key=lambda x: x.get_name()):
lines.append(child.list_contents(indent + 1))
return "\n".join(lines)
def tree(self, max_depth: int = -1, current_depth: int = 0) -> str:
"""生成树形结构视图"""
if max_depth >= 0 and current_depth > max_depth:
return ""
prefix = "│ " * current_depth
lines = []
for i, child in enumerate(sorted(self._children, key=lambda x: x.get_name())):
is_last = i == len(self._children) - 1
connector = "└── " if is_last else "├── "
if child.get_type() == FileType.DIRECTORY:
lines.append(f"{prefix}{connector}{child.get_name()}/")
lines.append(child.tree(max_depth, current_depth + 1))
else:
lines.append(f"{prefix}{connector}{child.get_name()}")
return "\n".join(filter(None, lines))
def get_file_count(self) -> int:
count = 0
for child in self._children:
if child.get_type() == FileType.FILE:
count += 1
else:
count += child.get_file_count()
return count
def get_directory_count(self) -> int:
count = 0
for child in self._children:
if child.get_type() == FileType.DIRECTORY:
count += 1 + child.get_directory_count()
return count
def create_sample_filesystem() -> Directory:
root = Directory("project")
src = Directory("src")
src.add(File("main.py", "print('Hello')"))
src.add(File("utils.py", "def helper(): pass"))
models = Directory("models")
models.add(File("user.py", "class User: pass"))
models.add(File("product.py", "class Product: pass"))
src.add(models)
tests = Directory("tests")
tests.add(File("test_main.py", "def test_main(): pass"))
tests.add(File("test_utils.py", "def test_utils(): pass"))
docs = Directory("docs")
docs.add(File("README.md", "# Project Documentation"))
docs.add(File("API.md", "## API Reference"))
root.add(src)
root.add(tests)
root.add(docs)
root.add(File("requirements.txt", "pytest\nrequests"))
root.add(File("setup.py", "# setup configuration"))
return root
if __name__ == "__main__":
fs = create_sample_filesystem()
print("=== 文件系统结构 ===")
print(fs.list_contents())
print(f"\n=== 树形视图 ===")
print(f"project/")
print(fs.tree())
print(f"\n=== 统计信息 ===")
print(f"总大小: {fs.get_size() / 1024:.2f} KB")
print(f"文件数: {fs.get_file_count()}")
print(f"目录数: {fs.get_directory_count()}")
print(f"\n=== 查找文件 ===")
found = fs.find("main.py")
if found:
print(f"找到: {found.get_name()}, 大小: {found.get_size()} bytes")8.5.2 组织架构系统
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any
from dataclasses import dataclass, field
from datetime import date
from enum import Enum
class DepartmentType(Enum):
HEADQUARTERS = "总部"
DIVISION = "事业部"
DEPARTMENT = "部门"
TEAM = "小组"
@dataclass
class EmployeeInfo:
"""员工信息"""
id: str
name: str
position: str
salary: float
hire_date: date
skills: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"position": self.position,
"salary": self.salary,
"hire_date": self.hire_date.isoformat(),
"skills": self.skills
}
class OrganizationComponent(ABC):
"""组织架构组件"""
@abstractmethod
def get_name(self) -> str:
pass
@abstractmethod
def get_budget(self) -> float:
pass
@abstractmethod
def get_employee_count(self) -> int:
pass
@abstractmethod
def get_structure(self, indent: int = 0) -> str:
pass
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
pass
def get_average_salary(self) -> float:
count = self.get_employee_count()
return self.get_budget() / count if count > 0 else 0
class Employee(OrganizationComponent):
"""员工节点"""
def __init__(self, info: EmployeeInfo):
self._info = info
@property
def info(self) -> EmployeeInfo:
return self._info
def get_name(self) -> str:
return f"{self._info.name} ({self._info.position})"
def get_budget(self) -> float:
return self._info.salary
def get_employee_count(self) -> int:
return 1
def get_structure(self, indent: int = 0) -> str:
prefix = " " * indent
return f"{prefix}👤 {self._info.name} - {self._info.position} (¥{self._info.salary:,.0f}/月)"
def to_dict(self) -> Dict[str, Any]:
return self._info.to_dict()
class Department(OrganizationComponent):
"""部门节点"""
def __init__(
self,
name: str,
dept_type: DepartmentType,
manager: Optional[str] = None,
budget_allocation: float = 0
):
self._name = name
self._dept_type = dept_type
self._manager = manager
self._budget_allocation = budget_allocation
self._components: List[OrganizationComponent] = []
@property
def dept_type(self) -> DepartmentType:
return self._dept_type
@property
def manager(self) -> Optional[str]:
return self._manager
def get_name(self) -> str:
return self._name
def get_budget(self) -> float:
return self._budget_allocation + sum(
comp.get_budget() for comp in self._components
)
def get_employee_count(self) -> int:
return sum(comp.get_employee_count() for comp in self._components)
def add(self, component: OrganizationComponent) -> None:
self._components.append(component)
def remove(self, component: OrganizationComponent) -> None:
self._components.remove(component)
def get_components(self) -> List[OrganizationComponent]:
return self._components.copy()
def get_structure(self, indent: int = 0) -> str:
prefix = " " * indent
lines = [f"{prefix}🏢 {self._name} [{self._dept_type.value}]"]
if self._manager:
lines[-1] += f" (负责人: {self._manager})"
for comp in self._components:
lines.append(comp.get_structure(indent + 1))
return "\n".join(lines)
def to_dict(self) -> Dict[str, Any]:
return {
"name": self._name,
"type": self._dept_type.value,
"manager": self._manager,
"budget": self.get_budget(),
"employee_count": self.get_employee_count(),
"children": [comp.to_dict() for comp in self._components]
}
def find_employee(self, name: str) -> Optional[Employee]:
"""按姓名查找员工"""
for comp in self._components:
if isinstance(comp, Employee):
if comp.info.name == name:
return comp
elif isinstance(comp, Department):
result = comp.find_employee(name)
if result:
return result
return None
def get_all_employees(self) -> List[Employee]:
"""获取所有员工"""
employees = []
for comp in self._components:
if isinstance(comp, Employee):
employees.append(comp)
elif isinstance(comp, Department):
employees.extend(comp.get_all_employees())
return employees
def get_salary_statistics(self) -> Dict[str, float]:
"""薪资统计"""
employees = self.get_all_employees()
if not employees:
return {"min": 0, "max": 0, "avg": 0, "total": 0}
salaries = [e.info.salary for e in employees]
return {
"min": min(salaries),
"max": max(salaries),
"avg": sum(salaries) / len(salaries),
"total": sum(salaries)
}
def get_skill_matrix(self) -> Dict[str, List[str]]:
"""技能矩阵"""
matrix: Dict[str, List[str]] = {}
for emp in self.get_all_employees():
for skill in emp.info.skills:
if skill not in matrix:
matrix[skill] = []
matrix[skill].append(emp.info.name)
return matrix
def create_organization() -> Department:
company = Department("科技创新公司", DepartmentType.HEADQUARTERS, "张总", 500000)
tech_division = Department("技术事业部", DepartmentType.DIVISION, "李副总", 200000)
backend_dept = Department("后端开发部", DepartmentType.DEPARTMENT, "王经理", 50000)
backend_dept.add(Employee(EmployeeInfo(
"E001", "赵工", "高级架构师", 45000,
date(2020, 3, 15), ["Python", "Go", "Kubernetes"]
)))
backend_dept.add(Employee(EmployeeInfo(
"E002", "钱工", "高级工程师", 30000,
date(2021, 6, 20), ["Python", "Django", "PostgreSQL"]
)))
backend_dept.add(Employee(EmployeeInfo(
"E003", "孙工", "工程师", 22000,
date(2022, 1, 10), ["Python", "FastAPI"]
)))
frontend_dept = Department("前端开发部", DepartmentType.DEPARTMENT, "周经理", 40000)
frontend_dept.add(Employee(EmployeeInfo(
"E004", "吴工", "高级前端工程师", 28000,
date(2021, 4, 5), ["React", "TypeScript", "Node.js"]
)))
frontend_dept.add(Employee(EmployeeInfo(
"E005", "郑工", "前端工程师", 20000,
date(2022, 8, 15), ["Vue", "JavaScript"]
)))
devops_team = Department("DevOps组", DepartmentType.TEAM, "冯组长", 30000)
devops_team.add(Employee(EmployeeInfo(
"E006", "陈工", "DevOps工程师", 25000,
date(2021, 9, 1), ["Docker", "Kubernetes", "CI/CD"]
)))
tech_division.add(backend_dept)
tech_division.add(frontend_dept)
tech_division.add(devops_team)
hr_division = Department("人力资源部", DepartmentType.DIVISION, "林经理", 80000)
hr_division.add(Employee(EmployeeInfo(
"E007", "黄专员", "HRBP", 15000,
date(2022, 3, 20), ["招聘", "培训", "绩效管理"]
)))
hr_division.add(Employee(EmployeeInfo(
"E008", "许专员", "招聘专员", 12000,
date(2023, 1, 15), ["招聘", "面试"]
)))
finance_division = Department("财务部", DepartmentType.DIVISION, "何经理", 60000)
finance_division.add(Employee(EmployeeInfo(
"E009", "沈会计", "高级会计", 18000,
date(2021, 5, 10), ["财务分析", "税务筹划"]
)))
finance_division.add(Employee(EmployeeInfo(
"E010", "韩出纳", "出纳", 10000,
date(2023, 2, 1), ["现金管理", "银行结算"]
)))
company.add(tech_division)
company.add(hr_division)
company.add(finance_division)
return company
if __name__ == "__main__":
org = create_organization()
print("=== 组织架构 ===")
print(org.get_structure())
print(f"\n=== 统计信息 ===")
print(f"总人数: {org.get_employee_count()}")
print(f"总预算: ¥{org.get_budget():,.0f}")
print(f"平均薪资: ¥{org.get_average_salary():,.0f}")
print(f"\n=== 薪资统计 ===")
stats = org.get_salary_statistics()
print(f"最低薪资: ¥{stats['min']:,.0f}")
print(f"最高薪资: ¥{stats['max']:,.0f}")
print(f"平均薪资: ¥{stats['avg']:,.0f}")
print(f"\n=== 技能矩阵 ===")
skills = org.get_skill_matrix()
for skill, employees in sorted(skills.items()):
print(f"{skill}: {', '.join(employees)}")
print(f"\n=== 查找员工 ===")
emp = org.find_employee("赵工")
if emp:
print(f"找到: {emp.info.name}, 职位: {emp.info.position}")8.5.3 图形编辑器
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, Dict, Any
from dataclasses import dataclass, field
from enum import Enum
import json
class ShapeType(Enum):
CIRCLE = "circle"
RECTANGLE = "rectangle"
LINE = "line"
GROUP = "group"
@dataclass
class Point:
x: float
y: float
def translate(self, dx: float, dy: float) -> 'Point':
return Point(self.x + dx, self.y + dy)
def distance_to(self, other: 'Point') -> float:
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5
@dataclass
class Color:
r: int
g: int
b: int
a: float = 1.0
def to_hex(self) -> str:
return f"#{self.r:02x}{self.g:02x}{self.b:02x}"
def to_rgba(self) -> str:
return f"rgba({self.r}, {self.g}, {self.b}, {self.a})"
@dataclass
class Style:
fill_color: Optional[Color] = None
stroke_color: Optional[Color] = None
stroke_width: float = 1.0
def to_dict(self) -> Dict[str, Any]:
return {
"fill": self.fill_color.to_hex() if self.fill_color else None,
"stroke": self.stroke_color.to_hex() if self.stroke_color else None,
"stroke_width": self.stroke_width
}
class Graphic(ABC):
"""图形组件抽象"""
@abstractmethod
def draw(self) -> str:
pass
@abstractmethod
def get_bounds(self) -> Tuple[Point, Point]:
pass
@abstractmethod
def move(self, dx: float, dy: float) -> None:
pass
@abstractmethod
def contains_point(self, point: Point) -> bool:
pass
@abstractmethod
def get_type(self) -> ShapeType:
pass
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
pass
def get_center(self) -> Point:
min_pt, max_pt = self.get_bounds()
return Point(
(min_pt.x + max_pt.x) / 2,
(min_pt.y + max_pt.y) / 2
)
def get_width(self) -> float:
min_pt, max_pt = self.get_bounds()
return max_pt.x - min_pt.x
def get_height(self) -> float:
min_pt, max_pt = self.get_bounds()
return max_pt.y - min_pt.y
class Circle(Graphic):
"""圆形"""
def __init__(self, center: Point, radius: float, style: Style = None):
self._center = center
self._radius = radius
self._style = style or Style()
@property
def center(self) -> Point:
return self._center
@property
def radius(self) -> float:
return self._radius
def draw(self) -> str:
return f"Circle(center=({self._center.x}, {self._center.y}), r={self._radius})"
def get_bounds(self) -> Tuple[Point, Point]:
return (
Point(self._center.x - self._radius, self._center.y - self._radius),
Point(self._center.x + self._radius, self._center.y + self._radius)
)
def move(self, dx: float, dy: float) -> None:
self._center = self._center.translate(dx, dy)
def contains_point(self, point: Point) -> bool:
return self._center.distance_to(point) <= self._radius
def get_type(self) -> ShapeType:
return ShapeType.CIRCLE
def to_dict(self) -> Dict[str, Any]:
return {
"type": "circle",
"center": {"x": self._center.x, "y": self._center.y},
"radius": self._radius,
"style": self._style.to_dict()
}
def scale(self, factor: float) -> None:
self._radius *= factor
class Rectangle(Graphic):
"""矩形"""
def __init__(
self,
top_left: Point,
width: float,
height: float,
style: Style = None
):
self._top_left = top_left
self._width = width
self._height = height
self._style = style or Style()
@property
def top_left(self) -> Point:
return self._top_left
@property
def width(self) -> float:
return self._width
@property
def height(self) -> float:
return self._height
def draw(self) -> str:
return f"Rectangle(pos=({self._top_left.x}, {self._top_left.y}), size={self._width}x{self._height})"
def get_bounds(self) -> Tuple[Point, Point]:
return (
self._top_left,
Point(self._top_left.x + self._width, self._top_left.y + self._height)
)
def move(self, dx: float, dy: float) -> None:
self._top_left = self._top_left.translate(dx, dy)
def contains_point(self, point: Point) -> bool:
min_pt, max_pt = self.get_bounds()
return (min_pt.x <= point.x <= max_pt.x and
min_pt.y <= point.y <= max_pt.y)
def get_type(self) -> ShapeType:
return ShapeType.RECTANGLE
def to_dict(self) -> Dict[str, Any]:
return {
"type": "rectangle",
"top_left": {"x": self._top_left.x, "y": self._top_left.y},
"width": self._width,
"height": self._height,
"style": self._style.to_dict()
}
class Line(Graphic):
"""线段"""
def __init__(self, start: Point, end: Point, style: Style = None):
self._start = start
self._end = end
self._style = style or Style(stroke_color=Color(0, 0, 0))
def draw(self) -> str:
return f"Line(({self._start.x}, {self._start.y}) -> ({self._end.x}, {self._end.y}))"
def get_bounds(self) -> Tuple[Point, Point]:
return (
Point(min(self._start.x, self._end.x), min(self._start.y, self._end.y)),
Point(max(self._start.x, self._end.x), max(self._start.y, self._end.y))
)
def move(self, dx: float, dy: float) -> None:
self._start = self._start.translate(dx, dy)
self._end = self._end.translate(dx, dy)
def contains_point(self, point: Point, threshold: float = 5.0) -> bool:
dist = abs(
(self._end.y - self._start.y) * point.x -
(self._end.x - self._start.x) * point.y +
self._end.x * self._start.y -
self._end.y * self._start.x
) / ((self._end.y - self._start.y) ** 2 + (self._end.x - self._start.x) ** 2) ** 0.5
return dist <= threshold
def get_type(self) -> ShapeType:
return ShapeType.LINE
def to_dict(self) -> Dict[str, Any]:
return {
"type": "line",
"start": {"x": self._start.x, "y": self._start.y},
"end": {"x": self._end.x, "y": self._end.y},
"style": self._style.to_dict()
}
@property
def length(self) -> float:
return self._start.distance_to(self._end)
class GraphicGroup(Graphic):
"""图形组合"""
def __init__(self, name: str = "Group"):
self._name = name
self._graphics: List[Graphic] = []
self._style = Style()
@property
def name(self) -> str:
return self._name
def add(self, graphic: Graphic) -> None:
self._graphics.append(graphic)
def remove(self, graphic: Graphic) -> None:
self._graphics.remove(graphic)
def get_graphics(self) -> List[Graphic]:
return self._graphics.copy()
def draw(self) -> str:
results = [f"Group '{self._name}':"]
for g in self._graphics:
for line in g.draw().split('\n'):
results.append(f" {line}")
return '\n'.join(results)
def get_bounds(self) -> Tuple[Point, Point]:
if not self._graphics:
return (Point(0, 0), Point(0, 0))
all_bounds = [g.get_bounds() for g in self._graphics]
min_x = min(b[0].x for b in all_bounds)
min_y = min(b[0].y for b in all_bounds)
max_x = max(b[1].x for b in all_bounds)
max_y = max(b[1].y for b in all_bounds)
return (Point(min_x, min_y), Point(max_x, max_y))
def move(self, dx: float, dy: float) -> None:
for graphic in self._graphics:
graphic.move(dx, dy)
def contains_point(self, point: Point) -> bool:
return any(g.contains_point(point) for g in self._graphics)
def get_type(self) -> ShapeType:
return ShapeType.GROUP
def to_dict(self) -> Dict[str, Any]:
return {
"type": "group",
"name": self._name,
"children": [g.to_dict() for g in self._graphics]
}
def flatten(self) -> List[Graphic]:
"""展平组合,返回所有基本图形"""
result = []
for g in self._graphics:
if g.get_type() == ShapeType.GROUP:
result.extend(g.flatten())
else:
result.append(g)
return result
def find_at_point(self, point: Point) -> Optional[Graphic]:
"""查找指定位置的图形"""
for g in reversed(self._graphics):
if g.contains_point(point):
if g.get_type() == ShapeType.GROUP:
found = g.find_at_point(point)
if found:
return found
return g
return None
def scale(self, factor: float, center: Point = None) -> None:
"""缩放组合"""
if center is None:
center = self.get_center()
for g in self._graphics:
if g.get_type() == ShapeType.GROUP:
g.scale(factor, center)
elif g.get_type() == ShapeType.CIRCLE:
g.scale(factor)
elif hasattr(g, '_top_left'):
g._top_left = Point(
center.x + (g._top_left.x - center.x) * factor,
center.y + (g._top_left.y - center.y) * factor
)
g._width *= factor
g._height *= factor
def create_drawing() -> GraphicGroup:
canvas = GraphicGroup("Canvas")
red = Color(255, 0, 0)
blue = Color(0, 0, 255)
green = Color(0, 255, 0)
face = GraphicGroup("Face")
face.add(Circle(Point(200, 200), 100, Style(fill_color=Color(255, 220, 180))))
face.add(Circle(Point(165, 175), 10, Style(fill_color=blue)))
face.add(Circle(Point(235, 175), 10, Style(fill_color=blue)))
face.add(Rectangle(Point(175, 220), 50, 10, Style(fill_color=red)))
house = GraphicGroup("House")
house.add(Rectangle(Point(400, 250), 120, 100, Style(fill_color=Color(200, 150, 100))))
house.add(Rectangle(Point(440, 280), 40, 70, Style(fill_color=Color(139, 90, 43))))
house.add(Rectangle(Point(410, 270), 25, 25, Style(fill_color=Color(135, 206, 235))))
house.add(Rectangle(Point(485, 270), 25, 25, Style(fill_color=Color(135, 206, 235))))
canvas.add(face)
canvas.add(house)
canvas.add(Line(Point(50, 350), Point(550, 350), Style(stroke_color=green, stroke_width=2)))
return canvas
if __name__ == "__main__":
drawing = create_drawing()
print("=== 图形结构 ===")
print(drawing.draw())
print(f"\n=== 边界信息 ===")
bounds = drawing.get_bounds()
print(f"左上角: ({bounds[0].x}, {bounds[0].y})")
print(f"右下角: ({bounds[1].x}, {bounds[1].y})")
print(f"\n=== 图形统计 ===")
shapes = drawing.flatten()
print(f"基本图形数量: {len(shapes)}")
for shape_type in ShapeType:
if shape_type != ShapeType.GROUP:
count = sum(1 for s in shapes if s.get_type() == shape_type)
print(f" {shape_type.value}: {count}")
print(f"\n=== JSON导出 ===")
print(json.dumps(drawing.to_dict(), indent=2, ensure_ascii=False))
print(f"\n=== 点击测试 ===")
test_point = Point(200, 200)
found = drawing.find_at_point(test_point)
if found:
print(f"在 ({test_point.x}, {test_point.y}) 找到: {found.draw()}")8.6 企业级应用示例
8.6.1 配置管理系统
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Set
from dataclasses import dataclass, field
from copy import deepcopy
import json
class ConfigNode(ABC):
"""配置节点抽象"""
@abstractmethod
def get(self, key: str, default: Any = None) -> Any:
pass
@abstractmethod
def set(self, key: str, value: Any) -> None:
pass
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
pass
@abstractmethod
def keys(self) -> Set[str]:
pass
def merge(self, other: 'ConfigNode') -> 'ConfigNode':
raise NotImplementedError
@dataclass
class ConfigValue(ConfigNode):
"""配置值节点"""
_value: Any
_metadata: Dict[str, Any] = field(default_factory=dict)
def get(self, key: str, default: Any = None) -> Any:
if key == "value":
return self._value
return self._metadata.get(key, default)
def set(self, key: str, value: Any) -> None:
if key == "value":
self._value = value
else:
self._metadata[key] = value
def to_dict(self) -> Dict[str, Any]:
result = {"value": self._value}
if self._metadata:
result["metadata"] = self._metadata
return result
def keys(self) -> Set[str]:
return {"value"} | set(self._metadata.keys())
@property
def value(self) -> Any:
return self._value
class ConfigSection(ConfigNode):
"""配置节节点"""
def __init__(self, name: str, parent: Optional['ConfigSection'] = None):
self._name = name
self._parent = parent
self._children: Dict[str, ConfigNode] = {}
self._inherit_from: List[str] = []
@property
def name(self) -> str:
return self._name
@property
def path(self) -> str:
if self._parent:
return f"{self._parent.path}.{self._name}"
return self._name
def get(self, key: str, default: Any = None) -> Any:
if '.' in key:
first, rest = key.split('.', 1)
child = self._children.get(first)
if child:
return child.get(rest, default)
return default
if key in self._children:
child = self._children[key]
if isinstance(child, ConfigValue):
return child.value
return child
for inherit_path in self._inherit_from:
inherited = self._resolve_inheritance(inherit_path)
if inherited:
result = inherited.get(key, None)
if result is not None:
return result
return default
def set(self, key: str, value: Any) -> None:
if '.' in key:
first, rest = key.split('.', 1)
if first not in self._children:
self._children[first] = ConfigSection(first, self)
self._children[first].set(rest, value)
else:
if isinstance(value, ConfigNode):
self._children[key] = value
else:
self._children[key] = ConfigValue(value)
def add_section(self, name: str) -> 'ConfigSection':
section = ConfigSection(name, self)
self._children[name] = section
return section
def add_inheritance(self, path: str) -> None:
self._inherit_from.append(path)
def _resolve_inheritance(self, path: str) -> Optional['ConfigSection']:
parts = path.split('.')
current: Optional[ConfigSection] = self._get_root()
for part in parts:
if current is None:
return None
child = current._children.get(part)
if isinstance(child, ConfigSection):
current = child
else:
return None
return current
def _get_root(self) -> 'ConfigSection':
current = self
while current._parent:
current = current._parent
return current
def to_dict(self) -> Dict[str, Any]:
result = {}
for key, child in self._children.items():
result[key] = child.to_dict()
if self._inherit_from:
result["_inherit"] = self._inherit_from
return result
def keys(self) -> Set[str]:
result = set(self._children.keys())
for inherit_path in self._inherit_from:
inherited = self._resolve_inheritance(inherit_path)
if inherited:
result.update(inherited.keys())
return result
def merge(self, other: ConfigNode) -> 'ConfigSection':
if isinstance(other, ConfigSection):
for key, child in other._children.items():
if key in self._children:
self_child = self._children[key]
if isinstance(self_child, ConfigSection) and isinstance(child, ConfigSection):
self_child.merge(child)
else:
self._children[key] = deepcopy(child)
else:
self._children[key] = deepcopy(child)
return self
def flatten(self) -> Dict[str, Any]:
"""展平配置为点分隔键值对"""
result = {}
for key, child in self._children.items():
if isinstance(child, ConfigSection):
for sub_key, value in child.flatten().items():
result[f"{key}.{sub_key}"] = value
elif isinstance(child, ConfigValue):
result[key] = child.value
return result
def validate(self, schema: Dict[str, Any]) -> List[str]:
"""验证配置"""
errors = []
for key, rules in schema.items():
value = self.get(key)
if rules.get("required", False) and value is None:
errors.append(f"Missing required key: {key}")
if value is not None and "type" in rules:
if not isinstance(value, rules["type"]):
errors.append(f"Invalid type for {key}: expected {rules['type']}, got {type(value)}")
return errors
def create_app_config() -> ConfigSection:
root = ConfigSection("app")
root.set("name", "MyApplication")
root.set("version", "1.0.0")
root.set("debug", False)
database = root.add_section("database")
database.set("host", "localhost")
database.set("port", 5432)
database.set("name", "myapp_db")
database.set("pool_size", 10)
cache = root.add_section("cache")
cache.set("backend", "redis")
cache.set("host", "localhost")
cache.set("port", 6379)
cache.set("ttl", 3600)
logging_config = root.add_section("logging")
logging_config.set("level", "INFO")
logging_config.set("format", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handlers = logging_config.add_section("handlers")
handlers.set("console", {"enabled": True, "level": "DEBUG"})
handlers.set("file", {"enabled": True, "level": "INFO", "path": "/var/log/app.log"})
features = root.add_section("features")
features.set("new_ui", True)
features.set("beta_api", False)
features.set("analytics", True)
return root
if __name__ == "__main__":
config = create_app_config()
print("=== 配置结构 ===")
print(json.dumps(config.to_dict(), indent=2))
print(f"\n=== 获取配置值 ===")
print(f"应用名称: {config.get('name')}")
print(f"数据库主机: {config.get('database.host')}")
print(f"缓存TTL: {config.get('cache.ttl')}")
print(f"日志级别: {config.get('logging.level')}")
print(f"不存在的键: {config.get('nonexistent.key', '默认值')}")
print(f"\n=== 展平配置 ===")
flat = config.flatten()
for key, value in sorted(flat.items()):
print(f"{key}: {value}")
print(f"\n=== 配置验证 ===")
schema = {
"name": {"required": True, "type": str},
"version": {"required": True, "type": str},
"database.host": {"required": True, "type": str},
"database.port": {"required": True, "type": int},
}
errors = config.validate(schema)
if errors:
for error in errors:
print(f"错误: {error}")
else:
print("配置验证通过")8.6.2 表达式解析与求值
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
import re
class Expression(ABC):
"""表达式抽象"""
@abstractmethod
def evaluate(self, context: Dict[str, Any]) -> Any:
pass
@abstractmethod
def to_string(self) -> str:
pass
@abstractmethod
def get_variables(self) -> set:
pass
def __repr__(self) -> str:
return self.to_string()
@dataclass
class Number(Expression):
"""数字字面量"""
value: float
def evaluate(self, context: Dict[str, Any]) -> float:
return self.value
def to_string(self) -> str:
return str(self.value)
def get_variables(self) -> set:
return set()
@dataclass
class Variable(Expression):
"""变量"""
name: str
def evaluate(self, context: Dict[str, Any]) -> Any:
if self.name not in context:
raise NameError(f"Variable '{self.name}' not defined")
return context[self.name]
def to_string(self) -> str:
return self.name
def get_variables(self) -> set:
return {self.name}
@dataclass
class BinaryOp(Expression):
"""二元运算"""
left: Expression
operator: str
right: Expression
_operators: Dict[str, Callable] = {
'+': lambda a, b: a + b,
'-': lambda a, b: a - b,
'*': lambda a, b: a * b,
'/': lambda a, b: a / b if b != 0 else float('inf'),
'%': lambda a, b: a % b,
'**': lambda a, b: a ** b,
'//': lambda a, b: a // b,
}
_precedence: Dict[str, int] = {
'+': 1, '-': 1,
'*': 2, '/': 2, '%': 2, '//': 2,
'**': 3,
}
def evaluate(self, context: Dict[str, Any]) -> Any:
left_val = self.left.evaluate(context)
right_val = self.right.evaluate(context)
if self.operator not in self._operators:
raise ValueError(f"Unknown operator: {self.operator}")
return self._operators[self.operator](left_val, right_val)
def to_string(self) -> str:
left_str = self.left.to_string()
right_str = self.right.to_string()
if isinstance(self.left, BinaryOp) and self._precedence.get(self.left.operator, 0) < self._precedence.get(self.operator, 0):
left_str = f"({left_str})"
if isinstance(self.right, BinaryOp) and self._precedence.get(self.right.operator, 0) < self._precedence.get(self.operator, 0):
right_str = f"({right_str})"
return f"{left_str} {self.operator} {right_str}"
def get_variables(self) -> set:
return self.left.get_variables() | self.right.get_variables()
@dataclass
class UnaryOp(Expression):
"""一元运算"""
operator: str
operand: Expression
def evaluate(self, context: Dict[str, Any]) -> Any:
val = self.operand.evaluate(context)
if self.operator == '-':
return -val
elif self.operator == '+':
return +val
raise ValueError(f"Unknown unary operator: {self.operator}")
def to_string(self) -> str:
return f"{self.operator}{self.operand.to_string()}"
def get_variables(self) -> set:
return self.operand.get_variables()
@dataclass
class FunctionCall(Expression):
"""函数调用"""
name: str
args: List[Expression]
_functions: Dict[str, Callable] = {
'abs': abs,
'round': round,
'min': min,
'max': max,
'sum': sum,
'sqrt': lambda x: x ** 0.5,
'sin': __import__('math').sin,
'cos': __import__('math').cos,
}
def evaluate(self, context: Dict[str, Any]) -> Any:
if self.name not in self._functions:
raise NameError(f"Unknown function: {self.name}")
arg_values = [arg.evaluate(context) for arg in self.args]
return self._functions[self.name](*arg_values)
def to_string(self) -> str:
args_str = ", ".join(arg.to_string() for arg in self.args)
return f"{self.name}({args_str})"
def get_variables(self) -> set:
result = set()
for arg in self.args:
result.update(arg.get_variables())
return result
class ExpressionParser:
"""表达式解析器"""
def __init__(self):
self._tokens: List[str] = []
self._pos: int = 0
def parse(self, expr: str) -> Expression:
self._tokens = self._tokenize(expr)
self._pos = 0
result = self._parse_expression()
if self._pos < len(self._tokens):
raise SyntaxError(f"Unexpected token: {self._tokens[self._pos]}")
return result
def _tokenize(self, expr: str) -> List[str]:
pattern = r'(\d+\.?\d*|\w+|[+\-*/%()]|\*\*|//)'
tokens = re.findall(pattern, expr.replace(' ', ''))
return tokens
def _parse_expression(self, min_prec: int = 0) -> Expression:
left = self._parse_primary()
while self._pos < len(self._tokens):
op = self._tokens[self._pos]
prec = BinaryOp._precedence.get(op, -1)
if prec < min_prec:
break
self._pos += 1
right = self._parse_expression(prec + 1)
left = BinaryOp(left, op, right)
return left
def _parse_primary(self) -> Expression:
if self._pos >= len(self._tokens):
raise SyntaxError("Unexpected end of expression")
token = self._tokens[self._pos]
if token == '(':
self._pos += 1
expr = self._parse_expression()
if self._pos >= len(self._tokens) or self._tokens[self._pos] != ')':
raise SyntaxError("Missing closing parenthesis")
self._pos += 1
return expr
if token in ('+', '-'):
self._pos += 1
return UnaryOp(token, self._parse_primary())
if re.match(r'^\d+\.?\d*$', token):
self._pos += 1
return Number(float(token))
if re.match(r'^\w+$', token):
self._pos += 1
if self._pos < len(self._tokens) and self._tokens[self._pos] == '(':
return self._parse_function_call(token)
return Variable(token)
raise SyntaxError(f"Unexpected token: {token}")
def _parse_function_call(self, name: str) -> FunctionCall:
self._pos += 1
args = []
if self._pos < len(self._tokens) and self._tokens[self._pos] != ')':
args.append(self._parse_expression())
while self._pos < len(self._tokens) and self._tokens[self._pos] == ',':
self._pos += 1
args.append(self._parse_expression())
if self._pos >= len(self._tokens) or self._tokens[self._pos] != ')':
raise SyntaxError("Missing closing parenthesis in function call")
self._pos += 1
return FunctionCall(name, args)
if __name__ == "__main__":
parser = ExpressionParser()
expressions = [
"1 + 2 * 3",
"(1 + 2) * 3",
"x + y * z",
"2 ** 3 + 4",
"abs(-5) + max(1, 2, 3)",
"sqrt(16) + sin(0)",
"a * b + c * d",
]
context = {
'x': 10, 'y': 20, 'z': 5,
'a': 2, 'b': 3, 'c': 4, 'd': 5,
}
print("=== 表达式求值 ===")
for expr_str in expressions:
expr = parser.parse(expr_str)
result = expr.evaluate(context)
variables = expr.get_variables()
print(f"{expr_str:30} => {expr.to_string():30} = {result}")
if variables:
print(f" 变量: {variables}")8.7 模式变体与扩展
8.7.1 缓存组合(Memoization)
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
from functools import lru_cache
import hashlib
import json
class CacheableComponent(ABC):
"""可缓存组件"""
@abstractmethod
def compute(self) -> Any:
pass
@abstractmethod
def get_cache_key(self) -> str:
pass
def compute_cached(self, cache: Dict[str, Any]) -> Any:
key = self.get_cache_key()
if key not in cache:
cache[key] = self.compute()
return cache[key]
class CacheableLeaf(CacheableComponent):
"""可缓存叶子"""
def __init__(self, name: str, value: Any, expensive_func: callable):
self._name = name
self._value = value
self._func = expensive_func
def compute(self) -> Any:
return self._func(self._value)
def get_cache_key(self) -> str:
data = json.dumps({"name": self._name, "value": str(self._value)})
return hashlib.md5(data.encode()).hexdigest()
class CacheableComposite(CacheableComponent):
"""可缓存组合"""
def __init__(self, name: str, aggregator: callable):
self._name = name
self._aggregator = aggregator
self._children: List[CacheableComponent] = []
def add(self, component: CacheableComponent) -> None:
self._children.append(component)
def compute(self) -> Any:
results = [child.compute() for child in self._children]
return self._aggregator(results)
def compute_cached(self, cache: Dict[str, Any]) -> Any:
results = [child.compute_cached(cache) for child in self._children]
return self._aggregator(results)
def get_cache_key(self) -> str:
child_keys = [child.get_cache_key() for child in self._children]
data = json.dumps({"name": self._name, "children": child_keys})
return hashlib.md5(data.encode()).hexdigest()
def invalidate_cache(self, cache: Dict[str, Any]) -> None:
key = self.get_cache_key()
if key in cache:
del cache[key]
for child in self._children:
if isinstance(child, CacheableComposite):
child.invalidate_cache(cache)8.7.2 观察者组合(事件传播)
from abc import ABC, abstractmethod
from typing import List, Callable, Dict, Any
from dataclasses import dataclass
from enum import Enum
class EventType(Enum):
ADD = "add"
REMOVE = "remove"
UPDATE = "update"
MOVE = "move"
@dataclass
class Event:
type: EventType
source: Any
data: Dict[str, Any]
class ObservableComponent(ABC):
"""可观察组件"""
def __init__(self):
self._listeners: List[Callable[[Event], None]] = []
def add_listener(self, listener: Callable[[Event], None]) -> None:
self._listeners.append(listener)
def remove_listener(self, listener: Callable[[Event], None]) -> None:
self._listeners.remove(listener)
def _emit(self, event: Event) -> None:
for listener in self._listeners:
listener(event)
@abstractmethod
def operation(self) -> Any:
pass
class ObservableLeaf(ObservableComponent):
"""可观察叶子"""
def __init__(self, name: str, value: Any):
super().__init__()
self._name = name
self._value = value
@property
def value(self) -> Any:
return self._value
@value.setter
def value(self, new_value: Any) -> None:
old_value = self._value
self._value = new_value
self._emit(Event(
type=EventType.UPDATE,
source=self,
data={"old": old_value, "new": new_value}
))
def operation(self) -> Any:
return self._value
class ObservableComposite(ObservableComponent):
"""可观察组合"""
def __init__(self, name: str):
super().__init__()
self._name = name
self._children: List[ObservableComponent] = []
def add(self, component: ObservableComponent) -> None:
self._children.append(component)
self._emit(Event(
type=EventType.ADD,
source=self,
data={"component": component}
))
def remove(self, component: ObservableComponent) -> None:
self._children.remove(component)
self._emit(Event(
type=EventType.REMOVE,
source=self,
data={"component": component}
))
def operation(self) -> Any:
return [child.operation() for child in self._children]8.7.3 持久化组合(序列化)
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Type
import json
class SerializableComponent(ABC):
"""可序列化组件"""
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
pass
@classmethod
@abstractmethod
def from_dict(cls, data: Dict[str, Any]) -> 'SerializableComponent':
pass
def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
@classmethod
def from_json(cls, json_str: str) -> 'SerializableComponent':
data = json.loads(json_str)
return cls.from_dict(data)
class SerializableLeaf(SerializableComponent):
"""可序列化叶子"""
_registry: Dict[str, Type['SerializableLeaf']] = {}
def __init__(self, name: str, value: Any):
self._name = name
self._value = value
@classmethod
def register(cls, type_name: str):
def decorator(subclass):
cls._registry[type_name] = subclass
return subclass
return decorator
def to_dict(self) -> Dict[str, Any]:
return {
"type": self.__class__.__name__,
"name": self._name,
"value": self._value
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'SerializableLeaf':
type_name = data.get("type", "SerializableLeaf")
if type_name in cls._registry:
return cls._registry[type_name].from_dict(data)
return cls(data["name"], data["value"])
class SerializableComposite(SerializableComponent):
"""可序列化组合"""
def __init__(self, name: str):
self._name = name
self._children: List[SerializableComponent] = []
def add(self, component: SerializableComponent) -> None:
self._children.append(component)
def to_dict(self) -> Dict[str, Any]:
return {
"type": "Composite",
"name": self._name,
"children": [child.to_dict() for child in self._children]
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'SerializableComposite':
composite = cls(data["name"])
for child_data in data.get("children", []):
if child_data.get("type") == "Composite":
child = cls.from_dict(child_data)
else:
child = SerializableLeaf.from_dict(child_data)
composite.add(child)
return composite8.8 反模式与最佳实践
8.8.1 常见反模式
反模式1:过度嵌套
class BadNestedComposite:
"""过度嵌套的反模式"""
def __init__(self):
self.level1 = {
"level2": {
"level3": {
"level4": {
"level5": {
"value": "too deep"
}
}
}
}
}
def get_value(self):
return self.level1["level2"]["level3"]["level4"]["level5"]["value"]
class GoodFlattenedComposite:
"""扁平化的正确模式"""
def __init__(self):
self._nodes: Dict[str, Any] = {}
def add(self, path: str, value: Any) -> None:
self._nodes[path] = value
def get(self, path: str, default: Any = None) -> Any:
return self._nodes.get(path, default)反模式2:循环引用
class CircularComposite:
"""循环引用的反模式"""
def __init__(self, name: str):
self.name = name
self.children = []
def add_circular(self, other: 'CircularComposite') -> None:
self.children.append(other)
other.children.append(self)
def detect_cycle(node: 'CircularComposite', visited: set = None) -> bool:
"""检测循环引用"""
if visited is None:
visited = set()
if id(node) in visited:
return True
visited.add(id(node))
for child in node.children:
if detect_cycle(child, visited):
return True
visited.remove(id(node))
return False反模式3:上帝组件
class GodComponent:
"""上帝组件反模式 - 承担过多职责"""
def __init__(self):
self.children = []
self.database_connection = None
self.cache = {}
self.logger = None
self.config = {}
self.event_handlers = []
def operation(self): pass
def save_to_db(self): pass
def load_from_cache(self): pass
def log_activity(self): pass
def validate_config(self): pass
def handle_event(self): pass
class FocusedComponent:
"""职责单一的正确模式"""
def __init__(self, name: str):
self._name = name
self._children: List['FocusedComponent'] = []
def operation(self) -> str:
results = [child.operation() for child in self._children]
return f"{self._name}: {results}"8.8.2 最佳实践
| 实践 | 描述 | 示例 |
|---|---|---|
| 接口最小化 | Component接口只包含必要方法 | 仅声明operation() |
| 类型检查 | 使用is_composite()或isinstance() | 运行时类型判断 |
| 防御性复制 | 返回子组件的副本 | get_children()返回copy |
| 迭代器支持 | 实现__iter__支持遍历 | for child in composite |
| 上下文管理 | 实现__enter__/__exit__ | 资源管理 |
| 不可变设计 | 叶子节点设计为不可变 | 使用@dataclass(frozen=True) |
8.9 模式比较
8.9.1 与相关模式对比
| 模式 | 意图 | 结构 | 关键区别 |
|---|---|---|---|
| 组合 | 部分-整体层次结构 | 树形递归 | 统一处理单个/组合对象 |
| 装饰器 | 动态添加职责 | 线性包装 | 不改变对象本质 |
| 责任链 | 请求传递处理 | 链式传递 | 关注处理流程 |
| 迭代器 | 遍历聚合对象 | 线性遍历 | 关注访问方式 |
| 访问者 | 分离操作与结构 | 双分派 | 关注操作扩展 |
8.9.2 组合模式与其他模式的协作
┌─────────────────────────────────────────────────────────────┐
│ 组合模式协作关系 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 组合模式 │────>│ 访问者 │────>│ 操作分离 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 迭代器 │────>│ 遍历支持 │────>│ 统一访问 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 装饰器 │────>│ 职责增强 │────>│ 动态扩展 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘8.10 决策指南
8.10.1 适用场景检查清单
- [ ] 需要表示"部分-整体"层次结构
- [ ] 希望统一处理单个对象和组合对象
- [ ] 结构具有递归特性
- [ ] 客户端代码需要忽略叶子与组合的差异
- [ ] 需要动态组合对象
8.10.2 实现选择决策树
需要类型安全?
/ \
是 否
/ \
安全式 需要运行时添加/移除?
/ \
是 否
/ \
透明式 安全式8.10.3 快速参考卡
┌─────────────────────────────────────────────────────────────┐
│ 组合模式快速参考 │
├─────────────────────────────────────────────────────────────┤
│ 定义: 将对象组合成树形结构,统一处理单个/组合对象 │
├─────────────────────────────────────────────────────────────┤
│ 参与者: │
│ • Component - 统一接口 │
│ • Leaf - 叶子节点 │
│ • Composite - 组合节点 │
│ • Client - 客户端 │
├─────────────────────────────────────────────────────────────┤
│ 形式化: C = L ∪ K, K = {(id, {c₁, c₂, ..., cₙ})} │
├─────────────────────────────────────────────────────────────┤
│ 变体: │
│ • 透明式 - 所有方法在基类声明 │
│ • 安全式 - 管理方法仅在Composite │
├─────────────────────────────────────────────────────────────┤
│ 典型应用: │
│ • 文件系统 │
│ • 组织架构 │
│ • 图形编辑器 │
│ • 菜单系统 │
│ • 配置管理 │
├─────────────────────────────────────────────────────────────┤
│ Python特性: │
│ • ABC抽象基类 │
│ • Protocol结构化类型 │
│ • Generic泛型支持 │
│ • dataclass数据类 │
├─────────────────────────────────────────────────────────────┤
│ 最佳实践: │
│ • 接口最小化 │
│ • 防御性复制 │
│ • 避免循环引用 │
│ • 实现迭代器支持 │
└─────────────────────────────────────────────────────────────┘8.11 小结
组合模式是处理树形结构的理想选择。通过统一单个对象和组合对象的接口,简化了客户端代码。其核心价值在于:
- 统一抽象:客户端无需区分叶子与组合
- 递归结构:天然支持层次化数据
- 开闭原则:易于添加新的组件类型
- 简化客户端:减少条件判断逻辑
在Python中,可以利用抽象基类、Protocol、Generic等现代特性实现类型安全的组合结构,同时结合迭代器、上下文管理器等协议提供更丰富的功能支持。
思考与练习
基础练习:实现一个支持撤销操作的组合模式,记录所有修改历史。
进阶练习:设计一个支持并行计算的组合模式,利用多线程/多进程加速组合节点的计算。
挑战练习:实现一个支持增量更新的组合模式,当叶子节点变化时,仅更新受影响的组合节点。
设计思考:组合模式与访问者模式如何协作?请设计一个支持多种访问操作的组合结构。