Skip to content

第46章 分布式系统

学习目标

完成本章学习后,你将能够:

  1. 理解分布式系统原理:CAP定理、BASE理论、分布式一致性
  2. 实现消息队列:RabbitMQ、Kafka、消息模式、可靠传输
  3. 设计缓存策略:Redis、缓存模式、缓存穿透、缓存雪崩
  4. 实现负载均衡:轮询、加权、一致性哈希、健康检查
  5. 构建高可用系统:故障转移、服务降级、限流熔断
  6. 处理分布式事务:两阶段提交、Saga模式、最终一致性
  7. 实现分布式锁:Redis锁、Zookeeper锁、分布式协调
  8. 设计分布式ID:雪花算法、UUID、数据库序列

46.1 分布式系统基础

46.1.1 核心理论

python
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
import asyncio
import time
import hashlib


class ConsistencyLevel(Enum):
    STRONG = "strong"
    EVENTUAL = "eventual"
    CAUSAL = "causal"


@dataclass
class Node:
    id: str
    host: str
    port: int
    status: str = "healthy"
    last_heartbeat: float = 0
    metadata: Dict[str, str] = field(default_factory=dict)

    def is_alive(self, timeout: float = 30.0) -> bool:
        return time.time() - self.last_heartbeat < timeout


@dataclass
class DistributedConfig:
    replication_factor: int = 3
    consistency_level: ConsistencyLevel = ConsistencyLevel.EVENTUAL
    read_quorum: int = 2
    write_quorum: int = 2
    heartbeat_interval: float = 5.0
    election_timeout: float = 10.0


class CAPTheorem:
    @staticmethod
    def explain() -> Dict:
        return {
            "C": "Consistency - 所有节点在同一时间看到相同的数据",
            "A": "Availability - 每个请求都能获得响应(成功或失败)",
            "P": "Partition Tolerance - 网络分区时系统仍能运行",
            "tradeoffs": {
                "CP": "牺牲可用性,保证一致性和分区容错(如HBase)",
                "AP": "牺牲一致性,保证可用性和分区容错(如Cassandra)",
                "CA": "单机系统,无分区容错(如传统RDBMS)"
            }
        }


class BASETheory:
    @staticmethod
    def explain() -> Dict:
        return {
            "BA": "Basically Available - 基本可用,允许损失部分可用性",
            "S": "Soft State - 软状态,允许中间状态存在",
            "E": "Eventually Consistent - 最终一致,经过时间后达到一致",
            "vs_ACID": {
                "ACID": "强一致性,适合金融等关键业务",
                "BASE": "最终一致性,适合高并发互联网应用"
            }
        }


class ConsistentHashing:
    def __init__(self, virtual_nodes: int = 150):
        self.virtual_nodes = virtual_nodes
        self.ring: Dict[int, str] = {}
        self.nodes: List[str] = []

    def _hash(self, key: str) -> int:
        return int(hashlib.md5(key.encode()).hexdigest(), 16)

    def add_node(self, node: str) -> None:
        self.nodes.append(node)
        for i in range(self.virtual_nodes):
            virtual_key = f"{node}#{i}"
            hash_val = self._hash(virtual_key)
            self.ring[hash_val] = node

    def remove_node(self, node: str) -> None:
        self.nodes.remove(node)
        for i in range(self.virtual_nodes):
            virtual_key = f"{node}#{i}"
            hash_val = self._hash(virtual_key)
            self.ring.pop(hash_val, None)

    def get_node(self, key: str) -> Optional[str]:
        if not self.ring:
            return None

        hash_val = self._hash(key)
        sorted_keys = sorted(self.ring.keys())

        for k in sorted_keys:
            if hash_val <= k:
                return self.ring[k]

        return self.ring[sorted_keys[0]]

    def get_nodes(self, key: str, count: int = 3) -> List[str]:
        if not self.ring:
            return []

        hash_val = self._hash(key)
        sorted_keys = sorted(self.ring.keys())

        nodes = []
        for k in sorted_keys:
            if hash_val <= k:
                node = self.ring[k]
                if node not in nodes:
                    nodes.append(node)
                    if len(nodes) >= count:
                        return nodes

        for k in sorted_keys:
            node = self.ring[k]
            if node not in nodes:
                nodes.append(node)
                if len(nodes) >= count:
                    break

        return nodes


class DistributedCounter:
    def __init__(self, node_id: str, nodes: List[str]):
        self.node_id = node_id
        self.nodes = nodes
        self._local_count = 0
        self._vector_clock: Dict[str, int] = {n: 0 for n in nodes}

    def increment(self) -> None:
        self._local_count += 1
        self._vector_clock[self.node_id] += 1

    def get_count(self) -> int:
        return self._local_count

    def merge(self, other: "DistributedCounter") -> None:
        for node in self._vector_clock:
            self._vector_clock[node] = max(
                self._vector_clock[node],
                other._vector_clock[node]
            )

    def get_vector_clock(self) -> Dict[str, int]:
        return self._vector_clock.copy()

46.1.2 分布式协调

python
from typing import Optional
import asyncio


class DistributedLock:
    def __init__(self, lock_name: str, ttl: int = 30):
        self.lock_name = lock_name
        self.ttl = ttl
        self._locked = False
        self._lock_value: Optional[str] = None

    async def acquire(self, timeout: float = 10.0) -> bool:
        start_time = time.time()

        while time.time() - start_time < timeout:
            if await self._try_acquire():
                self._locked = True
                return True
            await asyncio.sleep(0.1)

        return False

    async def _try_acquire(self) -> bool:
        import uuid
        self._lock_value = str(uuid.uuid4())
        return True

    async def release(self) -> bool:
        if not self._locked:
            return False

        self._locked = False
        self._lock_value = None
        return True

    @property
    def locked(self) -> bool:
        return self._locked


class LeaderElection:
    def __init__(
        self,
        node_id: str,
        nodes: List[str],
        election_timeout: float = 5.0
    ):
        self.node_id = node_id
        self.nodes = nodes
        self.election_timeout = election_timeout
        self._leader_id: Optional[str] = None
        self._term = 0
        self._is_candidate = False
        self._votes_received = 0

    @property
    def leader_id(self) -> Optional[str]:
        return self._leader_id

    @property
    def is_leader(self) -> bool:
        return self._leader_id == self.node_id

    async def start_election(self) -> bool:
        self._term += 1
        self._is_candidate = True
        self._votes_received = 1

        for node in self.nodes:
            if node != self.node_id:
                if await self._request_vote(node):
                    self._votes_received += 1

        majority = len(self.nodes) // 2 + 1
        if self._votes_received >= majority:
            self._leader_id = self.node_id
            self._is_candidate = False
            return True

        self._is_candidate = False
        return False

    async def _request_vote(self, node: str) -> bool:
        return True

    def update_leader(self, leader_id: str, term: int) -> None:
        if term > self._term:
            self._term = term
            self._leader_id = leader_id
            self._is_candidate = False


class HeartbeatManager:
    def __init__(
        self,
        node_id: str,
        interval: float = 5.0,
        timeout: float = 15.0
    ):
        self.node_id = node_id
        self.interval = interval
        self.timeout = timeout
        self._last_heartbeat: Dict[str, float] = {}
        self._running = False

    async def start(self) -> None:
        self._running = True
        asyncio.create_task(self._heartbeat_loop())

    async def stop(self) -> None:
        self._running = False

    async def _heartbeat_loop(self) -> None:
        while self._running:
            await self._send_heartbeat()
            await asyncio.sleep(self.interval)

    async def _send_heartbeat(self) -> None:
        pass

    def receive_heartbeat(self, node_id: str) -> None:
        self._last_heartbeat[node_id] = time.time()

    def is_node_alive(self, node_id: str) -> bool:
        last = self._last_heartbeat.get(node_id, 0)
        return time.time() - last < self.timeout

    def get_dead_nodes(self) -> List[str]:
        return [
            node_id
            for node_id in self._last_heartbeat
            if not self.is_node_alive(node_id)
        ]

46.2 消息队列

46.2.1 消息队列实现

python
from dataclasses import dataclass
from typing import Any, Optional, Callable
from enum import Enum
import asyncio
import json
from datetime import datetime


class MessageStatus(Enum):
    PENDING = "pending"
    PROCESSING = "processing"
    COMPLETED = "completed"
    FAILED = "failed"


@dataclass
class Message:
    id: str
    queue: str
    payload: Any
    status: MessageStatus = MessageStatus.PENDING
    created_at: datetime = None
    processed_at: datetime = None
    retry_count: int = 0
    max_retries: int = 3
    error_message: str = None

    def __post_init__(self):
        if self.created_at is None:
            self.created_at = datetime.utcnow()


class MessageQueue:
    def __init__(self, name: str, max_size: int = 10000):
        self.name = name
        self.max_size = max_size
        self._queue: asyncio.Queue = asyncio.Queue(maxsize=max_size)
        self._dead_letter_queue: List[Message] = []
        self._consumers: List[Callable] = []

    async def publish(self, message: Message) -> bool:
        try:
            await asyncio.wait_for(
                self._queue.put(message),
                timeout=5.0
            )
            return True
        except asyncio.TimeoutError:
            return False

    async def consume(self) -> Optional[Message]:
        try:
            return await asyncio.wait_for(
                self._queue.get(),
                timeout=1.0
            )
        except asyncio.TimeoutError:
            return None

    async def ack(self, message: Message) -> None:
        message.status = MessageStatus.COMPLETED
        message.processed_at = datetime.utcnow()

    async def nack(self, message: Message, error: str = None) -> None:
        message.retry_count += 1
        message.error_message = error

        if message.retry_count >= message.max_retries:
            message.status = MessageStatus.FAILED
            self._dead_letter_queue.append(message)
        else:
            message.status = MessageStatus.PENDING
            await self.publish(message)

    def get_queue_size(self) -> int:
        return self._queue.qsize()

    def get_dead_letter_count(self) -> int:
        return len(self._dead_letter_queue)


class MessageBroker:
    def __init__(self):
        self._queues: Dict[str, MessageQueue] = {}
        self._exchanges: Dict[str, Dict] = {}

    def create_queue(self, name: str, max_size: int = 10000) -> MessageQueue:
        if name not in self._queues:
            self._queues[name] = MessageQueue(name, max_size)
        return self._queues[name]

    def get_queue(self, name: str) -> Optional[MessageQueue]:
        return self._queues.get(name)

    async def publish(
        self,
        queue_name: str,
        payload: Any,
        max_retries: int = 3
    ) -> bool:
        queue = self.get_queue(queue_name)
        if not queue:
            queue = self.create_queue(queue_name)

        import uuid
        message = Message(
            id=str(uuid.uuid4()),
            queue=queue_name,
            payload=payload,
            max_retries=max_retries
        )

        return await queue.publish(message)

    async def subscribe(
        self,
        queue_name: str,
        handler: Callable
    ) -> None:
        queue = self.get_queue(queue_name)
        if not queue:
            queue = self.create_queue(queue_name)

        queue._consumers.append(handler)
        asyncio.create_task(self._consume_loop(queue, handler))

    async def _consume_loop(
        self,
        queue: MessageQueue,
        handler: Callable
    ) -> None:
        while True:
            message = await queue.consume()
            if message:
                try:
                    message.status = MessageStatus.PROCESSING
                    result = await handler(message.payload)
                    await queue.ack(message)
                except Exception as e:
                    await queue.nack(message, str(e))


class PubSub:
    def __init__(self):
        self._subscribers: Dict[str, List[Callable]] = {}

    def subscribe(self, topic: str, handler: Callable) -> None:
        if topic not in self._subscribers:
            self._subscribers[topic] = []
        self._subscribers[topic].append(handler)

    def unsubscribe(self, topic: str, handler: Callable) -> None:
        if topic in self._subscribers:
            self._subscribers[topic].remove(handler)

    async def publish(self, topic: str, message: Any) -> int:
        if topic not in self._subscribers:
            return 0

        delivered = 0
        for handler in self._subscribers[topic]:
            try:
                if asyncio.iscoroutinefunction(handler):
                    await handler(message)
                else:
                    handler(message)
                delivered += 1
            except Exception:
                pass

        return delivered


class DelayedMessageQueue:
    def __init__(self):
        self._messages: List[tuple] = []

    async def schedule(
        self,
        payload: Any,
        delay_seconds: float,
        queue_name: str = "default"
    ) -> str:
        import uuid
        message_id = str(uuid.uuid4())
        execute_at = time.time() + delay_seconds

        self._messages.append((execute_at, message_id, queue_name, payload))
        self._messages.sort(key=lambda x: x[0])

        return message_id

    async def process_due_messages(
        self,
        broker: MessageBroker
    ) -> int:
        now = time.time()
        due_messages = [
            msg for msg in self._messages
            if msg[0] <= now
        ]

        self._messages = [
            msg for msg in self._messages
            if msg[0] > now
        ]

        for _, msg_id, queue_name, payload in due_messages:
            await broker.publish(queue_name, payload)

        return len(due_messages)

46.3 缓存系统

46.3.1 缓存策略

python
from typing import Any, Optional, Dict, List
from dataclasses import dataclass
import time
import hashlib
import asyncio


@dataclass
class CacheEntry:
    key: str
    value: Any
    created_at: float
    ttl: Optional[float] = None
    hits: int = 0

    def is_expired(self) -> bool:
        if self.ttl is None:
            return False
        return time.time() - self.created_at > self.ttl


class CacheStrategy(Enum):
    LRU = "lru"
    LFU = "lfu"
    FIFO = "fifo"


class InMemoryCache:
    def __init__(
        self,
        max_size: int = 1000,
        default_ttl: Optional[float] = None,
        strategy: CacheStrategy = CacheStrategy.LRU
    ):
        self.max_size = max_size
        self.default_ttl = default_ttl
        self.strategy = strategy
        self._cache: Dict[str, CacheEntry] = {}
        self._access_order: List[str] = []
        self._lock = asyncio.Lock()

    async def get(self, key: str) -> Optional[Any]:
        async with self._lock:
            entry = self._cache.get(key)

            if entry is None:
                return None

            if entry.is_expired():
                await self._remove(key)
                return None

            entry.hits += 1
            self._update_access_order(key)

            return entry.value

    async def set(
        self,
        key: str,
        value: Any,
        ttl: Optional[float] = None
    ) -> None:
        async with self._lock:
            if len(self._cache) >= self.max_size and key not in self._cache:
                await self._evict()

            self._cache[key] = CacheEntry(
                key=key,
                value=value,
                created_at=time.time(),
                ttl=ttl or self.default_ttl
            )

            if key in self._access_order:
                self._access_order.remove(key)
            self._access_order.append(key)

    async def delete(self, key: str) -> bool:
        async with self._lock:
            return await self._remove(key)

    async def _remove(self, key: str) -> bool:
        if key in self._cache:
            del self._cache[key]
            if key in self._access_order:
                self._access_order.remove(key)
            return True
        return False

    async def _evict(self) -> None:
        if not self._cache:
            return

        if self.strategy == CacheStrategy.LRU:
            key_to_evict = self._access_order[0]
        elif self.strategy == CacheStrategy.LFU:
            key_to_evict = min(
                self._cache.keys(),
                key=lambda k: self._cache[k].hits
            )
        else:
            key_to_evict = self._access_order[0]

        await self._remove(key_to_evict)

    def _update_access_order(self, key: str) -> None:
        if key in self._access_order:
            self._access_order.remove(key)
        self._access_order.append(key)

    async def clear(self) -> None:
        async with self._lock:
            self._cache.clear()
            self._access_order.clear()

    def get_stats(self) -> Dict:
        total_hits = sum(e.hits for e in self._cache.values())
        return {
            "size": len(self._cache),
            "max_size": self.max_size,
            "total_hits": total_hits,
            "keys": list(self._cache.keys())
        }


class CacheAside:
    def __init__(
        self,
        cache: InMemoryCache,
        data_loader: Callable
    ):
        self.cache = cache
        self.data_loader = data_loader

    async def get(self, key: str) -> Any:
        value = await self.cache.get(key)

        if value is not None:
            return value

        value = await self.data_loader(key)

        if value is not None:
            await self.cache.set(key, value)

        return value

    async def invalidate(self, key: str) -> None:
        await self.cache.delete(key)

    async def refresh(self, key: str) -> Any:
        await self.cache.delete(key)
        return await self.get(key)


class WriteThrough:
    def __init__(
        self,
        cache: InMemoryCache,
        data_writer: Callable
    ):
        self.cache = cache
        self.data_writer = data_writer

    async def set(self, key: str, value: Any) -> bool:
        success = await self.data_writer(key, value)

        if success:
            await self.cache.set(key, value)

        return success

    async def get(self, key: str) -> Any:
        return await self.cache.get(key)


class CachePenetrationProtection:
    def __init__(
        self,
        cache: InMemoryCache,
        bloom_filter_size: int = 10000
    ):
        self.cache = cache
        self._bloom_filter: set = set()
        self._null_cache_ttl = 60

    def _hash_key(self, key: str) -> int:
        return int(hashlib.md5(key.encode()).hexdigest(), 16)

    def add_to_filter(self, key: str) -> None:
        self._bloom_filter.add(self._hash_key(key))

    def might_exist(self, key: str) -> bool:
        return self._hash_key(key) in self._bloom_filter

    async def get(self, key: str) -> Optional[Any]:
        if not self.might_exist(key):
            return None

        return await self.cache.get(key)

    async def set_null(self, key: str) -> None:
        await self.cache.set(key, None, ttl=self._null_cache_ttl)


class CacheAvalancheProtection:
    def __init__(
        self,
        cache: InMemoryCache,
        base_ttl: float = 300
    ):
        self.cache = cache
        self.base_ttl = base_ttl

    async def set(self, key: str, value: Any) -> None:
        import random
        jitter = random.uniform(0, self.base_ttl * 0.1)
        ttl = self.base_ttl + jitter
        await self.cache.set(key, value, ttl=ttl)

    async def warm_up(self, keys: List[str], loader: Callable) -> None:
        import random
        random.shuffle(keys)

        for key in keys:
            value = await loader(key)
            if value is not None:
                await self.set(key, value)

46.4 负载均衡

46.4.1 负载均衡算法

python
from typing import List, Dict, Optional
from dataclasses import dataclass
import random
import time


@dataclass
class BackendServer:
    id: str
    host: str
    port: int
    weight: int = 1
    connections: int = 0
    healthy: bool = True
    response_time: float = 0.0


class LoadBalancer:
    def __init__(self, servers: List[BackendServer]):
        self.servers = servers
        self._current_index = 0

    def select(self) -> Optional[BackendServer]:
        healthy_servers = [s for s in self.servers if s.healthy]
        if not healthy_servers:
            return None
        return healthy_servers[0]

    def mark_unhealthy(self, server_id: str) -> None:
        for server in self.servers:
            if server.id == server_id:
                server.healthy = False

    def mark_healthy(self, server_id: str) -> None:
        for server in self.servers:
            if server.id == server_id:
                server.healthy = True


class RoundRobinLoadBalancer(LoadBalancer):
    def select(self) -> Optional[BackendServer]:
        healthy_servers = [s for s in self.servers if s.healthy]
        if not healthy_servers:
            return None

        server = healthy_servers[self._current_index % len(healthy_servers)]
        self._current_index += 1
        return server


class WeightedRoundRobinLoadBalancer(LoadBalancer):
    def __init__(self, servers: List[BackendServer]):
        super().__init__(servers)
        self._current_weights: Dict[str, int] = {}

    def select(self) -> Optional[BackendServer]:
        healthy_servers = [s for s in self.servers if s.healthy]
        if not healthy_servers:
            return None

        total_weight = sum(s.weight for s in healthy_servers)
        if total_weight == 0:
            return random.choice(healthy_servers)

        r = random.randint(1, total_weight)
        current_weight = 0

        for server in healthy_servers:
            current_weight += server.weight
            if r <= current_weight:
                return server

        return healthy_servers[0]


class LeastConnectionsLoadBalancer(LoadBalancer):
    def select(self) -> Optional[BackendServer]:
        healthy_servers = [s for s in self.servers if s.healthy]
        if not healthy_servers:
            return None

        return min(healthy_servers, key=lambda s: s.connections)


class IPHashLoadBalancer(LoadBalancer):
    def select(self, client_ip: str = None) -> Optional[BackendServer]:
        healthy_servers = [s for s in self.servers if s.healthy]
        if not healthy_servers:
            return None

        if client_ip is None:
            return random.choice(healthy_servers)

        hash_val = int(hashlib.md5(client_ip.encode()).hexdigest(), 16)
        index = hash_val % len(healthy_servers)
        return healthy_servers[index]


class ConsistentHashLoadBalancer(LoadBalancer):
    def __init__(
        self,
        servers: List[BackendServer],
        virtual_nodes: int = 150
    ):
        super().__init__(servers)
        self.virtual_nodes = virtual_nodes
        self._ring: Dict[int, BackendServer] = {}
        self._build_ring()

    def _build_ring(self) -> None:
        for server in self.servers:
            for i in range(self.virtual_nodes):
                key = f"{server.id}#{i}"
                hash_val = int(hashlib.md5(key.encode()).hexdigest(), 16)
                self._ring[hash_val] = server

    def select(self, key: str = None) -> Optional[BackendServer]:
        healthy_servers = [s for s in self.servers if s.healthy]
        if not healthy_servers:
            return None

        if key is None:
            return random.choice(healthy_servers)

        hash_val = int(hashlib.md5(key.encode()).hexdigest(), 16)
        sorted_keys = sorted(self._ring.keys())

        for k in sorted_keys:
            if hash_val <= k and self._ring[k].healthy:
                return self._ring[k]

        for k in sorted_keys:
            if self._ring[k].healthy:
                return self._ring[k]

        return None


class HealthChecker:
    def __init__(
        self,
        load_balancer: LoadBalancer,
        check_interval: float = 10.0,
        timeout: float = 5.0,
        unhealthy_threshold: int = 3,
        healthy_threshold: int = 2
    ):
        self.load_balancer = load_balancer
        self.check_interval = check_interval
        self.timeout = timeout
        self.unhealthy_threshold = unhealthy_threshold
        self.healthy_threshold = healthy_threshold
        self._failure_counts: Dict[str, int] = {}
        self._success_counts: Dict[str, int] = {}

    async def start(self) -> None:
        while True:
            await self._check_all()
            await asyncio.sleep(self.check_interval)

    async def _check_all(self) -> None:
        for server in self.load_balancer.servers:
            is_healthy = await self._check_server(server)

            if is_healthy:
                self._failure_counts[server.id] = 0
                self._success_counts[server.id] = self._success_counts.get(server.id, 0) + 1

                if self._success_counts[server.id] >= self.healthy_threshold:
                    self.load_balancer.mark_healthy(server.id)
            else:
                self._success_counts[server.id] = 0
                self._failure_counts[server.id] = self._failure_counts.get(server.id, 0) + 1

                if self._failure_counts[server.id] >= self.unhealthy_threshold:
                    self.load_balancer.mark_unhealthy(server.id)

    async def _check_server(self, server: BackendServer) -> bool:
        try:
            import aiohttp
            async with aiohttp.ClientSession() as session:
                async with session.get(
                    f"http://{server.host}:{server.port}/health",
                    timeout=aiohttp.ClientTimeout(total=self.timeout)
                ) as response:
                    return response.status == 200
        except Exception:
            return False

46.5 知识图谱

46.5.1 分布式系统架构

┌─────────────────────────────────────────────────────────────────────┐
│                      分布式系统架构全景图                             │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      客户端层 (Client)                        │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 负载均衡  │ │ CDN      │ │ DNS      │ │ API网关  │       │   │
│  │  │ Nginx   │ │ CloudFlare│ │ Route53 │ │ Kong    │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│  ┌─────────────────────────────┴───────────────────────────────┐   │
│  │                      服务层 (Services)                        │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 服务发现  │ │ 配置中心  │ │ 服务网格  │ │ 容器编排  │       │   │
│  │  │ Consul  │ │ Apollo  │ │ Istio   │ │ K8s     │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│  ┌─────────────────────────────┴───────────────────────────────┐   │
│  │                      数据层 (Data)                            │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 分布式DB │ │ 缓存     │ │ 消息队列  │ │ 搜索引擎  │       │   │
│  │  │ TiDB    │ │ Redis   │ │ Kafka   │ │ ES      │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│  ┌─────────────────────────────┴───────────────────────────────┐   │
│  │                      协调层 (Coordination)                    │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 分布式锁  │ │ 选举     │ │ 事务     │ │ 一致性   │       │   │
│  │  │ Redis   │ │ Raft    │ │ 2PC/Saga│ │ Paxos   │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│  ┌─────────────────────────────┴───────────────────────────────┐   │
│  │                      可观测性 (Observability)                 │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 监控     │ │ 日志     │ │ 追踪     │ │ 告警     │       │   │
│  │  │Prometheus│ │ ELK     │ │ Jaeger  │ │AlertMngr│       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

46.5.2 CAP定理与BASE理论

┌─────────────────────────────────────────────────────────────────────┐
│                      CAP定理图解                                     │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│                          Consistency                                │
│                          (一致性)                                    │
│                             /\                                      │
│                            /  \                                     │
│                           /    \                                    │
│                          /  CA  \      ┌────────────────────────┐  │
│                         /        \     │ CA: 传统关系数据库      │  │
│                        /──────────\    │    PostgreSQL, MySQL    │  │
│                       /            \   └────────────────────────┘  │
│                      /      CP      \                              │
│                     /                \   ┌────────────────────────┐ │
│                    /──────────────────\  │ CP: 分布式数据库       │ │
│                   /                    \ │    MongoDB, HBase      │ │
│                  /        AP           \└────────────────────────┘ │
│                 /                        \                          │
│                /──────────────────────────\                         │
│               Partition Tolerance         Availability             │
│               (分区容错)                   (可用性)                  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │  AP系统: 高可用、最终一致 (Cassandra, DynamoDB, CouchDB)       │  │
│  │  CP系统: 强一致、可能不可用 (Zookeeper, Redis Cluster)         │  │
│  │  CA系统: 无分区容错 (单机数据库)                               │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │                      BASE理论                                 │  │
│  │  Basically Available: 基本可用                                │  │
│  │  Soft State: 软状态(中间状态)                                │  │
│  │  Eventually Consistent: 最终一致                              │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

46.6 技术选型指南

46.6.1 分布式协调服务选型

服务一致性模型性能功能丰富度学习曲线推荐指数
ZookeeperCP★★★★☆
etcdCP★★★★★
ConsulCP极高★★★★★
RedisAP极高★★★★☆

46.6.2 消息队列选型

消息队列吞吐量延迟持久化顺序保证推荐场景
Kafka极高分区内有序日志、大数据
RabbitMQ企业应用
RocketMQ极高金融交易
Redis Stream极低可选实时消息
Pulsar极高云原生应用

46.6.3 分布式缓存选型

缓存方案性能持久化集群支持数据结构推荐指数
Redis Cluster极高丰富★★★★★
Memcached极高简单★★★☆☆
Hazelcast丰富★★★★☆
Aerospike极高中等★★★★☆

46.7 常见问题与解决方案

46.7.1 分布式锁实现

python
import time
import uuid
import threading
from typing import Optional, Callable
from dataclasses import dataclass

@dataclass
class LockOptions:
    """锁选项"""
    lock_timeout: float = 30.0
    wait_timeout: float = 10.0
    retry_interval: float = 0.1


class DistributedLock:
    """分布式锁基类"""
    
    def __init__(self, name: str, options: LockOptions = None):
        self.name = name
        self.options = options or LockOptions()
        self._lock_value = str(uuid.uuid4())
        self._acquired = False
    
    def acquire(self) -> bool:
        """获取锁"""
        raise NotImplementedError
    
    def release(self) -> bool:
        """释放锁"""
        raise NotImplementedError
    
    def extend(self, additional_time: float) -> bool:
        """延长锁时间"""
        raise NotImplementedError
    
    def __enter__(self):
        if not self.acquire():
            raise TimeoutError(f"Failed to acquire lock: {self.name}")
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.release()
        return False


class RedisDistributedLock(DistributedLock):
    """Redis分布式锁"""
    
    def __init__(self, redis_client, name: str, options: LockOptions = None):
        super().__init__(name, options)
        self.redis = redis_client
    
    def acquire(self) -> bool:
        """获取锁"""
        start_time = time.time()
        
        while True:
            acquired = self.redis.set(
                self.name,
                self._lock_value,
                nx=True,
                ex=int(self.options.lock_timeout)
            )
            
            if acquired:
                self._acquired = True
                return True
            
            if time.time() - start_time >= self.options.wait_timeout:
                return False
            
            time.sleep(self.options.retry_interval)
    
    def release(self) -> bool:
        """释放锁(Lua脚本保证原子性)"""
        if not self._acquired:
            return False
        
        lua_script = """
        if redis.call("get", KEYS[1]) == ARGV[1] then
            return redis.call("del", KEYS[1])
        else
            return 0
        end
        """
        
        result = self.redis.eval(lua_script, 1, self.name, self._lock_value)
        self._acquired = False
        return bool(result)
    
    def extend(self, additional_time: float) -> bool:
        """延长锁时间"""
        if not self._acquired:
            return False
        
        lua_script = """
        if redis.call("get", KEYS[1]) == ARGV[1] then
            return redis.call("expire", KEYS[1], ARGV[2])
        else
            return 0
        end
        """
        
        result = self.redis.eval(
            lua_script, 1, 
            self.name, 
            self._lock_value,
            int(additional_time)
        )
        return bool(result)


class ZookeeperDistributedLock(DistributedLock):
    """Zookeeper分布式锁"""
    
    def __init__(self, zk_client, name: str, options: LockOptions = None):
        super().__init__(name, options)
        self.zk = zk_client
        self._lock_path = f"/locks/{name}"
        self._node_path = None
    
    def acquire(self) -> bool:
        """获取锁"""
        start_time = time.time()
        
        while True:
            try:
                self._ensure_path(self._lock_path)
                
                self._node_path = self.zk.create(
                    f"{self._lock_path}/lock-",
                    value=self._lock_value.encode(),
                    sequence=True,
                    ephemeral=True
                )
                
                if self._is_lowest_sequence():
                    self._acquired = True
                    return True
                
                self._watch_previous()
                
                if time.time() - start_time >= self.options.wait_timeout:
                    self._cleanup()
                    return False
                
            except Exception:
                self._cleanup()
                if time.time() - start_time >= self.options.wait_timeout:
                    return False
            
            time.sleep(self.options.retry_interval)
    
    def _is_lowest_sequence(self) -> bool:
        children = self.zk.get_children(self._lock_path)
        if not children:
            return True
        
        sequences = sorted(children)
        my_sequence = self._node_path.split('/')[-1]
        return sequences[0] == my_sequence
    
    def _watch_previous(self):
        children = sorted(self.zk.get_children(self._lock_path))
        my_sequence = self._node_path.split('/')[-1]
        my_index = children.index(my_sequence)
        
        if my_index > 0:
            previous = f"{self._lock_path}/{children[my_index - 1]}"
            event = threading.Event()
            self.zk.exists(previous, watch=lambda e: event.set())
            event.wait(timeout=self.options.wait_timeout)
    
    def _ensure_path(self, path):
        if not self.zk.exists(path):
            self.zk.ensure_path(path)
    
    def _cleanup(self):
        if self._node_path and self.zk.exists(self._node_path):
            self.zk.delete(self._node_path)
    
    def release(self) -> bool:
        if not self._acquired:
            return False
        
        self._cleanup()
        self._acquired = False
        return True

46.7.2 分布式ID生成

python
import time
import threading
from typing import Optional
from dataclasses import dataclass

@dataclass
class SnowflakeConfig:
    """雪花算法配置"""
    worker_id_bits: int = 5
    datacenter_id_bits: int = 5
    sequence_bits: int = 12
    
    worker_id: int = 0
    datacenter_id: int = 0
    
    epoch: int = 1704067200000  # 2024-01-01 00:00:00


class SnowflakeIDGenerator:
    """雪花算法ID生成器"""
    
    def __init__(self, config: SnowflakeConfig = None):
        self.config = config or SnowflakeConfig()
        
        self._max_worker_id = -1 ^ (-1 << self.config.worker_id_bits)
        self._max_datacenter_id = -1 ^ (-1 << self.config.datacenter_id_bits)
        self._sequence_mask = -1 ^ (-1 << self.config.sequence_bits)
        
        self._worker_id_shift = self.config.sequence_bits
        self._datacenter_id_shift = (
            self.config.sequence_bits + self.config.worker_id_bits
        )
        self._timestamp_shift = (
            self.config.sequence_bits + 
            self.config.worker_id_bits + 
            self.config.datacenter_id_bits
        )
        
        self._sequence = 0
        self._last_timestamp = -1
        self._lock = threading.Lock()
        
        if self.config.worker_id > self._max_worker_id or self.config.worker_id < 0:
            raise ValueError(f"Worker ID must be 0-{self._max_worker_id}")
        
        if self.config.datacenter_id > self._max_datacenter_id or self.config.datacenter_id < 0:
            raise ValueError(f"Datacenter ID must be 0-{self._max_datacenter_id}")
    
    def generate(self) -> int:
        """生成ID"""
        with self._lock:
            timestamp = self._current_timestamp()
            
            if timestamp < self._last_timestamp:
                raise RuntimeError(
                    f"Clock moved backwards. Refusing to generate id for "
                    f"{self._last_timestamp - timestamp} milliseconds"
                )
            
            if timestamp == self._last_timestamp:
                self._sequence = (self._sequence + 1) & self._sequence_mask
                if self._sequence == 0:
                    timestamp = self._wait_next_millis(self._last_timestamp)
            else:
                self._sequence = 0
            
            self._last_timestamp = timestamp
            
            return (
                ((timestamp - self.config.epoch) << self._timestamp_shift) |
                (self.config.datacenter_id << self._datacenter_id_shift) |
                (self.config.worker_id << self._worker_id_shift) |
                self._sequence
            )
    
    def _current_timestamp(self) -> int:
        return int(time.time() * 1000)
    
    def _wait_next_millis(self, last_timestamp: int) -> int:
        timestamp = self._current_timestamp()
        while timestamp <= last_timestamp:
            timestamp = self._current_timestamp()
        return timestamp
    
    @staticmethod
    def parse_id(id: int) -> dict:
        """解析ID"""
        timestamp_bits = 41
        datacenter_bits = 5
        worker_bits = 5
        sequence_bits = 12
        
        sequence_mask = (1 << sequence_bits) - 1
        worker_mask = (1 << worker_bits) - 1
        datacenter_mask = (1 << datacenter_bits) - 1
        
        sequence = id & sequence_mask
        worker_id = (id >> sequence_bits) & worker_mask
        datacenter_id = (id >> (sequence_bits + worker_bits)) & datacenter_mask
        timestamp = id >> (sequence_bits + worker_bits + datacenter_bits)
        
        return {
            "timestamp": timestamp + 1704067200000,
            "datacenter_id": datacenter_id,
            "worker_id": worker_id,
            "sequence": sequence
        }


class UUIDGenerator:
    """UUID生成器"""
    
    @staticmethod
    def v4() -> str:
        """生成UUID v4"""
        import uuid
        return str(uuid.uuid4())
    
    @staticmethod
    def v7() -> str:
        """生成UUID v7(时间排序)"""
        import uuid
        import time
        
        timestamp = int(time.time() * 1000)
        timestamp_bytes = timestamp.to_bytes(6, 'big')
        random_bytes = uuid.uuid4().bytes[6:]
        
        uuid_bytes = timestamp_bytes + random_bytes
        return str(uuid.UUID(bytes=uuid_bytes))


class LeafIDGenerator:
    """Leaf ID生成器(号段模式)"""
    
    def __init__(self, db_connection, biz_tag: str, step: int = 1000):
        self.db = db_connection
        self.biz_tag = biz_tag
        self.step = step
        
        self._current_id = 0
        self._max_id = 0
        self._lock = threading.Lock()
    
    def generate(self) -> int:
        """生成ID"""
        with self._lock:
            if self._current_id >= self._max_id:
                self._load_segment()
            
            self._current_id += 1
            return self._current_id
    
    def _load_segment(self):
        """加载号段"""
        self.db.execute(
            "UPDATE leaf_alloc SET max_id = max_id + ? WHERE biz_tag = ?",
            (self.step, self.biz_tag)
        )
        
        result = self.db.fetch_one(
            "SELECT max_id FROM leaf_alloc WHERE biz_tag = ?",
            (self.biz_tag,)
        )
        
        if result:
            self._max_id = result[0]
            self._current_id = self._max_id - self.step

46.7.3 分布式事务模式

python
from typing import List, Dict, Any, Callable, Optional
from dataclasses import dataclass, field
from enum import Enum
import uuid

class TransactionState(Enum):
    PENDING = "pending"
    COMMITTED = "committed"
    ROLLED_BACK = "rolled_back"


@dataclass
class TransactionStep:
    """事务步骤"""
    name: str
    execute: Callable
    compensate: Callable
    state: TransactionState = TransactionState.PENDING


class SagaTransaction:
    """Saga事务"""
    
    def __init__(self, name: str):
        self.name = name
        self.transaction_id = str(uuid.uuid4())
        self.steps: List[TransactionStep] = []
        self.executed_steps: List[int] = []
    
    def add_step(
        self,
        name: str,
        execute: Callable,
        compensate: Callable
    ) -> "SagaTransaction":
        """添加步骤"""
        self.steps.append(TransactionStep(
            name=name,
            execute=execute,
            compensate=compensate
        ))
        return self
    
    def execute(self) -> Dict:
        """执行事务"""
        for i, step in enumerate(self.steps):
            try:
                step.execute()
                step.state = TransactionState.COMMITTED
                self.executed_steps.append(i)
            except Exception as e:
                self._compensate()
                return {
                    "success": False,
                    "transaction_id": self.transaction_id,
                    "failed_step": step.name,
                    "error": str(e)
                }
        
        return {
            "success": True,
            "transaction_id": self.transaction_id
        }
    
    def _compensate(self):
        """执行补偿"""
        for i in reversed(self.executed_steps):
            step = self.steps[i]
            try:
                step.compensate()
                step.state = TransactionState.ROLLED_BACK
            except Exception as e:
                print(f"Compensation failed for {step.name}: {e}")


class TCCTransaction:
    """TCC事务(Try-Confirm-Cancel)"""
    
    def __init__(self, name: str):
        self.name = name
        self.transaction_id = str(uuid.uuid4())
        self.participants: List[Dict] = []
    
    def add_participant(
        self,
        name: str,
        try_func: Callable,
        confirm_func: Callable,
        cancel_func: Callable
    ) -> "TCCTransaction":
        """添加参与者"""
        self.participants.append({
            "name": name,
            "try": try_func,
            "confirm": confirm_func,
            "cancel": cancel_func,
            "state": TransactionState.PENDING
        })
        return self
    
    def execute(self) -> Dict:
        """执行TCC事务"""
        if not self._try_phase():
            self._cancel_phase()
            return {
                "success": False,
                "transaction_id": self.transaction_id,
                "phase": "try"
            }
        
        if not self._confirm_phase():
            return {
                "success": False,
                "transaction_id": self.transaction_id,
                "phase": "confirm"
            }
        
        return {
            "success": True,
            "transaction_id": self.transaction_id
        }
    
    def _try_phase(self) -> bool:
        """Try阶段"""
        for participant in self.participants:
            try:
                participant["try"]()
            except Exception:
                return False
        return True
    
    def _confirm_phase(self) -> bool:
        """Confirm阶段"""
        for participant in self.participants:
            try:
                participant["confirm"]()
                participant["state"] = TransactionState.COMMITTED
            except Exception as e:
                print(f"Confirm failed for {participant['name']}: {e}")
                return False
        return True
    
    def _cancel_phase(self):
        """Cancel阶段"""
        for participant in self.participants:
            try:
                participant["cancel"]()
                participant["state"] = TransactionState.ROLLED_BACK
            except Exception as e:
                print(f"Cancel failed for {participant['name']}: {e}")


class LocalMessageTable:
    """本地消息表模式"""
    
    def __init__(self, db_connection):
        self.db = db_connection
        self._create_table()
    
    def _create_table(self):
        self.db.execute("""
            CREATE TABLE IF NOT EXISTS local_message (
                id VARCHAR(36) PRIMARY KEY,
                topic VARCHAR(100),
                payload TEXT,
                status VARCHAR(20) DEFAULT 'pending',
                retry_count INT DEFAULT 0,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                processed_at TIMESTAMP NULL
            )
        """)
    
    def begin_transaction(self, business_func: Callable, topic: str, payload: Dict):
        """开始事务"""
        message_id = str(uuid.uuid4())
        
        with self.db.transaction():
            business_func()
            
            self.db.execute(
                """INSERT INTO local_message (id, topic, payload)
                   VALUES (?, ?, ?)""",
                (message_id, topic, str(payload))
            )
        
        return message_id
    
    def get_pending_messages(self, limit: int = 100) -> List[Dict]:
        """获取待处理消息"""
        return self.db.fetch_all(
            """SELECT * FROM local_message 
               WHERE status = 'pending' AND retry_count < 5
               ORDER BY created_at LIMIT ?""",
            (limit,)
        )
    
    def mark_sent(self, message_id: str):
        """标记已发送"""
        self.db.execute(
            "UPDATE local_message SET status = 'sent', processed_at = CURRENT_TIMESTAMP WHERE id = ?",
            (message_id,)
        )
    
    def increment_retry(self, message_id: str):
        """增加重试次数"""
        self.db.execute(
            "UPDATE local_message SET retry_count = retry_count + 1 WHERE id = ?",
            (message_id,)
        )

46.8 本章小结

本章详细介绍了Python分布式系统的核心概念和实践:

  1. 分布式系统基础:CAP定理、BASE理论、一致性哈希
  2. 分布式协调:分布式锁、领导选举、心跳检测
  3. 消息队列:消息模式、可靠传输、延迟队列
  4. 缓存系统:缓存策略、缓存穿透、缓存雪崩
  5. 负载均衡:轮询、加权、一致性哈希、健康检查

练习题

  1. 实现一个分布式锁服务,支持自动续期和故障恢复
  2. 开发一个消息队列系统,支持消息持久化和重试机制
  3. 实现一个缓存系统,支持多种缓存策略和过期策略
  4. 开发一个负载均衡器,支持多种算法和健康检查
  5. 实现一个分布式ID生成器,支持高并发和唯一性保证

扩展阅读

Python技术丛书 - 江苏省宿城中等专业学校