第46章 分布式系统
学习目标
完成本章学习后,你将能够:
- 理解分布式系统原理:CAP定理、BASE理论、分布式一致性
- 实现消息队列:RabbitMQ、Kafka、消息模式、可靠传输
- 设计缓存策略:Redis、缓存模式、缓存穿透、缓存雪崩
- 实现负载均衡:轮询、加权、一致性哈希、健康检查
- 构建高可用系统:故障转移、服务降级、限流熔断
- 处理分布式事务:两阶段提交、Saga模式、最终一致性
- 实现分布式锁:Redis锁、Zookeeper锁、分布式协调
- 设计分布式ID:雪花算法、UUID、数据库序列
46.1 分布式系统基础
46.1.1 核心理论
python
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
import asyncio
import time
import hashlib
class ConsistencyLevel(Enum):
STRONG = "strong"
EVENTUAL = "eventual"
CAUSAL = "causal"
@dataclass
class Node:
id: str
host: str
port: int
status: str = "healthy"
last_heartbeat: float = 0
metadata: Dict[str, str] = field(default_factory=dict)
def is_alive(self, timeout: float = 30.0) -> bool:
return time.time() - self.last_heartbeat < timeout
@dataclass
class DistributedConfig:
replication_factor: int = 3
consistency_level: ConsistencyLevel = ConsistencyLevel.EVENTUAL
read_quorum: int = 2
write_quorum: int = 2
heartbeat_interval: float = 5.0
election_timeout: float = 10.0
class CAPTheorem:
@staticmethod
def explain() -> Dict:
return {
"C": "Consistency - 所有节点在同一时间看到相同的数据",
"A": "Availability - 每个请求都能获得响应(成功或失败)",
"P": "Partition Tolerance - 网络分区时系统仍能运行",
"tradeoffs": {
"CP": "牺牲可用性,保证一致性和分区容错(如HBase)",
"AP": "牺牲一致性,保证可用性和分区容错(如Cassandra)",
"CA": "单机系统,无分区容错(如传统RDBMS)"
}
}
class BASETheory:
@staticmethod
def explain() -> Dict:
return {
"BA": "Basically Available - 基本可用,允许损失部分可用性",
"S": "Soft State - 软状态,允许中间状态存在",
"E": "Eventually Consistent - 最终一致,经过时间后达到一致",
"vs_ACID": {
"ACID": "强一致性,适合金融等关键业务",
"BASE": "最终一致性,适合高并发互联网应用"
}
}
class ConsistentHashing:
def __init__(self, virtual_nodes: int = 150):
self.virtual_nodes = virtual_nodes
self.ring: Dict[int, str] = {}
self.nodes: List[str] = []
def _hash(self, key: str) -> int:
return int(hashlib.md5(key.encode()).hexdigest(), 16)
def add_node(self, node: str) -> None:
self.nodes.append(node)
for i in range(self.virtual_nodes):
virtual_key = f"{node}#{i}"
hash_val = self._hash(virtual_key)
self.ring[hash_val] = node
def remove_node(self, node: str) -> None:
self.nodes.remove(node)
for i in range(self.virtual_nodes):
virtual_key = f"{node}#{i}"
hash_val = self._hash(virtual_key)
self.ring.pop(hash_val, None)
def get_node(self, key: str) -> Optional[str]:
if not self.ring:
return None
hash_val = self._hash(key)
sorted_keys = sorted(self.ring.keys())
for k in sorted_keys:
if hash_val <= k:
return self.ring[k]
return self.ring[sorted_keys[0]]
def get_nodes(self, key: str, count: int = 3) -> List[str]:
if not self.ring:
return []
hash_val = self._hash(key)
sorted_keys = sorted(self.ring.keys())
nodes = []
for k in sorted_keys:
if hash_val <= k:
node = self.ring[k]
if node not in nodes:
nodes.append(node)
if len(nodes) >= count:
return nodes
for k in sorted_keys:
node = self.ring[k]
if node not in nodes:
nodes.append(node)
if len(nodes) >= count:
break
return nodes
class DistributedCounter:
def __init__(self, node_id: str, nodes: List[str]):
self.node_id = node_id
self.nodes = nodes
self._local_count = 0
self._vector_clock: Dict[str, int] = {n: 0 for n in nodes}
def increment(self) -> None:
self._local_count += 1
self._vector_clock[self.node_id] += 1
def get_count(self) -> int:
return self._local_count
def merge(self, other: "DistributedCounter") -> None:
for node in self._vector_clock:
self._vector_clock[node] = max(
self._vector_clock[node],
other._vector_clock[node]
)
def get_vector_clock(self) -> Dict[str, int]:
return self._vector_clock.copy()46.1.2 分布式协调
python
from typing import Optional
import asyncio
class DistributedLock:
def __init__(self, lock_name: str, ttl: int = 30):
self.lock_name = lock_name
self.ttl = ttl
self._locked = False
self._lock_value: Optional[str] = None
async def acquire(self, timeout: float = 10.0) -> bool:
start_time = time.time()
while time.time() - start_time < timeout:
if await self._try_acquire():
self._locked = True
return True
await asyncio.sleep(0.1)
return False
async def _try_acquire(self) -> bool:
import uuid
self._lock_value = str(uuid.uuid4())
return True
async def release(self) -> bool:
if not self._locked:
return False
self._locked = False
self._lock_value = None
return True
@property
def locked(self) -> bool:
return self._locked
class LeaderElection:
def __init__(
self,
node_id: str,
nodes: List[str],
election_timeout: float = 5.0
):
self.node_id = node_id
self.nodes = nodes
self.election_timeout = election_timeout
self._leader_id: Optional[str] = None
self._term = 0
self._is_candidate = False
self._votes_received = 0
@property
def leader_id(self) -> Optional[str]:
return self._leader_id
@property
def is_leader(self) -> bool:
return self._leader_id == self.node_id
async def start_election(self) -> bool:
self._term += 1
self._is_candidate = True
self._votes_received = 1
for node in self.nodes:
if node != self.node_id:
if await self._request_vote(node):
self._votes_received += 1
majority = len(self.nodes) // 2 + 1
if self._votes_received >= majority:
self._leader_id = self.node_id
self._is_candidate = False
return True
self._is_candidate = False
return False
async def _request_vote(self, node: str) -> bool:
return True
def update_leader(self, leader_id: str, term: int) -> None:
if term > self._term:
self._term = term
self._leader_id = leader_id
self._is_candidate = False
class HeartbeatManager:
def __init__(
self,
node_id: str,
interval: float = 5.0,
timeout: float = 15.0
):
self.node_id = node_id
self.interval = interval
self.timeout = timeout
self._last_heartbeat: Dict[str, float] = {}
self._running = False
async def start(self) -> None:
self._running = True
asyncio.create_task(self._heartbeat_loop())
async def stop(self) -> None:
self._running = False
async def _heartbeat_loop(self) -> None:
while self._running:
await self._send_heartbeat()
await asyncio.sleep(self.interval)
async def _send_heartbeat(self) -> None:
pass
def receive_heartbeat(self, node_id: str) -> None:
self._last_heartbeat[node_id] = time.time()
def is_node_alive(self, node_id: str) -> bool:
last = self._last_heartbeat.get(node_id, 0)
return time.time() - last < self.timeout
def get_dead_nodes(self) -> List[str]:
return [
node_id
for node_id in self._last_heartbeat
if not self.is_node_alive(node_id)
]46.2 消息队列
46.2.1 消息队列实现
python
from dataclasses import dataclass
from typing import Any, Optional, Callable
from enum import Enum
import asyncio
import json
from datetime import datetime
class MessageStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class Message:
id: str
queue: str
payload: Any
status: MessageStatus = MessageStatus.PENDING
created_at: datetime = None
processed_at: datetime = None
retry_count: int = 0
max_retries: int = 3
error_message: str = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
class MessageQueue:
def __init__(self, name: str, max_size: int = 10000):
self.name = name
self.max_size = max_size
self._queue: asyncio.Queue = asyncio.Queue(maxsize=max_size)
self._dead_letter_queue: List[Message] = []
self._consumers: List[Callable] = []
async def publish(self, message: Message) -> bool:
try:
await asyncio.wait_for(
self._queue.put(message),
timeout=5.0
)
return True
except asyncio.TimeoutError:
return False
async def consume(self) -> Optional[Message]:
try:
return await asyncio.wait_for(
self._queue.get(),
timeout=1.0
)
except asyncio.TimeoutError:
return None
async def ack(self, message: Message) -> None:
message.status = MessageStatus.COMPLETED
message.processed_at = datetime.utcnow()
async def nack(self, message: Message, error: str = None) -> None:
message.retry_count += 1
message.error_message = error
if message.retry_count >= message.max_retries:
message.status = MessageStatus.FAILED
self._dead_letter_queue.append(message)
else:
message.status = MessageStatus.PENDING
await self.publish(message)
def get_queue_size(self) -> int:
return self._queue.qsize()
def get_dead_letter_count(self) -> int:
return len(self._dead_letter_queue)
class MessageBroker:
def __init__(self):
self._queues: Dict[str, MessageQueue] = {}
self._exchanges: Dict[str, Dict] = {}
def create_queue(self, name: str, max_size: int = 10000) -> MessageQueue:
if name not in self._queues:
self._queues[name] = MessageQueue(name, max_size)
return self._queues[name]
def get_queue(self, name: str) -> Optional[MessageQueue]:
return self._queues.get(name)
async def publish(
self,
queue_name: str,
payload: Any,
max_retries: int = 3
) -> bool:
queue = self.get_queue(queue_name)
if not queue:
queue = self.create_queue(queue_name)
import uuid
message = Message(
id=str(uuid.uuid4()),
queue=queue_name,
payload=payload,
max_retries=max_retries
)
return await queue.publish(message)
async def subscribe(
self,
queue_name: str,
handler: Callable
) -> None:
queue = self.get_queue(queue_name)
if not queue:
queue = self.create_queue(queue_name)
queue._consumers.append(handler)
asyncio.create_task(self._consume_loop(queue, handler))
async def _consume_loop(
self,
queue: MessageQueue,
handler: Callable
) -> None:
while True:
message = await queue.consume()
if message:
try:
message.status = MessageStatus.PROCESSING
result = await handler(message.payload)
await queue.ack(message)
except Exception as e:
await queue.nack(message, str(e))
class PubSub:
def __init__(self):
self._subscribers: Dict[str, List[Callable]] = {}
def subscribe(self, topic: str, handler: Callable) -> None:
if topic not in self._subscribers:
self._subscribers[topic] = []
self._subscribers[topic].append(handler)
def unsubscribe(self, topic: str, handler: Callable) -> None:
if topic in self._subscribers:
self._subscribers[topic].remove(handler)
async def publish(self, topic: str, message: Any) -> int:
if topic not in self._subscribers:
return 0
delivered = 0
for handler in self._subscribers[topic]:
try:
if asyncio.iscoroutinefunction(handler):
await handler(message)
else:
handler(message)
delivered += 1
except Exception:
pass
return delivered
class DelayedMessageQueue:
def __init__(self):
self._messages: List[tuple] = []
async def schedule(
self,
payload: Any,
delay_seconds: float,
queue_name: str = "default"
) -> str:
import uuid
message_id = str(uuid.uuid4())
execute_at = time.time() + delay_seconds
self._messages.append((execute_at, message_id, queue_name, payload))
self._messages.sort(key=lambda x: x[0])
return message_id
async def process_due_messages(
self,
broker: MessageBroker
) -> int:
now = time.time()
due_messages = [
msg for msg in self._messages
if msg[0] <= now
]
self._messages = [
msg for msg in self._messages
if msg[0] > now
]
for _, msg_id, queue_name, payload in due_messages:
await broker.publish(queue_name, payload)
return len(due_messages)46.3 缓存系统
46.3.1 缓存策略
python
from typing import Any, Optional, Dict, List
from dataclasses import dataclass
import time
import hashlib
import asyncio
@dataclass
class CacheEntry:
key: str
value: Any
created_at: float
ttl: Optional[float] = None
hits: int = 0
def is_expired(self) -> bool:
if self.ttl is None:
return False
return time.time() - self.created_at > self.ttl
class CacheStrategy(Enum):
LRU = "lru"
LFU = "lfu"
FIFO = "fifo"
class InMemoryCache:
def __init__(
self,
max_size: int = 1000,
default_ttl: Optional[float] = None,
strategy: CacheStrategy = CacheStrategy.LRU
):
self.max_size = max_size
self.default_ttl = default_ttl
self.strategy = strategy
self._cache: Dict[str, CacheEntry] = {}
self._access_order: List[str] = []
self._lock = asyncio.Lock()
async def get(self, key: str) -> Optional[Any]:
async with self._lock:
entry = self._cache.get(key)
if entry is None:
return None
if entry.is_expired():
await self._remove(key)
return None
entry.hits += 1
self._update_access_order(key)
return entry.value
async def set(
self,
key: str,
value: Any,
ttl: Optional[float] = None
) -> None:
async with self._lock:
if len(self._cache) >= self.max_size and key not in self._cache:
await self._evict()
self._cache[key] = CacheEntry(
key=key,
value=value,
created_at=time.time(),
ttl=ttl or self.default_ttl
)
if key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
async def delete(self, key: str) -> bool:
async with self._lock:
return await self._remove(key)
async def _remove(self, key: str) -> bool:
if key in self._cache:
del self._cache[key]
if key in self._access_order:
self._access_order.remove(key)
return True
return False
async def _evict(self) -> None:
if not self._cache:
return
if self.strategy == CacheStrategy.LRU:
key_to_evict = self._access_order[0]
elif self.strategy == CacheStrategy.LFU:
key_to_evict = min(
self._cache.keys(),
key=lambda k: self._cache[k].hits
)
else:
key_to_evict = self._access_order[0]
await self._remove(key_to_evict)
def _update_access_order(self, key: str) -> None:
if key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
async def clear(self) -> None:
async with self._lock:
self._cache.clear()
self._access_order.clear()
def get_stats(self) -> Dict:
total_hits = sum(e.hits for e in self._cache.values())
return {
"size": len(self._cache),
"max_size": self.max_size,
"total_hits": total_hits,
"keys": list(self._cache.keys())
}
class CacheAside:
def __init__(
self,
cache: InMemoryCache,
data_loader: Callable
):
self.cache = cache
self.data_loader = data_loader
async def get(self, key: str) -> Any:
value = await self.cache.get(key)
if value is not None:
return value
value = await self.data_loader(key)
if value is not None:
await self.cache.set(key, value)
return value
async def invalidate(self, key: str) -> None:
await self.cache.delete(key)
async def refresh(self, key: str) -> Any:
await self.cache.delete(key)
return await self.get(key)
class WriteThrough:
def __init__(
self,
cache: InMemoryCache,
data_writer: Callable
):
self.cache = cache
self.data_writer = data_writer
async def set(self, key: str, value: Any) -> bool:
success = await self.data_writer(key, value)
if success:
await self.cache.set(key, value)
return success
async def get(self, key: str) -> Any:
return await self.cache.get(key)
class CachePenetrationProtection:
def __init__(
self,
cache: InMemoryCache,
bloom_filter_size: int = 10000
):
self.cache = cache
self._bloom_filter: set = set()
self._null_cache_ttl = 60
def _hash_key(self, key: str) -> int:
return int(hashlib.md5(key.encode()).hexdigest(), 16)
def add_to_filter(self, key: str) -> None:
self._bloom_filter.add(self._hash_key(key))
def might_exist(self, key: str) -> bool:
return self._hash_key(key) in self._bloom_filter
async def get(self, key: str) -> Optional[Any]:
if not self.might_exist(key):
return None
return await self.cache.get(key)
async def set_null(self, key: str) -> None:
await self.cache.set(key, None, ttl=self._null_cache_ttl)
class CacheAvalancheProtection:
def __init__(
self,
cache: InMemoryCache,
base_ttl: float = 300
):
self.cache = cache
self.base_ttl = base_ttl
async def set(self, key: str, value: Any) -> None:
import random
jitter = random.uniform(0, self.base_ttl * 0.1)
ttl = self.base_ttl + jitter
await self.cache.set(key, value, ttl=ttl)
async def warm_up(self, keys: List[str], loader: Callable) -> None:
import random
random.shuffle(keys)
for key in keys:
value = await loader(key)
if value is not None:
await self.set(key, value)46.4 负载均衡
46.4.1 负载均衡算法
python
from typing import List, Dict, Optional
from dataclasses import dataclass
import random
import time
@dataclass
class BackendServer:
id: str
host: str
port: int
weight: int = 1
connections: int = 0
healthy: bool = True
response_time: float = 0.0
class LoadBalancer:
def __init__(self, servers: List[BackendServer]):
self.servers = servers
self._current_index = 0
def select(self) -> Optional[BackendServer]:
healthy_servers = [s for s in self.servers if s.healthy]
if not healthy_servers:
return None
return healthy_servers[0]
def mark_unhealthy(self, server_id: str) -> None:
for server in self.servers:
if server.id == server_id:
server.healthy = False
def mark_healthy(self, server_id: str) -> None:
for server in self.servers:
if server.id == server_id:
server.healthy = True
class RoundRobinLoadBalancer(LoadBalancer):
def select(self) -> Optional[BackendServer]:
healthy_servers = [s for s in self.servers if s.healthy]
if not healthy_servers:
return None
server = healthy_servers[self._current_index % len(healthy_servers)]
self._current_index += 1
return server
class WeightedRoundRobinLoadBalancer(LoadBalancer):
def __init__(self, servers: List[BackendServer]):
super().__init__(servers)
self._current_weights: Dict[str, int] = {}
def select(self) -> Optional[BackendServer]:
healthy_servers = [s for s in self.servers if s.healthy]
if not healthy_servers:
return None
total_weight = sum(s.weight for s in healthy_servers)
if total_weight == 0:
return random.choice(healthy_servers)
r = random.randint(1, total_weight)
current_weight = 0
for server in healthy_servers:
current_weight += server.weight
if r <= current_weight:
return server
return healthy_servers[0]
class LeastConnectionsLoadBalancer(LoadBalancer):
def select(self) -> Optional[BackendServer]:
healthy_servers = [s for s in self.servers if s.healthy]
if not healthy_servers:
return None
return min(healthy_servers, key=lambda s: s.connections)
class IPHashLoadBalancer(LoadBalancer):
def select(self, client_ip: str = None) -> Optional[BackendServer]:
healthy_servers = [s for s in self.servers if s.healthy]
if not healthy_servers:
return None
if client_ip is None:
return random.choice(healthy_servers)
hash_val = int(hashlib.md5(client_ip.encode()).hexdigest(), 16)
index = hash_val % len(healthy_servers)
return healthy_servers[index]
class ConsistentHashLoadBalancer(LoadBalancer):
def __init__(
self,
servers: List[BackendServer],
virtual_nodes: int = 150
):
super().__init__(servers)
self.virtual_nodes = virtual_nodes
self._ring: Dict[int, BackendServer] = {}
self._build_ring()
def _build_ring(self) -> None:
for server in self.servers:
for i in range(self.virtual_nodes):
key = f"{server.id}#{i}"
hash_val = int(hashlib.md5(key.encode()).hexdigest(), 16)
self._ring[hash_val] = server
def select(self, key: str = None) -> Optional[BackendServer]:
healthy_servers = [s for s in self.servers if s.healthy]
if not healthy_servers:
return None
if key is None:
return random.choice(healthy_servers)
hash_val = int(hashlib.md5(key.encode()).hexdigest(), 16)
sorted_keys = sorted(self._ring.keys())
for k in sorted_keys:
if hash_val <= k and self._ring[k].healthy:
return self._ring[k]
for k in sorted_keys:
if self._ring[k].healthy:
return self._ring[k]
return None
class HealthChecker:
def __init__(
self,
load_balancer: LoadBalancer,
check_interval: float = 10.0,
timeout: float = 5.0,
unhealthy_threshold: int = 3,
healthy_threshold: int = 2
):
self.load_balancer = load_balancer
self.check_interval = check_interval
self.timeout = timeout
self.unhealthy_threshold = unhealthy_threshold
self.healthy_threshold = healthy_threshold
self._failure_counts: Dict[str, int] = {}
self._success_counts: Dict[str, int] = {}
async def start(self) -> None:
while True:
await self._check_all()
await asyncio.sleep(self.check_interval)
async def _check_all(self) -> None:
for server in self.load_balancer.servers:
is_healthy = await self._check_server(server)
if is_healthy:
self._failure_counts[server.id] = 0
self._success_counts[server.id] = self._success_counts.get(server.id, 0) + 1
if self._success_counts[server.id] >= self.healthy_threshold:
self.load_balancer.mark_healthy(server.id)
else:
self._success_counts[server.id] = 0
self._failure_counts[server.id] = self._failure_counts.get(server.id, 0) + 1
if self._failure_counts[server.id] >= self.unhealthy_threshold:
self.load_balancer.mark_unhealthy(server.id)
async def _check_server(self, server: BackendServer) -> bool:
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://{server.host}:{server.port}/health",
timeout=aiohttp.ClientTimeout(total=self.timeout)
) as response:
return response.status == 200
except Exception:
return False46.5 知识图谱
46.5.1 分布式系统架构
┌─────────────────────────────────────────────────────────────────────┐
│ 分布式系统架构全景图 │
├─────────────────────────────────────────────────────────────────────┤
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 客户端层 (Client) │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 负载均衡 │ │ CDN │ │ DNS │ │ API网关 │ │ │
│ │ │ Nginx │ │ CloudFlare│ │ Route53 │ │ Kong │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────┐ │
│ │ 服务层 (Services) │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 服务发现 │ │ 配置中心 │ │ 服务网格 │ │ 容器编排 │ │ │
│ │ │ Consul │ │ Apollo │ │ Istio │ │ K8s │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────┐ │
│ │ 数据层 (Data) │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 分布式DB │ │ 缓存 │ │ 消息队列 │ │ 搜索引擎 │ │ │
│ │ │ TiDB │ │ Redis │ │ Kafka │ │ ES │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────┐ │
│ │ 协调层 (Coordination) │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 分布式锁 │ │ 选举 │ │ 事务 │ │ 一致性 │ │ │
│ │ │ Redis │ │ Raft │ │ 2PC/Saga│ │ Paxos │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────┐ │
│ │ 可观测性 (Observability) │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 监控 │ │ 日志 │ │ 追踪 │ │ 告警 │ │ │
│ │ │Prometheus│ │ ELK │ │ Jaeger │ │AlertMngr│ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘46.5.2 CAP定理与BASE理论
┌─────────────────────────────────────────────────────────────────────┐
│ CAP定理图解 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ Consistency │
│ (一致性) │
│ /\ │
│ / \ │
│ / \ │
│ / CA \ ┌────────────────────────┐ │
│ / \ │ CA: 传统关系数据库 │ │
│ /──────────\ │ PostgreSQL, MySQL │ │
│ / \ └────────────────────────┘ │
│ / CP \ │
│ / \ ┌────────────────────────┐ │
│ /──────────────────\ │ CP: 分布式数据库 │ │
│ / \ │ MongoDB, HBase │ │
│ / AP \└────────────────────────┘ │
│ / \ │
│ /──────────────────────────\ │
│ Partition Tolerance Availability │
│ (分区容错) (可用性) │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ AP系统: 高可用、最终一致 (Cassandra, DynamoDB, CouchDB) │ │
│ │ CP系统: 强一致、可能不可用 (Zookeeper, Redis Cluster) │ │
│ │ CA系统: 无分区容错 (单机数据库) │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ BASE理论 │ │
│ │ Basically Available: 基本可用 │ │
│ │ Soft State: 软状态(中间状态) │ │
│ │ Eventually Consistent: 最终一致 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘46.6 技术选型指南
46.6.1 分布式协调服务选型
| 服务 | 一致性模型 | 性能 | 功能丰富度 | 学习曲线 | 推荐指数 |
|---|---|---|---|---|---|
| Zookeeper | CP | 中 | 高 | 高 | ★★★★☆ |
| etcd | CP | 高 | 中 | 中 | ★★★★★ |
| Consul | CP | 高 | 极高 | 低 | ★★★★★ |
| Redis | AP | 极高 | 中 | 低 | ★★★★☆ |
46.6.2 消息队列选型
| 消息队列 | 吞吐量 | 延迟 | 持久化 | 顺序保证 | 推荐场景 |
|---|---|---|---|---|---|
| Kafka | 极高 | 中 | ✅ | 分区内有序 | 日志、大数据 |
| RabbitMQ | 高 | 低 | ✅ | ✅ | 企业应用 |
| RocketMQ | 极高 | 低 | ✅ | ✅ | 金融交易 |
| Redis Stream | 高 | 极低 | 可选 | ✅ | 实时消息 |
| Pulsar | 极高 | 低 | ✅ | ✅ | 云原生应用 |
46.6.3 分布式缓存选型
| 缓存方案 | 性能 | 持久化 | 集群支持 | 数据结构 | 推荐指数 |
|---|---|---|---|---|---|
| Redis Cluster | 极高 | ✅ | ✅ | 丰富 | ★★★★★ |
| Memcached | 极高 | ❌ | ✅ | 简单 | ★★★☆☆ |
| Hazelcast | 高 | ✅ | ✅ | 丰富 | ★★★★☆ |
| Aerospike | 极高 | ✅ | ✅ | 中等 | ★★★★☆ |
46.7 常见问题与解决方案
46.7.1 分布式锁实现
python
import time
import uuid
import threading
from typing import Optional, Callable
from dataclasses import dataclass
@dataclass
class LockOptions:
"""锁选项"""
lock_timeout: float = 30.0
wait_timeout: float = 10.0
retry_interval: float = 0.1
class DistributedLock:
"""分布式锁基类"""
def __init__(self, name: str, options: LockOptions = None):
self.name = name
self.options = options or LockOptions()
self._lock_value = str(uuid.uuid4())
self._acquired = False
def acquire(self) -> bool:
"""获取锁"""
raise NotImplementedError
def release(self) -> bool:
"""释放锁"""
raise NotImplementedError
def extend(self, additional_time: float) -> bool:
"""延长锁时间"""
raise NotImplementedError
def __enter__(self):
if not self.acquire():
raise TimeoutError(f"Failed to acquire lock: {self.name}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
return False
class RedisDistributedLock(DistributedLock):
"""Redis分布式锁"""
def __init__(self, redis_client, name: str, options: LockOptions = None):
super().__init__(name, options)
self.redis = redis_client
def acquire(self) -> bool:
"""获取锁"""
start_time = time.time()
while True:
acquired = self.redis.set(
self.name,
self._lock_value,
nx=True,
ex=int(self.options.lock_timeout)
)
if acquired:
self._acquired = True
return True
if time.time() - start_time >= self.options.wait_timeout:
return False
time.sleep(self.options.retry_interval)
def release(self) -> bool:
"""释放锁(Lua脚本保证原子性)"""
if not self._acquired:
return False
lua_script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
result = self.redis.eval(lua_script, 1, self.name, self._lock_value)
self._acquired = False
return bool(result)
def extend(self, additional_time: float) -> bool:
"""延长锁时间"""
if not self._acquired:
return False
lua_script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end
"""
result = self.redis.eval(
lua_script, 1,
self.name,
self._lock_value,
int(additional_time)
)
return bool(result)
class ZookeeperDistributedLock(DistributedLock):
"""Zookeeper分布式锁"""
def __init__(self, zk_client, name: str, options: LockOptions = None):
super().__init__(name, options)
self.zk = zk_client
self._lock_path = f"/locks/{name}"
self._node_path = None
def acquire(self) -> bool:
"""获取锁"""
start_time = time.time()
while True:
try:
self._ensure_path(self._lock_path)
self._node_path = self.zk.create(
f"{self._lock_path}/lock-",
value=self._lock_value.encode(),
sequence=True,
ephemeral=True
)
if self._is_lowest_sequence():
self._acquired = True
return True
self._watch_previous()
if time.time() - start_time >= self.options.wait_timeout:
self._cleanup()
return False
except Exception:
self._cleanup()
if time.time() - start_time >= self.options.wait_timeout:
return False
time.sleep(self.options.retry_interval)
def _is_lowest_sequence(self) -> bool:
children = self.zk.get_children(self._lock_path)
if not children:
return True
sequences = sorted(children)
my_sequence = self._node_path.split('/')[-1]
return sequences[0] == my_sequence
def _watch_previous(self):
children = sorted(self.zk.get_children(self._lock_path))
my_sequence = self._node_path.split('/')[-1]
my_index = children.index(my_sequence)
if my_index > 0:
previous = f"{self._lock_path}/{children[my_index - 1]}"
event = threading.Event()
self.zk.exists(previous, watch=lambda e: event.set())
event.wait(timeout=self.options.wait_timeout)
def _ensure_path(self, path):
if not self.zk.exists(path):
self.zk.ensure_path(path)
def _cleanup(self):
if self._node_path and self.zk.exists(self._node_path):
self.zk.delete(self._node_path)
def release(self) -> bool:
if not self._acquired:
return False
self._cleanup()
self._acquired = False
return True46.7.2 分布式ID生成
python
import time
import threading
from typing import Optional
from dataclasses import dataclass
@dataclass
class SnowflakeConfig:
"""雪花算法配置"""
worker_id_bits: int = 5
datacenter_id_bits: int = 5
sequence_bits: int = 12
worker_id: int = 0
datacenter_id: int = 0
epoch: int = 1704067200000 # 2024-01-01 00:00:00
class SnowflakeIDGenerator:
"""雪花算法ID生成器"""
def __init__(self, config: SnowflakeConfig = None):
self.config = config or SnowflakeConfig()
self._max_worker_id = -1 ^ (-1 << self.config.worker_id_bits)
self._max_datacenter_id = -1 ^ (-1 << self.config.datacenter_id_bits)
self._sequence_mask = -1 ^ (-1 << self.config.sequence_bits)
self._worker_id_shift = self.config.sequence_bits
self._datacenter_id_shift = (
self.config.sequence_bits + self.config.worker_id_bits
)
self._timestamp_shift = (
self.config.sequence_bits +
self.config.worker_id_bits +
self.config.datacenter_id_bits
)
self._sequence = 0
self._last_timestamp = -1
self._lock = threading.Lock()
if self.config.worker_id > self._max_worker_id or self.config.worker_id < 0:
raise ValueError(f"Worker ID must be 0-{self._max_worker_id}")
if self.config.datacenter_id > self._max_datacenter_id or self.config.datacenter_id < 0:
raise ValueError(f"Datacenter ID must be 0-{self._max_datacenter_id}")
def generate(self) -> int:
"""生成ID"""
with self._lock:
timestamp = self._current_timestamp()
if timestamp < self._last_timestamp:
raise RuntimeError(
f"Clock moved backwards. Refusing to generate id for "
f"{self._last_timestamp - timestamp} milliseconds"
)
if timestamp == self._last_timestamp:
self._sequence = (self._sequence + 1) & self._sequence_mask
if self._sequence == 0:
timestamp = self._wait_next_millis(self._last_timestamp)
else:
self._sequence = 0
self._last_timestamp = timestamp
return (
((timestamp - self.config.epoch) << self._timestamp_shift) |
(self.config.datacenter_id << self._datacenter_id_shift) |
(self.config.worker_id << self._worker_id_shift) |
self._sequence
)
def _current_timestamp(self) -> int:
return int(time.time() * 1000)
def _wait_next_millis(self, last_timestamp: int) -> int:
timestamp = self._current_timestamp()
while timestamp <= last_timestamp:
timestamp = self._current_timestamp()
return timestamp
@staticmethod
def parse_id(id: int) -> dict:
"""解析ID"""
timestamp_bits = 41
datacenter_bits = 5
worker_bits = 5
sequence_bits = 12
sequence_mask = (1 << sequence_bits) - 1
worker_mask = (1 << worker_bits) - 1
datacenter_mask = (1 << datacenter_bits) - 1
sequence = id & sequence_mask
worker_id = (id >> sequence_bits) & worker_mask
datacenter_id = (id >> (sequence_bits + worker_bits)) & datacenter_mask
timestamp = id >> (sequence_bits + worker_bits + datacenter_bits)
return {
"timestamp": timestamp + 1704067200000,
"datacenter_id": datacenter_id,
"worker_id": worker_id,
"sequence": sequence
}
class UUIDGenerator:
"""UUID生成器"""
@staticmethod
def v4() -> str:
"""生成UUID v4"""
import uuid
return str(uuid.uuid4())
@staticmethod
def v7() -> str:
"""生成UUID v7(时间排序)"""
import uuid
import time
timestamp = int(time.time() * 1000)
timestamp_bytes = timestamp.to_bytes(6, 'big')
random_bytes = uuid.uuid4().bytes[6:]
uuid_bytes = timestamp_bytes + random_bytes
return str(uuid.UUID(bytes=uuid_bytes))
class LeafIDGenerator:
"""Leaf ID生成器(号段模式)"""
def __init__(self, db_connection, biz_tag: str, step: int = 1000):
self.db = db_connection
self.biz_tag = biz_tag
self.step = step
self._current_id = 0
self._max_id = 0
self._lock = threading.Lock()
def generate(self) -> int:
"""生成ID"""
with self._lock:
if self._current_id >= self._max_id:
self._load_segment()
self._current_id += 1
return self._current_id
def _load_segment(self):
"""加载号段"""
self.db.execute(
"UPDATE leaf_alloc SET max_id = max_id + ? WHERE biz_tag = ?",
(self.step, self.biz_tag)
)
result = self.db.fetch_one(
"SELECT max_id FROM leaf_alloc WHERE biz_tag = ?",
(self.biz_tag,)
)
if result:
self._max_id = result[0]
self._current_id = self._max_id - self.step46.7.3 分布式事务模式
python
from typing import List, Dict, Any, Callable, Optional
from dataclasses import dataclass, field
from enum import Enum
import uuid
class TransactionState(Enum):
PENDING = "pending"
COMMITTED = "committed"
ROLLED_BACK = "rolled_back"
@dataclass
class TransactionStep:
"""事务步骤"""
name: str
execute: Callable
compensate: Callable
state: TransactionState = TransactionState.PENDING
class SagaTransaction:
"""Saga事务"""
def __init__(self, name: str):
self.name = name
self.transaction_id = str(uuid.uuid4())
self.steps: List[TransactionStep] = []
self.executed_steps: List[int] = []
def add_step(
self,
name: str,
execute: Callable,
compensate: Callable
) -> "SagaTransaction":
"""添加步骤"""
self.steps.append(TransactionStep(
name=name,
execute=execute,
compensate=compensate
))
return self
def execute(self) -> Dict:
"""执行事务"""
for i, step in enumerate(self.steps):
try:
step.execute()
step.state = TransactionState.COMMITTED
self.executed_steps.append(i)
except Exception as e:
self._compensate()
return {
"success": False,
"transaction_id": self.transaction_id,
"failed_step": step.name,
"error": str(e)
}
return {
"success": True,
"transaction_id": self.transaction_id
}
def _compensate(self):
"""执行补偿"""
for i in reversed(self.executed_steps):
step = self.steps[i]
try:
step.compensate()
step.state = TransactionState.ROLLED_BACK
except Exception as e:
print(f"Compensation failed for {step.name}: {e}")
class TCCTransaction:
"""TCC事务(Try-Confirm-Cancel)"""
def __init__(self, name: str):
self.name = name
self.transaction_id = str(uuid.uuid4())
self.participants: List[Dict] = []
def add_participant(
self,
name: str,
try_func: Callable,
confirm_func: Callable,
cancel_func: Callable
) -> "TCCTransaction":
"""添加参与者"""
self.participants.append({
"name": name,
"try": try_func,
"confirm": confirm_func,
"cancel": cancel_func,
"state": TransactionState.PENDING
})
return self
def execute(self) -> Dict:
"""执行TCC事务"""
if not self._try_phase():
self._cancel_phase()
return {
"success": False,
"transaction_id": self.transaction_id,
"phase": "try"
}
if not self._confirm_phase():
return {
"success": False,
"transaction_id": self.transaction_id,
"phase": "confirm"
}
return {
"success": True,
"transaction_id": self.transaction_id
}
def _try_phase(self) -> bool:
"""Try阶段"""
for participant in self.participants:
try:
participant["try"]()
except Exception:
return False
return True
def _confirm_phase(self) -> bool:
"""Confirm阶段"""
for participant in self.participants:
try:
participant["confirm"]()
participant["state"] = TransactionState.COMMITTED
except Exception as e:
print(f"Confirm failed for {participant['name']}: {e}")
return False
return True
def _cancel_phase(self):
"""Cancel阶段"""
for participant in self.participants:
try:
participant["cancel"]()
participant["state"] = TransactionState.ROLLED_BACK
except Exception as e:
print(f"Cancel failed for {participant['name']}: {e}")
class LocalMessageTable:
"""本地消息表模式"""
def __init__(self, db_connection):
self.db = db_connection
self._create_table()
def _create_table(self):
self.db.execute("""
CREATE TABLE IF NOT EXISTS local_message (
id VARCHAR(36) PRIMARY KEY,
topic VARCHAR(100),
payload TEXT,
status VARCHAR(20) DEFAULT 'pending',
retry_count INT DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processed_at TIMESTAMP NULL
)
""")
def begin_transaction(self, business_func: Callable, topic: str, payload: Dict):
"""开始事务"""
message_id = str(uuid.uuid4())
with self.db.transaction():
business_func()
self.db.execute(
"""INSERT INTO local_message (id, topic, payload)
VALUES (?, ?, ?)""",
(message_id, topic, str(payload))
)
return message_id
def get_pending_messages(self, limit: int = 100) -> List[Dict]:
"""获取待处理消息"""
return self.db.fetch_all(
"""SELECT * FROM local_message
WHERE status = 'pending' AND retry_count < 5
ORDER BY created_at LIMIT ?""",
(limit,)
)
def mark_sent(self, message_id: str):
"""标记已发送"""
self.db.execute(
"UPDATE local_message SET status = 'sent', processed_at = CURRENT_TIMESTAMP WHERE id = ?",
(message_id,)
)
def increment_retry(self, message_id: str):
"""增加重试次数"""
self.db.execute(
"UPDATE local_message SET retry_count = retry_count + 1 WHERE id = ?",
(message_id,)
)46.8 本章小结
本章详细介绍了Python分布式系统的核心概念和实践:
- 分布式系统基础:CAP定理、BASE理论、一致性哈希
- 分布式协调:分布式锁、领导选举、心跳检测
- 消息队列:消息模式、可靠传输、延迟队列
- 缓存系统:缓存策略、缓存穿透、缓存雪崩
- 负载均衡:轮询、加权、一致性哈希、健康检查
练习题
- 实现一个分布式锁服务,支持自动续期和故障恢复
- 开发一个消息队列系统,支持消息持久化和重试机制
- 实现一个缓存系统,支持多种缓存策略和过期策略
- 开发一个负载均衡器,支持多种算法和健康检查
- 实现一个分布式ID生成器,支持高并发和唯一性保证