Skip to content

第11章 享元模式

学习目标

完成本章学习后,读者将能够:

  • 理解享元模式的核心概念与形式化定义
  • 精确区分内部状态(Intrinsic State)与外部状态(Extrinsic State)
  • 掌握享元模式的多种Python实现技术
  • 分析享元模式的时间-空间权衡(Time-Space Trade-off)
  • 设计线程安全的享元对象与工厂
  • 识别享元模式的适用场景与反模式

11.1 模式定义

11.1.1 正式定义

享元模式(Flyweight Pattern) 是一种结构型设计模式,通过共享技术有效支持大量细粒度对象,以最小化内存使用和共享成本。

$$\text{Flyweight}: \mathcal{O} \xrightarrow{\text{share}} \mathcal{O}' \text{ where } |\mathcal{O}'| \ll |\mathcal{O}|$$

其中:

  • $\mathcal{O}$ 表示原始对象集合
  • $\mathcal{O}'$ 表示共享后的享元对象集合
  • $|\mathcal{O}'| \ll |\mathcal{O}|$ 表示享元对象数量远小于原始对象数量

内存节省公式

$$\text{Memory Saved} = n \times (S_{total} - S_{intrinsic}) - S_{factory}$$

其中:

  • $n$ 为对象实例数量
  • $S_{total}$ 为单个对象的总状态大小
  • $S_{intrinsic}$ 为内部状态大小
  • $S_{factory}$ 为享元工厂的开销

状态分离原则

$$\text{State}{total} = \text{State} \cup \text{State}_{extrinsic}$$

$$\text{State}{intrinsic} \cap \text{State} = \emptyset$$

11.1.2 历史背景与学术脉络

时期发展阶段关键贡献代表人物/文献
1987概念萌芽Smalltalk中的字符共享技术ParcPlace Systems
1994GoF正式定义《设计模式》收录为结构型模式Gamma, Helm, Johnson, Vlissides
1995理论深化对象池与享元的关系研究Schmidt, Coplien
2000Java应用Java字符串常量池的经典实现Sun Microsystems
2005.NET集成CLR字符串驻留机制Microsoft
2010游戏开发游戏引擎中的粒子系统优化Game Programming Gems
2015云原生容器镜像层的共享存储机制Docker, Kubernetes
2020现代演进Python __slots__ 与享元的结合Python Community

11.1.3 模式动机

问题场景:当系统中存在大量相似对象时,直接创建每个对象会产生以下问题:

  1. 内存爆炸:$n$ 个对象 $\times$ 每个对象大小 $s$ = 总内存 $n \times s$
  2. GC压力:大量短生命周期对象增加垃圾回收负担
  3. 缓存失效:CPU缓存命中率下降,性能降低

解决方案:将对象状态分为可共享的内部状态和不可共享的外部状态:

┌─────────────────────────────────────────────────────────────┐
│                      对象状态分解                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│    原始对象                    享元重构                      │
│    ┌──────────────┐           ┌──────────────┐              │
│    │  状态 A      │           │  内部状态    │ ← 共享       │
│    │  状态 B      │    →      │  (Intrinsic) │              │
│    │  状态 C      │           └──────────────┘              │
│    │  状态 D      │           ┌──────────────┐              │
│    └──────────────┘           │  外部状态    │ ← 不共享     │
│                               │  (Extrinsic) │              │
│    每对象独立存储              └──────────────┘              │
│                               分离存储,按需传递             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

11.2 理论基础

11.2.1 内部状态与外部状态

内部状态(Intrinsic State)

  • 定义:存储在享元对象内部,可以共享的状态
  • 特性:不可变(Immutable)、线程安全、与上下文无关
  • 存储:享元对象内部
  • 生命周期:与享元对象相同

外部状态(Extrinsic State)

  • 定义:随上下文变化,不能共享的状态
  • 特性:可变(Mutable)、上下文相关、由客户端管理
  • 存储:客户端或上下文对象
  • 生命周期:由客户端控制

状态分类决策矩阵

状态特征内部状态外部状态
可变性不可变可变
共享性可共享不可共享
存储位置享元对象内客户端
线程安全天然安全需要同步
示例字体、颜色、纹理位置、速度、年龄

11.2.2 UML结构模型

┌─────────────────────────────────────────────────────────────────────────┐
│                         享元模式结构图                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌─────────────────────┐         ┌─────────────────────────────┐        │
│  │   <<interface>>     │         │     FlyweightFactory        │        │
│  │     Flyweight       │         ├─────────────────────────────┤        │
│  ├─────────────────────┤         │ - flyweights: Dict[K, F]    │        │
│  │ + operation(        │─────────│ - _lock: threading.Lock     │        │
│  │   extrinsic: Any)   │  uses   │ + get_flyweight(key): F     │        │
│  └─────────────────────┘         │ + list_flyweights(): List   │        │
│            △                      │ + count(): int              │        │
│            │                      └─────────────────────────────┘        │
│            │ implements                    │ creates                     │
│  ┌─────────┴──────────┐                    │ manages                      │
│  │ ConcreteFlyweight  │◄───────────────────┘                              │
│  ├────────────────────┤                                                  │
│  │ - intrinsic_state  │  ← 内部状态(共享)                               │
│  │ + operation(       │                                                  │
│  │   extrinsic: Any)  │  ← 外部状态通过参数传入                           │
│  └────────────────────┘                                                  │
│                                                                         │
│  ┌────────────────────┐                                                 │
│  │  UnsharedConcrete  │  ← 非共享具体享元(可选)                         │
│  │     Flyweight      │                                                 │
│  ├────────────────────┤                                                 │
│  │ - all_state        │                                                 │
│  └────────────────────┘                                                 │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

11.2.3 参与者职责

参与者职责Python实现要点
Flyweight声明接口,通过该接口享元可以接受并作用于外部状态ABC或Protocol
ConcreteFlyweight实现Flyweight接口,存储内部状态frozen dataclass确保不可变
UnsharedConcreteFlyweight不需要共享的享元子类普通类
FlyweightFactory创建并管理享元对象,确保合理共享类变量字典 + 线程锁
Client维护外部状态,使用享元对象调用工厂获取享元

11.3 Python实现

11.3.1 标准ABC实现

python
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import threading
from dataclasses import dataclass


class Flyweight(ABC):
    """享元抽象基类"""
    
    @abstractmethod
    def operation(self, extrinsic_state: Any) -> str:
        """执行操作,接受外部状态"""
        pass
    
    @property
    @abstractmethod
    def intrinsic_state(self) -> Any:
        """获取内部状态"""
        pass


@dataclass(frozen=True)
class ConcreteFlyweight(Flyweight):
    """具体享元:存储内部状态"""
    _shared_state: str
    
    @property
    def intrinsic_state(self) -> str:
        return self._shared_state
    
    def operation(self, extrinsic_state: Any) -> str:
        return (
            f"享元操作: 内部状态[{self._shared_state}], "
            f"外部状态[{extrinsic_state}]"
        )


class FlyweightFactory:
    """享元工厂:管理享元对象池"""
    
    _flyweights: Dict[str, Flyweight] = {}
    _lock = threading.Lock()
    
    @classmethod
    def get_flyweight(cls, key: str) -> Flyweight:
        """获取享元对象(线程安全)"""
        if key not in cls._flyweights:
            with cls._lock:
                if key not in cls._flyweights:
                    cls._flyweights[key] = ConcreteFlyweight(key)
        return cls._flyweights[key]
    
    @classmethod
    def get_flyweight_with_state(cls, key: str, state: str) -> Flyweight:
        """获取或创建指定状态的享元"""
        if key not in cls._flyweights:
            with cls._lock:
                if key not in cls._flyweights:
                    cls._flyweights[key] = ConcreteFlyweight(state)
        return cls._flyweights[key]
    
    @classmethod
    def list_flyweights(cls) -> list:
        return list(cls._flyweights.keys())
    
    @classmethod
    def count(cls) -> int:
        return len(cls._flyweights)
    
    @classmethod
    def clear(cls) -> None:
        with cls._lock:
            cls._flyweights.clear()


class Client:
    """客户端:维护外部状态"""
    
    def __init__(self):
        self._extrinsic_states: Dict[str, Any] = {}
    
    def operation(self, key: str, extrinsic: Any) -> str:
        flyweight = FlyweightFactory.get_flyweight(key)
        return flyweight.operation(extrinsic)


if __name__ == "__main__":
    factory = FlyweightFactory
    
    fw1 = factory.get_flyweight("A")
    fw2 = factory.get_flyweight("A")
    fw3 = factory.get_flyweight("B")
    
    print(fw1.operation("外部状态1"))
    print(fw2.operation("外部状态2"))
    print(f"fw1 is fw2: {fw1 is fw2}")
    print(f"享元数量: {factory.count()}")

11.3.2 Protocol实现(结构化类型)

python
from typing import Protocol, Any, Dict, Tuple, runtime_checkable
import threading
from dataclasses import dataclass
import hashlib


@runtime_checkable
class FlyweightProtocol(Protocol):
    """享元协议:定义享元接口"""
    
    @property
    def intrinsic_state(self) -> Any: ...
    
    def operation(self, extrinsic_state: Any) -> str: ...


@dataclass(frozen=True)
class TreeType:
    """树木类型享元:不可变内部状态"""
    name: str
    color: str
    texture: str
    
    @property
    def intrinsic_state(self) -> Tuple[str, str, str]:
        return (self.name, self.color, self.texture)
    
    def operation(self, extrinsic_state: Tuple[int, int]) -> str:
        x, y = extrinsic_state
        return f"绘制{self.name}树在({x}, {y}), 颜色:{self.color}, 纹理:{self.texture}"
    
    def draw(self, x: int, y: int) -> str:
        return self.operation((x, y))


class TreeTypeFactory:
    """树木类型工厂:享元管理"""
    
    _tree_types: Dict[str, TreeType] = {}
    _lock = threading.Lock()
    
    @classmethod
    def _generate_key(cls, name: str, color: str, texture: str) -> str:
        """生成唯一键"""
        return f"{name}|{color}|{texture}"
    
    @classmethod
    def get_tree_type(cls, name: str, color: str, texture: str) -> TreeType:
        """获取或创建树木类型"""
        key = cls._generate_key(name, color, texture)
        if key not in cls._tree_types:
            with cls._lock:
                if key not in cls._tree_types:
                    cls._tree_types[key] = TreeType(name, color, texture)
        return cls._tree_types[key]
    
    @classmethod
    def count(cls) -> int:
        return len(cls._tree_types)
    
    @classmethod
    def list_types(cls) -> list:
        return list(cls._tree_types.values())


@dataclass
class Tree:
    """树木:外部状态容器"""
    x: int
    y: int
    tree_type: TreeType
    
    def draw(self) -> str:
        return self.tree_type.draw(self.x, self.y)


class Forest:
    """森林:管理大量树木"""
    
    def __init__(self):
        self._trees: list[Tree] = []
    
    def plant_tree(self, x: int, y: int, name: str, color: str, texture: str) -> None:
        tree_type = TreeTypeFactory.get_tree_type(name, color, texture)
        tree = Tree(x, y, tree_type)
        self._trees.append(tree)
    
    def draw(self) -> None:
        for tree in self._trees:
            print(tree.draw())
    
    def tree_count(self) -> int:
        return len(self._trees)


if __name__ == "__main__":
    forest = Forest()
    
    for i in range(1000):
        forest.plant_tree(i % 100, i // 100, "橡树", "绿色", "粗糙")
    
    for i in range(500):
        forest.plant_tree(i % 50, i // 50 + 10, "松树", "深绿", "光滑")
    
    print(f"树木总数: {forest.tree_count()}")
    print(f"树木类型数: {TreeTypeFactory.count()}")

11.3.3 Generic泛型实现

python
from typing import TypeVar, Generic, Dict, Hashable, Any, Optional
from dataclasses import dataclass
from abc import ABC, abstractmethod
import threading
from functools import lru_cache


K = TypeVar('K', bound=Hashable)
S = TypeVar('S')
E = TypeVar('E')


class FlyweightBase(ABC, Generic[S, E]):
    """享元基类(泛型)"""
    
    @property
    @abstractmethod
    def intrinsic_state(self) -> S: ...
    
    @abstractmethod
    def operation(self, extrinsic: E) -> str: ...


class FlyweightFactoryGeneric(Generic[K, S, E]):
    """泛型享元工厂"""
    
    def __init__(self, create_func: callable):
        self._flyweights: Dict[K, FlyweightBase[S, E]] = {}
        self._lock = threading.Lock()
        self._create_func = create_func
    
    def get(self, key: K) -> FlyweightBase[S, E]:
        if key not in self._flyweights:
            with self._lock:
                if key not in self._flyweights:
                    self._flyweights[key] = self._create_func(key)
        return self._flyweights[key]
    
    def get_or_create(self, key: K, state: S) -> FlyweightBase[S, E]:
        if key not in self._flyweights:
            with self._lock:
                if key not in self._flyweights:
                    self._flyweights[key] = self._create_func(key, state)
        return self._flyweights[key]
    
    def contains(self, key: K) -> bool:
        return key in self._flyweights
    
    def count(self) -> int:
        return len(self._flyweights)
    
    def clear(self) -> None:
        with self._lock:
            self._flyweights.clear()


@dataclass(frozen=True)
class CharacterStyle(FlyweightBase[tuple, tuple]):
    """字符样式享元"""
    font: str
    size: int
    color: str
    bold: bool = False
    italic: bool = False
    
    @property
    def intrinsic_state(self) -> tuple:
        return (self.font, self.size, self.color, self.bold, self.italic)
    
    def operation(self, extrinsic: tuple) -> str:
        char, x, y = extrinsic
        style = "bold" if self.bold else ""
        style += " italic" if self.italic else ""
        return f"'{char}' at ({x}, {y}) [{self.font} {self.size}pt {self.color} {style.strip()}]"


def create_character_style(key: tuple, state: Optional[tuple] = None) -> CharacterStyle:
    """字符样式工厂函数"""
    if state:
        return CharacterStyle(*state)
    return CharacterStyle(*key)


if __name__ == "__main__":
    factory = FlyweightFactoryGeneric[bytes, tuple, tuple](create_character_style)
    
    style1 = factory.get_or_create(
        (b'arial_12_black', ('Arial', 12, 'black', False, False))
    )
    style2 = factory.get_or_create(
        (b'arial_12_black', ('Arial', 12, 'black', False, False))
    )
    
    print(style1.operation(('A', 0, 0)))
    print(style1.operation(('B', 10, 0)))
    print(f"style1 is style2: {style1 is style2}")

11.3.4 使用__new__实现享元

python
from typing import Dict, Any, Optional, Tuple
import weakref


class FlyweightMeta(type):
    """享元元类:自动实现享元模式"""
    
    def __new__(mcs, name: str, bases: tuple, namespace: dict):
        cls = super().__new__(mcs, name, bases, namespace)
        cls._instances: Dict[Any, weakref.ref] = {}
        cls._lock = type('_lock', (), {'__enter__': lambda s: None, '__exit__': lambda s, *a: None})()
        return cls


class SharedObject(metaclass=FlyweightMeta):
    """基于元类的享元基类"""
    
    _instances: Dict[Any, 'SharedObject'] = {}
    _lock = None
    
    def __new__(cls, key: Any, *args, **kwargs):
        if key in cls._instances:
            instance = cls._instances[key]()
            if instance is not None:
                return instance
        
        instance = super().__new__(cls)
        instance._key = key
        instance._initialized = False
        cls._instances[key] = weakref.ref(instance)
        return instance
    
    def __init__(self, key: Any, value: Any = None):
        if self._initialized:
            return
        self._value = value
        self._initialized = True
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(key={self._key}, value={self._value})"


class Font(SharedObject):
    """字体享元"""
    
    def __init__(self, name: str, size: int = 12, bold: bool = False):
        if hasattr(self, '_initialized') and self._initialized:
            return
        super().__init__(name)
        self.name = name
        self.size = size
        self.bold = bold
    
    def render(self, text: str, x: int, y: int) -> str:
        weight = "bold" if self.bold else "normal"
        return f"[{self.name} {self.size}pt {weight}] '{text}' at ({x}, {y})"


class InternPool:
    """字符串驻留池:类似Java String.intern()"""
    
    _pool: Dict[str, str] = {}
    _lock = None
    
    @classmethod
    def intern(cls, s: str) -> str:
        """返回池中相同字符串的引用"""
        if s not in cls._pool:
            cls._pool[s] = s
        return cls._pool[s]
    
    @classmethod
    def stats(cls) -> dict:
        return {
            'pool_size': len(cls._pool),
            'memory_estimate': sum(len(s) for s in cls._pool)
        }


if __name__ == "__main__":
    font1 = Font("Arial", 12)
    font2 = Font("Arial", 14)
    font3 = Font("Arial", 12)
    
    print(font1.render("Hello", 0, 0))
    print(f"font1 is font3: {font1 is font3}")
    print(f"font1 is font2: {font1 is font2}")
    
    s1 = InternPool.intern("hello world")
    s2 = InternPool.intern("hello world")
    print(f"\n字符串驻留: s1 is s2 = {s1 is s2}")
    print(InternPool.stats())

11.3.5 使用__slots__优化内存

python
from typing import Dict, Any
from dataclasses import dataclass
import sys


@dataclass(frozen=True, slots=True)
class SlottedFlyweight:
    """使用slots的享元:进一步减少内存"""
    key: str
    value: Any
    
    def operation(self, extrinsic: Any) -> str:
        return f"Flyweight({self.key}): {self.value} + {extrinsic}"


class SlottedFlyweightFactory:
    """slots享元工厂"""
    
    __slots__ = ['_flyweights', '_lock']
    
    def __init__(self):
        self._flyweights: Dict[str, SlottedFlyweight] = {}
    
    def get(self, key: str, value: Any = None) -> SlottedFlyweight:
        if key not in self._flyweights:
            self._flyweights[key] = SlottedFlyweight(key, value or key)
        return self._flyweights[key]
    
    def count(self) -> int:
        return len(self._flyweights)


def memory_comparison():
    """内存使用对比"""
    
    class NormalClass:
        def __init__(self, key, value):
            self.key = key
            self.value = value
    
    @dataclass
    class DataclassNormal:
        key: str
        value: Any
    
    @dataclass(frozen=True, slots=True)
    class DataclassSlotted:
        key: str
        value: Any
    
    normal = NormalClass("test", 123)
    dc_normal = DataclassNormal("test", 123)
    dc_slotted = DataclassSlotted("test", 123)
    
    print(f"NormalClass: {sys.getsizeof(normal)} bytes")
    print(f"DataclassNormal: {sys.getsizeof(dc_normal)} bytes")
    print(f"DataclassSlotted: {sys.getsizeof(dc_slotted)} bytes")


if __name__ == "__main__":
    memory_comparison()

11.4 企业级应用示例

11.4.1 数据库连接池享元

python
from typing import Dict, Optional, Any
from dataclasses import dataclass
import threading
import time
from contextlib import contextmanager
from abc import ABC, abstractmethod


@dataclass(frozen=True)
class ConnectionConfig:
    """连接配置享元:不可变配置"""
    host: str
    port: int
    database: str
    user: str
    max_connections: int = 10
    timeout: float = 30.0
    
    @property
    def connection_string(self) -> str:
        return f"postgresql://{self.user}@{self.host}:{self.port}/{self.database}"
    
    @property
    def pool_key(self) -> str:
        return f"{self.host}:{self.port}:{self.database}:{self.user}"


class ConnectionPoolFlyweight:
    """连接池享元:管理共享连接"""
    
    def __init__(self, config: ConnectionConfig):
        self._config = config
        self._pool: list = []
        self._in_use: set = set()
        self._lock = threading.Lock()
        self._created = time.time()
        self._stats = {
            'total_created': 0,
            'total_reused': 0,
            'total_errors': 0
        }
    
    @property
    def config(self) -> ConnectionConfig:
        return self._config
    
    def _create_connection(self) -> Any:
        """模拟创建连接"""
        class MockConnection:
            def __init__(self, conn_str):
                self.conn_str = conn_str
                self.created = time.time()
                self._closed = False
            
            def execute(self, query: str) -> str:
                if self._closed:
                    raise RuntimeError("Connection closed")
                return f"Executed: {query}"
            
            def close(self):
                self._closed = True
            
            def is_alive(self) -> bool:
                return not self._closed
        
        self._stats['total_created'] += 1
        return MockConnection(self._config.connection_string)
    
    def acquire(self) -> Any:
        """获取连接"""
        with self._lock:
            if self._pool:
                conn = self._pool.pop()
                self._in_use.add(conn)
                self._stats['total_reused'] += 1
                return conn
            
            if len(self._in_use) < self._config.max_connections:
                conn = self._create_connection()
                self._in_use.add(conn)
                return conn
            
            raise RuntimeError("Connection pool exhausted")
    
    def release(self, conn: Any) -> None:
        """释放连接"""
        with self._lock:
            if conn in self._in_use:
                self._in_use.remove(conn)
                if conn.is_alive():
                    self._pool.append(conn)
    
    @contextmanager
    def connection(self):
        """上下文管理器"""
        conn = self.acquire()
        try:
            yield conn
        finally:
            self.release(conn)
    
    def stats(self) -> dict:
        return {
            **self._stats,
            'pool_size': len(self._pool),
            'in_use': len(self._in_use),
            'uptime': time.time() - self._created
        }
    
    def close_all(self) -> None:
        with self._lock:
            for conn in self._pool:
                conn.close()
            for conn in self._in_use:
                conn.close()
            self._pool.clear()
            self._in_use.clear()


class ConnectionPoolFactory:
    """连接池工厂:享元管理"""
    
    _pools: Dict[str, ConnectionPoolFlyweight] = {}
    _lock = threading.Lock()
    
    @classmethod
    def get_pool(cls, config: ConnectionConfig) -> ConnectionPoolFlyweight:
        key = config.pool_key
        if key not in cls._pools:
            with cls._lock:
                if key not in cls._pools:
                    cls._pools[key] = ConnectionPoolFlyweight(config)
        return cls._pools[key]
    
    @classmethod
    def get_pool_by_params(
        cls,
        host: str,
        port: int,
        database: str,
        user: str,
        **kwargs
    ) -> ConnectionPoolFlyweight:
        config = ConnectionConfig(host, port, database, user, **kwargs)
        return cls.get_pool(config)
    
    @classmethod
    def list_pools(cls) -> list:
        return list(cls._pools.keys())
    
    @classmethod
    def pool_count(cls) -> int:
        return len(cls._pools)
    
    @classmethod
    def close_all(cls) -> None:
        for pool in cls._pools.values():
            pool.close_all()
        cls._pools.clear()


class DatabaseService:
    """数据库服务:使用享元连接池"""
    
    def __init__(self, config: ConnectionConfig):
        self._pool = ConnectionPoolFactory.get_pool(config)
    
    def execute_query(self, query: str) -> str:
        with self._pool.connection() as conn:
            return conn.execute(query)
    
    def get_stats(self) -> dict:
        return self._pool.stats()


if __name__ == "__main__":
    config1 = ConnectionConfig("localhost", 5432, "mydb", "admin")
    config2 = ConnectionConfig("localhost", 5432, "mydb", "admin")
    config3 = ConnectionConfig("remotehost", 5432, "mydb", "admin")
    
    service1 = DatabaseService(config1)
    service2 = DatabaseService(config2)
    service3 = DatabaseService(config3)
    
    print(service1.execute_query("SELECT * FROM users"))
    print(service2.execute_query("SELECT * FROM orders"))
    
    print(f"\n连接池数量: {ConnectionPoolFactory.pool_count()}")
    print(f"池1统计: {service1.get_stats()}")
    print(f"service1和service2共享同一池: {service1._pool is service2._pool}")

11.4.2 图形渲染资源享元

python
from typing import Dict, Tuple, Optional, Any
from dataclasses import dataclass
from enum import Enum
from abc import ABC, abstractmethod
import hashlib


class TextureFormat(Enum):
    PNG = "png"
    JPEG = "jpeg"
    WEBP = "webp"


@dataclass(frozen=True)
class Texture:
    """纹理享元:GPU纹理资源"""
    name: str
    width: int
    height: int
    format: TextureFormat
    data_hash: str
    
    @property
    def memory_size(self) -> int:
        bytes_per_pixel = 4 if self.format != TextureFormat.JPEG else 3
        return self.width * self.height * bytes_per_pixel
    
    def bind(self, slot: int) -> str:
        return f"绑定纹理 '{self.name}' 到槽位 {slot}"


@dataclass(frozen=True)
class Shader:
    """着色器享元:编译后的着色器程序"""
    name: str
    vertex_source_hash: str
    fragment_source_hash: str
    uniform_count: int
    
    def use(self) -> str:
        return f"使用着色器程序 '{self.name}'"
    
    def set_uniform(self, name: str, value: Any) -> str:
        return f"设置 uniform '{name}' = {value}"


@dataclass(frozen=True)
class Material:
    """材质享元:渲染材质"""
    name: str
    shader: Shader
    albedo_texture: Optional[Texture]
    normal_texture: Optional[Texture]
    roughness: float
    metallic: float
    
    def apply(self) -> list:
        commands = [self.shader.use()]
        if self.albedo_texture:
            commands.append(self.albedo_texture.bind(0))
        if self.normal_texture:
            commands.append(self.normal_texture.bind(1))
        commands.append(self.shader.set_uniform("roughness", self.roughness))
        commands.append(self.shader.set_uniform("metallic", self.metallic))
        return commands


class ResourceFactory:
    """资源工厂:享元管理"""
    
    _textures: Dict[str, Texture] = {}
    _shaders: Dict[str, Shader] = {}
    _materials: Dict[str, Material] = {}
    
    @classmethod
    def create_texture(
        cls,
        name: str,
        width: int,
        height: int,
        format: TextureFormat,
        data: bytes = None
    ) -> Texture:
        data_hash = hashlib.md5(data or b'').hexdigest() if data else "empty"
        key = f"{name}_{width}x{height}_{format.value}_{data_hash}"
        
        if key not in cls._textures:
            cls._textures[key] = Texture(name, width, height, format, data_hash)
        return cls._textures[key]
    
    @classmethod
    def create_shader(
        cls,
        name: str,
        vertex_source: str,
        fragment_source: str
    ) -> Shader:
        v_hash = hashlib.md5(vertex_source.encode()).hexdigest()[:8]
        f_hash = hashlib.md5(fragment_source.encode()).hexdigest()[:8]
        key = f"{name}_{v_hash}_{f_hash}"
        
        if key not in cls._shaders:
            uniform_count = vertex_source.count("uniform") + fragment_source.count("uniform")
            cls._shaders[key] = Shader(name, v_hash, f_hash, uniform_count)
        return cls._shaders[key]
    
    @classmethod
    def create_material(
        cls,
        name: str,
        shader: Shader,
        albedo: Optional[Texture] = None,
        normal: Optional[Texture] = None,
        roughness: float = 0.5,
        metallic: float = 0.0
    ) -> Material:
        key = f"{name}_{shader.name}"
        
        if key not in cls._materials:
            cls._materials[key] = Material(
                name, shader, albedo, normal, roughness, metallic
            )
        return cls._materials[key]
    
    @classmethod
    def stats(cls) -> dict:
        texture_memory = sum(t.memory_size for t in cls._textures.values())
        return {
            'textures': len(cls._textures),
            'texture_memory_bytes': texture_memory,
            'texture_memory_mb': texture_memory / (1024 * 1024),
            'shaders': len(cls._shaders),
            'materials': len(cls._materials)
        }


class Mesh:
    """网格:外部状态"""
    
    def __init__(self, vertices: list, indices: list):
        self.vertices = vertices
        self.indices = indices
        self.position = (0.0, 0.0, 0.0)
        self.rotation = (0.0, 0.0, 0.0)
        self.scale = (1.0, 1.0, 1.0)
        self.material: Optional[Material] = None
    
    def render(self) -> str:
        return f"渲染网格: {len(self.vertices)} 顶点, 位置: {self.position}"


class Scene:
    """场景:管理渲染对象"""
    
    def __init__(self):
        self._meshes: list[Mesh] = []
    
    def add_mesh(self, mesh: Mesh) -> None:
        self._meshes.append(mesh)
    
    def render(self) -> list:
        commands = []
        for mesh in self._meshes:
            if mesh.material:
                commands.extend(mesh.material.apply())
            commands.append(mesh.render())
        return commands


if __name__ == "__main__":
    basic_shader = ResourceFactory.create_shader(
        "basic",
        "uniform mat4 model; void main() {}",
        "uniform vec3 color; void main() {}"
    )
    
    pbr_shader = ResourceFactory.create_shader(
        "pbr",
        "uniform mat4 model; uniform mat4 view; void main() {}",
        "uniform float roughness; uniform float metallic; void main() {}"
    )
    
    wood_texture = ResourceFactory.create_texture(
        "wood", 1024, 1024, TextureFormat.PNG, b"wood_data"
    )
    
    metal_material = ResourceFactory.create_material(
        "metal", pbr_shader, roughness=0.2, metallic=1.0
    )
    
    wood_material = ResourceFactory.create_material(
        "wood", pbr_shader, albedo=wood_texture, roughness=0.8, metallic=0.0
    )
    
    mesh1 = Mesh([(0, 0, 0), (1, 0, 0), (0, 1, 0)], [0, 1, 2])
    mesh1.material = metal_material
    mesh1.position = (10, 0, 5)
    
    mesh2 = Mesh([(0, 0, 0), (1, 0, 0), (0, 1, 0)], [0, 1, 2])
    mesh2.material = wood_material
    mesh2.position = (20, 0, 5)
    
    scene = Scene()
    scene.add_mesh(mesh1)
    scene.add_mesh(mesh2)
    
    for cmd in scene.render():
        print(cmd)
    
    print(f"\n资源统计: {ResourceFactory.stats()}")

11.4.3 缓存系统享元

python
from typing import Dict, Any, Optional, Generic, TypeVar, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from abc import ABC, abstractmethod
import threading
import json
import hashlib


T = TypeVar('T')


@dataclass(frozen=True)
class CacheKey:
    """缓存键享元"""
    namespace: str
    key: str
    
    @property
    def full_key(self) -> str:
        return f"{self.namespace}:{self.key}"
    
    @property
    def hash(self) -> str:
        return hashlib.md5(self.full_key.encode()).hexdigest()


@dataclass
class CacheEntry(Generic[T]):
    """缓存条目"""
    key: CacheKey
    value: T
    created_at: datetime = field(default_factory=datetime.now)
    expires_at: Optional[datetime] = None
    hits: int = 0
    size_bytes: int = 0
    
    def is_expired(self) -> bool:
        if self.expires_at is None:
            return False
        return datetime.now() > self.expires_at
    
    def touch(self) -> None:
        self.hits += 1


class CachePolicy(ABC):
    """缓存策略接口"""
    
    @abstractmethod
    def should_evict(self, entry: CacheEntry) -> bool:
        pass
    
    @abstractmethod
    def select_victim(self, entries: list) -> Optional[CacheEntry]:
        pass


class LRUPolicy(CachePolicy):
    """LRU策略"""
    
    def should_evict(self, entry: CacheEntry) -> bool:
        return entry.is_expired()
    
    def select_victim(self, entries: list) -> Optional[CacheEntry]:
        if not entries:
            return None
        return min(entries, key=lambda e: e.hits)


class TTLPolicy(CachePolicy):
    """TTL策略"""
    
    def __init__(self, max_age: timedelta):
        self.max_age = max_age
    
    def should_evict(self, entry: CacheEntry) -> bool:
        age = datetime.now() - entry.created_at
        return age > self.max_age or entry.is_expired()
    
    def select_victim(self, entries: list) -> Optional[CacheEntry]:
        expired = [e for e in entries if self.should_evict(e)]
        return expired[0] if expired else None


class FlyweightCache(Generic[T]):
    """享元缓存:共享缓存实例"""
    
    _instances: Dict[str, 'FlyweightCache'] = {}
    _lock = threading.Lock()
    
    def __init__(
        self,
        namespace: str,
        max_size: int = 1000,
        max_memory_mb: float = 100.0,
        policy: Optional[CachePolicy] = None
    ):
        self._namespace = namespace
        self._max_size = max_size
        self._max_memory = max_memory_mb * 1024 * 1024
        self._policy = policy or LRUPolicy()
        self._entries: Dict[str, CacheEntry[T]] = {}
        self._entry_lock = threading.Lock()
        self._stats = {
            'hits': 0,
            'misses': 0,
            'evictions': 0,
            'memory_used': 0
        }
    
    @classmethod
    def get_cache(
        cls,
        namespace: str,
        max_size: int = 1000,
        max_memory_mb: float = 100.0,
        policy: Optional[CachePolicy] = None
    ) -> 'FlyweightCache':
        if namespace not in cls._instances:
            with cls._lock:
                if namespace not in cls._instances:
                    cls._instances[namespace] = FlyweightCache(
                        namespace, max_size, max_memory_mb, policy
                    )
        return cls._instances[namespace]
    
    def _make_key(self, key: str) -> CacheKey:
        return CacheKey(self._namespace, key)
    
    def _estimate_size(self, value: Any) -> int:
        try:
            return len(json.dumps(value))
        except:
            return 64
    
    def _evict_if_needed(self) -> None:
        while (
            len(self._entries) >= self._max_size or
            self._stats['memory_used'] >= self._max_memory
        ):
            entries = list(self._entries.values())
            victim = self._policy.select_victim(entries)
            if victim:
                del self._entries[victim.key.full_key]
                self._stats['memory_used'] -= victim.size_bytes
                self._stats['evictions'] += 1
            else:
                break
    
    def get(self, key: str) -> Optional[T]:
        cache_key = self._make_key(key)
        with self._entry_lock:
            if cache_key.full_key in self._entries:
                entry = self._entries[cache_key.full_key]
                if not entry.is_expired():
                    entry.touch()
                    self._stats['hits'] += 1
                    return entry.value
                else:
                    del self._entries[cache_key.full_key]
            self._stats['misses'] += 1
            return None
    
    def set(
        self,
        key: str,
        value: T,
        ttl: Optional[timedelta] = None
    ) -> None:
        cache_key = self._make_key(key)
        size = self._estimate_size(value)
        
        expires_at = datetime.now() + ttl if ttl else None
        entry = CacheEntry(
            key=cache_key,
            value=value,
            expires_at=expires_at,
            size_bytes=size
        )
        
        with self._entry_lock:
            self._evict_if_needed()
            if cache_key.full_key in self._entries:
                old = self._entries[cache_key.full_key]
                self._stats['memory_used'] -= old.size_bytes
            self._entries[cache_key.full_key] = entry
            self._stats['memory_used'] += size
    
    def delete(self, key: str) -> bool:
        cache_key = self._make_key(key)
        with self._entry_lock:
            if cache_key.full_key in self._entries:
                entry = self._entries.pop(cache_key.full_key)
                self._stats['memory_used'] -= entry.size_bytes
                return True
            return False
    
    def clear(self) -> None:
        with self._entry_lock:
            self._entries.clear()
            self._stats['memory_used'] = 0
    
    def stats(self) -> dict:
        total = self._stats['hits'] + self._stats['misses']
        hit_rate = self._stats['hits'] / total if total > 0 else 0
        return {
            **self._stats,
            'entries': len(self._entries),
            'hit_rate': f"{hit_rate:.2%}",
            'memory_mb': self._stats['memory_used'] / (1024 * 1024)
        }
    
    @classmethod
    def list_caches(cls) -> list:
        return list(cls._instances.keys())
    
    @classmethod
    def global_stats(cls) -> dict:
        return {
            'cache_count': len(cls._instances),
            'caches': {ns: cache.stats() for ns, cache in cls._instances.items()}
        }


if __name__ == "__main__":
    user_cache = FlyweightCache.get_cache("users", max_size=100, policy=LRUPolicy())
    session_cache = FlyweightCache.get_cache(
        "sessions",
        max_size=1000,
        policy=TTLPolicy(timedelta(minutes=30))
    )
    
    user_cache.set("user:1", {"name": "Alice", "email": "alice@example.com"})
    user_cache.set("user:2", {"name": "Bob", "email": "bob@example.com"})
    
    session_cache.set("session:abc", {"user_id": 1, "role": "admin"})
    
    print(f"用户缓存命中: {user_cache.get('user:1')}")
    print(f"用户缓存未命中: {user_cache.get('user:999')}")
    
    print(f"\n用户缓存统计: {user_cache.stats()}")
    print(f"会话缓存统计: {session_cache.stats()}")
    print(f"\n全局统计: {FlyweightCache.global_stats()}")

11.5 模式变体与扩展

11.5.1 复合享元模式

python
from typing import Dict, List, Any, Iterator
from dataclasses import dataclass
from abc import ABC, abstractmethod


class Flyweight(ABC):
    """享元接口"""
    
    @abstractmethod
    def operation(self, extrinsic: Any) -> str:
        pass


@dataclass(frozen=True)
class ConcreteFlyweight(Flyweight):
    """具体享元"""
    _state: str
    
    def operation(self, extrinsic: Any) -> str:
        return f"享元[{self._state}] + 外部[{extrinsic}]"


class CompositeFlyweight(Flyweight):
    """复合享元:组合多个享元"""
    
    def __init__(self):
        self._flyweights: Dict[str, Flyweight] = {}
    
    def add(self, key: str, flyweight: Flyweight) -> None:
        self._flyweights[key] = flyweight
    
    def remove(self, key: str) -> None:
        self._flyweights.pop(key, None)
    
    def operation(self, extrinsic: Dict[str, Any]) -> List[str]:
        results = []
        for key, flyweight in self._flyweights.items():
            ext = extrinsic.get(key, None)
            results.append(flyweight.operation(ext))
        return results
    
    def count(self) -> int:
        return len(self._flyweights)


class FlyweightFactory:
    """享元工厂:支持复合享元"""
    
    _flyweights: Dict[str, Flyweight] = {}
    _composites: Dict[str, CompositeFlyweight] = {}
    
    @classmethod
    def get_flyweight(cls, key: str) -> Flyweight:
        if key not in cls._flyweights:
            cls._flyweights[key] = ConcreteFlyweight(key)
        return cls._flyweights[key]
    
    @classmethod
    def create_composite(cls, name: str, keys: List[str]) -> CompositeFlyweight:
        if name not in cls._composites:
            composite = CompositeFlyweight()
            for key in keys:
                composite.add(key, cls.get_flyweight(key))
            cls._composites[name] = composite
        return cls._composites[name]
    
    @classmethod
    def get_composite(cls, name: str) -> CompositeFlyweight:
        return cls._composites.get(name)


if __name__ == "__main__":
    composite = FlyweightFactory.create_composite(
        "style_group",
        ["font", "color", "size"]
    )
    
    extrinsic = {
        "font": "Arial",
        "color": "Red",
        "size": 14
    }
    
    for result in composite.operation(extrinsic):
        print(result)

11.5.2 享元池模式

python
from typing import Dict, List, Optional, TypeVar, Generic
from dataclasses import dataclass
from abc import ABC, abstractmethod
import threading
import queue


T = TypeVar('T')


class Poolable(ABC):
    """可池化对象接口"""
    
    @abstractmethod
    def reset(self) -> None:
        """重置对象状态"""
        pass
    
    @abstractmethod
    def is_valid(self) -> bool:
        """检查对象是否有效"""
        pass


class FlyweightPool(Generic[T]):
    """享元池:对象池与享元结合"""
    
    def __init__(self, factory: callable, max_size: int = 100):
        self._factory = factory
        self._max_size = max_size
        self._pool: queue.Queue = queue.Queue(maxsize=max_size)
        self._all_objects: Dict[int, T] = {}
        self._lock = threading.Lock()
        self._created = 0
    
    def acquire(self) -> T:
        try:
            obj = self._pool.get_nowait()
            if hasattr(obj, 'is_valid') and not obj.is_valid():
                return self._create_new()
            return obj
        except queue.Empty:
            return self._create_new()
    
    def release(self, obj: T) -> None:
        if hasattr(obj, 'reset'):
            obj.reset()
        try:
            self._pool.put_nowait(obj)
        except queue.Full:
            pass
    
    def _create_new(self) -> T:
        with self._lock:
            if self._created < self._max_size:
                obj = self._factory()
                self._all_objects[id(obj)] = obj
                self._created += 1
                return obj
        raise RuntimeError("Pool exhausted")
    
    def stats(self) -> dict:
        return {
            'created': self._created,
            'available': self._pool.qsize(),
            'in_use': self._created - self._pool.qsize(),
            'max_size': self._max_size
        }


@dataclass
class StringBuilder(Poolable):
    """可池化的字符串构建器"""
    _buffer: List[str] = None
    
    def __post_init__(self):
        if self._buffer is None:
            self._buffer = []
    
    def append(self, s: str) -> 'StringBuilder':
        self._buffer.append(s)
        return self
    
    def build(self) -> str:
        return ''.join(self._buffer)
    
    def reset(self) -> None:
        self._buffer.clear()
    
    def is_valid(self) -> bool:
        return True


if __name__ == "__main__":
    pool = FlyweightPool[StringBuilder](
        factory=lambda: StringBuilder(),
        max_size=10
    )
    
    sb1 = pool.acquire()
    sb1.append("Hello").append(" ").append("World")
    print(sb1.build())
    pool.release(sb1)
    
    sb2 = pool.acquire()
    sb2.append("New").append(" Message")
    print(sb2.build())
    
    print(f"池统计: {pool.stats()}")
    print(f"sb1 is sb2: {sb1 is sb2}")

11.5.3 弱引用享元

python
from typing import Dict, Any, Optional
from dataclasses import dataclass
import weakref
import gc


@dataclass(frozen=True)
class HeavyResource:
    """重型资源享元"""
    resource_id: str
    data: str
    
    def __repr__(self):
        return f"HeavyResource({self.resource_id})"


class WeakRefFlyweightFactory:
    """弱引用享元工厂:自动回收未使用的享元"""
    
    _flyweights: Dict[str, weakref.ref] = {}
    _lock = None
    
    @classmethod
    def get(cls, key: str, data: str = None) -> HeavyResource:
        if key in cls._flyweights:
            ref = cls._flyweights[key]
            obj = ref()
            if obj is not None:
                return obj
        
        obj = HeavyResource(key, data or f"data_for_{key}")
        cls._flyweights[key] = weakref.ref(obj)
        return obj
    
    @classmethod
    def count(cls) -> int:
        alive = sum(1 for ref in cls._flyweights.values() if ref() is not None)
        return alive
    
    @classmethod
    def cleanup(cls) -> int:
        dead_keys = [k for k, ref in cls._flyweights.items() if ref() is None]
        for k in dead_keys:
            del cls._flyweights[k]
        return len(dead_keys)


class StrongRefFlyweightFactory:
    """强引用享元工厂:手动管理生命周期"""
    
    _flyweights: Dict[str, HeavyResource] = {}
    _ref_counts: Dict[str, int] = {}
    
    @classmethod
    def acquire(cls, key: str, data: str = None) -> HeavyResource:
        if key not in cls._flyweights:
            cls._flyweights[key] = HeavyResource(key, data or f"data_for_{key}")
            cls._ref_counts[key] = 0
        cls._ref_counts[key] += 1
        return cls._flyweights[key]
    
    @classmethod
    def release(cls, key: str) -> None:
        if key in cls._ref_counts:
            cls._ref_counts[key] -= 1
            if cls._ref_counts[key] <= 0:
                del cls._flyweights[key]
                del cls._ref_counts[key]
    
    @classmethod
    def count(cls) -> int:
        return len(cls._flyweights)


if __name__ == "__main__":
    print("=== 弱引用享元 ===")
    r1 = WeakRefFlyweightFactory.get("A", "data_A")
    r2 = WeakRefFlyweightFactory.get("A", "data_A")
    print(f"r1 is r2: {r1 is r2}")
    print(f"活跃数量: {WeakRefFlyweightFactory.count()}")
    
    del r1, r2
    gc.collect()
    print(f"删除后活跃数量: {WeakRefFlyweightFactory.count()}")
    
    print("\n=== 强引用享元 ===")
    s1 = StrongRefFlyweightFactory.acquire("B", "data_B")
    s2 = StrongRefFlyweightFactory.acquire("B", "data_B")
    print(f"s1 is s2: {s1 is s2}")
    print(f"引用计数: {StrongRefFlyweightFactory._ref_counts.get('B', 0)}")
    
    StrongRefFlyweightFactory.release("B")
    print(f"释放一次后数量: {StrongRefFlyweightFactory.count()}")
    StrongRefFlyweightFactory.release("B")
    print(f"全部释放后数量: {StrongRefFlyweightFactory.count()}")

11.6 反模式与最佳实践

11.6.1 常见反模式

反模式1:状态泄漏

python
from dataclasses import dataclass
from typing import Dict


@dataclass
class BadFlyweight:
    """错误示例:可变的内部状态"""
    shared_data: str
    _cache: Dict = None
    
    def __post_init__(self):
        if self._cache is None:
            self._cache = {}
    
    def operation(self, key: str, value: str) -> str:
        self._cache[key] = value
        return f"{self.shared_data}: {self._cache}"


@dataclass(frozen=True)
class GoodFlyweight:
    """正确示例:不可变内部状态"""
    shared_data: str
    
    def operation(self, key: str, value: str, external_cache: Dict) -> str:
        return f"{self.shared_data}: {external_cache}"


if __name__ == "__main__":
    bad = BadFlyweight("shared")
    bad.operation("a", "1")
    bad.operation("b", "2")
    print(f"状态泄漏: {bad._cache}")
    
    good = GoodFlyweight("shared")
    cache = {}
    good.operation("a", "1", cache)
    good.operation("b", "2", cache)
    print(f"外部状态: {cache}")

反模式2:过度享元化

python
from typing import Dict, Tuple
from dataclasses import dataclass


@dataclass(frozen=True)
class OverEngineeredPoint:
    """过度设计:每个坐标都享元化"""
    x: float
    y: float


class PointFactory:
    _points: Dict[Tuple[float, float], OverEngineeredPoint] = {}
    
    @classmethod
    def get_point(cls, x: float, y: float) -> OverEngineeredPoint:
        key = (x, y)
        if key not in cls._points:
            cls._points[key] = OverEngineeredPoint(x, y)
        return cls._points[key]


@dataclass
class SimplePoint:
    """简单方案:直接创建"""
    x: float
    y: float


if __name__ == "__main__":
    import sys
    
    points_over = [PointFactory.get_point(i, i) for i in range(1000)]
    points_simple = [SimplePoint(i, i) for i in range(1000)]
    
    print(f"享元化对象数: {len(PointFactory._points)}")
    print(f"简单对象数: {len(points_simple)}")
    print(f"享元工厂开销: {sys.getsizeof(PointFactory._points)} bytes")

反模式3:线程不安全的工厂

python
from typing import Dict
from dataclasses import dataclass
import threading


@dataclass(frozen=True)
class Resource:
    name: str


class UnsafeFactory:
    """错误示例:非线程安全"""
    _resources: Dict[str, Resource] = {}
    
    @classmethod
    def get(cls, name: str) -> Resource:
        if name not in cls._resources:
            cls._resources[name] = Resource(name)
        return cls._resources[name]


class SafeFactory:
    """正确示例:线程安全"""
    _resources: Dict[str, Resource] = {}
    _lock = threading.Lock()
    
    @classmethod
    def get(cls, name: str) -> Resource:
        if name not in cls._resources:
            with cls._lock:
                if name not in cls._resources:
                    cls._resources[name] = Resource(name)
        return cls._resources[name]


def test_thread_safety():
    errors = []
    
    def worker(factory_class, name):
        try:
            for _ in range(1000):
                obj = factory_class.get(name)
                if obj.name != name:
                    errors.append(f"Wrong object: {obj.name}")
        except Exception as e:
            errors.append(str(e))
    
    threads = [
        threading.Thread(target=worker, args=(SafeFactory, "test"))
        for _ in range(10)
    ]
    
    for t in threads:
        t.start()
    for t in threads:
        t.join()
    
    print(f"线程安全测试: {'通过' if not errors else '失败'}")
    print(f"创建对象数: {len(SafeFactory._resources)}")

11.6.2 最佳实践清单

实践说明代码示例
不可变内部状态使用frozen=True确保线程安全@dataclass(frozen=True)
线程安全工厂使用双重检查锁定with lock: if key not in cache:
合理键设计键应唯一标识内部状态`f"
内存监控跟踪享元数量和内存使用factory.stats()
生命周期管理提供清理机制factory.clear()
弱引用可选大对象考虑弱引用weakref.ref(obj)
外部状态分离客户端管理可变状态通过参数传递
文档化状态明确标注内部/外部状态类型注解 + 文档字符串

11.7 决策指南

11.7.1 是否使用享元模式?

┌─────────────────────────────────────────────────────────────┐
│                    享元模式决策树                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  问题:是否存在大量相似对象?                                │
│         │                                                   │
│         ├── 否 ──→ 不需要享元模式                           │
│         │                                                   │
│         └── 是                                              │
│              │                                              │
│              ▼                                              │
│  问题:对象状态是否可分离为内部/外部?                       │
│         │                                                   │
│         ├── 否 ──→ 考虑对象池模式                           │
│         │                                                   │
│         └── 是                                              │
│              │                                              │
│              ▼                                              │
│  问题:内部状态是否真正可共享?                              │
│         │                                                   │
│         ├── 否 ──→ 重新设计状态分离                         │
│         │                                                   │
│         └── 是                                              │
│              │                                              │
│              ▼                                              │
│  问题:内存节省是否大于工厂开销?                            │
│         │                                                   │
│         ├── 否 ──→ 不值得使用享元                           │
│         │                                                   │
│         └── 是 ──→ ✓ 使用享元模式                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

11.7.2 实现技术选择

场景推荐实现理由
简单享元frozen dataclass + 字典工厂简洁、类型安全
需要接口约束ABC + frozen dataclass强类型约束
结构化类型Protocol + frozen dataclass灵活、鸭子类型
自动享元化元类 + __new__透明、自动化
内存敏感__slots__ + frozen dataclass最小内存占用
大对象池弱引用享元自动垃圾回收
高并发线程安全工厂 + 双重检查锁避免竞态条件

11.7.3 与其他模式的关系

模式关系协作方式
工厂方法创建关系工厂创建享元对象
单例特例关系享元工厂通常是单例
组合结构关系复合享元使用组合
对象池替代关系对象池重用实例,享元共享状态
原型对比关系原型复制对象,享元共享对象
策略状态关系策略对象可作为享元

11.8 快速参考卡片

┌─────────────────────────────────────────────────────────────────────────┐
│                        享元模式速查表                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  定义: 运用共享技术有效支持大量细粒度对象                                │
│                                                                         │
│  核心公式:                                                               │
│    Memory Saved = n × (S_total - S_intrinsic) - S_factory              │
│                                                                         │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                         │
│  参与者:                                                                 │
│    • Flyweight         → 声明接口                                       │
│    • ConcreteFlyweight → 存储内部状态(不可变)                         │
│    • FlyweightFactory  → 创建和管理享元                                 │
│    • Client            → 维护外部状态                                   │
│                                                                         │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                         │
│  状态分类:                                                               │
│    内部状态: 可共享、不可变、存储在享元内                                │
│    外部状态: 不可共享、可变、由客户端管理                                │
│                                                                         │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                         │
│  Python实现要点:                                                         │
│    @dataclass(frozen=True)  # 确保不可变                                │
│    class ConcreteFlyweight:                                             │
│        ...                                                              │
│                                                                         │
│    class FlyweightFactory:                                              │
│        _flyweights: Dict[str, Flyweight] = {}                          │
│        _lock = threading.Lock()                                        │
│                                                                         │
│        @classmethod                                                     │
│        def get(cls, key):                                               │
│            if key not in cls._flyweights:                              │
│                with cls._lock:                                          │
│                    if key not in cls._flyweights:                      │
│                        cls._flyweights[key] = Flyweight(key)           │
│            return cls._flyweights[key]                                 │
│                                                                         │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                         │
│  适用场景:                                                               │
│    ✓ 系统中存在大量相似对象                                              │
│    ✓ 对象的大部分状态可以外部化                                          │
│    ✓ 需要缓冲池的场景                                                    │
│    ✓ 内存是关键约束                                                      │
│                                                                         │
│  不适用场景:                                                             │
│    ✗ 对象状态不可分离                                                    │
│    ✗ 对象数量较少                                                        │
│    ✗ 外部状态传递开销大于内存节省                                        │
│                                                                         │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                         │
│  经典案例:                                                               │
│    • Java String.intern()                                               │
│    • Python 字符串驻留                                                   │
│    • 文字编辑器字符样式                                                   │
│    • 游戏粒子系统                                                        │
│    • 数据库连接池                                                        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

11.9 思考与实践

11.9.1 思考题

  1. 概念辨析:享元模式与单例模式有何本质区别?什么情况下两者可以结合使用?

  2. 状态分离:在设计一个图形编辑器时,如何确定哪些属性属于内部状态,哪些属于外部状态?

  3. 性能权衡:享元模式通过共享减少内存使用,但增加了查找和传递外部状态的开销。如何量化分析这种时间-空间权衡?

  4. 线程安全:为什么享元对象必须是不可变的?如果内部状态需要偶尔更新,应该如何设计?

  5. 内存管理:比较强引用享元和弱引用享元的适用场景,分析各自的优缺点。

11.9.2 实践练习

练习1:实现一个字体缓存系统

设计一个字体缓存系统,支持:

  • 字体名称、大小、样式(粗体、斜体)作为内部状态
  • 渲染位置、文本内容作为外部状态
  • 统计缓存命中率和内存使用

练习2:优化游戏地图渲染

给定一个包含100万个瓦片的游戏地图:

  • 瓦片类型:草地、沙漠、水域、山脉、森林(共5种)
  • 每个瓦片需要存储位置坐标和类型
  • 使用享元模式优化内存使用,计算节省的内存

练习3:实现多级缓存享元

设计一个支持L1/L2两级缓存的享元系统:

  • L1缓存使用强引用,容量有限
  • L2缓存使用弱引用,容量不限
  • 实现缓存淘汰策略

11.10 小结

享元模式是一种以空间换时间的优化模式,通过共享不可变的内部状态来减少内存占用。本章深入探讨了:

  1. 理论基础:内部状态与外部状态的精确区分,形式化定义与内存节省公式
  2. 实现技术:ABC、Protocol、Generic、元类、__slots__等多种Python实现方式
  3. 企业应用:数据库连接池、图形资源管理、缓存系统等实际案例
  4. 模式变体:复合享元、享元池、弱引用享元等扩展形式
  5. 最佳实践:避免状态泄漏、过度享元化等反模式,掌握线程安全设计

享元模式的核心价值在于:当系统中存在大量细粒度对象且状态可分离时,通过共享技术显著降低内存消耗,提升系统性能。


参考资料

  1. Gamma, E., et al. Design Patterns: Elements of Reusable Object-Oriented Software. Addison-Wesley, 1994.
  2. Schmidt, D., et al. Pattern-Oriented Software Architecture. Wiley, 1996.
  3. Python Documentation. Data Classes. https://docs.python.org/3/library/dataclasses.html
  4. PEP 544. Protocols: Structural subtyping. https://www.python.org/dev/peps/pep-0544/
  5. Nystrom, R. Game Programming Patterns. Genever Benning, 2014.

Python技术丛书 - 江苏省宿城中等专业学校