第25章 并发设计模式
学习目标
完成本章学习后,读者将能够:
- 理解并发编程的核心概念与形式化定义
- 掌握线程安全的设计原则与同步机制
- 使用生产者-消费者、读写锁、Actor等并发模式
- 实现高效的异步编程与协程系统
- 识别并发编程中的常见陷阱与反模式
25.1 并发编程概述
25.1.1 核心定义
并发编程(Concurrent Programming) 是一种程序设计范式,它允许多个计算任务在重叠的时间段内执行,以提高系统的响应性和资源利用率。
25.1.2 形式化定义
从形式化角度,并发系统可以定义为一个五元组:
$$\mathcal{C} = \langle P, S, T, \sigma, \phi \rangle$$
其中:
- $P = {p_1, p_2, \ldots, p_n}$ 是进程/线程集合
- $S$ 是共享状态空间
- $T$ 是转换函数集合
- $\sigma: P \times S \rightarrow S$ 是状态转换函数
- $\phi: S \rightarrow {safe, unsafe}$ 是安全性判定函数
线程安全定义:
一个操作 $op$ 是线程安全的,当且仅当对于任意并发执行序列:
$$\forall s \in S, \forall p_i, p_j \in P: op(p_i, op(p_j, s)) = op(p_j, op(p_i, s))$$
互斥条件:
临界区 $CS$ 的互斥访问定义为:
$$\forall t: |{p \in P : p \text{ in } CS \text{ at } t}| \leq 1$$
活性条件:
每个请求进入临界区的进程最终都能进入:
$$\forall p \in P: p \text{ requests } CS \Rightarrow \Diamond(p \text{ enters } CS)$$
25.1.3 并发模型对比
| 特性 | 多线程 | 多进程 | 异步IO | Actor模型 |
|---|---|---|---|---|
| 内存模型 | 共享内存 | 独立内存 | 共享内存 | 消息传递 |
| 通信方式 | 共享变量 | IPC | 事件循环 | 消息 |
| 开销 | 中等 | 高 | 低 | 中等 |
| 适用场景 | CPU密集 | 隔离需求 | IO密集 | 分布式系统 |
| 调试难度 | 高 | 中等 | 低 | 中等 |
25.1.4 并发与并行
┌─────────────────────────────────────────────────────────────────┐
│ 并发 vs 并行 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 并发(Concurrency): │
│ - 多个任务在重叠时间段内推进 │
│ - 单核CPU通过时间片轮转实现 │
│ - 关注任务的结构和调度 │
│ │
│ 并行(Parallelism): │
│ - 多个任务同时执行 │
│ - 需要多核CPU支持 │
│ - 关注任务的执行效率 │
│ │
│ 时间线示例: │
│ │
│ 并发(单核): │
│ Task A: ████████░░░░░░░░████████░░░░░░░░████████ │
│ Task B: ░░░░░░░░████████░░░░░░░░████████░░░░░░░░ │
│ └─ 时间片轮转 ─┘ │
│ │
│ 并行(多核): │
│ Core 1: ████████████████████████████████████████ │
│ Core 2: ████████████████████████████████████████ │
│ └─ 同时执行 ─┘ │
│ │
└─────────────────────────────────────────────────────────────────┘25.2 历史背景与演进
25.2.1 历史发展
| 年代 | 里程碑 | 描述 |
|---|---|---|
| 1965 | Dijkstra信号量 | Edsger Dijkstra提出信号量机制 |
| 1971 | 监控器 | Hoare和Hansen提出监控器概念 |
| 1978 | CSP | Hoare提出通信顺序进程 |
| 1985 | Actor模型 | Hewitt的Actor模型成熟 |
| 1995 | Java线程 | Java内置多线程支持 |
| 2000s | 并发集合 | Java并发包、.NET并行库 |
| 2009 | Go goroutine | Go语言轻量级并发 |
| 2010s | async/await | C#、Python、JS异步语法 |
| 2020s | 结构化并发 | 结构化并发概念普及 |
25.2.2 理论基础
并发编程的理论基础来源于:
- 进程代数:CSP、CCS、π演算
- 时序逻辑:LTL、CTL用于验证并发性质
- Petri网:并发系统的图形化建模
- 线性化:并发数据结构的正确性标准
25.3 线程安全模式
25.3.1 不可变对象模式
python
from dataclasses import dataclass, field
from typing import Tuple, FrozenSet, List, Dict, Any
from functools import reduce
import hashlib
import json
@dataclass(frozen=True)
class ImmutablePoint:
"""不可变点:天然线程安全"""
x: float
y: float
def move(self, dx: float, dy: float) -> 'ImmutablePoint':
return ImmutablePoint(self.x + dx, self.y + dy)
def distance_to(self, other: 'ImmutablePoint') -> float:
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5
@dataclass(frozen=True)
class ImmutableConfig:
"""不可变配置"""
host: str
port: int
database: str
options: FrozenSet[str] = field(default_factory=frozenset)
@staticmethod
def create(host: str, port: int, database: str, *options: str) -> 'ImmutableConfig':
return ImmutableConfig(host, port, database, frozenset(options))
def with_port(self, new_port: int) -> 'ImmutableConfig':
return ImmutableConfig(self.host, new_port, self.database, self.options)
def with_option(self, option: str) -> 'ImmutableConfig':
return ImmutableConfig(self.host, self.port, self.database,
self.options | {option})
class ImmutableList:
"""
不可变列表:所有操作返回新实例
线程安全:无共享可变状态
"""
def __init__(self, items: Tuple = ()):
self._items = items
self._hash: int | None = None
@staticmethod
def of(*items) -> 'ImmutableList':
return ImmutableList(tuple(items))
@staticmethod
def from_list(items: list) -> 'ImmutableList':
return ImmutableList(tuple(items))
def append(self, item: Any) -> 'ImmutableList':
return ImmutableList(self._items + (item,))
def prepend(self, item: Any) -> 'ImmutableList':
return ImmutableList((item,) + self._items)
def remove(self, item: Any) -> 'ImmutableList':
return ImmutableList(tuple(x for x in self._items if x != item))
def map(self, func: callable) -> 'ImmutableList':
return ImmutableList(tuple(func(x) for x in self._items))
def filter(self, predicate: callable) -> 'ImmutableList':
return ImmutableList(tuple(x for x in self._items if predicate(x)))
def reduce(self, func: callable, initial: Any) -> Any:
return reduce(func, self._items, initial)
def __len__(self) -> int:
return len(self._items)
def __getitem__(self, index: int) -> Any:
return self._items[index]
def __iter__(self):
return iter(self._items)
def __hash__(self) -> int:
if self._hash is None:
self._hash = hash(self._items)
return self._hash
def __eq__(self, other) -> bool:
if isinstance(other, ImmutableList):
return self._items == other._items
return False
def __repr__(self) -> str:
return f"ImmutableList({list(self._items)})"
class ImmutableBuilder:
"""
不可变对象构建器:用于构建复杂不可变对象
"""
def __init__(self):
self._data: Dict[str, Any] = {}
def set(self, key: str, value: Any) -> 'ImmutableBuilder':
self._data[key] = value
return self
def update(self, data: Dict[str, Any]) -> 'ImmutableBuilder':
self._data.update(data)
return self
def build(self) -> 'ImmutableRecord':
return ImmutableRecord(self._data.copy())
class ImmutableRecord:
"""
不可变记录:通用不可变数据容器
"""
def __init__(self, data: Dict[str, Any]):
object.__setattr__(self, '_data', data.copy())
def get(self, key: str, default: Any = None) -> Any:
return self._data.get(key, default)
def with_value(self, key: str, value: Any) -> 'ImmutableRecord':
new_data = self._data.copy()
new_data[key] = value
return ImmutableRecord(new_data)
def without(self, key: str) -> 'ImmutableRecord':
new_data = self._data.copy()
new_data.pop(key, None)
return ImmutableRecord(new_data)
def keys(self) -> Tuple:
return tuple(self._data.keys())
def values(self) -> Tuple:
return tuple(self._data.values())
def to_dict(self) -> Dict[str, Any]:
return self._data.copy()
def __getattr__(self, name: str) -> Any:
if name.startswith('_'):
return object.__getattribute__(self, name)
return self._data.get(name)
def __setattr__(self, name: str, value: Any) -> None:
raise AttributeError("ImmutableRecord is immutable")
def __repr__(self) -> str:
return f"ImmutableRecord({self._data})"
point1 = ImmutablePoint(0, 0)
point2 = point1.move(10, 20)
print(f"原始点: ({point1.x}, {point1.y})")
print(f"移动后: ({point2.x}, {point2.y})")
config = ImmutableConfig.create("localhost", 3306, "mydb", "ssl", "timeout")
config2 = config.with_port(5432).with_option("compress")
print(f"\n原配置: {config.host}:{config.port}")
print(f"新配置: {config2.host}:{config2.port}, options: {config2.options}")
nums = ImmutableList.of(1, 2, 3, 4, 5)
nums2 = nums.append(6).map(lambda x: x * 2)
print(f"\n原始列表: {nums}")
print(f"修改后列表: {nums2}")
record = (ImmutableBuilder()
.set('name', '张三')
.set('age', 30)
.set('email', 'zhang@example.com')
.build())
record2 = record.with_value('age', 31)
print(f"\n原始记录: {record}")
print(f"修改后记录: {record2}")25.3.2 线程局部存储模式
python
import threading
from typing import Any, Optional, Callable, Dict
from contextlib import contextmanager
from dataclasses import dataclass
class ThreadLocalStorage:
"""
线程局部存储:每个线程独立的数据存储
"""
def __init__(self):
self._local = threading.local()
def set(self, key: str, value: Any) -> None:
if not hasattr(self._local, 'data'):
self._local.data = {}
self._local.data[key] = value
def get(self, key: str, default: Any = None) -> Any:
if not hasattr(self._local, 'data'):
return default
return self._local.data.get(key, default)
def delete(self, key: str) -> None:
if hasattr(self._local, 'data') and key in self._local.data:
del self._local.data[key]
def clear(self) -> None:
if hasattr(self._local, 'data'):
self._local.data.clear()
def keys(self) -> list:
if not hasattr(self._local, 'data'):
return []
return list(self._local.data.keys())
def has(self, key: str) -> bool:
if not hasattr(self._local, 'data'):
return False
return key in self._local.data
class RequestContext:
"""
请求上下文:使用线程局部存储管理请求级数据
"""
_storage = ThreadLocalStorage()
@classmethod
def set_request_id(cls, request_id: str) -> None:
cls._storage.set('request_id', request_id)
@classmethod
def get_request_id(cls) -> Optional[str]:
return cls._storage.get('request_id')
@classmethod
def set_user(cls, user: dict) -> None:
cls._storage.set('user', user)
@classmethod
def get_user(cls) -> Optional[dict]:
return cls._storage.get('user')
@classmethod
def clear(cls) -> None:
cls._storage.clear()
@classmethod
@contextmanager
def context(cls, request_id: str, user: dict = None):
"""请求上下文管理器"""
cls.set_request_id(request_id)
if user:
cls.set_user(user)
try:
yield
finally:
cls.clear()
class DatabaseConnection:
"""
线程安全的数据库连接管理
"""
_connections = ThreadLocalStorage()
@classmethod
def get_connection(cls, db_name: str) -> 'Connection':
if not cls._connections.has(db_name):
conn = cls._create_connection(db_name)
cls._connections.set(db_name, conn)
return cls._connections.get(db_name)
@classmethod
def _create_connection(cls, db_name: str) -> 'Connection':
print(f"创建新连接: {db_name}")
return {'name': db_name, 'connected': True}
@classmethod
def close_all(cls) -> None:
for key in cls._connections.keys():
conn = cls._connections.get(key)
if conn:
print(f"关闭连接: {key}")
cls._connections.clear()
def worker(thread_id: int):
with RequestContext.context(f"req-{thread_id}", {'id': thread_id, 'name': f'User{thread_id}'}):
print(f"线程 {thread_id}: request_id = {RequestContext.get_request_id()}")
print(f"线程 {thread_id}: user = {RequestContext.get_user()}")
conn = DatabaseConnection.get_connection('main_db')
print(f"线程 {thread_id}: connection = {conn}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()25.3.3 同步包装器模式
python
import threading
from typing import Any, Callable, Dict, List, Set, Optional
from functools import wraps
from contextlib import contextmanager
def synchronized(lock: threading.Lock = None):
"""
同步装饰器:确保方法线程安全
"""
if lock is None:
lock = threading.Lock()
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
with lock:
return func(*args, **kwargs)
wrapper._lock = lock
return wrapper
return decorator
class SynchronizedDict:
"""
线程安全的字典
"""
def __init__(self, initial: Dict = None):
self._data: Dict = initial or {}
self._lock = threading.RLock()
def get(self, key: str, default: Any = None) -> Any:
with self._lock:
return self._data.get(key, default)
def set(self, key: str, value: Any) -> None:
with self._lock:
self._data[key] = value
def delete(self, key: str) -> bool:
with self._lock:
if key in self._data:
del self._data[key]
return True
return False
def contains(self, key: str) -> bool:
with self._lock:
return key in self._data
def keys(self) -> List:
with self._lock:
return list(self._data.keys())
def values(self) -> List:
with self._lock:
return list(self._data.values())
def items(self) -> List:
with self._lock:
return list(self._data.items())
def update(self, other: Dict) -> None:
with self._lock:
self._data.update(other)
def clear(self) -> None:
with self._lock:
self._data.clear()
def compute_if_absent(self, key: str, func: Callable) -> Any:
with self._lock:
if key not in self._data:
self._data[key] = func()
return self._data[key]
def __len__(self) -> int:
with self._lock:
return len(self._data)
def __repr__(self) -> str:
with self._lock:
return f"SynchronizedDict({self._data})"
class SynchronizedList:
"""
线程安全的列表
"""
def __init__(self, initial: List = None):
self._data: List = initial or []
self._lock = threading.RLock()
def append(self, item: Any) -> None:
with self._lock:
self._data.append(item)
def insert(self, index: int, item: Any) -> None:
with self._lock:
self._data.insert(index, item)
def remove(self, item: Any) -> None:
with self._lock:
self._data.remove(item)
def pop(self, index: int = -1) -> Any:
with self._lock:
return self._data.pop(index)
def get(self, index: int) -> Any:
with self._lock:
return self._data[index]
def set(self, index: int, value: Any) -> None:
with self._lock:
self._data[index] = value
def contains(self, item: Any) -> bool:
with self._lock:
return item in self._data
def index(self, item: Any) -> int:
with self._lock:
return self._data.index(item)
def extend(self, items: List) -> None:
with self._lock:
self._data.extend(items)
def clear(self) -> None:
with self._lock:
self._data.clear()
def to_list(self) -> List:
with self._lock:
return self._data.copy()
def __len__(self) -> int:
with self._lock:
return len(self._data)
def __iter__(self):
with self._lock:
return iter(self._data.copy())
def __repr__(self) -> str:
with self._lock:
return f"SynchronizedList({self._data})"
class AtomicReference:
"""
原子引用:线程安全的对象引用
"""
def __init__(self, value: Any = None):
self._value = value
self._lock = threading.Lock()
def get(self) -> Any:
with self._lock:
return self._value
def set(self, new_value: Any) -> None:
with self._lock:
self._value = new_value
def get_and_set(self, new_value: Any) -> Any:
with self._lock:
old_value = self._value
self._value = new_value
return old_value
def compare_and_set(self, expected: Any, new_value: Any) -> bool:
with self._lock:
if self._value == expected:
self._value = new_value
return True
return False
def update(self, func: Callable[[Any], Any]) -> Any:
with self._lock:
self._value = func(self._value)
return self._value
def __repr__(self) -> str:
return f"AtomicReference({self._value})"
shared_dict = SynchronizedDict({'initial': 'value'})
shared_dict.set('new_key', 'new_value')
print(f"同步字典: {shared_dict}")
shared_list = SynchronizedList([1, 2, 3])
shared_list.append(4)
print(f"同步列表: {shared_list}")
counter = AtomicReference(0)
counter.update(lambda x: x + 1)
print(f"原子计数器: {counter}")25.4 同步模式
25.4.1 读写锁模式
python
import threading
from contextlib import contextmanager
from typing import Optional
import time
class ReadWriteLock:
"""
读写锁:允许多个读操作并发,写操作独占
特性:
- 读读共享:多个读者可以同时访问
- 读写互斥:写者需要独占访问
- 写写互斥:写者之间互斥
"""
def __init__(self):
self._read_lock = threading.Lock()
self._write_lock = threading.Lock()
self._readers = 0
self._readers_count_lock = threading.Lock()
@contextmanager
def read_lock(self):
self._acquire_read()
try:
yield
finally:
self._release_read()
@contextmanager
def write_lock(self):
self._acquire_write()
try:
yield
finally:
self._release_write()
def _acquire_read(self) -> None:
with self._readers_count_lock:
self._readers += 1
if self._readers == 1:
self._write_lock.acquire()
def _release_read(self) -> None:
with self._readers_count_lock:
self._readers -= 1
if self._readers == 0:
self._write_lock.release()
def _acquire_write(self) -> None:
self._write_lock.acquire()
def _release_write(self) -> None:
self._write_lock.release()
class ReentrantReadWriteLock:
"""
可重入读写锁:支持同一线程多次获取锁
"""
def __init__(self):
self._lock = threading.Lock()
self._readers = 0
self._writer: Optional[int] = None
self._write_count = 0
self._read_holders: dict = {}
@contextmanager
def read_lock(self):
thread_id = threading.get_ident()
with self._lock:
if self._writer is not None and self._writer != thread_id:
while self._writer is not None:
self._lock.release()
time.sleep(0.001)
self._lock.acquire()
self._readers += 1
self._read_holders[thread_id] = self._read_holders.get(thread_id, 0) + 1
try:
yield
finally:
with self._lock:
self._read_holders[thread_id] -= 1
if self._read_holders[thread_id] == 0:
del self._read_holders[thread_id]
self._readers -= 1
@contextmanager
def write_lock(self):
thread_id = threading.get_ident()
with self._lock:
if self._writer is None:
self._writer = thread_id
elif self._writer != thread_id:
while self._writer is not None:
self._lock.release()
time.sleep(0.001)
self._lock.acquire()
self._write_count += 1
try:
yield
finally:
with self._lock:
self._write_count -= 1
if self._write_count == 0:
self._writer = None
class SharedCache:
"""
使用读写锁的共享缓存
"""
def __init__(self):
self._data: dict = {}
self._lock = ReadWriteLock()
self._stats = {'reads': 0, 'writes': 0, 'hits': 0, 'misses': 0}
def get(self, key: str) -> any:
with self._lock.read_lock():
self._stats['reads'] += 1
if key in self._data:
self._stats['hits'] += 1
return self._data[key]
self._stats['misses'] += 1
return None
def set(self, key: str, value: any) -> None:
with self._lock.write_lock():
self._stats['writes'] += 1
self._data[key] = value
def delete(self, key: str) -> bool:
with self._lock.write_lock():
if key in self._data:
del self._data[key]
return True
return False
def get_all(self) -> dict:
with self._lock.read_lock():
return self._data.copy()
def clear(self) -> None:
with self._lock.write_lock():
self._data.clear()
def get_stats(self) -> dict:
with self._lock.read_lock():
return self._stats.copy()
cache = SharedCache()
cache.set('user:1', {'name': '张三', 'age': 30})
cache.set('user:2', {'name': '李四', 'age': 25})
print(f"缓存命中: {cache.get('user:1')}")
print(f"缓存未命中: {cache.get('user:3')}")
print(f"统计信息: {cache.get_stats()}")25.4.2 屏障模式
python
import threading
from typing import Callable, List, Optional
from dataclasses import dataclass
import time
import random
@dataclass
class BarrierStats:
parties: int
waiting: int
generation: int
class CyclicBarrier:
"""
循环屏障:让一组线程互相等待,直到所有线程都到达屏障点
特性:
- 可重用:屏障可以被重复使用
- 回调支持:所有线程到达后执行回调
"""
def __init__(self, parties: int, action: Callable = None):
self._parties = parties
self._action = action
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
self._waiting = 0
self._generation = 0
self._broken = False
def wait(self, timeout: float = None) -> int:
"""等待其他线程到达屏障"""
with self._condition:
generation = self._generation
if self._broken:
raise RuntimeError("Barrier is broken")
index = self._waiting
self._waiting += 1
if index == self._parties - 1:
self._generation += 1
self._waiting = 0
self._broken = False
if self._action:
try:
self._action()
except Exception:
self._broken = True
self._condition.notify_all()
raise
self._condition.notify_all()
return 0
else:
while generation == self._generation and not self._broken:
if not self._condition.wait(timeout):
self._broken = True
self._condition.notify_all()
raise RuntimeError("Barrier timeout")
if self._broken:
raise RuntimeError("Barrier is broken")
return self._parties - self._waiting
def reset(self) -> None:
"""重置屏障"""
with self._condition:
self._broken = False
self._waiting = 0
self._generation += 1
self._condition.notify_all()
def get_stats(self) -> BarrierStats:
with self._lock:
return BarrierStats(
parties=self._parties,
waiting=self._waiting,
generation=self._generation
)
class Phaser:
"""
分阶段器:支持动态注册/注销参与者的屏障
"""
def __init__(self, parties: int = 0):
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
self._registered = parties
self._arrived = 0
self._phase = 0
def register(self, parties: int = 1) -> int:
"""注册参与者"""
with self._lock:
self._registered += parties
return self._phase
def arrive(self) -> int:
"""到达并继续"""
with self._condition:
phase = self._phase
self._arrived += 1
if self._arrived == self._registered:
self._arrived = 0
self._phase += 1
self._condition.notify_all()
return phase
def arrive_and_await_advance(self) -> int:
"""到达并等待其他参与者"""
with self._condition:
phase = self._phase
self._arrived += 1
if self._arrived == self._registered:
self._arrived = 0
self._phase += 1
self._condition.notify_all()
return self._phase
else:
while self._phase == phase:
self._condition.wait()
return self._phase
def arrive_and_deregister(self) -> int:
"""到达并注销"""
with self._condition:
phase = self._phase
self._arrived += 1
self._registered -= 1
if self._arrived == self._registered:
self._arrived = 0
self._phase += 1
self._condition.notify_all()
return phase
def await_advance(self, phase: int) -> int:
"""等待指定阶段完成"""
with self._condition:
while self._phase == phase:
self._condition.wait()
return self._phase
def get_phase(self) -> int:
with self._lock:
return self._phase
def get_registered(self) -> int:
with self._lock:
return self._registered
def parallel_computation_example():
"""并行计算示例:使用屏障同步多个阶段"""
def phase_action():
print(f"=== 屏障触发 ===")
barrier = CyclicBarrier(3, phase_action)
results = []
def worker(worker_id: int):
for phase in range(3):
print(f"工作者 {worker_id}: 阶段 {phase} 开始")
time.sleep(random.random() * 0.5)
print(f"工作者 {worker_id}: 阶段 {phase} 完成,等待其他线程")
barrier.wait()
print(f"工作者 {worker_id}: 进入下一阶段")
results.append(f"工作者 {worker_id} 完成")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
print(f"\n所有工作者完成: {results}")
parallel_computation_example()25.4.3 信号量模式
python
import threading
from typing import Optional
from contextlib import contextmanager
import time
class Semaphore:
"""
信号量:控制同时访问资源的线程数量
"""
def __init__(self, permits: int):
self._permits = permits
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
def acquire(self, permits: int = 1, timeout: float = None) -> bool:
"""获取许可"""
with self._condition:
while self._permits < permits:
if not self._condition.wait(timeout):
return False
self._permits -= permits
return True
def release(self, permits: int = 1) -> None:
"""释放许可"""
with self._condition:
self._permits += permits
self._condition.notify_all()
@contextmanager
def acquire_context(self, permits: int = 1):
"""上下文管理器方式获取许可"""
self.acquire(permits)
try:
yield
finally:
self.release(permits)
def available_permits(self) -> int:
with self._lock:
return self._permits
class BoundedSemaphore:
"""
有界信号量:许可数量不能超过初始值
"""
def __init__(self, permits: int):
self._max_permits = permits
self._permits = permits
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
def acquire(self, timeout: float = None) -> bool:
with self._condition:
while self._permits <= 0:
if not self._condition.wait(timeout):
return False
self._permits -= 1
return True
def release(self) -> None:
with self._condition:
if self._permits >= self._max_permits:
raise ValueError("Semaphore permits exceeded maximum")
self._permits += 1
self._condition.notify()
@contextmanager
def context(self):
self.acquire()
try:
yield
finally:
self.release()
class ConnectionPool:
"""
连接池:使用信号量控制资源访问
"""
def __init__(self, max_connections: int, create_connection: callable):
self._max_connections = max_connections
self._create_connection = create_connection
self._semaphore = BoundedSemaphore(max_connections)
self._pool: list = []
self._pool_lock = threading.Lock()
self._created = 0
self._created_lock = threading.Lock()
def get_connection(self, timeout: float = None) -> any:
"""获取连接"""
if not self._semaphore.acquire(timeout):
raise TimeoutError("获取连接超时")
with self._pool_lock:
if self._pool:
return self._pool.pop()
with self._created_lock:
self._created += 1
return self._create_connection()
def release_connection(self, conn: any) -> None:
"""释放连接"""
with self._pool_lock:
self._pool.append(conn)
self._semaphore.release()
@contextmanager
def connection(self, timeout: float = None):
"""上下文管理器方式获取连接"""
conn = self.get_connection(timeout)
try:
yield conn
finally:
self.release_connection(conn)
def get_stats(self) -> dict:
return {
'max_connections': self._max_connections,
'created': self._created,
'available': len(self._pool),
'in_use': self._max_connections - self._semaphore._permits
}
def create_db_connection():
return {'id': random.randint(1000, 9999), 'connected': True}
pool = ConnectionPool(3, create_db_connection)
def use_connection(worker_id: int):
with pool.connection() as conn:
print(f"工作者 {worker_id}: 使用连接 {conn['id']}")
time.sleep(0.5)
print(f"工作者 {worker_id}: 释放连接 {conn['id']}")
threads = [threading.Thread(target=use_connection, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
print(f"\n连接池统计: {pool.get_stats()}")25.5 生产者-消费者模式
25.5.1 基础实现
python
import threading
import queue
from typing import Any, Callable, Optional, List
from dataclasses import dataclass, field
from enum import Enum, auto
import time
import random
class TaskStatus(Enum):
PENDING = auto()
PROCESSING = auto()
COMPLETED = auto()
FAILED = auto()
@dataclass
class Task:
"""任务定义"""
id: int
data: Any
status: TaskStatus = TaskStatus.PENDING
result: Any = None
error: Optional[str] = None
created_at: float = field(default_factory=time.time)
completed_at: Optional[float] = None
class Producer:
"""生产者:生成任务"""
def __init__(self, task_queue: queue.Queue, name: str = "Producer"):
self._queue = task_queue
self._name = name
self._task_counter = 0
def produce(self, data: Any, priority: int = 0) -> Task:
"""生产任务"""
self._task_counter += 1
task = Task(id=self._task_counter, data=data)
if hasattr(self._queue, 'put_with_priority'):
self._queue.put_with_priority(task, priority)
else:
self._queue.put(task)
print(f"[{self._name}] 生产任务 {task.id}")
return task
def produce_batch(self, data_list: List[Any]) -> List[Task]:
"""批量生产任务"""
return [self.produce(data) for data in data_list]
class Consumer:
"""消费者:处理任务"""
def __init__(
self,
task_queue: queue.Queue,
name: str,
processor: Callable[[Task], Any]
):
self._queue = task_queue
self._name = name
self._processor = processor
self._running = False
self._thread: Optional[threading.Thread] = None
self._processed_count = 0
self._error_count = 0
def start(self) -> None:
self._running = True
self._thread = threading.Thread(target=self._consume, daemon=True)
self._thread.start()
print(f"[{self._name}] 消费者启动")
def stop(self) -> None:
self._running = False
if self._thread:
self._thread.join(timeout=2)
print(f"[{self._name}] 消费者停止,处理 {self._processed_count} 个任务,错误 {self._error_count} 个")
def _consume(self) -> None:
while self._running:
try:
task = self._queue.get(timeout=1)
if task is None:
break
self._process_task(task)
self._queue.task_done()
except queue.Empty:
continue
except Exception as e:
print(f"[{self._name}] 错误: {e}")
self._error_count += 1
def _process_task(self, task: Task) -> None:
task.status = TaskStatus.PROCESSING
try:
result = self._processor(task)
task.result = result
task.status = TaskStatus.COMPLETED
task.completed_at = time.time()
self._processed_count += 1
print(f"[{self._name}] 完成任务 {task.id}: {result}")
except Exception as e:
task.status = TaskStatus.FAILED
task.error = str(e)
self._error_count += 1
print(f"[{self._name}] 任务 {task.id} 失败: {e}")
class ProducerConsumerSystem:
"""生产者-消费者系统"""
def __init__(
self,
queue_size: int = 10,
consumer_count: int = 2,
processor: Callable[[Task], Any] = None
):
self._queue = queue.Queue(maxsize=queue_size)
self._consumers: List[Consumer] = []
self._producers: List[Producer] = []
self._default_processor = processor or (lambda t: f"处理: {t.data}")
for i in range(consumer_count):
consumer = Consumer(self._queue, f"Consumer-{i+1}", self._default_processor)
self._consumers.append(consumer)
def create_producer(self, name: str = None) -> Producer:
producer = Producer(self._queue, name or f"Producer-{len(self._producers)+1}")
self._producers.append(producer)
return producer
def start(self) -> None:
for consumer in self._consumers:
consumer.start()
def stop(self) -> None:
for consumer in self._consumers:
consumer.stop()
def wait_completion(self, timeout: float = None) -> bool:
"""等待所有任务完成"""
return self._queue.join()
def process_task(task: Task) -> str:
time.sleep(random.random() * 0.5)
return f"结果-{task.data}"
system = ProducerConsumerSystem(queue_size=20, consumer_count=3, processor=process_task)
system.start()
producer = system.create_producer("MainProducer")
for i in range(10):
producer.produce(f"数据-{i}")
time.sleep(0.1)
system._queue.join()
system.stop()25.5.2 优先级队列
python
import threading
import queue
from typing import Any, Callable, Optional, List, Tuple
from dataclasses import dataclass, field
import heapq
import time
@dataclass(order=True)
class PriorityTask:
"""优先级任务"""
priority: int
sequence: int = field(compare=True)
task: Any = field(compare=False)
class PriorityTaskQueue:
"""
优先级任务队列:高优先级任务优先处理
"""
def __init__(self, maxsize: int = 0):
self._maxsize = maxsize
self._queue: List[PriorityTask] = []
self._lock = threading.Lock()
self._not_empty = threading.Condition(self._lock)
self._not_full = threading.Condition(self._lock)
self._sequence = 0
self._unfinished_tasks = 0
def put(self, task: Any, priority: int = 0, block: bool = True, timeout: float = None) -> None:
"""添加任务"""
with self._not_full:
if self._maxsize > 0:
while len(self._queue) >= self._maxsize:
if not block:
raise queue.Full()
if not self._not_full.wait(timeout):
raise queue.Full()
with self._lock:
self._sequence += 1
priority_task = PriorityTask(priority, self._sequence, task)
heapq.heappush(self._queue, priority_task)
self._unfinished_tasks += 1
self._not_empty.notify()
def get(self, block: bool = True, timeout: float = None) -> Any:
"""获取任务"""
with self._not_empty:
while not self._queue:
if not block:
raise queue.Empty()
if not self._not_empty.wait(timeout):
raise queue.Empty()
with self._lock:
priority_task = heapq.heappop(self._queue)
return priority_task.task
def task_done(self) -> None:
with self._lock:
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._not_empty.notify_all()
def join(self) -> None:
with self._not_empty:
while self._unfinished_tasks:
self._not_empty.wait()
def qsize(self) -> int:
with self._lock:
return len(self._queue)
def empty(self) -> bool:
with self._lock:
return len(self._queue) == 0
class DelayedTaskQueue:
"""
延迟任务队列:任务在指定时间后才能被获取
"""
def __init__(self):
self._queue: List[Tuple[float, int, Any]] = []
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
self._sequence = 0
def put(self, task: Any, delay_seconds: float = 0) -> None:
"""添加延迟任务"""
execute_time = time.time() + delay_seconds
with self._condition:
self._sequence += 1
heapq.heappush(self._queue, (execute_time, self._sequence, task))
self._condition.notify()
def get(self, timeout: float = None) -> Any:
"""获取到期任务"""
with self._condition:
while True:
while not self._queue:
if not self._condition.wait(timeout):
raise queue.Empty()
execute_time, _, task = self._queue[0]
now = time.time()
if execute_time <= now:
heapq.heappop(self._queue)
return task
wait_time = execute_time - now
if timeout is not None and wait_time > timeout:
raise queue.Empty()
self._condition.wait(wait_time)
def qsize(self) -> int:
with self._lock:
return len(self._queue)
priority_queue = PriorityTaskQueue(maxsize=10)
priority_queue.put("低优先级任务", priority=10)
priority_queue.put("高优先级任务", priority=1)
priority_queue.put("中优先级任务", priority=5)
print("优先级队列处理顺序:")
while not priority_queue.empty():
task = priority_queue.get()
print(f" 处理: {task}")
priority_queue.task_done()
delayed_queue = DelayedTaskQueue()
delayed_queue.put("立即任务", delay_seconds=0)
delayed_queue.put("2秒后任务", delay_seconds=2)
delayed_queue.put("1秒后任务", delay_seconds=1)
print("\n延迟队列处理顺序:")
for _ in range(3):
task = delayed_queue.get()
print(f" 处理: {task}")25.6 线程池模式
25.6.1 自定义线程池
python
import threading
import queue
from typing import Callable, Any, List, Optional, Dict
from dataclasses import dataclass, field
from concurrent.futures import Future
from enum import Enum, auto
import time
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class PoolState(Enum):
RUNNING = auto()
SHUTDOWN = auto()
TERMINATED = auto()
@dataclass
class WorkerStats:
tasks_completed: int = 0
tasks_failed: int = 0
total_time: float = 0.0
class Worker(threading.Thread):
"""工作线程"""
def __init__(self, pool: 'ThreadPool', name: str):
super().__init__(name=name, daemon=True)
self._pool = pool
self._stats = WorkerStats()
self._current_task = None
def run(self) -> None:
while True:
task = self._pool._get_task()
if task is None:
break
self._current_task = task
self._execute_task(task)
self._current_task = None
def _execute_task(self, task: 'ThreadPoolTask') -> None:
start_time = time.time()
try:
result = task.func(*task.args, **task.kwargs)
task.future.set_result(result)
self._stats.tasks_completed += 1
except Exception as e:
task.future.set_exception(e)
self._stats.tasks_failed += 1
logger.error(f"任务执行失败: {e}")
finally:
self._stats.total_time += time.time() - start_time
def get_stats(self) -> WorkerStats:
return self._stats
@dataclass
class ThreadPoolTask:
func: Callable
args: tuple
kwargs: dict
future: Future
class ThreadPool:
"""
线程池:管理和复用工作线程
"""
def __init__(
self,
max_workers: int = 4,
queue_size: int = 100,
thread_name_prefix: str = "Worker"
):
self._max_workers = max_workers
self._queue_size = queue_size
self._thread_name_prefix = thread_name_prefix
self._task_queue: queue.Queue[Optional[ThreadPoolTask]] = queue.Queue(maxsize=queue_size)
self._workers: List[Worker] = []
self._state = PoolState.RUNNING
self._lock = threading.Lock()
self._submitted_count = 0
self._completed_count = 0
def start(self) -> None:
"""启动线程池"""
with self._lock:
if self._state != PoolState.RUNNING:
return
for i in range(self._max_workers):
worker = Worker(self, f"{self._thread_name_prefix}-{i+1}")
worker.start()
self._workers.append(worker)
logger.info(f"线程池启动,工作线程数: {self._max_workers}")
def submit(self, func: Callable, *args, **kwargs) -> Future:
"""提交任务"""
with self._lock:
if self._state != PoolState.RUNNING:
raise RuntimeError("线程池已关闭")
future = Future()
task = ThreadPoolTask(func, args, kwargs, future)
self._task_queue.put(task)
self._submitted_count += 1
return future
def map(self, func: Callable, iterable: List, timeout: float = None) -> List:
"""批量提交任务并等待结果"""
futures = [self.submit(func, item) for item in iterable]
return [f.result(timeout) for f in futures]
def shutdown(self, wait: bool = True) -> None:
"""关闭线程池"""
with self._lock:
if self._state == PoolState.TERMINATED:
return
self._state = PoolState.SHUTDOWN
for _ in self._workers:
self._task_queue.put(None)
if wait:
for worker in self._workers:
worker.join()
with self._lock:
self._state = PoolState.TERMINATED
logger.info("线程池已关闭")
def _get_task(self) -> Optional[ThreadPoolTask]:
return self._task_queue.get()
def get_stats(self) -> Dict:
"""获取统计信息"""
with self._lock:
worker_stats = [w.get_stats() for w in self._workers]
return {
'state': self._state.name,
'max_workers': self._max_workers,
'active_workers': sum(1 for w in self._workers if w.is_alive()),
'queue_size': self._task_queue.qsize(),
'submitted': self._submitted_count,
'completed': sum(s.tasks_completed for s in worker_stats),
'failed': sum(s.tasks_failed for s in worker_stats),
'total_time': sum(s.total_time for s in worker_stats)
}
def compute_square(n: int) -> int:
time.sleep(0.1)
return n ** 2
pool = ThreadPool(max_workers=3, queue_size=10)
pool.start()
futures = [pool.submit(compute_square, i) for i in range(6)]
for i, future in enumerate(futures):
print(f"任务 {i}: {future.result()}")
print(f"\n线程池统计: {pool.get_stats()}")
pool.shutdown()25.6.2 使用concurrent.futures
python
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from typing import Callable, List, Any, Dict
import time
import os
class ExecutorManager:
"""
执行器管理器:封装concurrent.futures
"""
def __init__(self, max_workers: int = None, use_process: bool = False):
self._max_workers = max_workers or os.cpu_count()
self._use_process = use_process
self._executor = None
def __enter__(self):
if self._use_process:
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
else:
self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._executor:
self._executor.shutdown(wait=True)
def submit(self, func: Callable, *args, **kwargs):
return self._executor.submit(func, *args, **kwargs)
def map(self, func: Callable, iterable: List, timeout: float = None) -> List:
return list(self._executor.map(func, iterable, timeout=timeout))
def map_with_progress(
self,
func: Callable,
iterable: List,
callback: Callable[[int, int, Any], None] = None
) -> List[Any]:
"""带进度的批量执行"""
futures = {self._executor.submit(func, item): i for i, item in enumerate(iterable)}
results = [None] * len(iterable)
completed = 0
for future in as_completed(futures):
index = futures[future]
results[index] = future.result()
completed += 1
if callback:
callback(completed, len(iterable), results[index])
return results
def submit_with_timeout(
self,
func: Callable,
args: tuple = (),
kwargs: dict = None,
timeout: float = None
) -> Any:
"""带超时的任务提交"""
kwargs = kwargs or {}
future = self._executor.submit(func, *args, **kwargs)
return future.result(timeout=timeout)
def cpu_intensive_task(n: int) -> int:
"""CPU密集型任务"""
result = 0
for i in range(n):
result += i ** 2
return result
def io_intensive_task(url: str) -> dict:
"""IO密集型任务"""
time.sleep(0.1)
return {'url': url, 'status': 'success'}
print("=== 线程池执行IO密集型任务 ===")
with ExecutorManager(max_workers=4) as executor:
urls = [f"http://example.com/{i}" for i in range(10)]
def progress_callback(completed, total, result):
print(f"进度: {completed}/{total}")
results = executor.map_with_progress(io_intensive_task, urls, progress_callback)
print(f"完成 {len(results)} 个任务")
print("\n=== 进程池执行CPU密集型任务 ===")
with ExecutorManager(max_workers=4, use_process=True) as executor:
numbers = [100000, 200000, 300000, 400000]
results = executor.map(cpu_intensive_task, numbers)
print(f"计算结果: {results}")25.7 异步编程模式
25.7.1 异步上下文管理器
python
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator, Any, Optional
import time
class AsyncResource:
"""异步资源"""
def __init__(self, name: str):
self.name = name
self._acquired = False
async def acquire(self) -> None:
print(f"获取资源: {self.name}")
await asyncio.sleep(0.1)
self._acquired = True
async def release(self) -> None:
print(f"释放资源: {self.name}")
await asyncio.sleep(0.05)
self._acquired = False
async def use(self) -> str:
if not self._acquired:
raise RuntimeError("资源未获取")
return f"使用资源: {self.name}"
class AsyncLock:
"""异步锁"""
def __init__(self):
self._locked = False
self._waiters = asyncio.Queue()
async def acquire(self) -> None:
if not self._locked:
self._locked = True
return
future = asyncio.get_event_loop().create_future()
await self._waiters.put(future)
await future
async def release(self) -> None:
if self._waiters.empty():
self._locked = False
else:
future = await self._waiters.get()
future.set_result(None)
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.release()
class AsyncSemaphore:
"""异步信号量"""
def __init__(self, permits: int):
self._permits = permits
self._lock = asyncio.Lock()
self._condition = asyncio.Condition(self._lock)
async def acquire(self) -> None:
async with self._condition:
while self._permits <= 0:
await self._condition.wait()
self._permits -= 1
async def release(self) -> None:
async with self._condition:
self._permits += 1
self._condition.notify()
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.release()
@asynccontextmanager
async def async_resource(name: str) -> AsyncIterator[AsyncResource]:
"""异步资源上下文管理器"""
resource = AsyncResource(name)
await resource.acquire()
try:
yield resource
finally:
await resource.release()
async def async_resource_example():
async with async_resource("数据库连接") as conn:
result = await conn.use()
print(result)
asyncio.run(async_resource_example())25.7.2 异步迭代器与生成器
python
import asyncio
from typing import AsyncIterator, AsyncGenerator, List, Any
import random
class AsyncRange:
"""异步范围迭代器"""
def __init__(self, start: int, end: int, delay: float = 0.1):
self._start = start
self._end = end
self._delay = delay
def __aiter__(self) -> 'AsyncRange':
self._current = self._start
return self
async def __anext__(self) -> int:
if self._current >= self._end:
raise StopAsyncIteration
await asyncio.sleep(self._delay)
value = self._current
self._current += 1
return value
async def async_data_stream(count: int, delay: float = 0.1) -> AsyncGenerator[dict, None]:
"""异步数据流生成器"""
for i in range(count):
await asyncio.sleep(delay)
yield {
'id': i,
'value': random.random() * 100,
'timestamp': asyncio.get_event_loop().time()
}
class AsyncBatchProcessor:
"""异步批量处理器"""
def __init__(self, batch_size: int = 10, flush_interval: float = 1.0):
self._batch_size = batch_size
self._flush_interval = flush_interval
self._batch: List[Any] = []
self._lock = asyncio.Lock()
self._last_flush = time.time()
async def add(self, item: Any) -> None:
async with self._lock:
self._batch.append(item)
if len(self._batch) >= self._batch_size:
await self._flush()
async def _flush(self) -> None:
if not self._batch:
return
batch = self._batch.copy()
self._batch.clear()
print(f"批量处理 {len(batch)} 条数据")
await asyncio.sleep(0.1)
async def close(self) -> None:
async with self._lock:
await self._flush()
async def process_async_stream():
"""处理异步流"""
print("=== 异步迭代器 ===")
result: List[int] = []
async for num in AsyncRange(0, 5):
print(f"处理: {num}")
result.append(num)
print(f"结果: {result}")
print("\n=== 异步生成器 ===")
async for data in async_data_stream(5):
print(f"接收: id={data['id']}, value={data['value']:.2f}")
print("\n=== 异步批量处理 ===")
processor = AsyncBatchProcessor(batch_size=3)
async for data in async_data_stream(10, delay=0.05):
await processor.add(data)
await processor.close()
asyncio.run(process_async_stream())25.7.3 异步任务编排
python
import asyncio
from typing import List, Any, Callable, Coroutine
from dataclasses import dataclass
import time
@dataclass
class TaskResult:
index: int
result: Any
error: Exception = None
duration: float = 0.0
class AsyncTaskOrchestrator:
"""异步任务编排器"""
def __init__(self):
self._tasks: List[asyncio.Task] = []
async def gather_with_exceptions(
self,
coroutines: List[Coroutine],
return_exceptions: bool = True
) -> List[TaskResult]:
"""收集结果,处理异常"""
results = []
async def run_with_timing(index: int, coro: Coroutine) -> TaskResult:
start = time.time()
try:
result = await coro
return TaskResult(index=index, result=result, duration=time.time() - start)
except Exception as e:
return TaskResult(index=index, result=None, error=e, duration=time.time() - start)
tasks = [run_with_timing(i, coro) for i, coro in enumerate(coroutines)]
results = await asyncio.gather(*tasks, return_exceptions=return_exceptions)
return results
async def run_with_timeout(
self,
coro: Coroutine,
timeout: float
) -> TaskResult:
"""带超时的任务执行"""
start = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
return TaskResult(index=0, result=result, duration=time.time() - start)
except asyncio.TimeoutError:
return TaskResult(index=0, result=None, error=asyncio.TimeoutError(), duration=timeout)
except Exception as e:
return TaskResult(index=0, result=None, error=e, duration=time.time() - start)
async def run_with_retry(
self,
coro_factory: Callable[[], Coroutine],
max_retries: int = 3,
delay: float = 1.0
) -> TaskResult:
"""带重试的任务执行"""
last_error = None
for attempt in range(max_retries + 1):
start = time.time()
try:
result = await coro_factory()
return TaskResult(index=attempt, result=result, duration=time.time() - start)
except Exception as e:
last_error = e
if attempt < max_retries:
await asyncio.sleep(delay)
return TaskResult(index=max_retries, result=None, error=last_error)
async def run_pipeline(
self,
initial_value: Any,
stages: List[Callable[[Any], Coroutine]]
) -> Any:
"""管道式执行"""
value = initial_value
for i, stage in enumerate(stages):
value = await stage(value)
print(f"阶段 {i+1} 完成: {value}")
return value
async def run_parallel_with_limit(
self,
coroutines: List[Coroutine],
max_concurrent: int
) -> List[TaskResult]:
"""限制并发数的并行执行"""
semaphore = asyncio.Semaphore(max_concurrent)
results = []
async def run_with_semaphore(index: int, coro: Coroutine) -> TaskResult:
async with semaphore:
start = time.time()
try:
result = await coro
return TaskResult(index=index, result=result, duration=time.time() - start)
except Exception as e:
return TaskResult(index=index, result=None, error=e, duration=time.time() - start)
tasks = [run_with_semaphore(i, coro) for i, coro in enumerate(coroutines)]
results = await asyncio.gather(*tasks)
return results
async def simulate_api_call(call_id: int, delay: float = 0.5) -> dict:
await asyncio.sleep(delay)
if random.random() < 0.2:
raise ValueError(f"API调用 {call_id} 失败")
return {'id': call_id, 'status': 'success'}
async def orchestrator_example():
orchestrator = AsyncTaskOrchestrator()
print("=== 并行执行(带异常处理)===")
coroutines = [simulate_api_call(i, random.uniform(0.1, 0.5)) for i in range(5)]
results = await orchestrator.gather_with_exceptions(coroutines)
for r in results:
if r.error:
print(f"任务 {r.index}: 错误 - {r.error}")
else:
print(f"任务 {r.index}: {r.result}")
print("\n=== 带超时的执行 ===")
result = await orchestrator.run_with_timeout(simulate_api_call(1, 2.0), timeout=1.0)
print(f"超时结果: {result}")
print("\n=== 管道式执行 ===")
async def stage1(x):
return x * 2
async def stage2(x):
return x + 10
async def stage3(x):
return f"最终结果: {x}"
final = await orchestrator.run_pipeline(5, [stage1, stage2, stage3])
print(f"管道结果: {final}")
asyncio.run(orchestrator_example())25.8 Actor模式
python
import threading
import queue
from typing import Any, Callable, Dict, Optional, List
from dataclasses import dataclass, field
from enum import Enum, auto
import time
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MessagePriority(Enum):
LOW = 0
NORMAL = 1
HIGH = 2
URGENT = 3
@dataclass
class Message:
"""Actor消息"""
sender: str
content: Any
priority: MessagePriority = MessagePriority.NORMAL
reply_to: Optional['Actor'] = None
timestamp: float = field(default_factory=time.time)
class Actor:
"""
Actor基类:基于消息传递的并发模型
特性:
- 独立状态:每个Actor维护自己的状态
- 消息驱动:通过消息进行通信
- 单线程处理:每个Actor串行处理消息
"""
def __init__(self, name: str):
self.name = name
self._mailbox: queue.PriorityQueue = queue.PriorityQueue()
self._running = False
self._thread: Optional[threading.Thread] = None
self._handlers: Dict[str, Callable] = {}
self._state: Dict[str, Any] = {}
self._message_count = 0
def on(self, message_type: str, handler: Callable) -> None:
"""注册消息处理器"""
self._handlers[message_type] = handler
def start(self) -> None:
"""启动Actor"""
self._running = True
self._thread = threading.Thread(target=self._run, name=f"Actor-{self.name}", daemon=True)
self._thread.start()
logger.info(f"Actor {self.name} 启动")
def stop(self) -> None:
"""停止Actor"""
self._running = False
if self._thread:
self._thread.join(timeout=2)
logger.info(f"Actor {self.name} 停止,处理消息数: {self._message_count}")
def send(self, recipient: 'Actor', content: Any, priority: MessagePriority = MessagePriority.NORMAL) -> None:
"""发送消息"""
message = Message(
sender=self.name,
content=content,
priority=priority
)
recipient._mailbox.put((priority.value, time.time(), message))
def send_and_wait(self, recipient: 'Actor', content: Any, timeout: float = 5.0) -> Any:
"""发送消息并等待回复"""
reply_queue = queue.Queue()
message = Message(
sender=self.name,
content=content,
reply_to=ReplyActor(f"Reply-{self.name}", reply_queue)
)
recipient._mailbox.put((MessagePriority.NORMAL.value, time.time(), message))
return reply_queue.get(timeout=timeout)
def _run(self) -> None:
"""Actor主循环"""
while self._running:
try:
_, _, message = self._mailbox.get(timeout=1)
self._handle(message)
self._message_count += 1
except queue.Empty:
continue
except Exception as e:
logger.error(f"Actor {self.name} 处理消息错误: {e}")
def _handle(self, message: Message) -> None:
"""处理消息"""
content = message.content
if isinstance(content, dict):
msg_type = content.get('type')
if msg_type and msg_type in self._handlers:
result = self._handlers[msg_type](message)
if message.reply_to:
message.reply_to.send(self, {'type': 'reply', 'result': result})
elif callable(self._handlers.get('default')):
self._handlers['default'](message)
class ReplyActor(Actor):
"""回复Actor:用于同步等待回复"""
def __init__(self, name: str, reply_queue: queue.Queue):
super().__init__(name)
self._reply_queue = reply_queue
self.on('reply', self._handle_reply)
def _handle_reply(self, message: Message) -> None:
self._reply_queue.put(message.content.get('result'))
class CounterActor(Actor):
"""计数器Actor示例"""
def __init__(self, name: str):
super().__init__(name)
self._count = 0
self.on('increment', self._increment)
self.on('decrement', self._decrement)
self.on('get', self._get_count)
self.on('reset', self._reset)
def _increment(self, message: Message) -> int:
amount = message.content.get('amount', 1)
self._count += amount
logger.info(f"{self.name}: 计数增加到 {self._count}")
return self._count
def _decrement(self, message: Message) -> int:
amount = message.content.get('amount', 1)
self._count -= amount
logger.info(f"{self.name}: 计数减少到 {self._count}")
return self._count
def _get_count(self, message: Message) -> int:
logger.info(f"{self.name}: 当前计数为 {self._count}")
return self._count
def _reset(self, message: Message) -> int:
self._count = 0
logger.info(f"{self.name}: 计数重置为 0")
return self._count
class ActorSystem:
"""Actor系统:管理多个Actor"""
def __init__(self):
self._actors: Dict[str, Actor] = {}
self._lock = threading.Lock()
def create_actor(self, actor_class: type, name: str, *args, **kwargs) -> Actor:
"""创建Actor"""
with self._lock:
if name in self._actors:
raise ValueError(f"Actor {name} 已存在")
actor = actor_class(name, *args, **kwargs)
self._actors[name] = actor
actor.start()
return actor
def get_actor(self, name: str) -> Optional[Actor]:
"""获取Actor"""
return self._actors.get(name)
def stop_all(self) -> None:
"""停止所有Actor"""
for actor in self._actors.values():
actor.stop()
self._actors.clear()
def broadcast(self, content: Any) -> None:
"""广播消息"""
for actor in self._actors.values():
actor.send(actor, content)
system = ActorSystem()
counter1 = system.create_actor(CounterActor, "Counter1")
counter2 = system.create_actor(CounterActor, "Counter2")
counter1.send(counter1, {'type': 'increment', 'amount': 5})
counter1.send(counter1, {'type': 'increment', 'amount': 3})
counter1.send(counter1, {'type': 'get'})
counter2.send(counter2, {'type': 'increment', 'amount': 10})
counter2.send(counter2, {'type': 'decrement', 'amount': 2})
counter2.send(counter2, {'type': 'get'})
time.sleep(1)
system.stop_all()25.9 反模式与最佳实践
25.9.1 常见反模式
python
import threading
import time
class ConcurrencyAntiPatterns:
"""并发反模式示例"""
@staticmethod
def deadlock_example():
"""
反模式:死锁
两个线程互相等待对方持有的锁
"""
lock1 = threading.Lock()
lock2 = threading.Lock()
def thread1():
with lock1:
time.sleep(0.1)
with lock2:
print("Thread 1 完成")
def thread2():
with lock2:
time.sleep(0.1)
with lock1:
print("Thread 2 完成")
print("死锁示例(跳过实际执行)")
@staticmethod
def race_condition_example():
"""
反模式:竞态条件
多线程访问共享变量未加锁
"""
counter = 0
def increment():
nonlocal counter
for _ in range(10000):
temp = counter
temp += 1
counter = temp
threads = [threading.Thread(target=increment) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
print(f"竞态条件示例: 期望 50000, 实际 {counter}")
@staticmethod
def correct_counter():
"""
正确做法:使用锁保护共享变量
"""
counter = 0
lock = threading.Lock()
def increment():
nonlocal counter
for _ in range(10000):
with lock:
counter += 1
threads = [threading.Thread(target=increment) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
print(f"正确计数器: {counter}")
ConcurrencyAntiPatterns.race_condition_example()
ConcurrencyAntiPatterns.correct_counter()25.9.2 最佳实践
python
import threading
import queue
from typing import Any, Callable
from contextlib import contextmanager
from dataclasses import dataclass
import time
@dataclass
class ConcurrencyBestPractices:
"""并发最佳实践"""
@staticmethod
@contextmanager
def timeout_lock(lock: threading.Lock, timeout: float):
"""带超时的锁获取"""
acquired = lock.acquire(timeout=timeout)
if not acquired:
raise TimeoutError("获取锁超时")
try:
yield
finally:
lock.release()
@staticmethod
def safe_shared_counter():
"""线程安全的计数器"""
class SafeCounter:
def __init__(self, initial: int = 0):
self._value = initial
self._lock = threading.Lock()
def increment(self, delta: int = 1) -> int:
with self._lock:
self._value += delta
return self._value
def decrement(self, delta: int = 1) -> int:
with self._lock:
self._value -= delta
return self._value
def get(self) -> int:
with self._lock:
return self._value
def set(self, value: int) -> None:
with self._lock:
self._value = value
return SafeCounter()
@staticmethod
def graceful_shutdown_example():
"""优雅关闭示例"""
class Service:
def __init__(self):
self._running = False
self._thread = None
self._stop_event = threading.Event()
def start(self):
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self, timeout: float = 5.0):
self._running = False
self._stop_event.set()
if self._thread:
self._thread.join(timeout=timeout)
def _run(self):
while not self._stop_event.is_set():
try:
pass
except Exception:
pass
return Service()
print("=== 最佳实践示例 ===")
counter = ConcurrencyBestPractices.safe_shared_counter()
def worker(counter, worker_id):
for _ in range(1000):
counter.increment()
print(f"工作者 {worker_id} 完成")
threads = [threading.Thread(target=worker, args=(counter, i)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
print(f"最终计数: {counter.get()}")25.10 小结
并发设计模式解决的核心问题是如何在多线程/多进程环境中安全、高效地协调任务执行。
关键要点
- 线程安全:通过不可变对象、锁机制、线程局部存储保证
- 同步协调:通过屏障、信号量、读写锁协调线程
- 任务分发:通过生产者-消费者、线程池模式分发任务
- 异步处理:通过async/await、异步生成器实现非阻塞IO
- Actor模型:通过消息传递实现无共享并发
实践建议
- 优先使用不可变对象,避免共享状态
- 使用高级同步原语(如队列、信号量)而非低级锁
- 选择合适的并发模型(线程/进程/协程)
- 设置合理的超时时间,避免死锁
- 实现优雅关闭机制
下一章预告
下一章将介绍架构设计模式,探讨如何设计可扩展、可维护的软件架构。