第2章 工厂方法模式
学习目标
完成本章学习后,读者将能够:
- 深入理解工厂方法模式的理论基础、设计动机与形式化定义
- 掌握Python语言特性下的多种工厂实现机制及其设计权衡
- 分析简单工厂、工厂方法与抽象工厂模式的本质区别
- 评估工厂方法模式在软件架构中的利弊与适用场景
- 运用工厂方法模式设计可扩展、可维护的对象创建体系
2.1 理论基础与模式定义
2.1.1 形式化定义
工厂方法模式(Factory Method Pattern) 是一种创建型设计模式,其核心思想可形式化表述为:
$$\text{FactoryMethod}: \text{Creator} \rightarrow \text{Product}$$
满足以下约束:
$$\forall c \in \text{ConcreteCreator}, \exists! p \in \text{ConcreteProduct}: c.\text{factory_method}() = p$$
即对于每个具体创建者 $c$,存在唯一的具体产品 $p$,使得工厂方法返回该产品实例。该模式实现了创建者与产品的多态关联:
$$\text{Client} \xrightarrow{\text{depends on}} \text{Creator} \xrightarrow{\text{creates}} \text{Product}$$
而非直接依赖:
$$\text{Client} \xrightarrow{\text{depends on}} \text{ConcreteProduct}$$
2.1.2 历史演进与学术背景
工厂方法模式源于对对象创建责任的抽象化处理。在面向对象设计早期,开发者发现直接使用new关键字实例化对象存在以下问题:
- 编译时绑定:客户端代码与具体类紧密耦合
- 扩展困难:新增产品类型需要修改客户端代码
- 测试障碍:难以在测试中替换真实对象
GoF在1994年将工厂方法模式系统化,定义为"定义一个创建对象的接口,让子类决定实例化哪一个类"。该模式体现了依赖倒置原则(DIP):
传统方式:
高层模块 ──────> 低层模块(具体类)
违反DIP
工厂方法:
高层模块 ──────> 抽象接口 <────── 低层模块
遵循DIP2.1.3 模式动机与应用场景
工厂方法模式解决的核心问题是对象创建的延迟绑定。考虑以下典型场景:
┌─────────────────────────────────────────────────────────────────┐
│ 工厂方法模式应用场景 │
├─────────────────────────────────────────────────────────────────┤
│ 框架开发 │ 插件系统 │ 组件扩展 │ 库的内部对象创建 │
├─────────────────────────────────────────────────────────────────┤
│ 多态创建 │ 不同数据库连接 │ 不同协议处理 │ 不同文件解析 │
├─────────────────────────────────────────────────────────────────┤
│ 测试隔离 │ Mock对象创建 │ 测试替身 │ 依赖注入 │
├─────────────────────────────────────────────────────────────────┤
│ 配置驱动 │ 运行时决定类型 │ 外部配置创建 │ 动态加载 │
└─────────────────────────────────────────────────────────────────┘核心设计动机:
- 解耦创建与使用:客户端无需知道具体产品类名
- 开闭原则支持:新增产品类型无需修改现有代码
- 单一职责:创建逻辑集中管理
- 延迟决策:将实例化推迟到运行时
2.1.4 UML结构模型
┌─────────────────────────────────────────────────────────────────┐
│ Product «interface» │
├─────────────────────────────────────────────────────────────────┤
│ + operation(): Result │
└─────────────────────────────────────────────────────────────────┘
△
│ implements
┌────────────┴────────────┐
│ │
┌─────────────┴─────────┐ ┌──────────┴──────────┐
│ ConcreteProductA │ │ ConcreteProductB │
├───────────────────────┤ ├─────────────────────┤
│ + operation(): Result│ │ + operation() │
└───────────────────────┘ └─────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Creator «abstract» │
├─────────────────────────────────────────────────────────────────┤
│ + factory_method(): Product «abstract» │
│ + some_operation(): Result │
│ │
│ // 模板方法,调用工厂方法 │
│ some_operation(): │
│ product = this.factory_method() │
│ return product.operation() │
└─────────────────────────────────────────────────────────────────┘
△
│ extends
┌────────────┴────────────┐
│ │
┌─────────────┴─────────┐ ┌──────────┴──────────┐
│ ConcreteCreatorA │ │ ConcreteCreatorB │
├───────────────────────┤ ├─────────────────────┤
│ + factory_method(): │ │ + factory_method() │
│ return ProductA()│ │ return ProductB│
└───────────────────────┘ └─────────────────────┘结构要素解析:
| 角色 | 职责 | Python实现方式 |
|---|---|---|
| Product | 定义产品接口 | Protocol 或 ABC |
| ConcreteProduct | 实现产品接口 | 具体类 |
| Creator | 声明工厂方法 | 抽象基类 |
| ConcreteCreator | 实现工厂方法 | 具体子类 |
2.2 Python实现机制深度解析
2.2.1 基于抽象基类的经典实现
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Optional, Dict, Type, Any
from dataclasses import dataclass
from enum import Enum, auto
T = TypeVar('T')
class Product(ABC):
"""产品抽象基类"""
@abstractmethod
def operation(self) -> str:
"""产品操作接口"""
pass
@abstractmethod
def get_info(self) -> Dict[str, Any]:
"""获取产品信息"""
pass
class ConcreteProductA(Product):
"""具体产品A"""
def __init__(self, name: str = "ProductA"):
self._name = name
self._created_at = __import__('datetime').datetime.now()
def operation(self) -> str:
return f"[{self._name}] 执行产品A的操作"
def get_info(self) -> Dict[str, Any]:
return {
"name": self._name,
"type": "ConcreteProductA",
"created_at": self._created_at.isoformat()
}
class ConcreteProductB(Product):
"""具体产品B"""
def __init__(self, version: str = "1.0"):
self._version = version
self._created_at = __import__('datetime').datetime.now()
def operation(self) -> str:
return f"[ProductB v{self._version}] 执行产品B的操作"
def get_info(self) -> Dict[str, Any]:
return {
"version": self._version,
"type": "ConcreteProductB",
"created_at": self._created_at.isoformat()
}
class Creator(ABC):
"""
创建者抽象基类
职责:
- 声明工厂方法
- 提供默认的业务逻辑(模板方法)
"""
@abstractmethod
def factory_method(self) -> Product:
"""工厂方法:由子类实现具体产品创建"""
pass
def some_operation(self) -> str:
"""
模板方法:使用工厂方法创建产品并执行操作
注意:此方法不关心具体产品类型,只依赖Product接口
"""
product = self.factory_method()
result = product.operation()
return f"Creator: {result}"
def create_and_configure(self, **kwargs) -> Product:
"""创建并配置产品"""
product = self.factory_method()
return product
class ConcreteCreatorA(Creator):
"""具体创建者A:负责创建产品A"""
def __init__(self, product_name: str = "DefaultA"):
self._product_name = product_name
def factory_method(self) -> Product:
return ConcreteProductA(self._product_name)
class ConcreteCreatorB(Creator):
"""具体创建者B:负责创建产品B"""
def __init__(self, version: str = "1.0"):
self._version = version
def factory_method(self) -> Product:
return ConcreteProductB(self._version)
def client_code(creator: Creator) -> None:
"""
客户端代码
关键点:客户端只依赖Creator抽象,不依赖具体产品类
"""
print(f"客户端:使用 {creator.__class__.__name__}")
print(creator.some_operation())
print(f"产品信息: {creator.factory_method().get_info()}")
client_code(ConcreteCreatorA("CustomA"))
print("-" * 50)
client_code(ConcreteCreatorB("2.0"))2.2.2 基于Protocol的现代实现
Python 3.8+引入的Protocol提供了结构化子类型(鸭子类型的静态检查),是接口定义的现代选择:
from typing import Protocol, runtime_checkable, TypeVar, Generic, ClassVar
from dataclasses import dataclass, field
@runtime_checkable
class Transport(Protocol):
"""运输工具协议(结构化接口)"""
capacity: ClassVar[int]
def deliver(self, destination: str) -> str:
"""执行配送"""
...
def estimate_time(self, distance: float) -> float:
"""估算配送时间"""
...
@dataclass
class Truck:
"""卡车运输"""
capacity: ClassVar[int] = 10000
license_plate: str
driver: str = "Unknown"
def deliver(self, destination: str) -> str:
return f"卡车 [{self.license_plate}] 配送到 {destination},司机: {self.driver}"
def estimate_time(self, distance: float) -> float:
return distance / 60.0
@dataclass
class Ship:
"""轮船运输"""
capacity: ClassVar[int] = 100000
vessel_name: str
port: str = "Default Port"
def deliver(self, destination: str) -> str:
return f"轮船 [{self.vessel_name}] 从 {self.port} 配送到 {destination}"
def estimate_time(self, distance: float) -> float:
return distance / 30.0
@dataclass
class Airplane:
"""飞机运输"""
capacity: ClassVar[int] = 5000
flight_number: str
airline: str = "Default Airline"
def deliver(self, destination: str) -> str:
return f"航班 [{self.flight_number}] 空运到 {destination}"
def estimate_time(self, distance: float) -> float:
return distance / 500.0
class Logistics(Protocol):
"""物流抽象协议"""
def create_transport(self) -> Transport:
"""工厂方法:创建运输工具"""
...
def plan_delivery(self, destination: str, distance: float) -> Dict[str, Any]:
"""规划配送"""
...
@dataclass
class RoadLogistics:
"""公路物流"""
default_driver: str = "Default Driver"
def create_transport(self) -> Transport:
import uuid
return Truck(
license_plate=f"TRUCK-{uuid.uuid4().hex[:6].upper()}",
driver=self.default_driver
)
def plan_delivery(self, destination: str, distance: float) -> Dict[str, Any]:
transport = self.create_transport()
return {
"transport_type": "Truck",
"delivery_info": transport.deliver(destination),
"estimated_hours": transport.estimate_time(distance),
"capacity": transport.capacity
}
@dataclass
class SeaLogistics:
"""海运物流"""
default_port: str = "Shanghai"
def create_transport(self) -> Transport:
import uuid
return Ship(
vessel_name=f"VESSEL-{uuid.uuid4().hex[:6].upper()}",
port=self.default_port
)
def plan_delivery(self, destination: str, distance: float) -> Dict[str, Any]:
transport = self.create_transport()
return {
"transport_type": "Ship",
"delivery_info": transport.deliver(destination),
"estimated_hours": transport.estimate_time(distance),
"capacity": transport.capacity
}
@dataclass
class AirLogistics:
"""空运物流"""
default_airline: str = "Global Air"
def create_transport(self) -> Transport:
import uuid
return Airplane(
flight_number=f"GA{uuid.uuid4().hex[:4].upper()}",
airline=self.default_airline
)
def plan_delivery(self, destination: str, distance: float) -> Dict[str, Any]:
transport = self.create_transport()
return {
"transport_type": "Airplane",
"delivery_info": transport.deliver(destination),
"estimated_hours": transport.estimate_time(distance),
"capacity": transport.capacity
}
def process_logistics(logistics: Logistics, destination: str, distance: float):
"""处理物流配送"""
plan = logistics.plan_delivery(destination, distance)
print(f"物流类型: {plan['transport_type']}")
print(f"配送信息: {plan['delivery_info']}")
print(f"预计时间: {plan['estimated_hours']:.2f} 小时")
print(f"运载能力: {plan['capacity']} kg")
road = RoadLogistics("张师傅")
process_logistics(road, "北京", 1200)
print("-" * 50)
air = AirLogistics("东方航空")
process_logistics(air, "纽约", 12000)2.2.3 参数化工厂方法
参数化工厂方法允许一个创建者根据参数创建不同产品:
from enum import Enum, auto
from typing import Dict, Type, Callable, Optional, Any
from dataclasses import dataclass
import threading
class PaymentType(Enum):
"""支付类型枚举"""
CREDIT_CARD = auto()
DEBIT_CARD = auto()
PAYPAL = auto()
WECHAT = auto()
ALIPAY = auto()
BANK_TRANSFER = auto()
@dataclass
class PaymentResult:
"""支付结果"""
success: bool
transaction_id: str
amount: float
currency: str
message: str
timestamp: float = field(default_factory=lambda: __import__('time').time())
class PaymentMethod(ABC):
"""支付方式抽象基类"""
@abstractmethod
def pay(self, amount: float, currency: str = "CNY") -> PaymentResult:
pass
@abstractmethod
def refund(self, transaction_id: str) -> PaymentResult:
pass
@property
@abstractmethod
def fee_rate(self) -> float:
"""手续费率"""
pass
class CreditCardPayment(PaymentMethod):
"""信用卡支付"""
def __init__(self, card_number: str, cvv: str, expiry: str):
self._card_number = card_number
self._cvv = cvv
self._expiry = expiry
@property
def fee_rate(self) -> float:
return 0.025
def pay(self, amount: float, currency: str = "CNY") -> PaymentResult:
import uuid
return PaymentResult(
success=True,
transaction_id=f"CC-{uuid.uuid4().hex[:12].upper()}",
amount=amount,
currency=currency,
message=f"信用卡支付成功 (尾号: {self._card_number[-4:]})"
)
def refund(self, transaction_id: str) -> PaymentResult:
return PaymentResult(
success=True,
transaction_id=transaction_id,
amount=0,
currency="CNY",
message="信用卡退款成功"
)
class WeChatPayment(PaymentMethod):
"""微信支付"""
def __init__(self, openid: str):
self._openid = openid
@property
def fee_rate(self) -> float:
return 0.006
def pay(self, amount: float, currency: str = "CNY") -> PaymentResult:
import uuid
return PaymentResult(
success=True,
transaction_id=f"WX-{uuid.uuid4().hex[:12].upper()}",
amount=amount,
currency=currency,
message=f"微信支付成功"
)
def refund(self, transaction_id: str) -> PaymentResult:
return PaymentResult(
success=True,
transaction_id=transaction_id,
amount=0,
currency="CNY",
message="微信退款成功"
)
class AlipayPayment(PaymentMethod):
"""支付宝支付"""
def __init__(self, user_id: str):
self._user_id = user_id
@property
def fee_rate(self) -> float:
return 0.006
def pay(self, amount: float, currency: str = "CNY") -> PaymentResult:
import uuid
return PaymentResult(
success=True,
transaction_id=f"ALI-{uuid.uuid4().hex[:12].upper()}",
amount=amount,
currency=currency,
message="支付宝支付成功"
)
def refund(self, transaction_id: str) -> PaymentResult:
return PaymentResult(
success=True,
transaction_id=transaction_id,
amount=0,
currency="CNY",
message="支付宝退款成功"
)
class PaymentFactory:
"""
支付工厂
特性:
- 参数化工厂方法
- 注册机制支持扩展
- 线程安全的注册表
"""
_registry: Dict[PaymentType, Type[PaymentMethod]] = {}
_lock = threading.RLock()
@classmethod
def register(cls, payment_type: PaymentType,
payment_class: Type[PaymentMethod]) -> None:
"""注册支付方式"""
with cls._lock:
cls._registry[payment_type] = payment_class
@classmethod
def unregister(cls, payment_type: PaymentType) -> None:
"""注销支付方式"""
with cls._lock:
cls._registry.pop(payment_type, None)
@classmethod
def create(cls, payment_type: PaymentType, **kwargs) -> PaymentMethod:
"""
创建支付方式实例
参数:
payment_type: 支付类型
**kwargs: 传递给支付方式构造函数的参数
"""
with cls._lock:
if payment_type not in cls._registry:
raise ValueError(f"不支持的支付类型: {payment_type}")
payment_class = cls._registry[payment_type]
return payment_class(**kwargs)
@classmethod
def get_supported_types(cls) -> list:
"""获取支持的支付类型"""
with cls._lock:
return list(cls._registry.keys())
PaymentFactory.register(PaymentType.CREDIT_CARD, CreditCardPayment)
PaymentFactory.register(PaymentType.WECHAT, WeChatPayment)
PaymentFactory.register(PaymentType.ALIPAY, AlipayPayment)
credit_card = PaymentFactory.create(
PaymentType.CREDIT_CARD,
card_number="6225881234567890",
cvv="123",
expiry="12/25"
)
result = credit_card.pay(999.99)
print(f"支付结果: {result.message}, 交易号: {result.transaction_id}")
wechat = PaymentFactory.create(PaymentType.WECHAT, openid="wx_openid_123")
result = wechat.pay(199.00)
print(f"支付结果: {result.message}, 手续费率: {wechat.fee_rate}")2.2.4 注册表模式与动态发现
from typing import Dict, Type, Callable, Any, Optional, List
import importlib
import inspect
from dataclasses import dataclass, field
@dataclass
class HandlerInfo:
"""处理器信息"""
name: str
handler_class: Type
description: str = ""
tags: List[str] = field(default_factory=list)
class HandlerRegistry:
"""
处理器注册表
特性:
- 自动发现和注册
- 支持元数据
- 线程安全
"""
_handlers: Dict[str, HandlerInfo] = {}
_factories: Dict[str, Callable] = {}
_lock = threading.RLock()
@classmethod
def register(cls, name: str, description: str = "",
tags: List[str] = None) -> Callable:
"""
装饰器方式注册处理器
用法:
@HandlerRegistry.register("my_handler", "描述", ["tag1"])
class MyHandler:
pass
"""
def decorator(handler_class: Type) -> Type:
with cls._lock:
cls._handlers[name] = HandlerInfo(
name=name,
handler_class=handler_class,
description=description,
tags=tags or []
)
return handler_class
return decorator
@classmethod
def register_factory(cls, name: str, factory: Callable) -> None:
"""注册工厂函数"""
with cls._lock:
cls._factories[name] = factory
@classmethod
def get(cls, name: str) -> Optional[HandlerInfo]:
"""获取处理器信息"""
return cls._handlers.get(name)
@classmethod
def create(cls, name: str, *args, **kwargs) -> Any:
"""创建处理器实例"""
with cls._lock:
if name in cls._factories:
return cls._factories[name](*args, **kwargs)
if name in cls._handlers:
return cls._handlers[name].handler_class(*args, **kwargs)
raise ValueError(f"未注册的处理器: {name}")
@classmethod
def list_handlers(cls, tag: str = None) -> List[HandlerInfo]:
"""列出所有处理器"""
with cls._lock:
handlers = list(cls._handlers.values())
if tag:
handlers = [h for h in handlers if tag in h.tags]
return handlers
@classmethod
def auto_discover(cls, module_name: str) -> int:
"""
自动发现模块中的处理器
发现规则:
- 类名以 Handler 结尾
- 实现特定接口
"""
count = 0
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if name.endswith('Handler') and obj.__module__ == module_name:
handler_name = name.replace('Handler', '').lower()
cls._handlers[handler_name] = HandlerInfo(
name=handler_name,
handler_class=obj
)
count += 1
except ImportError:
pass
return count
class MessageHandler(ABC):
"""消息处理器基类"""
@abstractmethod
def handle(self, message: dict) -> dict:
pass
@HandlerRegistry.register("email", "邮件处理器", ["notification", "async"])
class EmailHandler(MessageHandler):
"""邮件处理器"""
def __init__(self, smtp_server: str = "localhost"):
self._smtp_server = smtp_server
def handle(self, message: dict) -> dict:
return {
"status": "sent",
"channel": "email",
"recipient": message.get("email"),
"smtp": self._smtp_server
}
@HandlerRegistry.register("sms", "短信处理器", ["notification"])
class SMSHandler(MessageHandler):
"""短信处理器"""
def handle(self, message: dict) -> dict:
return {
"status": "sent",
"channel": "sms",
"recipient": message.get("phone")
}
@HandlerRegistry.register("push", "推送处理器", ["notification", "mobile"])
class PushHandler(MessageHandler):
"""推送处理器"""
def handle(self, message: dict) -> dict:
return {
"status": "sent",
"channel": "push",
"device": message.get("device_id")
}
print("已注册处理器:")
for info in HandlerRegistry.list_handlers():
print(f" - {info.name}: {info.description} [{', '.join(info.tags)}]")
email_handler = HandlerRegistry.create("email", smtp_server="smtp.example.com")
result = email_handler.handle({"email": "user@example.com", "subject": "测试"})
print(f"\n处理结果: {result}")2.3 实际应用案例
2.3.1 企业级数据库连接工厂
from typing import Dict, Type, Any, Optional, ContextManager
from dataclasses import dataclass, field
from contextlib import contextmanager
from enum import Enum, auto
import threading
import time
class DatabaseType(Enum):
MYSQL = auto()
POSTGRESQL = auto()
SQLITE = auto()
MONGODB = auto()
REDIS = auto()
@dataclass
class ConnectionConfig:
"""连接配置"""
host: str = "localhost"
port: int = 3306
database: str = "test"
username: str = "root"
password: str = ""
charset: str = "utf8mb4"
pool_size: int = 5
timeout: float = 30.0
def to_connection_string(self, db_type: DatabaseType) -> str:
"""生成连接字符串"""
if db_type == DatabaseType.MYSQL:
return f"mysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
elif db_type == DatabaseType.POSTGRESQL:
return f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
elif db_type == DatabaseType.SQLITE:
return f"sqlite:///{self.database}.db"
elif db_type == DatabaseType.MONGODB:
return f"mongodb://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
elif db_type == DatabaseType.REDIS:
return f"redis://{self.host}:{self.port}/{self.database}"
return ""
@dataclass
class QueryResult:
"""查询结果"""
success: bool
data: Any = None
affected_rows: int = 0
error: Optional[str] = None
execution_time: float = 0.0
class Connection(ABC):
"""数据库连接抽象基类"""
def __init__(self, config: ConnectionConfig):
self._config = config
self._connected = False
self._last_query_time = 0.0
@abstractmethod
def connect(self) -> bool:
"""建立连接"""
pass
@abstractmethod
def disconnect(self) -> None:
"""断开连接"""
pass
@abstractmethod
def execute(self, query: str, params: tuple = None) -> QueryResult:
"""执行查询"""
pass
@abstractmethod
def begin_transaction(self) -> None:
"""开始事务"""
pass
@abstractmethod
def commit(self) -> None:
"""提交事务"""
pass
@abstractmethod
def rollback(self) -> None:
"""回滚事务"""
pass
@property
def is_connected(self) -> bool:
return self._connected
@contextmanager
def transaction(self):
"""事务上下文管理器"""
self.begin_transaction()
try:
yield self
self.commit()
except Exception as e:
self.rollback()
raise
class MySQLConnection(Connection):
"""MySQL连接"""
def connect(self) -> bool:
print(f"[MySQL] 连接到 {self._config.host}:{self._config.port}")
self._connected = True
return True
def disconnect(self) -> None:
print("[MySQL] 断开连接")
self._connected = False
def execute(self, query: str, params: tuple = None) -> QueryResult:
start = time.time()
print(f"[MySQL] 执行: {query}")
return QueryResult(
success=True,
data=[{"id": 1, "name": "test"}],
affected_rows=1,
execution_time=time.time() - start
)
def begin_transaction(self) -> None:
print("[MySQL] 开始事务")
def commit(self) -> None:
print("[MySQL] 提交事务")
def rollback(self) -> None:
print("[MySQL] 回滚事务")
class PostgreSQLConnection(Connection):
"""PostgreSQL连接"""
def connect(self) -> bool:
print(f"[PostgreSQL] 连接到 {self._config.host}:{self._config.port}")
self._connected = True
return True
def disconnect(self) -> None:
print("[PostgreSQL] 断开连接")
self._connected = False
def execute(self, query: str, params: tuple = None) -> QueryResult:
start = time.time()
print(f"[PostgreSQL] 执行: {query}")
return QueryResult(
success=True,
data=[{"id": 1, "name": "test"}],
affected_rows=1,
execution_time=time.time() - start
)
def begin_transaction(self) -> None:
print("[PostgreSQL] 开始事务")
def commit(self) -> None:
print("[PostgreSQL] 提交事务")
def rollback(self) -> None:
print("[PostgreSQL] 回滚事务")
class SQLiteDatabaseConnection(Connection):
"""SQLite连接"""
def connect(self) -> bool:
print(f"[SQLite] 连接到 {self._config.database}.db")
self._connected = True
return True
def disconnect(self) -> None:
print("[SQLite] 断开连接")
self._connected = False
def execute(self, query: str, params: tuple = None) -> QueryResult:
start = time.time()
print(f"[SQLite] 执行: {query}")
return QueryResult(
success=True,
data=[],
affected_rows=0,
execution_time=time.time() - start
)
def begin_transaction(self) -> None:
print("[SQLite] 开始事务")
def commit(self) -> None:
print("[SQLite] 提交事务")
def rollback(self) -> None:
print("[SQLite] 回滚事务")
class DatabaseFactory:
"""
数据库连接工厂
特性:
- 支持多种数据库类型
- 连接池管理
- 配置驱动创建
"""
_connection_classes: Dict[DatabaseType, Type[Connection]] = {
DatabaseType.MYSQL: MySQLConnection,
DatabaseType.POSTGRESQL: PostgreSQLConnection,
DatabaseType.SQLITE: SQLiteDatabaseConnection,
}
_connection_pools: Dict[str, list] = {}
_lock = threading.RLock()
@classmethod
def create_connection(cls, db_type: DatabaseType,
config: Optional[ConnectionConfig] = None) -> Connection:
"""创建数据库连接"""
if db_type not in cls._connection_classes:
raise ValueError(f"不支持的数据库类型: {db_type}")
connection_class = cls._connection_classes[db_type]
connection = connection_class(config or ConnectionConfig())
connection.connect()
return connection
@classmethod
def register_database(cls, db_type: DatabaseType,
connection_class: Type[Connection]) -> None:
"""注册新的数据库类型"""
cls._connection_classes[db_type] = connection_class
@classmethod
@contextmanager
def connection(cls, db_type: DatabaseType,
config: Optional[ConnectionConfig] = None):
"""连接上下文管理器"""
conn = None
try:
conn = cls.create_connection(db_type, config)
yield conn
finally:
if conn:
conn.disconnect()
@classmethod
def get_supported_databases(cls) -> List[DatabaseType]:
"""获取支持的数据库类型"""
return list(cls._connection_classes.keys())
config = ConnectionConfig(
host="192.168.1.100",
port=3306,
database="production",
username="admin",
password="secret"
)
with DatabaseFactory.connection(DatabaseType.MYSQL, config) as conn:
result = conn.execute("SELECT * FROM users WHERE status = %s", ("active",))
print(f"查询结果: {result.data}")
print(f"支持的数据库: {[db.name for db in DatabaseFactory.get_supported_databases()]}")2.3.2 插件架构中的工厂方法
from typing import Dict, Type, Any, Optional, List, Callable
from dataclasses import dataclass, field
from pathlib import Path
import importlib.util
import sys
@dataclass
class PluginInfo:
"""插件信息"""
name: str
version: str
description: str
author: str = ""
dependencies: List[str] = field(default_factory=list)
enabled: bool = True
class Plugin(ABC):
"""插件抽象基类"""
@property
@abstractmethod
def info(self) -> PluginInfo:
"""插件信息"""
pass
@abstractmethod
def initialize(self, context: Dict[str, Any]) -> None:
"""初始化插件"""
pass
@abstractmethod
def execute(self, *args, **kwargs) -> Any:
"""执行插件功能"""
pass
@abstractmethod
def shutdown(self) -> None:
"""关闭插件"""
pass
class PluginFactory:
"""
插件工厂
特性:
- 动态加载插件
- 插件生命周期管理
- 依赖解析
"""
_plugins: Dict[str, Type[Plugin]] = {}
_instances: Dict[str, Plugin] = {}
_lock = threading.RLock()
@classmethod
def register(cls, plugin_class: Type[Plugin]) -> None:
"""注册插件类"""
with cls._lock:
temp_instance = plugin_class()
info = temp_instance.info
cls._plugins[info.name] = plugin_class
print(f"[PluginFactory] 注册插件: {info.name} v{info.version}")
@classmethod
def create(cls, name: str, context: Dict[str, Any] = None) -> Plugin:
"""创建插件实例"""
with cls._lock:
if name not in cls._plugins:
raise ValueError(f"未注册的插件: {name}")
if name in cls._instances:
return cls._instances[name]
plugin_class = cls._plugins[name]
instance = plugin_class()
instance.initialize(context or {})
cls._instances[name] = instance
return instance
@classmethod
def load_from_file(cls, file_path: str) -> Optional[Type[Plugin]]:
"""从文件加载插件"""
path = Path(file_path)
if not path.exists():
print(f"[PluginFactory] 文件不存在: {file_path}")
return None
module_name = f"plugin_{path.stem}"
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
for name, obj in vars(module).items():
if isinstance(obj, type) and issubclass(obj, Plugin) and obj is not Plugin:
cls.register(obj)
return obj
return None
@classmethod
def get_plugin(cls, name: str) -> Optional[Plugin]:
"""获取已创建的插件实例"""
return cls._instances.get(name)
@classmethod
def list_plugins(cls) -> List[PluginInfo]:
"""列出所有插件"""
with cls._lock:
return [
cls._plugins[name]().info
for name in cls._plugins
]
@classmethod
def shutdown_all(cls) -> None:
"""关闭所有插件"""
with cls._lock:
for name, instance in cls._instances.items():
try:
instance.shutdown()
print(f"[PluginFactory] 关闭插件: {name}")
except Exception as e:
print(f"[PluginFactory] 关闭插件 {name} 失败: {e}")
cls._instances.clear()
class DataProcessorPlugin(Plugin):
"""数据处理器插件"""
@property
def info(self) -> PluginInfo:
return PluginInfo(
name="data_processor",
version="1.0.0",
description="数据处理插件",
author="Developer"
)
def initialize(self, context: Dict[str, Any]) -> None:
self._config = context.get("config", {})
print(f"[DataProcessor] 初始化完成")
def execute(self, data: List[Any], operation: str = "transform") -> Any:
if operation == "transform":
return [item * 2 for item in data]
elif operation == "filter":
return [item for item in data if item > 0]
return data
def shutdown(self) -> None:
print("[DataProcessor] 关闭")
class NotificationPlugin(Plugin):
"""通知插件"""
@property
def info(self) -> PluginInfo:
return PluginInfo(
name="notification",
version="2.0.0",
description="通知发送插件",
author="Developer",
dependencies=["data_processor"]
)
def initialize(self, context: Dict[str, Any]) -> None:
self._channels = context.get("channels", ["email"])
print(f"[Notification] 初始化完成,通道: {self._channels}")
def execute(self, message: str, recipients: List[str]) -> Dict:
return {
"message": message,
"recipients": recipients,
"channels": self._channels,
"status": "sent"
}
def shutdown(self) -> None:
print("[Notification] 关闭")
PluginFactory.register(DataProcessorPlugin)
PluginFactory.register(NotificationPlugin)
print("已注册插件:")
for info in PluginFactory.list_plugins():
print(f" - {info.name} v{info.version}: {info.description}")
processor = PluginFactory.create("data_processor", {"config": {"mode": "advanced"}})
result = processor.execute([1, 2, 3, 4, 5], operation="transform")
print(f"\n处理结果: {result}")
notification = PluginFactory.create("notification", {"channels": ["email", "sms"]})
result = notification.execute("测试消息", ["user@example.com"])
print(f"通知结果: {result}")
PluginFactory.shutdown_all()2.4 简单工厂 vs 工厂方法 vs 抽象工厂
2.4.1 三种工厂模式对比
┌─────────────────────────────────────────────────────────────────┐
│ 工厂模式层次结构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 简单工厂 │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 一个工厂类,根据参数创建不同产品 │ │
│ │ 违反开闭原则,适合产品种类固定的场景 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 工厂方法 │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 工厂抽象 + 多个具体工厂 │ │
│ │ 每个工厂只创建一种产品 │ │
│ │ 遵循开闭原则,适合产品种类可扩展 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 抽象工厂 │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 工厂抽象 + 多个具体工厂 │ │
│ │ 每个工厂创建一系列相关产品 │ │
│ │ 产品族概念,适合跨平台/主题场景 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘2.4.2 简单工厂实现
from typing import Dict, Type, Any
class SimpleFactory:
"""
简单工厂
特点:
- 集中创建逻辑
- 通过参数选择产品
- 违反开闭原则
"""
_products: Dict[str, Type] = {}
@classmethod
def register(cls, product_type: str, product_class: Type) -> None:
cls._products[product_type] = product_class
@classmethod
def create(cls, product_type: str, *args, **kwargs) -> Any:
if product_type not in cls._products:
raise ValueError(f"未知产品类型: {product_type}")
return cls._products[product_type](*args, **kwargs)
class SimpleFactoryWithSwitch:
"""
简单工厂(switch版本)
问题: 新增产品需要修改工厂类
"""
@staticmethod
def create(product_type: str) -> 'Product':
if product_type == "A":
return ConcreteProductA()
elif product_type == "B":
return ConcreteProductB()
else:
raise ValueError(f"未知产品类型: {product_type}")2.4.3 模式选择决策表
| 场景 | 推荐模式 | 理由 |
|---|---|---|
| 产品种类固定且少 | 简单工厂 | 实现简单,够用即可 |
| 产品种类可能扩展 | 工厂方法 | 遵循开闭原则 |
| 需要创建产品族 | 抽象工厂 | 保证产品一致性 |
| 框架/库开发 | 工厂方法 | 允许用户扩展 |
| 配置驱动创建 | 参数化工厂 | 灵活配置 |
2.5 反模式与最佳实践
2.5.1 常见反模式
反模式1:工厂承担过多职责
class BadFactory:
"""错误示例:工厂承担了创建、验证、持久化等多个职责"""
@staticmethod
def create_and_save(product_type: str, data: dict):
product = create_product(product_type)
validate_product(product)
save_to_database(product)
send_notification(product)
return product
class GoodFactory:
"""正确示例:工厂只负责创建"""
@staticmethod
def create(product_type: str, **kwargs) -> Product:
return create_product(product_type, **kwargs)反模式2:过度使用工厂
class User:
"""错误示例:简单值对象不需要工厂"""
def __init__(self, name: str, email: str):
self.name = name
self.email = email
class UserFactory:
"""过度设计:简单对象不需要工厂"""
@staticmethod
def create(name: str, email: str) -> User:
return User(name, email)2.5.2 最佳实践
from typing import Protocol, TypeVar, Generic, Dict, Type, Callable, Any
from dataclasses import dataclass
import threading
T = TypeVar('T')
class Factory(Protocol[T]):
"""工厂协议"""
def create(self, **kwargs) -> T:
...
@dataclass
class FactoryConfig:
"""工厂配置"""
default_args: Dict[str, Any] = None
validators: Dict[str, Callable] = None
post_processors: List[Callable] = None
class ConfigurableFactory(Generic[T]):
"""
可配置的泛型工厂
特性:
- 支持默认参数
- 支持验证器
- 支持后处理器
- 线程安全
"""
def __init__(self, product_class: Type[T], config: FactoryConfig = None):
self._product_class = product_class
self._config = config or FactoryConfig()
self._lock = threading.RLock()
def create(self, **kwargs) -> T:
"""创建产品实例"""
with self._lock:
merged_kwargs = self._merge_args(kwargs)
if self._config.validators:
self._validate(merged_kwargs)
instance = self._product_class(**merged_kwargs)
if self._config.post_processors:
for processor in self._config.post_processors:
instance = processor(instance)
return instance
def _merge_args(self, kwargs: Dict) -> Dict:
"""合并默认参数和传入参数"""
if self._config.default_args:
return {**self._config.default_args, **kwargs}
return kwargs
def _validate(self, kwargs: Dict) -> None:
"""验证参数"""
for key, validator in self._config.validators.items():
if key in kwargs and not validator(kwargs[key]):
raise ValueError(f"参数 {key} 验证失败")
class ProductRegistry:
"""
产品注册表
集中管理产品类型和工厂
"""
_factories: Dict[str, ConfigurableFactory] = {}
_lock = threading.RLock()
@classmethod
def register(cls, name: str, factory: ConfigurableFactory) -> None:
with cls._lock:
cls._factories[name] = factory
@classmethod
def create(cls, name: str, **kwargs) -> Any:
with cls._lock:
if name not in cls._factories:
raise ValueError(f"未注册的产品: {name}")
return cls._factories[name].create(**kwargs)
@classmethod
def get_factory(cls, name: str) -> Optional[ConfigurableFactory]:
return cls._factories.get(name)2.6 模式评估与决策指南
2.6.1 适用性检查清单
| 检查项 | 是 | 否 | 说明 |
|---|---|---|---|
| 产品类型可能变化? | ✓ 使用 | ✗ 评估 | 需要扩展性 |
| 创建逻辑复杂? | ✓ 使用 | ✗ 评估 | 需要封装 |
| 客户端不应知道具体类? | ✓ 使用 | ✗ 评估 | 需要解耦 |
| 需要延迟创建决策? | ✓ 使用 | ✗ 评估 | 运行时决定 |
| 产品种类很少且固定? | ✗ 评估 | ✓ 使用 | 简单工厂足够 |
2.6.2 性能考量
| 实现方式 | 创建开销 | 内存占用 | 适用场景 |
|---|---|---|---|
| 直接实例化 | 最低 | 最低 | 简单场景 |
| 简单工厂 | 低 | 低 | 产品固定 |
| 工厂方法 | 中 | 中 | 需要扩展 |
| 注册表工厂 | 中 | 中 | 动态发现 |
| 反射工厂 | 高 | 中 | 插件系统 |
2.7 小结
工厂方法模式通过延迟绑定和多态创建实现了创建者与产品的解耦。在Python中,可以利用动态特性简化实现:
┌─────────────────────────────────────────────────────────────────┐
│ 工厂方法实现选择指南 │
│ │
│ 产品种类是否固定? │
│ │ │
│ ┌──────────┴──────────┐ │
│ ▼ ▼ │
│ 是的 不确定 │
│ │ │ │
│ ▼ ▼ │
│ 简单工厂 需要创建产品族? │
│ │ │
│ ┌─────────┴─────────┐ │
│ ▼ ▼ │
│ 是的 不需要 │
│ │ │ │
│ ▼ ▼ │
│ 抽象工厂 工厂方法 │
│ │
└─────────────────────────────────────────────────────────────────┘关键要点:
- 核心价值:将对象创建与使用分离,支持开闭原则
- Python特性:利用Protocol、装饰器、注册表简化实现
- 避免过度设计:简单场景使用简单工厂或直接实例化
- 测试友好:工厂方法便于注入Mock对象
- 扩展机制:注册表模式支持动态扩展
在下一章中,我们将探讨抽象工厂模式,学习如何创建相关的对象家族。