第21章 策略模式
学习目标
完成本章学习后,读者将能够:
- 理解策略模式的核心概念、数学定义与理论基础
- 掌握算法族的封装与动态切换方法
- 使用Python实现多种策略模式变体
- 识别策略模式的适用场景与局限性
- 理解策略模式与状态模式、命令模式的本质区别
21.1 模式定义
21.1.1 核心定义
策略模式(Strategy Pattern) 定义一系列算法,将每个算法封装到独立的类中,并使它们可以相互替换。策略模式让算法独立于使用它的客户端而变化,是实现开闭原则的重要手段。
21.1.2 形式化定义
从数学角度,策略模式可以形式化定义为:
定义 21.1(策略系统) 策略系统是一个四元组:
$$\mathcal{S} = \langle C, P, \Sigma, \phi \rangle$$
其中:
- $C$:上下文(Context)的状态空间
- $P = {p_1, p_2, \ldots, p_n}$:策略集合
- $\Sigma$:输入数据域
- $\phi: P \times \Sigma \rightarrow \Omega$:策略执行函数,将策略和输入映射到输出
定义 21.2(策略等价性) 两个策略 $p_i$ 和 $p_j$ 在输入域 $\Sigma$ 上等价,当且仅当:
$$\forall \sigma \in \Sigma: \phi(p_i, \sigma) = \phi(p_j, \sigma)$$
定义 21.3(策略复杂度) 策略 $p$ 的时间复杂度定义为:
$$T_p(n) = \text{time}(\phi(p, \sigma_n))$$
其中 $\sigma_n$ 是大小为 $n$ 的输入。
定理 21.1(策略独立性) 策略模式中,策略的添加或删除不影响现有策略的正确性:
$$\text{Correctness}(P) = \bigwedge_{p \in P} \text{Correctness}(p)$$
21.1.3 策略模式 vs 条件分支
| 特性 | 策略模式 | 条件分支 |
|---|---|---|
| 扩展性 | 新增策略无需修改现有代码 | 需要修改条件判断 |
| 可测试性 | 每个策略独立测试 | 需要覆盖所有分支 |
| 复用性 | 策略可在多个上下文复用 | 逻辑绑定在方法中 |
| 运行时切换 | 支持 | 需要重新编译/部署 |
| 代码复杂度 | 类数量增加 | 方法长度增加 |
21.1.4 策略模式 vs 状态模式
| 特性 | 策略模式 | 状态模式 |
|---|---|---|
| 目的 | 算法/行为的可替换性 | 状态驱动的行为变化 |
| 策略/状态关系 | 策略之间相互独立 | 状态之间有转换关系 |
| 切换控制 | 客户端控制 | 状态类可触发转换 |
| 生命周期 | 策略通常无状态 | 状态可能有生命周期 |
| 典型应用 | 支付方式、排序算法 | 订单状态、工作流 |
21.2 历史背景与理论渊源
21.2.1 发展历程
| 年份 | 里程碑 | 贡献者 | 意义 |
|---|---|---|---|
| 1970s | 多态与动态绑定 | Smalltalk社区 | 策略模式的理论基础 |
| 1980s | 算法族封装 | 面向对象设计中的算法抽象 | |
| 1994 | GoF策略模式 | Gamma et al. | 将策略模式标准化为23种设计模式之一 |
| 1995 | STL算法 | Alexander Stepanov | C++标准库中的策略模式应用 |
| 2000s | 函数式编程兴起 | 多方 | 策略模式在函数式语言中的自然实现 |
| 2010s | 依赖注入框架 | 多方 | Spring等框架中的策略注入 |
| 2015 | Python类型提示 | Guido van Rossum | 类型安全的策略模式实现 |
21.2.2 理论基础
策略模式的理论基础源于:
开闭原则(OCP):对扩展开放,对修改关闭
依赖倒置原则(DIP):高层模块依赖抽象策略接口
单一职责原则(SRP):每个策略类只负责一种算法
里氏替换原则(LSP):策略之间可以相互替换
21.3 UML结构图
21.3.1 标准结构
┌─────────────────────────────────────────────────────────────────┐
│ <<interface>> │
│ Strategy │
├─────────────────────────────────────────────────────────────────┤
│ + execute(data: T): R │
│ + get_name(): str │
│ + get_complexity(): str │
└─────────────────────────────────────────────────────────────────┘
△
│
┌───────────────┼───────────────┐
│ │ │
┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐
│ ConcreteStrategyA │ │ ConcreteStrategyB │ │ ConcreteStrategyC │
├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤
│ - config: Config │ │ - config: Config │ │ - config: Config │
├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤
│ + execute(data): R │ │ + execute(data): R │ │ + execute(data): R │
│ + get_name(): str │ │ + get_name(): str │ │ + get_name(): str │
└─────────────────────┘ └─────────────────────┘ └─────────────────────┘
│
│ 使用
↓
┌─────────────────────────────────────────────────────────────────┐
│ Context │
├─────────────────────────────────────────────────────────────────┤
│ - strategy: Strategy │
│ - strategy_registry: Dict[str, Strategy] │
├─────────────────────────────────────────────────────────────────┤
│ + set_strategy(strategy: Strategy): void │
│ + execute_strategy(data: T): R │
│ + register_strategy(name: str, strategy: Strategy): void │
│ + get_available_strategies(): List[str] │
└─────────────────────────────────────────────────────────────────┘21.3.2 策略注入模式
┌─────────────────────────────────────────────────────────────────┐
│ Strategy Registry │
├─────────────────────────────────────────────────────────────────┤
│ - strategies: Dict[str, Strategy] │
│ - default_strategy: str │
├─────────────────────────────────────────────────────────────────┤
│ + register(name: str, strategy: Strategy): void │
│ + get(name: str): Strategy │
│ + get_all(): Dict[str, Strategy] │
└─────────────────────────────────────────────────────────────────┘
│ │
│ 注册 │ 查找
│ │
┌─────────────────┐ ┌─────────────────┐
│ StrategyA │ │ Context │
├─────────────────┤ ├─────────────────┤
│ + execute() │ │ - registry │
└─────────────────┘ │ + execute() │
┌─────────────────┐ └─────────────────┘
│ StrategyB │
├─────────────────┤
│ + execute() │
└─────────────────┘21.3.3 策略组合模式
┌─────────────────────────────────────────────────────────────────┐
│ CompositeStrategy │
├─────────────────────────────────────────────────────────────────┤
│ - strategies: List[Strategy] │
│ - combiner: Callable[[List[R]], R] │
├─────────────────────────────────────────────────────────────────┤
│ + execute(data: T): R │
│ + add_strategy(strategy: Strategy): void │
│ + remove_strategy(strategy: Strategy): void │
└─────────────────────────────────────────────────────────────────┘
│
│ 组合
↓
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ StrategyA │ │ StrategyB │ │ StrategyC │
├─────────────────┤ ├─────────────────┤ ├─────────────────┤
│ + execute() │ │ + execute() │ │ + execute() │
└─────────────────┘ └─────────────────┘ └─────────────────┘21.4 Python实现
21.4.1 基于ABC的标准实现
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Dict, List, Optional, Any
from dataclasses import dataclass
from enum import Enum, auto
T = TypeVar('T')
R = TypeVar('R')
class Strategy(ABC, Generic[T, R]):
@abstractmethod
def execute(self, data: T) -> R:
pass
@property
@abstractmethod
def name(self) -> str:
pass
@property
def description(self) -> str:
return ""
@property
def complexity(self) -> str:
return "O(n)"
class StrategyRegistry(Generic[T, R]):
def __init__(self):
self._strategies: Dict[str, Strategy[T, R]] = {}
self._default: Optional[str] = None
def register(self, strategy: Strategy[T, R], is_default: bool = False) -> None:
self._strategies[strategy.name] = strategy
if is_default or self._default is None:
self._default = strategy.name
def get(self, name: str = None) -> Optional[Strategy[T, R]]:
key = name or self._default
return self._strategies.get(key)
def get_all(self) -> Dict[str, Strategy[T, R]]:
return self._strategies.copy()
def get_names(self) -> List[str]:
return list(self._strategies.keys())
def unregister(self, name: str) -> bool:
if name in self._strategies:
del self._strategies[name]
if self._default == name:
self._default = next(iter(self._strategies), None) if self._strategies else None
return True
return False
class Context(Generic[T, R]):
def __init__(self, strategy: Strategy[T, R] = None):
self._strategy: Optional[Strategy[T, R]] = strategy
self._registry = StrategyRegistry[T, R]()
self._history: List[tuple[str, T, R]] = []
@property
def strategy(self) -> Optional[Strategy[T, R]]:
return self._strategy
@strategy.setter
def strategy(self, strategy: Strategy[T, R]) -> None:
self._strategy = strategy
def set_strategy_by_name(self, name: str) -> bool:
strategy = self._registry.get(name)
if strategy:
self._strategy = strategy
return True
return False
def register_strategy(self, strategy: Strategy[T, R], is_default: bool = False) -> None:
self._registry.register(strategy, is_default)
def execute(self, data: T) -> R:
if not self._strategy:
raise ValueError("No strategy set")
result = self._strategy.execute(data)
self._history.append((self._strategy.name, data, result))
return result
def get_available_strategies(self) -> List[str]:
return self._registry.get_names()
def get_history(self) -> List[tuple[str, T, R]]:
return self._history.copy()
@dataclass
class SortData:
items: List[int]
@dataclass
class SortResult:
sorted_items: List[int]
comparisons: int
swaps: int
class BubbleSortStrategy(Strategy[SortData, SortResult]):
@property
def name(self) -> str:
return "bubble_sort"
@property
def description(self) -> str:
return "冒泡排序:简单但效率较低的排序算法"
@property
def complexity(self) -> str:
return "O(n²)"
def execute(self, data: SortData) -> SortResult:
items = data.items.copy()
n = len(items)
comparisons = 0
swaps = 0
for i in range(n):
for j in range(0, n - i - 1):
comparisons += 1
if items[j] > items[j + 1]:
items[j], items[j + 1] = items[j + 1], items[j]
swaps += 1
return SortResult(sorted_items=items, comparisons=comparisons, swaps=swaps)
class QuickSortStrategy(Strategy[SortData, SortResult]):
@property
def name(self) -> str:
return "quick_sort"
@property
def description(self) -> str:
return "快速排序:高效的分治排序算法"
@property
def complexity(self) -> str:
return "O(n log n)"
def execute(self, data: SortData) -> SortResult:
items = data.items.copy()
comparisons = [0]
swaps = [0]
def partition(arr, low, high):
pivot = arr[high]
i = low - 1
for j in range(low, high):
comparisons[0] += 1
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]
swaps[0] += 1
arr[i + 1], arr[high] = arr[high], arr[i + 1]
swaps[0] += 1
return i + 1
def quicksort(arr, low, high):
if low < high:
pi = partition(arr, low, high)
quicksort(arr, low, pi - 1)
quicksort(arr, pi + 1, high)
if items:
quicksort(items, 0, len(items) - 1)
return SortResult(sorted_items=items, comparisons=comparisons[0], swaps=swaps[0])
class MergeSortStrategy(Strategy[SortData, SortResult]):
@property
def name(self) -> str:
return "merge_sort"
@property
def description(self) -> str:
return "归并排序:稳定的分治排序算法"
@property
def complexity(self) -> str:
return "O(n log n)"
def execute(self, data: SortData) -> SortResult:
items = data.items.copy()
comparisons = [0]
def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
comparisons[0] += 1
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
def mergesort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
return merge(mergesort(arr[:mid]), mergesort(arr[mid:]))
sorted_items = mergesort(items)
return SortResult(sorted_items=sorted_items, comparisons=comparisons[0], swaps=0)
context = Context[SortData, SortResult]()
context.register_strategy(BubbleSortStrategy())
context.register_strategy(QuickSortStrategy(), is_default=True)
context.register_strategy(MergeSortStrategy())
print(f"可用策略: {context.get_available_strategies()}")
data = SortData([64, 34, 25, 12, 22, 11, 90])
for strategy_name in context.get_available_strategies():
context.set_strategy_by_name(strategy_name)
result = context.execute(data)
print(f"\n{strategy_name}:")
print(f" 排序结果: {result.sorted_items}")
print(f" 比较次数: {result.comparisons}")
print(f" 交换次数: {result.swaps}")21.4.2 函数式策略实现
from typing import Callable, TypeVar, Dict, List, Optional, Any
from dataclasses import dataclass
from functools import wraps
T = TypeVar('T')
R = TypeVar('R')
@dataclass
class StrategyInfo:
name: str
func: Callable
description: str = ""
complexity: str = ""
class FunctionalStrategyRegistry:
def __init__(self):
self._strategies: Dict[str, StrategyInfo] = {}
self._default: Optional[str] = None
def strategy(
self,
name: str = None,
description: str = "",
complexity: str = "",
is_default: bool = False
) -> Callable:
def decorator(func: Callable) -> Callable:
strategy_name = name or func.__name__
self._strategies[strategy_name] = StrategyInfo(
name=strategy_name,
func=func,
description=description,
complexity=complexity
)
if is_default or self._default is None:
self._default = strategy_name
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
def get(self, name: str = None) -> Optional[Callable]:
key = name or self._default
info = self._strategies.get(key)
return info.func if info else None
def get_info(self, name: str) -> Optional[StrategyInfo]:
return self._strategies.get(name)
def list_strategies(self) -> List[StrategyInfo]:
return list(self._strategies.values())
class FunctionalContext:
def __init__(self, registry: FunctionalStrategyRegistry = None):
self._registry = registry or FunctionalStrategyRegistry()
self._current_strategy: Optional[str] = None
@property
def registry(self) -> FunctionalStrategyRegistry:
return self._registry
def use(self, strategy_name: str) -> 'FunctionalContext':
if strategy_name not in self._registry._strategies:
raise ValueError(f"Unknown strategy: {strategy_name}")
self._current_strategy = strategy_name
return self
def execute(self, *args, **kwargs) -> Any:
if not self._current_strategy:
self._current_strategy = self._registry._default
func = self._registry.get(self._current_strategy)
if not func:
raise ValueError("No strategy available")
return func(*args, **kwargs)
registry = FunctionalStrategyRegistry()
@registry.strategy(name="simple_tax", description="简单税率计算", complexity="O(1)", is_default=True)
def calculate_simple_tax(amount: float, rate: float = 0.1) -> float:
return amount * rate
@registry.strategy(name="progressive_tax", description="累进税率计算", complexity="O(n)")
def calculate_progressive_tax(amount: float, brackets: List[tuple] = None) -> float:
if brackets is None:
brackets = [(0, 0.1), (10000, 0.15), (30000, 0.2), (50000, 0.25)]
tax = 0.0
prev_bracket = 0
for threshold, rate in brackets:
if amount <= prev_bracket:
break
taxable = min(amount, threshold) - prev_bracket
tax += max(0, taxable) * rate
prev_bracket = threshold
return tax
@registry.strategy(name="flat_tax", description="固定税率计算", complexity="O(1)")
def calculate_flat_tax(amount: float, rate: float = 0.2) -> float:
return amount * rate
context = FunctionalContext(registry)
print("可用策略:")
for info in registry.list_strategies():
print(f" - {info.name}: {info.description} ({info.complexity})")
amount = 50000
print(f"\n收入: ¥{amount}")
print(f"简单税: ¥{context.use('simple_tax').execute(amount)}")
print(f"累进税: ¥{context.use('progressive_tax').execute(amount)}")
print(f"固定税: ¥{context.use('flat_tax').execute(amount)}")21.4.3 策略工厂模式
from abc import ABC, abstractmethod
from typing import Dict, Type, Optional, List, Any, Callable
from dataclasses import dataclass
from enum import Enum, auto
import inspect
class StrategyCategory(Enum):
SORTING = auto()
SEARCH = auto()
ENCRYPTION = auto()
COMPRESSION = auto()
VALIDATION = auto()
@dataclass
class StrategyMetadata:
name: str
category: StrategyCategory
description: str
complexity: str
stable: bool = True
in_place: bool = False
class StrategyFactory:
_registry: Dict[str, Dict[str, Type]] = {}
_metadata: Dict[str, StrategyMetadata] = {}
@classmethod
def register(
cls,
category: StrategyCategory,
name: str = None,
description: str = "",
complexity: str = "",
stable: bool = True,
in_place: bool = False
) -> Callable:
def decorator(strategy_class: Type) -> Type:
strategy_name = name or strategy_class.__name__
if category not in cls._registry:
cls._registry[category] = {}
cls._registry[category][strategy_name] = strategy_class
cls._metadata[strategy_name] = StrategyMetadata(
name=strategy_name,
category=category,
description=description,
complexity=complexity,
stable=stable,
in_place=in_place
)
return strategy_class
return decorator
@classmethod
def create(cls, category: StrategyCategory, name: str, *args, **kwargs) -> Any:
if category not in cls._registry:
raise ValueError(f"Unknown category: {category}")
if name not in cls._registry[category]:
raise ValueError(f"Unknown strategy: {name}")
strategy_class = cls._registry[category][name]
return strategy_class(*args, **kwargs)
@classmethod
def get_available(cls, category: StrategyCategory) -> List[str]:
return list(cls._registry.get(category, {}).keys())
@classmethod
def get_metadata(cls, name: str) -> Optional[StrategyMetadata]:
return cls._metadata.get(name)
@classmethod
def get_all_metadata(cls, category: StrategyCategory) -> List[StrategyMetadata]:
return [
cls._metadata[name]
for name in cls._registry.get(category, {}).keys()
]
@StrategyFactory.register(
StrategyCategory.SORTING,
name="quick_sort",
description="快速排序",
complexity="O(n log n)",
stable=False,
in_place=True
)
class QuickSort:
def __init__(self, pivot_strategy: str = "middle"):
self.pivot_strategy = pivot_strategy
def sort(self, data: List[int]) -> List[int]:
if len(data) <= 1:
return data
if self.pivot_strategy == "first":
pivot = data[0]
elif self.pivot_strategy == "last":
pivot = data[-1]
else:
pivot = data[len(data) // 2]
left = [x for x in data if x < pivot]
middle = [x for x in data if x == pivot]
right = [x for x in data if x > pivot]
return self.sort(left) + middle + self.sort(right)
@StrategyFactory.register(
StrategyCategory.SORTING,
name="merge_sort",
description="归并排序",
complexity="O(n log n)",
stable=True,
in_place=False
)
class MergeSort:
def sort(self, data: List[int]) -> List[int]:
if len(data) <= 1:
return data
mid = len(data) // 2
left = self.sort(data[:mid])
right = self.sort(data[mid:])
return self._merge(left, right)
def _merge(self, left: List[int], right: List[int]) -> List[int]:
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
@StrategyFactory.register(
StrategyCategory.SORTING,
name="heap_sort",
description="堆排序",
complexity="O(n log n)",
stable=False,
in_place=True
)
class HeapSort:
def sort(self, data: List[int]) -> List[int]:
arr = data.copy()
n = len(arr)
for i in range(n // 2 - 1, -1, -1):
self._heapify(arr, n, i)
for i in range(n - 1, 0, -1):
arr[0], arr[i] = arr[i], arr[0]
self._heapify(arr, i, 0)
return arr
def _heapify(self, arr: List[int], n: int, i: int) -> None:
largest = i
left = 2 * i + 1
right = 2 * i + 2
if left < n and arr[left] > arr[largest]:
largest = left
if right < n and arr[right] > arr[largest]:
largest = right
if largest != i:
arr[i], arr[largest] = arr[largest], arr[i]
self._heapify(arr, n, largest)
print("排序策略:")
for metadata in StrategyFactory.get_all_metadata(StrategyCategory.SORTING):
print(f" - {metadata.name}: {metadata.description}")
print(f" 复杂度: {metadata.complexity}, 稳定: {metadata.stable}, 原地: {metadata.in_place}")
data = [64, 34, 25, 12, 22, 11, 90]
for strategy_name in StrategyFactory.get_available(StrategyCategory.SORTING):
strategy = StrategyFactory.create(StrategyCategory.SORTING, strategy_name)
result = strategy.sort(data)
print(f"\n{strategy_name}: {result}")21.4.4 组合策略模式
from typing import List, Callable, Any, Optional, TypeVar, Generic
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
T = TypeVar('T')
R = TypeVar('R')
class Combiner(ABC, Generic[R]):
@abstractmethod
def combine(self, results: List[R]) -> R:
pass
class FirstResultCombiner(Combiner[Any]):
def combine(self, results: List[Any]) -> Any:
return results[0] if results else None
class LastResultCombiner(Combiner[Any]):
def combine(self, results: List[Any]) -> Any:
return results[-1] if results else None
class ListCombiner(Combiner[List[Any]]):
def combine(self, results: List[Any]) -> List[Any]:
return results
class VotingCombiner(Combiner[Any]):
def combine(self, results: List[Any]) -> Any:
if not results:
return None
counts: dict = {}
for result in results:
key = str(result)
counts[key] = counts.get(key, 0) + 1
return max(counts.keys(), key=lambda k: counts[k])
class CompositeStrategy(Strategy[T, List[R]]):
def __init__(
self,
strategies: List[Strategy[T, R]] = None,
combiner: Combiner[R] = None
):
self._strategies: List[Strategy[T, R]] = strategies or []
self._combiner = combiner or ListCombiner()
@property
def name(self) -> str:
return "composite"
def add_strategy(self, strategy: Strategy[T, R]) -> None:
self._strategies.append(strategy)
def remove_strategy(self, strategy: Strategy[T, R]) -> None:
if strategy in self._strategies:
self._strategies.remove(strategy)
def execute(self, data: T) -> R:
results = [s.execute(data) for s in self._strategies]
return self._combiner.combine(results)
@dataclass
class ValidationResult:
is_valid: bool
errors: List[str] = field(default_factory=list)
class ValidationStrategy(Strategy[str, ValidationResult]):
pass
class LengthValidationStrategy(ValidationStrategy):
def __init__(self, min_length: int = 0, max_length: int = 100):
self._min = min_length
self._max = max_length
@property
def name(self) -> str:
return "length"
def execute(self, data: str) -> ValidationResult:
errors = []
if len(data) < self._min:
errors.append(f"长度不足{self._min}字符")
if len(data) > self._max:
errors.append(f"长度超过{self._max}字符")
return ValidationResult(len(errors) == 0, errors)
class PatternValidationStrategy(ValidationStrategy):
def __init__(self, pattern: str, description: str = ""):
self._pattern = pattern
self._description = description
@property
def name(self) -> str:
return "pattern"
def execute(self, data: str) -> ValidationResult:
import re
if re.match(self._pattern, data):
return ValidationResult(True)
return ValidationResult(False, [f"不符合{self._description}格式"])
class CompositeValidationStrategy(CompositeStrategy[str, ValidationResult]):
def __init__(self, strategies: List[ValidationStrategy] = None):
super().__init__(strategies, AllValidCombiner())
@property
def name(self) -> str:
return "composite_validation"
def execute(self, data: str) -> ValidationResult:
results = [s.execute(data) for s in self._strategies]
all_valid = all(r.is_valid for r in results)
all_errors = []
for r in results:
all_errors.extend(r.errors)
return ValidationResult(all_valid, all_errors)
class AllValidCombiner(Combiner[ValidationResult]):
def combine(self, results: List[ValidationResult]) -> ValidationResult:
all_valid = all(r.is_valid for r in results)
all_errors = []
for r in results:
all_errors.extend(r.errors)
return ValidationResult(all_valid, all_errors)
validator = CompositeValidationStrategy([
LengthValidationStrategy(min_length=8, max_length=20),
PatternValidationStrategy(r'.*[A-Z].*', "大写字母"),
PatternValidationStrategy(r'.*[0-9].*', "数字"),
])
test_passwords = ["weak", "StrongPass123", "short", "ValidPass99"]
for password in test_passwords:
result = validator.execute(password)
status = "有效" if result.is_valid else "无效"
print(f"'{password}': {status}")
if result.errors:
print(f" 错误: {', '.join(result.errors)}")21.4.5 策略缓存模式
from typing import Dict, List, Callable, Any, Optional, TypeVar, Generic
from dataclasses import dataclass, field
from functools import lru_cache
import hashlib
import json
T = TypeVar('T')
R = TypeVar('R')
@dataclass
class CacheEntry:
result: Any
timestamp: float
hits: int = 0
class CachedStrategy(Strategy[T, R]):
def __init__(
self,
strategy: Strategy[T, R],
max_size: int = 100,
ttl: float = 3600
):
self._strategy = strategy
self._max_size = max_size
self._ttl = ttl
self._cache: Dict[str, CacheEntry] = {}
self._hits = 0
self._misses = 0
@property
def name(self) -> str:
return f"cached_{self._strategy.name}"
@property
def strategy(self) -> Strategy[T, R]:
return self._strategy
def _make_key(self, data: T) -> str:
if isinstance(data, (str, int, float, bool)):
return str(data)
try:
return hashlib.md5(json.dumps(data, sort_keys=True).encode()).hexdigest()
except (TypeError, ValueError):
return str(id(data))
def execute(self, data: T) -> R:
import time
key = self._make_key(data)
current_time = time.time()
if key in self._cache:
entry = self._cache[key]
if current_time - entry.timestamp < self._ttl:
entry.hits += 1
self._hits += 1
return entry.result
else:
del self._cache[key]
result = self._strategy.execute(data)
self._misses += 1
if len(self._cache) >= self._max_size:
self._evict_oldest()
self._cache[key] = CacheEntry(result=result, timestamp=current_time)
return result
def _evict_oldest(self) -> None:
if not self._cache:
return
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].timestamp)
del self._cache[oldest_key]
def clear_cache(self) -> None:
self._cache.clear()
self._hits = 0
self._misses = 0
def get_stats(self) -> Dict[str, Any]:
total = self._hits + self._misses
hit_rate = self._hits / total if total > 0 else 0
return {
"cache_size": len(self._cache),
"max_size": self._max_size,
"hits": self._hits,
"misses": self._misses,
"hit_rate": f"{hit_rate:.2%}"
}
class ExpensiveCalculationStrategy(Strategy[int, int]):
@property
def name(self) -> str:
return "expensive_calc"
def execute(self, data: int) -> int:
import time
time.sleep(0.01)
return data * data
cached_strategy = CachedStrategy(ExpensiveCalculationStrategy(), max_size=10)
print("执行计算...")
for i in range(5):
result = cached_strategy.execute(i)
print(f" {i}² = {result}")
print("\n重复计算(命中缓存)...")
for i in range(5):
result = cached_strategy.execute(i)
print(f" {i}² = {result}")
print(f"\n缓存统计: {cached_strategy.get_stats()}")21.5 企业级应用示例
21.5.1 支付网关策略
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from enum import Enum, auto
import uuid
class PaymentStatus(Enum):
PENDING = auto()
SUCCESS = auto()
FAILED = auto()
CANCELLED = auto()
REFUNDED = auto()
@dataclass
class PaymentRequest:
order_id: str
amount: Decimal
currency: str = "CNY"
description: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class PaymentResponse:
transaction_id: str
status: PaymentStatus
amount: Decimal
currency: str
gateway: str
message: str = ""
timestamp: datetime = field(default_factory=datetime.now)
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class RefundResponse:
refund_id: str
transaction_id: str
amount: Decimal
status: PaymentStatus
message: str = ""
class PaymentStrategy(ABC):
@property
@abstractmethod
def name(self) -> str:
pass
@property
@abstractmethod
def supported_currencies(self) -> List[str]:
pass
@property
def fee_rate(self) -> Decimal:
return Decimal("0")
@abstractmethod
def pay(self, request: PaymentRequest) -> PaymentResponse:
pass
@abstractmethod
def refund(self, transaction_id: str, amount: Decimal) -> RefundResponse:
pass
def supports_currency(self, currency: str) -> bool:
return currency in self.supported_currencies
def calculate_fee(self, amount: Decimal) -> Decimal:
return amount * self.fee_rate
class AlipayStrategy(PaymentStrategy):
def __init__(self, app_id: str, private_key: str):
self._app_id = app_id
self._private_key = private_key
@property
def name(self) -> str:
return "alipay"
@property
def supported_currencies(self) -> List[str]:
return ["CNY"]
@property
def fee_rate(self) -> Decimal:
return Decimal("0.006")
def pay(self, request: PaymentRequest) -> PaymentResponse:
transaction_id = f"ALI{uuid.uuid4().hex[:12].upper()}"
print(f" [支付宝] 处理支付: {request.amount} {request.currency}")
return PaymentResponse(
transaction_id=transaction_id,
status=PaymentStatus.SUCCESS,
amount=request.amount,
currency=request.currency,
gateway=self.name,
message="支付宝支付成功",
metadata={"app_id": self._app_id}
)
def refund(self, transaction_id: str, amount: Decimal) -> RefundResponse:
refund_id = f"ALIR{uuid.uuid4().hex[:12].upper()}"
print(f" [支付宝] 处理退款: {amount}")
return RefundResponse(
refund_id=refund_id,
transaction_id=transaction_id,
amount=amount,
status=PaymentStatus.REFUNDED,
message="退款成功"
)
class WeChatPayStrategy(PaymentStrategy):
def __init__(self, app_id: str, mch_id: str, api_key: str):
self._app_id = app_id
self._mch_id = mch_id
self._api_key = api_key
@property
def name(self) -> str:
return "wechat"
@property
def supported_currencies(self) -> List[str]:
return ["CNY"]
@property
def fee_rate(self) -> Decimal:
return Decimal("0.006")
def pay(self, request: PaymentRequest) -> PaymentResponse:
transaction_id = f"WX{uuid.uuid4().hex[:12].upper()}"
print(f" [微信支付] 处理支付: {request.amount} {request.currency}")
return PaymentResponse(
transaction_id=transaction_id,
status=PaymentStatus.SUCCESS,
amount=request.amount,
currency=request.currency,
gateway=self.name,
message="微信支付成功",
metadata={"mch_id": self._mch_id}
)
def refund(self, transaction_id: str, amount: Decimal) -> RefundResponse:
refund_id = f"WXR{uuid.uuid4().hex[:12].upper()}"
print(f" [微信支付] 处理退款: {amount}")
return RefundResponse(
refund_id=refund_id,
transaction_id=transaction_id,
amount=amount,
status=PaymentStatus.REFUNDED,
message="退款成功"
)
class StripeStrategy(PaymentStrategy):
def __init__(self, api_key: str):
self._api_key = api_key
@property
def name(self) -> str:
return "stripe"
@property
def supported_currencies(self) -> List[str]:
return ["USD", "EUR", "GBP", "JPY"]
@property
def fee_rate(self) -> Decimal:
return Decimal("0.029")
def pay(self, request: PaymentRequest) -> PaymentResponse:
transaction_id = f"STRIPE{uuid.uuid4().hex[:12].upper()}"
print(f" [Stripe] Processing payment: {request.amount} {request.currency}")
return PaymentResponse(
transaction_id=transaction_id,
status=PaymentStatus.SUCCESS,
amount=request.amount,
currency=request.currency,
gateway=self.name,
message="Stripe payment successful",
metadata={"api_key_prefix": self._api_key[:8]}
)
def refund(self, transaction_id: str, amount: Decimal) -> RefundResponse:
refund_id = f"STRIPER{uuid.uuid4().hex[:12].upper()}"
print(f" [Stripe] Processing refund: {amount}")
return RefundResponse(
refund_id=refund_id,
transaction_id=transaction_id,
amount=amount,
status=PaymentStatus.REFUNDED,
message="Refund successful"
)
class PaymentGateway:
def __init__(self):
self._strategies: Dict[str, PaymentStrategy] = {}
self._default_strategy: Optional[str] = None
def register(self, strategy: PaymentStrategy, is_default: bool = False) -> None:
self._strategies[strategy.name] = strategy
if is_default or self._default_strategy is None:
self._default_strategy = strategy.name
def get_strategy(self, name: str = None) -> Optional[PaymentStrategy]:
key = name or self._default_strategy
return self._strategies.get(key)
def get_available_gateways(self) -> List[str]:
return list(self._strategies.keys())
def get_supported_currencies(self, gateway: str) -> List[str]:
strategy = self._strategies.get(gateway)
return strategy.supported_currencies if strategy else []
def pay(
self,
request: PaymentRequest,
gateway: str = None
) -> PaymentResponse:
strategy = self.get_strategy(gateway)
if not strategy:
return PaymentResponse(
transaction_id="",
status=PaymentStatus.FAILED,
amount=request.amount,
currency=request.currency,
gateway="unknown",
message="支付网关不可用"
)
if not strategy.supports_currency(request.currency):
return PaymentResponse(
transaction_id="",
status=PaymentStatus.FAILED,
amount=request.amount,
currency=request.currency,
gateway=strategy.name,
message=f"不支持货币: {request.currency}"
)
return strategy.pay(request)
def refund(
self,
transaction_id: str,
amount: Decimal,
gateway: str
) -> RefundResponse:
strategy = self._strategies.get(gateway)
if not strategy:
return RefundResponse(
refund_id="",
transaction_id=transaction_id,
amount=amount,
status=PaymentStatus.FAILED,
message="支付网关不可用"
)
return strategy.refund(transaction_id, amount)
def calculate_fee(self, amount: Decimal, gateway: str = None) -> Decimal:
strategy = self.get_strategy(gateway)
return strategy.calculate_fee(amount) if strategy else Decimal("0")
gateway = PaymentGateway()
gateway.register(AlipayStrategy("alipay_app_id", "private_key"), is_default=True)
gateway.register(WeChatPayStrategy("wx_app_id", "wx_mch_id", "wx_api_key"))
gateway.register(StripeStrategy("sk_test_xxx"))
print("可用支付网关:")
for name in gateway.get_available_gateways():
strategy = gateway.get_strategy(name)
print(f" - {name}: 支持货币 {strategy.supported_currencies}, 手续费率 {strategy.fee_rate}")
print("\n=== 国内支付 ===")
request1 = PaymentRequest(
order_id="ORD001",
amount=Decimal("99.00"),
currency="CNY",
description="测试订单"
)
response1 = gateway.pay(request1, "alipay")
print(f"支付结果: {response1.status.name}, 交易号: {response1.transaction_id}")
print(f"手续费: {gateway.calculate_fee(request1.amount, 'alipay')}")
print("\n=== 国际支付 ===")
request2 = PaymentRequest(
order_id="ORD002",
amount=Decimal("19.99"),
currency="USD",
description="International order"
)
response2 = gateway.pay(request2, "stripe")
print(f"Payment result: {response2.status.name}, Transaction: {response2.transaction_id}")
print(f"Fee: {gateway.calculate_fee(request2.amount, 'stripe')}")21.5.2 数据导出策略
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, TypeVar, Generic
from dataclasses import dataclass, field
from datetime import datetime
from io import StringIO, BytesIO
import json
import csv
T = TypeVar('T')
@dataclass
class ExportConfig:
include_headers: bool = True
date_format: str = "%Y-%m-%d %H:%M:%S"
encoding: str = "utf-8"
delimiter: str = ","
sheet_name: str = "Sheet1"
@dataclass
class ExportResult:
content: bytes
filename: str
content_type: str
size: int
@property
def size_kb(self) -> float:
return self.size / 1024
class ExportStrategy(ABC, Generic[T]):
@property
@abstractmethod
def name(self) -> str:
pass
@property
@abstractmethod
def file_extension(self) -> str:
pass
@property
@abstractmethod
def content_type(self) -> str:
pass
@abstractmethod
def export(self, data: List[T], config: ExportConfig = None) -> ExportResult:
pass
def _get_field_value(self, obj: Any, field: str) -> Any:
if isinstance(obj, dict):
return obj.get(field, "")
return getattr(obj, field, "")
class CSVExportStrategy(ExportStrategy[Dict[str, Any]]):
@property
def name(self) -> str:
return "csv"
@property
def file_extension(self) -> str:
return ".csv"
@property
def content_type(self) -> str:
return "text/csv"
def export(
self,
data: List[Dict[str, Any]],
config: ExportConfig = None
) -> ExportResult:
config = config or ExportConfig()
if not data:
return ExportResult(
content=b"",
filename=f"export{self.file_extension}",
content_type=self.content_type,
size=0
)
output = StringIO()
fieldnames = list(data[0].keys())
writer = csv.DictWriter(
output,
fieldnames=fieldnames,
delimiter=config.delimiter
)
if config.include_headers:
writer.writeheader()
for row in data:
processed_row = {}
for key, value in row.items():
if isinstance(value, datetime):
processed_row[key] = value.strftime(config.date_format)
else:
processed_row[key] = value
writer.writerow(processed_row)
content = output.getvalue().encode(config.encoding)
filename = f"export_{datetime.now().strftime('%Y%m%d_%H%M%S')}{self.file_extension}"
return ExportResult(
content=content,
filename=filename,
content_type=self.content_type,
size=len(content)
)
class JSONExportStrategy(ExportStrategy[Dict[str, Any]]):
@property
def name(self) -> str:
return "json"
@property
def file_extension(self) -> str:
return ".json"
@property
def content_type(self) -> str:
return "application/json"
def export(
self,
data: List[Dict[str, Any]],
config: ExportConfig = None
) -> ExportResult:
config = config or ExportConfig()
def serialize(obj):
if isinstance(obj, datetime):
return obj.strftime(config.date_format)
if isinstance(obj, Decimal):
return str(obj)
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
content = json.dumps(
data,
ensure_ascii=False,
indent=2,
default=serialize
).encode(config.encoding)
filename = f"export_{datetime.now().strftime('%Y%m%d_%H%M%S')}{self.file_extension}"
return ExportResult(
content=content,
filename=filename,
content_type=self.content_type,
size=len(content)
)
class XMLExportStrategy(ExportStrategy[Dict[str, Any]]):
@property
def name(self) -> str:
return "xml"
@property
def file_extension(self) -> str:
return ".xml"
@property
def content_type(self) -> str:
return "application/xml"
def export(
self,
data: List[Dict[str, Any]],
config: ExportConfig = None
) -> ExportResult:
config = config or ExportConfig()
xml_parts = ['<?xml version="1.0" encoding="utf-8"?>', '<records>']
for record in data:
xml_parts.append(' <record>')
for key, value in record.items():
if isinstance(value, datetime):
value = value.strftime(config.date_format)
xml_parts.append(f' <{key}>{value}</{key}>')
xml_parts.append(' </record>')
xml_parts.append('</records>')
content = '\n'.join(xml_parts).encode(config.encoding)
filename = f"export_{datetime.now().strftime('%Y%m%d_%H%M%S')}{self.file_extension}"
return ExportResult(
content=content,
filename=filename,
content_type=self.content_type,
size=len(content)
)
class MarkdownExportStrategy(ExportStrategy[Dict[str, Any]]):
@property
def name(self) -> str:
return "markdown"
@property
def file_extension(self) -> str:
return ".md"
@property
def content_type(self) -> str:
return "text/markdown"
def export(
self,
data: List[Dict[str, Any]],
config: ExportConfig = None
) -> ExportResult:
config = config or ExportConfig()
if not data:
return ExportResult(
content=b"",
filename=f"export{self.file_extension}",
content_type=self.content_type,
size=0
)
fieldnames = list(data[0].keys())
lines = []
lines.append("| " + " | ".join(fieldnames) + " |")
lines.append("| " + " | ".join(["---"] * len(fieldnames)) + " |")
for record in data:
values = []
for key in fieldnames:
value = record.get(key, "")
if isinstance(value, datetime):
value = value.strftime(config.date_format)
values.append(str(value))
lines.append("| " + " | ".join(values) + " |")
content = '\n'.join(lines).encode(config.encoding)
filename = f"export_{datetime.now().strftime('%Y%m%d_%H%M%S')}{self.file_extension}"
return ExportResult(
content=content,
filename=filename,
content_type=self.content_type,
size=len(content)
)
class ExportService:
def __init__(self):
self._strategies: Dict[str, ExportStrategy] = {}
self._default: Optional[str] = None
def register(self, strategy: ExportStrategy, is_default: bool = False) -> None:
self._strategies[strategy.name] = strategy
if is_default or self._default is None:
self._default = strategy.name
def get_available_formats(self) -> List[str]:
return list(self._strategies.keys())
def export(
self,
data: List[Dict[str, Any]],
format: str = None,
config: ExportConfig = None
) -> ExportResult:
format_key = format or self._default
strategy = self._strategies.get(format_key)
if not strategy:
raise ValueError(f"Unknown export format: {format_key}")
return strategy.export(data, config)
from decimal import Decimal
export_service = ExportService()
export_service.register(CSVExportStrategy(), is_default=True)
export_service.register(JSONExportStrategy())
export_service.register(XMLExportStrategy())
export_service.register(MarkdownExportStrategy())
data = [
{"id": 1, "name": "张三", "email": "zhangsan@example.com", "created_at": datetime.now()},
{"id": 2, "name": "李四", "email": "lisi@example.com", "created_at": datetime.now()},
{"id": 3, "name": "王五", "email": "wangwu@example.com", "created_at": datetime.now()},
]
print(f"可用导出格式: {export_service.get_available_formats()}")
for format_name in export_service.get_available_formats():
result = export_service.export(data, format_name)
print(f"\n=== {format_name.upper()} ===")
print(f"文件名: {result.filename}")
print(f"大小: {result.size_kb:.2f} KB")
print(f"内容类型: {result.content_type}")
print(f"内容预览:\n{result.content[:200].decode('utf-8')}...")21.5.3 路由策略模式
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Callable, TypeVar, Generic
from dataclasses import dataclass, field
from enum import Enum, auto
import re
from functools import wraps
T = TypeVar('T')
class HttpMethod(Enum):
GET = auto()
POST = auto()
PUT = auto()
DELETE = auto()
PATCH = auto()
@dataclass
class RouteMatch:
handler: Callable
params: Dict[str, str]
route: str
class RouteStrategy(ABC):
@property
@abstractmethod
def name(self) -> str:
pass
@abstractmethod
def add_route(
self,
method: HttpMethod,
path: str,
handler: Callable
) -> None:
pass
@abstractmethod
def match(
self,
method: HttpMethod,
path: str
) -> Optional[RouteMatch]:
pass
@abstractmethod
def get_routes(self) -> List[str]:
pass
class ExactMatchStrategy(RouteStrategy):
@property
def name(self) -> str:
return "exact"
def __init__(self):
self._routes: Dict[HttpMethod, Dict[str, Callable]] = {}
def add_route(
self,
method: HttpMethod,
path: str,
handler: Callable
) -> None:
if method not in self._routes:
self._routes[method] = {}
self._routes[method][path] = handler
def match(
self,
method: HttpMethod,
path: str
) -> Optional[RouteMatch]:
handlers = self._routes.get(method, {})
handler = handlers.get(path)
if handler:
return RouteMatch(handler=handler, params={}, route=path)
return None
def get_routes(self) -> List[str]:
routes = []
for method, paths in self._routes.items():
for path in paths:
routes.append(f"{method.name} {path}")
return routes
class PrefixMatchStrategy(RouteStrategy):
@property
def name(self) -> str:
return "prefix"
def __init__(self):
self._routes: Dict[HttpMethod, List[tuple[str, Callable]]] = {}
def add_route(
self,
method: HttpMethod,
path: str,
handler: Callable
) -> None:
if method not in self._routes:
self._routes[method] = []
self._routes[method].append((path, handler))
self._routes[method].sort(key=lambda x: len(x[0]), reverse=True)
def match(
self,
method: HttpMethod,
path: str
) -> Optional[RouteMatch]:
routes = self._routes.get(method, [])
for prefix, handler in routes:
if path.startswith(prefix):
remaining = path[len(prefix):]
return RouteMatch(
handler=handler,
params={"remaining": remaining},
route=prefix
)
return None
def get_routes(self) -> List[str]:
routes = []
for method, path_list in self._routes.items():
for path, _ in path_list:
routes.append(f"{method.name} {path}*")
return routes
class RegexMatchStrategy(RouteStrategy):
@property
def name(self) -> str:
return "regex"
def __init__(self):
self._routes: Dict[HttpMethod, List[tuple[re.Pattern, Callable, str]]] = {}
def add_route(
self,
method: HttpMethod,
path: str,
handler: Callable
) -> None:
if method not in self._routes:
self._routes[method] = []
pattern = re.compile(f"^{path}$")
self._routes[method].append((pattern, handler, path))
def match(
self,
method: HttpMethod,
path: str
) -> Optional[RouteMatch]:
routes = self._routes.get(method, [])
for pattern, handler, route in routes:
match = pattern.match(path)
if match:
return RouteMatch(
handler=handler,
params=match.groupdict(),
route=route
)
return None
def get_routes(self) -> List[str]:
routes = []
for method, patterns in self._routes.items():
for _, _, path in patterns:
routes.append(f"{method.name} {path}")
return routes
class Router:
def __init__(self, strategy: RouteStrategy = None):
self._strategy = strategy or RegexMatchStrategy()
self._middleware: List[Callable] = []
@property
def strategy(self) -> RouteStrategy:
return self._strategy
@strategy.setter
def strategy(self, strategy: RouteStrategy) -> None:
self._strategy = strategy
def add_middleware(self, middleware: Callable) -> None:
self._middleware.append(middleware)
def route(self, method: HttpMethod, path: str) -> Callable:
def decorator(handler: Callable) -> Callable:
self._strategy.add_route(method, path, handler)
return handler
return decorator
def get(self, path: str) -> Callable:
return self.route(HttpMethod.GET, path)
def post(self, path: str) -> Callable:
return self.route(HttpMethod.POST, path)
def put(self, path: str) -> Callable:
return self.route(HttpMethod.PUT, path)
def delete(self, path: str) -> Callable:
return self.route(HttpMethod.DELETE, path)
def dispatch(self, method: HttpMethod, path: str) -> Any:
match = self._strategy.match(method, path)
if not match:
return {"error": "Not Found", "path": path}
for middleware in self._middleware:
middleware(method, path, match)
return match.handler(**match.params)
def list_routes(self) -> List[str]:
return self._strategy.get_routes()
router = Router(RegexMatchStrategy())
@router.get("/users")
def list_users():
return {"users": ["user1", "user2", "user3"]}
@router.get("/users/(?P<user_id>\\d+)")
def get_user(user_id: str):
return {"user_id": user_id, "name": f"User {user_id}"}
@router.post("/users")
def create_user():
return {"message": "User created"}
@router.get("/posts/(?P<post_id>\\d+)/comments/(?P<comment_id>\\d+)")
def get_comment(post_id: str, comment_id: str):
return {"post_id": post_id, "comment_id": comment_id}
print(f"路由策略: {router.strategy.name}")
print(f"注册路由:\n " + "\n ".join(router.list_routes()))
print("\n=== 路由测试 ===")
print(f"GET /users: {router.dispatch(HttpMethod.GET, '/users')}")
print(f"GET /users/123: {router.dispatch(HttpMethod.GET, '/users/123')}")
print(f"POST /users: {router.dispatch(HttpMethod.POST, '/users')}")
print(f"GET /posts/1/comments/5: {router.dispatch(HttpMethod.GET, '/posts/1/comments/5')}")
print(f"GET /unknown: {router.dispatch(HttpMethod.GET, '/unknown')}")
print("\n=== 切换策略 ===")
router.strategy = ExactMatchStrategy()
@router.get("/api/health")
def health_check():
return {"status": "ok"}
print(f"路由策略: {router.strategy.name}")
print(f"GET /api/health: {router.dispatch(HttpMethod.GET, '/api/health')}")
print(f"GET /api/health/detail: {router.dispatch(HttpMethod.GET, '/api/health/detail')}")21.6 模式变体
21.6.1 策略模式变体对比
| 变体 | 特点 | 适用场景 | 复杂度 |
|---|---|---|---|
| 经典策略模式 | 每个策略一个类 | 标准OOP设计 | ★☆☆ |
| 函数式策略 | 使用函数作为策略 | Python风格,简单场景 | ★☆☆ |
| 策略工厂 | 通过工厂创建策略 | 需要延迟实例化 | ★★☆ |
| 组合策略 | 多策略组合执行 | 需要多种策略协同 | ★★★ |
| 缓存策略 | 缓存策略执行结果 | 计算密集型策略 | ★★☆ |
21.6.2 策略选择器模式
from typing import Dict, List, Callable, Any, Optional, TypeVar, Generic
from dataclasses import dataclass
from abc import ABC, abstractmethod
T = TypeVar('T')
R = TypeVar('R')
class StrategySelector(ABC, Generic[T]):
@abstractmethod
def select(self, context: T) -> Optional[str]:
pass
class RuleBasedSelector(StrategySelector[Dict[str, Any]]):
def __init__(self):
self._rules: List[tuple[Callable[[Dict], bool], str]] = []
self._default: Optional[str] = None
def add_rule(
self,
condition: Callable[[Dict], bool],
strategy_name: str
) -> None:
self._rules.append((condition, strategy_name))
def set_default(self, strategy_name: str) -> None:
self._default = strategy_name
def select(self, context: Dict[str, Any]) -> Optional[str]:
for condition, strategy_name in self._rules:
if condition(context):
return strategy_name
return self._default
class SmartContext(Generic[T, R]):
def __init__(self):
self._strategies: Dict[str, Strategy[T, R]] = {}
self._selector: Optional[StrategySelector] = None
def register(self, strategy: Strategy[T, R]) -> None:
self._strategies[strategy.name] = strategy
def set_selector(self, selector: StrategySelector) -> None:
self._selector = selector
def execute(self, data: T, context: Dict[str, Any] = None) -> R:
strategy_name = None
if self._selector and context:
strategy_name = self._selector.select(context)
if not strategy_name:
strategy_name = next(iter(self._strategies), None)
strategy = self._strategies.get(strategy_name)
if not strategy:
raise ValueError("No strategy available")
return strategy.execute(data)
selector = RuleBasedSelector()
selector.add_rule(lambda ctx: ctx.get("data_size", 0) < 100, "bubble_sort")
selector.add_rule(lambda ctx: ctx.get("data_size", 0) < 10000, "quick_sort")
selector.add_rule(lambda ctx: ctx.get("stable_required", False), "merge_sort")
selector.set_default("quick_sort")
print("策略选择规则:")
print(" - 数据量 < 100: 冒泡排序")
print(" - 数据量 < 10000: 快速排序")
print(" - 需要稳定排序: 归并排序")
print(" - 默认: 快速排序")
test_contexts = [
{"data_size": 50},
{"data_size": 5000},
{"data_size": 50000, "stable_required": True},
{"data_size": 100000},
]
for ctx in test_contexts:
selected = selector.select(ctx)
print(f"\n上下文 {ctx} -> 选择策略: {selected}")21.7 反模式与最佳实践
21.7.1 常见反模式
反模式1:策略膨胀
class BadExample:
def __init__(self):
self._strategies = {
"strategy_a": StrategyA(),
"strategy_b": StrategyB(),
"strategy_c": StrategyC(),
"strategy_d": StrategyD(),
"strategy_e": StrategyE(),
}
class GoodExample:
def __init__(self):
self._registry = StrategyRegistry()
def get_strategy(self, name: str):
return self._registry.get(name)反模式2:策略泄漏
class BadExample:
def __init__(self, strategy):
self._strategy = strategy
self._strategy._internal_state = {}
class GoodExample:
def __init__(self, strategy):
self._strategy = strategy
self._context_state = {}反模式3:过度抽象
class BadExample:
def execute(self, data):
return self._strategy.execute(
self._preprocessor.process(
self._validator.validate(data)
)
)
class GoodExample:
def execute(self, data):
return self._strategy.execute(data)21.7.2 最佳实践清单
| 实践 | 描述 | 重要性 |
|---|---|---|
| 策略无状态 | 策略不应维护内部状态 | ★★★ |
| 单一职责 | 每个策略只做一件事 | ★★★ |
| 策略注册 | 使用注册表管理策略 | ★★☆ |
| 策略文档 | 提供清晰的策略描述 | ★★☆ |
| 策略测试 | 每个策略独立测试 | ★★★ |
| 性能标注 | 标注策略的时间和空间复杂度 | ★★☆ |
21.8 决策指南
21.8.1 是否使用策略模式
需要算法/行为切换?
│
┌────┴────┐
│ │
否 是
│ │
↓ ↓
不需要模式 切换频率?
│
┌─────────┼─────────┐
│ │ │
运行时 编译时 不切换
│ │ │
↓ ↓ ↓
策略模式 配置/工厂 直接实现21.8.2 实现方式选择
选择实现方式
│
┌────────────────────┼────────────────────┐
│ │ │
↓ ↓ ↓
简单场景 标准场景 复杂场景
│ │ │
↓ ↓ ↓
函数式策略 ABC策略模式 策略工厂+注册表
│ │ │
│ │ ┌───┴───┐
│ │ │ │
↓ ↓ ↓ ↓
Lambda/函数 策略类 组合策略 缓存策略21.8.3 技术选型对照表
| 场景 | 推荐实现 | 理由 |
|---|---|---|
| 支付方式 | ABC策略模式 | 需要状态和配置 |
| 排序算法 | 函数式策略 | 无状态,简单 |
| 数据导出 | 策略工厂 | 延迟加载,可配置 |
| 验证规则 | 组合策略 | 多规则组合 |
| 缓存策略 | 装饰器模式 | 透明缓存 |
21.9 与其他模式的关系
21.9.1 模式组合
策略模式 + 工厂模式:
┌─────────────────────────────────────────────────────────────┐
│ StrategyFactory │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ create(name: str): Strategy │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ │ 创建 │
│ ↓ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ StrategyA │ │ StrategyB │ │ StrategyC │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────┘
策略模式 + 装饰器模式:
┌─────────────────────────────────────────────────────────────┐
│ Context │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Strategy │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ LoggingDecorator │ │ │
│ │ │ ┌─────────────────────────────────────┐ │ │ │
│ │ │ │ CachingDecorator │ │ │ │
│ │ │ │ ┌─────────────────────────────┐ │ │ │ │
│ │ │ │ │ ConcreteStrategy │ │ │ │ │
│ │ │ │ └─────────────────────────────┘ │ │ │ │
│ │ │ └─────────────────────────────────────┘ │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘21.9.2 模式对比
| 模式 | 关系 | 区别 |
|---|---|---|
| 状态模式 | 结构相似 | 状态有转换关系,策略独立 |
| 命令模式 | 都封装行为 | 命令封装请求,策略封装算法 |
| 模板方法模式 | 都定义算法 | 模板方法用继承,策略用组合 |
| 工厂模式 | 常组合使用 | 工厂创建策略实例 |
21.10 快速参考卡
21.10.1 核心概念速查
┌─────────────────────────────────────────────────────────────────┐
│ 策略模式速查卡 │
├─────────────────────────────────────────────────────────────────┤
│ 定义:封装算法族,使它们可以相互替换 │
├─────────────────────────────────────────────────────────────────┤
│ 核心角色: │
│ • Strategy(策略接口):定义算法接口 │
│ • ConcreteStrategy:具体策略实现 │
│ • Context(上下文):维护策略引用,委托执行 │
├─────────────────────────────────────────────────────────────────┤
│ 关键方法: │
│ • execute(data):执行策略 │
│ • set_strategy(strategy):切换策略 │
│ • register_strategy(name, strategy):注册策略 │
├─────────────────────────────────────────────────────────────────┤
│ 策略切换时机: │
│ • 客户端显式切换 │
│ • 基于上下文自动选择 │
│ • 配置文件驱动 │
├─────────────────────────────────────────────────────────────────┤
│ 适用场景: │
│ ✓ 需要在运行时切换算法 │
│ ✓ 有多种实现同一功能的方式 │
│ ✓ 需要避免复杂的条件判断 │
│ ✓ 算法经常变化或需要扩展 │
├─────────────────────────────────────────────────────────────────┤
│ 注意事项: │
│ ✗ 策略应无状态或状态独立 │
│ ✗ 客户端需要了解策略差异 │
│ ✗ 避免策略数量过多 │
└─────────────────────────────────────────────────────────────────┘21.10.2 Python实现模板
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Dict, List
T = TypeVar('T')
R = TypeVar('R')
class Strategy(ABC, Generic[T, R]):
@abstractmethod
def execute(self, data: T) -> R:
pass
@property
def name(self) -> str:
return self.__class__.__name__
class Context(Generic[T, R]):
def __init__(self, strategy: Strategy[T, R] = None):
self._strategy = strategy
self._registry: Dict[str, Strategy[T, R]] = {}
def register(self, strategy: Strategy[T, R]) -> None:
self._registry[strategy.name] = strategy
def use(self, name: str) -> None:
self._strategy = self._registry.get(name)
def execute(self, data: T) -> R:
if not self._strategy:
raise ValueError("No strategy set")
return self._strategy.execute(data)21.11 小结
策略模式是封装算法族的核心模式,通过组合而非继承实现算法的可替换性。本章从形式化定义出发,深入探讨了策略模式的理论基础和多种实现方式。
关键要点:
理论基础:策略模式基于开闭原则,通过抽象策略接口实现算法的独立变化
实现方式:从经典的ABC实现到函数式策略,Python提供了灵活的策略实现选择
企业应用:在支付网关、数据导出、路由匹配等场景中发挥重要作用
模式变体:组合策略、缓存策略、策略工厂等变体扩展了策略模式的应用范围
最佳实践:策略应无状态,使用注册表管理,提供清晰的复杂度标注
策略模式是实现开闭原则的重要手段,理解其原理和实现对于构建可扩展、可维护的系统至关重要。
思考题
策略模式与状态模式在结构上非常相似,它们的本质区别是什么?在什么情况下两者可以互换?
如何设计一个支持策略热加载和动态注册的策略框架?
在策略模式中,如何处理策略之间的依赖关系?
如何实现策略的自动选择,使客户端无需了解具体策略?
在微服务架构中,策略模式如何与配置中心结合实现动态策略切换?