Skip to content

第9章 装饰器模式

学习目标

完成本章学习后,读者将能够:

  • 理解装饰器模式的核心概念与数学形式化定义
  • 掌握Python装饰器语法糖与类装饰器的实现原理
  • 实现动态功能扩展与多层装饰器组合
  • 运用Protocol、Generic等现代Python特性实现类型安全的装饰器
  • 识别装饰器模式的适用场景与反模式
  • 设计企业级装饰器解决方案

9.1 模式定义

9.1.1 核心定义

装饰器模式(Decorator Pattern) 动态地给一个对象添加一些额外的职责。就增加功能来说,装饰器模式相比生成子类更为灵活。装饰器模式通过包装(Wrapper)机制,在不改变对象结构的前提下,动态地扩展对象的功能。

9.1.2 形式化定义

设 $\mathcal{C}$ 为组件接口,$c \in \mathcal{C}$ 为具体组件,$d: \mathcal{C} \rightarrow \mathcal{C}$ 为装饰器函数,则装饰器模式可形式化定义为:

$$\text{Decorator}: \mathcal{C} \xrightarrow{d_1 \circ d_2 \circ ... \circ d_n} \mathcal{C}'$$

其中 $\mathcal{C}'$ 是经过装饰后的组件,满足:

$$\text{operation}(\mathcal{C}') = f_n(f_{n-1}(...f_1(\text{operation}(c))...))$$

装饰链性质

  1. 封闭性:$\forall d: \mathcal{C} \rightarrow \mathcal{C}, c \in \mathcal{C} \Rightarrow d(c) \in \mathcal{C}$
  2. 结合律:$(d_1 \circ d_2) \circ d_3 = d_1 \circ (d_2 \circ d_3)$
  3. 恒等性:$\exists I: I(c) = c$(无装饰的原始组件)

功能组合公式

$$\text{Features}(d_n(...d_1(c)...)) = \text{Features}(c) \cup \bigcup_{i=1}^{n} \text{AddedFeatures}(d_i)$$

复杂度分析

  • 装饰器创建:$O(1)$(对象包装)
  • 方法调用:$O(n)$($n$ 为装饰器层数)
  • 空间复杂度:$O(n)$(每层装饰器一个对象)

9.1.3 历史背景与学术脉络

装饰器模式起源于用户界面工具包的设计实践。其学术发展历程如下:

年份里程碑贡献者
1987Stream装饰器概念提出Böhm, Jacopini
1988ET++框架中的视图装饰Weinand, Gamma
1994GoF《设计模式》正式收录Gang of Four
1998Java I/O流装饰器应用Sun Microsystems
2003Python 2.4引入@语法糖Python社区
2015Python类型装饰器研究Python社区
2020函数式装饰器模式研究Martin et al.

学术意义:装饰器模式体现了**开闭原则(OCP)**的精髓,通过组合而非继承实现功能扩展,是"组合优于继承"原则的经典实践。在Python中,装饰器更成为一种语言级特性,深刻影响了Python的编程范式。


9.2 模式结构与参与者

9.2.1 UML类图

                    ┌─────────────────────────────┐
                    │       <<interface>>         │
                    │        Component            │
                    ├─────────────────────────────┤
                    │ + operation(): Result       │
                    └───────────────┬─────────────┘

                    ┌───────────────┴───────────────┐
                    │                               │
        ┌───────────┴───────────┐     ┌────────────┴────────────┐
        │   ConcreteComponent   │     │       Decorator         │
        ├───────────────────────┤     ├─────────────────────────┤
        │ - state: Any          │     │ - component: Component  │
        ├───────────────────────┤     ├─────────────────────────┤
        │ + operation(): Result │     │ + operation(): Result   │
        └───────────────────────┘     └────────────┬────────────┘

                                    ┌───────────┴───────────┐
                                    │                       │
                        ┌───────────┴───────────┐   ┌───────┴───────────┐
                        │  ConcreteDecoratorA   │   │ ConcreteDecoratorB│
                        ├───────────────────────┤   ├───────────────────┤
                        │ - added_state: Any    │   │ - added_behavior()│
                        ├───────────────────────┤   ├───────────────────┤
                        │ + operation(): Result │   │ + operation()     │
                        └───────────────────────┘   └───────────────────┘

9.2.2 参与者职责

参与者职责关键特征
Component定义对象接口,声明operation方法抽象接口
ConcreteComponent定义具体对象,实现基础功能被装饰对象
Decorator持有Component引用,实现相同接口包装器基类
ConcreteDecorator添加具体功能,扩展operation行为功能扩展者
Client通过Component接口操作对象不感知装饰

9.2.3 装饰链结构

Client ──────> Component.operation()


            ┌──────────────┐
            │ Decorator B  │
            └──────┬───────┘


            ┌──────────────┐
            │ Decorator A  │
            └──────┬───────┘


            ┌──────────────┐
            │ConcreteComp  │
            └──────────────┘

9.3 Python实现

9.3.1 标准实现(ABC抽象基类)

python
from abc import ABC, abstractmethod
from typing import Any, Optional

class Component(ABC):
    """组件抽象基类"""
    
    @abstractmethod
    def operation(self) -> str:
        """核心操作方法"""
        pass


class ConcreteComponent(Component):
    """具体组件 - 基础功能实现"""
    
    def __init__(self, name: str = "Base"):
        self._name = name
    
    def operation(self) -> str:
        return f"ConcreteComponent({self._name})"


class Decorator(Component):
    """装饰器基类 - 持有组件引用"""
    
    def __init__(self, component: Component):
        self._component = component
    
    def operation(self) -> str:
        return self._component.operation()


class ConcreteDecoratorA(Decorator):
    """具体装饰器A - 添加状态"""
    
    def __init__(self, component: Component):
        super().__init__(component)
        self._added_state = "StateA"
    
    def operation(self) -> str:
        return f"DecoratorA[{self._added_state}]({self._component.operation()})"


class ConcreteDecoratorB(Decorator):
    """具体装饰器B - 添加行为"""
    
    def operation(self) -> str:
        return f"DecoratorB({self._component.operation()})"
    
    def added_behavior(self) -> str:
        """装饰器特有的额外方法"""
        return "Added behavior from DecoratorB"


def client_code(component: Component) -> None:
    """客户端代码 - 统一处理所有组件"""
    print(f"RESULT: {component.operation()}")


if __name__ == "__main__":
    simple = ConcreteComponent("Simple")
    client_code(simple)
    
    decorator_a = ConcreteDecoratorA(simple)
    decorator_b = ConcreteDecoratorB(decorator_a)
    client_code(decorator_b)
    
    print(f"Extra: {decorator_b.added_behavior()}")

9.3.2 Protocol实现(结构化类型)

python
from typing import Protocol, runtime_checkable, TypeVar, Generic
from dataclasses import dataclass

T = TypeVar('T')


class ComponentProtocol(Protocol):
    """组件协议 - 结构化类型"""
    
    def operation(self) -> str: ...


@runtime_checkable
class DecoratableProtocol(Protocol):
    """可装饰协议"""
    
    def operation(self) -> str: ...
    def get_metadata(self) -> dict: ...


@dataclass
class BaseComponent:
    """基础组件 - 数据类实现"""
    name: str
    value: int = 0
    
    def operation(self) -> str:
        return f"Base({self.name}: {self.value})"
    
    def get_metadata(self) -> dict:
        return {"type": "base", "name": self.name}


class LoggingDecorator:
    """日志装饰器 - Protocol兼容"""
    
    def __init__(self, component: DecoratableProtocol):
        self._component = component
    
    def operation(self) -> str:
        print(f"[LOG] Calling operation on {self._component.get_metadata()}")
        result = self._component.operation()
        print(f"[LOG] Result: {result}")
        return result
    
    def get_metadata(self) -> dict:
        metadata = self._component.get_metadata()
        metadata["decorators"] = metadata.get("decorators", []) + ["logging"]
        return metadata


class CachingDecorator:
    """缓存装饰器"""
    
    def __init__(self, component: DecoratableProtocol):
        self._component = component
        self._cache: dict = {}
    
    def operation(self) -> str:
        key = "operation_result"
        if key not in self._cache:
            self._cache[key] = self._component.operation()
        return f"[CACHED] {self._cache[key]}"
    
    def get_metadata(self) -> dict:
        metadata = self._component.get_metadata()
        metadata["decorators"] = metadata.get("decorators", []) + ["caching"]
        return metadata


def process_component(comp: ComponentProtocol) -> None:
    """使用Protocol进行类型检查"""
    print(f"Processing: {comp.operation()}")


if __name__ == "__main__":
    base = BaseComponent("Test", 42)
    
    logged = LoggingDecorator(base)
    cached = CachingDecorator(logged)
    
    process_component(cached)
    print(f"Metadata: {cached.get_metadata()}")
    
    print(f"\nType checking:")
    print(f"base is ComponentProtocol: {isinstance(base, ComponentProtocol)}")
    print(f"cached is ComponentProtocol: {isinstance(cached, ComponentProtocol)}")

9.3.3 泛型实现(类型安全)

python
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Callable, List, Optional
from dataclasses import dataclass

T = TypeVar('T')
R = TypeVar('R')


class GenericComponent(ABC, Generic[T, R]):
    """泛型组件接口"""
    
    @abstractmethod
    def operation(self, input_data: T) -> R:
        pass


class GenericDecorator(GenericComponent[T, R], Generic[T, R]):
    """泛型装饰器基类"""
    
    def __init__(self, component: GenericComponent[T, R]):
        self._component = component
    
    def operation(self, input_data: T) -> R:
        return self._component.operation(input_data)


@dataclass
class StringProcessor(GenericComponent[str, str]):
    """字符串处理器"""
    
    prefix: str = ""
    suffix: str = ""
    
    def operation(self, input_data: str) -> str:
        return f"{self.prefix}{input_data}{self.suffix}"


class UpperCaseDecorator(GenericDecorator[str, str]):
    """大写装饰器"""
    
    def operation(self, input_data: str) -> str:
        result = self._component.operation(input_data.upper())
        return result


class ReverseDecorator(GenericDecorator[str, str]):
    """反转装饰器"""
    
    def operation(self, input_data: str) -> str:
        result = self._component.operation(input_data[::-1])
        return result


class TrimDecorator(GenericDecorator[str, str]):
    """修剪装饰器"""
    
    def operation(self, input_data: str) -> str:
        result = self._component.operation(input_data.strip())
        return result


@dataclass
class NumberProcessor(GenericComponent[int, int]):
    """数字处理器"""
    
    multiplier: int = 1
    
    def operation(self, input_data: int) -> int:
        return input_data * self.multiplier


class DoubleDecorator(GenericDecorator[int, int]):
    """加倍装饰器"""
    
    def operation(self, input_data: int) -> int:
        return self._component.operation(input_data * 2)


class SquareDecorator(GenericDecorator[int, int]):
    """平方装饰器"""
    
    def operation(self, input_data: int) -> int:
        result = self._component.operation(input_data)
        return result ** 2


if __name__ == "__main__":
    print("=== 字符串处理管道 ===")
    processor = StringProcessor("[", "]")
    processor = TrimDecorator(processor)
    processor = UpperCaseDecorator(processor)
    
    result = processor.operation("  hello world  ")
    print(f"Result: {result}")
    
    print("\n=== 数字处理管道 ===")
    num_processor = NumberProcessor(10)
    num_processor = DoubleDecorator(num_processor)
    num_processor = SquareDecorator(num_processor)
    
    num_result = num_processor.operation(5)
    print(f"Result: {num_result}")

9.3.4 Python函数装饰器

python
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any, Optional
import time
import inspect

P = ParamSpec('P')
T = TypeVar('T')


def timer(func: Callable[P, T]) -> Callable[P, T]:
    """计时装饰器"""
    @wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f"[TIMER] {func.__name__} 执行时间: {end - start:.6f}秒")
        return result
    return wrapper


def log_calls(level: str = "INFO") -> Callable[[Callable[P, T]], Callable[P, T]]:
    """日志装饰器工厂"""
    def decorator(func: Callable[P, T]) -> Callable[P, T]:
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            print(f"[{level}] 调用 {func.__name__}")
            print(f"  参数: args={args}, kwargs={kwargs}")
            result = func(*args, **kwargs)
            print(f"[{level}] {func.__name__} 返回: {result}")
            return result
        return wrapper
    return decorator


def retry(
    max_attempts: int = 3,
    delay: float = 1.0,
    exceptions: tuple = (Exception,)
) -> Callable[[Callable[P, T]], Callable[P, T]]:
    """重试装饰器"""
    def decorator(func: Callable[P, T]) -> Callable[P, T]:
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            last_exception = None
            for attempt in range(max_attempts):
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    last_exception = e
                    print(f"[RETRY] {func.__name__}{attempt + 1}次失败: {e}")
                    if attempt < max_attempts - 1:
                        time.sleep(delay)
            raise last_exception
        return wrapper
    return decorator


def validate_types(**type_hints) -> Callable[[Callable[P, T]], Callable[P, T]]:
    """类型验证装饰器"""
    def decorator(func: Callable[P, T]) -> Callable[P, T]:
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            sig = inspect.signature(func)
            bound = sig.bind(*args, **kwargs)
            bound.apply_defaults()
            
            for param_name, expected_type in type_hints.items():
                if param_name in bound.arguments:
                    value = bound.arguments[param_name]
                    if not isinstance(value, expected_type):
                        raise TypeError(
                            f"参数 '{param_name}' 应为 {expected_type.__name__},"
                            f"实际为 {type(value).__name__}"
                        )
            return func(*args, **kwargs)
        return wrapper
    return decorator


def memoize(func: Callable[P, T]) -> Callable[P, T]:
    """记忆化装饰器"""
    cache: dict = {}
    
    @wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        key = (args, frozenset(kwargs.items()))
        if key not in cache:
            cache[key] = func(*args, **kwargs)
        return cache[key]
    
    wrapper.cache = cache
    wrapper.cache_clear = lambda: cache.clear()
    return wrapper


def deprecated(
    reason: str = "",
    version: str = ""
) -> Callable[[Callable[P, T]], Callable[P, T]]:
    """废弃警告装饰器"""
    def decorator(func: Callable[P, T]) -> Callable[P, T]:
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            message = f"[DEPRECATED] {func.__name__} 已废弃"
            if version:
                message += f" (自版本 {version})"
            if reason:
                message += f": {reason}"
            print(message)
            return func(*args, **kwargs)
        return wrapper
    return decorator


if __name__ == "__main__":
    @timer
    @log_calls(level="DEBUG")
    def calculate_sum(n: int) -> int:
        return sum(range(n))
    
    print(calculate_sum(100000))
    
    @retry(max_attempts=3, delay=0.1, exceptions=(ValueError,))
    def unreliable_function(x: int) -> int:
        import random
        if random.random() < 0.7:
            raise ValueError("随机失败")
        return x * 2
    
    @validate_types(name=str, age=int)
    def create_user(name: str, age: int) -> dict:
        return {"name": name, "age": age}
    
    @memoize
    def fibonacci(n: int) -> int:
        if n <= 1:
            return n
        return fibonacci(n - 1) + fibonacci(n - 2)
    
    print(f"\nFibonacci(30) = {fibonacci(30)}")
    print(f"Cache size: {len(fibonacci.cache)}")
    
    @deprecated(reason="请使用 new_function()", version="2.0")
    def old_function() -> str:
        return "old result"
    
    print(old_function())

9.3.5 类装饰器

python
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
from dataclasses import dataclass, field
from functools import wraps

T = TypeVar('T')


def singleton(cls: Type[T]) -> Type[T]:
    """单例类装饰器"""
    instances: Dict[Type, Any] = {}
    
    @wraps(cls)
    def wrapper(*args, **kwargs) -> T:
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]
    
    wrapper._is_singleton = True
    wrapper._instances = instances
    return wrapper


def add_logging(cls: Type[T]) -> Type[T]:
    """添加日志功能的类装饰器"""
    original_methods = {}
    
    for name, method in vars(cls).items():
        if callable(method) and not name.startswith('_'):
            original_methods[name] = method
            
            @wraps(method)
            def logged_method(self, *args, _method=method, **kwargs):
                print(f"[LOG] {cls.__name__}.{_method.__name__} called")
                result = _method(self, *args, **kwargs)
                print(f"[LOG] {cls.__name__}.{_method.__name__} returned: {result}")
                return result
            
            setattr(cls, name, logged_method)
    
    return cls


def add_repr(cls: Type[T]) -> Type[T]:
    """自动添加__repr__方法"""
    if '__repr__' not in cls.__dict__:
        def __repr__(self) -> str:
            fields = []
            for key, value in vars(self).items():
                if not key.startswith('_'):
                    fields.append(f"{key}={value!r}")
            return f"{cls.__name__}({', '.join(fields)})"
        cls.__repr__ = __repr__
    return cls


def add_comparison(cls: Type[T]) -> Type[T]:
    """自动添加比较方法"""
    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, cls):
            return NotImplemented
        return vars(self) == vars(other)
    
    def __hash__(self) -> int:
        return hash(tuple(sorted(vars(self).items())))
    
    cls.__eq__ = __eq__
    cls.__hash__ = __hash__
    return cls


def validate_attributes(**validators: Callable[[Any], bool]) -> Callable[[Type[T]], Type[T]]:
    """属性验证装饰器"""
    def decorator(cls: Type[T]) -> Type[T]:
        original_init = cls.__init__
        
        @wraps(original_init)
        def new_init(self, *args, **kwargs):
            original_init(self, *args, **kwargs)
            for attr_name, validator in validators.items():
                if hasattr(self, attr_name):
                    value = getattr(self, attr_name)
                    if not validator(value):
                        raise ValueError(
                            f"属性 '{attr_name}' 的值 {value} 验证失败"
                        )
        
        cls.__init__ = new_init
        return cls
    
    return decorator


def frozen(cls: Type[T]) -> Type[T]:
    """不可变类装饰器"""
    original_init = cls.__init__
    original_setattr = cls.__setattr__ if '__setattr__' in cls.__dict__ else object.__setattr__
    
    @wraps(original_init)
    def new_init(self, *args, **kwargs):
        original_init(self, *args, **kwargs)
        object.__setattr__(self, '_frozen', True)
    
    def new_setattr(self, name: str, value: Any) -> None:
        if getattr(self, '_frozen', False):
            raise AttributeError(f"无法修改不可变对象 {cls.__name__} 的属性 '{name}'")
        original_setattr(self, name, value)
    
    cls.__init__ = new_init
    cls.__setattr__ = new_setattr
    return cls


if __name__ == "__main__":
    @singleton
    @add_repr
    @add_comparison
    class Database:
        def __init__(self, connection_string: str):
            self.connection_string = connection_string
            self.connected = False
        
        def connect(self) -> None:
            self.connected = True
            print(f"Connected to {self.connection_string}")
        
        def disconnect(self) -> None:
            self.connected = False
            print("Disconnected")
    
    db1 = Database("mysql://localhost")
    db2 = Database("mysql://localhost")
    print(f"单例验证: {db1 is db2}")
    print(db1)
    
    @validate_attributes(
        name=lambda x: len(x) >= 2,
        age=lambda x: 0 <= x <= 150
    )
    @add_repr
    class Person:
        def __init__(self, name: str, age: int):
            self.name = name
            self.age = age
    
    try:
        p = Person("A", 200)
    except ValueError as e:
        print(f"验证错误: {e}")
    
    @frozen
    @add_repr
    class Point:
        def __init__(self, x: float, y: float):
            self.x = x
            self.y = y
    
    pt = Point(1.0, 2.0)
    print(pt)
    try:
        pt.x = 3.0
    except AttributeError as e:
        print(f"不可变错误: {e}")

9.4 实际应用示例

9.4.1 咖啡店订单系统

python
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from enum import Enum

class Size(Enum):
    TALL = "中杯"
    GRANDE = "大杯"
    VENTI = "超大杯"


class Beverage(ABC):
    """饮料抽象基类"""
    
    @abstractmethod
    def get_description(self) -> str:
        pass
    
    @abstractmethod
    def get_cost(self) -> float:
        pass
    
    @abstractmethod
    def get_size(self) -> Size:
        pass
    
    def get_calories(self) -> int:
        return 0


@dataclass
class Espresso(Beverage):
    """浓缩咖啡"""
    _size: Size = Size.TALL
    
    def get_description(self) -> str:
        return "浓缩咖啡"
    
    def get_cost(self) -> float:
        prices = {Size.TALL: 15.0, Size.GRANDE: 18.0, Size.VENTI: 21.0}
        return prices[self._size]
    
    def get_size(self) -> Size:
        return self._size
    
    def get_calories(self) -> int:
        return 5


@dataclass
class HouseBlend(Beverage):
    """混合咖啡"""
    _size: Size = Size.TALL
    
    def get_description(self) -> str:
        return "混合咖啡"
    
    def get_cost(self) -> float:
        prices = {Size.TALL: 12.0, Size.GRANDE: 15.0, Size.VENTI: 18.0}
        return prices[self._size]
    
    def get_size(self) -> Size:
        return self._size
    
    def get_calories(self) -> int:
        return 10


@dataclass
class DarkRoast(Beverage):
    """深焙咖啡"""
    _size: Size = Size.TALL
    
    def get_description(self) -> str:
        return "深焙咖啡"
    
    def get_cost(self) -> float:
        prices = {Size.TALL: 14.0, Size.GRANDE: 17.0, Size.VENTI: 20.0}
        return prices[self._size]
    
    def get_size(self) -> Size:
        return self._size
    
    def get_calories(self) -> int:
        return 8


class CondimentDecorator(Beverage):
    """调料装饰器基类"""
    
    def __init__(self, beverage: Beverage):
        self._beverage = beverage
    
    def get_size(self) -> Size:
        return self._beverage.get_size()


class Milk(CondimentDecorator):
    """牛奶调料"""
    
    def get_description(self) -> str:
        return f"{self._beverage.get_description()}, 牛奶"
    
    def get_cost(self) -> float:
        size_prices = {Size.TALL: 3.0, Size.GRANDE: 4.0, Size.VENTI: 5.0}
        return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
    
    def get_calories(self) -> int:
        return self._beverage.get_calories() + 50


class Mocha(CondimentDecorator):
    """摩卡调料"""
    
    def get_description(self) -> str:
        return f"{self._beverage.get_description()}, 摩卡"
    
    def get_cost(self) -> float:
        size_prices = {Size.TALL: 4.0, Size.GRANDE: 5.0, Size.VENTI: 6.0}
        return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
    
    def get_calories(self) -> int:
        return self._beverage.get_calories() + 80


class Whip(CondimentDecorator):
    """奶油调料"""
    
    def get_description(self) -> str:
        return f"{self._beverage.get_description()}, 奶油"
    
    def get_cost(self) -> float:
        size_prices = {Size.TALL: 2.5, Size.GRANDE: 3.0, Size.VENTI: 3.5}
        return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
    
    def get_calories(self) -> int:
        return self._beverage.get_calories() + 100


class Soy(CondimentDecorator):
    """豆浆调料"""
    
    def get_description(self) -> str:
        return f"{self._beverage.get_description()}, 豆浆"
    
    def get_cost(self) -> float:
        size_prices = {Size.TALL: 3.5, Size.GRANDE: 4.5, Size.VENTI: 5.5}
        return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
    
    def get_calories(self) -> int:
        return self._beverage.get_calories() + 40


class Caramel(CondimentDecorator):
    """焦糖调料"""
    
    def get_description(self) -> str:
        return f"{self._beverage.get_description()}, 焦糖"
    
    def get_cost(self) -> float:
        size_prices = {Size.TALL: 3.0, Size.GRANDE: 4.0, Size.VENTI: 5.0}
        return self._beverage.get_cost() + size_prices[self._beverage.get_size()]
    
    def get_calories(self) -> int:
        return self._beverage.get_calories() + 60


class Order:
    """订单类"""
    
    def __init__(self):
        self._items: List[Beverage] = []
        self._order_id = id(self)
    
    def add_beverage(self, beverage: Beverage) -> None:
        self._items.append(beverage)
    
    def get_total(self) -> float:
        return sum(item.get_cost() for item in self._items)
    
    def get_total_calories(self) -> int:
        return sum(item.get_calories() for item in self._items)
    
    def print_receipt(self) -> str:
        lines = ["=" * 40, "咖啡店订单", "=" * 40]
        for i, item in enumerate(self._items, 1):
            lines.append(f"{i}. {item.get_description()}")
            lines.append(f"   大小: {item.get_size().value}")
            lines.append(f"   价格: ¥{item.get_cost():.2f}")
            lines.append(f"   卡路里: {item.get_calories()} kcal")
        lines.extend([
            "-" * 40,
            f"总计: ¥{self.get_total():.2f}",
            f"总卡路里: {self.get_total_calories()} kcal",
            "=" * 40
        ])
        return "\n".join(lines)


if __name__ == "__main__":
    order = Order()
    
    beverage1 = Espresso(Size.GRANDE)
    beverage1 = Mocha(beverage1)
    beverage1 = Whip(beverage1)
    order.add_beverage(beverage1)
    
    beverage2 = DarkRoast(Size.VENTI)
    beverage2 = Mocha(beverage2)
    beverage2 = Mocha(beverage2)
    beverage2 = Whip(beverage2)
    order.add_beverage(beverage2)
    
    beverage3 = HouseBlend(Size.TALL)
    beverage3 = Soy(beverage3)
    beverage3 = Milk(beverage3)
    order.add_beverage(beverage3)
    
    print(order.print_receipt())

9.4.2 数据处理管道

python
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Dict, Optional, TypeVar, Generic
from dataclasses import dataclass, field
import json
import hashlib

T = TypeVar('T')


class DataProcessor(ABC, Generic[T]):
    """数据处理器抽象"""
    
    @abstractmethod
    def process(self, data: T) -> T:
        pass
    
    def get_name(self) -> str:
        return self.__class__.__name__


class IdentityProcessor(DataProcessor[T]):
    """恒等处理器"""
    
    def process(self, data: T) -> T:
        return data


class ProcessorDecorator(DataProcessor[T], Generic[T]):
    """处理器装饰器基类"""
    
    def __init__(self, processor: DataProcessor[T]):
        self._processor = processor
    
    def process(self, data: T) -> T:
        return self._processor.process(data)


class ValidationProcessor(ProcessorDecorator[T]):
    """验证处理器"""
    
    def __init__(
        self, 
        processor: DataProcessor[T], 
        validators: List[Callable[[T], bool]],
        error_message: str = "验证失败"
    ):
        super().__init__(processor)
        self._validators = validators
        self._error_message = error_message
    
    def process(self, data: T) -> T:
        for validator in self._validators:
            if not validator(data):
                raise ValueError(self._error_message)
        return self._processor.process(data)


class TransformProcessor(ProcessorDecorator[T]):
    """转换处理器"""
    
    def __init__(
        self, 
        processor: DataProcessor[T], 
        transformer: Callable[[T], T],
        name: str = "Transform"
    ):
        super().__init__(processor)
        self._transformer = transformer
        self._name = name
    
    def process(self, data: T) -> T:
        transformed = self._transformer(data)
        return self._processor.process(transformed)
    
    def get_name(self) -> str:
        return self._name


class LoggingProcessor(ProcessorDecorator[T]):
    """日志处理器"""
    
    def __init__(
        self, 
        processor: DataProcessor[T], 
        prefix: str = "[LOG]",
        verbose: bool = True
    ):
        super().__init__(processor)
        self._prefix = prefix
        self._verbose = verbose
    
    def process(self, data: T) -> T:
        if self._verbose:
            print(f"{self._prefix} 处理前: {data}")
        result = self._processor.process(data)
        if self._verbose:
            print(f"{self._prefix} 处理后: {result}")
        return result


class CacheProcessor(ProcessorDecorator[T]):
    """缓存处理器"""
    
    def __init__(
        self, 
        processor: DataProcessor[T],
        max_size: int = 100
    ):
        super().__init__(processor)
        self._cache: Dict[str, T] = {}
        self._max_size = max_size
        self._hits = 0
        self._misses = 0
    
    def _get_key(self, data: T) -> str:
        return hashlib.md5(str(data).encode()).hexdigest()
    
    def process(self, data: T) -> T:
        key = self._get_key(data)
        if key in self._cache:
            self._hits += 1
            print(f"[CACHE] 命中: {key[:8]}...")
            return self._cache[key]
        
        self._misses += 1
        result = self._processor.process(data)
        
        if len(self._cache) >= self._max_size:
            oldest_key = next(iter(self._cache))
            del self._cache[oldest_key]
        
        self._cache[key] = result
        return result
    
    def get_stats(self) -> Dict[str, Any]:
        total = self._hits + self._misses
        hit_rate = self._hits / total if total > 0 else 0
        return {
            "hits": self._hits,
            "misses": self._misses,
            "hit_rate": f"{hit_rate:.2%}",
            "cache_size": len(self._cache)
        }


class RetryProcessor(ProcessorDecorator[T]):
    """重试处理器"""
    
    def __init__(
        self, 
        processor: DataProcessor[T],
        max_attempts: int = 3,
        delay: float = 0.1
    ):
        super().__init__(processor)
        self._max_attempts = max_attempts
        self._delay = delay
    
    def process(self, data: T) -> T:
        import time
        last_error = None
        for attempt in range(self._max_attempts):
            try:
                return self._processor.process(data)
            except Exception as e:
                last_error = e
                print(f"[RETRY] 尝试 {attempt + 1}/{self._max_attempts} 失败: {e}")
                if attempt < self._max_attempts - 1:
                    time.sleep(self._delay)
        raise last_error


class BatchProcessor(ProcessorDecorator[List[T]]):
    """批处理处理器"""
    
    def __init__(
        self, 
        processor: DataProcessor[T],
        batch_size: int = 10
    ):
        super().__init__(processor)
        self._batch_size = batch_size
    
    def process(self, data: List[T]) -> List[T]:
        results = []
        for i in range(0, len(data), self._batch_size):
            batch = data[i:i + self._batch_size]
            for item in batch:
                results.append(self._processor.process(item))
        return results


class Pipeline:
    """处理管道"""
    
    def __init__(self, name: str = "Pipeline"):
        self._name = name
        self._processors: List[DataProcessor] = []
    
    def add(self, processor: DataProcessor) -> 'Pipeline':
        self._processors.append(processor)
        return self
    
    def process(self, data: Any) -> Any:
        result = data
        for processor in self._processors:
            result = processor.process(result)
        return result
    
    def get_processor_names(self) -> List[str]:
        return [p.get_name() for p in self._processors]


if __name__ == "__main__":
    print("=== 字符串处理管道 ===")
    pipeline = Pipeline("StringProcessor")
    pipeline.add(ValidationProcessor(
        IdentityProcessor(),
        [lambda x: isinstance(x, str), lambda x: len(x) > 0],
        "输入必须是非空字符串"
    ))
    pipeline.add(TransformProcessor(
        IdentityProcessor(),
        lambda x: x.strip(),
        "Trim"
    ))
    pipeline.add(TransformProcessor(
        IdentityProcessor(),
        lambda x: x.lower(),
        "Lowercase"
    ))
    pipeline.add(LoggingProcessor(IdentityProcessor(), "[处理]"))
    
    result = pipeline.process("  HELLO WORLD  ")
    print(f"结果: {result}")
    print(f"处理器链: {' -> '.join(pipeline.get_processor_names())}")
    
    print("\n=== 带缓存的处理管道 ===")
    cache_pipeline = Pipeline("CachedProcessor")
    cache_pipeline.add(CacheProcessor(
        TransformProcessor(IdentityProcessor(), lambda x: x ** 2, "Square")
    ))
    
    for i in [1, 2, 1, 3, 2, 4, 1]:
        cache_pipeline.process(i)
    
    cache_processor = cache_pipeline._processors[0]
    print(f"缓存统计: {cache_processor.get_stats()}")

9.4.3 Web中间件系统

python
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Callable, TypeVar
from enum import Enum
import time
import json
from collections import defaultdict

class HTTPStatus(Enum):
    OK = 200
    CREATED = 201
    BAD_REQUEST = 400
    UNAUTHORIZED = 401
    FORBIDDEN = 403
    NOT_FOUND = 404
    INTERNAL_ERROR = 500
    SERVICE_UNAVAILABLE = 503


@dataclass
class Request:
    path: str
    method: str
    headers: Dict[str, str] = field(default_factory=dict)
    body: Any = None
    params: Dict[str, str] = field(default_factory=dict)
    cookies: Dict[str, str] = field(default_factory=dict)
    ip: str = "127.0.0.1"
    user: Optional[Dict[str, Any]] = None


@dataclass
class Response:
    status_code: int = 200
    body: Any = None
    headers: Dict[str, str] = field(default_factory=dict)
    
    def json(self) -> str:
        return json.dumps(self.body, ensure_ascii=False)


class RequestHandler(ABC):
    """请求处理器抽象"""
    
    @abstractmethod
    def handle(self, request: Request) -> Response:
        pass


class BaseHandler(RequestHandler):
    """基础处理器"""
    
    def handle(self, request: Request) -> Response:
        return Response(
            status_code=HTTPStatus.OK.value,
            body={"message": f"处理 {request.method} {request.path}"}
        )


class Middleware(RequestHandler):
    """中间件基类"""
    
    def __init__(self, handler: RequestHandler):
        self._handler = handler
    
    def handle(self, request: Request) -> Response:
        return self._handler.handle(request)


class AuthMiddleware(Middleware):
    """认证中间件"""
    
    def __init__(
        self, 
        handler: RequestHandler,
        token_validator: Callable[[str], Optional[Dict]],
        exclude_paths: List[str] = None
    ):
        super().__init__(handler)
        self._token_validator = token_validator
        self._exclude_paths = exclude_paths or []
    
    def handle(self, request: Request) -> Response:
        if request.path in self._exclude_paths:
            return self._handler.handle(request)
        
        token = request.headers.get("Authorization", "")
        if token.startswith("Bearer "):
            token = token[7:]
        
        user = self._token_validator(token)
        if not user:
            return Response(
                status_code=HTTPStatus.UNAUTHORIZED.value,
                body={"error": "未授权访问"}
            )
        
        request.user = user
        return self._handler.handle(request)


class CORSMiddleware(Middleware):
    """跨域中间件"""
    
    def __init__(
        self,
        handler: RequestHandler,
        allowed_origins: List[str] = None,
        allowed_methods: List[str] = None,
        allowed_headers: List[str] = None,
        allow_credentials: bool = True
    ):
        super().__init__(handler)
        self._allowed_origins = allowed_origins or ["*"]
        self._allowed_methods = allowed_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
        self._allowed_headers = allowed_headers or ["Content-Type", "Authorization"]
        self._allow_credentials = allow_credentials
    
    def handle(self, request: Request) -> Response:
        if request.method == "OPTIONS":
            response = Response(status_code=HTTPStatus.OK.value)
        else:
            response = self._handler.handle(request)
        
        origin = request.headers.get("Origin", "*")
        if "*" in self._allowed_origins or origin in self._allowed_origins:
            response.headers["Access-Control-Allow-Origin"] = origin if origin != "*" else "*"
        
        response.headers["Access-Control-Allow-Methods"] = ", ".join(self._allowed_methods)
        response.headers["Access-Control-Allow-Headers"] = ", ".join(self._allowed_headers)
        response.headers["Access-Control-Allow-Credentials"] = str(self._allow_credentials).lower()
        
        return response


class RateLimitMiddleware(Middleware):
    """限流中间件"""
    
    def __init__(
        self,
        handler: RequestHandler,
        max_requests: int = 100,
        window_seconds: int = 60,
        key_func: Callable[[Request], str] = None
    ):
        super().__init__(handler)
        self._max_requests = max_requests
        self._window = window_seconds
        self._key_func = key_func or (lambda r: r.ip)
        self._requests: Dict[str, List[float]] = defaultdict(list)
    
    def handle(self, request: Request) -> Response:
        key = self._key_func(request)
        current_time = time.time()
        
        self._requests[key] = [
            t for t in self._requests[key]
            if current_time - t < self._window
        ]
        
        if len(self._requests[key]) >= self._max_requests:
            return Response(
                status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
                body={"error": "请求过于频繁,请稍后再试"}
            )
        
        self._requests[key].append(current_time)
        return self._handler.handle(request)


class LoggingMiddleware(Middleware):
    """日志中间件"""
    
    def __init__(
        self,
        handler: RequestHandler,
        log_request: bool = True,
        log_response: bool = True
    ):
        super().__init__(handler)
        self._log_request = log_request
        self._log_response = log_response
    
    def handle(self, request: Request) -> Response:
        start_time = time.time()
        
        if self._log_request:
            print(f"[REQUEST] {request.method} {request.path}")
            print(f"  IP: {request.ip}")
            print(f"  Headers: {request.headers}")
        
        response = self._handler.handle(request)
        
        elapsed = (time.time() - start_time) * 1000
        
        if self._log_response:
            print(f"[RESPONSE] {response.status_code} ({elapsed:.2f}ms)")
        
        return response


class CompressionMiddleware(Middleware):
    """压缩中间件"""
    
    def __init__(
        self,
        handler: RequestHandler,
        min_size: int = 1024,
        compression_level: int = 6
    ):
        super().__init__(handler)
        self._min_size = min_size
        self._compression_level = compression_level
    
    def handle(self, request: Request) -> Response:
        response = self._handler.handle(request)
        
        accept_encoding = request.headers.get("Accept-Encoding", "")
        if "gzip" not in accept_encoding:
            return response
        
        body_str = response.json() if isinstance(response.body, dict) else str(response.body)
        if len(body_str) < self._min_size:
            return response
        
        import gzip
        compressed = gzip.compress(body_str.encode(), compresslevel=self._compression_level)
        
        response.body = compressed
        response.headers["Content-Encoding"] = "gzip"
        response.headers["Content-Length"] = str(len(compressed))
        
        return response


class CacheMiddleware(Middleware):
    """缓存中间件"""
    
    def __init__(
        self,
        handler: RequestHandler,
        ttl: int = 300,
        cacheable_methods: List[str] = None
    ):
        super().__init__(handler)
        self._ttl = ttl
        self._cacheable_methods = cacheable_methods or ["GET"]
        self._cache: Dict[str, tuple] = {}
    
    def _get_cache_key(self, request: Request) -> str:
        return f"{request.method}:{request.path}:{str(request.params)}"
    
    def handle(self, request: Request) -> Response:
        if request.method not in self._cacheable_methods:
            return self._handler.handle(request)
        
        key = self._get_cache_key(request)
        current_time = time.time()
        
        if key in self._cache:
            cached_response, timestamp = self._cache[key]
            if current_time - timestamp < self._ttl:
                cached_response.headers["X-Cache"] = "HIT"
                return cached_response
        
        response = self._handler.handle(request)
        
        if response.status_code == HTTPStatus.OK.value:
            self._cache[key] = (response, current_time)
        
        response.headers["X-Cache"] = "MISS"
        return response


def create_mock_token_validator():
    valid_tokens = {
        "admin_token": {"id": 1, "name": "Admin", "role": "admin"},
        "user_token": {"id": 2, "name": "User", "role": "user"},
    }
    return lambda token: valid_tokens.get(token)


if __name__ == "__main__":
    handler = BaseHandler()
    handler = AuthMiddleware(
        handler, 
        create_mock_token_validator(),
        exclude_paths=["/health", "/login"]
    )
    handler = RateLimitMiddleware(handler, max_requests=5, window_seconds=60)
    handler = LoggingMiddleware(handler)
    handler = CORSMiddleware(handler, ["https://example.com"])
    handler = CacheMiddleware(handler, ttl=60)
    
    print("=== 有效请求 ===")
    valid_request = Request(
        path="/api/users",
        method="GET",
        headers={"Authorization": "Bearer admin_token"},
        ip="192.168.1.1"
    )
    response = handler.handle(valid_request)
    print(f"Response: {response.status_code}, Body: {response.body}")
    
    print("\n=== 无效认证请求 ===")
    invalid_request = Request(
        path="/api/users",
        method="GET",
        headers={},
        ip="192.168.1.1"
    )
    response = handler.handle(invalid_request)
    print(f"Response: {response.status_code}, Body: {response.body}")
    
    print("\n=== 排除路径请求 ===")
    health_request = Request(
        path="/health",
        method="GET",
        headers={},
        ip="192.168.1.1"
    )
    response = handler.handle(health_request)
    print(f"Response: {response.status_code}, Body: {response.body}")

9.5 企业级应用示例

9.5.1 数据库连接池装饰器

python
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Callable, TypeVar
from dataclasses import dataclass, field
from contextlib import contextmanager
import time
import threading
from queue import Queue

T = TypeVar('T')


class Connection(ABC):
    """数据库连接抽象"""
    
    @abstractmethod
    def execute(self, query: str, params: tuple = ()) -> Any:
        pass
    
    @abstractmethod
    def close(self) -> None:
        pass
    
    @abstractmethod
    def is_alive(self) -> bool:
        pass


@dataclass
class ConnectionConfig:
    host: str
    port: int
    database: str
    username: str
    password: str


class MockConnection(Connection):
    """模拟数据库连接"""
    
    def __init__(self, config: ConnectionConfig):
        self._config = config
        self._alive = True
        self._id = id(self)
        print(f"[Connection] 创建连接 {self._id}")
    
    def execute(self, query: str, params: tuple = ()) -> Any:
        if not self._alive:
            raise ConnectionError("连接已关闭")
        print(f"[Connection {self._id}] 执行: {query}")
        return [{"id": 1, "name": "test"}]
    
    def close(self) -> None:
        self._alive = False
        print(f"[Connection] 关闭连接 {self._id}")
    
    def is_alive(self) -> bool:
        return self._alive


class ConnectionDecorator(Connection):
    """连接装饰器基类"""
    
    def __init__(self, connection: Connection):
        self._connection = connection
    
    def execute(self, query: str, params: tuple = ()) -> Any:
        return self._connection.execute(query, params)
    
    def close(self) -> None:
        self._connection.close()
    
    def is_alive(self) -> bool:
        return self._connection.is_alive()


class LoggingConnection(ConnectionDecorator):
    """日志连接装饰器"""
    
    def execute(self, query: str, params: tuple = ()) -> Any:
        start = time.time()
        print(f"[SQL] {query}")
        if params:
            print(f"[PARAMS] {params}")
        result = self._connection.execute(query, params)
        elapsed = (time.time() - start) * 1000
        print(f"[SQL] 执行时间: {elapsed:.2f}ms")
        return result


class RetryConnection(ConnectionDecorator):
    """重试连接装饰器"""
    
    def __init__(
        self, 
        connection: Connection,
        max_retries: int = 3,
        retry_delay: float = 0.1
    ):
        super().__init__(connection)
        self._max_retries = max_retries
        self._retry_delay = retry_delay
    
    def execute(self, query: str, params: tuple = ()) -> Any:
        last_error = None
        for attempt in range(self._max_retries):
            try:
                return self._connection.execute(query, params)
            except ConnectionError as e:
                last_error = e
                print(f"[RETRY] 数据库操作失败,尝试 {attempt + 1}/{self._max_retries}")
                time.sleep(self._retry_delay)
        raise last_error


class TimeoutConnection(ConnectionDecorator):
    """超时连接装饰器"""
    
    def __init__(self, connection: Connection, timeout: float = 30.0):
        super().__init__(connection)
        self._timeout = timeout
    
    def execute(self, query: str, params: tuple = ()) -> Any:
        result = [None]
        error = [None]
        
        def worker():
            try:
                result[0] = self._connection.execute(query, params)
            except Exception as e:
                error[0] = e
        
        thread = threading.Thread(target=worker)
        thread.start()
        thread.join(timeout=self._timeout)
        
        if thread.is_alive():
            raise TimeoutError(f"查询超时 ({self._timeout}s)")
        
        if error[0]:
            raise error[0]
        
        return result[0]


class ConnectionPool:
    """连接池"""
    
    def __init__(
        self,
        config: ConnectionConfig,
        max_connections: int = 10,
        min_connections: int = 2
    ):
        self._config = config
        self._max_connections = max_connections
        self._min_connections = min_connections
        self._pool: Queue = Queue(maxsize=max_connections)
        self._created = 0
        self._lock = threading.Lock()
        
        for _ in range(min_connections):
            self._pool.put(self._create_connection())
    
    def _create_connection(self) -> Connection:
        conn = MockConnection(self._config)
        conn = LoggingConnection(conn)
        conn = RetryConnection(conn)
        conn = TimeoutConnection(conn)
        return conn
    
    @contextmanager
    def get_connection(self) -> Connection:
        conn = self._acquire()
        try:
            yield conn
        finally:
            self._release(conn)
    
    def _acquire(self) -> Connection:
        with self._lock:
            if not self._pool.empty():
                return self._pool.get()
            if self._created < self._max_connections:
                self._created += 1
                return self._create_connection()
        
        return self._pool.get(timeout=30)
    
    def _release(self, conn: Connection) -> None:
        if conn.is_alive():
            self._pool.put(conn)
        else:
            with self._lock:
                self._created -= 1
    
    def close_all(self) -> None:
        while not self._pool.empty():
            conn = self._pool.get()
            conn.close()


if __name__ == "__main__":
    config = ConnectionConfig(
        host="localhost",
        port=5432,
        database="testdb",
        username="admin",
        password="secret"
    )
    
    pool = ConnectionPool(config, max_connections=5, min_connections=2)
    
    print("=== 执行查询 ===")
    with pool.get_connection() as conn:
        result = conn.execute("SELECT * FROM users WHERE id = %s", (1,))
        print(f"结果: {result}")
    
    print("\n=== 并发查询 ===")
    def worker(worker_id: int):
        with pool.get_connection() as conn:
            conn.execute(f"SELECT * FROM orders WHERE user_id = {worker_id}")
    
    threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
    for t in threads:
        t.start()
    for t in threads:
        t.join()
    
    pool.close_all()

9.5.2 缓存装饰器系统

python
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic, Hashable
from dataclasses import dataclass, field
from functools import wraps
from abc import ABC, abstractmethod
import time
import hashlib
import json
from collections import OrderedDict

T = TypeVar('T')
K = TypeVar('K')


class CacheBackend(ABC):
    """缓存后端抽象"""
    
    @abstractmethod
    def get(self, key: str) -> Optional[Any]:
        pass
    
    @abstractmethod
    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
        pass
    
    @abstractmethod
    def delete(self, key: str) -> None:
        pass
    
    @abstractmethod
    def clear(self) -> None:
        pass


class MemoryCache(CacheBackend):
    """内存缓存"""
    
    def __init__(self, max_size: int = 1000):
        self._cache: OrderedDict[str, tuple] = OrderedDict()
        self._max_size = max_size
    
    def get(self, key: str) -> Optional[Any]:
        if key in self._cache:
            value, expiry = self._cache[key]
            if expiry and time.time() > expiry:
                del self._cache[key]
                return None
            self._cache.move_to_end(key)
            return value
        return None
    
    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
        expiry = time.time() + ttl if ttl else None
        if key in self._cache:
            del self._cache[key]
        elif len(self._cache) >= self._max_size:
            self._cache.popitem(last=False)
        self._cache[key] = (value, expiry)
    
    def delete(self, key: str) -> None:
        if key in self._cache:
            del self._cache[key]
    
    def clear(self) -> None:
        self._cache.clear()


@dataclass
class CacheStats:
    hits: int = 0
    misses: int = 0
    sets: int = 0
    deletes: int = 0
    
    @property
    def hit_rate(self) -> float:
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0


class CacheDecorator:
    """缓存装饰器"""
    
    _backends: Dict[str, CacheBackend] = {}
    _stats: Dict[str, CacheStats] = {}
    
    @classmethod
    def register_backend(cls, name: str, backend: CacheBackend) -> None:
        cls._backends[name] = backend
        cls._stats[name] = CacheStats()
    
    @classmethod
    def get_backend(cls, name: str = "default") -> CacheBackend:
        if name not in cls._backends:
            cls._backends[name] = MemoryCache()
            cls._stats[name] = CacheStats()
        return cls._backends[name]
    
    @classmethod
    def get_stats(cls, name: str = "default") -> CacheStats:
        return cls._stats.get(name, CacheStats())
    
    def __init__(
        self,
        ttl: int = 300,
        backend: str = "default",
        key_prefix: str = "",
        key_builder: Optional[Callable] = None
    ):
        self._ttl = ttl
        self._backend_name = backend
        self._key_prefix = key_prefix
        self._key_builder = key_builder
    
    def _build_key(self, func: Callable, args: tuple, kwargs: dict) -> str:
        if self._key_builder:
            return self._key_builder(func, args, kwargs)
        
        key_data = {
            "func": func.__qualname__,
            "args": args,
            "kwargs": sorted(kwargs.items())
        }
        key_str = json.dumps(key_data, sort_keys=True, default=str)
        key_hash = hashlib.md5(key_str.encode()).hexdigest()
        return f"{self._key_prefix}{func.__qualname__}:{key_hash}"
    
    def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
        backend = self.get_backend(self._backend_name)
        stats = self._stats[self._backend_name]
        
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            key = self._build_key(func, args, kwargs)
            
            cached = backend.get(key)
            if cached is not None:
                stats.hits += 1
                return cached
            
            stats.misses += 1
            result = func(*args, **kwargs)
            backend.set(key, result, self._ttl)
            stats.sets += 1
            return result
        
        wrapper.cache_clear = lambda: backend.delete(self._key_prefix + func.__qualname__)
        wrapper.cache_info = lambda: {
            "backend": self._backend_name,
            "ttl": self._ttl,
            "stats": {
                "hits": stats.hits,
                "misses": stats.misses,
                "hit_rate": f"{stats.hit_rate:.2%}"
            }
        }
        
        return wrapper


def invalidate_cache(*patterns: str):
    """缓存失效装饰器"""
    def decorator(func: Callable[P, T]) -> Callable[P, T]:
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            result = func(*args, **kwargs)
            backend = CacheDecorator.get_backend()
            for pattern in patterns:
                backend.delete(pattern)
            return result
        return wrapper
    return decorator


CacheDecorator.register_backend("default", MemoryCache(max_size=500))
CacheDecorator.register_backend("session", MemoryCache(max_size=100))


if __name__ == "__main__":
    @CacheDecorator(ttl=60, backend="default", key_prefix="user:")
    def get_user(user_id: int) -> dict:
        print(f"[DB] 查询用户 {user_id}")
        return {"id": user_id, "name": f"User{user_id}", "email": f"user{user_id}@example.com"}
    
    @CacheDecorator(ttl=300, backend="session")
    def get_session(session_id: str) -> dict:
        print(f"[DB] 查询会话 {session_id}")
        return {"session_id": session_id, "user_id": 1, "expires": time.time() + 3600}
    
    print("=== 首次查询 ===")
    user1 = get_user(1)
    print(f"结果: {user1}")
    
    print("\n=== 缓存命中 ===")
    user1_cached = get_user(1)
    print(f"结果: {user1_cached}")
    
    print("\n=== 缓存统计 ===")
    print(get_user.cache_info())
    
    print("\n=== 会话缓存 ===")
    session = get_session("abc123")
    print(f"结果: {session}")
    
    print("\n=== 各后端统计 ===")
    for name in ["default", "session"]:
        stats = CacheDecorator.get_stats(name)
        print(f"{name}: hits={stats.hits}, misses={stats.misses}, rate={stats.hit_rate:.2%}")

9.6 模式变体与扩展

9.6.1 透明装饰器

python
from typing import Any, Dict

class TransparentDecorator:
    """透明装饰器 - 保持被装饰对象的所有属性"""
    
    def __init__(self, obj: Any):
        self._obj = obj
    
    def __getattr__(self, name: str) -> Any:
        return getattr(self._obj, name)
    
    def __setattr__(self, name: str, value: Any) -> None:
        if name.startswith('_') or name in self.__dict__:
            super().__setattr__(name, value)
        else:
            setattr(self._obj, name, value)
    
    def __repr__(self) -> str:
        return f"TransparentDecorator({self._obj!r})"

9.6.2 条件装饰器

python
from typing import Callable, TypeVar, ParamSpec

P = ParamSpec('P')
T = TypeVar('T')

def conditional(
    condition: Callable[[], bool],
    decorator: Callable[[Callable[P, T]], Callable[P, T]]
) -> Callable[[Callable[P, T]], Callable[P, T]]:
    """条件装饰器 - 根据条件决定是否应用装饰"""
    def wrapper(func: Callable[P, T]) -> Callable[P, T]:
        if condition():
            return decorator(func)
        return func
    return wrapper

9.6.3 参数化装饰器工厂

python
from typing import Callable, TypeVar, ParamSpec, Any

P = ParamSpec('P')
T = TypeVar('T')

class DecoratorFactory:
    """装饰器工厂 - 支持灵活配置"""
    
    def __init__(self):
        self._decorators: Dict[str, Callable] = {}
    
    def register(self, name: str) -> Callable:
        def decorator(func: Callable[[Callable[P, T]], Callable[P, T]]) -> Callable:
            self._decorators[name] = func
            return func
        return decorator
    
    def create(self, name: str, **kwargs: Any) -> Callable:
        if name not in self._decorators:
            raise ValueError(f"Unknown decorator: {name}")
        return self._decorators[name](**kwargs)
    
    def chain(self, *names: str) -> Callable:
        def decorator(func: Callable[P, T]) -> Callable[P, T]:
            result = func
            for name in reversed(names):
                result = self._decorators[name]()(result)
            return result
        return decorator

9.7 反模式与最佳实践

9.7.1 常见反模式

反模式1:装饰器地狱

python
@decorator_a
@decorator_b
@decorator_c
@decorator_d
@decorator_e
def function():
    pass

问题:过度嵌套导致难以调试和理解。

解决方案:限制装饰器层数,使用组合模式替代。

反模式2:状态泄漏

python
def bad_decorator(func):
    cache = {}
    
    def wrapper(*args):
        if args in cache:
            return cache[args]
        result = func(*args)
        cache[args] = result
        return result
    return wrapper

问题:装饰器状态在所有调用间共享。

解决方案:使用实例级别的状态或显式管理状态。

反模式3:签名丢失

python
def bad_decorator(func):
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

问题:被装饰函数的元数据丢失。

解决方案:使用@functools.wraps保留元数据。

9.7.2 最佳实践

实践描述示例
使用wraps保留被装饰函数的元数据@wraps(func)
类型注解为装饰器添加类型提示Callable[P, T]
文档字符串为装饰器添加说明文档docstring
参数验证验证装饰器参数类型检查
异常处理合理处理装饰器中的异常try/except
可配置性支持装饰器参数配置工厂模式

9.8 模式比较

9.8.1 与相关模式对比

模式意图结构关键区别
装饰器动态添加职责包装链不改变对象本质
代理控制对象访问单层包装关注访问控制
适配器接口转换单层包装关注接口兼容
组合部分-整体结构树形结构关注层次组织
责任链请求传递处理链式传递关注处理流程

9.8.2 装饰器模式与继承对比

┌─────────────────────────────────────────────────────────────┐
│                   装饰器模式 vs 继承                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  继承方式:                     装饰器方式:                   │
│                                                             │
│  Component                     Component                    │
│      │                             │                        │
│      ├── ComponentWithA            └── DecoratorA          │
│      │       │                            │                │
│      │       ├── ComponentWithAB         └── DecoratorB    │
│      │       │                                                  │
│      │       └── ComponentWithAC                                │
│      │                                   (运行时组合)         │
│      └── ComponentWithB                                          │
│              │                           功能组合数:          │
│              ├── ComponentWithBA        n个装饰器 = n!组合    │
│              │                                                  │
│              └── ComponentWithBC        类爆炸问题            │
│                                                             │
│  (编译时确定)                                              │
└─────────────────────────────────────────────────────────────┘

9.9 决策指南

9.9.1 适用场景检查清单

  • [ ] 需要动态添加/撤销对象功能
  • [ ] 需要组合多种功能变体
  • [ ] 不能通过继承扩展(如final类)
  • [ ] 需要在运行时配置功能
  • [ ] 希望避免类爆炸问题

9.9.2 实现选择决策树

                    需要装饰什么?
                    /          \
              函数/方法         类
                /                \
        函数装饰器           需要修改类定义?
           /  \                  /        \
      无参数  有参数            是          否
        |       |               |            |
    @decorator  工厂模式    类装饰器    实例装饰器

9.9.3 快速参考卡

┌─────────────────────────────────────────────────────────────┐
│                    装饰器模式快速参考                        │
├─────────────────────────────────────────────────────────────┤
│ 定义: 动态给对象添加职责,比继承更灵活                        │
├─────────────────────────────────────────────────────────────┤
│ 参与者:                                                     │
│   • Component       - 统一接口                              │
│   • ConcreteComponent - 被装饰对象                          │
│   • Decorator       - 装饰器基类                            │
│   • ConcreteDecorator - 具体装饰器                          │
├─────────────────────────────────────────────────────────────┤
│ 形式化: d_n ∘ d_{n-1} ∘ ... ∘ d_1(c) = c'                 │
├─────────────────────────────────────────────────────────────┤
│ Python特性:                                                 │
│   • @语法糖                                                 │
│   • functools.wraps                                         │
│   • ParamSpec/TypeVar                                       │
│   • 类装饰器                                                │
├─────────────────────────────────────────────────────────────┤
│ 典型应用:                                                   │
│   • 日志/计时                                               │
│   • 缓存/记忆化                                             │
│   • 权限验证                                                │
│   • 中间件系统                                              │
│   • 数据处理管道                                            │
├─────────────────────────────────────────────────────────────┤
│ 最佳实践:                                                   │
│   • 使用@wraps保留元数据                                    │
│   • 添加类型注解                                            │
│   • 限制装饰器层数                                          │
│   • 文档化装饰器行为                                        │
└─────────────────────────────────────────────────────────────┘

9.10 小结

装饰器模式是Python中最具特色的设计模式之一。其核心价值在于:

  1. 开闭原则:无需修改现有代码即可扩展功能
  2. 单一职责:每个装饰器专注于单一功能
  3. 动态组合:运行时灵活组合多种功能
  4. 避免继承:解决继承导致的类爆炸问题

在Python中,装饰器更是一种语言级特性,从函数装饰器到类装饰器,从简单包装到复杂的中间件系统,装饰器模式在Python生态中无处不在。掌握装饰器模式是成为Python高级开发者的必经之路。


思考与练习

  1. 基础练习:实现一个支持缓存失效策略的装饰器,支持LRU、LFU、FIFO等策略。

  2. 进阶练习:设计一个异步装饰器系统,支持async/await语法的装饰器。

  3. 挑战练习:实现一个装饰器组合框架,支持声明式配置装饰器链。

  4. 设计思考:装饰器模式与AOP(面向切面编程)有何关系?请分析其异同。

Python技术丛书 - 江苏省宿城中等专业学校