第28章 数据库编程
学习目标
完成本章学习后,你将能够:
- 理解数据库基础概念:关系型数据库原理、SQL语言、数据库设计范式
- 掌握SQLite数据库操作:创建数据库、执行SQL语句、事务管理
- 使用SQLAlchemy ORM:模型定义、查询构建、关系映射
- 实现数据库连接池:连接管理、性能优化、资源释放
- 处理数据库迁移:Alembic迁移工具、版本控制、回滚策略
- 掌握高级查询技术:复杂查询、聚合函数、子查询、连接查询
- 实现数据库安全实践:参数化查询、SQL注入防护、权限管理
- 构建数据库应用架构:Repository模式、Unit of Work、数据访问层设计
28.1 数据库基础概念
28.1.1 关系型数据库概述
关系型数据库(Relational Database)是基于关系模型的数据库,数据以表格形式存储,表格之间通过关系连接。
┌─────────────────────────────────────────────────────────────────────┐
│ 关系型数据库结构 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ users 表 │ │ orders 表 │ │
│ ├─────────────────┤ ├─────────────────┤ │
│ │ id (PK) │◄────────│ user_id (FK) │ │
│ │ username │ 1:N │ id (PK) │ │
│ │ email │ │ order_date │ │
│ │ created_at │ │ total_amount │ │
│ └─────────────────┘ │ status │ │
│ └─────────────────┘ │
│ │ │
│ │ N:1 │
│ ▼ │
│ ┌─────────────────┐ │
│ │ products 表 │ │
│ ├─────────────────┤ │
│ │ id (PK) │ │
│ │ name │ │
│ │ price │ │
│ │ category │ │
│ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
PK = Primary Key (主键)
FK = Foreign Key (外键)28.1.2 SQL语言基础
SQL(Structured Query Language)是操作关系型数据库的标准语言:
| 类别 | 命令 | 用途 |
|---|---|---|
| DDL | CREATE, ALTER, DROP | 定义数据库结构 |
| DML | INSERT, UPDATE, DELETE, SELECT | 操作数据 |
| DCL | GRANT, REVOKE | 控制访问权限 |
| TCL | COMMIT, ROLLBACK, SAVEPOINT | 事务控制 |
28.1.3 数据库设计范式
python
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
@dataclass
class User:
id: int
username: str
email: str
created_at: datetime
@dataclass
class Product:
id: int
name: str
price: float
category: str
stock: int
@dataclass
class Order:
id: int
user_id: int
order_date: datetime
status: str
total_amount: float
@dataclass
class OrderItem:
id: int
order_id: int
product_id: int
quantity: int
unit_price: float28.2 SQLite数据库操作
28.2.1 SQLite简介
SQLite是一个轻量级的嵌入式数据库,无需服务器,适合小型应用和原型开发。
python
import sqlite3
from pathlib import Path
from typing import Any, Optional
import logging
logger = logging.getLogger(__name__)
class SQLiteManager:
def __init__(self, db_path: str | Path = ":memory:"):
self.db_path = Path(db_path) if db_path != ":memory:" else db_path
self._connection: Optional[sqlite3.Connection] = None
self._cursor: Optional[sqlite3.Cursor] = None
@property
def connection(self) -> sqlite3.Connection:
if self._connection is None:
self._connection = sqlite3.connect(
self.db_path,
detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
)
self._connection.row_factory = sqlite3.Row
self._connection.execute("PRAGMA foreign_keys = ON")
logger.info(f"Connected to database: {self.db_path}")
return self._connection
@property
def cursor(self) -> sqlite3.Cursor:
if self._cursor is None:
self._cursor = self.connection.cursor()
return self._cursor
def execute(
self,
query: str,
params: tuple | dict = (),
commit: bool = False
) -> sqlite3.Cursor:
cursor = self.cursor.execute(query, params)
if commit:
self.connection.commit()
return cursor
def executemany(
self,
query: str,
params_list: list[tuple | dict],
commit: bool = False
) -> sqlite3.Cursor:
cursor = self.cursor.executemany(query, params_list)
if commit:
self.connection.commit()
return cursor
def fetchone(self, query: str, params: tuple | dict = ()) -> Optional[sqlite3.Row]:
cursor = self.execute(query, params)
return cursor.fetchone()
def fetchall(self, query: str, params: tuple | dict = ()) -> list[sqlite3.Row]:
cursor = self.execute(query, params)
return cursor.fetchall()
def fetchvalue(self, query: str, params: tuple | dict = ()) -> Any:
row = self.fetchone(query, params)
return row[0] if row else None
def close(self) -> None:
if self._cursor:
self._cursor.close()
self._cursor = None
if self._connection:
self._connection.close()
self._connection = None
logger.info("Database connection closed")
def __enter__(self) -> "SQLiteManager":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if exc_type:
self.connection.rollback()
self.close()
def create_tables(self) -> None:
self.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.execute("""
CREATE TABLE IF NOT EXISTS products (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
description TEXT,
price REAL NOT NULL CHECK(price >= 0),
category TEXT NOT NULL,
stock INTEGER DEFAULT 0 CHECK(stock >= 0),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.execute("""
CREATE TABLE IF NOT EXISTS orders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
status TEXT DEFAULT 'pending',
total_amount REAL DEFAULT 0,
shipping_address TEXT,
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
)
""")
self.execute("""
CREATE TABLE IF NOT EXISTS order_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
order_id INTEGER NOT NULL,
product_id INTEGER NOT NULL,
quantity INTEGER NOT NULL CHECK(quantity > 0),
unit_price REAL NOT NULL CHECK(unit_price >= 0),
FOREIGN KEY (order_id) REFERENCES orders(id) ON DELETE CASCADE,
FOREIGN KEY (product_id) REFERENCES products(id) ON DELETE RESTRICT
)
""")
self.connection.commit()
logger.info("Tables created successfully")28.2.2 CRUD操作实现
python
from dataclasses import asdict
from datetime import datetime
import hashlib
import secrets
class UserRepository:
def __init__(self, db: SQLiteManager):
self.db = db
def create(self, username: str, email: str, password: str) -> int:
password_hash = self._hash_password(password)
cursor = self.db.execute(
"""
INSERT INTO users (username, email, password_hash)
VALUES (?, ?, ?)
""",
(username, email, password_hash),
commit=True
)
return cursor.lastrowid
def find_by_id(self, user_id: int) -> Optional[dict]:
row = self.db.fetchone(
"SELECT id, username, email, created_at FROM users WHERE id = ?",
(user_id,)
)
return dict(row) if row else None
def find_by_username(self, username: str) -> Optional[dict]:
row = self.db.fetchone(
"SELECT * FROM users WHERE username = ?",
(username,)
)
return dict(row) if row else None
def find_all(self, limit: int = 100, offset: int = 0) -> list[dict]:
rows = self.db.fetchall(
"SELECT id, username, email, created_at FROM users ORDER BY created_at DESC LIMIT ? OFFSET ?",
(limit, offset)
)
return [dict(row) for row in rows]
def update(self, user_id: int, **fields) -> bool:
allowed_fields = {"username", "email"}
updates = {k: v for k, v in fields.items() if k in allowed_fields}
if not updates:
return False
updates["updated_at"] = datetime.now()
set_clause = ", ".join(f"{k} = ?" for k in updates.keys())
values = list(updates.values()) + [user_id]
cursor = self.db.execute(
f"UPDATE users SET {set_clause} WHERE id = ?",
tuple(values),
commit=True
)
return cursor.rowcount > 0
def delete(self, user_id: int) -> bool:
cursor = self.db.execute(
"DELETE FROM users WHERE id = ?",
(user_id,),
commit=True
)
return cursor.rowcount > 0
def verify_password(self, username: str, password: str) -> bool:
row = self.db.fetchone(
"SELECT password_hash FROM users WHERE username = ?",
(username,)
)
if not row:
return False
return self._verify_password(password, row["password_hash"])
@staticmethod
def _hash_password(password: str) -> str:
salt = secrets.token_hex(16)
hash_value = hashlib.pbkdf2_hmac(
"sha256",
password.encode("utf-8"),
salt.encode("utf-8"),
100000
)
return f"{salt}:{hash_value.hex()}"
@staticmethod
def _verify_password(password: str, stored_hash: str) -> bool:
salt, hash_value = stored_hash.split(":")
new_hash = hashlib.pbkdf2_hmac(
"sha256",
password.encode("utf-8"),
salt.encode("utf-8"),
100000
)
return secrets.compare_digest(new_hash.hex(), hash_value)
class ProductRepository:
def __init__(self, db: SQLiteManager):
self.db = db
def create(
self,
name: str,
price: float,
category: str,
description: str = "",
stock: int = 0
) -> int:
cursor = self.db.execute(
"""
INSERT INTO products (name, description, price, category, stock)
VALUES (?, ?, ?, ?, ?)
""",
(name, description, price, category, stock),
commit=True
)
return cursor.lastrowid
def find_by_id(self, product_id: int) -> Optional[dict]:
row = self.db.fetchone(
"SELECT * FROM products WHERE id = ?",
(product_id,)
)
return dict(row) if row else None
def find_by_category(self, category: str) -> list[dict]:
rows = self.db.fetchall(
"SELECT * FROM products WHERE category = ? ORDER BY name",
(category,)
)
return [dict(row) for row in rows]
def search(self, keyword: str) -> list[dict]:
rows = self.db.fetchall(
"""
SELECT * FROM products
WHERE name LIKE ? OR description LIKE ?
ORDER BY name
""",
(f"%{keyword}%", f"%{keyword}%")
)
return [dict(row) for row in rows]
def update_stock(self, product_id: int, quantity: int) -> bool:
cursor = self.db.execute(
"""
UPDATE products SET stock = stock + ?
WHERE id = ? AND stock + ? >= 0
""",
(quantity, product_id, quantity),
commit=True
)
return cursor.rowcount > 0
def list_all(
self,
category: Optional[str] = None,
min_price: Optional[float] = None,
max_price: Optional[float] = None,
in_stock_only: bool = False,
limit: int = 100,
offset: int = 0
) -> list[dict]:
conditions = []
params = []
if category:
conditions.append("category = ?")
params.append(category)
if min_price is not None:
conditions.append("price >= ?")
params.append(min_price)
if max_price is not None:
conditions.append("price <= ?")
params.append(max_price)
if in_stock_only:
conditions.append("stock > 0")
where_clause = " AND ".join(conditions) if conditions else "1=1"
params.extend([limit, offset])
rows = self.db.fetchall(
f"""
SELECT * FROM products
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT ? OFFSET ?
""",
tuple(params)
)
return [dict(row) for row in rows]28.2.3 事务管理
python
from contextlib import contextmanager
from typing import Generator, Callable, TypeVar, ParamSpec
P = ParamSpec("P")
T = TypeVar("T")
class TransactionManager:
def __init__(self, db: SQLiteManager):
self.db = db
@contextmanager
def transaction(self) -> Generator[sqlite3.Connection, None, None]:
conn = self.db.connection
try:
yield conn
conn.commit()
except Exception as e:
conn.rollback()
logger.error(f"Transaction rolled back: {e}")
raise
def execute_in_transaction(
self,
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs
) -> T:
with self.transaction():
return func(*args, **kwargs)
class OrderService:
def __init__(self, db: SQLiteManager):
self.db = db
self.tx = TransactionManager(db)
self.user_repo = UserRepository(db)
self.product_repo = ProductRepository(db)
def create_order(
self,
user_id: int,
items: list[dict],
shipping_address: str
) -> int:
with self.db.connection:
user = self.user_repo.find_by_id(user_id)
if not user:
raise ValueError(f"User {user_id} not found")
total_amount = 0.0
order_items = []
for item in items:
product = self.product_repo.find_by_id(item["product_id"])
if not product:
raise ValueError(f"Product {item['product_id']} not found")
if product["stock"] < item["quantity"]:
raise ValueError(
f"Insufficient stock for {product['name']}. "
f"Available: {product['stock']}, Requested: {item['quantity']}"
)
unit_price = product["price"]
subtotal = unit_price * item["quantity"]
total_amount += subtotal
order_items.append({
"product_id": item["product_id"],
"quantity": item["quantity"],
"unit_price": unit_price
})
cursor = self.db.execute(
"""
INSERT INTO orders (user_id, total_amount, shipping_address, status)
VALUES (?, ?, ?, 'pending')
""",
(user_id, total_amount, shipping_address)
)
order_id = cursor.lastrowid
for item in order_items:
self.db.execute(
"""
INSERT INTO order_items (order_id, product_id, quantity, unit_price)
VALUES (?, ?, ?, ?)
""",
(order_id, item["product_id"], item["quantity"], item["unit_price"])
)
self.product_repo.update_stock(item["product_id"], -item["quantity"])
self.db.connection.commit()
logger.info(f"Order {order_id} created for user {user_id}")
return order_id
def cancel_order(self, order_id: int) -> bool:
with self.db.connection:
order = self.db.fetchone(
"SELECT * FROM orders WHERE id = ?",
(order_id,)
)
if not order:
raise ValueError(f"Order {order_id} not found")
if order["status"] == "cancelled":
return True
if order["status"] == "delivered":
raise ValueError("Cannot cancel a delivered order")
items = self.db.fetchall(
"SELECT product_id, quantity FROM order_items WHERE order_id = ?",
(order_id,)
)
for item in items:
self.product_repo.update_stock(item["product_id"], item["quantity"])
self.db.execute(
"UPDATE orders SET status = 'cancelled' WHERE id = ?",
(order_id,)
)
self.db.connection.commit()
logger.info(f"Order {order_id} cancelled")
return True
def get_order_details(self, order_id: int) -> Optional[dict]:
order = self.db.fetchone(
"""
SELECT o.*, u.username, u.email
FROM orders o
JOIN users u ON o.user_id = u.id
WHERE o.id = ?
""",
(order_id,)
)
if not order:
return None
items = self.db.fetchall(
"""
SELECT oi.*, p.name as product_name, p.category
FROM order_items oi
JOIN products p ON oi.product_id = p.id
WHERE oi.order_id = ?
""",
(order_id,)
)
return {
"order": dict(order),
"items": [dict(item) for item in items]
}28.3 SQLAlchemy ORM
28.3.1 SQLAlchemy简介
SQLAlchemy是Python最流行的ORM框架,提供了高级的ORM和低级的SQL表达式两种使用方式。
python
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, ForeignKey, Text, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker
from datetime import datetime
from typing import Optional, List
Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(100), unique=True, nullable=False, index=True)
password_hash = Column(String(128), nullable=False)
is_active = Column(Boolean, default=True)
is_admin = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
orders = relationship("Order", back_populates="user", cascade="all, delete-orphan")
def __repr__(self) -> str:
return f"<User(id={self.id}, username='{self.username}')>"
class Product(Base):
__tablename__ = "products"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False)
description = Column(Text)
price = Column(Float, nullable=False)
category = Column(String(50), nullable=False, index=True)
stock = Column(Integer, default=0)
is_available = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
order_items = relationship("OrderItem", back_populates="product")
def __repr__(self) -> str:
return f"<Product(id={self.id}, name='{self.name}', price={self.price})>"
class Order(Base):
__tablename__ = "orders"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
order_date = Column(DateTime, default=datetime.utcnow)
status = Column(String(20), default="pending")
total_amount = Column(Float, default=0.0)
shipping_address = Column(Text)
user = relationship("User", back_populates="orders")
items = relationship("OrderItem", back_populates="order", cascade="all, delete-orphan")
def __repr__(self) -> str:
return f"<Order(id={self.id}, user_id={self.user_id}, status='{self.status}')>"
def calculate_total(self) -> float:
return sum(item.quantity * item.unit_price for item in self.items)
class OrderItem(Base):
__tablename__ = "order_items"
id = Column(Integer, primary_key=True, autoincrement=True)
order_id = Column(Integer, ForeignKey("orders.id", ondelete="CASCADE"), nullable=False)
product_id = Column(Integer, ForeignKey("products.id", ondelete="RESTRICT"), nullable=False)
quantity = Column(Integer, nullable=False)
unit_price = Column(Float, nullable=False)
order = relationship("Order", back_populates="items")
product = relationship("Product", back_populates="order_items")
def __repr__(self) -> str:
return f"<OrderItem(order_id={self.order_id}, product_id={self.product_id}, qty={self.quantity})>"
class Database:
def __init__(self, database_url: str = "sqlite:///app.db", echo: bool = False):
self.engine = create_engine(database_url, echo=echo)
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
def create_tables(self) -> None:
Base.metadata.create_all(self.engine)
def drop_tables(self) -> None:
Base.metadata.drop_all(self.engine)
def get_session(self) -> Session:
return self.SessionLocal()
@staticmethod
def session_context(session: Session):
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()28.3.2 查询构建
python
from sqlalchemy import select, update, delete, func, and_, or_, not_
from sqlalchemy.orm import joinedload, selectinload, subqueryload
from typing import TypeVar, Generic, Type, Any
ModelType = TypeVar("ModelType", bound=Base)
class BaseRepository(Generic[ModelType]):
def __init__(self, model: Type[ModelType], session: Session):
self.model = model
self.session = session
def create(self, **kwargs) -> ModelType:
instance = self.model(**kwargs)
self.session.add(instance)
self.session.commit()
self.session.refresh(instance)
return instance
def find_by_id(self, id: int) -> Optional[ModelType]:
return self.session.get(self.model, id)
def find_all(self, skip: int = 0, limit: int = 100) -> List[ModelType]:
stmt = select(self.model).offset(skip).limit(limit)
return list(self.session.scalars(stmt))
def update(self, id: int, **kwargs) -> Optional[ModelType]:
instance = self.find_by_id(id)
if instance:
for key, value in kwargs.items():
setattr(instance, key, value)
self.session.commit()
self.session.refresh(instance)
return instance
def delete(self, id: int) -> bool:
instance = self.find_by_id(id)
if instance:
self.session.delete(instance)
self.session.commit()
return True
return False
def count(self) -> int:
stmt = select(func.count()).select_from(self.model)
return self.session.scalar(stmt) or 0
def exists(self, id: int) -> bool:
return self.find_by_id(id) is not None
class UserRepositorySQLAlchemy(BaseRepository[User]):
def __init__(self, session: Session):
super().__init__(User, session)
def find_by_username(self, username: str) -> Optional[User]:
stmt = select(User).where(User.username == username)
return self.session.scalar(stmt)
def find_by_email(self, email: str) -> Optional[User]:
stmt = select(User).where(User.email == email)
return self.session.scalar(stmt)
def search_users(
self,
keyword: Optional[str] = None,
is_active: Optional[bool] = None,
is_admin: Optional[bool] = None,
skip: int = 0,
limit: int = 100
) -> List[User]:
stmt = select(User)
conditions = []
if keyword:
conditions.append(
or_(
User.username.ilike(f"%{keyword}%"),
User.email.ilike(f"%{keyword}%")
)
)
if is_active is not None:
conditions.append(User.is_active == is_active)
if is_admin is not None:
conditions.append(User.is_admin == is_admin)
if conditions:
stmt = stmt.where(and_(*conditions))
stmt = stmt.offset(skip).limit(limit).order_by(User.created_at.desc())
return list(self.session.scalars(stmt))
def get_user_with_orders(self, user_id: int) -> Optional[User]:
stmt = (
select(User)
.options(selectinload(User.orders))
.where(User.id == user_id)
)
return self.session.scalar(stmt)
class ProductRepositorySQLAlchemy(BaseRepository[Product]):
def __init__(self, session: Session):
super().__init__(Product, session)
def find_by_category(self, category: str) -> List[Product]:
stmt = select(Product).where(Product.category == category).order_by(Product.name)
return list(self.session.scalars(stmt))
def find_available_products(self, skip: int = 0, limit: int = 100) -> List[Product]:
stmt = (
select(Product)
.where(and_(Product.is_available == True, Product.stock > 0))
.offset(skip)
.limit(limit)
.order_by(Product.created_at.desc())
)
return list(self.session.scalars(stmt))
def search_products(
self,
keyword: str,
category: Optional[str] = None,
min_price: Optional[float] = None,
max_price: Optional[float] = None
) -> List[Product]:
stmt = select(Product)
conditions = [
or_(
Product.name.ilike(f"%{keyword}%"),
Product.description.ilike(f"%{keyword}%")
)
]
if category:
conditions.append(Product.category == category)
if min_price is not None:
conditions.append(Product.price >= min_price)
if max_price is not None:
conditions.append(Product.price <= max_price)
stmt = stmt.where(and_(*conditions)).order_by(Product.name)
return list(self.session.scalars(stmt))
def get_categories(self) -> List[str]:
stmt = select(Product.category).distinct().order_by(Product.category)
return list(self.session.scalars(stmt))
def get_stock_statistics(self) -> dict:
stmt = select(
Product.category,
func.count(Product.id).label("product_count"),
func.sum(Product.stock).label("total_stock"),
func.avg(Product.price).label("avg_price")
).group_by(Product.category)
results = self.session.execute(stmt).all()
return [
{
"category": row.category,
"product_count": row.product_count,
"total_stock": row.total_stock or 0,
"avg_price": round(row.avg_price, 2) if row.avg_price else 0
}
for row in results
]
class OrderRepositorySQLAlchemy(BaseRepository[Order]):
def __init__(self, session: Session):
super().__init__(Order, session)
def find_by_user(self, user_id: int, skip: int = 0, limit: int = 100) -> List[Order]:
stmt = (
select(Order)
.where(Order.user_id == user_id)
.options(selectinload(Order.items))
.offset(skip)
.limit(limit)
.order_by(Order.order_date.desc())
)
return list(self.session.scalars(stmt))
def find_by_status(self, status: str) -> List[Order]:
stmt = (
select(Order)
.where(Order.status == status)
.options(joinedload(Order.user))
.order_by(Order.order_date)
)
return list(self.session.scalars(stmt))
def get_order_with_details(self, order_id: int) -> Optional[Order]:
stmt = (
select(Order)
.options(
joinedload(Order.user),
selectinload(Order.items).joinedload(OrderItem.product)
)
.where(Order.id == order_id)
)
return self.session.scalar(stmt)
def get_daily_sales(self, start_date: datetime, end_date: datetime) -> List[dict]:
stmt = (
select(
func.date(Order.order_date).label("date"),
func.count(Order.id).label("order_count"),
func.sum(Order.total_amount).label("total_revenue")
)
.where(
and_(
Order.order_date >= start_date,
Order.order_date <= end_date,
Order.status != "cancelled"
)
)
.group_by(func.date(Order.order_date))
.order_by(func.date(Order.order_date))
)
results = self.session.execute(stmt).all()
return [
{
"date": row.date,
"order_count": row.order_count,
"total_revenue": row.total_revenue or 0
}
for row in results
]
def get_top_customers(self, limit: int = 10) -> List[dict]:
stmt = (
select(
User.id,
User.username,
User.email,
func.count(Order.id).label("order_count"),
func.sum(Order.total_amount).label("total_spent")
)
.join(Order, User.id == Order.user_id)
.where(Order.status != "cancelled")
.group_by(User.id)
.order_by(func.sum(Order.total_amount).desc())
.limit(limit)
)
results = self.session.execute(stmt).all()
return [
{
"user_id": row.id,
"username": row.username,
"email": row.email,
"order_count": row.order_count,
"total_spent": row.total_spent or 0
}
for row in results
]28.3.3 高级查询技术
python
from sqlalchemy import case, literal_column, over
from sqlalchemy.sql import functions as sql_func
class AdvancedQueries:
def __init__(self, session: Session):
self.session = session
def complex_join_query(self) -> List[dict]:
stmt = (
select(
Order.id,
Order.order_date,
User.username,
User.email,
func.count(OrderItem.id).label("item_count"),
func.sum(OrderItem.quantity * OrderItem.unit_price).label("total")
)
.join(User, Order.user_id == User.id)
.join(OrderItem, Order.id == OrderItem.order_id)
.where(Order.status == "delivered")
.group_by(Order.id, User.id)
.having(func.sum(OrderItem.quantity * OrderItem.unit_price) > 100)
.order_by(func.sum(OrderItem.quantity * OrderItem.unit_price).desc())
)
results = self.session.execute(stmt).all()
return [dict(row._mapping) for row in results]
def subquery_example(self) -> List[dict]:
avg_price_subquery = (
select(
Product.category,
func.avg(Product.price).label("avg_price")
)
.group_by(Product.category)
.subquery()
)
stmt = (
select(
Product.name,
Product.category,
Product.price,
avg_price_subquery.c.avg_price
)
.join(
avg_price_subquery,
Product.category == avg_price_subquery.c.category
)
.where(Product.price > avg_price_subquery.c.avg_price)
.order_by(Product.category, Product.price.desc())
)
results = self.session.execute(stmt).all()
return [dict(row._mapping) for row in results]
def window_function_example(self) -> List[dict]:
rank_window = over(
func.rank(),
partition_by=Product.category,
order_by=Product.price.desc()
)
stmt = (
select(
Product.name,
Product.category,
Product.price,
rank_window.label("price_rank")
)
.select_from(Product)
)
results = self.session.execute(stmt).all()
return [
{
"name": row.name,
"category": row.category,
"price": row.price,
"price_rank": row.price_rank
}
for row in results
]
def case_expression_example(self) -> List[dict]:
price_category = case(
(Product.price < 50, "Budget"),
(Product.price < 100, "Mid-range"),
(Product.price < 500, "Premium"),
else_="Luxury"
)
stmt = (
select(
Product.name,
Product.price,
Product.category,
price_category.label("price_tier")
)
.order_by(Product.price)
)
results = self.session.execute(stmt).all()
return [dict(row._mapping) for row in results]
def cte_example(self) -> List[dict]:
monthly_sales = (
select(
func.strftime("%Y-%m", Order.order_date).label("month"),
func.sum(Order.total_amount).label("monthly_total")
)
.where(Order.status == "delivered")
.group_by(func.strftime("%Y-%m", Order.order_date))
.cte("monthly_sales")
)
stmt = (
select(
monthly_sales.c.month,
monthly_sales.c.monthly_total,
sql_func.lag(monthly_sales.c.monthly_total).over(
order_by=monthly_sales.c.month
).label("previous_month"),
(
monthly_sales.c.monthly_total -
sql_func.lag(monthly_sales.c.monthly_total).over(
order_by=monthly_sales.c.month
)
).label("growth")
)
.order_by(monthly_sales.c.month)
)
results = self.session.execute(stmt).all()
return [dict(row._mapping) for row in results]28.4 数据库连接池
28.4.1 连接池配置
python
from sqlalchemy import create_engine, event
from sqlalchemy.pool import QueuePool, NullPool
import logging
logger = logging.getLogger(__name__)
class ConnectionPoolManager:
def __init__(
self,
database_url: str,
pool_size: int = 5,
max_overflow: int = 10,
pool_timeout: int = 30,
pool_recycle: int = 3600,
echo: bool = False
):
self.database_url = database_url
self.pool_size = pool_size
self.max_overflow = max_overflow
self.pool_timeout = pool_timeout
self.pool_recycle = pool_recycle
self.echo = echo
self._engine = None
@property
def engine(self):
if self._engine is None:
self._engine = create_engine(
self.database_url,
poolclass=QueuePool,
pool_size=self.pool_size,
max_overflow=self.max_overflow,
pool_timeout=self.pool_timeout,
pool_recycle=self.pool_recycle,
echo=self.echo,
echo_pool=True,
pre_ping=True
)
self._setup_events()
return self._engine
def _setup_events(self):
@event.listens_for(self._engine, "connect")
def on_connect(dbapi_conn, connection_record):
logger.debug(f"Connection created: {connection_record.info}")
@event.listens_for(self._engine, "checkout")
def on_checkout(dbapi_conn, connection_record, connection_proxy):
logger.debug(f"Connection checked out from pool")
@event.listens_for(self._engine, "checkin")
def on_checkin(dbapi_conn, connection_record):
logger.debug(f"Connection returned to pool")
@event.listens_for(self._engine, "checkout")
def on_checkout_error(dbapi_conn, connection_record, connection_proxy):
logger.warning(f"Connection checkout failed, attempting reconnect")
def get_pool_status(self) -> dict:
pool = self._engine.pool
return {
"pool_size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"invalid": pool.invalidatedcount() if hasattr(pool, "invalidatedcount") else 0
}
def dispose(self):
if self._engine:
self._engine.dispose()
logger.info("Connection pool disposed")
class DatabaseConfig:
SQLITE_MEMORY = "sqlite:///:memory:"
SQLITE_FILE = "sqlite:///app.db"
@staticmethod
def postgresql(
host: str = "localhost",
port: int = 5432,
database: str = "app",
user: str = "postgres",
password: str = ""
) -> str:
return f"postgresql://{user}:{password}@{host}:{port}/{database}"
@staticmethod
def mysql(
host: str = "localhost",
port: int = 3306,
database: str = "app",
user: str = "root",
password: str = ""
) -> str:
return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
class DatabaseManager:
_instance = None
_pool_manager = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def initialize(
cls,
database_url: str,
pool_size: int = 5,
max_overflow: int = 10
) -> None:
cls._pool_manager = ConnectionPoolManager(
database_url=database_url,
pool_size=pool_size,
max_overflow=max_overflow
)
@classmethod
def get_engine(cls):
if cls._pool_manager is None:
raise RuntimeError("DatabaseManager not initialized")
return cls._pool_manager.engine
@classmethod
def get_session(cls) -> Session:
engine = cls.get_engine()
SessionLocal = sessionmaker(bind=engine)
return SessionLocal()
@classmethod
def get_pool_status(cls) -> dict:
if cls._pool_manager is None:
return {"error": "Pool not initialized"}
return cls._pool_manager.get_pool_status()
@classmethod
def dispose(cls) -> None:
if cls._pool_manager:
cls._pool_manager.dispose()
cls._pool_manager = None28.4.2 异步数据库操作
python
from sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncSession,
async_sessionmaker
)
from sqlalchemy import select
import asyncio
class AsyncDatabase:
def __init__(self, database_url: str = "sqlite+aiosqlite:///app.db"):
self.engine = create_async_engine(database_url, echo=False)
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False
)
async def create_tables(self):
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_session(self) -> AsyncSession:
async with self.async_session() as session:
yield session
class AsyncUserRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, username: str, email: str, password_hash: str) -> User:
user = User(
username=username,
email=email,
password_hash=password_hash
)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user)
return user
async def find_by_id(self, user_id: int) -> Optional[User]:
result = await self.session.execute(
select(User).where(User.id == user_id)
)
return result.scalar_one_or_none()
async def find_by_username(self, username: str) -> Optional[User]:
result = await self.session.execute(
select(User).where(User.username == username)
)
return result.scalar_one_or_none()
async def find_all(self, skip: int = 0, limit: int = 100) -> List[User]:
result = await self.session.execute(
select(User).offset(skip).limit(limit)
)
return list(result.scalars().all())
async def update(self, user_id: int, **kwargs) -> Optional[User]:
user = await self.find_by_id(user_id)
if user:
for key, value in kwargs.items():
setattr(user, key, value)
await self.session.commit()
await self.session.refresh(user)
return user
async def delete(self, user_id: int) -> bool:
user = await self.find_by_id(user_id)
if user:
await self.session.delete(user)
await self.session.commit()
return True
return False
async def async_main():
db = AsyncDatabase()
await db.create_tables()
async with db.async_session() as session:
user_repo = AsyncUserRepository(session)
user = await user_repo.create(
username="testuser",
email="test@example.com",
password_hash="hashed_password"
)
print(f"Created user: {user}")
found = await user_repo.find_by_username("testuser")
print(f"Found user: {found}")
users = await user_repo.find_all()
print(f"All users: {users}")
if __name__ == "__main__":
asyncio.run(async_main())28.5 数据库迁移
28.5.1 Alembic配置
python
from alembic.config import Config
from alembic import command
from pathlib import Path
class MigrationManager:
def __init__(self, database_url: str, migrations_dir: str = "migrations"):
self.database_url = database_url
self.migrations_dir = Path(migrations_dir)
self.config = self._create_config()
def _create_config(self) -> Config:
config = Config()
config.set_main_option("sqlalchemy.url", self.database_url)
config.set_main_option("script_location", str(self.migrations_dir))
return config
def init_migrations(self) -> None:
if not self.migrations_dir.exists():
command.init(self.config, str(self.migrations_dir))
print(f"Initialized migrations directory: {self.migrations_dir}")
else:
print(f"Migrations directory already exists: {self.migrations_dir}")
def create_migration(self, message: str = "auto migration") -> None:
command.revision(self.config, autogenerate=True, message=message)
print(f"Created migration: {message}")
def upgrade(self, revision: str = "head") -> None:
command.upgrade(self.config, revision)
print(f"Upgraded database to revision: {revision}")
def downgrade(self, revision: str = "-1") -> None:
command.downgrade(self.config, revision)
print(f"Downgraded database to revision: {revision}")
def current(self) -> None:
command.current(self.config)
def history(self) -> None:
command.history(self.config)
def alembic_env_py_template():
return '''
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from models import Base
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
'''28.5.2 迁移脚本示例
python
from alembic import op
import sqlalchemy as sa
from datetime import datetime
def upgrade():
op.create_table(
'users',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('username', sa.String(length=50), nullable=False),
sa.Column('email', sa.String(length=100), nullable=False),
sa.Column('password_hash', sa.String(length=128), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('is_admin', sa.Boolean(), nullable=True, default=False),
sa.Column('created_at', sa.DateTime(), nullable=True, default=datetime.utcnow),
sa.Column('updated_at', sa.DateTime(), nullable=True, default=datetime.utcnow),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('username'),
sa.UniqueConstraint('email')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=False)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
op.create_table(
'products',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('price', sa.Float(), nullable=False),
sa.Column('category', sa.String(length=50), nullable=False),
sa.Column('stock', sa.Integer(), nullable=True, default=0),
sa.Column('is_available', sa.Boolean(), nullable=True, default=True),
sa.Column('created_at', sa.DateTime(), nullable=True, default=datetime.utcnow),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_products_category'), 'products', ['category'], unique=False)
op.create_table(
'orders',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('order_date', sa.DateTime(), nullable=True, default=datetime.utcnow),
sa.Column('status', sa.String(length=20), nullable=True, default='pending'),
sa.Column('total_amount', sa.Float(), nullable=True, default=0.0),
sa.Column('shipping_address', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table(
'order_items',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('order_id', sa.Integer(), nullable=False),
sa.Column('product_id', sa.Integer(), nullable=False),
sa.Column('quantity', sa.Integer(), nullable=False),
sa.Column('unit_price', sa.Float(), nullable=False),
sa.ForeignKeyConstraint(['order_id'], ['orders.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['product_id'], ['products.id'], ondelete='RESTRICT'),
sa.PrimaryKeyConstraint('id')
)
def downgrade():
op.drop_table('order_items')
op.drop_table('orders')
op.drop_index(op.f('ix_products_category'), table_name='products')
op.drop_table('products')
op.drop_index(op.f('ix_users_username'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')28.6 数据库安全实践
28.6.1 SQL注入防护
python
import re
from typing import Any
class SQLInjectionProtector:
DANGEROUS_PATTERNS = [
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|ALTER|CREATE|TRUNCATE)\b)",
r"(--|#|/\*|\*/)",
r"(;|\||`)",
r"(\b(OR|AND)\b\s+\d+\s*=\s*\d+)",
r"(CONCAT|CHAR|ASCII|HEX|UNHEX)",
]
@classmethod
def sanitize_input(cls, value: str) -> str:
if not isinstance(value, str):
return value
for pattern in cls.DANGEROUS_PATTERNS:
if re.search(pattern, value, re.IGNORECASE):
raise ValueError(f"Potentially dangerous input detected: {value}")
return value.strip()
@classmethod
def validate_identifier(cls, identifier: str) -> str:
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier):
raise ValueError(f"Invalid identifier: {identifier}")
return identifier
class SecureQueryBuilder:
def __init__(self, table: str):
self.table = SQLInjectionProtector.validate_identifier(table)
self._conditions = []
self._params = []
self._order_by = None
self._limit = None
self._offset = None
def where(self, column: str, operator: str, value: Any) -> "SecureQueryBuilder":
column = SQLInjectionProtector.validate_identifier(column)
valid_operators = {"=", "!=", "<", ">", "<=", ">=", "LIKE", "IN", "IS"}
operator = operator.upper()
if operator not in valid_operators:
raise ValueError(f"Invalid operator: {operator}")
if operator == "IN":
if not isinstance(value, (list, tuple)):
raise ValueError("IN operator requires a list or tuple")
placeholders = ", ".join(["?"] * len(value))
self._conditions.append(f"{column} IN ({placeholders})")
self._params.extend(value)
elif operator == "IS":
self._conditions.append(f"{column} IS {value}")
else:
self._conditions.append(f"{column} {operator} ?")
self._params.append(value)
return self
def order_by(self, column: str, direction: str = "ASC") -> "SecureQueryBuilder":
column = SQLInjectionProtector.validate_identifier(column)
direction = direction.upper()
if direction not in ("ASC", "DESC"):
raise ValueError(f"Invalid direction: {direction}")
self._order_by = f"ORDER BY {column} {direction}"
return self
def limit(self, limit: int) -> "SecureQueryBuilder":
if not isinstance(limit, int) or limit < 0:
raise ValueError("Limit must be a non-negative integer")
self._limit = f"LIMIT {limit}"
return self
def offset(self, offset: int) -> "SecureQueryBuilder":
if not isinstance(offset, int) or offset < 0:
raise ValueError("Offset must be a non-negative integer")
self._offset = f"OFFSET {offset}"
return self
def build_select(self, columns: list[str] = None) -> tuple[str, list]:
if columns:
columns = [SQLInjectionProtector.validate_identifier(c) for c in columns]
select_clause = ", ".join(columns)
else:
select_clause = "*"
query = f"SELECT {select_clause} FROM {self.table}"
if self._conditions:
query += " WHERE " + " AND ".join(self._conditions)
if self._order_by:
query += f" {self._order_by}"
if self._limit:
query += f" {self._limit}"
if self._offset:
query += f" {self._offset}"
return query, self._params
def build_count(self) -> tuple[str, list]:
query = f"SELECT COUNT(*) FROM {self.table}"
if self._conditions:
query += " WHERE " + " AND ".join(self._conditions)
return query, self._params
def demonstrate_secure_queries():
builder = SecureQueryBuilder("users")
query, params = (
builder
.where("is_active", "=", True)
.where("created_at", ">", "2024-01-01")
.order_by("created_at", "DESC")
.limit(10)
.build_select(["id", "username", "email"])
)
print(f"Query: {query}")
print(f"Params: {params}")28.6.2 数据加密
python
import os
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from typing import Optional
class FieldEncryption:
def __init__(self, encryption_key: Optional[bytes] = None):
if encryption_key:
self.key = encryption_key
else:
self.key = Fernet.generate_key()
self.cipher = Fernet(self.key)
@classmethod
def derive_key_from_password(cls, password: str, salt: Optional[bytes] = None) -> tuple["FieldEncryption", bytes]:
if salt is None:
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=480000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
return cls(key), salt
def encrypt(self, plaintext: str) -> str:
if not plaintext:
return plaintext
encrypted = self.cipher.encrypt(plaintext.encode())
return base64.urlsafe_b64encode(encrypted).decode()
def decrypt(self, ciphertext: str) -> str:
if not ciphertext:
return ciphertext
encrypted = base64.urlsafe_b64decode(ciphertext.encode())
decrypted = self.cipher.decrypt(encrypted)
return decrypted.decode()
class EncryptedField:
def __init__(self, encryption: FieldEncryption):
self.encryption = encryption
def __get__(self, obj, objtype=None):
if obj is None:
return self
encrypted_value = obj.__dict__.get(self.name + "_encrypted")
if encrypted_value:
return self.encryption.decrypt(encrypted_value)
return None
def __set__(self, obj, value):
if value:
encrypted = self.encryption.encrypt(value)
obj.__dict__[self.name + "_encrypted"] = encrypted
else:
obj.__dict__[self.name + "_encrypted"] = None
def __set_name__(self, owner, name):
self.name = name
class SecureUser:
encryption = FieldEncryption()
id: int
username: str
email: str
credit_card: EncryptedField = EncryptedField(encryption)
ssn: EncryptedField = EncryptedField(encryption)
def __init__(self, username: str, email: str):
self.username = username
self.email = email28.7 Repository模式与数据访问层
28.7.1 Repository模式实现
python
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Optional, List, Protocol
from dataclasses import dataclass
T = TypeVar("T")
ID = TypeVar("ID")
class Repository(Protocol[T, ID]):
def find_by_id(self, id: ID) -> Optional[T]: ...
def find_all(self) -> List[T]: ...
def save(self, entity: T) -> T: ...
def delete(self, entity: T) -> None: ...
@dataclass
class Specification:
conditions: List[str]
params: dict
order_by: Optional[str] = None
limit: Optional[int] = None
offset: Optional[int] = None
@classmethod
def where(cls, condition: str, **params) -> "Specification":
return cls(conditions=[condition], params=params)
def and_where(self, condition: str, **params) -> "Specification":
self.conditions.append(condition)
self.params.update(params)
return self
def order_by_clause(self, clause: str) -> "Specification":
self.order_by = clause
return self
def limit_rows(self, limit: int) -> "Specification":
self.limit = limit
return self
def offset_rows(self, offset: int) -> "Specification":
self.offset = offset
return self
class UnitOfWork:
def __init__(self, session: Session):
self.session = session
self._committed = False
def commit(self) -> None:
self.session.commit()
self._committed = True
def rollback(self) -> None:
self.session.rollback()
def __enter__(self) -> "UnitOfWork":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if exc_type:
self.rollback()
elif not self._committed:
self.commit()
class ServiceLayer:
def __init__(self, session: Session):
self.session = session
self.uow = UnitOfWork(session)
@property
def users(self) -> UserRepositorySQLAlchemy:
return UserRepositorySQLAlchemy(self.session)
@property
def products(self) -> ProductRepositorySQLAlchemy:
return ProductRepositorySQLAlchemy(self.session)
@property
def orders(self) -> OrderRepositorySQLAlchemy:
return OrderRepositorySQLAlchemy(self.session)28.7.2 完整应用示例
python
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, EmailStr
from typing import List, Optional
from contextlib import contextmanager
import uvicorn
app = FastAPI(title="E-Commerce API")
security = HTTPBearer()
class UserCreate(BaseModel):
username: str
email: EmailStr
password: str
class UserResponse(BaseModel):
id: int
username: str
email: str
is_active: bool
created_at: datetime
class Config:
from_attributes = True
class ProductCreate(BaseModel):
name: str
description: Optional[str] = None
price: float
category: str
stock: int = 0
class ProductResponse(BaseModel):
id: int
name: str
description: Optional[str]
price: float
category: str
stock: int
class Config:
from_attributes = True
class OrderItemCreate(BaseModel):
product_id: int
quantity: int
class OrderCreate(BaseModel):
items: List[OrderItemCreate]
shipping_address: str
class OrderResponse(BaseModel):
id: int
user_id: int
order_date: datetime
status: str
total_amount: float
class Config:
from_attributes = True
def get_db():
db = DatabaseManager.get_session()
try:
yield db
finally:
db.close()
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> User:
token = credentials.credentials
user_repo = UserRepositorySQLAlchemy(db)
user = user_repo.find_by_id(int(token))
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials"
)
return user
@app.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
def create_user(user_data: UserCreate, db: Session = Depends(get_db)):
user_repo = UserRepositorySQLAlchemy(db)
if user_repo.find_by_username(user_data.username):
raise HTTPException(status_code=400, detail="Username already exists")
if user_repo.find_by_email(user_data.email):
raise HTTPException(status_code=400, detail="Email already exists")
password_hash = UserRepository._hash_password(user_data.password)
user = user_repo.create(
username=user_data.username,
email=user_data.email,
password_hash=password_hash
)
return user
@app.get("/users/me", response_model=UserResponse)
def get_current_user_info(current_user: User = Depends(get_current_user)):
return current_user
@app.get("/products", response_model=List[ProductResponse])
def list_products(
category: Optional[str] = None,
min_price: Optional[float] = None,
max_price: Optional[float] = None,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db)
):
product_repo = ProductRepositorySQLAlchemy(db)
products = product_repo.search_products(
keyword="",
category=category,
min_price=min_price,
max_price=max_price
)
return products[skip:skip+limit]
@app.post("/orders", response_model=OrderResponse, status_code=status.HTTP_201_CREATED)
def create_order(
order_data: OrderCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
product_repo = ProductRepositorySQLAlchemy(db)
order_repo = OrderRepositorySQLAlchemy(db)
total_amount = 0.0
order_items = []
for item in order_data.items:
product = product_repo.find_by_id(item.product_id)
if not product:
raise HTTPException(status_code=404, detail=f"Product {item.product_id} not found")
if product.stock < item.quantity:
raise HTTPException(status_code=400, detail=f"Insufficient stock for {product.name}")
subtotal = product.price * item.quantity
total_amount += subtotal
order_items.append({
"product_id": item.product_id,
"quantity": item.quantity,
"unit_price": product.price
})
order = order_repo.create(
user_id=current_user.id,
total_amount=total_amount,
shipping_address=order_data.shipping_address,
status="pending"
)
for item_data in order_items:
product = product_repo.find_by_id(item_data["product_id"])
product.stock -= item_data["quantity"]
db.commit()
return order
@app.get("/orders/{order_id}", response_model=OrderResponse)
def get_order(
order_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
order_repo = OrderRepositorySQLAlchemy(db)
order = order_repo.get_order_with_details(order_id)
if not order:
raise HTTPException(status_code=404, detail="Order not found")
if order.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
return order
if __name__ == "__main__":
DatabaseManager.initialize("sqlite:///ecommerce.db")
DatabaseManager.get_engine().create_tables(Base.metadata)
uvicorn.run(app, host="0.0.0.0", port=8000)28.8 本章小结
本章详细介绍了Python数据库编程的核心概念和实践:
- 数据库基础:理解关系型数据库原理、SQL语言和设计范式
- SQLite操作:使用sqlite3模块进行轻量级数据库操作
- SQLAlchemy ORM:模型定义、关系映射、查询构建
- 连接池管理:配置、监控和优化数据库连接
- 异步数据库:使用asyncio和SQLAlchemy进行异步操作
- 数据库迁移:使用Alembic管理数据库版本
- 安全实践:SQL注入防护、数据加密、权限管理
- 架构模式:Repository模式、Unit of Work、服务层设计
练习题
- 创建一个图书管理系统,使用SQLite存储图书、作者和借阅记录
- 使用SQLAlchemy实现一个博客系统的数据模型
- 实现一个支持分页、排序和过滤的通用查询构建器
- 使用异步SQLAlchemy实现一个高并发的API服务
- 设计一个支持多租户的数据库架构