第9章 装饰器模式
学习目标
完成本章学习后,读者将能够:
- 理解装饰器模式的核心概念与数学形式化定义
- 掌握Python装饰器语法糖与类装饰器的实现原理
- 实现动态功能扩展与多层装饰器组合
- 运用Protocol、Generic等现代Python特性实现类型安全的装饰器
- 识别装饰器模式的适用场景与反模式
- 设计企业级装饰器解决方案
9.1 模式定义
9.1.1 核心定义
装饰器模式(Decorator Pattern) 动态地给一个对象添加一些额外的职责。就增加功能来说,装饰器模式相比生成子类更为灵活。装饰器模式通过包装(Wrapper)机制,在不改变对象结构的前提下,动态地扩展对象的功能。
9.1.2 形式化定义
设 $\mathcal{C}$ 为组件接口,$c \in \mathcal{C}$ 为具体组件,$d: \mathcal{C} \rightarrow \mathcal{C}$ 为装饰器函数,则装饰器模式可形式化定义为:
$$\text{Decorator}: \mathcal{C} \xrightarrow{d_1 \circ d_2 \circ ... \circ d_n} \mathcal{C}'$$
其中 $\mathcal{C}'$ 是经过装饰后的组件,满足:
$$\text{operation}(\mathcal{C}') = f_n(f_{n-1}(...f_1(\text{operation}(c))...))$$
装饰链性质:
- 封闭性:$\forall d: \mathcal{C} \rightarrow \mathcal{C}, c \in \mathcal{C} \Rightarrow d(c) \in \mathcal{C}$
- 结合律:$(d_1 \circ d_2) \circ d_3 = d_1 \circ (d_2 \circ d_3)$
- 恒等性:$\exists I: I(c) = c$(无装饰的原始组件)
功能组合公式:
$$\text{Features}(d_n(...d_1(c)...)) = \text{Features}(c) \cup \bigcup_{i=1}^{n} \text{AddedFeatures}(d_i)$$
复杂度分析:
- 装饰器创建:$O(1)$(对象包装)
- 方法调用:$O(n)$($n$ 为装饰器层数)
- 空间复杂度:$O(n)$(每层装饰器一个对象)
9.1.3 历史背景与学术脉络
装饰器模式起源于用户界面工具包的设计实践。其学术发展历程如下:
| 年份 | 里程碑 | 贡献者 |
|---|---|---|
| 1987 | Stream装饰器概念提出 | Böhm, Jacopini |
| 1988 | ET++框架中的视图装饰 | Weinand, Gamma |
| 1994 | GoF《设计模式》正式收录 | Gang of Four |
| 1998 | Java I/O流装饰器应用 | Sun Microsystems |
| 2003 | Python 2.4引入@语法糖 | Python社区 |
| 2015 | Python类型装饰器研究 | Python社区 |
| 2020 | 函数式装饰器模式研究 | Martin et al. |
学术意义:装饰器模式体现了**开闭原则(OCP)**的精髓,通过组合而非继承实现功能扩展,是"组合优于继承"原则的经典实践。在Python中,装饰器更成为一种语言级特性,深刻影响了Python的编程范式。
9.2 模式结构与参与者
9.2.1 UML类图
┌─────────────────────────────┐
│ <<interface>> │
│ Component │
├─────────────────────────────┤
│ + operation(): Result │
└───────────────┬─────────────┘
│
┌───────────────┴───────────────┐
│ │
┌───────────┴───────────┐ ┌────────────┴────────────┐
│ ConcreteComponent │ │ Decorator │
├───────────────────────┤ ├─────────────────────────┤
│ - state: Any │ │ - component: Component │
├───────────────────────┤ ├─────────────────────────┤
│ + operation(): Result │ │ + operation(): Result │
└───────────────────────┘ └────────────┬────────────┘
│
┌───────────┴───────────┐
│ │
┌───────────┴───────────┐ ┌───────┴───────────┐
│ ConcreteDecoratorA │ │ ConcreteDecoratorB│
├───────────────────────┤ ├───────────────────┤
│ - added_state: Any │ │ - added_behavior()│
├───────────────────────┤ ├───────────────────┤
│ + operation(): Result │ │ + operation() │
└───────────────────────┘ └───────────────────┘9.2.2 参与者职责
| 参与者 | 职责 | 关键特征 |
|---|---|---|
| Component | 定义对象接口,声明operation方法 | 抽象接口 |
| ConcreteComponent | 定义具体对象,实现基础功能 | 被装饰对象 |
| Decorator | 持有Component引用,实现相同接口 | 包装器基类 |
| ConcreteDecorator | 添加具体功能,扩展operation行为 | 功能扩展者 |
| Client | 通过Component接口操作对象 | 不感知装饰 |
9.2.3 装饰链结构
Client ──────> Component.operation()
│
▼
┌──────────────┐
│ Decorator B │
└──────┬───────┘
│
▼
┌──────────────┐
│ Decorator A │
└──────┬───────┘
│
▼
┌──────────────┐
│ConcreteComp │
└──────────────┘9.3 Python实现
9.3.1 标准实现(ABC抽象基类)
from abc import ABC, abstractmethod
from typing import Any, Optional
class Component(ABC):
"""组件抽象基类"""
@abstractmethod
def operation(self) -> str:
"""核心操作方法"""
pass
class ConcreteComponent(Component):
"""具体组件 - 基础功能实现"""
def __init__(self, name: str = "Base"):
self._name = name
def operation(self) -> str:
return f"ConcreteComponent({self._name})"
class Decorator(Component):
"""装饰器基类 - 持有组件引用"""
def __init__(self, component: Component):
self._component = component
def operation(self) -> str:
return self._component.operation()
class ConcreteDecoratorA(Decorator):
"""具体装饰器A - 添加状态"""
def __init__(self, component: Component):
super().__init__(component)
self._added_state = "StateA"
def operation(self) -> str:
return f"DecoratorA[{self._added_state}]({self._component.operation()})"
class ConcreteDecoratorB(Decorator):
"""具体装饰器B - 添加行为"""
def operation(self) -> str:
return f"DecoratorB({self._component.operation()})"
def added_behavior(self) -> str:
"""装饰器特有的额外方法"""
return "Added behavior from DecoratorB"
def client_code(component: Component) -> None:
"""客户端代码 - 统一处理所有组件"""
print(f"RESULT: {component.operation()}")
if __name__ == "__main__":
simple = ConcreteComponent("Simple")
client_code(simple)
decorator_a = ConcreteDecoratorA(simple)
decorator_b = ConcreteDecoratorB(decorator_a)
client_code(decorator_b)
print(f"Extra: {decorator_b.added_behavior()}")9.3.2 Protocol实现(结构化类型)
from typing import Protocol, runtime_checkable, TypeVar, Generic
from dataclasses import dataclass
T = TypeVar('T')
class ComponentProtocol(Protocol):
"""组件协议 - 结构化类型"""
def operation(self) -> str: ...
@runtime_checkable
class DecoratableProtocol(Protocol):
"""可装饰协议"""
def operation(self) -> str: ...
def get_metadata(self) -> dict: ...
@dataclass
class BaseComponent:
"""基础组件 - 数据类实现"""
name: str
value: int = 0
def operation(self) -> str:
return f"Base({self.name}: {self.value})"
def get_metadata(self) -> dict:
return {"type": "base", "name": self.name}
class LoggingDecorator:
"""日志装饰器 - Protocol兼容"""
def __init__(self, component: DecoratableProtocol):
self._component = component
def operation(self) -> str:
print(f"[LOG] Calling operation on {self._component.get_metadata()}")
result = self._component.operation()
print(f"[LOG] Result: {result}")
return result
def get_metadata(self) -> dict:
metadata = self._component.get_metadata()
metadata["decorators"] = metadata.get("decorators", []) + ["logging"]
return metadata
class CachingDecorator:
"""缓存装饰器"""
def __init__(self, component: DecoratableProtocol):
self._component = component
self._cache: dict = {}
def operation(self) -> str:
key = "operation_result"
if key not in self._cache:
self._cache[key] = self._component.operation()
return f"[CACHED] {self._cache[key]}"
def get_metadata(self) -> dict:
metadata = self._component.get_metadata()
metadata["decorators"] = metadata.get("decorators", []) + ["caching"]
return metadata
def process_component(comp: ComponentProtocol) -> None:
"""使用Protocol进行类型检查"""
print(f"Processing: {comp.operation()}")
if __name__ == "__main__":
base = BaseComponent("Test", 42)
logged = LoggingDecorator(base)
cached = CachingDecorator(logged)
process_component(cached)
print(f"Metadata: {cached.get_metadata()}")
print(f"\nType checking:")
print(f"base is ComponentProtocol: {isinstance(base, ComponentProtocol)}")
print(f"cached is ComponentProtocol: {isinstance(cached, ComponentProtocol)}")9.3.3 泛型实现(类型安全)
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Callable, List, Optional
from dataclasses import dataclass
T = TypeVar('T')
R = TypeVar('R')
class GenericComponent(ABC, Generic[T, R]):
"""泛型组件接口"""
@abstractmethod
def operation(self, input_data: T) -> R:
pass
class GenericDecorator(GenericComponent[T, R], Generic[T, R]):
"""泛型装饰器基类"""
def __init__(self, component: GenericComponent[T, R]):
self._component = component
def operation(self, input_data: T) -> R:
return self._component.operation(input_data)
@dataclass
class StringProcessor(GenericComponent[str, str]):
"""字符串处理器"""
prefix: str = ""
suffix: str = ""
def operation(self, input_data: str) -> str:
return f"{self.prefix}{input_data}{self.suffix}"
class UpperCaseDecorator(GenericDecorator[str, str]):
"""大写装饰器"""
def operation(self, input_data: str) -> str:
result = self._component.operation(input_data.upper())
return result
class ReverseDecorator(GenericDecorator[str, str]):
"""反转装饰器"""
def operation(self, input_data: str) -> str:
result = self._component.operation(input_data[::-1])
return result
class TrimDecorator(GenericDecorator[str, str]):
"""修剪装饰器"""
def operation(self, input_data: str) -> str:
result = self._component.operation(input_data.strip())
return result
@dataclass
class NumberProcessor(GenericComponent[int, int]):
"""数字处理器"""
multiplier: int = 1
def operation(self, input_data: int) -> int:
return input_data * self.multiplier
class DoubleDecorator(GenericDecorator[int, int]):
"""加倍装饰器"""
def operation(self, input_data: int) -> int:
return self._component.operation(input_data * 2)
class SquareDecorator(GenericDecorator[int, int]):
"""平方装饰器"""
def operation(self, input_data: int) -> int:
result = self._component.operation(input_data)
return result ** 2
if __name__ == "__main__":
print("=== 字符串处理管道 ===")
processor = StringProcessor("[", "]")
processor = TrimDecorator(processor)
processor = UpperCaseDecorator(processor)
result = processor.operation(" hello world ")
print(f"Result: {result}")
print("\n=== 数字处理管道 ===")
num_processor = NumberProcessor(10)
num_processor = DoubleDecorator(num_processor)
num_processor = SquareDecorator(num_processor)
num_result = num_processor.operation(5)
print(f"Result: {num_result}")9.3.4 Python函数装饰器
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any, Optional
import time
import inspect
P = ParamSpec('P')
T = TypeVar('T')
def timer(func: Callable[P, T]) -> Callable[P, T]:
"""计时装饰器"""
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
start = time.perf_counter()
result = func(*args, **kwargs)
end = time.perf_counter()
print(f"[TIMER] {func.__name__} 执行时间: {end - start:.6f}秒")
return result
return wrapper
def log_calls(level: str = "INFO") -> Callable[[Callable[P, T]], Callable[P, T]]:
"""日志装饰器工厂"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
print(f"[{level}] 调用 {func.__name__}")
print(f" 参数: args={args}, kwargs={kwargs}")
result = func(*args, **kwargs)
print(f"[{level}] {func.__name__} 返回: {result}")
return result
return wrapper
return decorator
def retry(
max_attempts: int = 3,
delay: float = 1.0,
exceptions: tuple = (Exception,)
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""重试装饰器"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
last_exception = None
for attempt in range(max_attempts):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
print(f"[RETRY] {func.__name__} 第{attempt + 1}次失败: {e}")
if attempt < max_attempts - 1:
time.sleep(delay)
raise last_exception
return wrapper
return decorator
def validate_types(**type_hints) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""类型验证装饰器"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for param_name, expected_type in type_hints.items():
if param_name in bound.arguments:
value = bound.arguments[param_name]
if not isinstance(value, expected_type):
raise TypeError(
f"参数 '{param_name}' 应为 {expected_type.__name__},"
f"实际为 {type(value).__name__}"
)
return func(*args, **kwargs)
return wrapper
return decorator
def memoize(func: Callable[P, T]) -> Callable[P, T]:
"""记忆化装饰器"""
cache: dict = {}
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
key = (args, frozenset(kwargs.items()))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
wrapper.cache = cache
wrapper.cache_clear = lambda: cache.clear()
return wrapper
def deprecated(
reason: str = "",
version: str = ""
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""废弃警告装饰器"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
message = f"[DEPRECATED] {func.__name__} 已废弃"
if version:
message += f" (自版本 {version})"
if reason:
message += f": {reason}"
print(message)
return func(*args, **kwargs)
return wrapper
return decorator
if __name__ == "__main__":
@timer
@log_calls(level="DEBUG")
def calculate_sum(n: int) -> int:
return sum(range(n))
print(calculate_sum(100000))
@retry(max_attempts=3, delay=0.1, exceptions=(ValueError,))
def unreliable_function(x: int) -> int:
import random
if random.random() < 0.7:
raise ValueError("随机失败")
return x * 2
@validate_types(name=str, age=int)
def create_user(name: str, age: int) -> dict:
return {"name": name, "age": age}
@memoize
def fibonacci(n: int) -> int:
if n <= 1:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
print(f"\nFibonacci(30) = {fibonacci(30)}")
print(f"Cache size: {len(fibonacci.cache)}")
@deprecated(reason="请使用 new_function()", version="2.0")
def old_function() -> str:
return "old result"
print(old_function())9.3.5 类装饰器
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
from dataclasses import dataclass, field
from functools import wraps
T = TypeVar('T')
def singleton(cls: Type[T]) -> Type[T]:
"""单例类装饰器"""
instances: Dict[Type, Any] = {}
@wraps(cls)
def wrapper(*args, **kwargs) -> T:
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
wrapper._is_singleton = True
wrapper._instances = instances
return wrapper
def add_logging(cls: Type[T]) -> Type[T]:
"""添加日志功能的类装饰器"""
original_methods = {}
for name, method in vars(cls).items():
if callable(method) and not name.startswith('_'):
original_methods[name] = method
@wraps(method)
def logged_method(self, *args, _method=method, **kwargs):
print(f"[LOG] {cls.__name__}.{_method.__name__} called")
result = _method(self, *args, **kwargs)
print(f"[LOG] {cls.__name__}.{_method.__name__} returned: {result}")
return result
setattr(cls, name, logged_method)
return cls
def add_repr(cls: Type[T]) -> Type[T]:
"""自动添加__repr__方法"""
if '__repr__' not in cls.__dict__:
def __repr__(self) -> str:
fields = []
for key, value in vars(self).items():
if not key.startswith('_'):
fields.append(f"{key}={value!r}")
return f"{cls.__name__}({', '.join(fields)})"
cls.__repr__ = __repr__
return cls
def add_comparison(cls: Type[T]) -> Type[T]:
"""自动添加比较方法"""
def __eq__(self, other: Any) -> bool:
if not isinstance(other, cls):
return NotImplemented
return vars(self) == vars(other)
def __hash__(self) -> int:
return hash(tuple(sorted(vars(self).items())))
cls.__eq__ = __eq__
cls.__hash__ = __hash__
return cls
def validate_attributes(**validators: Callable[[Any], bool]) -> Callable[[Type[T]], Type[T]]:
"""属性验证装饰器"""
def decorator(cls: Type[T]) -> Type[T]:
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
for attr_name, validator in validators.items():
if hasattr(self, attr_name):
value = getattr(self, attr_name)
if not validator(value):
raise ValueError(
f"属性 '{attr_name}' 的值 {value} 验证失败"
)
cls.__init__ = new_init
return cls
return decorator
def frozen(cls: Type[T]) -> Type[T]:
"""不可变类装饰器"""
original_init = cls.__init__
original_setattr = cls.__setattr__ if '__setattr__' in cls.__dict__ else object.__setattr__
@wraps(original_init)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
object.__setattr__(self, '_frozen', True)
def new_setattr(self, name: str, value: Any) -> None:
if getattr(self, '_frozen', False):
raise AttributeError(f"无法修改不可变对象 {cls.__name__} 的属性 '{name}'")
original_setattr(self, name, value)
cls.__init__ = new_init
cls.__setattr__ = new_setattr
return cls
if __name__ == "__main__":
@singleton
@add_repr
@add_comparison
class Database:
def __init__(self, connection_string: str):
self.connection_string = connection_string
self.connected = False
def connect(self) -> None:
self.connected = True
print(f"Connected to {self.connection_string}")
def disconnect(self) -> None:
self.connected = False
print("Disconnected")
db1 = Database("mysql://localhost")
db2 = Database("mysql://localhost")
print(f"单例验证: {db1 is db2}")
print(db1)
@validate_attributes(
name=lambda x: len(x) >= 2,
age=lambda x: 0 <= x <= 150
)
@add_repr
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
try:
p = Person("A", 200)
except ValueError as e:
print(f"验证错误: {e}")
@frozen
@add_repr
class Point:
def __init__(self, x: float, y: float):
self.x = x
self.y = y
pt = Point(1.0, 2.0)
print(pt)
try:
pt.x = 3.0
except AttributeError as e:
print(f"不可变错误: {e}")9.4 实际应用示例
9.4.1 咖啡店订单系统
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from enum import Enum
class Size(Enum):
TALL = "中杯"
GRANDE = "大杯"
VENTI = "超大杯"
class Beverage(ABC):
"""饮料抽象基类"""
@abstractmethod
def get_description(self) -> str:
pass
@abstractmethod
def get_cost(self) -> float:
pass
@abstractmethod
def get_size(self) -> Size:
pass
def get_calories(self) -> int:
return 0
@dataclass
class Espresso(Beverage):
"""浓缩咖啡"""
_size: Size = Size.TALL
def get_description(self) -> str:
return "浓缩咖啡"
def get_cost(self) -> float:
prices = {Size.TALL: 15.0, Size.GRANDE: 18.0, Size.VENTI: 21.0}
return prices[self._size]
def get_size(self) -> Size:
return self._size
def get_calories(self) -> int:
return 5
@dataclass
class HouseBlend(Beverage):
"""混合咖啡"""
_size: Size = Size.TALL
def get_description(self) -> str:
return "混合咖啡"
def get_cost(self) -> float:
prices = {Size.TALL: 12.0, Size.GRANDE: 15.0, Size.VENTI: 18.0}
return prices[self._size]
def get_size(self) -> Size:
return self._size
def get_calories(self) -> int:
return 10
@dataclass
class DarkRoast(Beverage):
"""深焙咖啡"""
_size: Size = Size.TALL
def get_description(self) -> str:
return "深焙咖啡"
def get_cost(self) -> float:
prices = {Size.TALL: 14.0, Size.GRANDE: 17.0, Size.VENTI: 20.0}
return prices[self._size]
def get_size(self) -> Size:
return self._size
def get_calories(self) -> int:
return 8
class CondimentDecorator(Beverage):
"""调料装饰器基类"""
def __init__(self, beverage: Beverage):
self._beverage = beverage
def get_size(self) -> Size:
return self._beverage.get_size()
class Milk(CondimentDecorator):
"""牛奶调料"""
def get_description(self) -> str:
return f"{self._beverage.get_description()}, 牛奶"
def get_cost(self) -> float:
size_prices = {Size.TALL: 3.0, Size.GRANDE: 4.0, Size.VENTI: 5.0}
return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
def get_calories(self) -> int:
return self._beverage.get_calories() + 50
class Mocha(CondimentDecorator):
"""摩卡调料"""
def get_description(self) -> str:
return f"{self._beverage.get_description()}, 摩卡"
def get_cost(self) -> float:
size_prices = {Size.TALL: 4.0, Size.GRANDE: 5.0, Size.VENTI: 6.0}
return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
def get_calories(self) -> int:
return self._beverage.get_calories() + 80
class Whip(CondimentDecorator):
"""奶油调料"""
def get_description(self) -> str:
return f"{self._beverage.get_description()}, 奶油"
def get_cost(self) -> float:
size_prices = {Size.TALL: 2.5, Size.GRANDE: 3.0, Size.VENTI: 3.5}
return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
def get_calories(self) -> int:
return self._beverage.get_calories() + 100
class Soy(CondimentDecorator):
"""豆浆调料"""
def get_description(self) -> str:
return f"{self._beverage.get_description()}, 豆浆"
def get_cost(self) -> float:
size_prices = {Size.TALL: 3.5, Size.GRANDE: 4.5, Size.VENTI: 5.5}
return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
def get_calories(self) -> int:
return self._beverage.get_calories() + 40
class Caramel(CondimentDecorator):
"""焦糖调料"""
def get_description(self) -> str:
return f"{self._beverage.get_description()}, 焦糖"
def get_cost(self) -> float:
size_prices = {Size.TALL: 3.0, Size.GRANDE: 4.0, Size.VENTI: 5.0}
return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
def get_calories(self) -> int:
return self._beverage.get_calories() + 60
class Order:
"""订单类"""
def __init__(self):
self._items: List[Beverage] = []
self._order_id = id(self)
def add_beverage(self, beverage: Beverage) -> None:
self._items.append(beverage)
def get_total(self) -> float:
return sum(item.get_cost() for item in self._items)
def get_total_calories(self) -> int:
return sum(item.get_calories() for item in self._items)
def print_receipt(self) -> str:
lines = ["=" * 40, "咖啡店订单", "=" * 40]
for i, item in enumerate(self._items, 1):
lines.append(f"{i}. {item.get_description()}")
lines.append(f" 大小: {item.get_size().value}")
lines.append(f" 价格: ¥{item.get_cost():.2f}")
lines.append(f" 卡路里: {item.get_calories()} kcal")
lines.extend([
"-" * 40,
f"总计: ¥{self.get_total():.2f}",
f"总卡路里: {self.get_total_calories()} kcal",
"=" * 40
])
return "\n".join(lines)
if __name__ == "__main__":
order = Order()
beverage1 = Espresso(Size.GRANDE)
beverage1 = Mocha(beverage1)
beverage1 = Whip(beverage1)
order.add_beverage(beverage1)
beverage2 = DarkRoast(Size.VENTI)
beverage2 = Mocha(beverage2)
beverage2 = Mocha(beverage2)
beverage2 = Whip(beverage2)
order.add_beverage(beverage2)
beverage3 = HouseBlend(Size.TALL)
beverage3 = Soy(beverage3)
beverage3 = Milk(beverage3)
order.add_beverage(beverage3)
print(order.print_receipt())9.4.2 数据处理管道
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Dict, Optional, TypeVar, Generic
from dataclasses import dataclass, field
import json
import hashlib
T = TypeVar('T')
class DataProcessor(ABC, Generic[T]):
"""数据处理器抽象"""
@abstractmethod
def process(self, data: T) -> T:
pass
def get_name(self) -> str:
return self.__class__.__name__
class IdentityProcessor(DataProcessor[T]):
"""恒等处理器"""
def process(self, data: T) -> T:
return data
class ProcessorDecorator(DataProcessor[T], Generic[T]):
"""处理器装饰器基类"""
def __init__(self, processor: DataProcessor[T]):
self._processor = processor
def process(self, data: T) -> T:
return self._processor.process(data)
class ValidationProcessor(ProcessorDecorator[T]):
"""验证处理器"""
def __init__(
self,
processor: DataProcessor[T],
validators: List[Callable[[T], bool]],
error_message: str = "验证失败"
):
super().__init__(processor)
self._validators = validators
self._error_message = error_message
def process(self, data: T) -> T:
for validator in self._validators:
if not validator(data):
raise ValueError(self._error_message)
return self._processor.process(data)
class TransformProcessor(ProcessorDecorator[T]):
"""转换处理器"""
def __init__(
self,
processor: DataProcessor[T],
transformer: Callable[[T], T],
name: str = "Transform"
):
super().__init__(processor)
self._transformer = transformer
self._name = name
def process(self, data: T) -> T:
transformed = self._transformer(data)
return self._processor.process(transformed)
def get_name(self) -> str:
return self._name
class LoggingProcessor(ProcessorDecorator[T]):
"""日志处理器"""
def __init__(
self,
processor: DataProcessor[T],
prefix: str = "[LOG]",
verbose: bool = True
):
super().__init__(processor)
self._prefix = prefix
self._verbose = verbose
def process(self, data: T) -> T:
if self._verbose:
print(f"{self._prefix} 处理前: {data}")
result = self._processor.process(data)
if self._verbose:
print(f"{self._prefix} 处理后: {result}")
return result
class CacheProcessor(ProcessorDecorator[T]):
"""缓存处理器"""
def __init__(
self,
processor: DataProcessor[T],
max_size: int = 100
):
super().__init__(processor)
self._cache: Dict[str, T] = {}
self._max_size = max_size
self._hits = 0
self._misses = 0
def _get_key(self, data: T) -> str:
return hashlib.md5(str(data).encode()).hexdigest()
def process(self, data: T) -> T:
key = self._get_key(data)
if key in self._cache:
self._hits += 1
print(f"[CACHE] 命中: {key[:8]}...")
return self._cache[key]
self._misses += 1
result = self._processor.process(data)
if len(self._cache) >= self._max_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
self._cache[key] = result
return result
def get_stats(self) -> Dict[str, Any]:
total = self._hits + self._misses
hit_rate = self._hits / total if total > 0 else 0
return {
"hits": self._hits,
"misses": self._misses,
"hit_rate": f"{hit_rate:.2%}",
"cache_size": len(self._cache)
}
class RetryProcessor(ProcessorDecorator[T]):
"""重试处理器"""
def __init__(
self,
processor: DataProcessor[T],
max_attempts: int = 3,
delay: float = 0.1
):
super().__init__(processor)
self._max_attempts = max_attempts
self._delay = delay
def process(self, data: T) -> T:
import time
last_error = None
for attempt in range(self._max_attempts):
try:
return self._processor.process(data)
except Exception as e:
last_error = e
print(f"[RETRY] 尝试 {attempt + 1}/{self._max_attempts} 失败: {e}")
if attempt < self._max_attempts - 1:
time.sleep(self._delay)
raise last_error
class BatchProcessor(ProcessorDecorator[List[T]]):
"""批处理处理器"""
def __init__(
self,
processor: DataProcessor[T],
batch_size: int = 10
):
super().__init__(processor)
self._batch_size = batch_size
def process(self, data: List[T]) -> List[T]:
results = []
for i in range(0, len(data), self._batch_size):
batch = data[i:i + self._batch_size]
for item in batch:
results.append(self._processor.process(item))
return results
class Pipeline:
"""处理管道"""
def __init__(self, name: str = "Pipeline"):
self._name = name
self._processors: List[DataProcessor] = []
def add(self, processor: DataProcessor) -> 'Pipeline':
self._processors.append(processor)
return self
def process(self, data: Any) -> Any:
result = data
for processor in self._processors:
result = processor.process(result)
return result
def get_processor_names(self) -> List[str]:
return [p.get_name() for p in self._processors]
if __name__ == "__main__":
print("=== 字符串处理管道 ===")
pipeline = Pipeline("StringProcessor")
pipeline.add(ValidationProcessor(
IdentityProcessor(),
[lambda x: isinstance(x, str), lambda x: len(x) > 0],
"输入必须是非空字符串"
))
pipeline.add(TransformProcessor(
IdentityProcessor(),
lambda x: x.strip(),
"Trim"
))
pipeline.add(TransformProcessor(
IdentityProcessor(),
lambda x: x.lower(),
"Lowercase"
))
pipeline.add(LoggingProcessor(IdentityProcessor(), "[处理]"))
result = pipeline.process(" HELLO WORLD ")
print(f"结果: {result}")
print(f"处理器链: {' -> '.join(pipeline.get_processor_names())}")
print("\n=== 带缓存的处理管道 ===")
cache_pipeline = Pipeline("CachedProcessor")
cache_pipeline.add(CacheProcessor(
TransformProcessor(IdentityProcessor(), lambda x: x ** 2, "Square")
))
for i in [1, 2, 1, 3, 2, 4, 1]:
cache_pipeline.process(i)
cache_processor = cache_pipeline._processors[0]
print(f"缓存统计: {cache_processor.get_stats()}")9.4.3 Web中间件系统
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Callable, TypeVar
from enum import Enum
import time
import json
from collections import defaultdict
class HTTPStatus(Enum):
OK = 200
CREATED = 201
BAD_REQUEST = 400
UNAUTHORIZED = 401
FORBIDDEN = 403
NOT_FOUND = 404
INTERNAL_ERROR = 500
SERVICE_UNAVAILABLE = 503
@dataclass
class Request:
path: str
method: str
headers: Dict[str, str] = field(default_factory=dict)
body: Any = None
params: Dict[str, str] = field(default_factory=dict)
cookies: Dict[str, str] = field(default_factory=dict)
ip: str = "127.0.0.1"
user: Optional[Dict[str, Any]] = None
@dataclass
class Response:
status_code: int = 200
body: Any = None
headers: Dict[str, str] = field(default_factory=dict)
def json(self) -> str:
return json.dumps(self.body, ensure_ascii=False)
class RequestHandler(ABC):
"""请求处理器抽象"""
@abstractmethod
def handle(self, request: Request) -> Response:
pass
class BaseHandler(RequestHandler):
"""基础处理器"""
def handle(self, request: Request) -> Response:
return Response(
status_code=HTTPStatus.OK.value,
body={"message": f"处理 {request.method} {request.path}"}
)
class Middleware(RequestHandler):
"""中间件基类"""
def __init__(self, handler: RequestHandler):
self._handler = handler
def handle(self, request: Request) -> Response:
return self._handler.handle(request)
class AuthMiddleware(Middleware):
"""认证中间件"""
def __init__(
self,
handler: RequestHandler,
token_validator: Callable[[str], Optional[Dict]],
exclude_paths: List[str] = None
):
super().__init__(handler)
self._token_validator = token_validator
self._exclude_paths = exclude_paths or []
def handle(self, request: Request) -> Response:
if request.path in self._exclude_paths:
return self._handler.handle(request)
token = request.headers.get("Authorization", "")
if token.startswith("Bearer "):
token = token[7:]
user = self._token_validator(token)
if not user:
return Response(
status_code=HTTPStatus.UNAUTHORIZED.value,
body={"error": "未授权访问"}
)
request.user = user
return self._handler.handle(request)
class CORSMiddleware(Middleware):
"""跨域中间件"""
def __init__(
self,
handler: RequestHandler,
allowed_origins: List[str] = None,
allowed_methods: List[str] = None,
allowed_headers: List[str] = None,
allow_credentials: bool = True
):
super().__init__(handler)
self._allowed_origins = allowed_origins or ["*"]
self._allowed_methods = allowed_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
self._allowed_headers = allowed_headers or ["Content-Type", "Authorization"]
self._allow_credentials = allow_credentials
def handle(self, request: Request) -> Response:
if request.method == "OPTIONS":
response = Response(status_code=HTTPStatus.OK.value)
else:
response = self._handler.handle(request)
origin = request.headers.get("Origin", "*")
if "*" in self._allowed_origins or origin in self._allowed_origins:
response.headers["Access-Control-Allow-Origin"] = origin if origin != "*" else "*"
response.headers["Access-Control-Allow-Methods"] = ", ".join(self._allowed_methods)
response.headers["Access-Control-Allow-Headers"] = ", ".join(self._allowed_headers)
response.headers["Access-Control-Allow-Credentials"] = str(self._allow_credentials).lower()
return response
class RateLimitMiddleware(Middleware):
"""限流中间件"""
def __init__(
self,
handler: RequestHandler,
max_requests: int = 100,
window_seconds: int = 60,
key_func: Callable[[Request], str] = None
):
super().__init__(handler)
self._max_requests = max_requests
self._window = window_seconds
self._key_func = key_func or (lambda r: r.ip)
self._requests: Dict[str, List[float]] = defaultdict(list)
def handle(self, request: Request) -> Response:
key = self._key_func(request)
current_time = time.time()
self._requests[key] = [
t for t in self._requests[key]
if current_time - t < self._window
]
if len(self._requests[key]) >= self._max_requests:
return Response(
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
body={"error": "请求过于频繁,请稍后再试"}
)
self._requests[key].append(current_time)
return self._handler.handle(request)
class LoggingMiddleware(Middleware):
"""日志中间件"""
def __init__(
self,
handler: RequestHandler,
log_request: bool = True,
log_response: bool = True
):
super().__init__(handler)
self._log_request = log_request
self._log_response = log_response
def handle(self, request: Request) -> Response:
start_time = time.time()
if self._log_request:
print(f"[REQUEST] {request.method} {request.path}")
print(f" IP: {request.ip}")
print(f" Headers: {request.headers}")
response = self._handler.handle(request)
elapsed = (time.time() - start_time) * 1000
if self._log_response:
print(f"[RESPONSE] {response.status_code} ({elapsed:.2f}ms)")
return response
class CompressionMiddleware(Middleware):
"""压缩中间件"""
def __init__(
self,
handler: RequestHandler,
min_size: int = 1024,
compression_level: int = 6
):
super().__init__(handler)
self._min_size = min_size
self._compression_level = compression_level
def handle(self, request: Request) -> Response:
response = self._handler.handle(request)
accept_encoding = request.headers.get("Accept-Encoding", "")
if "gzip" not in accept_encoding:
return response
body_str = response.json() if isinstance(response.body, dict) else str(response.body)
if len(body_str) < self._min_size:
return response
import gzip
compressed = gzip.compress(body_str.encode(), compresslevel=self._compression_level)
response.body = compressed
response.headers["Content-Encoding"] = "gzip"
response.headers["Content-Length"] = str(len(compressed))
return response
class CacheMiddleware(Middleware):
"""缓存中间件"""
def __init__(
self,
handler: RequestHandler,
ttl: int = 300,
cacheable_methods: List[str] = None
):
super().__init__(handler)
self._ttl = ttl
self._cacheable_methods = cacheable_methods or ["GET"]
self._cache: Dict[str, tuple] = {}
def _get_cache_key(self, request: Request) -> str:
return f"{request.method}:{request.path}:{str(request.params)}"
def handle(self, request: Request) -> Response:
if request.method not in self._cacheable_methods:
return self._handler.handle(request)
key = self._get_cache_key(request)
current_time = time.time()
if key in self._cache:
cached_response, timestamp = self._cache[key]
if current_time - timestamp < self._ttl:
cached_response.headers["X-Cache"] = "HIT"
return cached_response
response = self._handler.handle(request)
if response.status_code == HTTPStatus.OK.value:
self._cache[key] = (response, current_time)
response.headers["X-Cache"] = "MISS"
return response
def create_mock_token_validator():
valid_tokens = {
"admin_token": {"id": 1, "name": "Admin", "role": "admin"},
"user_token": {"id": 2, "name": "User", "role": "user"},
}
return lambda token: valid_tokens.get(token)
if __name__ == "__main__":
handler = BaseHandler()
handler = AuthMiddleware(
handler,
create_mock_token_validator(),
exclude_paths=["/health", "/login"]
)
handler = RateLimitMiddleware(handler, max_requests=5, window_seconds=60)
handler = LoggingMiddleware(handler)
handler = CORSMiddleware(handler, ["https://example.com"])
handler = CacheMiddleware(handler, ttl=60)
print("=== 有效请求 ===")
valid_request = Request(
path="/api/users",
method="GET",
headers={"Authorization": "Bearer admin_token"},
ip="192.168.1.1"
)
response = handler.handle(valid_request)
print(f"Response: {response.status_code}, Body: {response.body}")
print("\n=== 无效认证请求 ===")
invalid_request = Request(
path="/api/users",
method="GET",
headers={},
ip="192.168.1.1"
)
response = handler.handle(invalid_request)
print(f"Response: {response.status_code}, Body: {response.body}")
print("\n=== 排除路径请求 ===")
health_request = Request(
path="/health",
method="GET",
headers={},
ip="192.168.1.1"
)
response = handler.handle(health_request)
print(f"Response: {response.status_code}, Body: {response.body}")9.5 企业级应用示例
9.5.1 数据库连接池装饰器
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Callable, TypeVar
from dataclasses import dataclass, field
from contextlib import contextmanager
import time
import threading
from queue import Queue
T = TypeVar('T')
class Connection(ABC):
"""数据库连接抽象"""
@abstractmethod
def execute(self, query: str, params: tuple = ()) -> Any:
pass
@abstractmethod
def close(self) -> None:
pass
@abstractmethod
def is_alive(self) -> bool:
pass
@dataclass
class ConnectionConfig:
host: str
port: int
database: str
username: str
password: str
class MockConnection(Connection):
"""模拟数据库连接"""
def __init__(self, config: ConnectionConfig):
self._config = config
self._alive = True
self._id = id(self)
print(f"[Connection] 创建连接 {self._id}")
def execute(self, query: str, params: tuple = ()) -> Any:
if not self._alive:
raise ConnectionError("连接已关闭")
print(f"[Connection {self._id}] 执行: {query}")
return [{"id": 1, "name": "test"}]
def close(self) -> None:
self._alive = False
print(f"[Connection] 关闭连接 {self._id}")
def is_alive(self) -> bool:
return self._alive
class ConnectionDecorator(Connection):
"""连接装饰器基类"""
def __init__(self, connection: Connection):
self._connection = connection
def execute(self, query: str, params: tuple = ()) -> Any:
return self._connection.execute(query, params)
def close(self) -> None:
self._connection.close()
def is_alive(self) -> bool:
return self._connection.is_alive()
class LoggingConnection(ConnectionDecorator):
"""日志连接装饰器"""
def execute(self, query: str, params: tuple = ()) -> Any:
start = time.time()
print(f"[SQL] {query}")
if params:
print(f"[PARAMS] {params}")
result = self._connection.execute(query, params)
elapsed = (time.time() - start) * 1000
print(f"[SQL] 执行时间: {elapsed:.2f}ms")
return result
class RetryConnection(ConnectionDecorator):
"""重试连接装饰器"""
def __init__(
self,
connection: Connection,
max_retries: int = 3,
retry_delay: float = 0.1
):
super().__init__(connection)
self._max_retries = max_retries
self._retry_delay = retry_delay
def execute(self, query: str, params: tuple = ()) -> Any:
last_error = None
for attempt in range(self._max_retries):
try:
return self._connection.execute(query, params)
except ConnectionError as e:
last_error = e
print(f"[RETRY] 数据库操作失败,尝试 {attempt + 1}/{self._max_retries}")
time.sleep(self._retry_delay)
raise last_error
class TimeoutConnection(ConnectionDecorator):
"""超时连接装饰器"""
def __init__(self, connection: Connection, timeout: float = 30.0):
super().__init__(connection)
self._timeout = timeout
def execute(self, query: str, params: tuple = ()) -> Any:
result = [None]
error = [None]
def worker():
try:
result[0] = self._connection.execute(query, params)
except Exception as e:
error[0] = e
thread = threading.Thread(target=worker)
thread.start()
thread.join(timeout=self._timeout)
if thread.is_alive():
raise TimeoutError(f"查询超时 ({self._timeout}s)")
if error[0]:
raise error[0]
return result[0]
class ConnectionPool:
"""连接池"""
def __init__(
self,
config: ConnectionConfig,
max_connections: int = 10,
min_connections: int = 2
):
self._config = config
self._max_connections = max_connections
self._min_connections = min_connections
self._pool: Queue = Queue(maxsize=max_connections)
self._created = 0
self._lock = threading.Lock()
for _ in range(min_connections):
self._pool.put(self._create_connection())
def _create_connection(self) -> Connection:
conn = MockConnection(self._config)
conn = LoggingConnection(conn)
conn = RetryConnection(conn)
conn = TimeoutConnection(conn)
return conn
@contextmanager
def get_connection(self) -> Connection:
conn = self._acquire()
try:
yield conn
finally:
self._release(conn)
def _acquire(self) -> Connection:
with self._lock:
if not self._pool.empty():
return self._pool.get()
if self._created < self._max_connections:
self._created += 1
return self._create_connection()
return self._pool.get(timeout=30)
def _release(self, conn: Connection) -> None:
if conn.is_alive():
self._pool.put(conn)
else:
with self._lock:
self._created -= 1
def close_all(self) -> None:
while not self._pool.empty():
conn = self._pool.get()
conn.close()
if __name__ == "__main__":
config = ConnectionConfig(
host="localhost",
port=5432,
database="testdb",
username="admin",
password="secret"
)
pool = ConnectionPool(config, max_connections=5, min_connections=2)
print("=== 执行查询 ===")
with pool.get_connection() as conn:
result = conn.execute("SELECT * FROM users WHERE id = %s", (1,))
print(f"结果: {result}")
print("\n=== 并发查询 ===")
def worker(worker_id: int):
with pool.get_connection() as conn:
conn.execute(f"SELECT * FROM orders WHERE user_id = {worker_id}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
pool.close_all()9.5.2 缓存装饰器系统
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic, Hashable
from dataclasses import dataclass, field
from functools import wraps
from abc import ABC, abstractmethod
import time
import hashlib
import json
from collections import OrderedDict
T = TypeVar('T')
K = TypeVar('K')
class CacheBackend(ABC):
"""缓存后端抽象"""
@abstractmethod
def get(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
pass
@abstractmethod
def delete(self, key: str) -> None:
pass
@abstractmethod
def clear(self) -> None:
pass
class MemoryCache(CacheBackend):
"""内存缓存"""
def __init__(self, max_size: int = 1000):
self._cache: OrderedDict[str, tuple] = OrderedDict()
self._max_size = max_size
def get(self, key: str) -> Optional[Any]:
if key in self._cache:
value, expiry = self._cache[key]
if expiry and time.time() > expiry:
del self._cache[key]
return None
self._cache.move_to_end(key)
return value
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
expiry = time.time() + ttl if ttl else None
if key in self._cache:
del self._cache[key]
elif len(self._cache) >= self._max_size:
self._cache.popitem(last=False)
self._cache[key] = (value, expiry)
def delete(self, key: str) -> None:
if key in self._cache:
del self._cache[key]
def clear(self) -> None:
self._cache.clear()
@dataclass
class CacheStats:
hits: int = 0
misses: int = 0
sets: int = 0
deletes: int = 0
@property
def hit_rate(self) -> float:
total = self.hits + self.misses
return self.hits / total if total > 0 else 0
class CacheDecorator:
"""缓存装饰器"""
_backends: Dict[str, CacheBackend] = {}
_stats: Dict[str, CacheStats] = {}
@classmethod
def register_backend(cls, name: str, backend: CacheBackend) -> None:
cls._backends[name] = backend
cls._stats[name] = CacheStats()
@classmethod
def get_backend(cls, name: str = "default") -> CacheBackend:
if name not in cls._backends:
cls._backends[name] = MemoryCache()
cls._stats[name] = CacheStats()
return cls._backends[name]
@classmethod
def get_stats(cls, name: str = "default") -> CacheStats:
return cls._stats.get(name, CacheStats())
def __init__(
self,
ttl: int = 300,
backend: str = "default",
key_prefix: str = "",
key_builder: Optional[Callable] = None
):
self._ttl = ttl
self._backend_name = backend
self._key_prefix = key_prefix
self._key_builder = key_builder
def _build_key(self, func: Callable, args: tuple, kwargs: dict) -> str:
if self._key_builder:
return self._key_builder(func, args, kwargs)
key_data = {
"func": func.__qualname__,
"args": args,
"kwargs": sorted(kwargs.items())
}
key_str = json.dumps(key_data, sort_keys=True, default=str)
key_hash = hashlib.md5(key_str.encode()).hexdigest()
return f"{self._key_prefix}{func.__qualname__}:{key_hash}"
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
backend = self.get_backend(self._backend_name)
stats = self._stats[self._backend_name]
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
key = self._build_key(func, args, kwargs)
cached = backend.get(key)
if cached is not None:
stats.hits += 1
return cached
stats.misses += 1
result = func(*args, **kwargs)
backend.set(key, result, self._ttl)
stats.sets += 1
return result
wrapper.cache_clear = lambda: backend.delete(self._key_prefix + func.__qualname__)
wrapper.cache_info = lambda: {
"backend": self._backend_name,
"ttl": self._ttl,
"stats": {
"hits": stats.hits,
"misses": stats.misses,
"hit_rate": f"{stats.hit_rate:.2%}"
}
}
return wrapper
def invalidate_cache(*patterns: str):
"""缓存失效装饰器"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
result = func(*args, **kwargs)
backend = CacheDecorator.get_backend()
for pattern in patterns:
backend.delete(pattern)
return result
return wrapper
return decorator
CacheDecorator.register_backend("default", MemoryCache(max_size=500))
CacheDecorator.register_backend("session", MemoryCache(max_size=100))
if __name__ == "__main__":
@CacheDecorator(ttl=60, backend="default", key_prefix="user:")
def get_user(user_id: int) -> dict:
print(f"[DB] 查询用户 {user_id}")
return {"id": user_id, "name": f"User{user_id}", "email": f"user{user_id}@example.com"}
@CacheDecorator(ttl=300, backend="session")
def get_session(session_id: str) -> dict:
print(f"[DB] 查询会话 {session_id}")
return {"session_id": session_id, "user_id": 1, "expires": time.time() + 3600}
print("=== 首次查询 ===")
user1 = get_user(1)
print(f"结果: {user1}")
print("\n=== 缓存命中 ===")
user1_cached = get_user(1)
print(f"结果: {user1_cached}")
print("\n=== 缓存统计 ===")
print(get_user.cache_info())
print("\n=== 会话缓存 ===")
session = get_session("abc123")
print(f"结果: {session}")
print("\n=== 各后端统计 ===")
for name in ["default", "session"]:
stats = CacheDecorator.get_stats(name)
print(f"{name}: hits={stats.hits}, misses={stats.misses}, rate={stats.hit_rate:.2%}")9.6 模式变体与扩展
9.6.1 透明装饰器
from typing import Any, Dict
class TransparentDecorator:
"""透明装饰器 - 保持被装饰对象的所有属性"""
def __init__(self, obj: Any):
self._obj = obj
def __getattr__(self, name: str) -> Any:
return getattr(self._obj, name)
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith('_') or name in self.__dict__:
super().__setattr__(name, value)
else:
setattr(self._obj, name, value)
def __repr__(self) -> str:
return f"TransparentDecorator({self._obj!r})"9.6.2 条件装饰器
from typing import Callable, TypeVar, ParamSpec
P = ParamSpec('P')
T = TypeVar('T')
def conditional(
condition: Callable[[], bool],
decorator: Callable[[Callable[P, T]], Callable[P, T]]
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""条件装饰器 - 根据条件决定是否应用装饰"""
def wrapper(func: Callable[P, T]) -> Callable[P, T]:
if condition():
return decorator(func)
return func
return wrapper9.6.3 参数化装饰器工厂
from typing import Callable, TypeVar, ParamSpec, Any
P = ParamSpec('P')
T = TypeVar('T')
class DecoratorFactory:
"""装饰器工厂 - 支持灵活配置"""
def __init__(self):
self._decorators: Dict[str, Callable] = {}
def register(self, name: str) -> Callable:
def decorator(func: Callable[[Callable[P, T]], Callable[P, T]]) -> Callable:
self._decorators[name] = func
return func
return decorator
def create(self, name: str, **kwargs: Any) -> Callable:
if name not in self._decorators:
raise ValueError(f"Unknown decorator: {name}")
return self._decorators[name](**kwargs)
def chain(self, *names: str) -> Callable:
def decorator(func: Callable[P, T]) -> Callable[P, T]:
result = func
for name in reversed(names):
result = self._decorators[name]()(result)
return result
return decorator9.7 反模式与最佳实践
9.7.1 常见反模式
反模式1:装饰器地狱
@decorator_a
@decorator_b
@decorator_c
@decorator_d
@decorator_e
def function():
pass问题:过度嵌套导致难以调试和理解。
解决方案:限制装饰器层数,使用组合模式替代。
反模式2:状态泄漏
def bad_decorator(func):
cache = {}
def wrapper(*args):
if args in cache:
return cache[args]
result = func(*args)
cache[args] = result
return result
return wrapper问题:装饰器状态在所有调用间共享。
解决方案:使用实例级别的状态或显式管理状态。
反模式3:签名丢失
def bad_decorator(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper问题:被装饰函数的元数据丢失。
解决方案:使用@functools.wraps保留元数据。
9.7.2 最佳实践
| 实践 | 描述 | 示例 |
|---|---|---|
| 使用wraps | 保留被装饰函数的元数据 | @wraps(func) |
| 类型注解 | 为装饰器添加类型提示 | Callable[P, T] |
| 文档字符串 | 为装饰器添加说明文档 | docstring |
| 参数验证 | 验证装饰器参数 | 类型检查 |
| 异常处理 | 合理处理装饰器中的异常 | try/except |
| 可配置性 | 支持装饰器参数配置 | 工厂模式 |
9.8 模式比较
9.8.1 与相关模式对比
| 模式 | 意图 | 结构 | 关键区别 |
|---|---|---|---|
| 装饰器 | 动态添加职责 | 包装链 | 不改变对象本质 |
| 代理 | 控制对象访问 | 单层包装 | 关注访问控制 |
| 适配器 | 接口转换 | 单层包装 | 关注接口兼容 |
| 组合 | 部分-整体结构 | 树形结构 | 关注层次组织 |
| 责任链 | 请求传递处理 | 链式传递 | 关注处理流程 |
9.8.2 装饰器模式与继承对比
┌─────────────────────────────────────────────────────────────┐
│ 装饰器模式 vs 继承 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 继承方式: 装饰器方式: │
│ │
│ Component Component │
│ │ │ │
│ ├── ComponentWithA └── DecoratorA │
│ │ │ │ │
│ │ ├── ComponentWithAB └── DecoratorB │
│ │ │ │
│ │ └── ComponentWithAC │
│ │ (运行时组合) │
│ └── ComponentWithB │
│ │ 功能组合数: │
│ ├── ComponentWithBA n个装饰器 = n!组合 │
│ │ │
│ └── ComponentWithBC 类爆炸问题 │
│ │
│ (编译时确定) │
└─────────────────────────────────────────────────────────────┘9.9 决策指南
9.9.1 适用场景检查清单
- [ ] 需要动态添加/撤销对象功能
- [ ] 需要组合多种功能变体
- [ ] 不能通过继承扩展(如final类)
- [ ] 需要在运行时配置功能
- [ ] 希望避免类爆炸问题
9.9.2 实现选择决策树
需要装饰什么?
/ \
函数/方法 类
/ \
函数装饰器 需要修改类定义?
/ \ / \
无参数 有参数 是 否
| | | |
@decorator 工厂模式 类装饰器 实例装饰器9.9.3 快速参考卡
┌─────────────────────────────────────────────────────────────┐
│ 装饰器模式快速参考 │
├─────────────────────────────────────────────────────────────┤
│ 定义: 动态给对象添加职责,比继承更灵活 │
├─────────────────────────────────────────────────────────────┤
│ 参与者: │
│ • Component - 统一接口 │
│ • ConcreteComponent - 被装饰对象 │
│ • Decorator - 装饰器基类 │
│ • ConcreteDecorator - 具体装饰器 │
├─────────────────────────────────────────────────────────────┤
│ 形式化: d_n ∘ d_{n-1} ∘ ... ∘ d_1(c) = c' │
├─────────────────────────────────────────────────────────────┤
│ Python特性: │
│ • @语法糖 │
│ • functools.wraps │
│ • ParamSpec/TypeVar │
│ • 类装饰器 │
├─────────────────────────────────────────────────────────────┤
│ 典型应用: │
│ • 日志/计时 │
│ • 缓存/记忆化 │
│ • 权限验证 │
│ • 中间件系统 │
│ • 数据处理管道 │
├─────────────────────────────────────────────────────────────┤
│ 最佳实践: │
│ • 使用@wraps保留元数据 │
│ • 添加类型注解 │
│ • 限制装饰器层数 │
│ • 文档化装饰器行为 │
└─────────────────────────────────────────────────────────────┘9.10 小结
装饰器模式是Python中最具特色的设计模式之一。其核心价值在于:
- 开闭原则:无需修改现有代码即可扩展功能
- 单一职责:每个装饰器专注于单一功能
- 动态组合:运行时灵活组合多种功能
- 避免继承:解决继承导致的类爆炸问题
在Python中,装饰器更是一种语言级特性,从函数装饰器到类装饰器,从简单包装到复杂的中间件系统,装饰器模式在Python生态中无处不在。掌握装饰器模式是成为Python高级开发者的必经之路。
思考与练习
基础练习:实现一个支持缓存失效策略的装饰器,支持LRU、LFU、FIFO等策略。
进阶练习:设计一个异步装饰器系统,支持async/await语法的装饰器。
挑战练习:实现一个装饰器组合框架,支持声明式配置装饰器链。
设计思考:装饰器模式与AOP(面向切面编程)有何关系?请分析其异同。