Skip to content

设计原则详解

学习目标

  • 深入理解SOLID原则的形式化定义与数学基础
  • 掌握各类设计原则的理论依据与应用场景
  • 学会在实际项目中权衡和应用设计原则
  • 理解设计原则与设计模式之间的内在联系

历史背景

软件设计原则的演进

时期里程碑代表性贡献
1970s结构化编程Dijkstra的GOTO有害论,模块化设计
1980s面向对象抽象、封装、继承、多态的确立
1990s设计模式GoF 23种设计模式,SOLID原则形成
2000s敏捷原则DRY、KISS、YAGNI等原则普及
2010s函数式原则不可变性、纯函数、副作用管理
2020s云原生原则微服务、容错、可观测性设计

SOLID原则的诞生

SOLID原则由Robert C. Martin(Uncle Bob)在2000年代初期系统化整理,其五个原则首字母组成了SOLID这个助记符:

  • Single Responsibility Principle(单一职责原则)
  • Open/Closed Principle(开闭原则)
  • Liskov Substitution Principle(里氏替换原则)
  • Interface Segregation Principle(接口隔离原则)
  • Dependency Inversion Principle(依赖倒置原则)

形式化定义

内聚与耦合

定义A.1(内聚度) 模块内聚度 $C$ 定义为模块内各元素相关程度的度量:

$$C = \frac{\sum_{i,j \in M} r_{ij}}{|M| \cdot (|M| - 1)}$$

其中 $r_{ij}$ 表示元素 $i$ 和 $j$ 的相关度,$|M|$ 是模块大小。

定义A.2(耦合度) 模块间耦合度 $D$ 定义为模块间依赖关系的强度:

$$D = \frac{\sum_{i \neq j} dep(M_i, M_j)}{n \cdot (n-1)}$$

其中 $dep(M_i, M_j)$ 表示模块 $M_i$ 对 $M_j$ 的依赖强度。

设计目标:最大化内聚,最小化耦合。

单一职责原则(SRP)

定义A.3(单一职责原则) 一个类 $C$ 应该只有一个变化原因 $R$:

$$\forall C: |{R : R \text{ causes change in } C}| = 1$$

形式化表述:如果类 $C$ 有职责集合 $Resp(C) = {r_1, r_2, ..., r_n}$,则:

$$SRP(C) \iff |Resp(C)| = 1$$

定理A.1(SRP复杂度) 违反SRP的类,其修改复杂度随职责数量平方增长:

$$Complexity(C) = O(|Resp(C)|^2)$$

开闭原则(OCP)

定义A.4(开闭原则) 软件实体 $E$ 应满足:

$$\forall ext: Extension(ext) \implies \neg Modification(E)$$

即对于任何扩展 $ext$,实体 $E$ 无需修改即可支持。

形式化表述

$$OCP(E) \iff \forall f' \in Extension(f): \exists g: f' = g \circ f \land \neg Modified(E)$$

里氏替换原则(LSP)

定义A.5(里氏替换原则) 设 $\phi(x)$ 是关于类型 $T$ 的可证明性质,则对于所有 $S \subseteq T$:

$$\forall o \in S: \phi(o)$$

即子类对象必须能够替换父类对象而不破坏程序正确性。

行为子类型

$$LSP(S, T) \iff \forall p: {pre_T(p) \implies pre_S(p)} \land {post_S(p) \implies post_T(p)}$$

接口隔离原则(ISP)

定义A.6(接口隔离原则) 客户端 $C$ 不应依赖它不使用的方法:

$$\forall C, I: Methods(C, I) = Methods(Used(C, I))$$

形式化表述

$$ISP(I) \iff \forall C: deps(C, I) = \bigcap_{m \in used(C, I)} {m}$$

依赖倒置原则(DIP)

定义A.7(依赖倒置原则) 高层模块 $H$ 和低层模块 $L$ 都应依赖抽象 $A$:

$$deps(H, L) = \emptyset \land deps(H, A) \land deps(L, A)$$

依赖方向

$$DIP \iff \forall H, L: H \not\to L \land H \to A \land L \to A$$

SOLID原则详解

单一职责原则(SRP)

核心思想:一个类应该只有一个引起它变化的原因。

┌─────────────────────────────────────────────────────────────────────────┐
│                     单一职责原则示意图                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  违反SRP:                          符合SRP:                            │
│  ┌──────────────────┐              ┌──────────────┐                    │
│  │     User         │              │    User      │                    │
│  ├──────────────────┤              ├──────────────┤                    │
│  │ - name           │              │ - name       │                    │
│  │ - email          │              │ - email      │                    │
│  ├──────────────────┤              └──────────────┘                    │
│  │ + save()         │                     │                           │
│  │ + validate()     │                     │                           │
│  │ + notify()       │              ┌──────┴──────┬──────────┐         │
│  │ + generateReport │              │             │          │         │
│  └──────────────────┘         ┌────▼───┐   ┌────▼───┐  ┌───▼────┐    │
│                               │ Repo   │   │Validator│  │Notifier│    │
│                               ├────────┤   ├────────┤  ├────────┤    │
│                               │+save() │   │+valid()│  │+notify()│   │
│                               └────────┘   └────────┘  └────────┘    │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
python
from typing import Protocol, Optional
from dataclasses import dataclass
from abc import ABC, abstractmethod

@dataclass
class User:
    name: str
    email: str
    
    def __post_init__(self):
        if not self.name:
            raise ValueError("用户名不能为空")

class UserRepository(Protocol):
    def save(self, user: User) -> None: ...
    def find(self, user_id: int) -> Optional[User]: ...
    def delete(self, user_id: int) -> None: ...

class InMemoryUserRepository:
    def __init__(self):
        self._users: dict[int, User] = {}
        self._next_id = 1
    
    def save(self, user: User) -> int:
        user_id = self._next_id
        self._users[user_id] = user
        self._next_id += 1
        return user_id
    
    def find(self, user_id: int) -> Optional[User]:
        return self._users.get(user_id)
    
    def delete(self, user_id: int) -> None:
        self._users.pop(user_id, None)

class UserValidator:
    def validate(self, user: User) -> bool:
        if not user.name or len(user.name) < 2:
            return False
        if '@' not in user.email:
            return False
        return True
    
    def validate_email(self, email: str) -> bool:
        import re
        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        return bool(re.match(pattern, email))

class UserNotifier:
    def __init__(self, email_sender: 'EmailSender'):
        self._email_sender = email_sender
    
    def notify(self, user: User, message: str) -> None:
        self._email_sender.send(user.email, message)
    
    def notify_welcome(self, user: User) -> None:
        self.notify(user, f"欢迎 {user.name}!")

class EmailSender(Protocol):
    def send(self, to: str, message: str) -> None: ...

class ConsoleEmailSender:
    def send(self, to: str, message: str) -> None:
        print(f"[Email] To: {to}\n{message}")

class UserService:
    def __init__(self, 
                 repository: UserRepository,
                 validator: UserValidator,
                 notifier: UserNotifier):
        self._repository = repository
        self._validator = validator
        self._notifier = notifier
    
    def register_user(self, name: str, email: str) -> Optional[int]:
        user = User(name=name, email=email)
        
        if not self._validator.validate(user):
            return None
        
        user_id = self._repository.save(user)
        self._notifier.notify_welcome(user)
        return user_id

开闭原则(OCP)

核心思想:软件实体应该对扩展开放,对修改关闭。

┌─────────────────────────────────────────────────────────────────────────┐
│                       开闭原则示意图                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  违反OCP:                          符合OCP:                            │
│  ┌──────────────────┐              ┌──────────────────┐                │
│  │ PriceCalculator  │              │ <<interface>>    │                │
│  ├──────────────────┤              │ DiscountStrategy │                │
│  │ - type: str      │              ├──────────────────┤                │
│  ├──────────────────┤              │ + calculate()    │                │
│  │ + calculate()    │              └────────┬─────────┘                │
│  │   if type==A:    │                       │                          │
│  │     ...          │           ┌───────────┼───────────┐              │
│  │   elif type==B:  │           │           │           │              │
│  │     ...          │      ┌────▼───┐  ┌────▼───┐  ┌────▼───┐         │
│  │   # 新增需修改    │      │NoDisc  │  │Percent │  │Fixed   │         │
│  └──────────────────┘      └────────┘  └────────┘  └────────┘         │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
python
from abc import ABC, abstractmethod
from typing import List, Callable
from dataclasses import dataclass
from enum import Enum, auto

class DiscountStrategy(ABC):
    @abstractmethod
    def calculate(self, price: float) -> float:
        pass
    
    @property
    @abstractmethod
    def name(self) -> str:
        pass

class NoDiscount(DiscountStrategy):
    @property
    def name(self) -> str:
        return "无折扣"
    
    def calculate(self, price: float) -> float:
        return price

class PercentageDiscount(DiscountStrategy):
    def __init__(self, percentage: float):
        if not 0 <= percentage <= 100:
            raise ValueError("折扣百分比必须在0-100之间")
        self._percentage = percentage
    
    @property
    def name(self) -> str:
        return f"{self._percentage}%折扣"
    
    def calculate(self, price: float) -> float:
        return price * (1 - self._percentage / 100)

class FixedDiscount(DiscountStrategy):
    def __init__(self, amount: float):
        if amount < 0:
            raise ValueError("折扣金额不能为负")
        self._amount = amount
    
    @property
    def name(self) -> str:
        return f"固定折扣{self._amount}元"
    
    def calculate(self, price: float) -> float:
        return max(0, price - self._amount)

class TieredDiscount(DiscountStrategy):
    def __init__(self, tiers: List[tuple[float, float]]):
        self._tiers = sorted(tiers, key=lambda x: x[0], reverse=True)
    
    @property
    def name(self) -> str:
        return "阶梯折扣"
    
    def calculate(self, price: float) -> float:
        for threshold, rate in self._tiers:
            if price >= threshold:
                return price * (1 - rate / 100)
        return price

class BuyXGetYDiscount(DiscountStrategy):
    def __init__(self, buy_count: int, get_count: int, unit_price: float):
        self._buy = buy_count
        self._get = get_count
        self._unit_price = unit_price
    
    @property
    def name(self) -> str:
        return f"买{self._buy}{self._get}"
    
    def calculate(self, price: float) -> float:
        total_items = int(price / self._unit_price)
        free_items = (total_items // (self._buy + self._get)) * self._get
        return price - free_items * self._unit_price

class PriceCalculator:
    def __init__(self, strategy: DiscountStrategy):
        self._strategy = strategy
    
    def set_strategy(self, strategy: DiscountStrategy) -> None:
        self._strategy = strategy
    
    def calculate(self, price: float) -> float:
        return self._strategy.calculate(price)
    
    def calculate_with_details(self, price: float) -> dict:
        final_price = self._strategy.calculate(price)
        return {
            'original_price': price,
            'discount_name': self._strategy.name,
            'final_price': final_price,
            'savings': price - final_price
        }

class DiscountStrategyFactory:
    _strategies: dict[str, type[DiscountStrategy]] = {}
    
    @classmethod
    def register(cls, name: str, strategy_class: type[DiscountStrategy]):
        cls._strategies[name] = strategy_class
    
    @classmethod
    def create(cls, name: str, *args, **kwargs) -> DiscountStrategy:
        if name not in cls._strategies:
            raise ValueError(f"未知策略: {name}")
        return cls._strategies[name](*args, **kwargs)

DiscountStrategyFactory.register('none', NoDiscount)
DiscountStrategyFactory.register('percentage', PercentageDiscount)
DiscountStrategyFactory.register('fixed', FixedDiscount)

里氏替换原则(LSP)

核心思想:子类对象必须能够替换掉所有父类对象而不破坏程序正确性。

┌─────────────────────────────────────────────────────────────────────────┐
│                     里氏替换原则示意图                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  违反LSP:                          符合LSP:                            │
│  ┌──────────────────┐              ┌──────────────────┐                │
│  │     Bird         │              │     Bird         │                │
│  ├──────────────────┤              ├──────────────────┤                │
│  │ + fly()          │              │ + move()         │                │
│  └────────┬─────────┘              └────────┬─────────┘                │
│           │                                 │                          │
│     ┌─────┴─────┐                     ┌─────┴─────┐                   │
│     │           │                     │           │                   │
│  ┌──▼───┐   ┌───▼───┐             ┌───▼───┐   ┌───▼───┐              │
│  │Sparrow│  │Penguin│             │Flying │   │Swimming│              │
│  │       │  │       │             │ Bird  │   │ Bird   │              │
│  │+fly() │  │+fly() │◄──错误!     ├───────┤   ├────────┤              │
│  └───────┘  │抛出异常│             │+fly() │   │+swim() │              │
│             └────────┘             └───────┘   └────────┘              │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
python
from abc import ABC, abstractmethod
from typing import Protocol

class Bird(ABC):
    def __init__(self, name: str):
        self.name = name
    
    @abstractmethod
    def move(self) -> str:
        pass
    
    def eat(self) -> str:
        return f"{self.name} 正在进食"

class FlyingBird(Bird):
    @abstractmethod
    def fly(self) -> str:
        pass
    
    def move(self) -> str:
        return self.fly()

class SwimmingBird(Bird):
    @abstractmethod
    def swim(self) -> str:
        pass
    
    def move(self) -> str:
        return self.swim()

class Sparrow(FlyingBird):
    def fly(self) -> str:
        return f"{self.name} 在空中飞翔"

class Eagle(FlyingBird):
    def __init__(self, name: str, wingspan: float):
        super().__init__(name)
        self.wingspan = wingspan
    
    def fly(self) -> str:
        return f"{self.name} 展开{self.wingspan}米翅膀高飞"

class Penguin(SwimmingBird):
    def swim(self) -> str:
        return f"{self.name} 在水中游泳"

class Duck(Bird):
    def move(self) -> str:
        return f"{self.name} 可以飞也可以游"
    
    def fly(self) -> str:
        return f"{self.name} 在低空飞行"
    
    def swim(self) -> str:
        return f"{self.name} 在水面游泳"

def make_bird_move(bird: Bird) -> str:
    return bird.move()

def make_flying_bird_fly(bird: FlyingBird) -> str:
    return bird.fly()

sparrow = Sparrow("小麻雀")
eagle = Eagle("雄鹰", 2.0)
penguin = Penguin("企鹅")
duck = Duck("鸭子")

print(make_bird_move(sparrow))
print(make_bird_move(penguin))
print(make_bird_move(duck))

print(make_flying_bird_fly(sparrow))
print(make_flying_bird_fly(eagle))

接口隔离原则(ISP)

核心思想:客户端不应该依赖它不需要的接口。

┌─────────────────────────────────────────────────────────────────────────┐
│                     接口隔离原则示意图                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  违反ISP:                          符合ISP:                            │
│  ┌──────────────────┐              ┌──────────────┐                    │
│  │ <<interface>>    │              │ <<interface>>│                    │
│  │   IMachine       │              │   Printer    │                    │
│  ├──────────────────┤              ├──────────────┤                    │
│  │ + print()        │              │ + print()    │                    │
│  │ + scan()         │              └──────────────┘                    │
│  │ + fax()          │                     │                           │
│  └────────┬─────────┘              ┌──────┴──────┐                    │
│           │                        │             │                    │
│     ┌─────┴─────┐             ┌────▼───┐   ┌─────▼────┐               │
│     │           │             │Simple  │   │MultiFunc │               │
│  ┌──▼───┐   ┌───▼───┐         │Printer │   │  Device  │               │
│  │Simple│   │Multi  │         └────────┘   └────┬─────┘               │
│  │Printer│  │Function│                           │                     │
│  │      │  │       │                      ┌──────┴──────┐            │
│  │#scan │  │       │                      │ <<interface>>│            │
│  │抛异常│  │       │                      │   Scanner   │            │
│  └──────┘  └───────┘                      ├─────────────┤            │
│                                           │ + scan()    │            │
│                                           └─────────────┘            │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
python
from abc import ABC, abstractmethod
from typing import Protocol, Optional

class Printer(Protocol):
    def print(self, document: str) -> None: ...

class Scanner(Protocol):
    def scan(self) -> str: ...

class Fax(Protocol):
    def send_fax(self, number: str, document: str) -> None: ...
    def receive_fax(self) -> str: ...

class Copier(Protocol):
    def copy(self, document: str) -> str: ...

class SimplePrinter:
    def __init__(self, name: str):
        self._name = name
        self._paper_count = 0
    
    def print(self, document: str) -> None:
        self._paper_count += 1
        print(f"[{self._name}] 打印: {document}")
    
    def get_paper_count(self) -> int:
        return self._paper_count
    
    def add_paper(self, count: int) -> None:
        self._paper_count += count

class SimpleScanner:
    def __init__(self, name: str):
        self._name = name
    
    def scan(self) -> str:
        return f"[{self._name}] 扫描的文档内容"

class MultiFunctionDevice:
    def __init__(self, name: str):
        self._name = name
        self._fax_buffer: list[str] = []
    
    def print(self, document: str) -> None:
        print(f"[{self._name}] 打印: {document}")
    
    def scan(self) -> str:
        return f"[{self._name}] 扫描的文档内容"
    
    def send_fax(self, number: str, document: str) -> None:
        print(f"[{self._name}] 发送传真到 {number}: {document}")
    
    def receive_fax(self) -> str:
        if self._fax_buffer:
            return self._fax_buffer.pop(0)
        return "无传真"
    
    def copy(self, document: str) -> str:
        scanned = self.scan()
        self.print(f"副本: {document}")
        return scanned

class OfficeWorker:
    def __init__(self, printer: Printer):
        self._printer = printer
    
    def print_document(self, document: str) -> None:
        self._printer.print(document)

class DocumentArchiver:
    def __init__(self, scanner: Scanner):
        self._scanner = scanner
    
    def archive_document(self) -> str:
        return self._scanner.scan()

class FaxOperator:
    def __init__(self, fax: Fax):
        self._fax = fax
    
    def send_document(self, number: str, document: str) -> None:
        self._fax.send_fax(number, document)

simple_printer = SimplePrinter("办公室打印机")
worker = OfficeWorker(simple_printer)
worker.print_document("会议记录")

scanner = SimpleScanner("档案扫描仪")
archiver = DocumentArchiver(scanner)
print(archiver.archive_document())

mfd = MultiFunctionDevice("多功能一体机")
fax_operator = FaxOperator(mfd)
fax_operator.send_document("12345678", "合同文件")

依赖倒置原则(DIP)

核心思想:高层模块不应该依赖低层模块,两者都应该依赖其抽象。

┌─────────────────────────────────────────────────────────────────────────┐
│                     依赖倒置原则示意图                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  违反DIP:                          符合DIP:                            │
│                                                                         │
│  ┌──────────────┐                 ┌──────────────────┐                 │
│  │ UserService  │                 │   <<interface>>  │                 │
│  ├──────────────┤                 │     Database     │                 │
│  │ - db: MySQL  │◄──直接依赖      ├──────────────────┤                 │
│  ├──────────────┤                 │ + save()         │                 │
│  │ + saveUser() │                 │ + find()         │                 │
│  └──────────────┘                 └────────┬─────────┘                 │
│                                            │                           │
│  ┌──────────────┐                 ┌────────┴────────┐                  │
│  │    MySQL     │                 │                 │                  │
│  └──────────────┘           ┌─────▼─────┐    ┌──────▼─────┐           │
│                             │ UserService│    │  MySQL     │           │
│  问题: 更换数据库需          ├───────────┤    ├────────────┤           │
│  修改UserService            │ - db:     │    │ + save()   │           │
│                             │   Database│    │ + find()   │           │
│                             ├───────────┤    └────────────┘           │
│                             │+saveUser()│                             │
│                             └───────────┘    ┌────────────┐           │
│                                              │ PostgreSQL │           │
│                                              ├────────────┤           │
│                                              │ + save()   │           │
│                                              │ + find()   │           │
│                                              └────────────┘           │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
python
from abc import ABC, abstractmethod
from typing import Protocol, Optional, Any
from dataclasses import dataclass

class Database(Protocol):
    def save(self, table: str, data: dict) -> int: ...
    def find(self, table: str, id: int) -> Optional[dict]: ...
    def delete(self, table: str, id: int) -> bool: ...
    def find_all(self, table: str) -> list[dict]: ...

class Logger(Protocol):
    def log(self, message: str, level: str = "INFO") -> None: ...
    def error(self, message: str) -> None: ...
    def debug(self, message: str) -> None: ...

@dataclass
class User:
    id: Optional[int] = None
    name: str = ""
    email: str = ""

class MySQLDatabase:
    def __init__(self, connection_string: str):
        self._connection_string = connection_string
        self._data: dict[str, dict[int, dict]] = {}
        self._next_id: dict[str, int] = {}
    
    def save(self, table: str, data: dict) -> int:
        if table not in self._next_id:
            self._next_id[table] = 1
            self._data[table] = {}
        
        id = self._next_id[table]
        data['id'] = id
        self._data[table][id] = data.copy()
        self._next_id[table] += 1
        return id
    
    def find(self, table: str, id: int) -> Optional[dict]:
        return self._data.get(table, {}).get(id)
    
    def delete(self, table: str, id: int) -> bool:
        if table in self._data and id in self._data[table]:
            del self._data[table][id]
            return True
        return False
    
    def find_all(self, table: str) -> list[dict]:
        return list(self._data.get(table, {}).values())

class PostgreSQLDatabase:
    def __init__(self, connection_string: str):
        self._connection_string = connection_string
        self._data: dict[str, dict[int, dict]] = {}
        self._next_id: dict[str, int] = {}
    
    def save(self, table: str, data: dict) -> int:
        if table not in self._next_id:
            self._next_id[table] = 1
            self._data[table] = {}
        
        id = self._next_id[table]
        data['id'] = id
        self._data[table][id] = data.copy()
        self._next_id[table] += 1
        print(f"[PostgreSQL] INSERT INTO {table} VALUES (...)")
        return id
    
    def find(self, table: str, id: int) -> Optional[dict]:
        print(f"[PostgreSQL] SELECT * FROM {table} WHERE id = {id}")
        return self._data.get(table, {}).get(id)
    
    def delete(self, table: str, id: int) -> bool:
        print(f"[PostgreSQL] DELETE FROM {table} WHERE id = {id}")
        if table in self._data and id in self._data[table]:
            del self._data[table][id]
            return True
        return False
    
    def find_all(self, table: str) -> list[dict]:
        print(f"[PostgreSQL] SELECT * FROM {table}")
        return list(self._data.get(table, {}).values())

class ConsoleLogger:
    def log(self, message: str, level: str = "INFO") -> None:
        print(f"[{level}] {message}")
    
    def error(self, message: str) -> None:
        self.log(message, "ERROR")
    
    def debug(self, message: str) -> None:
        self.log(message, "DEBUG")

class UserService:
    TABLE = "users"
    
    def __init__(self, database: Database, logger: Logger):
        self._database = database
        self._logger = logger
    
    def create_user(self, name: str, email: str) -> User:
        self._logger.log(f"创建用户: {name}")
        
        data = {'name': name, 'email': email}
        user_id = self._database.save(self.TABLE, data)
        
        self._logger.debug(f"用户创建成功, ID: {user_id}")
        return User(id=user_id, name=name, email=email)
    
    def get_user(self, user_id: int) -> Optional[User]:
        self._logger.debug(f"查询用户: {user_id}")
        data = self._database.find(self.TABLE, user_id)
        
        if data:
            return User(**data)
        return None
    
    def delete_user(self, user_id: int) -> bool:
        self._logger.log(f"删除用户: {user_id}")
        return self._database.delete(self.TABLE, user_id)

mysql_db = MySQLDatabase("mysql://localhost:3306/mydb")
logger = ConsoleLogger()
user_service = UserService(mysql_db, logger)

user = user_service.create_user("张三", "zhangsan@example.com")
print(f"创建的用户: {user}")

pg_db = PostgreSQLDatabase("postgresql://localhost:5432/mydb")
user_service_pg = UserService(pg_db, logger)
user_pg = user_service_pg.create_user("李四", "lisi@example.com")

其他设计原则

迪米特法则(LoD)

定义A.8(迪米特法则) 一个对象 $O$ 应该只与以下对象通信:

  1. $O$ 本身
  2. $O$ 的成员对象
  3. $O$ 创建的对象
  4. $O$ 的参数对象

形式化表述

$$\forall m \in Methods(O): \forall T \in Types(m): T \in {O, Fields(O), Params(m), Locals(m)}$$

python
from dataclasses import dataclass
from typing import Optional

@dataclass
class Department:
    name: str
    manager_name: str
    
    def get_manager_name(self) -> str:
        return self.manager_name

class Company:
    def __init__(self, name: str):
        self._name = name
        self._departments: dict[str, Department] = {}
    
    def add_department(self, name: str, manager: str) -> None:
        self._departments[name] = Department(name, manager)
    
    def get_department(self, name: str) -> Optional[Department]:
        return self._departments.get(name)
    
    def get_manager_of_department(self, dept_name: str) -> Optional[str]:
        dept = self._departments.get(dept_name)
        return dept.get_manager_name() if dept else None

class Employee:
    def __init__(self, name: str, company: Company):
        self._name = name
        self._company = company
    
    def get_manager(self, department_name: str) -> Optional[str]:
        return self._company.get_manager_of_department(department_name)

company = Company("科技公司")
company.add_department("研发部", "王经理")
company.add_department("市场部", "李经理")

employee = Employee("张三", company)
print(f"研发部经理: {employee.get_manager('研发部')}")

组合复用原则(CRP)

定义A.9(组合复用原则) 优先使用对象组合而非继承来实现代码复用。

组合 vs 继承

$$Inheritance: \quad IsA \implies \forall m \in M_{parent}: m \in M_{child}$$

$$Composition: \quad HasA \implies \forall m \in M_{component}: delegate(m) \in M_{container}$$

python
from abc import ABC, abstractmethod
from typing import Protocol

class Engine(Protocol):
    def start(self) -> str: ...
    def stop(self) -> str: ...
    def get_type(self) -> str: ...

class GasolineEngine:
    def start(self) -> str:
        return "汽油引擎启动"
    
    def stop(self) -> str:
        return "汽油引擎停止"
    
    def get_type(self) -> str:
        return "汽油"

class ElectricEngine:
    def start(self) -> str:
        return "电动引擎启动"
    
    def stop(self) -> str:
        return "电动引擎停止"
    
    def get_type(self) -> str:
        return "电动"

class HybridEngine:
    def __init__(self):
        self._gasoline = GasolineEngine()
        self._electric = ElectricEngine()
        self._current_mode = "electric"
    
    def start(self) -> str:
        if self._current_mode == "electric":
            return self._electric.start()
        return self._gasoline.start()
    
    def stop(self) -> str:
        return f"{self._electric.stop()}{self._gasoline.stop()}"
    
    def get_type(self) -> str:
        return "混合动力"
    
    def switch_mode(self) -> None:
        self._current_mode = "gasoline" if self._current_mode == "electric" else "electric"

class Transmission(Protocol):
    def shift(self, gear: int) -> str: ...

class AutomaticTransmission:
    def shift(self, gear: int) -> str:
        return f"自动变速箱切换到 {gear} 档"

class ManualTransmission:
    def shift(self, gear: int) -> str:
        return f"手动变速箱切换到 {gear} 档"

class Car:
    def __init__(self, brand: str, engine: Engine, transmission: Transmission):
        self._brand = brand
        self._engine = engine
        self._transmission = transmission
        self._speed = 0
    
    def start(self) -> str:
        return f"{self._brand}: {self._engine.start()}"
    
    def stop(self) -> str:
        self._speed = 0
        return f"{self._brand}: {self._engine.stop()}"
    
    def accelerate(self, speed: int) -> str:
        self._speed += speed
        return f"{self._brand}: 加速到 {self._speed} km/h"
    
    def shift_gear(self, gear: int) -> str:
        return f"{self._brand}: {self._transmission.shift(gear)}"
    
    def get_info(self) -> str:
        return f"{self._brand} ({self._engine.get_type()}引擎)"

gasoline_car = Car("丰田", GasolineEngine(), AutomaticTransmission())
electric_car = Car("特斯拉", ElectricEngine(), AutomaticTransmission())
hybrid_car = Car("普锐斯", HybridEngine(), ManualTransmission())

print(gasoline_car.start())
print(electric_car.start())
print(hybrid_car.start())

KISS原则

定义A.10(KISS原则) Keep It Simple, Stupid - 保持简单愚蠢。

python
from typing import List, Optional
from dataclasses import dataclass

@dataclass
class Product:
    id: int
    name: str
    price: float

class SimpleInventory:
    def __init__(self):
        self._products: dict[int, Product] = {}
        self._quantities: dict[int, int] = {}
    
    def add_product(self, product: Product, quantity: int) -> None:
        self._products[product.id] = product
        self._quantities[product.id] = self._quantities.get(product.id, 0) + quantity
    
    def get_quantity(self, product_id: int) -> int:
        return self._quantities.get(product_id, 0)
    
    def reduce_quantity(self, product_id: int, amount: int) -> bool:
        if self._quantities.get(product_id, 0) >= amount:
            self._quantities[product_id] -= amount
            return True
        return False
    
    def get_total_value(self) -> float:
        return sum(
            self._products[pid].price * qty 
            for pid, qty in self._quantities.items()
        )

DRY原则

定义A.11(DRY原则) Don't Repeat Yourself - 不要重复自己。

python
from typing import Callable, TypeVar, Generic
from dataclasses import dataclass
from functools import wraps
import time

T = TypeVar('T')

@dataclass
class ValidationResult:
    is_valid: bool
    errors: list[str]

class Validator:
    @staticmethod
    def validate_email(email: str) -> ValidationResult:
        errors = []
        if '@' not in email:
            errors.append("邮箱必须包含@符号")
        if '.' not in email.split('@')[-1]:
            errors.append("邮箱域名无效")
        return ValidationResult(len(errors) == 0, errors)
    
    @staticmethod
    def validate_password(password: str) -> ValidationResult:
        errors = []
        if len(password) < 8:
            errors.append("密码长度至少8位")
        if not any(c.isupper() for c in password):
            errors.append("密码必须包含大写字母")
        if not any(c.isdigit() for c in password):
            errors.append("密码必须包含数字")
        return ValidationResult(len(errors) == 0, errors)
    
    @staticmethod
    def validate_username(username: str) -> ValidationResult:
        errors = []
        if len(username) < 3:
            errors.append("用户名长度至少3位")
        if not username.isalnum():
            errors.append("用户名只能包含字母和数字")
        return ValidationResult(len(errors) == 0, errors)

def retry(max_attempts: int = 3, delay: float = 1.0):
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            last_exception = None
            for attempt in range(max_attempts):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    last_exception = e
                    if attempt < max_attempts - 1:
                        time.sleep(delay)
            raise last_exception
        return wrapper
    return decorator

YAGNI原则

定义A.12(YAGNI原则) You Aren't Gonna Need It - 你不会需要它。

python
class UserRepository:
    def __init__(self, database):
        self._database = database
    
    def find_by_id(self, user_id: int):
        return self._database.find('users', user_id)
    
    def save(self, user: dict) -> int:
        return self._database.save('users', user)
    
    def delete(self, user_id: int) -> bool:
        return self._database.delete('users', user_id)

class UserService:
    def __init__(self, repository: UserRepository):
        self._repository = repository
    
    def get_user(self, user_id: int):
        return self._repository.find_by_id(user_id)
    
    def create_user(self, name: str, email: str) -> dict:
        user = {'name': name, 'email': email}
        user_id = self._repository.save(user)
        user['id'] = user_id
        return user

原则之间的关系

┌─────────────────────────────────────────────────────────────────────────┐
│                         设计原则关系图                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│                        ┌─────────────────┐                             │
│                        │   高内聚低耦合   │                             │
│                        │   (设计目标)     │                             │
│                        └────────┬────────┘                             │
│                                 │                                       │
│          ┌──────────────────────┼──────────────────────┐               │
│          │                      │                      │               │
│   ┌──────▼──────┐        ┌──────▼──────┐       ┌──────▼──────┐        │
│   │  SOLID原则   │        │  其他原则   │       │  实践原则   │        │
│   ├─────────────┤        ├─────────────┤       ├─────────────┤        │
│   │ SRP ──────► │        │ LoD ──────► │       │ KISS ─────► │        │
│   │    高内聚   │        │   低耦合    │       │   简洁性    │        │
│   │             │        │             │       │             │        │
│   │ OCP ──────► │        │ CRP ──────► │       │ DRY ──────► │        │
│   │    可扩展   │        │   灵活复用  │       │   可维护    │        │
│   │             │        │             │       │             │        │
│   │ LSP ──────► │        │             │       │ YAGNI ────► │        │
│   │    可替换   │        │             │       │   避免过度  │        │
│   │             │        │             │       │             │        │
│   │ ISP ──────► │        │             │       │             │        │
│   │    接口精简 │        │             │       │             │        │
│   │             │        │             │       │             │        │
│   │ DIP ──────► │        │             │       │             │        │
│   │    解耦     │        │             │       │             │        │
│   └─────────────┘        └─────────────┘       └─────────────┘        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

原则应用指南

场景决策表

场景推荐原则具体应用
类职责过多单一职责原则拆分为多个专注的类
需要扩展功能开闭原则使用策略模式、装饰器模式
继承层次设计里氏替换原则确保子类行为一致
接口设计接口隔离原则拆分大接口为小接口
模块依赖依赖倒置原则引入抽象层解耦
类间通信迪米特法则减少直接依赖
代码复用组合复用原则优先组合而非继承
系统复杂度KISS原则保持简单直接
重复代码DRY原则提取公共逻辑
功能预测YAGNI原则只实现当前需要

原则冲突处理

冲突场景权衡建议
SRP vs 性能适度合并高频调用的职责
OCP vs 简洁性预测稳定的变化点再抽象
DRY vs YAGNI只提取当前存在的重复
组合 vs 继承行为复用用组合,类型关系用继承

快速参考卡片

SOLID原则速查

原则核心思想检验方法
SRP一个类只有一个变化原因能否用一句话描述类的职责?
OCP对扩展开放,对修改关闭新增功能是否需要修改现有代码?
LSP子类可替换父类子类是否违反父类契约?
ISP接口职责单一客户端是否依赖不需要的方法?
DIP依赖抽象而非具体高层模块是否直接依赖低层?

设计原则检查清单

  • [ ] 每个类是否有单一明确的职责?
  • [ ] 新功能是否可以通过扩展而非修改实现?
  • [ ] 子类是否可以安全替换父类?
  • [ ] 接口是否足够精简?
  • [ ] 是否依赖抽象而非具体实现?
  • [ ] 类之间的耦合是否最小化?
  • [ ] 是否优先使用组合而非继承?
  • [ ] 代码是否足够简单?
  • [ ] 是否存在重复代码?
  • [ ] 是否只实现当前需要的功能?

总结

设计原则是软件设计的基石,它们提供了指导性的思想而非僵化的规则。在实际应用中:

  1. 理解原则本质:掌握每个原则背后的设计哲学
  2. 权衡取舍:根据具体场景灵活应用原则
  3. 持续重构:通过重构逐步改进设计
  4. 团队共识:建立团队对原则的共同理解

记住:设计原则是工具,不是目的。好的设计应该服务于业务需求,同时保持代码的可维护性和可扩展性。

Python技术丛书 - 江苏省宿城中等专业学校