Skip to content

第24章 函数式编程模式

学习目标

完成本章学习后,读者将能够:

  • 理解函数式编程的核心概念与形式化定义
  • 掌握不可变数据和纯函数的设计原则
  • 使用高阶函数、函子和单子等抽象结构
  • 在Python中应用函数式编程模式解决实际问题
  • 结合面向对象与函数式编程发挥两者优势

24.1 函数式编程概述

24.1.1 核心定义

函数式编程(Functional Programming) 是一种编程范式,它将计算视为数学函数的求值,强调使用纯函数、不可变数据和函数组合来构建软件系统。

24.1.2 形式化定义

从数学角度,函数式编程基于Lambda演算范畴论

Lambda演算基础

Lambda表达式定义为:

$$\lambda x. E$$

其中 $x$ 是变量,$E$ 是表达式。应用操作为:

$$(\lambda x. E_1) E_2 \rightarrow E_1[x := E_2]$$

纯函数定义

函数 $f: A \rightarrow B$ 是纯函数,当且仅当:

  1. 确定性:$\forall x \in A, \exists! y \in B: f(x) = y$
  2. 无副作用:$f$ 不修改任何外部状态

引用透明性

表达式 $E$ 是引用透明的,当且仅当:

$$\forall \text{上下文 } C, \forall E' \text{ 与 } E \text{ 等价}: C[E] = C[E']$$

不可变性约束

对于不可变数据结构 $D$,任何操作 $op$ 满足:

$$op(D) = D' \land D \text{ 保持不变}$$

24.1.3 设计原则

原则描述数学表示
纯函数无副作用,引用透明$f: A \rightarrow B$
不可变性数据一旦创建不可修改$D' = op(D) \Rightarrow D \text{ unchanged}$
一等函数函数作为值传递$f \in Values$
高阶函数函数接受/返回函数$h: (A \rightarrow B) \rightarrow C$
惰性求值按需计算$eval(E) \text{ when needed}$

24.1.4 与面向对象对比

特性函数式编程面向对象编程
核心抽象函数对象
状态管理不可变可变
数据行为数据与行为分离数据与行为封装
组合方式函数组合对象组合
副作用避免/隔离可接受
测试性天然可测试需要模拟

24.2 历史背景与演进

24.2.1 历史发展

年代里程碑描述
1930sLambda演算Alonzo Church提出Lambda演算
1958LISPJohn McCarthy创建LISP语言
1970sML语言Robin Milner开发ML类型系统
1977FP语言John Backus提出函数式编程概念
1987Haskell纯函数式语言Haskell诞生
1990sMonadsPhilip Wadler将单子引入函数式编程
2000sF#, Scala函数式特性融入主流语言
2010sReact/Redux函数式思想影响前端开发
2020s普及化函数式特性成为现代语言标配

24.2.2 理论基础

函数式编程的理论基础来源于:

  1. Lambda演算:计算的形式化模型
  2. 范畴论:函子、单子等抽象结构
  3. 类型理论:类型系统与类型推导
  4. 组合子逻辑:函数组合的理论基础

24.3 纯函数与不可变性

24.3.1 纯函数

python
from typing import TypeVar, Generic, Callable, List, Any
from dataclasses import dataclass
from functools import wraps
import hashlib
import json

T = TypeVar('T')
R = TypeVar('R')


def pure_function_example():
    """
    纯函数示例:相同输入总是产生相同输出,无副作用
    """
    
    def add(a: int, b: int) -> int:
        return a + b
    
    def square(x: int) -> int:
        return x ** 2
    
    def format_name(first: str, last: str) -> str:
        return f"{first} {last}"
    
    def calculate_discount(price: float, rate: float) -> float:
        return price * (1 - rate)
    
    print(f"add(2, 3) = {add(2, 3)}")
    print(f"square(5) = {square(5)}")
    print(f"format_name('张', '三') = {format_name('张', '三')}")


def impure_vs_pure():
    """
    不纯函数 vs 纯函数对比
    """
    
    total = 0
    
    def impure_add(a: int) -> int:
        nonlocal total
        total += a
        return total
    
    def pure_add(total: int, a: int) -> tuple[int, int]:
        new_total = total + a
        return new_total, new_total
    
    print("=== 不纯函数 ===")
    print(impure_add(1))
    print(impure_add(1))
    print(impure_add(1))
    
    print("\n=== 纯函数 ===")
    total = 0
    total, _ = pure_add(total, 1)
    print(total)
    total, _ = pure_add(total, 1)
    print(total)
    total, _ = pure_add(total, 1)
    print(total)


pure_function_example()
print()
impure_vs_pure()

24.3.2 不可变数据结构

python
from dataclasses import dataclass, field
from typing import Tuple, FrozenSet, List, Dict, Any
from copy import deepcopy
import json


@dataclass(frozen=True)
class Point:
    """不可变点类"""
    x: float
    y: float
    
    def move(self, dx: float, dy: float) -> 'Point':
        return Point(self.x + dx, self.y + dy)
    
    def distance_to(self, other: 'Point') -> float:
        return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5


@dataclass(frozen=True)
class Rectangle:
    """不可变矩形类"""
    top_left: Point
    width: float
    height: float
    
    def with_width(self, width: float) -> 'Rectangle':
        return Rectangle(self.top_left, width, self.height)
    
    def with_height(self, height: float) -> 'Rectangle':
        return Rectangle(self.top_left, self.width, height)
    
    def move(self, dx: float, dy: float) -> 'Rectangle':
        return Rectangle(self.top_left.move(dx, dy), self.width, self.height)
    
    def area(self) -> float:
        return self.width * self.height


class ImmutableList:
    """不可变列表"""
    
    def __init__(self, items: tuple = ()):
        self._items = items
    
    @staticmethod
    def from_list(items: list) -> 'ImmutableList':
        return ImmutableList(tuple(items))
    
    def append(self, item: Any) -> 'ImmutableList':
        return ImmutableList(self._items + (item,))
    
    def prepend(self, item: Any) -> 'ImmutableList':
        return ImmutableList((item,) + self._items)
    
    def remove(self, item: Any) -> 'ImmutableList':
        return ImmutableList(tuple(x for x in self._items if x != item))
    
    def map(self, func: Callable[[Any], Any]) -> 'ImmutableList':
        return ImmutableList(tuple(func(x) for x in self._items))
    
    def filter(self, predicate: Callable[[Any], bool]) -> 'ImmutableList':
        return ImmutableList(tuple(x for x in self._items if predicate(x)))
    
    def reduce(self, func: Callable[[Any, Any], Any], initial: Any) -> Any:
        result = initial
        for item in self._items:
            result = func(result, item)
        return result
    
    def to_list(self) -> list:
        return list(self._items)
    
    def __len__(self) -> int:
        return len(self._items)
    
    def __getitem__(self, index: int) -> Any:
        return self._items[index]
    
    def __repr__(self) -> str:
        return f"ImmutableList({self._items})"


class ImmutableDict:
    """不可变字典"""
    
    def __init__(self, data: dict = None):
        self._data = dict(data) if data else {}
        self._hash = None
    
    def set(self, key: str, value: Any) -> 'ImmutableDict':
        new_data = self._data.copy()
        new_data[key] = value
        return ImmutableDict(new_data)
    
    def delete(self, key: str) -> 'ImmutableDict':
        new_data = self._data.copy()
        if key in new_data:
            del new_data[key]
        return ImmutableDict(new_data)
    
    def update(self, other: dict) -> 'ImmutableDict':
        new_data = self._data.copy()
        new_data.update(other)
        return ImmutableDict(new_data)
    
    def get(self, key: str, default: Any = None) -> Any:
        return self._data.get(key, default)
    
    def keys(self) -> tuple:
        return tuple(self._data.keys())
    
    def values(self) -> tuple:
        return tuple(self._data.values())
    
    def items(self) -> tuple:
        return tuple(self._data.items())
    
    def __contains__(self, key: str) -> bool:
        return key in self._data
    
    def __getitem__(self, key: str) -> Any:
        return self._data[key]
    
    def __repr__(self) -> str:
        return f"ImmutableDict({self._data})"
    
    def __hash__(self) -> int:
        if self._hash is None:
            self._hash = hash(frozenset(self._data.items()))
        return self._hash


p1 = Point(0, 0)
p2 = p1.move(10, 20)
print(f"原始点: ({p1.x}, {p1.y})")
print(f"移动后: ({p2.x}, {p2.y})")

rect = Rectangle(Point(0, 0), 10, 20)
rect2 = rect.with_width(15).move(5, 5)
print(f"\n原始矩形: {rect.area()} 平方单位")
print(f"修改后矩形: {rect2.area()} 平方单位")

nums = ImmutableList.from_list([1, 2, 3, 4, 5])
nums2 = nums.append(6).map(lambda x: x * 2)
print(f"\n原始列表: {nums}")
print(f"修改后列表: {nums2}")

config = ImmutableDict({'host': 'localhost', 'port': 8080})
config2 = config.set('debug', True)
print(f"\n原始配置: {config}")
print(f"修改后配置: {config2}")

24.3.3 引用透明性与等式推理

python
from typing import Callable, Any


def demonstrate_referential_transparency():
    """
    引用透明性演示:可以用函数调用结果替换函数调用本身
    """
    
    def double(x: int) -> int:
        return x * 2
    
    def add(a: int, b: int) -> int:
        return a + b
    
    result1 = add(double(3), double(4))
    
    result2 = add(6, 8)
    
    result3 = 6 + 8
    
    print(f"add(double(3), double(4)) = {result1}")
    print(f"add(6, 8) = {result2}")
    print(f"6 + 8 = {result3}")
    print(f"所有结果相等: {result1 == result2 == result3}")


def equational_reasoning():
    """
    等式推理:通过等价变换推导程序性质
    """
    
    def process(x: int) -> int:
        return (x + 1) * 2 - 2
    
    def simplified_process(x: int) -> int:
        return x * 2
    
    for i in range(5):
        original = process(i)
        simplified = simplified_process(i)
        print(f"process({i}) = {original}, simplified({i}) = {simplified}, 相等: {original == simplified}")


demonstrate_referential_transparency()
print()
equational_reasoning()

24.4 高阶函数与函数组合

24.4.1 高阶函数

python
from typing import Callable, TypeVar, List, Any, Iterator
from functools import reduce, partial, wraps
from dataclasses import dataclass

T = TypeVar('T')
R = TypeVar('R')


class HigherOrderFunctions:
    """高阶函数工具集"""
    
    @staticmethod
    def map_func(func: Callable[[T], R], items: List[T]) -> List[R]:
        """映射函数"""
        return [func(x) for x in items]
    
    @staticmethod
    def filter_func(predicate: Callable[[T], bool], items: List[T]) -> List[T]:
        """过滤函数"""
        return [x for x in items if predicate(x)]
    
    @staticmethod
    def reduce_func(func: Callable[[R, T], R], items: List[T], initial: R) -> R:
        """归约函数"""
        result = initial
        for x in items:
            result = func(result, x)
        return result
    
    @staticmethod
    def flat_map(func: Callable[[T], List[R]], items: List[T]) -> List[R]:
        """扁平映射"""
        result = []
        for x in items:
            result.extend(func(x))
        return result
    
    @staticmethod
    def take_while(predicate: Callable[[T], bool], items: List[T]) -> List[T]:
        """取满足条件的元素直到第一个不满足"""
        result = []
        for x in items:
            if predicate(x):
                result.append(x)
            else:
                break
        return result
    
    @staticmethod
    def drop_while(predicate: Callable[[T], bool], items: List[T]) -> List[T]:
        """丢弃满足条件的元素直到第一个不满足"""
        for i, x in enumerate(items):
            if not predicate(x):
                return items[i:]
        return []
    
    @staticmethod
    def partition(predicate: Callable[[T], bool], items: List[T]) -> tuple[List[T], List[T]]:
        """分区:将列表分为满足和不满足条件的两部分"""
        true_items = []
        false_items = []
        for x in items:
            if predicate(x):
                true_items.append(x)
            else:
                false_items.append(x)
        return true_items, false_items
    
    @staticmethod
    def zip_with(func: Callable[[T, R], Any], items1: List[T], items2: List[R]) -> List[Any]:
        """用函数组合两个列表"""
        return [func(a, b) for a, b in zip(items1, items2)]


numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

hof = HigherOrderFunctions()

squared = hof.map_func(lambda x: x ** 2, numbers)
print(f"平方: {squared}")

evens = hof.filter_func(lambda x: x % 2 == 0, numbers)
print(f"偶数: {evens}")

total = hof.reduce_func(lambda acc, x: acc + x, numbers, 0)
print(f"总和: {total}")

nested = hof.flat_map(lambda x: [x, x * 10], [1, 2, 3])
print(f"扁平映射: {nested}")

taken = hof.take_while(lambda x: x < 5, numbers)
print(f"取到<5: {taken}")

dropped = hof.drop_while(lambda x: x < 5, numbers)
print(f"丢弃<5: {dropped}")

evens, odds = hof.partition(lambda x: x % 2 == 0, numbers)
print(f"分区: 偶数={evens}, 奇数={odds}")

sums = hof.zip_with(lambda a, b: a + b, [1, 2, 3], [10, 20, 30])
print(f"zipWith: {sums}")

24.4.2 函数组合

python
from typing import Callable, TypeVar, Any
from functools import reduce, wraps

T = TypeVar('T')
R = TypeVar('R')


def compose(*functions: Callable) -> Callable:
    """
    函数组合:从右到左执行
    compose(f, g, h)(x) = f(g(h(x)))
    """
    def composed(x: Any) -> Any:
        result = x
        for f in reversed(functions):
            result = f(result)
        return result
    return composed


def pipe(*functions: Callable) -> Callable:
    """
    管道:从左到右执行
    pipe(f, g, h)(x) = h(g(f(x)))
    """
    def piped(x: Any) -> Any:
        result = x
        for f in functions:
            result = f(result)
        return result
    return piped


class Flow:
    """流式API"""
    
    def __init__(self, value: Any):
        self._value = value
    
    def map(self, func: Callable) -> 'Flow':
        return Flow(func(self._value))
    
    def filter(self, predicate: Callable) -> 'Flow':
        if isinstance(self._value, list):
            return Flow([x for x in self._value if predicate(x)])
        return Flow(self._value if predicate(self._value) else None)
    
    def flat_map(self, func: Callable) -> 'Flow':
        if isinstance(self._value, list):
            result = []
            for x in self._value:
                mapped = func(x)
                if isinstance(mapped, list):
                    result.extend(mapped)
                else:
                    result.append(mapped)
            return Flow(result)
        return Flow(func(self._value))
    
    def reduce(self, func: Callable, initial: Any = None) -> 'Flow':
        if isinstance(self._value, list):
            from functools import reduce as _reduce
            return Flow(_reduce(func, self._value, initial))
        return Flow(self._value)
    
    def tap(self, func: Callable) -> 'Flow':
        func(self._value)
        return self
    
    def value(self) -> Any:
        return self._value
    
    def __repr__(self) -> str:
        return f"Flow({self._value})"


def trim(s: str) -> str:
    return s.strip()

def lower(s: str) -> str:
    return s.lower()

def capitalize(s: str) -> str:
    return s.capitalize()

def exclaim(s: str) -> str:
    return f"{s}!"

normalize = compose(exclaim, capitalize, lower, trim)
print(f"compose结果: {normalize('  HELLO WORLD  ')}")

normalize_pipe = pipe(trim, lower, capitalize, exclaim)
print(f"pipe结果: {normalize_pipe('  HELLO WORLD  ')}")

result = (Flow([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    .filter(lambda x: x % 2 == 0)
    .map(lambda x: x ** 2)
    .tap(lambda x: print(f"中间结果: {x}"))
    .reduce(lambda acc, x: acc + x, 0)
    .value())

print(f"Flow结果: {result}")

24.4.3 柯里化与偏函数

python
from functools import partial, wraps
from typing import Callable, Any, get_type_hints
import inspect


def curry(func: Callable) -> Callable:
    """
    自动柯里化装饰器
    """
    @wraps(func)
    def wrapped(*args, **kwargs):
        sig = inspect.signature(func)
        params = list(sig.parameters.keys())
        provided = len(args) + len(kwargs)
        needed = len(params)
        
        if provided >= needed:
            return func(*args, **kwargs)
        
        @wraps(func)
        def partial_func(*more_args, **more_kwargs):
            new_args = args + more_args
            new_kwargs = {**kwargs, **more_kwargs}
            return wrapped(*new_args, **new_kwargs)
        
        return partial_func
    
    return wrapped


def curry_explicit(func: Callable, arity: int) -> Callable:
    """
    显式指定参数数量的柯里化
    """
    def curried(args):
        if len(args) >= arity:
            return func(*args)
        return lambda x: curried(args + [x])
    
    return lambda x: curried([x])


@curry
def add_three(a: int, b: int, c: int) -> int:
    return a + b + c

@curry
def format_message(greeting: str, name: str, punctuation: str) -> str:
    return f"{greeting}, {name}{punctuation}"

add_5 = add_three(5)
add_5_3 = add_5(3)
print(f"柯里化 add_three(5)(3)(2) = {add_5_3(2)}")

say_hello = format_message("Hello")
say_hello_world = say_hello("World")
print(f"柯里化 format_message: {say_hello_world('!')}")


def create_validator(field: str, rule: Callable, message: str) -> Callable:
    """偏函数示例:创建验证器"""
    def validator(value: Any) -> tuple[bool, str]:
        if rule(value):
            return True, ""
        return False, message.format(field=field, value=value)
    return validator

is_positive = partial(create_validator, rule=lambda x: x > 0, message="{field} must be positive")
is_valid_age = partial(is_positive, field="age")

print(f"\n验证年龄 -5: {is_valid_age(-5)}")
print(f"验证年龄 25: {is_valid_age(25)}")

24.5 函子与单子

24.5.1 函子(Functor)

python
from typing import TypeVar, Generic, Callable, List, Iterator
from dataclasses import dataclass
from abc import ABC, abstractmethod

T = TypeVar('T')
R = TypeVar('R')


class Functor(ABC, Generic[T]):
    """
    函子抽象基类
    
    函子定律:
    1. 恒等律: fmap(id) = id
    2. 组合律: fmap(f . g) = fmap(f) . fmap(g)
    """
    
    @abstractmethod
    def map(self, func: Callable[[T], R]) -> 'Functor[R]':
        pass
    
    def __rshift__(self, func: Callable[[T], R]) -> 'Functor[R]':
        return self.map(func)


class ListFunctor(Functor[T]):
    """列表函子"""
    
    def __init__(self, values: List[T]):
        self._values = list(values)
    
    def map(self, func: Callable[[T], R]) -> 'ListFunctor[R]':
        return ListFunctor([func(x) for x in self._values])
    
    def __iter__(self) -> Iterator[T]:
        return iter(self._values)
    
    def __repr__(self) -> str:
        return f"ListFunctor({self._values})"


class Maybe(Functor[T]):
    """
    Maybe函子:处理可能为空的值
    
    Maybe a = Just a | Nothing
    """
    
    def __init__(self, value: T = None):
        self._value = value
    
    @staticmethod
    def just(value: T) -> 'Maybe[T]':
        return Maybe(value)
    
    @staticmethod
    def nothing() -> 'Maybe[T]':
        return Maybe(None)
    
    def is_just(self) -> bool:
        return self._value is not None
    
    def is_nothing(self) -> bool:
        return self._value is None
    
    def map(self, func: Callable[[T], R]) -> 'Maybe[R]':
        if self.is_just():
            return Maybe.just(func(self._value))
        return Maybe.nothing()
    
    def get_or_else(self, default: T) -> T:
        return self._value if self.is_just() else default
    
    def __repr__(self) -> str:
        if self.is_just():
            return f"Just({self._value})"
        return "Nothing"


class Either(Functor[T]):
    """
    Either函子:处理可能失败的计算
    
    Either a b = Left a | Right b
    """
    
    def __init__(self, left=None, right=None):
        self._left = left
        self._right = right
    
    @staticmethod
    def left(value) -> 'Either':
        return Either(left=value)
    
    @staticmethod
    def right(value: T) -> 'Either[T]':
        return Either(right=value)
    
    def is_left(self) -> bool:
        return self._left is not None
    
    def is_right(self) -> bool:
        return self._right is not None
    
    def map(self, func: Callable[[T], R]) -> 'Either[R]':
        if self.is_right():
            return Either.right(func(self._right))
        return Either.left(self._left)
    
    def get_or_else(self, default: T) -> T:
        return self._right if self.is_right() else default
    
    def __repr__(self) -> str:
        if self.is_left():
            return f"Left({self._left})"
        return f"Right({self._right})"


def safe_divide(a: float, b: float) -> Maybe[float]:
    if b == 0:
        return Maybe.nothing()
    return Maybe.just(a / b)

result = (Maybe.just(10)
    .map(lambda x: x * 2)
    .map(lambda x: x + 5)
    .get_or_else(0))
print(f"Maybe链式调用: {result}")

result2 = (Maybe.nothing()
    .map(lambda x: x * 2)
    .get_or_else("默认值"))
print(f"Maybe空值处理: {result2}")

def parse_int(s: str) -> Either[int]:
    try:
        return Either.right(int(s))
    except ValueError as e:
        return Either.left(str(e))

result3 = (parse_int("42")
    .map(lambda x: x * 2)
    .map(lambda x: x + 10))
print(f"Either成功: {result3}")

result4 = (parse_int("abc")
    .map(lambda x: x * 2))
print(f"Either失败: {result4}")

24.5.2 应用函子(Applicative Functor)

python
from typing import TypeVar, Generic, Callable, List
from dataclasses import dataclass

T = TypeVar('T')
R = TypeVar('R')


class Applicative(Generic[T]):
    """
    应用函子:可以应用包装在函子中的函数
    
    Applicative定律:
    1. 恒等律: pure(id) <*> v = v
    2. 同态律: pure(f) <*> pure(x) = pure(f(x))
    3. 交换律: u <*> pure(y) = pure($y) <*> u
    4. 组合律: pure(.) <*> u <*> v <*> w = u <*> (v <*> w)
    """
    
    def __init__(self, value: T = None):
        self._value = value
    
    @staticmethod
    def pure(value: T) -> 'Applicative[T]':
        return Applicative(value)
    
    def map(self, func: Callable[[T], R]) -> 'Applicative[R]':
        if self._value is None:
            return Applicative(None)
        return Applicative(func(self._value))
    
    def ap(self, other: 'Applicative[Callable[[T], R]]') -> 'Applicative[R]':
        """
        应用:将包装的函数应用到包装的值
        """
        if self._value is None or other._value is None:
            return Applicative(None)
        return Applicative(other._value(self._value))
    
    def __repr__(self) -> str:
        return f"Applicative({self._value})"


class Validation(Generic[T]):
    """
    Validation应用函子:累积多个错误
    """
    
    def __init__(self, value: T = None, errors: List[str] = None):
        self._value = value
        self._errors = errors or []
    
    @staticmethod
    def success(value: T) -> 'Validation[T]':
        return Validation(value=value)
    
    @staticmethod
    def failure(error: str) -> 'Validation[T]':
        return Validation(errors=[error])
    
    @staticmethod
    def failures(errors: List[str]) -> 'Validation[T]':
        return Validation(errors=errors)
    
    def is_success(self) -> bool:
        return not self._errors
    
    def map(self, func: Callable[[T], R]) -> 'Validation[R]':
        if self.is_success():
            return Validation.success(func(self._value))
        return Validation.failures(self._errors)
    
    def ap(self, other: 'Validation[Callable[[T], R]]') -> 'Validation[R]':
        """
        组合验证,累积所有错误
        """
        if self.is_success() and other.is_success():
            return Validation.success(other._value(self._value))
        
        all_errors = self._errors + other._errors
        return Validation.failures(all_errors)
    
    def get_or_else(self, default: T) -> T:
        return self._value if self.is_success() else default
    
    def get_errors(self) -> List[str]:
        return self._errors
    
    def __repr__(self) -> str:
        if self.is_success():
            return f"Success({self._value})"
        return f"Failure({self._errors})"


def validate_email(email: str) -> Validation[str]:
    if '@' not in email:
        return Validation.failure("邮箱格式无效")
    return Validation.success(email)

def validate_age(age: int) -> Validation[int]:
    if age < 0:
        return Validation.failure("年龄不能为负数")
    if age > 150:
        return Validation.failure("年龄不合理")
    return Validation.success(age)

def validate_name(name: str) -> Validation[str]:
    if not name or len(name) < 2:
        return Validation.failure("姓名长度不足")
    return Validation.success(name)

def create_user(name: str, email: str, age: int) -> dict:
    return {'name': name, 'email': email, 'age': age}

name_v = validate_name("张三")
email_v = validate_email("test@example.com")
age_v = validate_age(25)

user = (Validation.pure(lambda n, e, a: create_user(n, e, a))
    .ap(name_v)
    .ap(email_v)
    .ap(age_v))

print(f"验证成功: {user}")

name_v2 = validate_name("A")
email_v2 = validate_email("invalid-email")
age_v2 = validate_age(-5)

user2 = (Validation.pure(lambda n, e, a: create_user(n, e, a))
    .ap(name_v2)
    .ap(email_v2)
    .ap(age_v2))

print(f"验证失败: {user2}")

24.5.3 单子(Monad)

python
from typing import TypeVar, Generic, Callable, List, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod

T = TypeVar('T')
R = TypeVar('R')


class Monad(ABC, Generic[T]):
    """
    单子抽象基类
    
    单子定律:
    1. 左单位元: return a >>= f = f a
    2. 右单位元: m >>= return = m
    3. 结合律: (m >>= f) >>= g = m >>= (\x -> f x >>= g)
    """
    
    @staticmethod
    @abstractmethod
    def return_(value: T) -> 'Monad[T]':
        pass
    
    @abstractmethod
    def bind(self, func: Callable[[T], 'Monad[R]']) -> 'Monad[R]':
        pass
    
    def __rshift__(self, func: Callable[[T], 'Monad[R]']) -> 'Monad[R]':
        return self.bind(func)


class MaybeMonad(Monad[T]):
    """Maybe单子"""
    
    def __init__(self, value: T = None):
        self._value = value
    
    @staticmethod
    def return_(value: T) -> 'MaybeMonad[T]':
        return MaybeMonad(value)
    
    @staticmethod
    def nothing() -> 'MaybeMonad[T]':
        return MaybeMonad(None)
    
    def is_just(self) -> bool:
        return self._value is not None
    
    def bind(self, func: Callable[[T], 'Monad[R]']) -> 'MaybeMonad[R]':
        if self.is_just():
            return func(self._value)
        return MaybeMonad.nothing()
    
    def get_or_else(self, default: T) -> T:
        return self._value if self.is_just() else default
    
    def __repr__(self) -> str:
        if self.is_just():
            return f"Just({self._value})"
        return "Nothing"


class EitherMonad(Monad[T]):
    """Either单子"""
    
    def __init__(self, left=None, right=None):
        self._left = left
        self._right = right
    
    @staticmethod
    def return_(value: T) -> 'EitherMonad[T]':
        return EitherMonad(right=value)
    
    @staticmethod
    def left(value: Any) -> 'EitherMonad[T]':
        return EitherMonad(left=value)
    
    @staticmethod
    def right(value: T) -> 'EitherMonad[T]':
        return EitherMonad(right=value)
    
    def is_right(self) -> bool:
        return self._right is not None
    
    def bind(self, func: Callable[[T], 'Monad[R]']) -> 'EitherMonad[R]':
        if self.is_right():
            return func(self._right)
        return EitherMonad.left(self._left)
    
    def get_or_else(self, default: T) -> T:
        return self._right if self.is_right() else default
    
    def __repr__(self) -> str:
        if self.is_right():
            return f"Right({self._right})"
        return f"Left({self._left})"


class ListMonad(Monad[T]):
    """列表单子"""
    
    def __init__(self, values: List[T]):
        self._values = list(values)
    
    @staticmethod
    def return_(value: T) -> 'ListMonad[T]':
        return ListMonad([value])
    
    def bind(self, func: Callable[[T], 'Monad[R]']) -> 'ListMonad[R]':
        result = []
        for value in self._values:
            monad = func(value)
            if isinstance(monad, ListMonad):
                result.extend(monad._values)
        return ListMonad(result)
    
    def __iter__(self):
        return iter(self._values)
    
    def __repr__(self) -> str:
        return f"ListMonad({self._values})"


def safe_divide_m(a: float, b: float) -> MaybeMonad[float]:
    if b == 0:
        return MaybeMonad.nothing()
    return MaybeMonad.return_(a / b)

def safe_sqrt(x: float) -> MaybeMonad[float]:
    if x < 0:
        return MaybeMonad.nothing()
    import math
    return MaybeMonad.return_(math.sqrt(x))

result = (MaybeMonad.return_(16)
    .bind(lambda x: safe_divide_m(x, 4))
    .bind(lambda x: safe_sqrt(x)))

print(f"单子链式计算: {result}")

result2 = (MaybeMonad.return_(16)
    .bind(lambda x: safe_divide_m(x, 0))
    .bind(lambda x: safe_sqrt(x)))

print(f"单子链式计算(失败): {result2}")


def pythagorean_triples(n: int) -> ListMonad[tuple]:
    """使用列表单子生成毕达哥拉斯三元组"""
    return (ListMonad(range(1, n + 1))
        .bind(lambda x:
            ListMonad(range(x, n + 1))
            .bind(lambda y:
                ListMonad(range(y, n + 1))
                .bind(lambda z:
                    ListMonad.return_((x, y, z)) if x*x + y*y == z*z
                    else ListMonad([])))))

triples = pythagorean_triples(20)
print(f"毕达哥拉斯三元组: {list(triples)[:5]}")

24.6 惰性求值与无限数据结构

24.6.1 惰性求值

python
from typing import Callable, TypeVar, Generic, Iterator, Any
from functools import wraps

T = TypeVar('T')
R = TypeVar('R')


class Lazy(Generic[T]):
    """
    惰性值:延迟计算直到需要时
    """
    
    def __init__(self, func: Callable[[], T]):
        self._func = func
        self._value: T = None
        self._evaluated = False
    
    def evaluate(self) -> T:
        if not self._evaluated:
            self._value = self._func()
            self._evaluated = True
        return self._value
    
    def map(self, func: Callable[[T], R]) -> 'Lazy[R]':
        return Lazy(lambda: func(self.evaluate()))
    
    def flat_map(self, func: Callable[[T], 'Lazy[R]']) -> 'Lazy[R]':
        return Lazy(lambda: func(self.evaluate()).evaluate())
    
    def __repr__(self) -> str:
        if self._evaluated:
            return f"Lazy({self._value})"
        return "Lazy(<unevaluated>)"


def lazy_fibonacci() -> Iterator[int]:
    """惰性斐波那契数列"""
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b


def lazy_primes() -> Iterator[int]:
    """惰性质数序列"""
    def is_prime(n: int) -> bool:
        if n < 2:
            return False
        for i in range(2, int(n ** 0.5) + 1):
            if n % i == 0:
                return False
        return True
    
    n = 2
    while True:
        if is_prime(n):
            yield n
        n += 1


def take(n: int, iterator: Iterator[T]) -> list:
    """从迭代器取前n个元素"""
    return [next(iterator) for _ in range(n)]


def take_while(predicate: Callable[[T], bool], iterator: Iterator[T]) -> list:
    """取满足条件的元素直到第一个不满足"""
    result = []
    for x in iterator:
        if predicate(x):
            result.append(x)
        else:
            break
    return result


def drop_while(predicate: Callable[[T], bool], iterator: Iterator[T]) -> Iterator[T]:
    """丢弃满足条件的元素直到第一个不满足"""
    for x in iterator:
        if not predicate(x):
            yield x
            break
    yield from iterator


class LazySequence:
    """惰性序列"""
    
    def __init__(self, iterator_factory: Callable[[], Iterator[T]]):
        self._iterator_factory = iterator_factory
    
    def map(self, func: Callable[[T], R]) -> 'LazySequence[R]':
        return LazySequence(lambda: map(func, self._iterator_factory()))
    
    def filter(self, predicate: Callable[[T], bool]) -> 'LazySequence[T]':
        return LazySequence(lambda: filter(predicate, self._iterator_factory()))
    
    def take(self, n: int) -> list:
        return take(n, self._iterator_factory())
    
    def take_while(self, predicate: Callable[[T], bool]) -> list:
        return take_while(predicate, self._iterator_factory())
    
    def to_list(self) -> list:
        return list(self._iterator_factory())


fib = lazy_fibonacci()
print(f"前10个斐波那契数: {take(10, fib)}")

primes = lazy_primes()
print(f"前10个质数: {take(10, primes)}")

lazy_nums = LazySequence(lambda: iter(range(1000000)))
result = (lazy_nums
    .filter(lambda x: x % 2 == 0)
    .map(lambda x: x ** 2)
    .take(10))
print(f"惰性序列处理: {result}")

24.6.2 无限数据结构

python
from typing import Callable, TypeVar, Generic, Iterator, Any
from dataclasses import dataclass
from functools import wraps

T = TypeVar('T')


class Stream(Generic[T]):
    """
    流:惰性无限序列
    """
    
    def __init__(self, head: T, tail_factory: Callable[[], 'Stream[T]'] = None):
        self._head = head
        self._tail_factory = tail_factory
        self._tail: 'Stream[T]' = None
        self._tail_evaluated = False
    
    @property
    def head(self) -> T:
        return self._head
    
    @property
    def tail(self) -> 'Stream[T]':
        if not self._tail_evaluated and self._tail_factory:
            self._tail = self._tail_factory()
            self._tail_evaluated = True
        return self._tail
    
    @staticmethod
    def iterate(func: Callable[[T], T], initial: T) -> 'Stream[T]':
        """创建无限流:f(x), f(f(x)), ..."""
        return Stream(initial, lambda: Stream.iterate(func, func(initial)))
    
    @staticmethod
    def repeat(value: T) -> 'Stream[T]':
        """创建重复值的无限流"""
        return Stream(value, lambda: Stream.repeat(value))
    
    @staticmethod
    def cycle(values: list) -> 'Stream[T]':
        """创建循环值的无限流"""
        def make_stream(index: int) -> Stream[T]:
            return Stream(values[index], lambda: make_stream((index + 1) % len(values)))
        return make_stream(0)
    
    def map(self, func: Callable[[T], Any]) -> 'Stream[Any]':
        return Stream(func(self._head), lambda: self.tail.map(func) if self.tail else None)
    
    def filter(self, predicate: Callable[[T], bool]) -> 'Stream[T]':
        if predicate(self._head):
            return Stream(self._head, lambda: self.tail.filter(predicate) if self.tail else None)
        return self.tail.filter(predicate) if self.tail else None
    
    def take(self, n: int) -> list:
        result = []
        current = self
        for _ in range(n):
            if current is None:
                break
            result.append(current.head)
            current = current.tail
        return result
    
    def zip_with(self, other: 'Stream[T]', func: Callable[[T, T], Any]) -> 'Stream[Any]':
        return Stream(
            func(self.head, other.head),
            lambda: self.tail.zip_with(other.tail, func) if self.tail and other.tail else None
        )
    
    def __repr__(self) -> str:
        return f"Stream({self._head}, ...)"


natural_numbers = Stream.iterate(lambda x: x + 1, 0)
print(f"自然数前10个: {natural_numbers.take(10)}")

powers_of_2 = Stream.iterate(lambda x: x * 2, 1)
print(f"2的幂前10个: {powers_of_2.take(10)}")

repeated = Stream.repeat("Hello")
print(f"重复值前5个: {repeated.take(5)}")

cycled = Stream.cycle([1, 2, 3])
print(f"循环值前10个: {cycled.take(10)}")

evens = natural_numbers.filter(lambda x: x % 2 == 0)
print(f"偶数前10个: {evens.take(10)}")

sums = natural_numbers.zip_with(powers_of_2, lambda a, b: a + b)
print(f"自然数+2的幂 前10个: {sums.take(10)}")

24.7 记忆化与性能优化

24.7.1 记忆化

python
from functools import lru_cache, wraps
from typing import Callable, Dict, Any, Tuple
import hashlib
import json
import time


def memoize(func: Callable) -> Callable:
    """
    自定义记忆化装饰器
    """
    cache: Dict[Tuple, Any] = {}
    
    @wraps(func)
    def wrapper(*args, **kwargs):
        key = (args, tuple(sorted(kwargs.items())))
        if key not in cache:
            cache[key] = func(*args, **kwargs)
        return cache[key]
    
    wrapper.cache = cache
    wrapper.cache_clear = lambda: cache.clear()
    wrapper.cache_info = lambda: {'size': len(cache)}
    return wrapper


def memoize_with_ttl(ttl_seconds: float):
    """
    带过期时间的记忆化
    """
    def decorator(func: Callable) -> Callable:
        cache: Dict[Tuple, Tuple[float, Any]] = {}
        
        @wraps(func)
        def wrapper(*args, **kwargs):
            key = (args, tuple(sorted(kwargs.items())))
            current_time = time.time()
            
            if key in cache:
                cached_time, cached_value = cache[key]
                if current_time - cached_time < ttl_seconds:
                    return cached_value
            
            result = func(*args, **kwargs)
            cache[key] = (current_time, result)
            return result
        
        wrapper.cache = cache
        return wrapper
    return decorator


def memoize_method(func: Callable) -> Callable:
    """
    方法记忆化(支持实例方法)
    """
    cache_attr = f'_memoized_{func.__name__}'
    
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        if not hasattr(self, cache_attr):
            setattr(self, cache_attr, {})
        cache = getattr(self, cache_attr)
        
        key = (args, tuple(sorted(kwargs.items())))
        if key not in cache:
            cache[key] = func(self, *args, **kwargs)
        return cache[key]
    
    return wrapper


@memoize
def fibonacci(n: int) -> int:
    if n <= 1:
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)


@lru_cache(maxsize=128)
def expensive_computation(n: int) -> int:
    time.sleep(0.01)
    return n ** 2


class Calculator:
    def __init__(self, factor: int):
        self.factor = factor
    
    @memoize_method
    def compute(self, x: int) -> int:
        time.sleep(0.01)
        return x ** self.factor


print(f"斐波那契(50): {fibonacci(50)}")

start = time.time()
for i in range(100):
    expensive_computation(5)
print(f"带缓存计算100次耗时: {time.time() - start:.3f}秒")
print(f"缓存信息: {expensive_computation.cache_info()}")

calc = Calculator(3)
start = time.time()
for _ in range(10):
    calc.compute(10)
print(f"方法记忆化计算10次耗时: {time.time() - start:.3f}秒")

24.7.2 尾递归优化

python
from typing import Callable, TypeVar, Any
from functools import wraps

T = TypeVar('T')


def tail_recursive(func: Callable) -> Callable:
    """
    尾递归优化装饰器(通过蹦床实现)
    """
    class TailCall:
        def __init__(self, func, *args, **kwargs):
            self.func = func
            self.args = args
            self.kwargs = kwargs
    
    @wraps(func)
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        while isinstance(result, TailCall):
            result = result.func(*result.args, **result.kwargs)
        return result
    
    def tail_call(*args, **kwargs):
        return TailCall(func, *args, **kwargs)
    
    wrapper.tail_call = tail_call
    return wrapper


def factorial_tr(n: int, acc: int = 1) -> int:
    """尾递归阶乘"""
    if n <= 1:
        return acc
    return factorial_tr(n - 1, n * acc)


def factorial_iterative(n: int) -> int:
    """迭代阶乘"""
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result


def fibonacci_tr(n: int, a: int = 0, b: int = 1) -> int:
    """尾递归斐波那契"""
    if n == 0:
        return a
    if n == 1:
        return b
    return fibonacci_tr(n - 1, b, a + b)


def fibonacci_iterative(n: int) -> int:
    """迭代斐波那契"""
    if n == 0:
        return 0
    a, b = 0, 1
    for _ in range(2, n + 1):
        a, b = b, a + b
    return b


print(f"尾递归阶乘(10): {factorial_tr(10)}")
print(f"迭代阶乘(10): {factorial_iterative(10)}")
print(f"尾递归斐波那契(50): {fibonacci_tr(50)}")
print(f"迭代斐波那契(50): {fibonacci_iterative(50)}")

24.8 模式匹配与代数数据类型

24.8.1 代数数据类型

python
from dataclasses import dataclass
from typing import Any, Union, List, Callable, TypeVar, Generic
from abc import ABC, abstractmethod
from enum import Enum, auto

T = TypeVar('T')
R = TypeVar('R')


class Shape(ABC):
    """形状ADT"""
    
    @abstractmethod
    def area(self) -> float:
        pass
    
    @abstractmethod
    def perimeter(self) -> float:
        pass


@dataclass(frozen=True)
class Circle(Shape):
    radius: float
    
    def area(self) -> float:
        import math
        return math.pi * self.radius ** 2
    
    def perimeter(self) -> float:
        import math
        return 2 * math.pi * self.radius


@dataclass(frozen=True)
class Rectangle(Shape):
    width: float
    height: float
    
    def area(self) -> float:
        return self.width * self.height
    
    def perimeter(self) -> float:
        return 2 * (self.width + self.height)


@dataclass(frozen=True)
class Triangle(Shape):
    a: float
    b: float
    c: float
    
    def area(self) -> float:
        import math
        s = (self.a + self.b + self.c) / 2
        return math.sqrt(s * (s - self.a) * (s - self.b) * (s - self.c))
    
    def perimeter(self) -> float:
        return self.a + self.b + self.c


class Option(ABC, Generic[T]):
    """Option ADT"""
    
    @abstractmethod
    def is_some(self) -> bool:
        pass
    
    @abstractmethod
    def is_none(self) -> bool:
        pass
    
    @abstractmethod
    def get_or_else(self, default: T) -> T:
        pass


@dataclass(frozen=True)
class Some(Option[T]):
    value: T
    
    def is_some(self) -> bool:
        return True
    
    def is_none(self) -> bool:
        return False
    
    def get_or_else(self, default: T) -> T:
        return self.value


@dataclass(frozen=True)
class None_(Option[T]):
    def is_some(self) -> bool:
        return False
    
    def is_none(self) -> bool:
        return True
    
    def get_or_else(self, default: T) -> T:
        return default


class Result(ABC, Generic[T, E]):
    """Result ADT"""
    
    @abstractmethod
    def is_ok(self) -> bool:
        pass
    
    @abstractmethod
    def is_error(self) -> bool:
        pass


@dataclass(frozen=True)
class Ok(Result[T, E]):
    value: T
    
    def is_ok(self) -> bool:
        return True
    
    def is_error(self) -> bool:
        return False


@dataclass(frozen=True)
class Error(Result[T, E]):
    error: E
    
    def is_ok(self) -> bool:
        return False
    
    def is_error(self) -> bool:
        return True


shapes: List[Shape] = [
    Circle(5),
    Rectangle(4, 6),
    Triangle(3, 4, 5)
]

for shape in shapes:
    print(f"{type(shape).__name__}: 面积={shape.area():.2f}, 周长={shape.perimeter():.2f}")

24.8.2 模式匹配

python
from dataclasses import dataclass
from typing import Any, Union, List, Callable, TypeVar, Generic, Type
from abc import ABC

T = TypeVar('T')
R = TypeVar('R')


def match(value: Any, patterns: dict) -> Any:
    """
    简单模式匹配
    """
    for pattern, handler in patterns.items():
        if callable(pattern) and not isinstance(pattern, type):
            if pattern(value):
                return handler(value)
        elif isinstance(pattern, type):
            if isinstance(value, pattern):
                return handler(value)
        elif value == pattern:
            return handler(value)
    
    if None in patterns:
        return patterns[None](value)
    
    raise ValueError(f"未匹配的值: {value}")


class PatternMatcher:
    """
    高级模式匹配器
    """
    
    def __init__(self, value: Any):
        self._value = value
        self._matched = False
        self._result = None
    
    def case(self, pattern: Any, handler: Callable) -> 'PatternMatcher':
        if self._matched:
            return self
        
        if self._matches(pattern):
            self._result = handler(self._value)
            self._matched = True
        
        return self
    
    def _matches(self, pattern: Any) -> bool:
        if pattern is Any:
            return True
        if isinstance(pattern, type):
            return isinstance(self._value, pattern)
        if callable(pattern):
            return pattern(self._value)
        if isinstance(pattern, dict):
            return self._match_dict(pattern)
        if isinstance(pattern, (list, tuple)):
            return self._match_sequence(pattern)
        return self._value == pattern
    
    def _match_dict(self, pattern: dict) -> bool:
        if not isinstance(self._value, dict):
            return False
        for key, val_pattern in pattern.items():
            if key not in self._value:
                return False
            if not PatternMatcher(self._value[key])._matches(val_pattern):
                return False
        return True
    
    def _match_sequence(self, pattern: Union[list, tuple]) -> bool:
        if not isinstance(self._value, (list, tuple)):
            return False
        if len(pattern) != len(self._value):
            return False
        for p, v in zip(pattern, self._value):
            if not PatternMatcher(v)._matches(p):
                return False
        return True
    
    def default(self, handler: Callable) -> 'PatternMatcher':
        if not self._matched:
            self._result = handler(self._value)
            self._matched = True
        return self
    
    def result(self) -> Any:
        if not self._matched:
            raise ValueError("模式匹配失败")
        return self._result


def describe(value: Any) -> str:
    return (PatternMatcher(value)
        .case(0, lambda _: "零")
        .case(lambda x: x > 0, lambda x: f"正数: {x}")
        .case(lambda x: x < 0, lambda x: f"负数: {x}")
        .case(list, lambda x: f"列表: {x}")
        .case(dict, lambda x: f"字典: {x}")
        .case(str, lambda x: f"字符串: {x}")
        .default(lambda x: f"未知类型: {type(x).__name__}")
        .result())


print(describe(0))
print(describe(42))
print(describe(-10))
print(describe([1, 2, 3]))
print(describe({'a': 1}))
print(describe("hello"))

24.9 函数式设计模式应用

24.9.1 函数式错误处理

python
from typing import Callable, TypeVar, Generic, Optional, List, Any
from dataclasses import dataclass
from functools import wraps

T = TypeVar('T')
E = TypeVar('E')
R = TypeVar('R')


@dataclass(frozen=True)
class Result(Generic[T, E]):
    """函数式结果类型"""
    _value: Optional[T] = None
    _error: Optional[E] = None
    
    @staticmethod
    def ok(value: T) -> 'Result[T, E]':
        return Result(_value=value)
    
    @staticmethod
    def error(error: E) -> 'Result[T, E]':
        return Result(_error=error)
    
    def is_ok(self) -> bool:
        return self._value is not None
    
    def is_error(self) -> bool:
        return self._error is not None
    
    def unwrap(self) -> T:
        if self.is_error():
            raise ValueError(str(self._error))
        return self._value
    
    def unwrap_or(self, default: T) -> T:
        return self._value if self.is_ok() else default
    
    def map(self, func: Callable[[T], R]) -> 'Result[R, E]':
        if self.is_ok():
            return Result.ok(func(self._value))
        return Result.error(self._error)
    
    def map_error(self, func: Callable[[E], R]) -> 'Result[T, R]':
        if self.is_error():
            return Result.error(func(self._error))
        return Result.ok(self._value)
    
    def and_then(self, func: Callable[[T], 'Result[R, E]']) -> 'Result[R, E]':
        if self.is_ok():
            return func(self._value)
        return Result.error(self._error)
    
    def or_else(self, func: Callable[[E], 'Result[T, R]']) -> 'Result[T, R]':
        if self.is_error():
            return func(self._error)
        return Result.ok(self._value)
    
    def __repr__(self) -> str:
        if self.is_ok():
            return f"Ok({self._value})"
        return f"Error({self._error})"


def try_catch(func: Callable) -> Callable:
    """将可能抛出异常的函数转换为返回Result的函数"""
    @wraps(func)
    def wrapper(*args, **kwargs) -> Result:
        try:
            return Result.ok(func(*args, **kwargs))
        except Exception as e:
            return Result.error(str(e))
    return wrapper


@try_catch
def parse_int(s: str) -> int:
    return int(s)


@try_catch
def divide(a: float, b: float) -> float:
    return a / b


def validate_user(data: dict) -> Result[dict, List[str]]:
    """验证用户数据"""
    errors = []
    
    if not data.get('name'):
        errors.append("姓名不能为空")
    elif len(data['name']) < 2:
        errors.append("姓名长度不足")
    
    if not data.get('email'):
        errors.append("邮箱不能为空")
    elif '@' not in data['email']:
        errors.append("邮箱格式无效")
    
    if data.get('age') is not None:
        if not isinstance(data['age'], int):
            errors.append("年龄必须是整数")
        elif data['age'] < 0 or data['age'] > 150:
            errors.append("年龄范围无效")
    
    if errors:
        return Result.error(errors)
    return Result.ok(data)


result = (parse_int("42")
    .and_then(lambda x: divide(x, 2))
    .map(lambda x: x * 3))

print(f"成功链: {result}")

result2 = (parse_int("abc")
    .and_then(lambda x: divide(x, 2))
    .map(lambda x: x * 3))

print(f"失败链: {result2}")

user_data = {'name': '张', 'email': 'invalid', 'age': -5}
validation = validate_user(user_data)
print(f"验证结果: {validation}")

24.9.2 函数式状态管理

python
from typing import Callable, TypeVar, Generic, Tuple, Any
from dataclasses import dataclass

S = TypeVar('S')
A = TypeVar('A')
B = TypeVar('B')


@dataclass(frozen=True)
class State(Generic[S, A]):
    """
    State单子:封装状态转换
    State s a = s -> (a, s)
    """
    run: Callable[[S], Tuple[A, S]]
    
    @staticmethod
    def return_(value: A) -> 'State[S, A]':
        return State(lambda s: (value, s))
    
    @staticmethod
    def get() -> 'State[S, S]':
        return State(lambda s: (s, s))
    
    @staticmethod
    def put(new_state: S) -> 'State[S, None]':
        return State(lambda s: (None, new_state))
    
    @staticmethod
    def modify(func: Callable[[S], S]) -> 'State[S, None]':
        return State(lambda s: (None, func(s)))
    
    def map(self, func: Callable[[A], B]) -> 'State[S, B]':
        return State(lambda s: (func(self.run(s)[0]), self.run(s)[1]))
    
    def bind(self, func: Callable[[A], 'State[S, B]']) -> 'State[S, B]':
        def run(s: S) -> Tuple[B, S]:
            value, new_state = self.run(s)
            return func(value).run(new_state)
        return State(run)
    
    def eval(self, initial_state: S) -> A:
        return self.run(initial_state)[0]
    
    def exec(self, initial_state: S) -> S:
        return self.run(initial_state)[1]
    
    def run_state(self, initial_state: S) -> Tuple[A, S]:
        return self.run(initial_state)


class Counter:
    """使用State单子实现的计数器"""
    
    @staticmethod
    def increment() -> State[int, int]:
        return State.modify(lambda s: s + 1).bind(lambda _: State.get())
    
    @staticmethod
    def decrement() -> State[int, int]:
        return State.modify(lambda s: s - 1).bind(lambda _: State.get())
    
    @staticmethod
    def add(n: int) -> State[int, int]:
        return State.modify(lambda s: s + n).bind(lambda _: State.get())
    
    @staticmethod
    def reset() -> State[int, int]:
        return State.put(0).bind(lambda _: State.get())


counter_ops = (
    State.return_(None)
    .bind(lambda _: Counter.add(10))
    .bind(lambda _: Counter.increment())
    .bind(lambda _: Counter.increment())
    .bind(lambda _: Counter.decrement())
)

result, final_state = counter_ops.run_state(0)
print(f"计数器最终状态: {final_state}")


class Stack:
    """使用State单子实现的栈"""
    
    @staticmethod
    def push(value: Any) -> State[list, None]:
        return State.modify(lambda s: s + [value])
    
    @staticmethod
    def pop() -> State[list, Any]:
        def do_pop(stack: list) -> Tuple[Any, list]:
            if not stack:
                raise ValueError("栈为空")
            return stack[-1], stack[:-1]
        return State(do_pop)
    
    @staticmethod
    def peek() -> State[list, Any]:
        return State(lambda s: (s[-1] if s else None, s))


stack_ops = (
    Stack.push(1)
    .bind(lambda _: Stack.push(2))
    .bind(lambda _: Stack.push(3))
    .bind(lambda _: Stack.pop())
    .bind(lambda top: State.return_(f"弹出: {top}"))
)

result, final_stack = stack_ops.run_state([])
print(f"栈操作结果: {result}, 最终栈: {final_stack}")

24.10 小结

函数式编程模式强调通过纯函数、不可变数据和函数组合来构建可靠、可测试的软件系统。

关键要点

  1. 纯函数:无副作用、引用透明,易于测试和推理
  2. 不可变性:避免状态共享问题,提高并发安全性
  3. 高阶函数:函数作为一等公民,支持灵活的抽象
  4. 函子与单子:提供强大的组合和错误处理能力
  5. 惰性求值:按需计算,支持无限数据结构
  6. 模式匹配:清晰的条件分支处理,提高代码可读性

实践建议

  1. 优先使用纯函数,将副作用隔离到边界
  2. 使用不可变数据结构,避免状态共享
  3. 利用高阶函数和函数组合提高代码复用
  4. 使用Maybe/Either/Result处理可能的失败
  5. 结合面向对象和函数式编程发挥两者优势

下一章预告

下一章将介绍并发编程模式,探讨如何在Python中设计线程安全、高效的并发系统。

Python技术丛书 - 江苏省宿城中等专业学校