FastAPIとデータベース連携 – SQLAlchemyによるORM実装

FastAPIとデータベース連携 – SQLAlchemyによるORM実装

FastAPIでWebアプリケーションを開発する際、データベースとの連携は必須の機能です。この記事では、SQLAlchemyを使用してFastAPIでデータベース操作を実装する方法を詳しく解説します。

SQLAlchemyとは?

SQLAlchemyは、Pythonで最も人気のあるORM(Object-Relational Mapping)ライブラリです。

主な特徴

  • データベース抽象化: 複数のデータベースエンジンに対応
  • ORM機能: Pythonオブジェクトとデータベーステーブルのマッピング
  • クエリビルダー: Pythonらしい記法でSQLクエリを構築
  • マイグレーション: Alembicによるスキーマ変更管理

環境構築

必要なパッケージのインストール

# FastAPI、SQLAlchemy、データベースドライバーをインストール
pip install fastapi uvicorn[standard] sqlalchemy psycopg2-binary

# 開発環境用(SQLite使用)
pip install fastapi uvicorn[standard] sqlalchemy

プロジェクト構成

project/
├── app/
│   ├── __init__.py
│   ├── main.py              # FastAPIアプリケーション
│   ├── database.py          # データベース設定
│   ├── models.py            # SQLAlchemyモデル
│   ├── schemas.py           # Pydanticスキーマ
│   ├── crud.py              # データベース操作
│   └── dependencies.py      # 依存性注入
├── alembic/                 # マイグレーションファイル
└── requirements.txt

データベース設定

database.py – データベース接続設定

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import os

# データベースURL設定
DATABASE_URL = os.getenv(
    "DATABASE_URL", 
    "sqlite:///./app.db"  # 開発環境用のSQLite
)

# PostgreSQL使用例
# DATABASE_URL = "postgresql://user:password@localhost/dbname"

# エンジンの作成
engine = create_engine(
    DATABASE_URL,
    # SQLite用設定(PostgreSQL使用時は不要)
    connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}
)

# セッションローカルの作成
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# ベースクラスの作成
Base = declarative_base()

# データベースセッションの取得
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

SQLAlchemyモデルの定義

models.py – データベースモデル

from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text, ForeignKey, Float
from sqlalchemy.relationship import relationship
from sqlalchemy.sql import func
from .database import Base

class User(Base):
    __tablename__ = "users"

    id = Column(Integer, primary_key=True, index=True)
    username = Column(String(50), unique=True, index=True, nullable=False)
    email = Column(String(100), unique=True, index=True, nullable=False)
    hashed_password = Column(String(255), nullable=False)
    full_name = Column(String(100))
    is_active = Column(Boolean, default=True)
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), onupdate=func.now())

    # リレーションシップ
    posts = relationship("Post", back_populates="author")

class Post(Base):
    __tablename__ = "posts"

    id = Column(Integer, primary_key=True, index=True)
    title = Column(String(200), nullable=False)
    content = Column(Text)
    published = Column(Boolean, default=False)
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), onupdate=func.now())

    # 外部キー
    author_id = Column(Integer, ForeignKey("users.id"))

    # リレーションシップ
    author = relationship("User", back_populates="posts")
    tags = relationship("Tag", secondary="post_tags", back_populates="posts")

class Tag(Base):
    __tablename__ = "tags"

    id = Column(Integer, primary_key=True, index=True)
    name = Column(String(50), unique=True, nullable=False)

    # リレーションシップ
    posts = relationship("Post", secondary="post_tags", back_populates="tags")

# 多対多のための中間テーブル
from sqlalchemy import Table

post_tags = Table(
    'post_tags',
    Base.metadata,
    Column('post_id', Integer, ForeignKey('posts.id')),
    Column('tag_id', Integer, ForeignKey('tags.id'))
)

Pydanticスキーマの定義

schemas.py – リクエスト/レスポンススキーマ

from pydantic import BaseModel, EmailStr
from typing import List, Optional
from datetime import datetime

# ユーザースキーマ
class UserBase(BaseModel):
    username: str
    email: EmailStr
    full_name: Optional[str] = None

class UserCreate(UserBase):
    password: str

class UserUpdate(BaseModel):
    username: Optional[str] = None
    email: Optional[EmailStr] = None
    full_name: Optional[str] = None
    is_active: Optional[bool] = None

class User(UserBase):
    id: int
    is_active: bool
    created_at: datetime
    updated_at: Optional[datetime] = None

    class Config:
        orm_mode = True

# 投稿スキーマ
class PostBase(BaseModel):
    title: str
    content: Optional[str] = None
    published: bool = False

class PostCreate(PostBase):
    pass

class PostUpdate(BaseModel):
    title: Optional[str] = None
    content: Optional[str] = None
    published: Optional[bool] = None

class Post(PostBase):
    id: int
    author_id: int
    created_at: datetime
    updated_at: Optional[datetime] = None
    author: User

    class Config:
        orm_mode = True

# タグスキーマ
class TagBase(BaseModel):
    name: str

class TagCreate(TagBase):
    pass

class Tag(TagBase):
    id: int

    class Config:
        orm_mode = True

# 詳細なユーザー情報(投稿含む)
class UserWithPosts(User):
    posts: List[Post] = []

CRUD操作の実装

crud.py – データベース操作

from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from typing import List, Optional
from . import models, schemas
from passlib.context import CryptContext

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def get_password_hash(password: str) -> str:
    return pwd_context.hash(password)

def verify_password(plain_password: str, hashed_password: str) -> bool:
    return pwd_context.verify(plain_password, hashed_password)

# ユーザーCRUD操作
class UserCRUD:
    @staticmethod
    def get_user(db: Session, user_id: int) -> Optional[models.User]:
        return db.query(models.User).filter(models.User.id == user_id).first()

    @staticmethod
    def get_user_by_email(db: Session, email: str) -> Optional[models.User]:
        return db.query(models.User).filter(models.User.email == email).first()

    @staticmethod
    def get_user_by_username(db: Session, username: str) -> Optional[models.User]:
        return db.query(models.User).filter(models.User.username == username).first()

    @staticmethod
    def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[models.User]:
        return db.query(models.User).offset(skip).limit(limit).all()

    @staticmethod
    def create_user(db: Session, user: schemas.UserCreate) -> models.User:
        hashed_password = get_password_hash(user.password)
        db_user = models.User(
            username=user.username,
            email=user.email,
            hashed_password=hashed_password,
            full_name=user.full_name
        )
        db.add(db_user)
        db.commit()
        db.refresh(db_user)
        return db_user

    @staticmethod
    def update_user(db: Session, user_id: int, user_update: schemas.UserUpdate) -> Optional[models.User]:
        db_user = UserCRUD.get_user(db, user_id)
        if db_user:
            update_data = user_update.dict(exclude_unset=True)
            for field, value in update_data.items():
                setattr(db_user, field, value)
            db.commit()
            db.refresh(db_user)
        return db_user

    @staticmethod
    def delete_user(db: Session, user_id: int) -> bool:
        db_user = UserCRUD.get_user(db, user_id)
        if db_user:
            db.delete(db_user)
            db.commit()
            return True
        return False

# 投稿CRUD操作
class PostCRUD:
    @staticmethod
    def get_post(db: Session, post_id: int) -> Optional[models.Post]:
        return db.query(models.Post).filter(models.Post.id == post_id).first()

    @staticmethod
    def get_posts(db: Session, skip: int = 0, limit: int = 100, published_only: bool = False) -> List[models.Post]:
        query = db.query(models.Post)
        if published_only:
            query = query.filter(models.Post.published == True)
        return query.offset(skip).limit(limit).all()

    @staticmethod
    def get_posts_by_user(db: Session, user_id: int, skip: int = 0, limit: int = 100) -> List[models.Post]:
        return db.query(models.Post).filter(
            models.Post.author_id == user_id
        ).offset(skip).limit(limit).all()

    @staticmethod
    def create_post(db: Session, post: schemas.PostCreate, author_id: int) -> models.Post:
        db_post = models.Post(**post.dict(), author_id=author_id)
        db.add(db_post)
        db.commit()
        db.refresh(db_post)
        return db_post

    @staticmethod
    def update_post(db: Session, post_id: int, post_update: schemas.PostUpdate) -> Optional[models.Post]:
        db_post = PostCRUD.get_post(db, post_id)
        if db_post:
            update_data = post_update.dict(exclude_unset=True)
            for field, value in update_data.items():
                setattr(db_post, field, value)
            db.commit()
            db.refresh(db_post)
        return db_post

    @staticmethod
    def delete_post(db: Session, post_id: int) -> bool:
        db_post = PostCRUD.get_post(db, post_id)
        if db_post:
            db.delete(db_post)
            db.commit()
            return True
        return False

# 検索機能
class SearchCRUD:
    @staticmethod
    def search_posts(db: Session, query: str, skip: int = 0, limit: int = 100) -> List[models.Post]:
        return db.query(models.Post).filter(
            or_(
                models.Post.title.contains(query),
                models.Post.content.contains(query)
            )
        ).offset(skip).limit(limit).all()

FastAPIエンドポイントの実装

main.py – APIエンドポイント

from fastapi import FastAPI, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from . import models, schemas, crud
from .database import SessionLocal, engine, get_db

# データベーステーブルの作成
models.Base.metadata.create_all(bind=engine)

app = FastAPI(title="ブログAPI", version="1.0.0")

# ユーザーエンドポイント
@app.post("/users/", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
async def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
    # 既存ユーザーチェック
    db_user = crud.UserCRUD.get_user_by_email(db, email=user.email)
    if db_user:
        raise HTTPException(
            status_code=400,
            detail="このメールアドレスは既に登録されています"
        )

    db_user = crud.UserCRUD.get_user_by_username(db, username=user.username)
    if db_user:
        raise HTTPException(
            status_code=400,
            detail="このユーザー名は既に使用されています"
        )

    return crud.UserCRUD.create_user(db=db, user=user)

@app.get("/users/", response_model=List[schemas.User])
async def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
    users = crud.UserCRUD.get_users(db, skip=skip, limit=limit)
    return users

@app.get("/users/{user_id}", response_model=schemas.UserWithPosts)
async def read_user(user_id: int, db: Session = Depends(get_db)):
    db_user = crud.UserCRUD.get_user(db, user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="ユーザーが見つかりません")
    return db_user

@app.put("/users/{user_id}", response_model=schemas.User)
async def update_user(user_id: int, user_update: schemas.UserUpdate, db: Session = Depends(get_db)):
    db_user = crud.UserCRUD.update_user(db, user_id=user_id, user_update=user_update)
    if db_user is None:
        raise HTTPException(status_code=404, detail="ユーザーが見つかりません")
    return db_user

@app.delete("/users/{user_id}")
async def delete_user(user_id: int, db: Session = Depends(get_db)):
    success = crud.UserCRUD.delete_user(db, user_id=user_id)
    if not success:
        raise HTTPException(status_code=404, detail="ユーザーが見つかりません")
    return {"message": "ユーザーが削除されました"}

# 投稿エンドポイント
@app.post("/users/{user_id}/posts/", response_model=schemas.Post, status_code=status.HTTP_201_CREATED)
async def create_post_for_user(user_id: int, post: schemas.PostCreate, db: Session = Depends(get_db)):
    # ユーザーの存在確認
    db_user = crud.UserCRUD.get_user(db, user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="ユーザーが見つかりません")

    return crud.PostCRUD.create_post(db=db, post=post, author_id=user_id)

@app.get("/posts/", response_model=List[schemas.Post])
async def read_posts(skip: int = 0, limit: int = 100, published_only: bool = False, db: Session = Depends(get_db)):
    posts = crud.PostCRUD.get_posts(db, skip=skip, limit=limit, published_only=published_only)
    return posts

@app.get("/posts/{post_id}", response_model=schemas.Post)
async def read_post(post_id: int, db: Session = Depends(get_db)):
    db_post = crud.PostCRUD.get_post(db, post_id=post_id)
    if db_post is None:
        raise HTTPException(status_code=404, detail="投稿が見つかりません")
    return db_post

@app.put("/posts/{post_id}", response_model=schemas.Post)
async def update_post(post_id: int, post_update: schemas.PostUpdate, db: Session = Depends(get_db)):
    db_post = crud.PostCRUD.update_post(db, post_id=post_id, post_update=post_update)
    if db_post is None:
        raise HTTPException(status_code=404, detail="投稿が見つかりません")
    return db_post

@app.delete("/posts/{post_id}")
async def delete_post(post_id: int, db: Session = Depends(get_db)):
    success = crud.PostCRUD.delete_post(db, post_id=post_id)
    if not success:
        raise HTTPException(status_code=404, detail="投稿が見つかりません")
    return {"message": "投稿が削除されました"}

# 検索エンドポイント
@app.get("/search/posts/", response_model=List[schemas.Post])
async def search_posts(q: str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
    posts = crud.SearchCRUD.search_posts(db, query=q, skip=skip, limit=limit)
    return posts

マイグレーション管理(Alembic)

Alembicの設定

# Alembicのインストール
pip install alembic

# Alembicの初期化
alembic init alembic

alembic.ini の設定

# alembic.ini
[alembic]
script_location = alembic
sqlalchemy.url = sqlite:///./app.db

# PostgreSQL使用例
# sqlalchemy.url = postgresql://user:password@localhost/dbname

env.py の設定

# alembic/env.py
from alembic import context
from sqlalchemy import engine_from_config, pool
from logging.config import fileConfig

# アプリケーションのモデルをインポート
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from app.models import Base
from app.database import DATABASE_URL

config = context.config
config.set_main_option("sqlalchemy.url", DATABASE_URL)

fileConfig(config.config_file_name)
target_metadata = Base.metadata

def run_migrations_offline():
    url = config.get_main_option("sqlalchemy.url")
    context.configure(
        url=url,
        target_metadata=target_metadata,
        literal_binds=True,
        dialect_opts={"paramstyle": "named"},
    )

    with context.begin_transaction():
        context.run_migrations()

def run_migrations_online():
    connectable = engine_from_config(
        config.get_section(config.config_ini_section),
        prefix="sqlalchemy.",
        poolclass=pool.NullPool,
    )

    with connectable.connect() as connection:
        context.configure(
            connection=connection, target_metadata=target_metadata
        )

        with context.begin_transaction():
            context.run_migrations()

if context.is_offline_mode():
    run_migrations_offline()
else:
    run_migrations_online()

マイグレーションの実行

# 初回マイグレーションファイルの作成
alembic revision --autogenerate -m "Initial migration"

# マイグレーションの実行
alembic upgrade head

# 特定のリビジョンへの移行
alembic upgrade revision_id

# 一つ前のバージョンに戻す
alembic downgrade -1

高度なクエリ操作

複雑なクエリの例

from sqlalchemy import func, desc, and_, or_

# 集計クエリ
def get_user_post_count(db: Session):
    return db.query(
        models.User.username,
        func.count(models.Post.id).label('post_count')
    ).join(models.Post).group_by(models.User.id).all()

# サブクエリ
def get_users_with_published_posts(db: Session):
    subquery = db.query(models.Post.author_id).filter(
        models.Post.published == True
    ).distinct().subquery()

    return db.query(models.User).filter(
        models.User.id.in_(subquery)
    ).all()

# 複雑な検索
def advanced_post_search(db: Session, title: str = None, author_name: str = None, 
                        published: bool = None, skip: int = 0, limit: int = 100):
    query = db.query(models.Post).join(models.User)

    if title:
        query = query.filter(models.Post.title.contains(title))
    if author_name:
        query = query.filter(models.User.username.contains(author_name))
    if published is not None:
        query = query.filter(models.Post.published == published)

    return query.order_by(desc(models.Post.created_at)).offset(skip).limit(limit).all()

パフォーマンス最適化

接続プールの設定

# database.py
from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool

engine = create_engine(
    DATABASE_URL,
    poolclass=QueuePool,
    pool_size=10,          # 通常の接続数
    max_overflow=20,       # 最大接続数
    pool_pre_ping=True,    # 接続の有効性チェック
    pool_recycle=3600,     # 接続の再利用時間(秒)
)

遅延読み込みの最適化

from sqlalchemy.orm import joinedload, selectinload

# N+1問題の解決
def get_users_with_posts_optimized(db: Session):
    return db.query(models.User).options(
        joinedload(models.User.posts)
    ).all()

# 大量データの場合はselectinloadを使用
def get_users_with_many_posts(db: Session):
    return db.query(models.User).options(
        selectinload(models.User.posts)
    ).all()

まとめ

FastAPIとSQLAlchemyを組み合わせることで、以下の利点があります:

  1. 型安全性: PydanticとSQLAlchemyの型定義による安全なデータ処理
  2. 自動ドキュメント: スキーマ定義から自動的にAPI仕様書を生成
  3. スケーラビリティ: 接続プールやクエリ最適化による高性能
  4. 保守性: 明確な関心事の分離による保守しやすいコード

次回は、FastAPIでの認証・セキュリティ実装について詳しく解説します。JWT認証、OAuth2、セキュリティヘッダーの設定など、本格的なWebアプリケーションに必要なセキュリティ機能を学んでいきましょう。

参考リンク