Skip to content

第2章 工厂方法模式

学习目标

完成本章学习后,读者将能够:

  • 深入理解工厂方法模式的理论基础、设计动机与形式化定义
  • 掌握Python语言特性下的多种工厂实现机制及其设计权衡
  • 分析简单工厂、工厂方法与抽象工厂模式的本质区别
  • 评估工厂方法模式在软件架构中的利弊与适用场景
  • 运用工厂方法模式设计可扩展、可维护的对象创建体系

2.1 理论基础与模式定义

2.1.1 形式化定义

工厂方法模式(Factory Method Pattern) 是一种创建型设计模式,其核心思想可形式化表述为:

$$\text{FactoryMethod}: \text{Creator} \rightarrow \text{Product}$$

满足以下约束:

$$\forall c \in \text{ConcreteCreator}, \exists! p \in \text{ConcreteProduct}: c.\text{factory_method}() = p$$

即对于每个具体创建者 $c$,存在唯一的具体产品 $p$,使得工厂方法返回该产品实例。该模式实现了创建者与产品的多态关联

$$\text{Client} \xrightarrow{\text{depends on}} \text{Creator} \xrightarrow{\text{creates}} \text{Product}$$

而非直接依赖:

$$\text{Client} \xrightarrow{\text{depends on}} \text{ConcreteProduct}$$

2.1.2 历史演进与学术背景

工厂方法模式源于对对象创建责任的抽象化处理。在面向对象设计早期,开发者发现直接使用new关键字实例化对象存在以下问题:

  1. 编译时绑定:客户端代码与具体类紧密耦合
  2. 扩展困难:新增产品类型需要修改客户端代码
  3. 测试障碍:难以在测试中替换真实对象

GoF在1994年将工厂方法模式系统化,定义为"定义一个创建对象的接口,让子类决定实例化哪一个类"。该模式体现了依赖倒置原则(DIP)

传统方式:
高层模块 ──────> 低层模块(具体类)
    违反DIP

工厂方法:
高层模块 ──────> 抽象接口 <────── 低层模块
    遵循DIP

2.1.3 模式动机与应用场景

工厂方法模式解决的核心问题是对象创建的延迟绑定。考虑以下典型场景:

┌─────────────────────────────────────────────────────────────────┐
│                    工厂方法模式应用场景                          │
├─────────────────────────────────────────────────────────────────┤
│  框架开发      │  插件系统 │ 组件扩展 │ 库的内部对象创建        │
├─────────────────────────────────────────────────────────────────┤
│  多态创建      │  不同数据库连接 │ 不同协议处理 │ 不同文件解析 │
├─────────────────────────────────────────────────────────────────┤
│  测试隔离      │  Mock对象创建 │ 测试替身 │ 依赖注入            │
├─────────────────────────────────────────────────────────────────┤
│  配置驱动      │  运行时决定类型 │ 外部配置创建 │ 动态加载       │
└─────────────────────────────────────────────────────────────────┘

核心设计动机

  1. 解耦创建与使用:客户端无需知道具体产品类名
  2. 开闭原则支持:新增产品类型无需修改现有代码
  3. 单一职责:创建逻辑集中管理
  4. 延迟决策:将实例化推迟到运行时

2.1.4 UML结构模型

┌─────────────────────────────────────────────────────────────────┐
│                      Product «interface»                         │
├─────────────────────────────────────────────────────────────────┤
│  + operation(): Result                                          │
└─────────────────────────────────────────────────────────────────┘

                           │ implements
              ┌────────────┴────────────┐
              │                         │
┌─────────────┴─────────┐   ┌──────────┴──────────┐
│   ConcreteProductA    │   │   ConcreteProductB  │
├───────────────────────┤   ├─────────────────────┤
│  + operation(): Result│   │  + operation()      │
└───────────────────────┘   └─────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                      Creator «abstract»                          │
├─────────────────────────────────────────────────────────────────┤
│  + factory_method(): Product «abstract»                         │
│  + some_operation(): Result                                     │
│                                                                 │
│  // 模板方法,调用工厂方法                                       │
│  some_operation():                                              │
│      product = this.factory_method()                            │
│      return product.operation()                                 │
└─────────────────────────────────────────────────────────────────┘

                           │ extends
              ┌────────────┴────────────┐
              │                         │
┌─────────────┴─────────┐   ┌──────────┴──────────┐
│   ConcreteCreatorA    │   │   ConcreteCreatorB  │
├───────────────────────┤   ├─────────────────────┤
│  + factory_method():  │   │  + factory_method() │
│      return ProductA()│   │      return ProductB│
└───────────────────────┘   └─────────────────────┘

结构要素解析

角色职责Python实现方式
Product定义产品接口ProtocolABC
ConcreteProduct实现产品接口具体类
Creator声明工厂方法抽象基类
ConcreteCreator实现工厂方法具体子类

2.2 Python实现机制深度解析

2.2.1 基于抽象基类的经典实现

python
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Optional, Dict, Type, Any
from dataclasses import dataclass
from enum import Enum, auto

T = TypeVar('T')

class Product(ABC):
    """产品抽象基类"""
    
    @abstractmethod
    def operation(self) -> str:
        """产品操作接口"""
        pass
    
    @abstractmethod
    def get_info(self) -> Dict[str, Any]:
        """获取产品信息"""
        pass


class ConcreteProductA(Product):
    """具体产品A"""
    
    def __init__(self, name: str = "ProductA"):
        self._name = name
        self._created_at = __import__('datetime').datetime.now()
    
    def operation(self) -> str:
        return f"[{self._name}] 执行产品A的操作"
    
    def get_info(self) -> Dict[str, Any]:
        return {
            "name": self._name,
            "type": "ConcreteProductA",
            "created_at": self._created_at.isoformat()
        }


class ConcreteProductB(Product):
    """具体产品B"""
    
    def __init__(self, version: str = "1.0"):
        self._version = version
        self._created_at = __import__('datetime').datetime.now()
    
    def operation(self) -> str:
        return f"[ProductB v{self._version}] 执行产品B的操作"
    
    def get_info(self) -> Dict[str, Any]:
        return {
            "version": self._version,
            "type": "ConcreteProductB",
            "created_at": self._created_at.isoformat()
        }


class Creator(ABC):
    """
    创建者抽象基类
    
    职责:
        - 声明工厂方法
        - 提供默认的业务逻辑(模板方法)
    """
    
    @abstractmethod
    def factory_method(self) -> Product:
        """工厂方法:由子类实现具体产品创建"""
        pass
    
    def some_operation(self) -> str:
        """
        模板方法:使用工厂方法创建产品并执行操作
        
        注意:此方法不关心具体产品类型,只依赖Product接口
        """
        product = self.factory_method()
        result = product.operation()
        return f"Creator: {result}"
    
    def create_and_configure(self, **kwargs) -> Product:
        """创建并配置产品"""
        product = self.factory_method()
        return product


class ConcreteCreatorA(Creator):
    """具体创建者A:负责创建产品A"""
    
    def __init__(self, product_name: str = "DefaultA"):
        self._product_name = product_name
    
    def factory_method(self) -> Product:
        return ConcreteProductA(self._product_name)


class ConcreteCreatorB(Creator):
    """具体创建者B:负责创建产品B"""
    
    def __init__(self, version: str = "1.0"):
        self._version = version
    
    def factory_method(self) -> Product:
        return ConcreteProductB(self._version)


def client_code(creator: Creator) -> None:
    """
    客户端代码
    
    关键点:客户端只依赖Creator抽象,不依赖具体产品类
    """
    print(f"客户端:使用 {creator.__class__.__name__}")
    print(creator.some_operation())
    print(f"产品信息: {creator.factory_method().get_info()}")


client_code(ConcreteCreatorA("CustomA"))
print("-" * 50)
client_code(ConcreteCreatorB("2.0"))

2.2.2 基于Protocol的现代实现

Python 3.8+引入的Protocol提供了结构化子类型(鸭子类型的静态检查),是接口定义的现代选择:

python
from typing import Protocol, runtime_checkable, TypeVar, Generic, ClassVar
from dataclasses import dataclass, field

@runtime_checkable
class Transport(Protocol):
    """运输工具协议(结构化接口)"""
    
    capacity: ClassVar[int]
    
    def deliver(self, destination: str) -> str:
        """执行配送"""
        ...
    
    def estimate_time(self, distance: float) -> float:
        """估算配送时间"""
        ...


@dataclass
class Truck:
    """卡车运输"""
    
    capacity: ClassVar[int] = 10000
    license_plate: str
    driver: str = "Unknown"
    
    def deliver(self, destination: str) -> str:
        return f"卡车 [{self.license_plate}] 配送到 {destination},司机: {self.driver}"
    
    def estimate_time(self, distance: float) -> float:
        return distance / 60.0


@dataclass
class Ship:
    """轮船运输"""
    
    capacity: ClassVar[int] = 100000
    vessel_name: str
    port: str = "Default Port"
    
    def deliver(self, destination: str) -> str:
        return f"轮船 [{self.vessel_name}] 从 {self.port} 配送到 {destination}"
    
    def estimate_time(self, distance: float) -> float:
        return distance / 30.0


@dataclass
class Airplane:
    """飞机运输"""
    
    capacity: ClassVar[int] = 5000
    flight_number: str
    airline: str = "Default Airline"
    
    def deliver(self, destination: str) -> str:
        return f"航班 [{self.flight_number}] 空运到 {destination}"
    
    def estimate_time(self, distance: float) -> float:
        return distance / 500.0


class Logistics(Protocol):
    """物流抽象协议"""
    
    def create_transport(self) -> Transport:
        """工厂方法:创建运输工具"""
        ...
    
    def plan_delivery(self, destination: str, distance: float) -> Dict[str, Any]:
        """规划配送"""
        ...


@dataclass
class RoadLogistics:
    """公路物流"""
    
    default_driver: str = "Default Driver"
    
    def create_transport(self) -> Transport:
        import uuid
        return Truck(
            license_plate=f"TRUCK-{uuid.uuid4().hex[:6].upper()}",
            driver=self.default_driver
        )
    
    def plan_delivery(self, destination: str, distance: float) -> Dict[str, Any]:
        transport = self.create_transport()
        return {
            "transport_type": "Truck",
            "delivery_info": transport.deliver(destination),
            "estimated_hours": transport.estimate_time(distance),
            "capacity": transport.capacity
        }


@dataclass
class SeaLogistics:
    """海运物流"""
    
    default_port: str = "Shanghai"
    
    def create_transport(self) -> Transport:
        import uuid
        return Ship(
            vessel_name=f"VESSEL-{uuid.uuid4().hex[:6].upper()}",
            port=self.default_port
        )
    
    def plan_delivery(self, destination: str, distance: float) -> Dict[str, Any]:
        transport = self.create_transport()
        return {
            "transport_type": "Ship",
            "delivery_info": transport.deliver(destination),
            "estimated_hours": transport.estimate_time(distance),
            "capacity": transport.capacity
        }


@dataclass
class AirLogistics:
    """空运物流"""
    
    default_airline: str = "Global Air"
    
    def create_transport(self) -> Transport:
        import uuid
        return Airplane(
            flight_number=f"GA{uuid.uuid4().hex[:4].upper()}",
            airline=self.default_airline
        )
    
    def plan_delivery(self, destination: str, distance: float) -> Dict[str, Any]:
        transport = self.create_transport()
        return {
            "transport_type": "Airplane",
            "delivery_info": transport.deliver(destination),
            "estimated_hours": transport.estimate_time(distance),
            "capacity": transport.capacity
        }


def process_logistics(logistics: Logistics, destination: str, distance: float):
    """处理物流配送"""
    plan = logistics.plan_delivery(destination, distance)
    print(f"物流类型: {plan['transport_type']}")
    print(f"配送信息: {plan['delivery_info']}")
    print(f"预计时间: {plan['estimated_hours']:.2f} 小时")
    print(f"运载能力: {plan['capacity']} kg")


road = RoadLogistics("张师傅")
process_logistics(road, "北京", 1200)
print("-" * 50)

air = AirLogistics("东方航空")
process_logistics(air, "纽约", 12000)

2.2.3 参数化工厂方法

参数化工厂方法允许一个创建者根据参数创建不同产品:

python
from enum import Enum, auto
from typing import Dict, Type, Callable, Optional, Any
from dataclasses import dataclass
import threading

class PaymentType(Enum):
    """支付类型枚举"""
    CREDIT_CARD = auto()
    DEBIT_CARD = auto()
    PAYPAL = auto()
    WECHAT = auto()
    ALIPAY = auto()
    BANK_TRANSFER = auto()


@dataclass
class PaymentResult:
    """支付结果"""
    success: bool
    transaction_id: str
    amount: float
    currency: str
    message: str
    timestamp: float = field(default_factory=lambda: __import__('time').time())


class PaymentMethod(ABC):
    """支付方式抽象基类"""
    
    @abstractmethod
    def pay(self, amount: float, currency: str = "CNY") -> PaymentResult:
        pass
    
    @abstractmethod
    def refund(self, transaction_id: str) -> PaymentResult:
        pass
    
    @property
    @abstractmethod
    def fee_rate(self) -> float:
        """手续费率"""
        pass


class CreditCardPayment(PaymentMethod):
    """信用卡支付"""
    
    def __init__(self, card_number: str, cvv: str, expiry: str):
        self._card_number = card_number
        self._cvv = cvv
        self._expiry = expiry
    
    @property
    def fee_rate(self) -> float:
        return 0.025
    
    def pay(self, amount: float, currency: str = "CNY") -> PaymentResult:
        import uuid
        return PaymentResult(
            success=True,
            transaction_id=f"CC-{uuid.uuid4().hex[:12].upper()}",
            amount=amount,
            currency=currency,
            message=f"信用卡支付成功 (尾号: {self._card_number[-4:]})"
        )
    
    def refund(self, transaction_id: str) -> PaymentResult:
        return PaymentResult(
            success=True,
            transaction_id=transaction_id,
            amount=0,
            currency="CNY",
            message="信用卡退款成功"
        )


class WeChatPayment(PaymentMethod):
    """微信支付"""
    
    def __init__(self, openid: str):
        self._openid = openid
    
    @property
    def fee_rate(self) -> float:
        return 0.006
    
    def pay(self, amount: float, currency: str = "CNY") -> PaymentResult:
        import uuid
        return PaymentResult(
            success=True,
            transaction_id=f"WX-{uuid.uuid4().hex[:12].upper()}",
            amount=amount,
            currency=currency,
            message=f"微信支付成功"
        )
    
    def refund(self, transaction_id: str) -> PaymentResult:
        return PaymentResult(
            success=True,
            transaction_id=transaction_id,
            amount=0,
            currency="CNY",
            message="微信退款成功"
        )


class AlipayPayment(PaymentMethod):
    """支付宝支付"""
    
    def __init__(self, user_id: str):
        self._user_id = user_id
    
    @property
    def fee_rate(self) -> float:
        return 0.006
    
    def pay(self, amount: float, currency: str = "CNY") -> PaymentResult:
        import uuid
        return PaymentResult(
            success=True,
            transaction_id=f"ALI-{uuid.uuid4().hex[:12].upper()}",
            amount=amount,
            currency=currency,
            message="支付宝支付成功"
        )
    
    def refund(self, transaction_id: str) -> PaymentResult:
        return PaymentResult(
            success=True,
            transaction_id=transaction_id,
            amount=0,
            currency="CNY",
            message="支付宝退款成功"
        )


class PaymentFactory:
    """
    支付工厂
    
    特性:
        - 参数化工厂方法
        - 注册机制支持扩展
        - 线程安全的注册表
    """
    
    _registry: Dict[PaymentType, Type[PaymentMethod]] = {}
    _lock = threading.RLock()
    
    @classmethod
    def register(cls, payment_type: PaymentType, 
                payment_class: Type[PaymentMethod]) -> None:
        """注册支付方式"""
        with cls._lock:
            cls._registry[payment_type] = payment_class
    
    @classmethod
    def unregister(cls, payment_type: PaymentType) -> None:
        """注销支付方式"""
        with cls._lock:
            cls._registry.pop(payment_type, None)
    
    @classmethod
    def create(cls, payment_type: PaymentType, **kwargs) -> PaymentMethod:
        """
        创建支付方式实例
        
        参数:
            payment_type: 支付类型
            **kwargs: 传递给支付方式构造函数的参数
        """
        with cls._lock:
            if payment_type not in cls._registry:
                raise ValueError(f"不支持的支付类型: {payment_type}")
            
            payment_class = cls._registry[payment_type]
            return payment_class(**kwargs)
    
    @classmethod
    def get_supported_types(cls) -> list:
        """获取支持的支付类型"""
        with cls._lock:
            return list(cls._registry.keys())


PaymentFactory.register(PaymentType.CREDIT_CARD, CreditCardPayment)
PaymentFactory.register(PaymentType.WECHAT, WeChatPayment)
PaymentFactory.register(PaymentType.ALIPAY, AlipayPayment)

credit_card = PaymentFactory.create(
    PaymentType.CREDIT_CARD,
    card_number="6225881234567890",
    cvv="123",
    expiry="12/25"
)
result = credit_card.pay(999.99)
print(f"支付结果: {result.message}, 交易号: {result.transaction_id}")

wechat = PaymentFactory.create(PaymentType.WECHAT, openid="wx_openid_123")
result = wechat.pay(199.00)
print(f"支付结果: {result.message}, 手续费率: {wechat.fee_rate}")

2.2.4 注册表模式与动态发现

python
from typing import Dict, Type, Callable, Any, Optional, List
import importlib
import inspect
from dataclasses import dataclass, field

@dataclass
class HandlerInfo:
    """处理器信息"""
    name: str
    handler_class: Type
    description: str = ""
    tags: List[str] = field(default_factory=list)


class HandlerRegistry:
    """
    处理器注册表
    
    特性:
        - 自动发现和注册
        - 支持元数据
        - 线程安全
    """
    
    _handlers: Dict[str, HandlerInfo] = {}
    _factories: Dict[str, Callable] = {}
    _lock = threading.RLock()
    
    @classmethod
    def register(cls, name: str, description: str = "", 
                tags: List[str] = None) -> Callable:
        """
        装饰器方式注册处理器
        
        用法:
            @HandlerRegistry.register("my_handler", "描述", ["tag1"])
            class MyHandler:
                pass
        """
        def decorator(handler_class: Type) -> Type:
            with cls._lock:
                cls._handlers[name] = HandlerInfo(
                    name=name,
                    handler_class=handler_class,
                    description=description,
                    tags=tags or []
                )
            return handler_class
        return decorator
    
    @classmethod
    def register_factory(cls, name: str, factory: Callable) -> None:
        """注册工厂函数"""
        with cls._lock:
            cls._factories[name] = factory
    
    @classmethod
    def get(cls, name: str) -> Optional[HandlerInfo]:
        """获取处理器信息"""
        return cls._handlers.get(name)
    
    @classmethod
    def create(cls, name: str, *args, **kwargs) -> Any:
        """创建处理器实例"""
        with cls._lock:
            if name in cls._factories:
                return cls._factories[name](*args, **kwargs)
            
            if name in cls._handlers:
                return cls._handlers[name].handler_class(*args, **kwargs)
            
            raise ValueError(f"未注册的处理器: {name}")
    
    @classmethod
    def list_handlers(cls, tag: str = None) -> List[HandlerInfo]:
        """列出所有处理器"""
        with cls._lock:
            handlers = list(cls._handlers.values())
            if tag:
                handlers = [h for h in handlers if tag in h.tags]
            return handlers
    
    @classmethod
    def auto_discover(cls, module_name: str) -> int:
        """
        自动发现模块中的处理器
        
        发现规则:
            - 类名以 Handler 结尾
            - 实现特定接口
        """
        count = 0
        try:
            module = importlib.import_module(module_name)
            for name, obj in inspect.getmembers(module, inspect.isclass):
                if name.endswith('Handler') and obj.__module__ == module_name:
                    handler_name = name.replace('Handler', '').lower()
                    cls._handlers[handler_name] = HandlerInfo(
                        name=handler_name,
                        handler_class=obj
                    )
                    count += 1
        except ImportError:
            pass
        return count


class MessageHandler(ABC):
    """消息处理器基类"""
    
    @abstractmethod
    def handle(self, message: dict) -> dict:
        pass


@HandlerRegistry.register("email", "邮件处理器", ["notification", "async"])
class EmailHandler(MessageHandler):
    """邮件处理器"""
    
    def __init__(self, smtp_server: str = "localhost"):
        self._smtp_server = smtp_server
    
    def handle(self, message: dict) -> dict:
        return {
            "status": "sent",
            "channel": "email",
            "recipient": message.get("email"),
            "smtp": self._smtp_server
        }


@HandlerRegistry.register("sms", "短信处理器", ["notification"])
class SMSHandler(MessageHandler):
    """短信处理器"""
    
    def handle(self, message: dict) -> dict:
        return {
            "status": "sent",
            "channel": "sms",
            "recipient": message.get("phone")
        }


@HandlerRegistry.register("push", "推送处理器", ["notification", "mobile"])
class PushHandler(MessageHandler):
    """推送处理器"""
    
    def handle(self, message: dict) -> dict:
        return {
            "status": "sent",
            "channel": "push",
            "device": message.get("device_id")
        }


print("已注册处理器:")
for info in HandlerRegistry.list_handlers():
    print(f"  - {info.name}: {info.description} [{', '.join(info.tags)}]")

email_handler = HandlerRegistry.create("email", smtp_server="smtp.example.com")
result = email_handler.handle({"email": "user@example.com", "subject": "测试"})
print(f"\n处理结果: {result}")

2.3 实际应用案例

2.3.1 企业级数据库连接工厂

python
from typing import Dict, Type, Any, Optional, ContextManager
from dataclasses import dataclass, field
from contextlib import contextmanager
from enum import Enum, auto
import threading
import time

class DatabaseType(Enum):
    MYSQL = auto()
    POSTGRESQL = auto()
    SQLITE = auto()
    MONGODB = auto()
    REDIS = auto()


@dataclass
class ConnectionConfig:
    """连接配置"""
    host: str = "localhost"
    port: int = 3306
    database: str = "test"
    username: str = "root"
    password: str = ""
    charset: str = "utf8mb4"
    pool_size: int = 5
    timeout: float = 30.0
    
    def to_connection_string(self, db_type: DatabaseType) -> str:
        """生成连接字符串"""
        if db_type == DatabaseType.MYSQL:
            return f"mysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
        elif db_type == DatabaseType.POSTGRESQL:
            return f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
        elif db_type == DatabaseType.SQLITE:
            return f"sqlite:///{self.database}.db"
        elif db_type == DatabaseType.MONGODB:
            return f"mongodb://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
        elif db_type == DatabaseType.REDIS:
            return f"redis://{self.host}:{self.port}/{self.database}"
        return ""


@dataclass
class QueryResult:
    """查询结果"""
    success: bool
    data: Any = None
    affected_rows: int = 0
    error: Optional[str] = None
    execution_time: float = 0.0


class Connection(ABC):
    """数据库连接抽象基类"""
    
    def __init__(self, config: ConnectionConfig):
        self._config = config
        self._connected = False
        self._last_query_time = 0.0
    
    @abstractmethod
    def connect(self) -> bool:
        """建立连接"""
        pass
    
    @abstractmethod
    def disconnect(self) -> None:
        """断开连接"""
        pass
    
    @abstractmethod
    def execute(self, query: str, params: tuple = None) -> QueryResult:
        """执行查询"""
        pass
    
    @abstractmethod
    def begin_transaction(self) -> None:
        """开始事务"""
        pass
    
    @abstractmethod
    def commit(self) -> None:
        """提交事务"""
        pass
    
    @abstractmethod
    def rollback(self) -> None:
        """回滚事务"""
        pass
    
    @property
    def is_connected(self) -> bool:
        return self._connected
    
    @contextmanager
    def transaction(self):
        """事务上下文管理器"""
        self.begin_transaction()
        try:
            yield self
            self.commit()
        except Exception as e:
            self.rollback()
            raise


class MySQLConnection(Connection):
    """MySQL连接"""
    
    def connect(self) -> bool:
        print(f"[MySQL] 连接到 {self._config.host}:{self._config.port}")
        self._connected = True
        return True
    
    def disconnect(self) -> None:
        print("[MySQL] 断开连接")
        self._connected = False
    
    def execute(self, query: str, params: tuple = None) -> QueryResult:
        start = time.time()
        print(f"[MySQL] 执行: {query}")
        return QueryResult(
            success=True,
            data=[{"id": 1, "name": "test"}],
            affected_rows=1,
            execution_time=time.time() - start
        )
    
    def begin_transaction(self) -> None:
        print("[MySQL] 开始事务")
    
    def commit(self) -> None:
        print("[MySQL] 提交事务")
    
    def rollback(self) -> None:
        print("[MySQL] 回滚事务")


class PostgreSQLConnection(Connection):
    """PostgreSQL连接"""
    
    def connect(self) -> bool:
        print(f"[PostgreSQL] 连接到 {self._config.host}:{self._config.port}")
        self._connected = True
        return True
    
    def disconnect(self) -> None:
        print("[PostgreSQL] 断开连接")
        self._connected = False
    
    def execute(self, query: str, params: tuple = None) -> QueryResult:
        start = time.time()
        print(f"[PostgreSQL] 执行: {query}")
        return QueryResult(
            success=True,
            data=[{"id": 1, "name": "test"}],
            affected_rows=1,
            execution_time=time.time() - start
        )
    
    def begin_transaction(self) -> None:
        print("[PostgreSQL] 开始事务")
    
    def commit(self) -> None:
        print("[PostgreSQL] 提交事务")
    
    def rollback(self) -> None:
        print("[PostgreSQL] 回滚事务")


class SQLiteDatabaseConnection(Connection):
    """SQLite连接"""
    
    def connect(self) -> bool:
        print(f"[SQLite] 连接到 {self._config.database}.db")
        self._connected = True
        return True
    
    def disconnect(self) -> None:
        print("[SQLite] 断开连接")
        self._connected = False
    
    def execute(self, query: str, params: tuple = None) -> QueryResult:
        start = time.time()
        print(f"[SQLite] 执行: {query}")
        return QueryResult(
            success=True,
            data=[],
            affected_rows=0,
            execution_time=time.time() - start
        )
    
    def begin_transaction(self) -> None:
        print("[SQLite] 开始事务")
    
    def commit(self) -> None:
        print("[SQLite] 提交事务")
    
    def rollback(self) -> None:
        print("[SQLite] 回滚事务")


class DatabaseFactory:
    """
    数据库连接工厂
    
    特性:
        - 支持多种数据库类型
        - 连接池管理
        - 配置驱动创建
    """
    
    _connection_classes: Dict[DatabaseType, Type[Connection]] = {
        DatabaseType.MYSQL: MySQLConnection,
        DatabaseType.POSTGRESQL: PostgreSQLConnection,
        DatabaseType.SQLITE: SQLiteDatabaseConnection,
    }
    
    _connection_pools: Dict[str, list] = {}
    _lock = threading.RLock()
    
    @classmethod
    def create_connection(cls, db_type: DatabaseType, 
                         config: Optional[ConnectionConfig] = None) -> Connection:
        """创建数据库连接"""
        if db_type not in cls._connection_classes:
            raise ValueError(f"不支持的数据库类型: {db_type}")
        
        connection_class = cls._connection_classes[db_type]
        connection = connection_class(config or ConnectionConfig())
        connection.connect()
        return connection
    
    @classmethod
    def register_database(cls, db_type: DatabaseType, 
                         connection_class: Type[Connection]) -> None:
        """注册新的数据库类型"""
        cls._connection_classes[db_type] = connection_class
    
    @classmethod
    @contextmanager
    def connection(cls, db_type: DatabaseType, 
                  config: Optional[ConnectionConfig] = None):
        """连接上下文管理器"""
        conn = None
        try:
            conn = cls.create_connection(db_type, config)
            yield conn
        finally:
            if conn:
                conn.disconnect()
    
    @classmethod
    def get_supported_databases(cls) -> List[DatabaseType]:
        """获取支持的数据库类型"""
        return list(cls._connection_classes.keys())


config = ConnectionConfig(
    host="192.168.1.100",
    port=3306,
    database="production",
    username="admin",
    password="secret"
)

with DatabaseFactory.connection(DatabaseType.MYSQL, config) as conn:
    result = conn.execute("SELECT * FROM users WHERE status = %s", ("active",))
    print(f"查询结果: {result.data}")

print(f"支持的数据库: {[db.name for db in DatabaseFactory.get_supported_databases()]}")

2.3.2 插件架构中的工厂方法

python
from typing import Dict, Type, Any, Optional, List, Callable
from dataclasses import dataclass, field
from pathlib import Path
import importlib.util
import sys

@dataclass
class PluginInfo:
    """插件信息"""
    name: str
    version: str
    description: str
    author: str = ""
    dependencies: List[str] = field(default_factory=list)
    enabled: bool = True


class Plugin(ABC):
    """插件抽象基类"""
    
    @property
    @abstractmethod
    def info(self) -> PluginInfo:
        """插件信息"""
        pass
    
    @abstractmethod
    def initialize(self, context: Dict[str, Any]) -> None:
        """初始化插件"""
        pass
    
    @abstractmethod
    def execute(self, *args, **kwargs) -> Any:
        """执行插件功能"""
        pass
    
    @abstractmethod
    def shutdown(self) -> None:
        """关闭插件"""
        pass


class PluginFactory:
    """
    插件工厂
    
    特性:
        - 动态加载插件
        - 插件生命周期管理
        - 依赖解析
    """
    
    _plugins: Dict[str, Type[Plugin]] = {}
    _instances: Dict[str, Plugin] = {}
    _lock = threading.RLock()
    
    @classmethod
    def register(cls, plugin_class: Type[Plugin]) -> None:
        """注册插件类"""
        with cls._lock:
            temp_instance = plugin_class()
            info = temp_instance.info
            cls._plugins[info.name] = plugin_class
            print(f"[PluginFactory] 注册插件: {info.name} v{info.version}")
    
    @classmethod
    def create(cls, name: str, context: Dict[str, Any] = None) -> Plugin:
        """创建插件实例"""
        with cls._lock:
            if name not in cls._plugins:
                raise ValueError(f"未注册的插件: {name}")
            
            if name in cls._instances:
                return cls._instances[name]
            
            plugin_class = cls._plugins[name]
            instance = plugin_class()
            instance.initialize(context or {})
            cls._instances[name] = instance
            return instance
    
    @classmethod
    def load_from_file(cls, file_path: str) -> Optional[Type[Plugin]]:
        """从文件加载插件"""
        path = Path(file_path)
        if not path.exists():
            print(f"[PluginFactory] 文件不存在: {file_path}")
            return None
        
        module_name = f"plugin_{path.stem}"
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        if spec and spec.loader:
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            spec.loader.exec_module(module)
            
            for name, obj in vars(module).items():
                if isinstance(obj, type) and issubclass(obj, Plugin) and obj is not Plugin:
                    cls.register(obj)
                    return obj
        
        return None
    
    @classmethod
    def get_plugin(cls, name: str) -> Optional[Plugin]:
        """获取已创建的插件实例"""
        return cls._instances.get(name)
    
    @classmethod
    def list_plugins(cls) -> List[PluginInfo]:
        """列出所有插件"""
        with cls._lock:
            return [
                cls._plugins[name]().info 
                for name in cls._plugins
            ]
    
    @classmethod
    def shutdown_all(cls) -> None:
        """关闭所有插件"""
        with cls._lock:
            for name, instance in cls._instances.items():
                try:
                    instance.shutdown()
                    print(f"[PluginFactory] 关闭插件: {name}")
                except Exception as e:
                    print(f"[PluginFactory] 关闭插件 {name} 失败: {e}")
            cls._instances.clear()


class DataProcessorPlugin(Plugin):
    """数据处理器插件"""
    
    @property
    def info(self) -> PluginInfo:
        return PluginInfo(
            name="data_processor",
            version="1.0.0",
            description="数据处理插件",
            author="Developer"
        )
    
    def initialize(self, context: Dict[str, Any]) -> None:
        self._config = context.get("config", {})
        print(f"[DataProcessor] 初始化完成")
    
    def execute(self, data: List[Any], operation: str = "transform") -> Any:
        if operation == "transform":
            return [item * 2 for item in data]
        elif operation == "filter":
            return [item for item in data if item > 0]
        return data
    
    def shutdown(self) -> None:
        print("[DataProcessor] 关闭")


class NotificationPlugin(Plugin):
    """通知插件"""
    
    @property
    def info(self) -> PluginInfo:
        return PluginInfo(
            name="notification",
            version="2.0.0",
            description="通知发送插件",
            author="Developer",
            dependencies=["data_processor"]
        )
    
    def initialize(self, context: Dict[str, Any]) -> None:
        self._channels = context.get("channels", ["email"])
        print(f"[Notification] 初始化完成,通道: {self._channels}")
    
    def execute(self, message: str, recipients: List[str]) -> Dict:
        return {
            "message": message,
            "recipients": recipients,
            "channels": self._channels,
            "status": "sent"
        }
    
    def shutdown(self) -> None:
        print("[Notification] 关闭")


PluginFactory.register(DataProcessorPlugin)
PluginFactory.register(NotificationPlugin)

print("已注册插件:")
for info in PluginFactory.list_plugins():
    print(f"  - {info.name} v{info.version}: {info.description}")

processor = PluginFactory.create("data_processor", {"config": {"mode": "advanced"}})
result = processor.execute([1, 2, 3, 4, 5], operation="transform")
print(f"\n处理结果: {result}")

notification = PluginFactory.create("notification", {"channels": ["email", "sms"]})
result = notification.execute("测试消息", ["user@example.com"])
print(f"通知结果: {result}")

PluginFactory.shutdown_all()

2.4 简单工厂 vs 工厂方法 vs 抽象工厂

2.4.1 三种工厂模式对比

┌─────────────────────────────────────────────────────────────────┐
│                    工厂模式层次结构                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  简单工厂                                                       │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │  一个工厂类,根据参数创建不同产品                         │   │
│  │  违反开闭原则,适合产品种类固定的场景                     │   │
│  └─────────────────────────────────────────────────────────┘   │
│                           │                                     │
│                           ▼                                     │
│  工厂方法                                                       │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │  工厂抽象 + 多个具体工厂                                 │   │
│  │  每个工厂只创建一种产品                                  │   │
│  │  遵循开闭原则,适合产品种类可扩展                        │   │
│  └─────────────────────────────────────────────────────────┘   │
│                           │                                     │
│                           ▼                                     │
│  抽象工厂                                                       │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │  工厂抽象 + 多个具体工厂                                 │   │
│  │  每个工厂创建一系列相关产品                              │   │
│  │  产品族概念,适合跨平台/主题场景                         │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

2.4.2 简单工厂实现

python
from typing import Dict, Type, Any

class SimpleFactory:
    """
    简单工厂
    
    特点:
        - 集中创建逻辑
        - 通过参数选择产品
        - 违反开闭原则
    """
    
    _products: Dict[str, Type] = {}
    
    @classmethod
    def register(cls, product_type: str, product_class: Type) -> None:
        cls._products[product_type] = product_class
    
    @classmethod
    def create(cls, product_type: str, *args, **kwargs) -> Any:
        if product_type not in cls._products:
            raise ValueError(f"未知产品类型: {product_type}")
        return cls._products[product_type](*args, **kwargs)


class SimpleFactoryWithSwitch:
    """
    简单工厂(switch版本)
    
    问题: 新增产品需要修改工厂类
    """
    
    @staticmethod
    def create(product_type: str) -> 'Product':
        if product_type == "A":
            return ConcreteProductA()
        elif product_type == "B":
            return ConcreteProductB()
        else:
            raise ValueError(f"未知产品类型: {product_type}")

2.4.3 模式选择决策表

场景推荐模式理由
产品种类固定且少简单工厂实现简单,够用即可
产品种类可能扩展工厂方法遵循开闭原则
需要创建产品族抽象工厂保证产品一致性
框架/库开发工厂方法允许用户扩展
配置驱动创建参数化工厂灵活配置

2.5 反模式与最佳实践

2.5.1 常见反模式

反模式1:工厂承担过多职责

python
class BadFactory:
    """错误示例:工厂承担了创建、验证、持久化等多个职责"""
    
    @staticmethod
    def create_and_save(product_type: str, data: dict):
        product = create_product(product_type)
        validate_product(product)
        save_to_database(product)
        send_notification(product)
        return product


class GoodFactory:
    """正确示例:工厂只负责创建"""
    
    @staticmethod
    def create(product_type: str, **kwargs) -> Product:
        return create_product(product_type, **kwargs)

反模式2:过度使用工厂

python
class User:
    """错误示例:简单值对象不需要工厂"""
    
    def __init__(self, name: str, email: str):
        self.name = name
        self.email = email


class UserFactory:
    """过度设计:简单对象不需要工厂"""
    
    @staticmethod
    def create(name: str, email: str) -> User:
        return User(name, email)

2.5.2 最佳实践

python
from typing import Protocol, TypeVar, Generic, Dict, Type, Callable, Any
from dataclasses import dataclass
import threading

T = TypeVar('T')

class Factory(Protocol[T]):
    """工厂协议"""
    
    def create(self, **kwargs) -> T:
        ...


@dataclass
class FactoryConfig:
    """工厂配置"""
    default_args: Dict[str, Any] = None
    validators: Dict[str, Callable] = None
    post_processors: List[Callable] = None


class ConfigurableFactory(Generic[T]):
    """
    可配置的泛型工厂
    
    特性:
        - 支持默认参数
        - 支持验证器
        - 支持后处理器
        - 线程安全
    """
    
    def __init__(self, product_class: Type[T], config: FactoryConfig = None):
        self._product_class = product_class
        self._config = config or FactoryConfig()
        self._lock = threading.RLock()
    
    def create(self, **kwargs) -> T:
        """创建产品实例"""
        with self._lock:
            merged_kwargs = self._merge_args(kwargs)
            
            if self._config.validators:
                self._validate(merged_kwargs)
            
            instance = self._product_class(**merged_kwargs)
            
            if self._config.post_processors:
                for processor in self._config.post_processors:
                    instance = processor(instance)
            
            return instance
    
    def _merge_args(self, kwargs: Dict) -> Dict:
        """合并默认参数和传入参数"""
        if self._config.default_args:
            return {**self._config.default_args, **kwargs}
        return kwargs
    
    def _validate(self, kwargs: Dict) -> None:
        """验证参数"""
        for key, validator in self._config.validators.items():
            if key in kwargs and not validator(kwargs[key]):
                raise ValueError(f"参数 {key} 验证失败")


class ProductRegistry:
    """
    产品注册表
    
    集中管理产品类型和工厂
    """
    
    _factories: Dict[str, ConfigurableFactory] = {}
    _lock = threading.RLock()
    
    @classmethod
    def register(cls, name: str, factory: ConfigurableFactory) -> None:
        with cls._lock:
            cls._factories[name] = factory
    
    @classmethod
    def create(cls, name: str, **kwargs) -> Any:
        with cls._lock:
            if name not in cls._factories:
                raise ValueError(f"未注册的产品: {name}")
            return cls._factories[name].create(**kwargs)
    
    @classmethod
    def get_factory(cls, name: str) -> Optional[ConfigurableFactory]:
        return cls._factories.get(name)

2.6 模式评估与决策指南

2.6.1 适用性检查清单

检查项说明
产品类型可能变化?✓ 使用✗ 评估需要扩展性
创建逻辑复杂?✓ 使用✗ 评估需要封装
客户端不应知道具体类?✓ 使用✗ 评估需要解耦
需要延迟创建决策?✓ 使用✗ 评估运行时决定
产品种类很少且固定?✗ 评估✓ 使用简单工厂足够

2.6.2 性能考量

实现方式创建开销内存占用适用场景
直接实例化最低最低简单场景
简单工厂产品固定
工厂方法需要扩展
注册表工厂动态发现
反射工厂插件系统

2.7 小结

工厂方法模式通过延迟绑定多态创建实现了创建者与产品的解耦。在Python中,可以利用动态特性简化实现:

┌─────────────────────────────────────────────────────────────────┐
│                    工厂方法实现选择指南                          │
│                                                                 │
│                    产品种类是否固定?                            │
│                         │                                       │
│              ┌──────────┴──────────┐                            │
│              ▼                     ▼                            │
│            是的                   不确定                         │
│              │                     │                            │
│              ▼                     ▼                            │
│        简单工厂              需要创建产品族?                     │
│                                   │                             │
│                         ┌─────────┴─────────┐                   │
│                         ▼                   ▼                   │
│                       是的               不需要                  │
│                         │                   │                   │
│                         ▼                   ▼                   │
│                    抽象工厂           工厂方法                   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

关键要点

  1. 核心价值:将对象创建与使用分离,支持开闭原则
  2. Python特性:利用Protocol、装饰器、注册表简化实现
  3. 避免过度设计:简单场景使用简单工厂或直接实例化
  4. 测试友好:工厂方法便于注入Mock对象
  5. 扩展机制:注册表模式支持动态扩展

在下一章中,我们将探讨抽象工厂模式,学习如何创建相关的对象家族。

Python技术丛书 - 江苏省宿城中等专业学校