设计原则详解
学习目标
- 深入理解SOLID原则的形式化定义与数学基础
- 掌握各类设计原则的理论依据与应用场景
- 学会在实际项目中权衡和应用设计原则
- 理解设计原则与设计模式之间的内在联系
历史背景
软件设计原则的演进
| 时期 | 里程碑 | 代表性贡献 |
|---|---|---|
| 1970s | 结构化编程 | Dijkstra的GOTO有害论,模块化设计 |
| 1980s | 面向对象 | 抽象、封装、继承、多态的确立 |
| 1990s | 设计模式 | GoF 23种设计模式,SOLID原则形成 |
| 2000s | 敏捷原则 | DRY、KISS、YAGNI等原则普及 |
| 2010s | 函数式原则 | 不可变性、纯函数、副作用管理 |
| 2020s | 云原生原则 | 微服务、容错、可观测性设计 |
SOLID原则的诞生
SOLID原则由Robert C. Martin(Uncle Bob)在2000年代初期系统化整理,其五个原则首字母组成了SOLID这个助记符:
- Single Responsibility Principle(单一职责原则)
- Open/Closed Principle(开闭原则)
- Liskov Substitution Principle(里氏替换原则)
- Interface Segregation Principle(接口隔离原则)
- Dependency Inversion Principle(依赖倒置原则)
形式化定义
内聚与耦合
定义A.1(内聚度) 模块内聚度 $C$ 定义为模块内各元素相关程度的度量:
$$C = \frac{\sum_{i,j \in M} r_{ij}}{|M| \cdot (|M| - 1)}$$
其中 $r_{ij}$ 表示元素 $i$ 和 $j$ 的相关度,$|M|$ 是模块大小。
定义A.2(耦合度) 模块间耦合度 $D$ 定义为模块间依赖关系的强度:
$$D = \frac{\sum_{i \neq j} dep(M_i, M_j)}{n \cdot (n-1)}$$
其中 $dep(M_i, M_j)$ 表示模块 $M_i$ 对 $M_j$ 的依赖强度。
设计目标:最大化内聚,最小化耦合。
单一职责原则(SRP)
定义A.3(单一职责原则) 一个类 $C$ 应该只有一个变化原因 $R$:
$$\forall C: |{R : R \text{ causes change in } C}| = 1$$
形式化表述:如果类 $C$ 有职责集合 $Resp(C) = {r_1, r_2, ..., r_n}$,则:
$$SRP(C) \iff |Resp(C)| = 1$$
定理A.1(SRP复杂度) 违反SRP的类,其修改复杂度随职责数量平方增长:
$$Complexity(C) = O(|Resp(C)|^2)$$
开闭原则(OCP)
定义A.4(开闭原则) 软件实体 $E$ 应满足:
$$\forall ext: Extension(ext) \implies \neg Modification(E)$$
即对于任何扩展 $ext$,实体 $E$ 无需修改即可支持。
形式化表述:
$$OCP(E) \iff \forall f' \in Extension(f): \exists g: f' = g \circ f \land \neg Modified(E)$$
里氏替换原则(LSP)
定义A.5(里氏替换原则) 设 $\phi(x)$ 是关于类型 $T$ 的可证明性质,则对于所有 $S \subseteq T$:
$$\forall o \in S: \phi(o)$$
即子类对象必须能够替换父类对象而不破坏程序正确性。
行为子类型:
$$LSP(S, T) \iff \forall p: {pre_T(p) \implies pre_S(p)} \land {post_S(p) \implies post_T(p)}$$
接口隔离原则(ISP)
定义A.6(接口隔离原则) 客户端 $C$ 不应依赖它不使用的方法:
$$\forall C, I: Methods(C, I) = Methods(Used(C, I))$$
形式化表述:
$$ISP(I) \iff \forall C: deps(C, I) = \bigcap_{m \in used(C, I)} {m}$$
依赖倒置原则(DIP)
定义A.7(依赖倒置原则) 高层模块 $H$ 和低层模块 $L$ 都应依赖抽象 $A$:
$$deps(H, L) = \emptyset \land deps(H, A) \land deps(L, A)$$
依赖方向:
$$DIP \iff \forall H, L: H \not\to L \land H \to A \land L \to A$$
SOLID原则详解
单一职责原则(SRP)
核心思想:一个类应该只有一个引起它变化的原因。
┌─────────────────────────────────────────────────────────────────────────┐
│ 单一职责原则示意图 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 违反SRP: 符合SRP: │
│ ┌──────────────────┐ ┌──────────────┐ │
│ │ User │ │ User │ │
│ ├──────────────────┤ ├──────────────┤ │
│ │ - name │ │ - name │ │
│ │ - email │ │ - email │ │
│ ├──────────────────┤ └──────────────┘ │
│ │ + save() │ │ │
│ │ + validate() │ │ │
│ │ + notify() │ ┌──────┴──────┬──────────┐ │
│ │ + generateReport │ │ │ │ │
│ └──────────────────┘ ┌────▼───┐ ┌────▼───┐ ┌───▼────┐ │
│ │ Repo │ │Validator│ │Notifier│ │
│ ├────────┤ ├────────┤ ├────────┤ │
│ │+save() │ │+valid()│ │+notify()│ │
│ └────────┘ └────────┘ └────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘from typing import Protocol, Optional
from dataclasses import dataclass
from abc import ABC, abstractmethod
@dataclass
class User:
name: str
email: str
def __post_init__(self):
if not self.name:
raise ValueError("用户名不能为空")
class UserRepository(Protocol):
def save(self, user: User) -> None: ...
def find(self, user_id: int) -> Optional[User]: ...
def delete(self, user_id: int) -> None: ...
class InMemoryUserRepository:
def __init__(self):
self._users: dict[int, User] = {}
self._next_id = 1
def save(self, user: User) -> int:
user_id = self._next_id
self._users[user_id] = user
self._next_id += 1
return user_id
def find(self, user_id: int) -> Optional[User]:
return self._users.get(user_id)
def delete(self, user_id: int) -> None:
self._users.pop(user_id, None)
class UserValidator:
def validate(self, user: User) -> bool:
if not user.name or len(user.name) < 2:
return False
if '@' not in user.email:
return False
return True
def validate_email(self, email: str) -> bool:
import re
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
return bool(re.match(pattern, email))
class UserNotifier:
def __init__(self, email_sender: 'EmailSender'):
self._email_sender = email_sender
def notify(self, user: User, message: str) -> None:
self._email_sender.send(user.email, message)
def notify_welcome(self, user: User) -> None:
self.notify(user, f"欢迎 {user.name}!")
class EmailSender(Protocol):
def send(self, to: str, message: str) -> None: ...
class ConsoleEmailSender:
def send(self, to: str, message: str) -> None:
print(f"[Email] To: {to}\n{message}")
class UserService:
def __init__(self,
repository: UserRepository,
validator: UserValidator,
notifier: UserNotifier):
self._repository = repository
self._validator = validator
self._notifier = notifier
def register_user(self, name: str, email: str) -> Optional[int]:
user = User(name=name, email=email)
if not self._validator.validate(user):
return None
user_id = self._repository.save(user)
self._notifier.notify_welcome(user)
return user_id开闭原则(OCP)
核心思想:软件实体应该对扩展开放,对修改关闭。
┌─────────────────────────────────────────────────────────────────────────┐
│ 开闭原则示意图 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 违反OCP: 符合OCP: │
│ ┌──────────────────┐ ┌──────────────────┐ │
│ │ PriceCalculator │ │ <<interface>> │ │
│ ├──────────────────┤ │ DiscountStrategy │ │
│ │ - type: str │ ├──────────────────┤ │
│ ├──────────────────┤ │ + calculate() │ │
│ │ + calculate() │ └────────┬─────────┘ │
│ │ if type==A: │ │ │
│ │ ... │ ┌───────────┼───────────┐ │
│ │ elif type==B: │ │ │ │ │
│ │ ... │ ┌────▼───┐ ┌────▼───┐ ┌────▼───┐ │
│ │ # 新增需修改 │ │NoDisc │ │Percent │ │Fixed │ │
│ └──────────────────┘ └────────┘ └────────┘ └────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘from abc import ABC, abstractmethod
from typing import List, Callable
from dataclasses import dataclass
from enum import Enum, auto
class DiscountStrategy(ABC):
@abstractmethod
def calculate(self, price: float) -> float:
pass
@property
@abstractmethod
def name(self) -> str:
pass
class NoDiscount(DiscountStrategy):
@property
def name(self) -> str:
return "无折扣"
def calculate(self, price: float) -> float:
return price
class PercentageDiscount(DiscountStrategy):
def __init__(self, percentage: float):
if not 0 <= percentage <= 100:
raise ValueError("折扣百分比必须在0-100之间")
self._percentage = percentage
@property
def name(self) -> str:
return f"{self._percentage}%折扣"
def calculate(self, price: float) -> float:
return price * (1 - self._percentage / 100)
class FixedDiscount(DiscountStrategy):
def __init__(self, amount: float):
if amount < 0:
raise ValueError("折扣金额不能为负")
self._amount = amount
@property
def name(self) -> str:
return f"固定折扣{self._amount}元"
def calculate(self, price: float) -> float:
return max(0, price - self._amount)
class TieredDiscount(DiscountStrategy):
def __init__(self, tiers: List[tuple[float, float]]):
self._tiers = sorted(tiers, key=lambda x: x[0], reverse=True)
@property
def name(self) -> str:
return "阶梯折扣"
def calculate(self, price: float) -> float:
for threshold, rate in self._tiers:
if price >= threshold:
return price * (1 - rate / 100)
return price
class BuyXGetYDiscount(DiscountStrategy):
def __init__(self, buy_count: int, get_count: int, unit_price: float):
self._buy = buy_count
self._get = get_count
self._unit_price = unit_price
@property
def name(self) -> str:
return f"买{self._buy}送{self._get}"
def calculate(self, price: float) -> float:
total_items = int(price / self._unit_price)
free_items = (total_items // (self._buy + self._get)) * self._get
return price - free_items * self._unit_price
class PriceCalculator:
def __init__(self, strategy: DiscountStrategy):
self._strategy = strategy
def set_strategy(self, strategy: DiscountStrategy) -> None:
self._strategy = strategy
def calculate(self, price: float) -> float:
return self._strategy.calculate(price)
def calculate_with_details(self, price: float) -> dict:
final_price = self._strategy.calculate(price)
return {
'original_price': price,
'discount_name': self._strategy.name,
'final_price': final_price,
'savings': price - final_price
}
class DiscountStrategyFactory:
_strategies: dict[str, type[DiscountStrategy]] = {}
@classmethod
def register(cls, name: str, strategy_class: type[DiscountStrategy]):
cls._strategies[name] = strategy_class
@classmethod
def create(cls, name: str, *args, **kwargs) -> DiscountStrategy:
if name not in cls._strategies:
raise ValueError(f"未知策略: {name}")
return cls._strategies[name](*args, **kwargs)
DiscountStrategyFactory.register('none', NoDiscount)
DiscountStrategyFactory.register('percentage', PercentageDiscount)
DiscountStrategyFactory.register('fixed', FixedDiscount)里氏替换原则(LSP)
核心思想:子类对象必须能够替换掉所有父类对象而不破坏程序正确性。
┌─────────────────────────────────────────────────────────────────────────┐
│ 里氏替换原则示意图 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 违反LSP: 符合LSP: │
│ ┌──────────────────┐ ┌──────────────────┐ │
│ │ Bird │ │ Bird │ │
│ ├──────────────────┤ ├──────────────────┤ │
│ │ + fly() │ │ + move() │ │
│ └────────┬─────────┘ └────────┬─────────┘ │
│ │ │ │
│ ┌─────┴─────┐ ┌─────┴─────┐ │
│ │ │ │ │ │
│ ┌──▼───┐ ┌───▼───┐ ┌───▼───┐ ┌───▼───┐ │
│ │Sparrow│ │Penguin│ │Flying │ │Swimming│ │
│ │ │ │ │ │ Bird │ │ Bird │ │
│ │+fly() │ │+fly() │◄──错误! ├───────┤ ├────────┤ │
│ └───────┘ │抛出异常│ │+fly() │ │+swim() │ │
│ └────────┘ └───────┘ └────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘from abc import ABC, abstractmethod
from typing import Protocol
class Bird(ABC):
def __init__(self, name: str):
self.name = name
@abstractmethod
def move(self) -> str:
pass
def eat(self) -> str:
return f"{self.name} 正在进食"
class FlyingBird(Bird):
@abstractmethod
def fly(self) -> str:
pass
def move(self) -> str:
return self.fly()
class SwimmingBird(Bird):
@abstractmethod
def swim(self) -> str:
pass
def move(self) -> str:
return self.swim()
class Sparrow(FlyingBird):
def fly(self) -> str:
return f"{self.name} 在空中飞翔"
class Eagle(FlyingBird):
def __init__(self, name: str, wingspan: float):
super().__init__(name)
self.wingspan = wingspan
def fly(self) -> str:
return f"{self.name} 展开{self.wingspan}米翅膀高飞"
class Penguin(SwimmingBird):
def swim(self) -> str:
return f"{self.name} 在水中游泳"
class Duck(Bird):
def move(self) -> str:
return f"{self.name} 可以飞也可以游"
def fly(self) -> str:
return f"{self.name} 在低空飞行"
def swim(self) -> str:
return f"{self.name} 在水面游泳"
def make_bird_move(bird: Bird) -> str:
return bird.move()
def make_flying_bird_fly(bird: FlyingBird) -> str:
return bird.fly()
sparrow = Sparrow("小麻雀")
eagle = Eagle("雄鹰", 2.0)
penguin = Penguin("企鹅")
duck = Duck("鸭子")
print(make_bird_move(sparrow))
print(make_bird_move(penguin))
print(make_bird_move(duck))
print(make_flying_bird_fly(sparrow))
print(make_flying_bird_fly(eagle))接口隔离原则(ISP)
核心思想:客户端不应该依赖它不需要的接口。
┌─────────────────────────────────────────────────────────────────────────┐
│ 接口隔离原则示意图 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 违反ISP: 符合ISP: │
│ ┌──────────────────┐ ┌──────────────┐ │
│ │ <<interface>> │ │ <<interface>>│ │
│ │ IMachine │ │ Printer │ │
│ ├──────────────────┤ ├──────────────┤ │
│ │ + print() │ │ + print() │ │
│ │ + scan() │ └──────────────┘ │
│ │ + fax() │ │ │
│ └────────┬─────────┘ ┌──────┴──────┐ │
│ │ │ │ │
│ ┌─────┴─────┐ ┌────▼───┐ ┌─────▼────┐ │
│ │ │ │Simple │ │MultiFunc │ │
│ ┌──▼───┐ ┌───▼───┐ │Printer │ │ Device │ │
│ │Simple│ │Multi │ └────────┘ └────┬─────┘ │
│ │Printer│ │Function│ │ │
│ │ │ │ │ ┌──────┴──────┐ │
│ │#scan │ │ │ │ <<interface>>│ │
│ │抛异常│ │ │ │ Scanner │ │
│ └──────┘ └───────┘ ├─────────────┤ │
│ │ + scan() │ │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘from abc import ABC, abstractmethod
from typing import Protocol, Optional
class Printer(Protocol):
def print(self, document: str) -> None: ...
class Scanner(Protocol):
def scan(self) -> str: ...
class Fax(Protocol):
def send_fax(self, number: str, document: str) -> None: ...
def receive_fax(self) -> str: ...
class Copier(Protocol):
def copy(self, document: str) -> str: ...
class SimplePrinter:
def __init__(self, name: str):
self._name = name
self._paper_count = 0
def print(self, document: str) -> None:
self._paper_count += 1
print(f"[{self._name}] 打印: {document}")
def get_paper_count(self) -> int:
return self._paper_count
def add_paper(self, count: int) -> None:
self._paper_count += count
class SimpleScanner:
def __init__(self, name: str):
self._name = name
def scan(self) -> str:
return f"[{self._name}] 扫描的文档内容"
class MultiFunctionDevice:
def __init__(self, name: str):
self._name = name
self._fax_buffer: list[str] = []
def print(self, document: str) -> None:
print(f"[{self._name}] 打印: {document}")
def scan(self) -> str:
return f"[{self._name}] 扫描的文档内容"
def send_fax(self, number: str, document: str) -> None:
print(f"[{self._name}] 发送传真到 {number}: {document}")
def receive_fax(self) -> str:
if self._fax_buffer:
return self._fax_buffer.pop(0)
return "无传真"
def copy(self, document: str) -> str:
scanned = self.scan()
self.print(f"副本: {document}")
return scanned
class OfficeWorker:
def __init__(self, printer: Printer):
self._printer = printer
def print_document(self, document: str) -> None:
self._printer.print(document)
class DocumentArchiver:
def __init__(self, scanner: Scanner):
self._scanner = scanner
def archive_document(self) -> str:
return self._scanner.scan()
class FaxOperator:
def __init__(self, fax: Fax):
self._fax = fax
def send_document(self, number: str, document: str) -> None:
self._fax.send_fax(number, document)
simple_printer = SimplePrinter("办公室打印机")
worker = OfficeWorker(simple_printer)
worker.print_document("会议记录")
scanner = SimpleScanner("档案扫描仪")
archiver = DocumentArchiver(scanner)
print(archiver.archive_document())
mfd = MultiFunctionDevice("多功能一体机")
fax_operator = FaxOperator(mfd)
fax_operator.send_document("12345678", "合同文件")依赖倒置原则(DIP)
核心思想:高层模块不应该依赖低层模块,两者都应该依赖其抽象。
┌─────────────────────────────────────────────────────────────────────────┐
│ 依赖倒置原则示意图 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 违反DIP: 符合DIP: │
│ │
│ ┌──────────────┐ ┌──────────────────┐ │
│ │ UserService │ │ <<interface>> │ │
│ ├──────────────┤ │ Database │ │
│ │ - db: MySQL │◄──直接依赖 ├──────────────────┤ │
│ ├──────────────┤ │ + save() │ │
│ │ + saveUser() │ │ + find() │ │
│ └──────────────┘ └────────┬─────────┘ │
│ │ │
│ ┌──────────────┐ ┌────────┴────────┐ │
│ │ MySQL │ │ │ │
│ └──────────────┘ ┌─────▼─────┐ ┌──────▼─────┐ │
│ │ UserService│ │ MySQL │ │
│ 问题: 更换数据库需 ├───────────┤ ├────────────┤ │
│ 修改UserService │ - db: │ │ + save() │ │
│ │ Database│ │ + find() │ │
│ ├───────────┤ └────────────┘ │
│ │+saveUser()│ │
│ └───────────┘ ┌────────────┐ │
│ │ PostgreSQL │ │
│ ├────────────┤ │
│ │ + save() │ │
│ │ + find() │ │
│ └────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘from abc import ABC, abstractmethod
from typing import Protocol, Optional, Any
from dataclasses import dataclass
class Database(Protocol):
def save(self, table: str, data: dict) -> int: ...
def find(self, table: str, id: int) -> Optional[dict]: ...
def delete(self, table: str, id: int) -> bool: ...
def find_all(self, table: str) -> list[dict]: ...
class Logger(Protocol):
def log(self, message: str, level: str = "INFO") -> None: ...
def error(self, message: str) -> None: ...
def debug(self, message: str) -> None: ...
@dataclass
class User:
id: Optional[int] = None
name: str = ""
email: str = ""
class MySQLDatabase:
def __init__(self, connection_string: str):
self._connection_string = connection_string
self._data: dict[str, dict[int, dict]] = {}
self._next_id: dict[str, int] = {}
def save(self, table: str, data: dict) -> int:
if table not in self._next_id:
self._next_id[table] = 1
self._data[table] = {}
id = self._next_id[table]
data['id'] = id
self._data[table][id] = data.copy()
self._next_id[table] += 1
return id
def find(self, table: str, id: int) -> Optional[dict]:
return self._data.get(table, {}).get(id)
def delete(self, table: str, id: int) -> bool:
if table in self._data and id in self._data[table]:
del self._data[table][id]
return True
return False
def find_all(self, table: str) -> list[dict]:
return list(self._data.get(table, {}).values())
class PostgreSQLDatabase:
def __init__(self, connection_string: str):
self._connection_string = connection_string
self._data: dict[str, dict[int, dict]] = {}
self._next_id: dict[str, int] = {}
def save(self, table: str, data: dict) -> int:
if table not in self._next_id:
self._next_id[table] = 1
self._data[table] = {}
id = self._next_id[table]
data['id'] = id
self._data[table][id] = data.copy()
self._next_id[table] += 1
print(f"[PostgreSQL] INSERT INTO {table} VALUES (...)")
return id
def find(self, table: str, id: int) -> Optional[dict]:
print(f"[PostgreSQL] SELECT * FROM {table} WHERE id = {id}")
return self._data.get(table, {}).get(id)
def delete(self, table: str, id: int) -> bool:
print(f"[PostgreSQL] DELETE FROM {table} WHERE id = {id}")
if table in self._data and id in self._data[table]:
del self._data[table][id]
return True
return False
def find_all(self, table: str) -> list[dict]:
print(f"[PostgreSQL] SELECT * FROM {table}")
return list(self._data.get(table, {}).values())
class ConsoleLogger:
def log(self, message: str, level: str = "INFO") -> None:
print(f"[{level}] {message}")
def error(self, message: str) -> None:
self.log(message, "ERROR")
def debug(self, message: str) -> None:
self.log(message, "DEBUG")
class UserService:
TABLE = "users"
def __init__(self, database: Database, logger: Logger):
self._database = database
self._logger = logger
def create_user(self, name: str, email: str) -> User:
self._logger.log(f"创建用户: {name}")
data = {'name': name, 'email': email}
user_id = self._database.save(self.TABLE, data)
self._logger.debug(f"用户创建成功, ID: {user_id}")
return User(id=user_id, name=name, email=email)
def get_user(self, user_id: int) -> Optional[User]:
self._logger.debug(f"查询用户: {user_id}")
data = self._database.find(self.TABLE, user_id)
if data:
return User(**data)
return None
def delete_user(self, user_id: int) -> bool:
self._logger.log(f"删除用户: {user_id}")
return self._database.delete(self.TABLE, user_id)
mysql_db = MySQLDatabase("mysql://localhost:3306/mydb")
logger = ConsoleLogger()
user_service = UserService(mysql_db, logger)
user = user_service.create_user("张三", "zhangsan@example.com")
print(f"创建的用户: {user}")
pg_db = PostgreSQLDatabase("postgresql://localhost:5432/mydb")
user_service_pg = UserService(pg_db, logger)
user_pg = user_service_pg.create_user("李四", "lisi@example.com")其他设计原则
迪米特法则(LoD)
定义A.8(迪米特法则) 一个对象 $O$ 应该只与以下对象通信:
- $O$ 本身
- $O$ 的成员对象
- $O$ 创建的对象
- $O$ 的参数对象
形式化表述:
$$\forall m \in Methods(O): \forall T \in Types(m): T \in {O, Fields(O), Params(m), Locals(m)}$$
from dataclasses import dataclass
from typing import Optional
@dataclass
class Department:
name: str
manager_name: str
def get_manager_name(self) -> str:
return self.manager_name
class Company:
def __init__(self, name: str):
self._name = name
self._departments: dict[str, Department] = {}
def add_department(self, name: str, manager: str) -> None:
self._departments[name] = Department(name, manager)
def get_department(self, name: str) -> Optional[Department]:
return self._departments.get(name)
def get_manager_of_department(self, dept_name: str) -> Optional[str]:
dept = self._departments.get(dept_name)
return dept.get_manager_name() if dept else None
class Employee:
def __init__(self, name: str, company: Company):
self._name = name
self._company = company
def get_manager(self, department_name: str) -> Optional[str]:
return self._company.get_manager_of_department(department_name)
company = Company("科技公司")
company.add_department("研发部", "王经理")
company.add_department("市场部", "李经理")
employee = Employee("张三", company)
print(f"研发部经理: {employee.get_manager('研发部')}")组合复用原则(CRP)
定义A.9(组合复用原则) 优先使用对象组合而非继承来实现代码复用。
组合 vs 继承:
$$Inheritance: \quad IsA \implies \forall m \in M_{parent}: m \in M_{child}$$
$$Composition: \quad HasA \implies \forall m \in M_{component}: delegate(m) \in M_{container}$$
from abc import ABC, abstractmethod
from typing import Protocol
class Engine(Protocol):
def start(self) -> str: ...
def stop(self) -> str: ...
def get_type(self) -> str: ...
class GasolineEngine:
def start(self) -> str:
return "汽油引擎启动"
def stop(self) -> str:
return "汽油引擎停止"
def get_type(self) -> str:
return "汽油"
class ElectricEngine:
def start(self) -> str:
return "电动引擎启动"
def stop(self) -> str:
return "电动引擎停止"
def get_type(self) -> str:
return "电动"
class HybridEngine:
def __init__(self):
self._gasoline = GasolineEngine()
self._electric = ElectricEngine()
self._current_mode = "electric"
def start(self) -> str:
if self._current_mode == "electric":
return self._electric.start()
return self._gasoline.start()
def stop(self) -> str:
return f"{self._electric.stop()} 和 {self._gasoline.stop()}"
def get_type(self) -> str:
return "混合动力"
def switch_mode(self) -> None:
self._current_mode = "gasoline" if self._current_mode == "electric" else "electric"
class Transmission(Protocol):
def shift(self, gear: int) -> str: ...
class AutomaticTransmission:
def shift(self, gear: int) -> str:
return f"自动变速箱切换到 {gear} 档"
class ManualTransmission:
def shift(self, gear: int) -> str:
return f"手动变速箱切换到 {gear} 档"
class Car:
def __init__(self, brand: str, engine: Engine, transmission: Transmission):
self._brand = brand
self._engine = engine
self._transmission = transmission
self._speed = 0
def start(self) -> str:
return f"{self._brand}: {self._engine.start()}"
def stop(self) -> str:
self._speed = 0
return f"{self._brand}: {self._engine.stop()}"
def accelerate(self, speed: int) -> str:
self._speed += speed
return f"{self._brand}: 加速到 {self._speed} km/h"
def shift_gear(self, gear: int) -> str:
return f"{self._brand}: {self._transmission.shift(gear)}"
def get_info(self) -> str:
return f"{self._brand} ({self._engine.get_type()}引擎)"
gasoline_car = Car("丰田", GasolineEngine(), AutomaticTransmission())
electric_car = Car("特斯拉", ElectricEngine(), AutomaticTransmission())
hybrid_car = Car("普锐斯", HybridEngine(), ManualTransmission())
print(gasoline_car.start())
print(electric_car.start())
print(hybrid_car.start())KISS原则
定义A.10(KISS原则) Keep It Simple, Stupid - 保持简单愚蠢。
from typing import List, Optional
from dataclasses import dataclass
@dataclass
class Product:
id: int
name: str
price: float
class SimpleInventory:
def __init__(self):
self._products: dict[int, Product] = {}
self._quantities: dict[int, int] = {}
def add_product(self, product: Product, quantity: int) -> None:
self._products[product.id] = product
self._quantities[product.id] = self._quantities.get(product.id, 0) + quantity
def get_quantity(self, product_id: int) -> int:
return self._quantities.get(product_id, 0)
def reduce_quantity(self, product_id: int, amount: int) -> bool:
if self._quantities.get(product_id, 0) >= amount:
self._quantities[product_id] -= amount
return True
return False
def get_total_value(self) -> float:
return sum(
self._products[pid].price * qty
for pid, qty in self._quantities.items()
)DRY原则
定义A.11(DRY原则) Don't Repeat Yourself - 不要重复自己。
from typing import Callable, TypeVar, Generic
from dataclasses import dataclass
from functools import wraps
import time
T = TypeVar('T')
@dataclass
class ValidationResult:
is_valid: bool
errors: list[str]
class Validator:
@staticmethod
def validate_email(email: str) -> ValidationResult:
errors = []
if '@' not in email:
errors.append("邮箱必须包含@符号")
if '.' not in email.split('@')[-1]:
errors.append("邮箱域名无效")
return ValidationResult(len(errors) == 0, errors)
@staticmethod
def validate_password(password: str) -> ValidationResult:
errors = []
if len(password) < 8:
errors.append("密码长度至少8位")
if not any(c.isupper() for c in password):
errors.append("密码必须包含大写字母")
if not any(c.isdigit() for c in password):
errors.append("密码必须包含数字")
return ValidationResult(len(errors) == 0, errors)
@staticmethod
def validate_username(username: str) -> ValidationResult:
errors = []
if len(username) < 3:
errors.append("用户名长度至少3位")
if not username.isalnum():
errors.append("用户名只能包含字母和数字")
return ValidationResult(len(errors) == 0, errors)
def retry(max_attempts: int = 3, delay: float = 1.0):
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_attempts):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_attempts - 1:
time.sleep(delay)
raise last_exception
return wrapper
return decoratorYAGNI原则
定义A.12(YAGNI原则) You Aren't Gonna Need It - 你不会需要它。
class UserRepository:
def __init__(self, database):
self._database = database
def find_by_id(self, user_id: int):
return self._database.find('users', user_id)
def save(self, user: dict) -> int:
return self._database.save('users', user)
def delete(self, user_id: int) -> bool:
return self._database.delete('users', user_id)
class UserService:
def __init__(self, repository: UserRepository):
self._repository = repository
def get_user(self, user_id: int):
return self._repository.find_by_id(user_id)
def create_user(self, name: str, email: str) -> dict:
user = {'name': name, 'email': email}
user_id = self._repository.save(user)
user['id'] = user_id
return user原则之间的关系
┌─────────────────────────────────────────────────────────────────────────┐
│ 设计原则关系图 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ │
│ │ 高内聚低耦合 │ │
│ │ (设计目标) │ │
│ └────────┬────────┘ │
│ │ │
│ ┌──────────────────────┼──────────────────────┐ │
│ │ │ │ │
│ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ │
│ │ SOLID原则 │ │ 其他原则 │ │ 实践原则 │ │
│ ├─────────────┤ ├─────────────┤ ├─────────────┤ │
│ │ SRP ──────► │ │ LoD ──────► │ │ KISS ─────► │ │
│ │ 高内聚 │ │ 低耦合 │ │ 简洁性 │ │
│ │ │ │ │ │ │ │
│ │ OCP ──────► │ │ CRP ──────► │ │ DRY ──────► │ │
│ │ 可扩展 │ │ 灵活复用 │ │ 可维护 │ │
│ │ │ │ │ │ │ │
│ │ LSP ──────► │ │ │ │ YAGNI ────► │ │
│ │ 可替换 │ │ │ │ 避免过度 │ │
│ │ │ │ │ │ │ │
│ │ ISP ──────► │ │ │ │ │ │
│ │ 接口精简 │ │ │ │ │ │
│ │ │ │ │ │ │ │
│ │ DIP ──────► │ │ │ │ │ │
│ │ 解耦 │ │ │ │ │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘原则应用指南
场景决策表
| 场景 | 推荐原则 | 具体应用 |
|---|---|---|
| 类职责过多 | 单一职责原则 | 拆分为多个专注的类 |
| 需要扩展功能 | 开闭原则 | 使用策略模式、装饰器模式 |
| 继承层次设计 | 里氏替换原则 | 确保子类行为一致 |
| 接口设计 | 接口隔离原则 | 拆分大接口为小接口 |
| 模块依赖 | 依赖倒置原则 | 引入抽象层解耦 |
| 类间通信 | 迪米特法则 | 减少直接依赖 |
| 代码复用 | 组合复用原则 | 优先组合而非继承 |
| 系统复杂度 | KISS原则 | 保持简单直接 |
| 重复代码 | DRY原则 | 提取公共逻辑 |
| 功能预测 | YAGNI原则 | 只实现当前需要 |
原则冲突处理
| 冲突场景 | 权衡建议 |
|---|---|
| SRP vs 性能 | 适度合并高频调用的职责 |
| OCP vs 简洁性 | 预测稳定的变化点再抽象 |
| DRY vs YAGNI | 只提取当前存在的重复 |
| 组合 vs 继承 | 行为复用用组合,类型关系用继承 |
快速参考卡片
SOLID原则速查
| 原则 | 核心思想 | 检验方法 |
|---|---|---|
| SRP | 一个类只有一个变化原因 | 能否用一句话描述类的职责? |
| OCP | 对扩展开放,对修改关闭 | 新增功能是否需要修改现有代码? |
| LSP | 子类可替换父类 | 子类是否违反父类契约? |
| ISP | 接口职责单一 | 客户端是否依赖不需要的方法? |
| DIP | 依赖抽象而非具体 | 高层模块是否直接依赖低层? |
设计原则检查清单
- [ ] 每个类是否有单一明确的职责?
- [ ] 新功能是否可以通过扩展而非修改实现?
- [ ] 子类是否可以安全替换父类?
- [ ] 接口是否足够精简?
- [ ] 是否依赖抽象而非具体实现?
- [ ] 类之间的耦合是否最小化?
- [ ] 是否优先使用组合而非继承?
- [ ] 代码是否足够简单?
- [ ] 是否存在重复代码?
- [ ] 是否只实现当前需要的功能?
总结
设计原则是软件设计的基石,它们提供了指导性的思想而非僵化的规则。在实际应用中:
- 理解原则本质:掌握每个原则背后的设计哲学
- 权衡取舍:根据具体场景灵活应用原则
- 持续重构:通过重构逐步改进设计
- 团队共识:建立团队对原则的共同理解
记住:设计原则是工具,不是目的。好的设计应该服务于业务需求,同时保持代码的可维护性和可扩展性。