GitHub Actions
Deploy backend from GitHub Actions
806492a
raw
history blame
6.44 kB
"""
CSRF Protection Middleware for Cookie-based Authentication
This middleware implements CSRF protection using the double-submit cookie pattern
to prevent Cross-Site Request Forgery attacks when using HTTP-only cookies.
"""
import secrets
from typing import Callable
from fastapi import Request, Response, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
import time
class CSRFMiddleware(BaseHTTPMiddleware):
"""
CSRF Protection Middleware for cookie-based authentication.
Implements the double-submit cookie pattern:
1. Generates CSRF token and stores in cookie
2. Client must include token in header for state-changing requests
3. Validates token on each protected request
"""
def __init__(
self,
app: Callable,
cookie_name: str = "csrf_token",
header_name: str = "X-CSRF-Token",
secure: bool = True,
httponly: bool = False,
samesite: str = "lax",
max_age: int = 3600, # 1 hour
exempt_paths: list = None,
safe_methods: list = None,
):
super().__init__(app)
self.cookie_name = cookie_name
self.header_name = header_name
self.secure = secure
self.httponly = httponly
self.samesite = samesite
self.max_age = max_age
self.exempt_paths = exempt_paths or ["/health", "/docs", "/openapi.json"]
self.safe_methods = safe_methods or ["GET", "HEAD", "OPTIONS", "TRACE"]
# Store tokens for validation (in production, use Redis or database)
self._tokens: dict[str, dict] = {}
self._cleanup_interval = 300 # 5 minutes
self._last_cleanup = time.time()
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip CSRF for exempt paths and safe methods
if (
self._is_path_exempt(request) or
request.method in self.safe_methods
):
return await call_next(request)
# Get or generate CSRF token
csrf_token = self._get_or_generate_token(request)
# Set CSRF cookie if not present
if self.cookie_name not in request.cookies:
response = await call_next(request)
self._set_csrf_cookie(response, csrf_token)
return response
# Validate CSRF token for state-changing requests
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
await self._validate_csrf_token(request, csrf_token)
# Add CSRF token to response headers for client access
response = await call_next(request)
response.headers[self.header_name] = csrf_token
return response
def _is_path_exempt(self, request: Request) -> bool:
"""Check if request path is exempt from CSRF protection."""
for path in self.exempt_paths:
if request.url.path.startswith(path):
return True
return False
def _get_or_generate_token(self, request: Request) -> str:
"""Get existing CSRF token or generate new one."""
# In production, store tokens in database/Redis with user_id
# For now, use session-based storage
session_id = getattr(request.state, "session_id", None)
# Clean up expired tokens periodically
self._cleanup_expired_tokens()
if session_id and session_id in self._tokens:
token_data = self._tokens[session_id]
if token_data["expires"] > time.time():
return token_data["token"]
else:
del self._tokens[session_id]
# Generate new token
token = secrets.token_urlsafe(32)
expires = time.time() + self.max_age
if session_id:
self._tokens[session_id] = {
"token": token,
"expires": expires
}
return token
def _set_csrf_cookie(self, response: Response, token: str):
"""Set CSRF token in response cookie."""
response.set_cookie(
key=self.cookie_name,
value=token,
max_age=self.max_age,
secure=self.secure,
httponly=self.httponly,
samesite=self.samesite,
path="/",
)
async def _validate_csrf_token(self, request: Request, expected_token: str):
"""Validate CSRF token from request header."""
# Get token from header
token = request.headers.get(self.header_name)
if not token:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="CSRF token missing",
headers={"X-Error": "CSRF token required"},
)
# Validate token matches expected
if not secrets.compare_digest(token, expected_token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid CSRF token",
headers={"X-Error": "CSRF token validation failed"},
)
# Check token expiration if we have session info
session_id = getattr(request.state, "session_id", None)
if session_id and session_id in self._tokens:
token_data = self._tokens[session_id]
if token_data["expires"] <= time.time():
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="CSRF token expired",
headers={"X-Error": "CSRF token expired"},
)
def _cleanup_expired_tokens(self):
"""Clean up expired CSRF tokens."""
now = time.time()
if now - self._last_cleanup > self._cleanup_interval:
expired_tokens = [
session_id for session_id, data in self._tokens.items()
if data["expires"] <= now
]
for session_id in expired_tokens:
del self._tokens[session_id]
self._last_cleanup = now
def get_csrf_token(request: Request) -> str:
"""
Get CSRF token from request headers.
Helper function for use in route handlers.
"""
return request.headers.get("X-CSRF-Token")
def validate_csrf_token(request: Request, token: str) -> bool:
"""
Validate CSRF token against expected token.
Helper function for use in route handlers.
"""
return request.headers.get("X-CSRF-Token") == token