Skip to content

第25章 并发设计模式

学习目标

完成本章学习后,读者将能够:

  • 理解并发编程的核心概念与形式化定义
  • 掌握线程安全的设计原则与同步机制
  • 使用生产者-消费者、读写锁、Actor等并发模式
  • 实现高效的异步编程与协程系统
  • 识别并发编程中的常见陷阱与反模式

25.1 并发编程概述

25.1.1 核心定义

并发编程(Concurrent Programming) 是一种程序设计范式,它允许多个计算任务在重叠的时间段内执行,以提高系统的响应性和资源利用率。

25.1.2 形式化定义

从形式化角度,并发系统可以定义为一个五元组:

$$\mathcal{C} = \langle P, S, T, \sigma, \phi \rangle$$

其中:

  • $P = {p_1, p_2, \ldots, p_n}$ 是进程/线程集合
  • $S$ 是共享状态空间
  • $T$ 是转换函数集合
  • $\sigma: P \times S \rightarrow S$ 是状态转换函数
  • $\phi: S \rightarrow {safe, unsafe}$ 是安全性判定函数

线程安全定义

一个操作 $op$ 是线程安全的,当且仅当对于任意并发执行序列:

$$\forall s \in S, \forall p_i, p_j \in P: op(p_i, op(p_j, s)) = op(p_j, op(p_i, s))$$

互斥条件

临界区 $CS$ 的互斥访问定义为:

$$\forall t: |{p \in P : p \text{ in } CS \text{ at } t}| \leq 1$$

活性条件

每个请求进入临界区的进程最终都能进入:

$$\forall p \in P: p \text{ requests } CS \Rightarrow \Diamond(p \text{ enters } CS)$$

25.1.3 并发模型对比

特性多线程多进程异步IOActor模型
内存模型共享内存独立内存共享内存消息传递
通信方式共享变量IPC事件循环消息
开销中等中等
适用场景CPU密集隔离需求IO密集分布式系统
调试难度中等中等

25.1.4 并发与并行

┌─────────────────────────────────────────────────────────────────┐
│                    并发 vs 并行                                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  并发(Concurrency):                                          │
│  - 多个任务在重叠时间段内推进                                    │
│  - 单核CPU通过时间片轮转实现                                     │
│  - 关注任务的结构和调度                                          │
│                                                                 │
│  并行(Parallelism):                                          │
│  - 多个任务同时执行                                              │
│  - 需要多核CPU支持                                              │
│  - 关注任务的执行效率                                            │
│                                                                 │
│  时间线示例:                                                    │
│                                                                 │
│  并发(单核):                                                  │
│  Task A: ████████░░░░░░░░████████░░░░░░░░████████              │
│  Task B: ░░░░░░░░████████░░░░░░░░████████░░░░░░░░              │
│           └─ 时间片轮转 ─┘                                      │
│                                                                 │
│  并行(多核):                                                  │
│  Core 1: ████████████████████████████████████████              │
│  Core 2: ████████████████████████████████████████              │
│           └─ 同时执行 ─┘                                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

25.2 历史背景与演进

25.2.1 历史发展

年代里程碑描述
1965Dijkstra信号量Edsger Dijkstra提出信号量机制
1971监控器Hoare和Hansen提出监控器概念
1978CSPHoare提出通信顺序进程
1985Actor模型Hewitt的Actor模型成熟
1995Java线程Java内置多线程支持
2000s并发集合Java并发包、.NET并行库
2009Go goroutineGo语言轻量级并发
2010sasync/awaitC#、Python、JS异步语法
2020s结构化并发结构化并发概念普及

25.2.2 理论基础

并发编程的理论基础来源于:

  1. 进程代数:CSP、CCS、π演算
  2. 时序逻辑:LTL、CTL用于验证并发性质
  3. Petri网:并发系统的图形化建模
  4. 线性化:并发数据结构的正确性标准

25.3 线程安全模式

25.3.1 不可变对象模式

python
from dataclasses import dataclass, field
from typing import Tuple, FrozenSet, List, Dict, Any
from functools import reduce
import hashlib
import json


@dataclass(frozen=True)
class ImmutablePoint:
    """不可变点:天然线程安全"""
    x: float
    y: float
    
    def move(self, dx: float, dy: float) -> 'ImmutablePoint':
        return ImmutablePoint(self.x + dx, self.y + dy)
    
    def distance_to(self, other: 'ImmutablePoint') -> float:
        return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5


@dataclass(frozen=True)
class ImmutableConfig:
    """不可变配置"""
    host: str
    port: int
    database: str
    options: FrozenSet[str] = field(default_factory=frozenset)
    
    @staticmethod
    def create(host: str, port: int, database: str, *options: str) -> 'ImmutableConfig':
        return ImmutableConfig(host, port, database, frozenset(options))
    
    def with_port(self, new_port: int) -> 'ImmutableConfig':
        return ImmutableConfig(self.host, new_port, self.database, self.options)
    
    def with_option(self, option: str) -> 'ImmutableConfig':
        return ImmutableConfig(self.host, self.port, self.database, 
                               self.options | {option})


class ImmutableList:
    """
    不可变列表:所有操作返回新实例
    线程安全:无共享可变状态
    """
    
    def __init__(self, items: Tuple = ()):
        self._items = items
        self._hash: int | None = None
    
    @staticmethod
    def of(*items) -> 'ImmutableList':
        return ImmutableList(tuple(items))
    
    @staticmethod
    def from_list(items: list) -> 'ImmutableList':
        return ImmutableList(tuple(items))
    
    def append(self, item: Any) -> 'ImmutableList':
        return ImmutableList(self._items + (item,))
    
    def prepend(self, item: Any) -> 'ImmutableList':
        return ImmutableList((item,) + self._items)
    
    def remove(self, item: Any) -> 'ImmutableList':
        return ImmutableList(tuple(x for x in self._items if x != item))
    
    def map(self, func: callable) -> 'ImmutableList':
        return ImmutableList(tuple(func(x) for x in self._items))
    
    def filter(self, predicate: callable) -> 'ImmutableList':
        return ImmutableList(tuple(x for x in self._items if predicate(x)))
    
    def reduce(self, func: callable, initial: Any) -> Any:
        return reduce(func, self._items, initial)
    
    def __len__(self) -> int:
        return len(self._items)
    
    def __getitem__(self, index: int) -> Any:
        return self._items[index]
    
    def __iter__(self):
        return iter(self._items)
    
    def __hash__(self) -> int:
        if self._hash is None:
            self._hash = hash(self._items)
        return self._hash
    
    def __eq__(self, other) -> bool:
        if isinstance(other, ImmutableList):
            return self._items == other._items
        return False
    
    def __repr__(self) -> str:
        return f"ImmutableList({list(self._items)})"


class ImmutableBuilder:
    """
    不可变对象构建器:用于构建复杂不可变对象
    """
    
    def __init__(self):
        self._data: Dict[str, Any] = {}
    
    def set(self, key: str, value: Any) -> 'ImmutableBuilder':
        self._data[key] = value
        return self
    
    def update(self, data: Dict[str, Any]) -> 'ImmutableBuilder':
        self._data.update(data)
        return self
    
    def build(self) -> 'ImmutableRecord':
        return ImmutableRecord(self._data.copy())


class ImmutableRecord:
    """
    不可变记录:通用不可变数据容器
    """
    
    def __init__(self, data: Dict[str, Any]):
        object.__setattr__(self, '_data', data.copy())
    
    def get(self, key: str, default: Any = None) -> Any:
        return self._data.get(key, default)
    
    def with_value(self, key: str, value: Any) -> 'ImmutableRecord':
        new_data = self._data.copy()
        new_data[key] = value
        return ImmutableRecord(new_data)
    
    def without(self, key: str) -> 'ImmutableRecord':
        new_data = self._data.copy()
        new_data.pop(key, None)
        return ImmutableRecord(new_data)
    
    def keys(self) -> Tuple:
        return tuple(self._data.keys())
    
    def values(self) -> Tuple:
        return tuple(self._data.values())
    
    def to_dict(self) -> Dict[str, Any]:
        return self._data.copy()
    
    def __getattr__(self, name: str) -> Any:
        if name.startswith('_'):
            return object.__getattribute__(self, name)
        return self._data.get(name)
    
    def __setattr__(self, name: str, value: Any) -> None:
        raise AttributeError("ImmutableRecord is immutable")
    
    def __repr__(self) -> str:
        return f"ImmutableRecord({self._data})"


point1 = ImmutablePoint(0, 0)
point2 = point1.move(10, 20)
print(f"原始点: ({point1.x}, {point1.y})")
print(f"移动后: ({point2.x}, {point2.y})")

config = ImmutableConfig.create("localhost", 3306, "mydb", "ssl", "timeout")
config2 = config.with_port(5432).with_option("compress")
print(f"\n原配置: {config.host}:{config.port}")
print(f"新配置: {config2.host}:{config2.port}, options: {config2.options}")

nums = ImmutableList.of(1, 2, 3, 4, 5)
nums2 = nums.append(6).map(lambda x: x * 2)
print(f"\n原始列表: {nums}")
print(f"修改后列表: {nums2}")

record = (ImmutableBuilder()
    .set('name', '张三')
    .set('age', 30)
    .set('email', 'zhang@example.com')
    .build())

record2 = record.with_value('age', 31)
print(f"\n原始记录: {record}")
print(f"修改后记录: {record2}")

25.3.2 线程局部存储模式

python
import threading
from typing import Any, Optional, Callable, Dict
from contextlib import contextmanager
from dataclasses import dataclass


class ThreadLocalStorage:
    """
    线程局部存储:每个线程独立的数据存储
    """
    
    def __init__(self):
        self._local = threading.local()
    
    def set(self, key: str, value: Any) -> None:
        if not hasattr(self._local, 'data'):
            self._local.data = {}
        self._local.data[key] = value
    
    def get(self, key: str, default: Any = None) -> Any:
        if not hasattr(self._local, 'data'):
            return default
        return self._local.data.get(key, default)
    
    def delete(self, key: str) -> None:
        if hasattr(self._local, 'data') and key in self._local.data:
            del self._local.data[key]
    
    def clear(self) -> None:
        if hasattr(self._local, 'data'):
            self._local.data.clear()
    
    def keys(self) -> list:
        if not hasattr(self._local, 'data'):
            return []
        return list(self._local.data.keys())
    
    def has(self, key: str) -> bool:
        if not hasattr(self._local, 'data'):
            return False
        return key in self._local.data


class RequestContext:
    """
    请求上下文:使用线程局部存储管理请求级数据
    """
    
    _storage = ThreadLocalStorage()
    
    @classmethod
    def set_request_id(cls, request_id: str) -> None:
        cls._storage.set('request_id', request_id)
    
    @classmethod
    def get_request_id(cls) -> Optional[str]:
        return cls._storage.get('request_id')
    
    @classmethod
    def set_user(cls, user: dict) -> None:
        cls._storage.set('user', user)
    
    @classmethod
    def get_user(cls) -> Optional[dict]:
        return cls._storage.get('user')
    
    @classmethod
    def clear(cls) -> None:
        cls._storage.clear()
    
    @classmethod
    @contextmanager
    def context(cls, request_id: str, user: dict = None):
        """请求上下文管理器"""
        cls.set_request_id(request_id)
        if user:
            cls.set_user(user)
        try:
            yield
        finally:
            cls.clear()


class DatabaseConnection:
    """
    线程安全的数据库连接管理
    """
    
    _connections = ThreadLocalStorage()
    
    @classmethod
    def get_connection(cls, db_name: str) -> 'Connection':
        if not cls._connections.has(db_name):
            conn = cls._create_connection(db_name)
            cls._connections.set(db_name, conn)
        return cls._connections.get(db_name)
    
    @classmethod
    def _create_connection(cls, db_name: str) -> 'Connection':
        print(f"创建新连接: {db_name}")
        return {'name': db_name, 'connected': True}
    
    @classmethod
    def close_all(cls) -> None:
        for key in cls._connections.keys():
            conn = cls._connections.get(key)
            if conn:
                print(f"关闭连接: {key}")
        cls._connections.clear()


def worker(thread_id: int):
    with RequestContext.context(f"req-{thread_id}", {'id': thread_id, 'name': f'User{thread_id}'}):
        print(f"线程 {thread_id}: request_id = {RequestContext.get_request_id()}")
        print(f"线程 {thread_id}: user = {RequestContext.get_user()}")
        
        conn = DatabaseConnection.get_connection('main_db')
        print(f"线程 {thread_id}: connection = {conn}")

threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
for t in threads:
    t.start()
for t in threads:
    t.join()

25.3.3 同步包装器模式

python
import threading
from typing import Any, Callable, Dict, List, Set, Optional
from functools import wraps
from contextlib import contextmanager


def synchronized(lock: threading.Lock = None):
    """
    同步装饰器:确保方法线程安全
    """
    if lock is None:
        lock = threading.Lock()
    
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            with lock:
                return func(*args, **kwargs)
        wrapper._lock = lock
        return wrapper
    return decorator


class SynchronizedDict:
    """
    线程安全的字典
    """
    
    def __init__(self, initial: Dict = None):
        self._data: Dict = initial or {}
        self._lock = threading.RLock()
    
    def get(self, key: str, default: Any = None) -> Any:
        with self._lock:
            return self._data.get(key, default)
    
    def set(self, key: str, value: Any) -> None:
        with self._lock:
            self._data[key] = value
    
    def delete(self, key: str) -> bool:
        with self._lock:
            if key in self._data:
                del self._data[key]
                return True
            return False
    
    def contains(self, key: str) -> bool:
        with self._lock:
            return key in self._data
    
    def keys(self) -> List:
        with self._lock:
            return list(self._data.keys())
    
    def values(self) -> List:
        with self._lock:
            return list(self._data.values())
    
    def items(self) -> List:
        with self._lock:
            return list(self._data.items())
    
    def update(self, other: Dict) -> None:
        with self._lock:
            self._data.update(other)
    
    def clear(self) -> None:
        with self._lock:
            self._data.clear()
    
    def compute_if_absent(self, key: str, func: Callable) -> Any:
        with self._lock:
            if key not in self._data:
                self._data[key] = func()
            return self._data[key]
    
    def __len__(self) -> int:
        with self._lock:
            return len(self._data)
    
    def __repr__(self) -> str:
        with self._lock:
            return f"SynchronizedDict({self._data})"


class SynchronizedList:
    """
    线程安全的列表
    """
    
    def __init__(self, initial: List = None):
        self._data: List = initial or []
        self._lock = threading.RLock()
    
    def append(self, item: Any) -> None:
        with self._lock:
            self._data.append(item)
    
    def insert(self, index: int, item: Any) -> None:
        with self._lock:
            self._data.insert(index, item)
    
    def remove(self, item: Any) -> None:
        with self._lock:
            self._data.remove(item)
    
    def pop(self, index: int = -1) -> Any:
        with self._lock:
            return self._data.pop(index)
    
    def get(self, index: int) -> Any:
        with self._lock:
            return self._data[index]
    
    def set(self, index: int, value: Any) -> None:
        with self._lock:
            self._data[index] = value
    
    def contains(self, item: Any) -> bool:
        with self._lock:
            return item in self._data
    
    def index(self, item: Any) -> int:
        with self._lock:
            return self._data.index(item)
    
    def extend(self, items: List) -> None:
        with self._lock:
            self._data.extend(items)
    
    def clear(self) -> None:
        with self._lock:
            self._data.clear()
    
    def to_list(self) -> List:
        with self._lock:
            return self._data.copy()
    
    def __len__(self) -> int:
        with self._lock:
            return len(self._data)
    
    def __iter__(self):
        with self._lock:
            return iter(self._data.copy())
    
    def __repr__(self) -> str:
        with self._lock:
            return f"SynchronizedList({self._data})"


class AtomicReference:
    """
    原子引用:线程安全的对象引用
    """
    
    def __init__(self, value: Any = None):
        self._value = value
        self._lock = threading.Lock()
    
    def get(self) -> Any:
        with self._lock:
            return self._value
    
    def set(self, new_value: Any) -> None:
        with self._lock:
            self._value = new_value
    
    def get_and_set(self, new_value: Any) -> Any:
        with self._lock:
            old_value = self._value
            self._value = new_value
            return old_value
    
    def compare_and_set(self, expected: Any, new_value: Any) -> bool:
        with self._lock:
            if self._value == expected:
                self._value = new_value
                return True
            return False
    
    def update(self, func: Callable[[Any], Any]) -> Any:
        with self._lock:
            self._value = func(self._value)
            return self._value
    
    def __repr__(self) -> str:
        return f"AtomicReference({self._value})"


shared_dict = SynchronizedDict({'initial': 'value'})
shared_dict.set('new_key', 'new_value')
print(f"同步字典: {shared_dict}")

shared_list = SynchronizedList([1, 2, 3])
shared_list.append(4)
print(f"同步列表: {shared_list}")

counter = AtomicReference(0)
counter.update(lambda x: x + 1)
print(f"原子计数器: {counter}")

25.4 同步模式

25.4.1 读写锁模式

python
import threading
from contextlib import contextmanager
from typing import Optional
import time


class ReadWriteLock:
    """
    读写锁:允许多个读操作并发,写操作独占
    
    特性:
    - 读读共享:多个读者可以同时访问
    - 读写互斥:写者需要独占访问
    - 写写互斥:写者之间互斥
    """
    
    def __init__(self):
        self._read_lock = threading.Lock()
        self._write_lock = threading.Lock()
        self._readers = 0
        self._readers_count_lock = threading.Lock()
    
    @contextmanager
    def read_lock(self):
        self._acquire_read()
        try:
            yield
        finally:
            self._release_read()
    
    @contextmanager
    def write_lock(self):
        self._acquire_write()
        try:
            yield
        finally:
            self._release_write()
    
    def _acquire_read(self) -> None:
        with self._readers_count_lock:
            self._readers += 1
            if self._readers == 1:
                self._write_lock.acquire()
    
    def _release_read(self) -> None:
        with self._readers_count_lock:
            self._readers -= 1
            if self._readers == 0:
                self._write_lock.release()
    
    def _acquire_write(self) -> None:
        self._write_lock.acquire()
    
    def _release_write(self) -> None:
        self._write_lock.release()


class ReentrantReadWriteLock:
    """
    可重入读写锁:支持同一线程多次获取锁
    """
    
    def __init__(self):
        self._lock = threading.Lock()
        self._readers = 0
        self._writer: Optional[int] = None
        self._write_count = 0
        self._read_holders: dict = {}
    
    @contextmanager
    def read_lock(self):
        thread_id = threading.get_ident()
        with self._lock:
            if self._writer is not None and self._writer != thread_id:
                while self._writer is not None:
                    self._lock.release()
                    time.sleep(0.001)
                    self._lock.acquire()
            
            self._readers += 1
            self._read_holders[thread_id] = self._read_holders.get(thread_id, 0) + 1
        
        try:
            yield
        finally:
            with self._lock:
                self._read_holders[thread_id] -= 1
                if self._read_holders[thread_id] == 0:
                    del self._read_holders[thread_id]
                self._readers -= 1
    
    @contextmanager
    def write_lock(self):
        thread_id = threading.get_ident()
        with self._lock:
            if self._writer is None:
                self._writer = thread_id
            elif self._writer != thread_id:
                while self._writer is not None:
                    self._lock.release()
                    time.sleep(0.001)
                    self._lock.acquire()
            
            self._write_count += 1
        
        try:
            yield
        finally:
            with self._lock:
                self._write_count -= 1
                if self._write_count == 0:
                    self._writer = None


class SharedCache:
    """
    使用读写锁的共享缓存
    """
    
    def __init__(self):
        self._data: dict = {}
        self._lock = ReadWriteLock()
        self._stats = {'reads': 0, 'writes': 0, 'hits': 0, 'misses': 0}
    
    def get(self, key: str) -> any:
        with self._lock.read_lock():
            self._stats['reads'] += 1
            if key in self._data:
                self._stats['hits'] += 1
                return self._data[key]
            self._stats['misses'] += 1
            return None
    
    def set(self, key: str, value: any) -> None:
        with self._lock.write_lock():
            self._stats['writes'] += 1
            self._data[key] = value
    
    def delete(self, key: str) -> bool:
        with self._lock.write_lock():
            if key in self._data:
                del self._data[key]
                return True
            return False
    
    def get_all(self) -> dict:
        with self._lock.read_lock():
            return self._data.copy()
    
    def clear(self) -> None:
        with self._lock.write_lock():
            self._data.clear()
    
    def get_stats(self) -> dict:
        with self._lock.read_lock():
            return self._stats.copy()


cache = SharedCache()
cache.set('user:1', {'name': '张三', 'age': 30})
cache.set('user:2', {'name': '李四', 'age': 25})

print(f"缓存命中: {cache.get('user:1')}")
print(f"缓存未命中: {cache.get('user:3')}")
print(f"统计信息: {cache.get_stats()}")

25.4.2 屏障模式

python
import threading
from typing import Callable, List, Optional
from dataclasses import dataclass
import time
import random


@dataclass
class BarrierStats:
    parties: int
    waiting: int
    generation: int


class CyclicBarrier:
    """
    循环屏障:让一组线程互相等待,直到所有线程都到达屏障点
    
    特性:
    - 可重用:屏障可以被重复使用
    - 回调支持:所有线程到达后执行回调
    """
    
    def __init__(self, parties: int, action: Callable = None):
        self._parties = parties
        self._action = action
        self._lock = threading.Lock()
        self._condition = threading.Condition(self._lock)
        self._waiting = 0
        self._generation = 0
        self._broken = False
    
    def wait(self, timeout: float = None) -> int:
        """等待其他线程到达屏障"""
        with self._condition:
            generation = self._generation
            
            if self._broken:
                raise RuntimeError("Barrier is broken")
            
            index = self._waiting
            self._waiting += 1
            
            if index == self._parties - 1:
                self._generation += 1
                self._waiting = 0
                self._broken = False
                
                if self._action:
                    try:
                        self._action()
                    except Exception:
                        self._broken = True
                        self._condition.notify_all()
                        raise
                
                self._condition.notify_all()
                return 0
            else:
                while generation == self._generation and not self._broken:
                    if not self._condition.wait(timeout):
                        self._broken = True
                        self._condition.notify_all()
                        raise RuntimeError("Barrier timeout")
                
                if self._broken:
                    raise RuntimeError("Barrier is broken")
                
                return self._parties - self._waiting
    
    def reset(self) -> None:
        """重置屏障"""
        with self._condition:
            self._broken = False
            self._waiting = 0
            self._generation += 1
            self._condition.notify_all()
    
    def get_stats(self) -> BarrierStats:
        with self._lock:
            return BarrierStats(
                parties=self._parties,
                waiting=self._waiting,
                generation=self._generation
            )


class Phaser:
    """
    分阶段器:支持动态注册/注销参与者的屏障
    """
    
    def __init__(self, parties: int = 0):
        self._lock = threading.Lock()
        self._condition = threading.Condition(self._lock)
        self._registered = parties
        self._arrived = 0
        self._phase = 0
    
    def register(self, parties: int = 1) -> int:
        """注册参与者"""
        with self._lock:
            self._registered += parties
            return self._phase
    
    def arrive(self) -> int:
        """到达并继续"""
        with self._condition:
            phase = self._phase
            self._arrived += 1
            
            if self._arrived == self._registered:
                self._arrived = 0
                self._phase += 1
                self._condition.notify_all()
            
            return phase
    
    def arrive_and_await_advance(self) -> int:
        """到达并等待其他参与者"""
        with self._condition:
            phase = self._phase
            self._arrived += 1
            
            if self._arrived == self._registered:
                self._arrived = 0
                self._phase += 1
                self._condition.notify_all()
                return self._phase
            else:
                while self._phase == phase:
                    self._condition.wait()
                return self._phase
    
    def arrive_and_deregister(self) -> int:
        """到达并注销"""
        with self._condition:
            phase = self._phase
            self._arrived += 1
            self._registered -= 1
            
            if self._arrived == self._registered:
                self._arrived = 0
                self._phase += 1
                self._condition.notify_all()
            
            return phase
    
    def await_advance(self, phase: int) -> int:
        """等待指定阶段完成"""
        with self._condition:
            while self._phase == phase:
                self._condition.wait()
            return self._phase
    
    def get_phase(self) -> int:
        with self._lock:
            return self._phase
    
    def get_registered(self) -> int:
        with self._lock:
            return self._registered


def parallel_computation_example():
    """并行计算示例:使用屏障同步多个阶段"""
    
    def phase_action():
        print(f"=== 屏障触发 ===")
    
    barrier = CyclicBarrier(3, phase_action)
    results = []
    
    def worker(worker_id: int):
        for phase in range(3):
            print(f"工作者 {worker_id}: 阶段 {phase} 开始")
            time.sleep(random.random() * 0.5)
            print(f"工作者 {worker_id}: 阶段 {phase} 完成,等待其他线程")
            
            barrier.wait()
            
            print(f"工作者 {worker_id}: 进入下一阶段")
        
        results.append(f"工作者 {worker_id} 完成")
    
    threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
    for t in threads:
        t.start()
    for t in threads:
        t.join()
    
    print(f"\n所有工作者完成: {results}")


parallel_computation_example()

25.4.3 信号量模式

python
import threading
from typing import Optional
from contextlib import contextmanager
import time


class Semaphore:
    """
    信号量:控制同时访问资源的线程数量
    """
    
    def __init__(self, permits: int):
        self._permits = permits
        self._lock = threading.Lock()
        self._condition = threading.Condition(self._lock)
    
    def acquire(self, permits: int = 1, timeout: float = None) -> bool:
        """获取许可"""
        with self._condition:
            while self._permits < permits:
                if not self._condition.wait(timeout):
                    return False
            
            self._permits -= permits
            return True
    
    def release(self, permits: int = 1) -> None:
        """释放许可"""
        with self._condition:
            self._permits += permits
            self._condition.notify_all()
    
    @contextmanager
    def acquire_context(self, permits: int = 1):
        """上下文管理器方式获取许可"""
        self.acquire(permits)
        try:
            yield
        finally:
            self.release(permits)
    
    def available_permits(self) -> int:
        with self._lock:
            return self._permits


class BoundedSemaphore:
    """
    有界信号量:许可数量不能超过初始值
    """
    
    def __init__(self, permits: int):
        self._max_permits = permits
        self._permits = permits
        self._lock = threading.Lock()
        self._condition = threading.Condition(self._lock)
    
    def acquire(self, timeout: float = None) -> bool:
        with self._condition:
            while self._permits <= 0:
                if not self._condition.wait(timeout):
                    return False
            self._permits -= 1
            return True
    
    def release(self) -> None:
        with self._condition:
            if self._permits >= self._max_permits:
                raise ValueError("Semaphore permits exceeded maximum")
            self._permits += 1
            self._condition.notify()
    
    @contextmanager
    def context(self):
        self.acquire()
        try:
            yield
        finally:
            self.release()


class ConnectionPool:
    """
    连接池:使用信号量控制资源访问
    """
    
    def __init__(self, max_connections: int, create_connection: callable):
        self._max_connections = max_connections
        self._create_connection = create_connection
        self._semaphore = BoundedSemaphore(max_connections)
        self._pool: list = []
        self._pool_lock = threading.Lock()
        self._created = 0
        self._created_lock = threading.Lock()
    
    def get_connection(self, timeout: float = None) -> any:
        """获取连接"""
        if not self._semaphore.acquire(timeout):
            raise TimeoutError("获取连接超时")
        
        with self._pool_lock:
            if self._pool:
                return self._pool.pop()
        
        with self._created_lock:
            self._created += 1
            return self._create_connection()
    
    def release_connection(self, conn: any) -> None:
        """释放连接"""
        with self._pool_lock:
            self._pool.append(conn)
        self._semaphore.release()
    
    @contextmanager
    def connection(self, timeout: float = None):
        """上下文管理器方式获取连接"""
        conn = self.get_connection(timeout)
        try:
            yield conn
        finally:
            self.release_connection(conn)
    
    def get_stats(self) -> dict:
        return {
            'max_connections': self._max_connections,
            'created': self._created,
            'available': len(self._pool),
            'in_use': self._max_connections - self._semaphore._permits
        }


def create_db_connection():
    return {'id': random.randint(1000, 9999), 'connected': True}

pool = ConnectionPool(3, create_db_connection)

def use_connection(worker_id: int):
    with pool.connection() as conn:
        print(f"工作者 {worker_id}: 使用连接 {conn['id']}")
        time.sleep(0.5)
        print(f"工作者 {worker_id}: 释放连接 {conn['id']}")

threads = [threading.Thread(target=use_connection, args=(i,)) for i in range(5)]
for t in threads:
    t.start()
for t in threads:
    t.join()

print(f"\n连接池统计: {pool.get_stats()}")

25.5 生产者-消费者模式

25.5.1 基础实现

python
import threading
import queue
from typing import Any, Callable, Optional, List
from dataclasses import dataclass, field
from enum import Enum, auto
import time
import random


class TaskStatus(Enum):
    PENDING = auto()
    PROCESSING = auto()
    COMPLETED = auto()
    FAILED = auto()


@dataclass
class Task:
    """任务定义"""
    id: int
    data: Any
    status: TaskStatus = TaskStatus.PENDING
    result: Any = None
    error: Optional[str] = None
    created_at: float = field(default_factory=time.time)
    completed_at: Optional[float] = None


class Producer:
    """生产者:生成任务"""
    
    def __init__(self, task_queue: queue.Queue, name: str = "Producer"):
        self._queue = task_queue
        self._name = name
        self._task_counter = 0
    
    def produce(self, data: Any, priority: int = 0) -> Task:
        """生产任务"""
        self._task_counter += 1
        task = Task(id=self._task_counter, data=data)
        
        if hasattr(self._queue, 'put_with_priority'):
            self._queue.put_with_priority(task, priority)
        else:
            self._queue.put(task)
        
        print(f"[{self._name}] 生产任务 {task.id}")
        return task
    
    def produce_batch(self, data_list: List[Any]) -> List[Task]:
        """批量生产任务"""
        return [self.produce(data) for data in data_list]


class Consumer:
    """消费者:处理任务"""
    
    def __init__(
        self,
        task_queue: queue.Queue,
        name: str,
        processor: Callable[[Task], Any]
    ):
        self._queue = task_queue
        self._name = name
        self._processor = processor
        self._running = False
        self._thread: Optional[threading.Thread] = None
        self._processed_count = 0
        self._error_count = 0
    
    def start(self) -> None:
        self._running = True
        self._thread = threading.Thread(target=self._consume, daemon=True)
        self._thread.start()
        print(f"[{self._name}] 消费者启动")
    
    def stop(self) -> None:
        self._running = False
        if self._thread:
            self._thread.join(timeout=2)
        print(f"[{self._name}] 消费者停止,处理 {self._processed_count} 个任务,错误 {self._error_count} 个")
    
    def _consume(self) -> None:
        while self._running:
            try:
                task = self._queue.get(timeout=1)
                if task is None:
                    break
                
                self._process_task(task)
                self._queue.task_done()
                
            except queue.Empty:
                continue
            except Exception as e:
                print(f"[{self._name}] 错误: {e}")
                self._error_count += 1
    
    def _process_task(self, task: Task) -> None:
        task.status = TaskStatus.PROCESSING
        
        try:
            result = self._processor(task)
            task.result = result
            task.status = TaskStatus.COMPLETED
            task.completed_at = time.time()
            self._processed_count += 1
            print(f"[{self._name}] 完成任务 {task.id}: {result}")
            
        except Exception as e:
            task.status = TaskStatus.FAILED
            task.error = str(e)
            self._error_count += 1
            print(f"[{self._name}] 任务 {task.id} 失败: {e}")


class ProducerConsumerSystem:
    """生产者-消费者系统"""
    
    def __init__(
        self,
        queue_size: int = 10,
        consumer_count: int = 2,
        processor: Callable[[Task], Any] = None
    ):
        self._queue = queue.Queue(maxsize=queue_size)
        self._consumers: List[Consumer] = []
        self._producers: List[Producer] = []
        self._default_processor = processor or (lambda t: f"处理: {t.data}")
        
        for i in range(consumer_count):
            consumer = Consumer(self._queue, f"Consumer-{i+1}", self._default_processor)
            self._consumers.append(consumer)
    
    def create_producer(self, name: str = None) -> Producer:
        producer = Producer(self._queue, name or f"Producer-{len(self._producers)+1}")
        self._producers.append(producer)
        return producer
    
    def start(self) -> None:
        for consumer in self._consumers:
            consumer.start()
    
    def stop(self) -> None:
        for consumer in self._consumers:
            consumer.stop()
    
    def wait_completion(self, timeout: float = None) -> bool:
        """等待所有任务完成"""
        return self._queue.join()


def process_task(task: Task) -> str:
    time.sleep(random.random() * 0.5)
    return f"结果-{task.data}"

system = ProducerConsumerSystem(queue_size=20, consumer_count=3, processor=process_task)
system.start()

producer = system.create_producer("MainProducer")
for i in range(10):
    producer.produce(f"数据-{i}")
    time.sleep(0.1)

system._queue.join()
system.stop()

25.5.2 优先级队列

python
import threading
import queue
from typing import Any, Callable, Optional, List, Tuple
from dataclasses import dataclass, field
import heapq
import time


@dataclass(order=True)
class PriorityTask:
    """优先级任务"""
    priority: int
    sequence: int = field(compare=True)
    task: Any = field(compare=False)


class PriorityTaskQueue:
    """
    优先级任务队列:高优先级任务优先处理
    """
    
    def __init__(self, maxsize: int = 0):
        self._maxsize = maxsize
        self._queue: List[PriorityTask] = []
        self._lock = threading.Lock()
        self._not_empty = threading.Condition(self._lock)
        self._not_full = threading.Condition(self._lock)
        self._sequence = 0
        self._unfinished_tasks = 0
    
    def put(self, task: Any, priority: int = 0, block: bool = True, timeout: float = None) -> None:
        """添加任务"""
        with self._not_full:
            if self._maxsize > 0:
                while len(self._queue) >= self._maxsize:
                    if not block:
                        raise queue.Full()
                    if not self._not_full.wait(timeout):
                        raise queue.Full()
            
            with self._lock:
                self._sequence += 1
                priority_task = PriorityTask(priority, self._sequence, task)
                heapq.heappush(self._queue, priority_task)
                self._unfinished_tasks += 1
            
            self._not_empty.notify()
    
    def get(self, block: bool = True, timeout: float = None) -> Any:
        """获取任务"""
        with self._not_empty:
            while not self._queue:
                if not block:
                    raise queue.Empty()
                if not self._not_empty.wait(timeout):
                    raise queue.Empty()
            
            with self._lock:
                priority_task = heapq.heappop(self._queue)
            
            return priority_task.task
    
    def task_done(self) -> None:
        with self._lock:
            self._unfinished_tasks -= 1
            if self._unfinished_tasks == 0:
                self._not_empty.notify_all()
    
    def join(self) -> None:
        with self._not_empty:
            while self._unfinished_tasks:
                self._not_empty.wait()
    
    def qsize(self) -> int:
        with self._lock:
            return len(self._queue)
    
    def empty(self) -> bool:
        with self._lock:
            return len(self._queue) == 0


class DelayedTaskQueue:
    """
    延迟任务队列:任务在指定时间后才能被获取
    """
    
    def __init__(self):
        self._queue: List[Tuple[float, int, Any]] = []
        self._lock = threading.Lock()
        self._condition = threading.Condition(self._lock)
        self._sequence = 0
    
    def put(self, task: Any, delay_seconds: float = 0) -> None:
        """添加延迟任务"""
        execute_time = time.time() + delay_seconds
        
        with self._condition:
            self._sequence += 1
            heapq.heappush(self._queue, (execute_time, self._sequence, task))
            self._condition.notify()
    
    def get(self, timeout: float = None) -> Any:
        """获取到期任务"""
        with self._condition:
            while True:
                while not self._queue:
                    if not self._condition.wait(timeout):
                        raise queue.Empty()
                
                execute_time, _, task = self._queue[0]
                now = time.time()
                
                if execute_time <= now:
                    heapq.heappop(self._queue)
                    return task
                
                wait_time = execute_time - now
                if timeout is not None and wait_time > timeout:
                    raise queue.Empty()
                
                self._condition.wait(wait_time)
    
    def qsize(self) -> int:
        with self._lock:
            return len(self._queue)


priority_queue = PriorityTaskQueue(maxsize=10)

priority_queue.put("低优先级任务", priority=10)
priority_queue.put("高优先级任务", priority=1)
priority_queue.put("中优先级任务", priority=5)

print("优先级队列处理顺序:")
while not priority_queue.empty():
    task = priority_queue.get()
    print(f"  处理: {task}")
    priority_queue.task_done()


delayed_queue = DelayedTaskQueue()
delayed_queue.put("立即任务", delay_seconds=0)
delayed_queue.put("2秒后任务", delay_seconds=2)
delayed_queue.put("1秒后任务", delay_seconds=1)

print("\n延迟队列处理顺序:")
for _ in range(3):
    task = delayed_queue.get()
    print(f"  处理: {task}")

25.6 线程池模式

25.6.1 自定义线程池

python
import threading
import queue
from typing import Callable, Any, List, Optional, Dict
from dataclasses import dataclass, field
from concurrent.futures import Future
from enum import Enum, auto
import time
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class PoolState(Enum):
    RUNNING = auto()
    SHUTDOWN = auto()
    TERMINATED = auto()


@dataclass
class WorkerStats:
    tasks_completed: int = 0
    tasks_failed: int = 0
    total_time: float = 0.0


class Worker(threading.Thread):
    """工作线程"""
    
    def __init__(self, pool: 'ThreadPool', name: str):
        super().__init__(name=name, daemon=True)
        self._pool = pool
        self._stats = WorkerStats()
        self._current_task = None
    
    def run(self) -> None:
        while True:
            task = self._pool._get_task()
            if task is None:
                break
            
            self._current_task = task
            self._execute_task(task)
            self._current_task = None
    
    def _execute_task(self, task: 'ThreadPoolTask') -> None:
        start_time = time.time()
        
        try:
            result = task.func(*task.args, **task.kwargs)
            task.future.set_result(result)
            self._stats.tasks_completed += 1
            
        except Exception as e:
            task.future.set_exception(e)
            self._stats.tasks_failed += 1
            logger.error(f"任务执行失败: {e}")
        
        finally:
            self._stats.total_time += time.time() - start_time
    
    def get_stats(self) -> WorkerStats:
        return self._stats


@dataclass
class ThreadPoolTask:
    func: Callable
    args: tuple
    kwargs: dict
    future: Future


class ThreadPool:
    """
    线程池:管理和复用工作线程
    """
    
    def __init__(
        self,
        max_workers: int = 4,
        queue_size: int = 100,
        thread_name_prefix: str = "Worker"
    ):
        self._max_workers = max_workers
        self._queue_size = queue_size
        self._thread_name_prefix = thread_name_prefix
        
        self._task_queue: queue.Queue[Optional[ThreadPoolTask]] = queue.Queue(maxsize=queue_size)
        self._workers: List[Worker] = []
        self._state = PoolState.RUNNING
        self._lock = threading.Lock()
        self._submitted_count = 0
        self._completed_count = 0
    
    def start(self) -> None:
        """启动线程池"""
        with self._lock:
            if self._state != PoolState.RUNNING:
                return
            
            for i in range(self._max_workers):
                worker = Worker(self, f"{self._thread_name_prefix}-{i+1}")
                worker.start()
                self._workers.append(worker)
            
            logger.info(f"线程池启动,工作线程数: {self._max_workers}")
    
    def submit(self, func: Callable, *args, **kwargs) -> Future:
        """提交任务"""
        with self._lock:
            if self._state != PoolState.RUNNING:
                raise RuntimeError("线程池已关闭")
            
            future = Future()
            task = ThreadPoolTask(func, args, kwargs, future)
            
            self._task_queue.put(task)
            self._submitted_count += 1
            
            return future
    
    def map(self, func: Callable, iterable: List, timeout: float = None) -> List:
        """批量提交任务并等待结果"""
        futures = [self.submit(func, item) for item in iterable]
        return [f.result(timeout) for f in futures]
    
    def shutdown(self, wait: bool = True) -> None:
        """关闭线程池"""
        with self._lock:
            if self._state == PoolState.TERMINATED:
                return
            
            self._state = PoolState.SHUTDOWN
        
        for _ in self._workers:
            self._task_queue.put(None)
        
        if wait:
            for worker in self._workers:
                worker.join()
        
        with self._lock:
            self._state = PoolState.TERMINATED
        
        logger.info("线程池已关闭")
    
    def _get_task(self) -> Optional[ThreadPoolTask]:
        return self._task_queue.get()
    
    def get_stats(self) -> Dict:
        """获取统计信息"""
        with self._lock:
            worker_stats = [w.get_stats() for w in self._workers]
            
            return {
                'state': self._state.name,
                'max_workers': self._max_workers,
                'active_workers': sum(1 for w in self._workers if w.is_alive()),
                'queue_size': self._task_queue.qsize(),
                'submitted': self._submitted_count,
                'completed': sum(s.tasks_completed for s in worker_stats),
                'failed': sum(s.tasks_failed for s in worker_stats),
                'total_time': sum(s.total_time for s in worker_stats)
            }


def compute_square(n: int) -> int:
    time.sleep(0.1)
    return n ** 2

pool = ThreadPool(max_workers=3, queue_size=10)
pool.start()

futures = [pool.submit(compute_square, i) for i in range(6)]

for i, future in enumerate(futures):
    print(f"任务 {i}: {future.result()}")

print(f"\n线程池统计: {pool.get_stats()}")

pool.shutdown()

25.6.2 使用concurrent.futures

python
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from typing import Callable, List, Any, Dict
import time
import os


class ExecutorManager:
    """
    执行器管理器:封装concurrent.futures
    """
    
    def __init__(self, max_workers: int = None, use_process: bool = False):
        self._max_workers = max_workers or os.cpu_count()
        self._use_process = use_process
        self._executor = None
    
    def __enter__(self):
        if self._use_process:
            self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
        else:
            self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._executor:
            self._executor.shutdown(wait=True)
    
    def submit(self, func: Callable, *args, **kwargs):
        return self._executor.submit(func, *args, **kwargs)
    
    def map(self, func: Callable, iterable: List, timeout: float = None) -> List:
        return list(self._executor.map(func, iterable, timeout=timeout))
    
    def map_with_progress(
        self,
        func: Callable,
        iterable: List,
        callback: Callable[[int, int, Any], None] = None
    ) -> List[Any]:
        """带进度的批量执行"""
        futures = {self._executor.submit(func, item): i for i, item in enumerate(iterable)}
        results = [None] * len(iterable)
        completed = 0
        
        for future in as_completed(futures):
            index = futures[future]
            results[index] = future.result()
            completed += 1
            
            if callback:
                callback(completed, len(iterable), results[index])
        
        return results
    
    def submit_with_timeout(
        self,
        func: Callable,
        args: tuple = (),
        kwargs: dict = None,
        timeout: float = None
    ) -> Any:
        """带超时的任务提交"""
        kwargs = kwargs or {}
        future = self._executor.submit(func, *args, **kwargs)
        return future.result(timeout=timeout)


def cpu_intensive_task(n: int) -> int:
    """CPU密集型任务"""
    result = 0
    for i in range(n):
        result += i ** 2
    return result


def io_intensive_task(url: str) -> dict:
    """IO密集型任务"""
    time.sleep(0.1)
    return {'url': url, 'status': 'success'}


print("=== 线程池执行IO密集型任务 ===")
with ExecutorManager(max_workers=4) as executor:
    urls = [f"http://example.com/{i}" for i in range(10)]
    
    def progress_callback(completed, total, result):
        print(f"进度: {completed}/{total}")
    
    results = executor.map_with_progress(io_intensive_task, urls, progress_callback)
    print(f"完成 {len(results)} 个任务")

print("\n=== 进程池执行CPU密集型任务 ===")
with ExecutorManager(max_workers=4, use_process=True) as executor:
    numbers = [100000, 200000, 300000, 400000]
    results = executor.map(cpu_intensive_task, numbers)
    print(f"计算结果: {results}")

25.7 异步编程模式

25.7.1 异步上下文管理器

python
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator, Any, Optional
import time


class AsyncResource:
    """异步资源"""
    
    def __init__(self, name: str):
        self.name = name
        self._acquired = False
    
    async def acquire(self) -> None:
        print(f"获取资源: {self.name}")
        await asyncio.sleep(0.1)
        self._acquired = True
    
    async def release(self) -> None:
        print(f"释放资源: {self.name}")
        await asyncio.sleep(0.05)
        self._acquired = False
    
    async def use(self) -> str:
        if not self._acquired:
            raise RuntimeError("资源未获取")
        return f"使用资源: {self.name}"


class AsyncLock:
    """异步锁"""
    
    def __init__(self):
        self._locked = False
        self._waiters = asyncio.Queue()
    
    async def acquire(self) -> None:
        if not self._locked:
            self._locked = True
            return
        
        future = asyncio.get_event_loop().create_future()
        await self._waiters.put(future)
        await future
    
    async def release(self) -> None:
        if self._waiters.empty():
            self._locked = False
        else:
            future = await self._waiters.get()
            future.set_result(None)
    
    async def __aenter__(self):
        await self.acquire()
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.release()


class AsyncSemaphore:
    """异步信号量"""
    
    def __init__(self, permits: int):
        self._permits = permits
        self._lock = asyncio.Lock()
        self._condition = asyncio.Condition(self._lock)
    
    async def acquire(self) -> None:
        async with self._condition:
            while self._permits <= 0:
                await self._condition.wait()
            self._permits -= 1
    
    async def release(self) -> None:
        async with self._condition:
            self._permits += 1
            self._condition.notify()
    
    async def __aenter__(self):
        await self.acquire()
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.release()


@asynccontextmanager
async def async_resource(name: str) -> AsyncIterator[AsyncResource]:
    """异步资源上下文管理器"""
    resource = AsyncResource(name)
    await resource.acquire()
    try:
        yield resource
    finally:
        await resource.release()


async def async_resource_example():
    async with async_resource("数据库连接") as conn:
        result = await conn.use()
        print(result)


asyncio.run(async_resource_example())

25.7.2 异步迭代器与生成器

python
import asyncio
from typing import AsyncIterator, AsyncGenerator, List, Any
import random


class AsyncRange:
    """异步范围迭代器"""
    
    def __init__(self, start: int, end: int, delay: float = 0.1):
        self._start = start
        self._end = end
        self._delay = delay
    
    def __aiter__(self) -> 'AsyncRange':
        self._current = self._start
        return self
    
    async def __anext__(self) -> int:
        if self._current >= self._end:
            raise StopAsyncIteration
        await asyncio.sleep(self._delay)
        value = self._current
        self._current += 1
        return value


async def async_data_stream(count: int, delay: float = 0.1) -> AsyncGenerator[dict, None]:
    """异步数据流生成器"""
    for i in range(count):
        await asyncio.sleep(delay)
        yield {
            'id': i,
            'value': random.random() * 100,
            'timestamp': asyncio.get_event_loop().time()
        }


class AsyncBatchProcessor:
    """异步批量处理器"""
    
    def __init__(self, batch_size: int = 10, flush_interval: float = 1.0):
        self._batch_size = batch_size
        self._flush_interval = flush_interval
        self._batch: List[Any] = []
        self._lock = asyncio.Lock()
        self._last_flush = time.time()
    
    async def add(self, item: Any) -> None:
        async with self._lock:
            self._batch.append(item)
            
            if len(self._batch) >= self._batch_size:
                await self._flush()
    
    async def _flush(self) -> None:
        if not self._batch:
            return
        
        batch = self._batch.copy()
        self._batch.clear()
        
        print(f"批量处理 {len(batch)} 条数据")
        await asyncio.sleep(0.1)
    
    async def close(self) -> None:
        async with self._lock:
            await self._flush()


async def process_async_stream():
    """处理异步流"""
    print("=== 异步迭代器 ===")
    result: List[int] = []
    async for num in AsyncRange(0, 5):
        print(f"处理: {num}")
        result.append(num)
    
    print(f"结果: {result}")
    
    print("\n=== 异步生成器 ===")
    async for data in async_data_stream(5):
        print(f"接收: id={data['id']}, value={data['value']:.2f}")
    
    print("\n=== 异步批量处理 ===")
    processor = AsyncBatchProcessor(batch_size=3)
    
    async for data in async_data_stream(10, delay=0.05):
        await processor.add(data)
    
    await processor.close()


asyncio.run(process_async_stream())

25.7.3 异步任务编排

python
import asyncio
from typing import List, Any, Callable, Coroutine
from dataclasses import dataclass
import time


@dataclass
class TaskResult:
    index: int
    result: Any
    error: Exception = None
    duration: float = 0.0


class AsyncTaskOrchestrator:
    """异步任务编排器"""
    
    def __init__(self):
        self._tasks: List[asyncio.Task] = []
    
    async def gather_with_exceptions(
        self,
        coroutines: List[Coroutine],
        return_exceptions: bool = True
    ) -> List[TaskResult]:
        """收集结果,处理异常"""
        results = []
        
        async def run_with_timing(index: int, coro: Coroutine) -> TaskResult:
            start = time.time()
            try:
                result = await coro
                return TaskResult(index=index, result=result, duration=time.time() - start)
            except Exception as e:
                return TaskResult(index=index, result=None, error=e, duration=time.time() - start)
        
        tasks = [run_with_timing(i, coro) for i, coro in enumerate(coroutines)]
        results = await asyncio.gather(*tasks, return_exceptions=return_exceptions)
        
        return results
    
    async def run_with_timeout(
        self,
        coro: Coroutine,
        timeout: float
    ) -> TaskResult:
        """带超时的任务执行"""
        start = time.time()
        try:
            result = await asyncio.wait_for(coro, timeout=timeout)
            return TaskResult(index=0, result=result, duration=time.time() - start)
        except asyncio.TimeoutError:
            return TaskResult(index=0, result=None, error=asyncio.TimeoutError(), duration=timeout)
        except Exception as e:
            return TaskResult(index=0, result=None, error=e, duration=time.time() - start)
    
    async def run_with_retry(
        self,
        coro_factory: Callable[[], Coroutine],
        max_retries: int = 3,
        delay: float = 1.0
    ) -> TaskResult:
        """带重试的任务执行"""
        last_error = None
        
        for attempt in range(max_retries + 1):
            start = time.time()
            try:
                result = await coro_factory()
                return TaskResult(index=attempt, result=result, duration=time.time() - start)
            except Exception as e:
                last_error = e
                if attempt < max_retries:
                    await asyncio.sleep(delay)
        
        return TaskResult(index=max_retries, result=None, error=last_error)
    
    async def run_pipeline(
        self,
        initial_value: Any,
        stages: List[Callable[[Any], Coroutine]]
    ) -> Any:
        """管道式执行"""
        value = initial_value
        for i, stage in enumerate(stages):
            value = await stage(value)
            print(f"阶段 {i+1} 完成: {value}")
        return value
    
    async def run_parallel_with_limit(
        self,
        coroutines: List[Coroutine],
        max_concurrent: int
    ) -> List[TaskResult]:
        """限制并发数的并行执行"""
        semaphore = asyncio.Semaphore(max_concurrent)
        results = []
        
        async def run_with_semaphore(index: int, coro: Coroutine) -> TaskResult:
            async with semaphore:
                start = time.time()
                try:
                    result = await coro
                    return TaskResult(index=index, result=result, duration=time.time() - start)
                except Exception as e:
                    return TaskResult(index=index, result=None, error=e, duration=time.time() - start)
        
        tasks = [run_with_semaphore(i, coro) for i, coro in enumerate(coroutines)]
        results = await asyncio.gather(*tasks)
        
        return results


async def simulate_api_call(call_id: int, delay: float = 0.5) -> dict:
    await asyncio.sleep(delay)
    if random.random() < 0.2:
        raise ValueError(f"API调用 {call_id} 失败")
    return {'id': call_id, 'status': 'success'}


async def orchestrator_example():
    orchestrator = AsyncTaskOrchestrator()
    
    print("=== 并行执行(带异常处理)===")
    coroutines = [simulate_api_call(i, random.uniform(0.1, 0.5)) for i in range(5)]
    results = await orchestrator.gather_with_exceptions(coroutines)
    
    for r in results:
        if r.error:
            print(f"任务 {r.index}: 错误 - {r.error}")
        else:
            print(f"任务 {r.index}: {r.result}")
    
    print("\n=== 带超时的执行 ===")
    result = await orchestrator.run_with_timeout(simulate_api_call(1, 2.0), timeout=1.0)
    print(f"超时结果: {result}")
    
    print("\n=== 管道式执行 ===")
    async def stage1(x):
        return x * 2
    
    async def stage2(x):
        return x + 10
    
    async def stage3(x):
        return f"最终结果: {x}"
    
    final = await orchestrator.run_pipeline(5, [stage1, stage2, stage3])
    print(f"管道结果: {final}")


asyncio.run(orchestrator_example())

25.8 Actor模式

python
import threading
import queue
from typing import Any, Callable, Dict, Optional, List
from dataclasses import dataclass, field
from enum import Enum, auto
import time
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class MessagePriority(Enum):
    LOW = 0
    NORMAL = 1
    HIGH = 2
    URGENT = 3


@dataclass
class Message:
    """Actor消息"""
    sender: str
    content: Any
    priority: MessagePriority = MessagePriority.NORMAL
    reply_to: Optional['Actor'] = None
    timestamp: float = field(default_factory=time.time)


class Actor:
    """
    Actor基类:基于消息传递的并发模型
    
    特性:
    - 独立状态:每个Actor维护自己的状态
    - 消息驱动:通过消息进行通信
    - 单线程处理:每个Actor串行处理消息
    """
    
    def __init__(self, name: str):
        self.name = name
        self._mailbox: queue.PriorityQueue = queue.PriorityQueue()
        self._running = False
        self._thread: Optional[threading.Thread] = None
        self._handlers: Dict[str, Callable] = {}
        self._state: Dict[str, Any] = {}
        self._message_count = 0
    
    def on(self, message_type: str, handler: Callable) -> None:
        """注册消息处理器"""
        self._handlers[message_type] = handler
    
    def start(self) -> None:
        """启动Actor"""
        self._running = True
        self._thread = threading.Thread(target=self._run, name=f"Actor-{self.name}", daemon=True)
        self._thread.start()
        logger.info(f"Actor {self.name} 启动")
    
    def stop(self) -> None:
        """停止Actor"""
        self._running = False
        if self._thread:
            self._thread.join(timeout=2)
        logger.info(f"Actor {self.name} 停止,处理消息数: {self._message_count}")
    
    def send(self, recipient: 'Actor', content: Any, priority: MessagePriority = MessagePriority.NORMAL) -> None:
        """发送消息"""
        message = Message(
            sender=self.name,
            content=content,
            priority=priority
        )
        recipient._mailbox.put((priority.value, time.time(), message))
    
    def send_and_wait(self, recipient: 'Actor', content: Any, timeout: float = 5.0) -> Any:
        """发送消息并等待回复"""
        reply_queue = queue.Queue()
        
        message = Message(
            sender=self.name,
            content=content,
            reply_to=ReplyActor(f"Reply-{self.name}", reply_queue)
        )
        recipient._mailbox.put((MessagePriority.NORMAL.value, time.time(), message))
        
        return reply_queue.get(timeout=timeout)
    
    def _run(self) -> None:
        """Actor主循环"""
        while self._running:
            try:
                _, _, message = self._mailbox.get(timeout=1)
                self._handle(message)
                self._message_count += 1
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Actor {self.name} 处理消息错误: {e}")
    
    def _handle(self, message: Message) -> None:
        """处理消息"""
        content = message.content
        
        if isinstance(content, dict):
            msg_type = content.get('type')
            if msg_type and msg_type in self._handlers:
                result = self._handlers[msg_type](message)
                if message.reply_to:
                    message.reply_to.send(self, {'type': 'reply', 'result': result})
        elif callable(self._handlers.get('default')):
            self._handlers['default'](message)


class ReplyActor(Actor):
    """回复Actor:用于同步等待回复"""
    
    def __init__(self, name: str, reply_queue: queue.Queue):
        super().__init__(name)
        self._reply_queue = reply_queue
        self.on('reply', self._handle_reply)
    
    def _handle_reply(self, message: Message) -> None:
        self._reply_queue.put(message.content.get('result'))


class CounterActor(Actor):
    """计数器Actor示例"""
    
    def __init__(self, name: str):
        super().__init__(name)
        self._count = 0
        self.on('increment', self._increment)
        self.on('decrement', self._decrement)
        self.on('get', self._get_count)
        self.on('reset', self._reset)
    
    def _increment(self, message: Message) -> int:
        amount = message.content.get('amount', 1)
        self._count += amount
        logger.info(f"{self.name}: 计数增加到 {self._count}")
        return self._count
    
    def _decrement(self, message: Message) -> int:
        amount = message.content.get('amount', 1)
        self._count -= amount
        logger.info(f"{self.name}: 计数减少到 {self._count}")
        return self._count
    
    def _get_count(self, message: Message) -> int:
        logger.info(f"{self.name}: 当前计数为 {self._count}")
        return self._count
    
    def _reset(self, message: Message) -> int:
        self._count = 0
        logger.info(f"{self.name}: 计数重置为 0")
        return self._count


class ActorSystem:
    """Actor系统:管理多个Actor"""
    
    def __init__(self):
        self._actors: Dict[str, Actor] = {}
        self._lock = threading.Lock()
    
    def create_actor(self, actor_class: type, name: str, *args, **kwargs) -> Actor:
        """创建Actor"""
        with self._lock:
            if name in self._actors:
                raise ValueError(f"Actor {name} 已存在")
            
            actor = actor_class(name, *args, **kwargs)
            self._actors[name] = actor
            actor.start()
            return actor
    
    def get_actor(self, name: str) -> Optional[Actor]:
        """获取Actor"""
        return self._actors.get(name)
    
    def stop_all(self) -> None:
        """停止所有Actor"""
        for actor in self._actors.values():
            actor.stop()
        self._actors.clear()
    
    def broadcast(self, content: Any) -> None:
        """广播消息"""
        for actor in self._actors.values():
            actor.send(actor, content)


system = ActorSystem()

counter1 = system.create_actor(CounterActor, "Counter1")
counter2 = system.create_actor(CounterActor, "Counter2")

counter1.send(counter1, {'type': 'increment', 'amount': 5})
counter1.send(counter1, {'type': 'increment', 'amount': 3})
counter1.send(counter1, {'type': 'get'})

counter2.send(counter2, {'type': 'increment', 'amount': 10})
counter2.send(counter2, {'type': 'decrement', 'amount': 2})
counter2.send(counter2, {'type': 'get'})

time.sleep(1)
system.stop_all()

25.9 反模式与最佳实践

25.9.1 常见反模式

python
import threading
import time


class ConcurrencyAntiPatterns:
    """并发反模式示例"""
    
    @staticmethod
    def deadlock_example():
        """
        反模式:死锁
        两个线程互相等待对方持有的锁
        """
        lock1 = threading.Lock()
        lock2 = threading.Lock()
        
        def thread1():
            with lock1:
                time.sleep(0.1)
                with lock2:
                    print("Thread 1 完成")
        
        def thread2():
            with lock2:
                time.sleep(0.1)
                with lock1:
                    print("Thread 2 完成")
        
        print("死锁示例(跳过实际执行)")
    
    @staticmethod
    def race_condition_example():
        """
        反模式:竞态条件
        多线程访问共享变量未加锁
        """
        counter = 0
        
        def increment():
            nonlocal counter
            for _ in range(10000):
                temp = counter
                temp += 1
                counter = temp
        
        threads = [threading.Thread(target=increment) for _ in range(5)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()
        
        print(f"竞态条件示例: 期望 50000, 实际 {counter}")
    
    @staticmethod
    def correct_counter():
        """
        正确做法:使用锁保护共享变量
        """
        counter = 0
        lock = threading.Lock()
        
        def increment():
            nonlocal counter
            for _ in range(10000):
                with lock:
                    counter += 1
        
        threads = [threading.Thread(target=increment) for _ in range(5)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()
        
        print(f"正确计数器: {counter}")


ConcurrencyAntiPatterns.race_condition_example()
ConcurrencyAntiPatterns.correct_counter()

25.9.2 最佳实践

python
import threading
import queue
from typing import Any, Callable
from contextlib import contextmanager
from dataclasses import dataclass
import time


@dataclass
class ConcurrencyBestPractices:
    """并发最佳实践"""
    
    @staticmethod
    @contextmanager
    def timeout_lock(lock: threading.Lock, timeout: float):
        """带超时的锁获取"""
        acquired = lock.acquire(timeout=timeout)
        if not acquired:
            raise TimeoutError("获取锁超时")
        try:
            yield
        finally:
            lock.release()
    
    @staticmethod
    def safe_shared_counter():
        """线程安全的计数器"""
        class SafeCounter:
            def __init__(self, initial: int = 0):
                self._value = initial
                self._lock = threading.Lock()
            
            def increment(self, delta: int = 1) -> int:
                with self._lock:
                    self._value += delta
                    return self._value
            
            def decrement(self, delta: int = 1) -> int:
                with self._lock:
                    self._value -= delta
                    return self._value
            
            def get(self) -> int:
                with self._lock:
                    return self._value
            
            def set(self, value: int) -> None:
                with self._lock:
                    self._value = value
        
        return SafeCounter()
    
    @staticmethod
    def graceful_shutdown_example():
        """优雅关闭示例"""
        class Service:
            def __init__(self):
                self._running = False
                self._thread = None
                self._stop_event = threading.Event()
            
            def start(self):
                self._running = True
                self._stop_event.clear()
                self._thread = threading.Thread(target=self._run, daemon=True)
                self._thread.start()
            
            def stop(self, timeout: float = 5.0):
                self._running = False
                self._stop_event.set()
                if self._thread:
                    self._thread.join(timeout=timeout)
            
            def _run(self):
                while not self._stop_event.is_set():
                    try:
                        pass
                    except Exception:
                        pass
        
        return Service()


print("=== 最佳实践示例 ===")
counter = ConcurrencyBestPractices.safe_shared_counter()

def worker(counter, worker_id):
    for _ in range(1000):
        counter.increment()
    print(f"工作者 {worker_id} 完成")

threads = [threading.Thread(target=worker, args=(counter, i)) for i in range(5)]
for t in threads:
    t.start()
for t in threads:
    t.join()

print(f"最终计数: {counter.get()}")

25.10 小结

并发设计模式解决的核心问题是如何在多线程/多进程环境中安全、高效地协调任务执行。

关键要点

  1. 线程安全:通过不可变对象、锁机制、线程局部存储保证
  2. 同步协调:通过屏障、信号量、读写锁协调线程
  3. 任务分发:通过生产者-消费者、线程池模式分发任务
  4. 异步处理:通过async/await、异步生成器实现非阻塞IO
  5. Actor模型:通过消息传递实现无共享并发

实践建议

  1. 优先使用不可变对象,避免共享状态
  2. 使用高级同步原语(如队列、信号量)而非低级锁
  3. 选择合适的并发模型(线程/进程/协程)
  4. 设置合理的超时时间,避免死锁
  5. 实现优雅关闭机制

下一章预告

下一章将介绍架构设计模式,探讨如何设计可扩展、可维护的软件架构。

Python技术丛书 - 江苏省宿城中等专业学校