Initial commit - cleaned repository

This commit is contained in:
jungwoo choi
2025-09-28 20:41:57 +09:00
commit e3c28f796a
188 changed files with 28102 additions and 0 deletions

View File

@ -0,0 +1,27 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies for Pillow and file type detection
RUN apt-get update && apt-get install -y \
gcc \
libmagic1 \
libjpeg-dev \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Create directories for thumbnails cache
RUN mkdir -p /tmp/thumbnails
# Expose port
EXPOSE 8000
# Run the application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]

View File

@ -0,0 +1,247 @@
"""
File Processor for handling file uploads and processing
"""
import hashlib
import mimetypes
from datetime import datetime
from typing import Dict, Any, Optional
import logging
import uuid
from fastapi import UploadFile
from models import FileType, FileStatus
logger = logging.getLogger(__name__)
class FileProcessor:
def __init__(self, minio_client, metadata_manager, thumbnail_generator):
self.minio_client = minio_client
self.metadata_manager = metadata_manager
self.thumbnail_generator = thumbnail_generator
def _determine_file_type(self, content_type: str) -> FileType:
"""Determine file type from content type"""
if content_type.startswith('image/'):
return FileType.IMAGE
elif content_type.startswith('video/'):
return FileType.VIDEO
elif content_type.startswith('audio/'):
return FileType.AUDIO
elif content_type in ['application/pdf', 'application/msword',
'application/vnd.openxmlformats-officedocument',
'text/plain', 'text/html', 'text/csv']:
return FileType.DOCUMENT
elif content_type in ['application/zip', 'application/x-rar-compressed',
'application/x-tar', 'application/gzip']:
return FileType.ARCHIVE
else:
return FileType.OTHER
def _calculate_file_hash(self, file_data: bytes) -> str:
"""Calculate SHA256 hash of file data"""
return hashlib.sha256(file_data).hexdigest()
async def process_upload(self, file: UploadFile, user_id: str,
bucket: str = "default",
public: bool = False,
generate_thumbnail: bool = True,
tags: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Process file upload"""
try:
# Read file data
file_data = await file.read()
file_size = len(file_data)
# Get content type
content_type = file.content_type or mimetypes.guess_type(file.filename)[0] or 'application/octet-stream'
# Generate file ID and object name
file_id = str(uuid.uuid4())
timestamp = datetime.now().strftime('%Y%m%d')
file_extension = file.filename.split('.')[-1] if '.' in file.filename else ''
object_name = f"{timestamp}/{user_id}/{file_id}.{file_extension}" if file_extension else f"{timestamp}/{user_id}/{file_id}"
# Calculate file hash
file_hash = self._calculate_file_hash(file_data)
# Check for duplicates
duplicates = await self.metadata_manager.find_duplicate_files(file_hash)
if duplicates and not public: # Allow duplicates for public files
# Return existing file info
existing = duplicates[0]
logger.info(f"Duplicate file detected: {existing['id']}")
return {
"file_id": existing["id"],
"filename": existing["filename"],
"size": existing["size"],
"content_type": existing["content_type"],
"file_type": existing["file_type"],
"bucket": existing["bucket"],
"public": existing["public"],
"has_thumbnail": existing.get("has_thumbnail", False),
"thumbnail_url": existing.get("thumbnail_url"),
"created_at": existing["created_at"],
"duplicate": True
}
# Upload to MinIO
upload_result = await self.minio_client.upload_file(
bucket=bucket,
object_name=object_name,
file_data=file_data,
content_type=content_type,
metadata={
"user_id": user_id,
"original_name": file.filename,
"upload_date": datetime.now().isoformat()
}
)
# Determine file type
file_type = self._determine_file_type(content_type)
# Generate thumbnail if applicable
has_thumbnail = False
thumbnail_url = None
if generate_thumbnail and file_type == FileType.IMAGE:
thumbnail_data = await self.thumbnail_generator.generate_thumbnail(
file_data=file_data,
content_type=content_type
)
if thumbnail_data:
has_thumbnail = True
# Generate multiple sizes
await self.thumbnail_generator.generate_multiple_sizes(
file_data=file_data,
content_type=content_type,
file_id=file_id
)
if public:
thumbnail_url = await self.minio_client.generate_presigned_download_url(
bucket="thumbnails",
object_name=f"thumbnails/{file_id}_medium.jpg",
expires_in=86400 * 30 # 30 days
)
# Create metadata
metadata = {
"id": file_id,
"filename": file.filename,
"original_name": file.filename,
"size": file_size,
"content_type": content_type,
"file_type": file_type.value,
"bucket": bucket,
"object_name": object_name,
"user_id": user_id,
"hash": file_hash,
"public": public,
"has_thumbnail": has_thumbnail,
"thumbnail_url": thumbnail_url,
"tags": tags or {},
"metadata": {
"etag": upload_result.get("etag"),
"version_id": upload_result.get("version_id")
}
}
# Save metadata to database
await self.metadata_manager.create_file_metadata(metadata)
# Generate download URL if public
download_url = None
if public:
download_url = await self.minio_client.generate_presigned_download_url(
bucket=bucket,
object_name=object_name,
expires_in=86400 * 30 # 30 days
)
logger.info(f"File uploaded successfully: {file_id}")
return {
"file_id": file_id,
"filename": file.filename,
"size": file_size,
"content_type": content_type,
"file_type": file_type.value,
"bucket": bucket,
"public": public,
"has_thumbnail": has_thumbnail,
"thumbnail_url": thumbnail_url,
"download_url": download_url,
"created_at": datetime.now()
}
except Exception as e:
logger.error(f"File processing error: {e}")
raise
async def process_large_file(self, file: UploadFile, user_id: str,
bucket: str = "default",
chunk_size: int = 1024 * 1024 * 5) -> Dict[str, Any]:
"""Process large file upload in chunks"""
try:
file_id = str(uuid.uuid4())
timestamp = datetime.now().strftime('%Y%m%d')
file_extension = file.filename.split('.')[-1] if '.' in file.filename else ''
object_name = f"{timestamp}/{user_id}/{file_id}.{file_extension}"
# Initialize multipart upload
hasher = hashlib.sha256()
total_size = 0
# Process file in chunks
chunks = []
while True:
chunk = await file.read(chunk_size)
if not chunk:
break
chunks.append(chunk)
hasher.update(chunk)
total_size += len(chunk)
# Combine chunks and upload
file_data = b''.join(chunks)
file_hash = hasher.hexdigest()
# Upload to MinIO
content_type = file.content_type or 'application/octet-stream'
await self.minio_client.upload_file(
bucket=bucket,
object_name=object_name,
file_data=file_data,
content_type=content_type
)
# Create metadata
metadata = {
"id": file_id,
"filename": file.filename,
"original_name": file.filename,
"size": total_size,
"content_type": content_type,
"file_type": self._determine_file_type(content_type).value,
"bucket": bucket,
"object_name": object_name,
"user_id": user_id,
"hash": file_hash,
"public": False,
"has_thumbnail": False
}
await self.metadata_manager.create_file_metadata(metadata)
return {
"file_id": file_id,
"filename": file.filename,
"size": total_size,
"message": "Large file uploaded successfully"
}
except Exception as e:
logger.error(f"Large file processing error: {e}")
raise

View File

@ -0,0 +1,541 @@
"""
File Management Service - S3-compatible Object Storage with MinIO
"""
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Query, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
import uvicorn
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
import asyncio
import os
import hashlib
import magic
import io
from contextlib import asynccontextmanager
import logging
from pathlib import Path
import json
# Import custom modules
from models import FileMetadata, FileUploadResponse, FileListResponse, StorageStats
from minio_client import MinIOManager
from thumbnail_generator import ThumbnailGenerator
from metadata_manager import MetadataManager
from file_processor import FileProcessor
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global instances
minio_manager = None
thumbnail_generator = None
metadata_manager = None
file_processor = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global minio_manager, thumbnail_generator, metadata_manager, file_processor
try:
# Initialize MinIO client
minio_manager = MinIOManager(
endpoint=os.getenv("MINIO_ENDPOINT", "minio:9000"),
access_key=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
secret_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
secure=os.getenv("MINIO_SECURE", "false").lower() == "true"
)
await minio_manager.initialize()
logger.info("MinIO client initialized")
# Initialize Metadata Manager (MongoDB)
metadata_manager = MetadataManager(
mongodb_url=os.getenv("MONGODB_URL", "mongodb://mongodb:27017"),
database=os.getenv("FILES_DB_NAME", "files_db")
)
await metadata_manager.connect()
logger.info("Metadata manager connected to MongoDB")
# Initialize Thumbnail Generator
thumbnail_generator = ThumbnailGenerator(
minio_client=minio_manager,
cache_dir="/tmp/thumbnails"
)
logger.info("Thumbnail generator initialized")
# Initialize File Processor
file_processor = FileProcessor(
minio_client=minio_manager,
metadata_manager=metadata_manager,
thumbnail_generator=thumbnail_generator
)
logger.info("File processor initialized")
except Exception as e:
logger.error(f"Failed to start File service: {e}")
raise
yield
# Shutdown
if metadata_manager:
await metadata_manager.close()
logger.info("File service shutdown complete")
app = FastAPI(
title="File Management Service",
description="S3-compatible object storage with MinIO",
version="1.0.0",
lifespan=lifespan
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"service": "File Management Service",
"status": "running",
"timestamp": datetime.now().isoformat()
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "files",
"components": {
"minio": "connected" if minio_manager and minio_manager.is_connected else "disconnected",
"mongodb": "connected" if metadata_manager and metadata_manager.is_connected else "disconnected",
"thumbnail_generator": "ready" if thumbnail_generator else "not_initialized"
},
"timestamp": datetime.now().isoformat()
}
# File Upload Endpoints
@app.post("/api/files/upload")
async def upload_file(
file: UploadFile = File(...),
user_id: str = Form(...),
bucket: str = Form("default"),
public: bool = Form(False),
generate_thumbnail: bool = Form(True),
tags: Optional[str] = Form(None)
):
"""Upload a file to object storage"""
try:
# Validate file
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
# Process file upload
result = await file_processor.process_upload(
file=file,
user_id=user_id,
bucket=bucket,
public=public,
generate_thumbnail=generate_thumbnail,
tags=json.loads(tags) if tags else {}
)
return FileUploadResponse(**result)
except Exception as e:
logger.error(f"File upload error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/files/upload-multiple")
async def upload_multiple_files(
files: List[UploadFile] = File(...),
user_id: str = Form(...),
bucket: str = Form("default"),
public: bool = Form(False)
):
"""Upload multiple files"""
try:
results = []
for file in files:
result = await file_processor.process_upload(
file=file,
user_id=user_id,
bucket=bucket,
public=public,
generate_thumbnail=True
)
results.append(result)
return {
"uploaded": len(results),
"files": results
}
except Exception as e:
logger.error(f"Multiple file upload error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# File Retrieval Endpoints
@app.get("/api/files/{file_id}")
async def get_file(file_id: str):
"""Get file by ID"""
try:
# Get metadata
metadata = await metadata_manager.get_file_metadata(file_id)
if not metadata:
raise HTTPException(status_code=404, detail="File not found")
# Get file from MinIO
file_stream = await minio_manager.get_file(
bucket=metadata["bucket"],
object_name=metadata["object_name"]
)
return StreamingResponse(
file_stream,
media_type=metadata.get("content_type", "application/octet-stream"),
headers={
"Content-Disposition": f'attachment; filename="{metadata["filename"]}"'
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"File retrieval error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/files/{file_id}/metadata")
async def get_file_metadata(file_id: str):
"""Get file metadata"""
try:
metadata = await metadata_manager.get_file_metadata(file_id)
if not metadata:
raise HTTPException(status_code=404, detail="File not found")
return FileMetadata(**metadata)
except HTTPException:
raise
except Exception as e:
logger.error(f"Metadata retrieval error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/files/{file_id}/thumbnail")
async def get_thumbnail(
file_id: str,
width: int = Query(200, ge=50, le=1000),
height: int = Query(200, ge=50, le=1000)
):
"""Get file thumbnail"""
try:
# Get metadata
metadata = await metadata_manager.get_file_metadata(file_id)
if not metadata:
raise HTTPException(status_code=404, detail="File not found")
# Check if file has thumbnail
if not metadata.get("has_thumbnail"):
raise HTTPException(status_code=404, detail="No thumbnail available")
# Get or generate thumbnail
thumbnail = await thumbnail_generator.get_thumbnail(
file_id=file_id,
bucket=metadata["bucket"],
object_name=metadata["object_name"],
width=width,
height=height
)
return StreamingResponse(
io.BytesIO(thumbnail),
media_type="image/jpeg"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Thumbnail retrieval error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/files/{file_id}/download")
async def download_file(file_id: str):
"""Download file with proper headers"""
try:
# Get metadata
metadata = await metadata_manager.get_file_metadata(file_id)
if not metadata:
raise HTTPException(status_code=404, detail="File not found")
# Update download count
await metadata_manager.increment_download_count(file_id)
# Get file from MinIO
file_stream = await minio_manager.get_file(
bucket=metadata["bucket"],
object_name=metadata["object_name"]
)
return StreamingResponse(
file_stream,
media_type=metadata.get("content_type", "application/octet-stream"),
headers={
"Content-Disposition": f'attachment; filename="{metadata["filename"]}"',
"Content-Length": str(metadata["size"])
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"File download error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# File Management Endpoints
@app.delete("/api/files/{file_id}")
async def delete_file(file_id: str, user_id: str):
"""Delete a file"""
try:
# Get metadata
metadata = await metadata_manager.get_file_metadata(file_id)
if not metadata:
raise HTTPException(status_code=404, detail="File not found")
# Check ownership
if metadata["user_id"] != user_id:
raise HTTPException(status_code=403, detail="Permission denied")
# Delete from MinIO
await minio_manager.delete_file(
bucket=metadata["bucket"],
object_name=metadata["object_name"]
)
# Delete thumbnail if exists
if metadata.get("has_thumbnail"):
await thumbnail_generator.delete_thumbnail(file_id)
# Delete metadata
await metadata_manager.delete_file_metadata(file_id)
return {"status": "deleted", "file_id": file_id}
except HTTPException:
raise
except Exception as e:
logger.error(f"File deletion error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.patch("/api/files/{file_id}")
async def update_file_metadata(
file_id: str,
user_id: str,
updates: Dict[str, Any]
):
"""Update file metadata"""
try:
# Get existing metadata
metadata = await metadata_manager.get_file_metadata(file_id)
if not metadata:
raise HTTPException(status_code=404, detail="File not found")
# Check ownership
if metadata["user_id"] != user_id:
raise HTTPException(status_code=403, detail="Permission denied")
# Update metadata
updated = await metadata_manager.update_file_metadata(file_id, updates)
return {"status": "updated", "file_id": file_id, "metadata": updated}
except HTTPException:
raise
except Exception as e:
logger.error(f"Metadata update error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# File Listing Endpoints
@app.get("/api/files")
async def list_files(
user_id: Optional[str] = None,
bucket: str = Query("default"),
limit: int = Query(20, le=100),
offset: int = Query(0),
search: Optional[str] = None,
file_type: Optional[str] = None,
sort_by: str = Query("created_at", pattern="^(created_at|filename|size)$"),
order: str = Query("desc", pattern="^(asc|desc)$")
):
"""List files with filtering and pagination"""
try:
files = await metadata_manager.list_files(
user_id=user_id,
bucket=bucket,
limit=limit,
offset=offset,
search=search,
file_type=file_type,
sort_by=sort_by,
order=order
)
return FileListResponse(**files)
except Exception as e:
logger.error(f"File listing error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/files/user/{user_id}")
async def get_user_files(
user_id: str,
limit: int = Query(20, le=100),
offset: int = Query(0)
):
"""Get all files for a specific user"""
try:
files = await metadata_manager.list_files(
user_id=user_id,
limit=limit,
offset=offset
)
return FileListResponse(**files)
except Exception as e:
logger.error(f"User files listing error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Storage Management Endpoints
@app.get("/api/storage/stats")
async def get_storage_stats():
"""Get storage statistics"""
try:
stats = await minio_manager.get_storage_stats()
db_stats = await metadata_manager.get_storage_stats()
return StorageStats(
total_files=db_stats["total_files"],
total_size=db_stats["total_size"],
buckets=stats["buckets"],
users_count=db_stats["users_count"],
file_types=db_stats["file_types"]
)
except Exception as e:
logger.error(f"Storage stats error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/storage/buckets")
async def create_bucket(bucket_name: str, public: bool = False):
"""Create a new storage bucket"""
try:
await minio_manager.create_bucket(bucket_name, public=public)
return {"status": "created", "bucket": bucket_name}
except Exception as e:
logger.error(f"Bucket creation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/storage/buckets")
async def list_buckets():
"""List all storage buckets"""
try:
buckets = await minio_manager.list_buckets()
return {"buckets": buckets}
except Exception as e:
logger.error(f"Bucket listing error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Presigned URL Endpoints
@app.post("/api/files/presigned-upload")
async def generate_presigned_upload_url(
filename: str,
content_type: str,
bucket: str = "default",
expires_in: int = Query(3600, ge=60, le=86400)
):
"""Generate presigned URL for direct upload to MinIO"""
try:
url = await minio_manager.generate_presigned_upload_url(
bucket=bucket,
object_name=f"{datetime.now().strftime('%Y%m%d')}/{filename}",
expires_in=expires_in
)
return {
"upload_url": url,
"expires_in": expires_in,
"method": "PUT"
}
except Exception as e:
logger.error(f"Presigned URL generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/files/{file_id}/share")
async def generate_share_link(
file_id: str,
expires_in: int = Query(86400, ge=60, le=604800) # 1 day default, max 7 days
):
"""Generate a shareable link for a file"""
try:
# Get metadata
metadata = await metadata_manager.get_file_metadata(file_id)
if not metadata:
raise HTTPException(status_code=404, detail="File not found")
# Generate presigned URL
url = await minio_manager.generate_presigned_download_url(
bucket=metadata["bucket"],
object_name=metadata["object_name"],
expires_in=expires_in
)
return {
"share_url": url,
"expires_in": expires_in,
"file_id": file_id,
"filename": metadata["filename"]
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Share link generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Batch Operations
@app.post("/api/files/batch-delete")
async def batch_delete_files(file_ids: List[str], user_id: str):
"""Delete multiple files at once"""
try:
deleted = []
errors = []
for file_id in file_ids:
try:
# Get metadata
metadata = await metadata_manager.get_file_metadata(file_id)
if metadata and metadata["user_id"] == user_id:
# Delete from MinIO
await minio_manager.delete_file(
bucket=metadata["bucket"],
object_name=metadata["object_name"]
)
# Delete metadata
await metadata_manager.delete_file_metadata(file_id)
deleted.append(file_id)
else:
errors.append({"file_id": file_id, "error": "Not found or permission denied"})
except Exception as e:
errors.append({"file_id": file_id, "error": str(e)})
return {
"deleted": deleted,
"errors": errors,
"total_deleted": len(deleted)
}
except Exception as e:
logger.error(f"Batch delete error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)

View File

@ -0,0 +1,331 @@
"""
Metadata Manager for file information storage in MongoDB
"""
from motor.motor_asyncio import AsyncIOMotorClient
from datetime import datetime
from typing import Optional, Dict, Any, List
import logging
import uuid
from models import FileType, FileStatus
logger = logging.getLogger(__name__)
class MetadataManager:
def __init__(self, mongodb_url: str, database: str = "files_db"):
self.mongodb_url = mongodb_url
self.database_name = database
self.client = None
self.db = None
self.collection = None
self.is_connected = False
async def connect(self):
"""Connect to MongoDB"""
try:
self.client = AsyncIOMotorClient(self.mongodb_url)
self.db = self.client[self.database_name]
self.collection = self.db.files
# Create indexes
await self._create_indexes()
# Test connection
await self.client.admin.command('ping')
self.is_connected = True
logger.info(f"Connected to MongoDB at {self.mongodb_url}")
except Exception as e:
logger.error(f"Failed to connect to MongoDB: {e}")
self.is_connected = False
raise
async def _create_indexes(self):
"""Create database indexes for better performance"""
try:
# Create indexes
await self.collection.create_index("user_id")
await self.collection.create_index("bucket")
await self.collection.create_index("created_at")
await self.collection.create_index("file_type")
await self.collection.create_index([("filename", "text")])
await self.collection.create_index([("user_id", 1), ("created_at", -1)])
logger.info("Database indexes created")
except Exception as e:
logger.error(f"Failed to create indexes: {e}")
async def create_file_metadata(self, metadata: Dict[str, Any]) -> str:
"""Create new file metadata"""
try:
# Add timestamps
metadata["created_at"] = datetime.now()
metadata["updated_at"] = datetime.now()
metadata["download_count"] = 0
metadata["status"] = FileStatus.READY.value
# Generate unique ID if not provided
if "id" not in metadata:
metadata["id"] = str(uuid.uuid4())
# Insert document
result = await self.collection.insert_one(metadata)
logger.info(f"Created metadata for file: {metadata['id']}")
return metadata["id"]
except Exception as e:
logger.error(f"Failed to create file metadata: {e}")
raise
async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID"""
try:
metadata = await self.collection.find_one({"id": file_id})
if metadata:
# Remove MongoDB's _id field
metadata.pop("_id", None)
return metadata
except Exception as e:
logger.error(f"Failed to get file metadata: {e}")
raise
async def update_file_metadata(self, file_id: str, updates: Dict[str, Any]) -> Dict[str, Any]:
"""Update file metadata"""
try:
# Add update timestamp
updates["updated_at"] = datetime.now()
# Update document
result = await self.collection.update_one(
{"id": file_id},
{"$set": updates}
)
if result.modified_count == 0:
raise Exception(f"File {file_id} not found")
# Return updated metadata
return await self.get_file_metadata(file_id)
except Exception as e:
logger.error(f"Failed to update file metadata: {e}")
raise
async def delete_file_metadata(self, file_id: str) -> bool:
"""Delete file metadata (soft delete)"""
try:
# Soft delete by marking as deleted
updates = {
"status": FileStatus.DELETED.value,
"deleted_at": datetime.now(),
"updated_at": datetime.now()
}
result = await self.collection.update_one(
{"id": file_id},
{"$set": updates}
)
return result.modified_count > 0
except Exception as e:
logger.error(f"Failed to delete file metadata: {e}")
raise
async def list_files(self, user_id: Optional[str] = None,
bucket: Optional[str] = None,
limit: int = 20,
offset: int = 0,
search: Optional[str] = None,
file_type: Optional[str] = None,
sort_by: str = "created_at",
order: str = "desc") -> Dict[str, Any]:
"""List files with filtering and pagination"""
try:
# Build query
query = {"status": {"$ne": FileStatus.DELETED.value}}
if user_id:
query["user_id"] = user_id
if bucket:
query["bucket"] = bucket
if file_type:
query["file_type"] = file_type
if search:
query["$text"] = {"$search": search}
# Count total documents
total = await self.collection.count_documents(query)
# Sort order
sort_order = -1 if order == "desc" else 1
# Execute query with pagination
cursor = self.collection.find(query)\
.sort(sort_by, sort_order)\
.skip(offset)\
.limit(limit)
files = []
async for doc in cursor:
doc.pop("_id", None)
files.append(doc)
return {
"files": files,
"total": total,
"limit": limit,
"offset": offset,
"has_more": (offset + limit) < total
}
except Exception as e:
logger.error(f"Failed to list files: {e}")
raise
async def increment_download_count(self, file_id: str):
"""Increment download counter for a file"""
try:
await self.collection.update_one(
{"id": file_id},
{
"$inc": {"download_count": 1},
"$set": {"last_accessed": datetime.now()}
}
)
except Exception as e:
logger.error(f"Failed to increment download count: {e}")
async def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage statistics"""
try:
# Aggregation pipeline for statistics
pipeline = [
{"$match": {"status": {"$ne": FileStatus.DELETED.value}}},
{
"$group": {
"_id": None,
"total_files": {"$sum": 1},
"total_size": {"$sum": "$size"},
"users": {"$addToSet": "$user_id"}
}
}
]
cursor = self.collection.aggregate(pipeline)
result = await cursor.to_list(length=1)
if result:
stats = result[0]
users_count = len(stats.get("users", []))
else:
stats = {"total_files": 0, "total_size": 0}
users_count = 0
# Get file type distribution
type_pipeline = [
{"$match": {"status": {"$ne": FileStatus.DELETED.value}}},
{
"$group": {
"_id": "$file_type",
"count": {"$sum": 1}
}
}
]
type_cursor = self.collection.aggregate(type_pipeline)
type_results = await type_cursor.to_list(length=None)
file_types = {
item["_id"]: item["count"]
for item in type_results if item["_id"]
}
return {
"total_files": stats.get("total_files", 0),
"total_size": stats.get("total_size", 0),
"users_count": users_count,
"file_types": file_types
}
except Exception as e:
logger.error(f"Failed to get storage stats: {e}")
raise
async def find_duplicate_files(self, file_hash: str) -> List[Dict[str, Any]]:
"""Find duplicate files by hash"""
try:
cursor = self.collection.find({
"hash": file_hash,
"status": {"$ne": FileStatus.DELETED.value}
})
duplicates = []
async for doc in cursor:
doc.pop("_id", None)
duplicates.append(doc)
return duplicates
except Exception as e:
logger.error(f"Failed to find duplicate files: {e}")
raise
async def get_user_storage_usage(self, user_id: str) -> Dict[str, Any]:
"""Get storage usage for a specific user"""
try:
pipeline = [
{
"$match": {
"user_id": user_id,
"status": {"$ne": FileStatus.DELETED.value}
}
},
{
"$group": {
"_id": "$file_type",
"count": {"$sum": 1},
"size": {"$sum": "$size"}
}
}
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(length=None)
total_size = sum(item["size"] for item in results)
total_files = sum(item["count"] for item in results)
breakdown = {
item["_id"]: {
"count": item["count"],
"size": item["size"]
}
for item in results if item["_id"]
}
return {
"user_id": user_id,
"total_files": total_files,
"total_size": total_size,
"breakdown": breakdown
}
except Exception as e:
logger.error(f"Failed to get user storage usage: {e}")
raise
async def close(self):
"""Close MongoDB connection"""
if self.client:
self.client.close()
self.is_connected = False
logger.info("MongoDB connection closed")

View File

@ -0,0 +1,333 @@
"""
MinIO Client for S3-compatible object storage
"""
from minio import Minio
from minio.error import S3Error
import asyncio
import io
from typing import Optional, Dict, Any, List
import logging
from datetime import timedelta
logger = logging.getLogger(__name__)
class MinIOManager:
def __init__(self, endpoint: str, access_key: str, secret_key: str, secure: bool = False):
self.endpoint = endpoint
self.access_key = access_key
self.secret_key = secret_key
self.secure = secure
self.client = None
self.is_connected = False
async def initialize(self):
"""Initialize MinIO client and create default buckets"""
try:
self.client = Minio(
self.endpoint,
access_key=self.access_key,
secret_key=self.secret_key,
secure=self.secure
)
# Create default buckets
default_buckets = ["default", "public", "thumbnails", "temp"]
for bucket in default_buckets:
await self.create_bucket(bucket, public=(bucket == "public"))
self.is_connected = True
logger.info(f"Connected to MinIO at {self.endpoint}")
except Exception as e:
logger.error(f"Failed to initialize MinIO: {e}")
self.is_connected = False
raise
async def create_bucket(self, bucket_name: str, public: bool = False):
"""Create a new bucket"""
try:
# Run in executor to avoid blocking
loop = asyncio.get_event_loop()
# Check if bucket exists
exists = await loop.run_in_executor(
None,
self.client.bucket_exists,
bucket_name
)
if not exists:
await loop.run_in_executor(
None,
self.client.make_bucket,
bucket_name
)
logger.info(f"Created bucket: {bucket_name}")
# Set bucket policy if public
if public:
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"AWS": ["*"]},
"Action": ["s3:GetObject"],
"Resource": [f"arn:aws:s3:::{bucket_name}/*"]
}
]
}
import json
await loop.run_in_executor(
None,
self.client.set_bucket_policy,
bucket_name,
json.dumps(policy)
)
logger.info(f"Set public policy for bucket: {bucket_name}")
except Exception as e:
logger.error(f"Failed to create bucket {bucket_name}: {e}")
raise
async def upload_file(self, bucket: str, object_name: str, file_data: bytes,
content_type: str = "application/octet-stream",
metadata: Optional[Dict[str, str]] = None):
"""Upload a file to MinIO"""
try:
loop = asyncio.get_event_loop()
# Convert bytes to BytesIO
file_stream = io.BytesIO(file_data)
length = len(file_data)
# Upload file
result = await loop.run_in_executor(
None,
self.client.put_object,
bucket,
object_name,
file_stream,
length,
content_type,
metadata
)
logger.info(f"Uploaded {object_name} to {bucket}")
return {
"bucket": bucket,
"object_name": object_name,
"etag": result.etag,
"version_id": result.version_id
}
except Exception as e:
logger.error(f"Failed to upload file: {e}")
raise
async def get_file(self, bucket: str, object_name: str) -> io.BytesIO:
"""Get a file from MinIO"""
try:
loop = asyncio.get_event_loop()
# Get object
response = await loop.run_in_executor(
None,
self.client.get_object,
bucket,
object_name
)
# Read data
data = response.read()
response.close()
response.release_conn()
return io.BytesIO(data)
except Exception as e:
logger.error(f"Failed to get file: {e}")
raise
async def delete_file(self, bucket: str, object_name: str):
"""Delete a file from MinIO"""
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
self.client.remove_object,
bucket,
object_name
)
logger.info(f"Deleted {object_name} from {bucket}")
except Exception as e:
logger.error(f"Failed to delete file: {e}")
raise
async def list_files(self, bucket: str, prefix: Optional[str] = None,
recursive: bool = True) -> List[Dict[str, Any]]:
"""List files in a bucket"""
try:
loop = asyncio.get_event_loop()
objects = await loop.run_in_executor(
None,
lambda: list(self.client.list_objects(
bucket,
prefix=prefix,
recursive=recursive
))
)
files = []
for obj in objects:
files.append({
"name": obj.object_name,
"size": obj.size,
"last_modified": obj.last_modified,
"etag": obj.etag,
"content_type": obj.content_type
})
return files
except Exception as e:
logger.error(f"Failed to list files: {e}")
raise
async def get_file_info(self, bucket: str, object_name: str) -> Dict[str, Any]:
"""Get file information"""
try:
loop = asyncio.get_event_loop()
stat = await loop.run_in_executor(
None,
self.client.stat_object,
bucket,
object_name
)
return {
"size": stat.size,
"etag": stat.etag,
"content_type": stat.content_type,
"last_modified": stat.last_modified,
"metadata": stat.metadata
}
except Exception as e:
logger.error(f"Failed to get file info: {e}")
raise
async def generate_presigned_download_url(self, bucket: str, object_name: str,
expires_in: int = 3600) -> str:
"""Generate a presigned URL for downloading"""
try:
loop = asyncio.get_event_loop()
url = await loop.run_in_executor(
None,
self.client.presigned_get_object,
bucket,
object_name,
timedelta(seconds=expires_in)
)
return url
except Exception as e:
logger.error(f"Failed to generate presigned URL: {e}")
raise
async def generate_presigned_upload_url(self, bucket: str, object_name: str,
expires_in: int = 3600) -> str:
"""Generate a presigned URL for uploading"""
try:
loop = asyncio.get_event_loop()
url = await loop.run_in_executor(
None,
self.client.presigned_put_object,
bucket,
object_name,
timedelta(seconds=expires_in)
)
return url
except Exception as e:
logger.error(f"Failed to generate presigned upload URL: {e}")
raise
async def copy_file(self, source_bucket: str, source_object: str,
dest_bucket: str, dest_object: str):
"""Copy a file within MinIO"""
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
self.client.copy_object,
dest_bucket,
dest_object,
f"/{source_bucket}/{source_object}"
)
logger.info(f"Copied {source_object} to {dest_object}")
except Exception as e:
logger.error(f"Failed to copy file: {e}")
raise
async def list_buckets(self) -> List[str]:
"""List all buckets"""
try:
loop = asyncio.get_event_loop()
buckets = await loop.run_in_executor(
None,
self.client.list_buckets
)
return [bucket.name for bucket in buckets]
except Exception as e:
logger.error(f"Failed to list buckets: {e}")
raise
async def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage statistics"""
try:
buckets = await self.list_buckets()
stats = {
"buckets": buckets,
"bucket_count": len(buckets),
"bucket_stats": {}
}
# Get stats for each bucket
for bucket in buckets:
files = await self.list_files(bucket)
total_size = sum(f["size"] for f in files)
stats["bucket_stats"][bucket] = {
"file_count": len(files),
"total_size": total_size
}
return stats
except Exception as e:
logger.error(f"Failed to get storage stats: {e}")
raise
async def check_file_exists(self, bucket: str, object_name: str) -> bool:
"""Check if a file exists"""
try:
await self.get_file_info(bucket, object_name)
return True
except:
return False

View File

@ -0,0 +1,112 @@
"""
Data models for File Management Service
"""
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
class FileType(str, Enum):
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = "document"
ARCHIVE = "archive"
OTHER = "other"
class FileStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
READY = "ready"
ERROR = "error"
DELETED = "deleted"
class FileMetadata(BaseModel):
id: str
filename: str
original_name: str
size: int
content_type: str
file_type: FileType
bucket: str
object_name: str
user_id: str
hash: str
status: FileStatus = FileStatus.READY
public: bool = False
has_thumbnail: bool = False
thumbnail_url: Optional[str] = None
tags: Dict[str, Any] = {}
metadata: Dict[str, Any] = {}
download_count: int = 0
created_at: datetime
updated_at: datetime
deleted_at: Optional[datetime] = None
class FileUploadResponse(BaseModel):
file_id: str
filename: str
size: int
content_type: str
file_type: FileType
bucket: str
public: bool
has_thumbnail: bool
thumbnail_url: Optional[str] = None
download_url: Optional[str] = None
created_at: datetime
message: str = "File uploaded successfully"
class FileListResponse(BaseModel):
files: List[FileMetadata]
total: int
limit: int
offset: int
has_more: bool
class StorageStats(BaseModel):
total_files: int
total_size: int
buckets: List[str]
users_count: int
file_types: Dict[str, int]
storage_used_percentage: Optional[float] = None
class ThumbnailRequest(BaseModel):
file_id: str
width: int = Field(200, ge=50, le=1000)
height: int = Field(200, ge=50, le=1000)
quality: int = Field(85, ge=50, le=100)
format: str = Field("jpeg", pattern="^(jpeg|png|webp)$")
class PresignedUrlResponse(BaseModel):
url: str
expires_in: int
method: str
headers: Optional[Dict[str, str]] = None
class BatchOperationResult(BaseModel):
successful: List[str]
failed: List[Dict[str, str]]
total_processed: int
total_successful: int
total_failed: int
class FileShareLink(BaseModel):
share_url: str
expires_in: int
file_id: str
filename: str
created_at: datetime
expires_at: datetime
class FileProcessingJob(BaseModel):
job_id: str
file_id: str
job_type: str # thumbnail, compress, convert, etc.
status: str # pending, processing, completed, failed
progress: Optional[float] = None
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
created_at: datetime
completed_at: Optional[datetime] = None

View File

@ -0,0 +1,11 @@
fastapi==0.109.0
uvicorn[standard]==0.27.0
pydantic==2.5.3
python-dotenv==1.0.0
motor==3.5.1
pymongo==4.6.1
minio==7.2.3
pillow==10.2.0
python-magic==0.4.27
aiofiles==23.2.1
python-multipart==0.0.6

View File

@ -0,0 +1,281 @@
#!/usr/bin/env python3
"""
Test script for File Management Service
"""
import asyncio
import httpx
import os
import json
from datetime import datetime
import base64
BASE_URL = "http://localhost:8014"
# Sample image for testing (1x1 pixel PNG)
TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
TEST_IMAGE_DATA = base64.b64decode(TEST_IMAGE_BASE64)
async def test_file_api():
"""Test file management API endpoints"""
async with httpx.AsyncClient() as client:
print("\n📁 Testing File Management Service API...")
# Test health check
print("\n1. Testing health check...")
response = await client.get(f"{BASE_URL}/health")
print(f"Health check: {response.json()}")
# Test file upload
print("\n2. Testing file upload...")
files = {
'file': ('test_image.png', TEST_IMAGE_DATA, 'image/png')
}
data = {
'user_id': 'test_user_123',
'bucket': 'default',
'public': 'false',
'generate_thumbnail': 'true',
'tags': json.dumps({"test": "true", "category": "sample"})
}
response = await client.post(
f"{BASE_URL}/api/files/upload",
files=files,
data=data
)
if response.status_code == 200:
upload_result = response.json()
print(f"File uploaded: {upload_result}")
file_id = upload_result.get("file_id")
else:
print(f"Upload failed: {response.status_code} - {response.text}")
file_id = None
# Test multiple file upload
print("\n3. Testing multiple file upload...")
files = [
('files', ('test1.png', TEST_IMAGE_DATA, 'image/png')),
('files', ('test2.png', TEST_IMAGE_DATA, 'image/png')),
('files', ('test3.png', TEST_IMAGE_DATA, 'image/png'))
]
data = {
'user_id': 'test_user_123',
'bucket': 'default',
'public': 'false'
}
response = await client.post(
f"{BASE_URL}/api/files/upload-multiple",
files=files,
data=data
)
if response.status_code == 200:
print(f"Multiple files uploaded: {response.json()}")
else:
print(f"Multiple upload failed: {response.status_code}")
# Test file metadata retrieval
if file_id:
print("\n4. Testing file metadata retrieval...")
response = await client.get(f"{BASE_URL}/api/files/{file_id}/metadata")
if response.status_code == 200:
print(f"File metadata: {response.json()}")
else:
print(f"Metadata retrieval failed: {response.status_code}")
# Test thumbnail generation
print("\n5. Testing thumbnail retrieval...")
response = await client.get(
f"{BASE_URL}/api/files/{file_id}/thumbnail",
params={"width": 100, "height": 100}
)
if response.status_code == 200:
print(f"Thumbnail retrieved: {len(response.content)} bytes")
else:
print(f"Thumbnail retrieval failed: {response.status_code}")
# Test file download
print("\n6. Testing file download...")
response = await client.get(f"{BASE_URL}/api/files/{file_id}/download")
if response.status_code == 200:
print(f"File downloaded: {len(response.content)} bytes")
else:
print(f"Download failed: {response.status_code}")
# Test share link generation
print("\n7. Testing share link generation...")
response = await client.get(
f"{BASE_URL}/api/files/{file_id}/share",
params={"expires_in": 3600}
)
if response.status_code == 200:
share_result = response.json()
print(f"Share link generated: {share_result.get('share_url', 'N/A')[:50]}...")
else:
print(f"Share link generation failed: {response.status_code}")
# Test file listing
print("\n8. Testing file listing...")
response = await client.get(
f"{BASE_URL}/api/files",
params={
"user_id": "test_user_123",
"limit": 10,
"offset": 0
}
)
if response.status_code == 200:
files_list = response.json()
print(f"Files found: {files_list.get('total', 0)} files")
else:
print(f"File listing failed: {response.status_code}")
# Test storage statistics
print("\n9. Testing storage statistics...")
response = await client.get(f"{BASE_URL}/api/storage/stats")
if response.status_code == 200:
stats = response.json()
print(f"Storage stats: {stats}")
else:
print(f"Storage stats failed: {response.status_code}")
# Test bucket operations
print("\n10. Testing bucket operations...")
response = await client.post(
f"{BASE_URL}/api/storage/buckets",
params={"bucket_name": "test-bucket", "public": False}
)
if response.status_code == 200:
print(f"Bucket created: {response.json()}")
else:
print(f"Bucket creation: {response.status_code}")
response = await client.get(f"{BASE_URL}/api/storage/buckets")
if response.status_code == 200:
print(f"Available buckets: {response.json()}")
else:
print(f"Bucket listing failed: {response.status_code}")
# Test presigned URL generation
print("\n11. Testing presigned URL generation...")
response = await client.post(
f"{BASE_URL}/api/files/presigned-upload",
params={
"filename": "test_upload.txt",
"content_type": "text/plain",
"bucket": "default",
"expires_in": 3600
}
)
if response.status_code == 200:
presigned = response.json()
print(f"Presigned upload URL generated: {presigned.get('upload_url', 'N/A')[:50]}...")
else:
print(f"Presigned URL generation failed: {response.status_code}")
# Test file deletion
if file_id:
print("\n12. Testing file deletion...")
response = await client.delete(
f"{BASE_URL}/api/files/{file_id}",
params={"user_id": "test_user_123"}
)
if response.status_code == 200:
print(f"File deleted: {response.json()}")
else:
print(f"File deletion failed: {response.status_code}")
async def test_large_file_upload():
"""Test large file upload"""
print("\n\n📦 Testing Large File Upload...")
# Create a larger test file (1MB)
large_data = b"x" * (1024 * 1024) # 1MB of data
async with httpx.AsyncClient(timeout=30.0) as client:
files = {
'file': ('large_test.bin', large_data, 'application/octet-stream')
}
data = {
'user_id': 'test_user_123',
'bucket': 'default',
'public': 'false'
}
print("Uploading 1MB file...")
response = await client.post(
f"{BASE_URL}/api/files/upload",
files=files,
data=data
)
if response.status_code == 200:
result = response.json()
print(f"Large file uploaded successfully: {result.get('file_id')}")
print(f"File size: {result.get('size')} bytes")
else:
print(f"Large file upload failed: {response.status_code}")
async def test_duplicate_detection():
"""Test duplicate file detection"""
print("\n\n🔍 Testing Duplicate Detection...")
async with httpx.AsyncClient() as client:
# Upload the same file twice
files = {
'file': ('duplicate_test.png', TEST_IMAGE_DATA, 'image/png')
}
data = {
'user_id': 'test_user_123',
'bucket': 'default',
'public': 'false'
}
print("Uploading file first time...")
response1 = await client.post(
f"{BASE_URL}/api/files/upload",
files=files,
data=data
)
if response1.status_code == 200:
result1 = response1.json()
print(f"First upload: {result1.get('file_id')}")
print("Uploading same file again...")
response2 = await client.post(
f"{BASE_URL}/api/files/upload",
files=files,
data=data
)
if response2.status_code == 200:
result2 = response2.json()
print(f"Second upload: {result2.get('file_id')}")
if result2.get('duplicate'):
print("✅ Duplicate detected successfully!")
else:
print("❌ Duplicate not detected")
async def main():
"""Run all tests"""
print("=" * 60)
print("FILE MANAGEMENT SERVICE TEST SUITE")
print("=" * 60)
print(f"Started at: {datetime.now().isoformat()}")
# Run tests
await test_file_api()
await test_large_file_upload()
await test_duplicate_detection()
print("\n" + "=" * 60)
print("✅ All tests completed!")
print(f"Finished at: {datetime.now().isoformat()}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,236 @@
"""
Thumbnail Generator for image and video files
"""
from PIL import Image, ImageOps
import io
import os
import hashlib
import logging
from typing import Optional, Tuple
import asyncio
from pathlib import Path
logger = logging.getLogger(__name__)
class ThumbnailGenerator:
def __init__(self, minio_client, cache_dir: str = "/tmp/thumbnails"):
self.minio_client = minio_client
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Supported image formats for thumbnail generation
self.supported_formats = {
'image/jpeg', 'image/jpg', 'image/png', 'image/gif',
'image/webp', 'image/bmp', 'image/tiff'
}
def _get_cache_path(self, file_id: str, width: int, height: int) -> Path:
"""Generate cache file path for thumbnail"""
cache_key = f"{file_id}_{width}x{height}"
cache_hash = hashlib.md5(cache_key.encode()).hexdigest()
return self.cache_dir / f"{cache_hash[:2]}" / f"{cache_hash}.jpg"
async def generate_thumbnail(self, file_data: bytes, content_type: str,
width: int = 200, height: int = 200) -> Optional[bytes]:
"""Generate a thumbnail from file data"""
try:
if content_type not in self.supported_formats:
logger.warning(f"Unsupported format for thumbnail: {content_type}")
return None
loop = asyncio.get_event_loop()
# Generate thumbnail in thread pool
thumbnail_data = await loop.run_in_executor(
None,
self._create_thumbnail,
file_data,
width,
height
)
return thumbnail_data
except Exception as e:
logger.error(f"Failed to generate thumbnail: {e}")
return None
def _create_thumbnail(self, file_data: bytes, width: int, height: int) -> bytes:
"""Create thumbnail using PIL"""
try:
# Open image
image = Image.open(io.BytesIO(file_data))
# Convert RGBA to RGB if necessary
if image.mode in ('RGBA', 'LA', 'P'):
# Create a white background
background = Image.new('RGB', image.size, (255, 255, 255))
if image.mode == 'P':
image = image.convert('RGBA')
background.paste(image, mask=image.split()[-1] if image.mode == 'RGBA' else None)
image = background
elif image.mode not in ('RGB', 'L'):
image = image.convert('RGB')
# Calculate thumbnail size maintaining aspect ratio
image.thumbnail((width, height), Image.Resampling.LANCZOS)
# Apply EXIF orientation if present
image = ImageOps.exif_transpose(image)
# Save thumbnail to bytes
output = io.BytesIO()
image.save(output, format='JPEG', quality=85, optimize=True)
output.seek(0)
return output.read()
except Exception as e:
logger.error(f"Thumbnail creation failed: {e}")
raise
async def get_thumbnail(self, file_id: str, bucket: str, object_name: str,
width: int = 200, height: int = 200) -> Optional[bytes]:
"""Get or generate thumbnail for a file"""
try:
# Check cache first
cache_path = self._get_cache_path(file_id, width, height)
if cache_path.exists():
logger.info(f"Thumbnail found in cache: {cache_path}")
with open(cache_path, 'rb') as f:
return f.read()
# Check if thumbnail exists in MinIO
thumbnail_object = f"thumbnails/{file_id}_{width}x{height}.jpg"
try:
thumbnail_stream = await self.minio_client.get_file(
bucket="thumbnails",
object_name=thumbnail_object
)
thumbnail_data = thumbnail_stream.read()
# Save to cache
await self._save_to_cache(cache_path, thumbnail_data)
return thumbnail_data
except:
pass # Thumbnail doesn't exist, generate it
# Get original file
file_stream = await self.minio_client.get_file(bucket, object_name)
file_data = file_stream.read()
# Get file info for content type
file_info = await self.minio_client.get_file_info(bucket, object_name)
content_type = file_info.get("content_type", "")
# Generate thumbnail
thumbnail_data = await self.generate_thumbnail(
file_data, content_type, width, height
)
if thumbnail_data:
# Save to MinIO
await self.minio_client.upload_file(
bucket="thumbnails",
object_name=thumbnail_object,
file_data=thumbnail_data,
content_type="image/jpeg"
)
# Save to cache
await self._save_to_cache(cache_path, thumbnail_data)
return thumbnail_data
except Exception as e:
logger.error(f"Failed to get thumbnail: {e}")
return None
async def _save_to_cache(self, cache_path: Path, data: bytes):
"""Save thumbnail to cache"""
try:
cache_path.parent.mkdir(parents=True, exist_ok=True)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: cache_path.write_bytes(data)
)
logger.info(f"Thumbnail saved to cache: {cache_path}")
except Exception as e:
logger.error(f"Failed to save to cache: {e}")
async def delete_thumbnail(self, file_id: str):
"""Delete all thumbnails for a file"""
try:
# Delete from cache
for cache_file in self.cache_dir.rglob(f"*{file_id}*"):
try:
cache_file.unlink()
logger.info(f"Deleted cache file: {cache_file}")
except:
pass
# Delete from MinIO (list and delete all sizes)
files = await self.minio_client.list_files(
bucket="thumbnails",
prefix=f"thumbnails/{file_id}_"
)
for file in files:
await self.minio_client.delete_file(
bucket="thumbnails",
object_name=file["name"]
)
logger.info(f"Deleted thumbnail: {file['name']}")
except Exception as e:
logger.error(f"Failed to delete thumbnails: {e}")
async def generate_multiple_sizes(self, file_data: bytes, content_type: str,
file_id: str) -> dict:
"""Generate thumbnails in multiple sizes"""
sizes = {
"small": (150, 150),
"medium": (300, 300),
"large": (600, 600)
}
results = {}
for size_name, (width, height) in sizes.items():
thumbnail = await self.generate_thumbnail(
file_data, content_type, width, height
)
if thumbnail:
# Save to MinIO
object_name = f"thumbnails/{file_id}_{size_name}.jpg"
await self.minio_client.upload_file(
bucket="thumbnails",
object_name=object_name,
file_data=thumbnail,
content_type="image/jpeg"
)
results[size_name] = {
"size": len(thumbnail),
"dimensions": f"{width}x{height}",
"object_name": object_name
}
return results
def clear_cache(self):
"""Clear thumbnail cache"""
try:
import shutil
shutil.rmtree(self.cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info("Thumbnail cache cleared")
except Exception as e:
logger.error(f"Failed to clear cache: {e}")

View File

@ -0,0 +1,26 @@
FROM python:3.11-slim
WORKDIR /app
# 시스템 패키지 설치
RUN apt-get update && apt-get install -y \
gcc \
libheif-dev \
libde265-dev \
libjpeg-dev \
libpng-dev \
&& rm -rf /var/lib/apt/lists/*
# Python 패키지 설치
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 애플리케이션 코드 복사
COPY . .
# 캐시 디렉토리 생성
RUN mkdir -p /app/cache
EXPOSE 8000
CMD ["python", "main.py"]

View File

View File

@ -0,0 +1,197 @@
from fastapi import APIRouter, Query, HTTPException, Body
from fastapi.responses import Response
from typing import Optional, Dict
import mimetypes
from pathlib import Path
import hashlib
from ..core.config import settings
# MinIO 사용 여부에 따라 적절한 캐시 모듈 선택
if settings.use_minio:
from ..core.minio_cache import cache
else:
from ..core.cache import cache
router = APIRouter()
@router.get("/image")
async def get_image(
url: str = Query(..., description="원본 이미지 URL"),
size: Optional[str] = Query(None, description="이미지 크기 (thumb, card, list, detail, hero)")
):
"""
이미지 프록시 엔드포인트
- 외부 URL의 이미지를 가져와서 캐싱
- 선택적으로 리사이징 및 최적화
- WebP 포맷으로 자동 변환 (설정에 따라)
"""
try:
# 캐시 확인
cached_data = await cache.get(url, size)
if cached_data:
# 캐시된 이미지 반환
# SVG 체크
if url.lower().endswith('.svg') or cache._is_svg(cached_data):
content_type = 'image/svg+xml'
# GIF 체크 (GIF는 WebP로 변환하지 않음)
elif url.lower().endswith('.gif'):
content_type = 'image/gif'
# WebP 변환이 활성화된 경우 항상 WebP로 제공 (GIF 제외)
elif settings.convert_to_webp and size:
content_type = 'image/webp'
else:
content_type = mimetypes.guess_type(url)[0] or 'image/jpeg'
return Response(
content=cached_data,
media_type=content_type,
headers={
"Cache-Control": f"public, max-age={86400 * 7}", # 7일 브라우저 캐시
"X-Cache": "HIT",
"X-Image-Format": content_type.split('/')[-1].upper()
}
)
# 캐시 미스 - 이미지 다운로드
image_data = await cache.download_image(url)
# URL에서 MIME 타입 추측
guessed_type = mimetypes.guess_type(url)[0]
# SVG 확장자 체크 (mimetypes가 SVG를 제대로 인식하지 못할 수 있음)
if url.lower().endswith('.svg') or cache._is_svg(image_data):
content_type = 'image/svg+xml'
# GIF 체크
elif url.lower().endswith('.gif') or (guessed_type and 'gif' in guessed_type.lower()):
content_type = 'image/gif'
else:
content_type = guessed_type or 'image/jpeg'
# 리사이징 및 최적화 (SVG와 GIF는 특별 처리)
if size and content_type != 'image/svg+xml':
# GIF는 특별 처리
if content_type == 'image/gif':
image_data, content_type = cache._process_gif(image_data, settings.thumbnail_sizes[size])
else:
image_data, content_type = cache.resize_and_optimize_image(image_data, size)
# 캐시에 저장
await cache.set(url, image_data, size)
# 백그라운드에서 다른 크기들도 생성하도록 트리거
await cache.trigger_background_generation(url)
# 이미지 반환
return Response(
content=image_data,
media_type=content_type,
headers={
"Cache-Control": f"public, max-age={86400 * 7}",
"X-Cache": "MISS",
"X-Image-Format": content_type.split('/')[-1].upper()
}
)
except HTTPException:
raise
except Exception as e:
import traceback
print(f"Error processing image from {url}: {str(e)}")
traceback.print_exc()
# 403 에러를 명확히 처리
if "403" in str(e):
raise HTTPException(
status_code=403,
detail=f"이미지 접근 거부됨: {url}"
)
raise HTTPException(
status_code=500,
detail=f"이미지 처리 실패: {str(e)}"
)
@router.get("/stats")
async def get_stats():
"""캐시 통계 정보"""
cache_size = await cache.get_cache_size()
# 디렉토리 구조 통계 추가 (MinIO 또는 파일시스템)
dir_stats = await cache.get_directory_stats()
return {
"cache_size_gb": round(cache_size, 2),
"max_cache_size_gb": settings.max_cache_size_gb,
"cache_usage_percent": round((cache_size / settings.max_cache_size_gb) * 100, 2),
"directory_stats": dir_stats
}
@router.post("/cleanup")
async def cleanup_cache():
"""오래된 캐시 정리"""
await cache.cleanup_old_cache()
return {"message": "캐시 정리 완료"}
@router.post("/cache/delete")
async def delete_cache(request: Dict = Body(...)):
"""특정 URL의 캐시 삭제"""
url = request.get("url")
if not url:
raise HTTPException(status_code=400, detail="URL이 필요합니다")
try:
# URL의 모든 크기 버전 삭제
sizes = ["thumb", "card", "list", "detail", "hero", None] # None은 원본
deleted_count = 0
for size in sizes:
# 캐시 경로 계산
url_hash = hashlib.md5(url.encode()).hexdigest()
# 3단계 디렉토리 구조
level1 = url_hash[:2]
level2 = url_hash[2:4]
level3 = url_hash[4:6]
# 크기별 파일명
if size:
patterns = [
f"{url_hash}_{size}.webp",
f"{url_hash}_{size}.jpg",
f"{url_hash}_{size}.jpeg",
f"{url_hash}_{size}.png",
f"{url_hash}_{size}.gif"
]
else:
patterns = [
f"{url_hash}",
f"{url_hash}.jpg",
f"{url_hash}.jpeg",
f"{url_hash}.png",
f"{url_hash}.gif",
f"{url_hash}.webp"
]
# 각 패턴에 대해 파일 삭제 시도
for filename in patterns:
cache_path = settings.cache_dir / level1 / level2 / level3 / filename
if cache_path.exists():
cache_path.unlink()
deleted_count += 1
print(f"✅ 캐시 파일 삭제: {cache_path}")
return {
"status": "success",
"message": f"{deleted_count}개의 캐시 파일이 삭제되었습니다",
"url": url
}
except Exception as e:
print(f"❌ 캐시 삭제 오류: {e}")
raise HTTPException(
status_code=500,
detail=f"캐시 삭제 실패: {str(e)}"
)

View File

@ -0,0 +1,91 @@
import asyncio
import logging
from typing import Set, Optional
from pathlib import Path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class BackgroundTaskManager:
"""백그라운드 작업 관리자"""
def __init__(self):
self.processing_urls: Set[str] = set() # 현재 처리 중인 URL 목록
self.task_queue: asyncio.Queue = None
self.worker_task: Optional[asyncio.Task] = None
async def start(self):
"""백그라운드 워커 시작"""
self.task_queue = asyncio.Queue(maxsize=100)
self.worker_task = asyncio.create_task(self._worker())
logger.info("백그라운드 작업 관리자 시작됨")
async def stop(self):
"""백그라운드 워커 정지"""
if self.worker_task:
self.worker_task.cancel()
try:
await self.worker_task
except asyncio.CancelledError:
pass
logger.info("백그라운드 작업 관리자 정지됨")
async def add_task(self, url: str):
"""작업 큐에 URL 추가"""
if url not in self.processing_urls and self.task_queue:
try:
self.processing_urls.add(url)
await self.task_queue.put(url)
logger.info(f"백그라운드 작업 추가: {url}")
except asyncio.QueueFull:
self.processing_urls.discard(url)
logger.warning(f"작업 큐가 가득 참: {url}")
async def _worker(self):
"""백그라운드 워커 - 큐에서 작업을 가져와 처리"""
from .cache import cache
while True:
try:
# 큐에서 URL 가져오기
url = await self.task_queue.get()
try:
# 원본 이미지가 캐시에 있는지 확인
original_data = await cache.get(url, None)
if not original_data:
# 원본 이미지 다운로드
original_data = await cache.download_image(url)
await cache.set(url, original_data, None)
# 모든 크기의 이미지 생성
sizes = ['thumb', 'card', 'list', 'detail', 'hero']
for size in sizes:
# 이미 존재하는지 확인
existing = await cache.get(url, size)
if not existing:
try:
# 리사이징 및 최적화 - cache.resize_and_optimize_image가 WebP를 처리함
resized_data, _ = cache.resize_and_optimize_image(original_data, size)
await cache.set(url, resized_data, size)
logger.info(f"백그라운드 생성 완료: {url} ({size})")
except Exception as e:
logger.error(f"백그라운드 리사이징 실패: {url} ({size}) - {str(e)}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
except Exception as e:
logger.error(f"백그라운드 작업 실패: {url} - {str(e)}")
finally:
# 처리 완료된 URL 제거
self.processing_urls.discard(url)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"백그라운드 워커 오류: {str(e)}")
await asyncio.sleep(1) # 오류 발생 시 잠시 대기
# 전역 백그라운드 작업 관리자
background_manager = BackgroundTaskManager()

View File

@ -0,0 +1,796 @@
import hashlib
import aiofiles
import os
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional
import httpx
from PIL import Image
try:
from pillow_heif import register_heif_opener, register_avif_opener
register_heif_opener() # HEIF/HEIC 지원
register_avif_opener() # AVIF 지원
print("HEIF/AVIF support enabled successfully")
except ImportError:
print("Warning: pillow_heif not installed, HEIF/AVIF support disabled")
import io
import asyncio
import ssl
from .config import settings
class ImageCache:
def __init__(self):
self.cache_dir = settings.cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _get_cache_path(self, url: str, size: Optional[str] = None) -> Path:
"""URL을 기반으로 캐시 파일 경로 생성"""
# URL을 해시하여 파일명 생성
url_hash = hashlib.md5(url.encode()).hexdigest()
# 3단계 디렉토리 구조 생성
# 예: 10f8a8f96aa1377e86fdbc6bf3c631cf -> 10/f8/a8/
level1 = url_hash[:2] # 첫 2자리
level2 = url_hash[2:4] # 다음 2자리
level3 = url_hash[4:6] # 다음 2자리
# 크기별로 다른 파일명 사용
if size:
filename = f"{url_hash}_{size}"
else:
filename = url_hash
# 확장자 추출 (WebP로 저장되는 경우 .webp 사용)
if settings.convert_to_webp and size:
filename = f"{filename}.webp"
else:
ext = self._get_extension_from_url(url)
if ext:
filename = f"{filename}.{ext}"
# 3단계 디렉토리 경로 생성
path = self.cache_dir / level1 / level2 / level3 / filename
path.parent.mkdir(parents=True, exist_ok=True)
return path
def _get_extension_from_url(self, url: str) -> Optional[str]:
"""URL에서 파일 확장자 추출"""
path = url.split('?')[0] # 쿼리 파라미터 제거
parts = path.split('.')
if len(parts) > 1:
ext = parts[-1].lower()
if ext in settings.allowed_formats:
return ext
return None
def _is_svg(self, data: bytes) -> bool:
"""SVG 파일인지 확인"""
# SVG 파일의 시작 부분 확인
if len(data) < 100:
return False
# 처음 1000바이트만 확인 (성능 최적화)
header = data[:1000].lower()
# SVG 시그니처 확인
svg_signatures = [
b'<svg',
b'<?xml',
b'<!doctype svg'
]
for sig in svg_signatures:
if sig in header:
return True
return False
def _process_gif(self, gif_data: bytes, target_size: tuple) -> tuple[bytes, str]:
"""GIF 처리 - JPEG로 변환하여 안정적으로 처리"""
try:
from PIL import Image
# GIF 열기
img = Image.open(io.BytesIO(gif_data))
# 모든 GIF를 RGB로 변환 (팔레트 모드 문제 해결)
# 팔레트 모드(P)를 RGB로 직접 변환
if img.mode != 'RGB':
img = img.convert('RGB')
# 리사이징
img = img.resize(target_size, Image.Resampling.LANCZOS)
# JPEG로 저장 (안정적)
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
return output.getvalue(), 'image/jpeg'
except Exception as e:
print(f"GIF 처리 중 오류: {e}")
import traceback
traceback.print_exc()
# 오류 발생 시 원본 반환
return gif_data, 'image/gif'
async def get(self, url: str, size: Optional[str] = None) -> Optional[bytes]:
"""캐시에서 이미지 가져오기"""
cache_path = self._get_cache_path(url, size)
if cache_path.exists():
# 캐시 만료 확인
stat = cache_path.stat()
age = datetime.now() - datetime.fromtimestamp(stat.st_mtime)
if age < timedelta(days=settings.cache_ttl_days):
async with aiofiles.open(cache_path, 'rb') as f:
return await f.read()
else:
# 만료된 캐시 삭제
cache_path.unlink()
return None
async def set(self, url: str, data: bytes, size: Optional[str] = None):
"""캐시에 이미지 저장"""
cache_path = self._get_cache_path(url, size)
async with aiofiles.open(cache_path, 'wb') as f:
await f.write(data)
async def download_image(self, url: str) -> bytes:
"""외부 URL에서 이미지 다운로드"""
from urllib.parse import urlparse
# URL에서 도메인 추출
parsed_url = urlparse(url)
domain = parsed_url.netloc
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
# 기본 헤더 설정
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
'Accept': 'image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8',
'Accept-Language': 'ko-KR,ko;q=0.9,en-US;q=0.8,en;q=0.7',
'Cache-Control': 'no-cache',
'Pragma': 'no-cache',
'Sec-Fetch-Dest': 'image',
'Sec-Fetch-Mode': 'no-cors',
'Sec-Fetch-Site': 'cross-site',
'Referer': base_url # 항상 기본 Referer 설정
}
# 특정 사이트별 Referer 오버라이드
if 'yna.co.kr' in url:
headers['Referer'] = 'https://www.yna.co.kr/'
client = httpx.AsyncClient(
verify=False, # SSL 검증 비활성화
timeout=30.0,
follow_redirects=True
)
elif 'investing.com' in url:
headers['Referer'] = 'https://www.investing.com/'
client = httpx.AsyncClient()
elif 'naver.com' in url:
headers['Referer'] = 'https://news.naver.com/'
client = httpx.AsyncClient()
elif 'daum.net' in url:
headers['Referer'] = 'https://news.daum.net/'
client = httpx.AsyncClient()
elif 'chosun.com' in url:
headers['Referer'] = 'https://www.chosun.com/'
client = httpx.AsyncClient()
elif 'vietnam.vn' in url or 'vstatic.vietnam.vn' in url:
headers['Referer'] = 'https://vietnam.vn/'
client = httpx.AsyncClient()
elif 'ddaily.co.kr' in url:
# ddaily는 /photos/ 경로를 사용해야 함
headers['Referer'] = 'https://www.ddaily.co.kr/'
# URL이 잘못된 경로를 사용하는 경우 수정
if '/2025/' in url and '/photos/' not in url:
url = url.replace('/2025/', '/photos/2025/')
print(f"Fixed ddaily URL: {url}")
client = httpx.AsyncClient()
else:
# 기본적으로 도메인 기반 Referer 사용
client = httpx.AsyncClient()
async with client:
try:
response = await client.get(
url,
headers=headers,
timeout=settings.request_timeout,
follow_redirects=True
)
response.raise_for_status()
except Exception as e:
# 모든 에러에 대해 Playwright 사용 시도
error_msg = str(e)
if isinstance(e, httpx.HTTPStatusError):
error_type = f"HTTP {e.response.status_code}"
elif isinstance(e, httpx.ConnectError):
error_type = "Connection Error"
elif isinstance(e, ssl.SSLError):
error_type = "SSL Error"
elif "resolve" in error_msg.lower() or "dns" in error_msg.lower():
error_type = "DNS Resolution Error"
else:
error_type = "Network Error"
print(f"{error_type} for {url}, trying with Playwright...")
# Playwright로 이미지 가져오기 시도
try:
from playwright.async_api import async_playwright
from PIL import Image
import io
async with async_playwright() as p:
# 브라우저 실행
browser = await p.chromium.launch(
headless=True,
args=['--no-sandbox', '--disable-setuid-sandbox']
)
# Referer 설정을 위한 도메인 추출
from urllib.parse import urlparse
parsed = urlparse(url)
referer_url = f"{parsed.scheme}://{parsed.netloc}/"
context = await browser.new_context(
viewport={'width': 1920, 'height': 1080},
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
extra_http_headers={
'Referer': referer_url
}
)
page = await context.new_page()
try:
# Response를 가로채기 위한 설정
image_data = None
async def handle_response(response):
nonlocal image_data
# 이미지 URL에 대한 응답 가로채기
if url in response.url or response.url == url:
try:
image_data = await response.body()
print(f"✅ Image intercepted: {len(image_data)} bytes")
except:
pass
# Response 이벤트 리스너 등록
page.on('response', handle_response)
# 이미지 URL로 이동 (에러 무시)
try:
await page.goto(url, wait_until='networkidle', timeout=30000)
except Exception as goto_error:
print(f"⚠️ Direct navigation failed: {goto_error}")
# 직접 이동 실패 시 HTML에 img 태그 삽입
await page.set_content(f'''
<html>
<body style="margin:0;padding:0;">
<img src="{url}" style="max-width:100%;height:auto;"
crossorigin="anonymous" />
</body>
</html>
''')
await page.wait_for_timeout(3000) # 이미지 로딩 대기
# 이미지 데이터가 없으면 JavaScript로 직접 fetch
if not image_data:
# JavaScript로 이미지 fetch
image_data_base64 = await page.evaluate('''
async (url) => {
try {
const response = await fetch(url);
const blob = await response.blob();
return new Promise((resolve) => {
const reader = new FileReader();
reader.onloadend = () => resolve(reader.result.split(',')[1]);
reader.readAsDataURL(blob);
});
} catch (e) {
return null;
}
}
''', url)
if image_data_base64:
import base64
image_data = base64.b64decode(image_data_base64)
print(f"✅ Image fetched via JavaScript: {len(image_data)} bytes")
# 여전히 데이터가 없으면 스크린샷 사용
if not image_data:
# 이미지 요소 찾기
img_element = await page.query_selector('img')
if img_element:
# 이미지가 로드되었는지 확인
is_loaded = await img_element.evaluate('(img) => img.complete && img.naturalHeight > 0')
if is_loaded:
image_data = await img_element.screenshot()
print(f"✅ Screenshot from loaded image: {len(image_data)} bytes")
else:
# 이미지 로드 대기
try:
await img_element.evaluate('(img) => new Promise(r => img.onload = r)')
image_data = await img_element.screenshot()
print(f"✅ Screenshot after waiting: {len(image_data)} bytes")
except:
# 전체 페이지 스크린샷
image_data = await page.screenshot(full_page=True)
print(f"⚠️ Full page screenshot: {len(image_data)} bytes")
else:
image_data = await page.screenshot(full_page=True)
print(f"⚠️ No image element, full screenshot: {len(image_data)} bytes")
print(f"✅ Successfully fetched image with Playwright: {url}")
return image_data
finally:
await page.close()
await context.close()
await browser.close()
except Exception as pw_error:
print(f"Playwright failed: {pw_error}, returning placeholder")
# Playwright도 실패하면 세련된 placeholder 반환
from PIL import Image, ImageDraw, ImageFont
import io
import random
# 그라디언트 배경색 선택 (부드러운 색상)
gradients = [
('#667eea', '#764ba2'), # 보라 그라디언트
('#f093fb', '#f5576c'), # 핑크 그라디언트
('#4facfe', '#00f2fe'), # 하늘색 그라디언트
('#43e97b', '#38f9d7'), # 민트 그라디언트
('#fa709a', '#fee140'), # 선셋 그라디언트
('#30cfd0', '#330867'), # 딥 오션
('#a8edea', '#fed6e3'), # 파스텔
('#ffecd2', '#fcb69f'), # 피치
]
# 랜덤 그라디언트 선택
color1, color2 = random.choice(gradients)
# 이미지 생성 (16:9 비율)
width, height = 800, 450
img = Image.new('RGB', (width, height))
draw = ImageDraw.Draw(img)
# 그라디언트 배경 생성
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
rgb1 = hex_to_rgb(color1)
rgb2 = hex_to_rgb(color2)
# 세로 그라디언트
for y in range(height):
ratio = y / height
r = int(rgb1[0] * (1 - ratio) + rgb2[0] * ratio)
g = int(rgb1[1] * (1 - ratio) + rgb2[1] * ratio)
b = int(rgb1[2] * (1 - ratio) + rgb2[2] * ratio)
draw.rectangle([(0, y), (width, y + 1)], fill=(r, g, b))
# 반투명 오버레이 추가 (깊이감)
overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
# 중앙 원형 그라디언트 효과
center_x, center_y = width // 2, height // 2
max_radius = min(width, height) // 3
for radius in range(max_radius, 0, -2):
opacity = int(255 * (1 - radius / max_radius) * 0.3)
overlay_draw.ellipse(
[(center_x - radius, center_y - radius),
(center_x + radius, center_y + radius)],
fill=(255, 255, 255, opacity)
)
# 이미지 아이콘 그리기 (산 모양)
icon_color = (255, 255, 255, 200)
icon_size = 80
icon_x = center_x
icon_y = center_y - 20
# 산 아이콘 (사진 이미지를 나타냄)
mountain_points = [
(icon_x - icon_size, icon_y + icon_size//2),
(icon_x - icon_size//2, icon_y - icon_size//4),
(icon_x - icon_size//4, icon_y),
(icon_x + icon_size//4, icon_y - icon_size//2),
(icon_x + icon_size, icon_y + icon_size//2),
]
overlay_draw.polygon(mountain_points, fill=icon_color)
# 태양/달 원
sun_radius = icon_size // 4
overlay_draw.ellipse(
[(icon_x - icon_size//2, icon_y - icon_size//2 - sun_radius),
(icon_x - icon_size//2 + sun_radius*2, icon_y - icon_size//2 + sun_radius)],
fill=icon_color
)
# 프레임 테두리
frame_margin = 40
overlay_draw.rectangle(
[(frame_margin, frame_margin),
(width - frame_margin, height - frame_margin)],
outline=(255, 255, 255, 150),
width=3
)
# 코너 장식
corner_size = 20
corner_width = 4
corners = [
(frame_margin, frame_margin),
(width - frame_margin - corner_size, frame_margin),
(frame_margin, height - frame_margin - corner_size),
(width - frame_margin - corner_size, height - frame_margin - corner_size)
]
for x, y in corners:
# 가로선
overlay_draw.rectangle(
[(x, y), (x + corner_size, y + corner_width)],
fill=(255, 255, 255, 200)
)
# 세로선
overlay_draw.rectangle(
[(x, y), (x + corner_width, y + corner_size)],
fill=(255, 255, 255, 200)
)
# "Image Loading..." 텍스트 (작게)
try:
# 시스템 폰트 시도
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
except:
font = ImageFont.load_default()
text = "Image Loading..."
bbox = draw.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
text_x = (width - text_width) // 2
text_y = center_y + icon_size
# 텍스트 그림자
for offset in [(2, 2), (-1, -1)]:
overlay_draw.text(
(text_x + offset[0], text_y + offset[1]),
text,
font=font,
fill=(0, 0, 0, 100)
)
# 텍스트 본체
overlay_draw.text(
(text_x, text_y),
text,
font=font,
fill=(255, 255, 255, 220)
)
# 오버레이 합성
img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB')
# 약간의 노이즈 추가 (텍스처)
pixels = img.load()
for _ in range(1000):
x = random.randint(0, width - 1)
y = random.randint(0, height - 1)
r, g, b = pixels[x, y]
brightness = random.randint(-20, 20)
pixels[x, y] = (
max(0, min(255, r + brightness)),
max(0, min(255, g + brightness)),
max(0, min(255, b + brightness))
)
# JPEG로 변환 (높은 품질)
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
return output.getvalue()
raise
# 이미지 크기 확인
content_length = int(response.headers.get('content-length', 0))
max_size = settings.max_image_size_mb * 1024 * 1024
if content_length > max_size:
raise ValueError(f"Image too large: {content_length} bytes")
# 응답 데이터 확인
content = response.content
print(f"Downloaded {len(content)} bytes from {url[:50]}...")
# gzip 압축 확인 및 해제
import gzip
if len(content) > 2 and content[:2] == b'\x1f\x8b':
print("📦 Gzip compressed data detected, decompressing...")
try:
content = gzip.decompress(content)
print(f"✅ Decompressed to {len(content)} bytes")
except Exception as e:
print(f"❌ Failed to decompress gzip: {e}")
# 처음 몇 바이트로 이미지 형식 확인
if len(content) > 10:
header = content[:12]
if header[:2] == b'\xff\xd8':
print("✅ JPEG image detected")
elif header[:8] == b'\x89PNG\r\n\x1a\n':
print("✅ PNG image detected")
elif header[:6] in (b'GIF87a', b'GIF89a'):
print("✅ GIF image detected")
elif header[:4] == b'RIFF' and header[8:12] == b'WEBP':
print("✅ WebP image detected")
elif b'<svg' in header or b'<?xml' in header:
print("✅ SVG image detected")
elif header[4:12] == b'ftypavif':
print("✅ AVIF image detected")
else:
print(f"⚠️ Unknown image format. Header: {header.hex()}")
return content
def resize_and_optimize_image(self, image_data: bytes, size: str) -> tuple[bytes, str]:
"""이미지 리사이징 및 최적화"""
if size not in settings.thumbnail_sizes:
raise ValueError(f"Invalid size: {size}")
target_size = settings.thumbnail_sizes[size]
# SVG 체크 - SVG는 리사이징하지 않고 그대로 반환
if self._is_svg(image_data):
return image_data, 'image/svg+xml'
# PIL로 이미지 열기
try:
img = Image.open(io.BytesIO(image_data))
except Exception as e:
# WebP 헤더 체크 (RIFF....WEBP)
header = image_data[:12] if len(image_data) >= 12 else image_data
if header[:4] == b'RIFF' and header[8:12] == b'WEBP':
print("🎨 WebP 이미지 감지됨, 변환 시도")
# WebP 형식이지만 PIL이 열지 못하는 경우
# Pillow-SIMD 또는 추가 라이브러리가 필요할 수 있음
try:
# 재시도
from PIL import WebPImagePlugin
img = Image.open(io.BytesIO(image_data))
except:
print("❌ WebP 이미지를 열 수 없음, 원본 반환")
return image_data, 'image/webp'
else:
raise e
# GIF 애니메이션 체크 및 처리
if getattr(img, "format", None) == "GIF":
return self._process_gif(image_data, target_size)
# WebP 형식 체크
original_format = getattr(img, "format", None)
is_webp = original_format == "WEBP"
# 원본 모드와 투명도 정보 저장
original_mode = img.mode
original_has_transparency = img.mode in ('RGBA', 'LA')
original_has_palette = img.mode == 'P'
# 팔레트 모드(P) 처리 - 간단하게 PIL의 기본 변환 사용
if img.mode == 'P':
# 팔레트 모드는 RGB로 직접 변환
# PIL의 convert 메서드가 팔레트를 올바르게 처리함
img = img.convert('RGB')
# 투명도가 있는 이미지 처리
if img.mode == 'RGBA':
# RGBA는 흰색 배경과 합성
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1])
img = background
elif img.mode == 'LA':
# LA(그레이스케일+알파)는 RGBA를 거쳐 RGB로
img = img.convert('RGBA')
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1])
img = background
elif img.mode == 'L':
# 그레이스케일은 RGB로 변환
img = img.convert('RGB')
elif img.mode not in ('RGB',):
# 기타 모드는 모두 RGB로 변환
img = img.convert('RGB')
# EXIF 방향 정보 처리 (RGB 변환 후에 수행)
try:
from PIL import ImageOps
img = ImageOps.exif_transpose(img)
except:
pass
# 메타데이터 제거는 스킵 (팔레트 모드 이미지에서 문제 발생)
# RGB로 변환되었으므로 이미 메타데이터는 대부분 제거됨
# 비율 유지하며 리사이징 (크롭 없이)
img_ratio = img.width / img.height
target_width = target_size[0]
target_height = target_size[1]
# 원본 비율을 유지하면서 목표 크기에 맞추기
# 너비 또는 높이 중 하나를 기준으로 비율 계산
if img.width > target_width or img.height > target_height:
# 너비 기준 리사이징
width_ratio = target_width / img.width
# 높이 기준 리사이징
height_ratio = target_height / img.height
# 둘 중 작은 비율 사용 (목표 크기를 넘지 않도록)
ratio = min(width_ratio, height_ratio)
new_width = int(img.width * ratio)
new_height = int(img.height * ratio)
# 큰 이미지를 작게 만들 때는 2단계 리샘플링으로 품질 향상
if img.width > new_width * 2 or img.height > new_height * 2:
# 1단계: 목표 크기의 2배로 먼저 축소
intermediate_width = new_width * 2
intermediate_height = new_height * 2
img = img.resize((intermediate_width, intermediate_height), Image.Resampling.LANCZOS)
# 최종 목표 크기로 리샘플링
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 샤프닝 적용 (작은 이미지에만)
if target_size[0] <= 400:
from PIL import ImageEnhance
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(1.2)
# 바이트로 변환
output = io.BytesIO()
# 적응형 품질 계산 (이미지 크기에 따라 조정)
def get_adaptive_quality(base_quality: int, target_width: int) -> int:
"""이미지 크기에 따른 적응형 품질 계산"""
# 품질을 더 높게 설정하여 검정색 문제 해결
if target_width <= 150: # 썸네일
return min(base_quality + 10, 95)
elif target_width <= 360: # 카드
return min(base_quality + 5, 90)
elif target_width <= 800: # 상세
return base_quality # 85
else: # 히어로
return base_quality # 85
# WebP 변환 및 최적화 - 최고 압축률 설정
# WebP 입력은 JPEG로 변환 (WebP 리사이징 문제 회피)
if is_webp:
output_format = 'JPEG'
content_type = 'image/jpeg'
else:
output_format = 'WEBP' if settings.convert_to_webp else 'JPEG'
content_type = 'image/webp' if output_format == 'WEBP' else 'image/jpeg'
if output_format == 'WEBP':
# WebP 최적화: method=6(최고품질), lossless=False, exact=False
adaptive_quality = get_adaptive_quality(settings.webp_quality, target_size[0])
save_kwargs = {
'format': 'WEBP',
'quality': adaptive_quality,
'method': 6, # 최고 압축 알고리즘 (0-6)
'lossless': settings.webp_lossless,
'exact': False, # 약간의 품질 손실 허용하여 더 작은 크기
}
img.save(output, **save_kwargs)
elif original_has_transparency and not settings.convert_to_webp:
# PNG 최적화 (투명도가 있는 이미지)
save_kwargs = {
'format': 'PNG',
'optimize': settings.optimize_png,
'compress_level': settings.png_compress_level,
}
# 팔레트 모드로 변환 가능한지 확인 (256색 이하)
if settings.optimize_png:
try:
# 색상 수가 256개 이하이면 팔레트 모드로 변환
quantized = img.quantize(colors=256, method=Image.Quantize.MEDIANCUT)
if len(quantized.getcolors()) <= 256:
img = quantized
save_kwargs['format'] = 'PNG'
except:
pass
content_type = 'image/png'
img.save(output, **save_kwargs)
else:
# JPEG 최적화 설정 (기본값)
adaptive_quality = get_adaptive_quality(settings.jpeg_quality, target_size[0])
save_kwargs = {
'format': 'JPEG',
'quality': adaptive_quality,
'optimize': True,
'progressive': settings.progressive_jpeg,
}
img.save(output, **save_kwargs)
return output.getvalue(), content_type
async def get_cache_size(self) -> float:
"""현재 캐시 크기 (GB)"""
total_size = 0
for dirpath, dirnames, filenames in os.walk(self.cache_dir):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
return total_size / (1024 ** 3) # GB로 변환
async def cleanup_old_cache(self):
"""오래된 캐시 파일 정리"""
cutoff_time = datetime.now() - timedelta(days=settings.cache_ttl_days)
for dirpath, dirnames, filenames in os.walk(self.cache_dir):
for filename in filenames:
filepath = Path(dirpath) / filename
if filepath.stat().st_mtime < cutoff_time.timestamp():
filepath.unlink()
async def trigger_background_generation(self, url: str):
"""백그라운드에서 모든 크기의 이미지 생성 트리거"""
from .background_tasks import background_manager
# 백그라운드 작업 큐에 추가
asyncio.create_task(background_manager.add_task(url))
async def get_directory_stats(self) -> dict:
"""디렉토리 구조 통계 정보"""
total_files = 0
total_dirs = 0
files_per_dir = {}
for root, dirs, files in os.walk(self.cache_dir):
total_dirs += len(dirs)
total_files += len(files)
# 각 디렉토리의 파일 수 계산
rel_path = os.path.relpath(root, self.cache_dir)
depth = len(Path(rel_path).parts) if rel_path != '.' else 0
if files and depth == 3: # 3단계 디렉토리에서만 파일 수 계산
files_per_dir[rel_path] = len(files)
# 통계 계산
avg_files_per_dir = sum(files_per_dir.values()) / len(files_per_dir) if files_per_dir else 0
max_files_in_dir = max(files_per_dir.values()) if files_per_dir else 0
return {
"total_files": total_files,
"total_directories": total_dirs,
"average_files_per_directory": round(avg_files_per_dir, 2),
"max_files_in_single_directory": max_files_in_dir,
"directory_depth": 3
}
cache = ImageCache()

View File

@ -0,0 +1,54 @@
from pydantic_settings import BaseSettings
from pathlib import Path
class Settings(BaseSettings):
# 기본 설정
app_name: str = "Image Proxy Service"
debug: bool = True
# 캐시 설정 (MinIO 전환 시에도 로컬 임시 파일용)
cache_dir: Path = Path("/app/cache")
max_cache_size_gb: int = 10
cache_ttl_days: int = 30
# MinIO 설정
use_minio: bool = True # MinIO 사용 여부
minio_endpoint: str = "minio:9000"
minio_access_key: str = "minioadmin"
minio_secret_key: str = "minioadmin"
minio_bucket_name: str = "image-cache"
minio_secure: bool = False
# 이미지 설정
max_image_size_mb: int = 20
allowed_formats: list = ["jpg", "jpeg", "png", "gif", "webp", "svg"]
# 리사이징 설정 - 뉴스 카드 용도별 최적화
thumbnail_sizes: dict = {
"thumb": (150, 100), # 작은 썸네일 (3:2 비율)
"card": (360, 240), # 뉴스 카드용 (3:2 비율)
"list": (300, 200), # 리스트용 (3:2 비율)
"detail": (800, 533), # 상세 페이지용 (원본 비율 유지)
"hero": (1200, 800) # 히어로 이미지용 (원본 비율 유지)
}
# 이미지 최적화 설정 - 품질 보장하면서 최저 용량
jpeg_quality: int = 85 # JPEG 품질 (품질 향상)
webp_quality: int = 85 # WebP 품질 (품질 향상으로 검정색 문제 해결)
webp_lossless: bool = False # 무손실 압축 비활성화 (용량 최적화)
png_compress_level: int = 9 # PNG 최대 압축 (0-9, 9가 최고 압축)
convert_to_webp: bool = False # WebP 변환 임시 비활성화 (검정색 이미지 문제)
# 고급 최적화 설정
progressive_jpeg: bool = True # 점진적 JPEG (로딩 성능 향상)
strip_metadata: bool = True # EXIF 등 메타데이터 제거 (용량 절약)
optimize_png: bool = True # PNG 팔레트 최적화
# 외부 요청 설정
request_timeout: int = 30
user_agent: str = "ImageProxyService/1.0"
class Config:
env_file = ".env"
settings = Settings()

View File

@ -0,0 +1,414 @@
import hashlib
import os
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Tuple
import httpx
from PIL import Image
try:
from pillow_heif import register_heif_opener, register_avif_opener
register_heif_opener() # HEIF/HEIC 지원
register_avif_opener() # AVIF 지원
print("HEIF/AVIF support enabled successfully")
except ImportError:
print("Warning: pillow_heif not installed, HEIF/AVIF support disabled")
import io
import asyncio
import ssl
from minio import Minio
from minio.error import S3Error
import tempfile
from .config import settings
class MinIOImageCache:
def __init__(self):
# MinIO 클라이언트 초기화
self.client = Minio(
settings.minio_endpoint,
access_key=settings.minio_access_key,
secret_key=settings.minio_secret_key,
secure=settings.minio_secure
)
# 버킷 생성 (동기 호출)
self._ensure_bucket()
# 로컬 임시 디렉토리 (이미지 처리용)
self.temp_dir = Path(tempfile.gettempdir()) / "image_cache_temp"
self.temp_dir.mkdir(parents=True, exist_ok=True)
def _ensure_bucket(self):
"""버킷이 존재하는지 확인하고 없으면 생성"""
try:
if not self.client.bucket_exists(settings.minio_bucket_name):
self.client.make_bucket(settings.minio_bucket_name)
print(f"✅ Created MinIO bucket: {settings.minio_bucket_name}")
else:
print(f"✅ MinIO bucket exists: {settings.minio_bucket_name}")
except S3Error as e:
print(f"❌ Error creating bucket: {e}")
def _get_object_name(self, url: str, size: Optional[str] = None) -> str:
"""URL을 기반으로 MinIO 객체 이름 생성"""
url_hash = hashlib.md5(url.encode()).hexdigest()
# 3단계 디렉토리 구조 생성 (MinIO는 /를 디렉토리처럼 취급)
level1 = url_hash[:2]
level2 = url_hash[2:4]
level3 = url_hash[4:6]
# 크기별로 다른 파일명 사용
if size:
filename = f"{url_hash}_{size}"
else:
filename = url_hash
# 확장자 추출 (WebP로 저장되는 경우 .webp 사용)
if settings.convert_to_webp and size:
filename = f"{filename}.webp"
else:
ext = self._get_extension_from_url(url)
if ext:
filename = f"{filename}.{ext}"
# MinIO 객체 경로 생성
object_name = f"{level1}/{level2}/{level3}/{filename}"
return object_name
def _get_extension_from_url(self, url: str) -> Optional[str]:
"""URL에서 파일 확장자 추출"""
path = url.split('?')[0] # 쿼리 파라미터 제거
parts = path.split('.')
if len(parts) > 1:
ext = parts[-1].lower()
if ext in settings.allowed_formats:
return ext
return None
def _is_svg(self, data: bytes) -> bool:
"""SVG 파일인지 확인"""
if len(data) < 100:
return False
header = data[:1000].lower()
svg_signatures = [
b'<svg',
b'<?xml',
b'<!doctype svg'
]
for sig in svg_signatures:
if sig in header:
return True
return False
def _process_gif(self, gif_data: bytes, target_size: tuple) -> tuple[bytes, str]:
"""GIF 처리 - JPEG로 변환하여 안정적으로 처리"""
try:
img = Image.open(io.BytesIO(gif_data))
if img.mode != 'RGB':
if img.mode == 'P':
img = img.convert('RGBA')
if img.mode == 'RGBA':
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[3] if len(img.split()) == 4 else None)
img = background
elif img.mode != 'RGB':
img = img.convert('RGB')
# 리사이즈
img.thumbnail(target_size, Image.Resampling.LANCZOS)
# JPEG로 저장
output = io.BytesIO()
img.save(
output,
format='JPEG',
quality=settings.jpeg_quality,
optimize=True,
progressive=settings.progressive_jpeg
)
return output.getvalue(), 'image/jpeg'
except Exception as e:
print(f"GIF 처리 오류: {e}")
return gif_data, 'image/gif'
def resize_and_optimize_image(self, image_data: bytes, size: str) -> tuple[bytes, str]:
"""이미지 리사이징 및 최적화"""
try:
target_size = settings.thumbnail_sizes.get(size, settings.thumbnail_sizes["thumb"])
# 이미지 열기
img = Image.open(io.BytesIO(image_data))
# EXIF 회전 정보 처리
try:
from PIL import ImageOps
img = ImageOps.exif_transpose(img)
except:
pass
# 리사이즈 (원본 비율 유지)
img.thumbnail(target_size, Image.Resampling.LANCZOS)
# 출력 버퍼
output = io.BytesIO()
# WebP로 변환 설정이 활성화되어 있으면
if settings.convert_to_webp:
# RGBA를 RGB로 변환 (WebP는 투명도 지원하지만 일부 브라우저 호환성 문제)
if img.mode in ('RGBA', 'LA', 'P'):
# 투명 배경을 흰색으로
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if 'A' in img.mode else None)
img = background
elif img.mode != 'RGB':
img = img.convert('RGB')
# WebP로 저장
img.save(
output,
format='WEBP',
quality=settings.webp_quality,
lossless=settings.webp_lossless,
method=6 # 최고 압축
)
content_type = 'image/webp'
else:
# 원본 포맷 유지하면서 최적화
if img.format == 'PNG':
img.save(
output,
format='PNG',
compress_level=settings.png_compress_level,
optimize=settings.optimize_png
)
content_type = 'image/png'
else:
# JPEG로 변환
if img.mode != 'RGB':
img = img.convert('RGB')
img.save(
output,
format='JPEG',
quality=settings.jpeg_quality,
optimize=True,
progressive=settings.progressive_jpeg
)
content_type = 'image/jpeg'
return output.getvalue(), content_type
except Exception as e:
print(f"이미지 최적화 오류: {e}")
import traceback
traceback.print_exc()
return image_data, 'image/jpeg'
async def get(self, url: str, size: Optional[str] = None) -> Optional[bytes]:
"""MinIO에서 캐시된 이미지 가져오기"""
object_name = self._get_object_name(url, size)
try:
# MinIO에서 객체 가져오기
response = self.client.get_object(settings.minio_bucket_name, object_name)
data = response.read()
response.close()
response.release_conn()
print(f"✅ Cache HIT from MinIO: {object_name}")
return data
except S3Error as e:
if e.code == 'NoSuchKey':
print(f"📭 Cache MISS in MinIO: {object_name}")
return None
else:
print(f"❌ MinIO error: {e}")
return None
async def set(self, url: str, data: bytes, size: Optional[str] = None):
"""MinIO에 이미지 캐시 저장"""
object_name = self._get_object_name(url, size)
try:
# 바이트 데이터를 스트림으로 변환
data_stream = io.BytesIO(data)
data_length = len(data)
# content-type 결정
if url.lower().endswith('.svg') or self._is_svg(data):
content_type = 'image/svg+xml'
elif url.lower().endswith('.gif'):
content_type = 'image/gif'
elif settings.convert_to_webp and size:
content_type = 'image/webp'
else:
content_type = 'application/octet-stream'
# MinIO에 저장 (메타데이터는 ASCII만 지원하므로 URL 해시 사용)
self.client.put_object(
settings.minio_bucket_name,
object_name,
data_stream,
data_length,
content_type=content_type,
metadata={
'url_hash': hashlib.md5(url.encode()).hexdigest(),
'cached_at': datetime.utcnow().isoformat(),
'size_variant': size or 'original'
}
)
print(f"✅ Cached to MinIO: {object_name} ({data_length} bytes)")
except S3Error as e:
print(f"❌ Failed to cache to MinIO: {e}")
async def download_image(self, url: str) -> bytes:
"""외부 URL에서 이미지 다운로드"""
# SSL 검증 비활성화 (개발 환경용)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with httpx.AsyncClient(
timeout=settings.request_timeout,
verify=False,
follow_redirects=True
) as client:
headers = {
"User-Agent": settings.user_agent,
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
"Accept-Encoding": "gzip, deflate, br",
"Cache-Control": "no-cache",
"Referer": url.split('/')[0] + '//' + url.split('/')[2] if len(url.split('/')) > 2 else url
}
response = await client.get(url, headers=headers)
if response.status_code == 403:
headers["User-Agent"] = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
response = await client.get(url, headers=headers)
response.raise_for_status()
content_length = response.headers.get("content-length")
if content_length:
size_mb = int(content_length) / (1024 * 1024)
if size_mb > settings.max_image_size_mb:
raise ValueError(f"이미지 크기가 {settings.max_image_size_mb}MB를 초과합니다")
return response.content
async def get_cache_size(self) -> float:
"""MinIO 버킷 크기 조회 (GB)"""
try:
total_size = 0
objects = self.client.list_objects(settings.minio_bucket_name, recursive=True)
for obj in objects:
total_size += obj.size
return total_size / (1024 ** 3) # GB로 변환
except S3Error as e:
print(f"❌ Failed to get cache size: {e}")
return 0.0
async def get_directory_stats(self) -> dict:
"""MinIO 디렉토리 구조 통계"""
try:
total_files = 0
directories = set()
objects = self.client.list_objects(settings.minio_bucket_name, recursive=True)
for obj in objects:
total_files += 1
# 디렉토리 경로 추출
parts = obj.object_name.split('/')
if len(parts) > 1:
dir_path = '/'.join(parts[:-1])
directories.add(dir_path)
return {
"total_files": total_files,
"total_directories": len(directories),
"average_files_per_directory": total_files / max(len(directories), 1),
"bucket_name": settings.minio_bucket_name
}
except S3Error as e:
print(f"❌ Failed to get directory stats: {e}")
return {
"total_files": 0,
"total_directories": 0,
"average_files_per_directory": 0,
"bucket_name": settings.minio_bucket_name
}
async def cleanup_old_cache(self):
"""오래된 캐시 정리"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=settings.cache_ttl_days)
deleted_count = 0
objects = self.client.list_objects(settings.minio_bucket_name, recursive=True)
for obj in objects:
# 객체의 마지막 수정 시간이 cutoff_date 이전이면 삭제
if obj.last_modified.replace(tzinfo=None) < cutoff_date:
self.client.remove_object(settings.minio_bucket_name, obj.object_name)
deleted_count += 1
print(f"🗑️ Deleted old cache: {obj.object_name}")
print(f"✅ Cleaned up {deleted_count} old cached files")
return deleted_count
except S3Error as e:
print(f"❌ Failed to cleanup cache: {e}")
return 0
async def trigger_background_generation(self, url: str):
"""백그라운드에서 다양한 크기 생성"""
asyncio.create_task(self._generate_all_sizes(url))
async def _generate_all_sizes(self, url: str):
"""모든 크기 버전 생성"""
try:
# 원본 이미지 다운로드
image_data = await self.download_image(url)
# SVG는 리사이징 불필요
if self._is_svg(image_data):
return
# 모든 크기 생성
for size_name in settings.thumbnail_sizes.keys():
# 이미 캐시되어 있는지 확인
existing = await self.get(url, size_name)
if not existing:
# 리사이징 및 최적화
if url.lower().endswith('.gif'):
resized_data, _ = self._process_gif(image_data, settings.thumbnail_sizes[size_name])
else:
resized_data, _ = self.resize_and_optimize_image(image_data, size_name)
# 캐시에 저장
await self.set(url, resized_data, size_name)
print(f"✅ Generated {size_name} version for {url}")
except Exception as e:
print(f"❌ Background generation failed for {url}: {e}")
# 싱글톤 인스턴스
cache = MinIOImageCache()

View File

@ -0,0 +1,65 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import uvicorn
from datetime import datetime
from app.api.endpoints import router
from app.core.config import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
# 시작 시
print("Images service starting...")
yield
# 종료 시
print("Images service stopping...")
app = FastAPI(
title="Images Service",
description="이미지 업로드, 프록시 및 캐싱 서비스",
version="2.0.0",
lifespan=lifespan
)
# CORS 설정
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 라우터 등록
app.include_router(router, prefix="/api/v1")
@app.get("/")
async def root():
return {
"service": "Images Service",
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"endpoints": {
"proxy": "/api/v1/image?url=<image_url>&size=<optional_size>",
"upload": "/api/v1/upload",
"stats": "/api/v1/stats",
"cleanup": "/api/v1/cleanup"
}
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "images",
"timestamp": datetime.now().isoformat()
}
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)

View File

@ -0,0 +1,12 @@
fastapi==0.109.0
uvicorn[standard]==0.27.0
httpx==0.26.0
pillow==10.2.0
pillow-heif==0.20.0
aiofiles==23.2.1
python-multipart==0.0.6
pydantic==2.5.3
pydantic-settings==2.1.0
motor==3.3.2
redis==5.0.1
minio==7.2.3

View File

@ -0,0 +1,21 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Expose port
EXPOSE 8000
# Run the application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]

View File

@ -0,0 +1,335 @@
"""
Channel Handlers for different notification delivery methods
"""
import logging
import asyncio
from typing import Optional, Dict, Any
from models import Notification, NotificationStatus
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import httpx
import json
logger = logging.getLogger(__name__)
class BaseChannelHandler:
"""Base class for channel handlers"""
async def send(self, notification: Notification) -> bool:
"""Send notification through the channel"""
raise NotImplementedError
async def verify_delivery(self, notification: Notification) -> bool:
"""Verify if notification was delivered"""
return True
class EmailHandler(BaseChannelHandler):
"""Email notification handler"""
def __init__(self, smtp_host: str, smtp_port: int, smtp_user: str, smtp_password: str):
self.smtp_host = smtp_host
self.smtp_port = smtp_port
self.smtp_user = smtp_user
self.smtp_password = smtp_password
async def send(self, notification: Notification) -> bool:
"""Send email notification"""
try:
# In production, would use async SMTP library
# For demo, we'll simulate email sending
logger.info(f"Sending email to user {notification.user_id}")
if not self.smtp_user or not self.smtp_password:
# Simulate sending without actual SMTP config
await asyncio.sleep(0.1) # Simulate network delay
logger.info(f"Email sent (simulated) to user {notification.user_id}")
return True
# Create message
msg = MIMEMultipart()
msg['From'] = self.smtp_user
msg['To'] = f"user_{notification.user_id}@example.com" # Would fetch actual email
msg['Subject'] = notification.title
# Add body
body = notification.message
if notification.data and "html_content" in notification.data:
msg.attach(MIMEText(notification.data["html_content"], 'html'))
else:
msg.attach(MIMEText(body, 'plain'))
# Send email (would be async in production)
# server = smtplib.SMTP(self.smtp_host, self.smtp_port)
# server.starttls()
# server.login(self.smtp_user, self.smtp_password)
# server.send_message(msg)
# server.quit()
logger.info(f"Email sent successfully to user {notification.user_id}")
return True
except Exception as e:
logger.error(f"Failed to send email: {e}")
return False
class SMSHandler(BaseChannelHandler):
"""SMS notification handler"""
def __init__(self, api_key: str, api_url: str):
self.api_key = api_key
self.api_url = api_url
self.client = httpx.AsyncClient()
async def send(self, notification: Notification) -> bool:
"""Send SMS notification"""
try:
# In production, would integrate with SMS provider (Twilio, etc.)
logger.info(f"Sending SMS to user {notification.user_id}")
if not self.api_key or not self.api_url:
# Simulate sending without actual API config
await asyncio.sleep(0.1) # Simulate network delay
logger.info(f"SMS sent (simulated) to user {notification.user_id}")
return True
# Would fetch user's phone number from database
phone_number = notification.data.get("phone") if notification.data else None
if not phone_number:
phone_number = "+1234567890" # Demo number
# Send SMS via API (example structure)
payload = {
"to": phone_number,
"message": f"{notification.title}\n{notification.message}",
"api_key": self.api_key
}
# response = await self.client.post(self.api_url, json=payload)
# return response.status_code == 200
# Simulate success
await asyncio.sleep(0.1)
logger.info(f"SMS sent successfully to user {notification.user_id}")
return True
except Exception as e:
logger.error(f"Failed to send SMS: {e}")
return False
class PushHandler(BaseChannelHandler):
"""Push notification handler (FCM/APNS)"""
def __init__(self, fcm_server_key: str):
self.fcm_server_key = fcm_server_key
self.fcm_url = "https://fcm.googleapis.com/fcm/send"
self.client = httpx.AsyncClient()
async def send(self, notification: Notification) -> bool:
"""Send push notification"""
try:
logger.info(f"Sending push notification to user {notification.user_id}")
if not self.fcm_server_key:
# Simulate sending without actual FCM config
await asyncio.sleep(0.1)
logger.info(f"Push notification sent (simulated) to user {notification.user_id}")
return True
# Would fetch user's device tokens from database
device_tokens = notification.data.get("device_tokens", []) if notification.data else []
if not device_tokens:
# Simulate with dummy token
device_tokens = ["dummy_token"]
# Send to each device token
for token in device_tokens:
payload = {
"to": token,
"notification": {
"title": notification.title,
"body": notification.message,
"icon": notification.data.get("icon") if notification.data else None,
"click_action": notification.data.get("click_action") if notification.data else None
},
"data": notification.data or {}
}
headers = {
"Authorization": f"key={self.fcm_server_key}",
"Content-Type": "application/json"
}
# response = await self.client.post(
# self.fcm_url,
# json=payload,
# headers=headers
# )
# Simulate success
await asyncio.sleep(0.05)
logger.info(f"Push notification sent successfully to user {notification.user_id}")
return True
except Exception as e:
logger.error(f"Failed to send push notification: {e}")
return False
class InAppHandler(BaseChannelHandler):
"""In-app notification handler"""
def __init__(self):
self.ws_server = None
def set_ws_server(self, ws_server):
"""Set WebSocket server for real-time delivery"""
self.ws_server = ws_server
async def send(self, notification: Notification) -> bool:
"""Send in-app notification"""
try:
logger.info(f"Sending in-app notification to user {notification.user_id}")
# Store notification in database (already done in manager)
# This would be retrieved when user logs in or requests notifications
# If WebSocket connection exists, send real-time
if self.ws_server:
await self.ws_server.send_to_user(
notification.user_id,
{
"type": "notification",
"notification": {
"id": notification.id,
"title": notification.title,
"message": notification.message,
"priority": notification.priority.value,
"category": notification.category.value if hasattr(notification, 'category') else "system",
"timestamp": notification.created_at.isoformat(),
"data": notification.data
}
}
)
logger.info(f"In-app notification sent successfully to user {notification.user_id}")
return True
except Exception as e:
logger.error(f"Failed to send in-app notification: {e}")
return False
class SlackHandler(BaseChannelHandler):
"""Slack notification handler"""
def __init__(self, webhook_url: Optional[str] = None):
self.webhook_url = webhook_url
self.client = httpx.AsyncClient()
async def send(self, notification: Notification) -> bool:
"""Send Slack notification"""
try:
logger.info(f"Sending Slack notification for user {notification.user_id}")
if not self.webhook_url:
# Simulate sending
await asyncio.sleep(0.1)
logger.info(f"Slack notification sent (simulated) for user {notification.user_id}")
return True
# Format message for Slack
slack_message = {
"text": notification.title,
"blocks": [
{
"type": "header",
"text": {
"type": "plain_text",
"text": notification.title
}
},
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": notification.message
}
}
]
}
# Add additional fields if present
if notification.data:
fields = []
for key, value in notification.data.items():
if key not in ["html_content", "device_tokens"]:
fields.append({
"type": "mrkdwn",
"text": f"*{key}:* {value}"
})
if fields:
slack_message["blocks"].append({
"type": "section",
"fields": fields[:10] # Slack limits to 10 fields
})
# Send to Slack
# response = await self.client.post(self.webhook_url, json=slack_message)
# return response.status_code == 200
await asyncio.sleep(0.1)
logger.info(f"Slack notification sent successfully")
return True
except Exception as e:
logger.error(f"Failed to send Slack notification: {e}")
return False
class WebhookHandler(BaseChannelHandler):
"""Generic webhook notification handler"""
def __init__(self, default_webhook_url: Optional[str] = None):
self.default_webhook_url = default_webhook_url
self.client = httpx.AsyncClient()
async def send(self, notification: Notification) -> bool:
"""Send webhook notification"""
try:
# Get webhook URL from notification data or use default
webhook_url = None
if notification.data and "webhook_url" in notification.data:
webhook_url = notification.data["webhook_url"]
else:
webhook_url = self.default_webhook_url
if not webhook_url:
logger.warning("No webhook URL configured")
return False
logger.info(f"Sending webhook notification for user {notification.user_id}")
# Prepare payload
payload = {
"notification_id": notification.id,
"user_id": notification.user_id,
"title": notification.title,
"message": notification.message,
"priority": notification.priority.value,
"timestamp": notification.created_at.isoformat(),
"data": notification.data
}
# Send webhook
# response = await self.client.post(webhook_url, json=payload)
# return response.status_code in [200, 201, 202, 204]
# Simulate success
await asyncio.sleep(0.1)
logger.info(f"Webhook notification sent successfully")
return True
except Exception as e:
logger.error(f"Failed to send webhook notification: {e}")
return False

View File

@ -0,0 +1,514 @@
"""
Notification Service - Real-time Multi-channel Notifications
"""
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import uvicorn
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
import asyncio
import os
from contextlib import asynccontextmanager
import logging
# Import custom modules
from models import (
Notification, NotificationChannel, NotificationTemplate,
NotificationPreference, NotificationHistory, NotificationStatus,
NotificationPriority, CreateNotificationRequest, BulkNotificationRequest
)
from notification_manager import NotificationManager
from channel_handlers import EmailHandler, SMSHandler, PushHandler, InAppHandler
from websocket_server import WebSocketNotificationServer
from queue_manager import NotificationQueueManager
from template_engine import TemplateEngine
from preference_manager import PreferenceManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global instances
notification_manager = None
ws_server = None
queue_manager = None
template_engine = None
preference_manager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global notification_manager, ws_server, queue_manager, template_engine, preference_manager
try:
# Initialize Template Engine
template_engine = TemplateEngine()
await template_engine.load_templates()
logger.info("Template engine initialized")
# Initialize Preference Manager
preference_manager = PreferenceManager(
mongodb_url=os.getenv("MONGODB_URL", "mongodb://mongodb:27017"),
database_name="notifications"
)
await preference_manager.connect()
logger.info("Preference manager connected")
# Initialize Notification Queue Manager
queue_manager = NotificationQueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
await queue_manager.connect()
logger.info("Queue manager connected")
# Initialize Channel Handlers
email_handler = EmailHandler(
smtp_host=os.getenv("SMTP_HOST", "smtp.gmail.com"),
smtp_port=int(os.getenv("SMTP_PORT", 587)),
smtp_user=os.getenv("SMTP_USER", ""),
smtp_password=os.getenv("SMTP_PASSWORD", "")
)
sms_handler = SMSHandler(
api_key=os.getenv("SMS_API_KEY", ""),
api_url=os.getenv("SMS_API_URL", "")
)
push_handler = PushHandler(
fcm_server_key=os.getenv("FCM_SERVER_KEY", "")
)
in_app_handler = InAppHandler()
# Initialize Notification Manager
notification_manager = NotificationManager(
channel_handlers={
NotificationChannel.EMAIL: email_handler,
NotificationChannel.SMS: sms_handler,
NotificationChannel.PUSH: push_handler,
NotificationChannel.IN_APP: in_app_handler
},
queue_manager=queue_manager,
template_engine=template_engine,
preference_manager=preference_manager
)
await notification_manager.start()
logger.info("Notification manager started")
# Initialize WebSocket Server
ws_server = WebSocketNotificationServer()
logger.info("WebSocket server initialized")
# Register in-app handler with WebSocket server
in_app_handler.set_ws_server(ws_server)
except Exception as e:
logger.error(f"Failed to start Notification service: {e}")
raise
yield
# Shutdown
if notification_manager:
await notification_manager.stop()
if queue_manager:
await queue_manager.close()
if preference_manager:
await preference_manager.close()
logger.info("Notification service shutdown complete")
app = FastAPI(
title="Notification Service",
description="Real-time Multi-channel Notification Service",
version="1.0.0",
lifespan=lifespan
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"service": "Notification Service",
"status": "running",
"timestamp": datetime.now().isoformat()
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "notifications",
"components": {
"queue_manager": "connected" if queue_manager and queue_manager.is_connected else "disconnected",
"preference_manager": "connected" if preference_manager and preference_manager.is_connected else "disconnected",
"notification_manager": "running" if notification_manager and notification_manager.is_running else "stopped",
"websocket_connections": len(ws_server.active_connections) if ws_server else 0
},
"timestamp": datetime.now().isoformat()
}
# Notification Endpoints
@app.post("/api/notifications/send")
async def send_notification(
request: CreateNotificationRequest,
background_tasks: BackgroundTasks
):
"""Send a single notification"""
try:
notification = await notification_manager.create_notification(
user_id=request.user_id,
title=request.title,
message=request.message,
channels=request.channels,
priority=request.priority,
data=request.data,
template_id=request.template_id,
schedule_at=request.schedule_at
)
if request.schedule_at and request.schedule_at > datetime.now():
# Schedule for later
await queue_manager.schedule_notification(notification, request.schedule_at)
return {
"notification_id": notification.id,
"status": "scheduled",
"scheduled_at": request.schedule_at.isoformat()
}
else:
# Send immediately
background_tasks.add_task(
notification_manager.send_notification,
notification
)
return {
"notification_id": notification.id,
"status": "queued"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/notifications/send-bulk")
async def send_bulk_notifications(
request: BulkNotificationRequest,
background_tasks: BackgroundTasks
):
"""Send notifications to multiple users"""
try:
notifications = []
for user_id in request.user_ids:
notification = await notification_manager.create_notification(
user_id=user_id,
title=request.title,
message=request.message,
channels=request.channels,
priority=request.priority,
data=request.data,
template_id=request.template_id
)
notifications.append(notification)
# Queue all notifications
background_tasks.add_task(
notification_manager.send_bulk_notifications,
notifications
)
return {
"count": len(notifications),
"notification_ids": [n.id for n in notifications],
"status": "queued"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/notifications/user/{user_id}")
async def get_user_notifications(
user_id: str,
status: Optional[NotificationStatus] = None,
channel: Optional[NotificationChannel] = None,
limit: int = Query(50, le=200),
offset: int = Query(0, ge=0)
):
"""Get notifications for a specific user"""
try:
notifications = await notification_manager.get_user_notifications(
user_id=user_id,
status=status,
channel=channel,
limit=limit,
offset=offset
)
return {
"notifications": notifications,
"count": len(notifications),
"limit": limit,
"offset": offset
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.patch("/api/notifications/{notification_id}/read")
async def mark_notification_read(notification_id: str):
"""Mark a notification as read"""
try:
success = await notification_manager.mark_as_read(notification_id)
if success:
return {"status": "marked_as_read", "notification_id": notification_id}
else:
raise HTTPException(status_code=404, detail="Notification not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/notifications/{notification_id}")
async def delete_notification(notification_id: str):
"""Delete a notification"""
try:
success = await notification_manager.delete_notification(notification_id)
if success:
return {"status": "deleted", "notification_id": notification_id}
else:
raise HTTPException(status_code=404, detail="Notification not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Template Endpoints
@app.get("/api/templates")
async def get_templates():
"""Get all notification templates"""
templates = await template_engine.get_all_templates()
return {"templates": templates}
@app.post("/api/templates")
async def create_template(template: NotificationTemplate):
"""Create a new notification template"""
try:
template_id = await template_engine.create_template(template)
return {"template_id": template_id, "status": "created"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.put("/api/templates/{template_id}")
async def update_template(template_id: str, template: NotificationTemplate):
"""Update an existing template"""
try:
success = await template_engine.update_template(template_id, template)
if success:
return {"status": "updated", "template_id": template_id}
else:
raise HTTPException(status_code=404, detail="Template not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Preference Endpoints
@app.get("/api/preferences/{user_id}")
async def get_user_preferences(user_id: str):
"""Get notification preferences for a user"""
try:
preferences = await preference_manager.get_user_preferences(user_id)
return {"user_id": user_id, "preferences": preferences}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.put("/api/preferences/{user_id}")
async def update_user_preferences(
user_id: str,
preferences: NotificationPreference
):
"""Update notification preferences for a user"""
try:
success = await preference_manager.update_user_preferences(user_id, preferences)
if success:
return {"status": "updated", "user_id": user_id}
else:
return {"status": "created", "user_id": user_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/preferences/{user_id}/unsubscribe/{category}")
async def unsubscribe_from_category(user_id: str, category: str):
"""Unsubscribe user from a notification category"""
try:
success = await preference_manager.unsubscribe_category(user_id, category)
if success:
return {"status": "unsubscribed", "user_id": user_id, "category": category}
else:
raise HTTPException(status_code=404, detail="User preferences not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# History and Analytics Endpoints
@app.get("/api/history")
async def get_notification_history(
user_id: Optional[str] = None,
channel: Optional[NotificationChannel] = None,
status: Optional[NotificationStatus] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = Query(100, le=1000)
):
"""Get notification history with filters"""
try:
history = await notification_manager.get_notification_history(
user_id=user_id,
channel=channel,
status=status,
start_date=start_date,
end_date=end_date,
limit=limit
)
return {
"history": history,
"count": len(history),
"filters": {
"user_id": user_id,
"channel": channel,
"status": status,
"date_range": f"{start_date} to {end_date}" if start_date and end_date else None
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/analytics")
async def get_notification_analytics(
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
):
"""Get notification analytics"""
try:
if not start_date:
start_date = datetime.now() - timedelta(days=7)
if not end_date:
end_date = datetime.now()
analytics = await notification_manager.get_analytics(start_date, end_date)
return analytics
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Queue Management Endpoints
@app.get("/api/queue/status")
async def get_queue_status():
"""Get current queue status"""
try:
status = await queue_manager.get_queue_status()
return status
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/queue/retry/{notification_id}")
async def retry_failed_notification(
notification_id: str,
background_tasks: BackgroundTasks
):
"""Retry a failed notification"""
try:
notification = await notification_manager.get_notification(notification_id)
if not notification:
raise HTTPException(status_code=404, detail="Notification not found")
if notification.status != NotificationStatus.FAILED:
raise HTTPException(status_code=400, detail="Only failed notifications can be retried")
background_tasks.add_task(
notification_manager.retry_notification,
notification
)
return {"status": "retry_queued", "notification_id": notification_id}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# WebSocket Endpoint
from fastapi import WebSocket, WebSocketDisconnect
@app.websocket("/ws/notifications/{user_id}")
async def websocket_notifications(websocket: WebSocket, user_id: str):
"""WebSocket endpoint for real-time notifications"""
await ws_server.connect(websocket, user_id)
try:
while True:
# Keep connection alive and handle incoming messages
data = await websocket.receive_text()
# Handle different message types
if data == "ping":
await websocket.send_text("pong")
elif data.startswith("read:"):
# Mark notification as read
notification_id = data.split(":")[1]
await notification_manager.mark_as_read(notification_id)
except WebSocketDisconnect:
ws_server.disconnect(user_id)
except Exception as e:
logger.error(f"WebSocket error for user {user_id}: {e}")
ws_server.disconnect(user_id)
# Device Token Management
@app.post("/api/devices/register")
async def register_device_token(
user_id: str,
device_token: str,
device_type: str = Query(..., regex="^(ios|android|web)$")
):
"""Register a device token for push notifications"""
try:
success = await notification_manager.register_device_token(
user_id=user_id,
device_token=device_token,
device_type=device_type
)
if success:
return {
"status": "registered",
"user_id": user_id,
"device_type": device_type
}
else:
raise HTTPException(status_code=500, detail="Failed to register device token")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/devices/{device_token}")
async def unregister_device_token(device_token: str):
"""Unregister a device token"""
try:
success = await notification_manager.unregister_device_token(device_token)
if success:
return {"status": "unregistered", "device_token": device_token}
else:
raise HTTPException(status_code=404, detail="Device token not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)

View File

@ -0,0 +1,201 @@
"""
Data models for Notification Service
"""
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional, List, Dict, Any, Literal
from enum import Enum
class NotificationChannel(str, Enum):
"""Notification delivery channels"""
EMAIL = "email"
SMS = "sms"
PUSH = "push"
IN_APP = "in_app"
class NotificationStatus(str, Enum):
"""Notification status"""
PENDING = "pending"
SENT = "sent"
DELIVERED = "delivered"
READ = "read"
FAILED = "failed"
CANCELLED = "cancelled"
class NotificationPriority(str, Enum):
"""Notification priority levels"""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
URGENT = "urgent"
class NotificationCategory(str, Enum):
"""Notification categories"""
SYSTEM = "system"
MARKETING = "marketing"
TRANSACTION = "transaction"
SOCIAL = "social"
SECURITY = "security"
UPDATE = "update"
class Notification(BaseModel):
"""Notification model"""
id: Optional[str] = Field(None, description="Unique notification ID")
user_id: str = Field(..., description="Target user ID")
title: str = Field(..., description="Notification title")
message: str = Field(..., description="Notification message")
channel: NotificationChannel = Field(..., description="Delivery channel")
status: NotificationStatus = Field(default=NotificationStatus.PENDING)
priority: NotificationPriority = Field(default=NotificationPriority.NORMAL)
category: NotificationCategory = Field(default=NotificationCategory.SYSTEM)
data: Optional[Dict[str, Any]] = Field(default=None, description="Additional data")
template_id: Optional[str] = Field(None, description="Template ID if using template")
scheduled_at: Optional[datetime] = Field(None, description="Scheduled delivery time")
sent_at: Optional[datetime] = Field(None, description="Actual sent time")
delivered_at: Optional[datetime] = Field(None, description="Delivery confirmation time")
read_at: Optional[datetime] = Field(None, description="Read time")
retry_count: int = Field(default=0, description="Number of retry attempts")
error_message: Optional[str] = Field(None, description="Error message if failed")
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class NotificationTemplate(BaseModel):
"""Notification template model"""
id: Optional[str] = Field(None, description="Template ID")
name: str = Field(..., description="Template name")
channel: NotificationChannel = Field(..., description="Target channel")
category: NotificationCategory = Field(..., description="Template category")
subject_template: Optional[str] = Field(None, description="Subject template (for email)")
body_template: str = Field(..., description="Body template with variables")
variables: List[str] = Field(default_factory=list, description="List of required variables")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Template metadata")
is_active: bool = Field(default=True, description="Template active status")
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class NotificationPreference(BaseModel):
"""User notification preferences"""
user_id: str = Field(..., description="User ID")
channels: Dict[NotificationChannel, bool] = Field(
default_factory=lambda: {
NotificationChannel.EMAIL: True,
NotificationChannel.SMS: False,
NotificationChannel.PUSH: True,
NotificationChannel.IN_APP: True
}
)
categories: Dict[NotificationCategory, bool] = Field(
default_factory=lambda: {
NotificationCategory.SYSTEM: True,
NotificationCategory.MARKETING: False,
NotificationCategory.TRANSACTION: True,
NotificationCategory.SOCIAL: True,
NotificationCategory.SECURITY: True,
NotificationCategory.UPDATE: True
}
)
quiet_hours: Optional[Dict[str, str]] = Field(
default=None,
description="Quiet hours configuration {start: 'HH:MM', end: 'HH:MM'}"
)
timezone: str = Field(default="UTC", description="User timezone")
language: str = Field(default="en", description="Preferred language")
email_frequency: Literal["immediate", "daily", "weekly"] = Field(default="immediate")
updated_at: datetime = Field(default_factory=datetime.now)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class NotificationHistory(BaseModel):
"""Notification history entry"""
notification_id: str
user_id: str
channel: NotificationChannel
status: NotificationStatus
title: str
message: str
sent_at: Optional[datetime]
delivered_at: Optional[datetime]
read_at: Optional[datetime]
error_message: Optional[str]
metadata: Dict[str, Any] = Field(default_factory=dict)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class CreateNotificationRequest(BaseModel):
"""Request model for creating notification"""
user_id: str
title: str
message: str
channels: List[NotificationChannel] = Field(default=[NotificationChannel.IN_APP])
priority: NotificationPriority = Field(default=NotificationPriority.NORMAL)
category: NotificationCategory = Field(default=NotificationCategory.SYSTEM)
data: Optional[Dict[str, Any]] = None
template_id: Optional[str] = None
schedule_at: Optional[datetime] = None
class BulkNotificationRequest(BaseModel):
"""Request model for bulk notifications"""
user_ids: List[str]
title: str
message: str
channels: List[NotificationChannel] = Field(default=[NotificationChannel.IN_APP])
priority: NotificationPriority = Field(default=NotificationPriority.NORMAL)
category: NotificationCategory = Field(default=NotificationCategory.SYSTEM)
data: Optional[Dict[str, Any]] = None
template_id: Optional[str] = None
class DeviceToken(BaseModel):
"""Device token for push notifications"""
user_id: str
token: str
device_type: Literal["ios", "android", "web"]
app_version: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class NotificationStats(BaseModel):
"""Notification statistics"""
total_sent: int
total_delivered: int
total_read: int
total_failed: int
delivery_rate: float
read_rate: float
channel_stats: Dict[str, Dict[str, int]]
category_stats: Dict[str, Dict[str, int]]
period: str
class NotificationEvent(BaseModel):
"""Notification event for tracking"""
event_type: Literal["sent", "delivered", "read", "failed", "clicked"]
notification_id: str
user_id: str
channel: NotificationChannel
timestamp: datetime = Field(default_factory=datetime.now)
metadata: Dict[str, Any] = Field(default_factory=dict)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}

View File

@ -0,0 +1,375 @@
"""
Notification Manager - Core notification orchestration
"""
import asyncio
import logging
from datetime import datetime
from typing import List, Optional, Dict, Any
import uuid
from models import (
Notification, NotificationChannel, NotificationStatus,
NotificationPriority, NotificationHistory, NotificationPreference
)
logger = logging.getLogger(__name__)
class NotificationManager:
"""Manages notification creation, delivery, and tracking"""
def __init__(
self,
channel_handlers: Dict[NotificationChannel, Any],
queue_manager: Any,
template_engine: Any,
preference_manager: Any
):
self.channel_handlers = channel_handlers
self.queue_manager = queue_manager
self.template_engine = template_engine
self.preference_manager = preference_manager
self.is_running = False
self.notification_store = {} # In-memory store for demo
self.history_store = [] # In-memory history for demo
self.device_tokens = {} # In-memory device tokens for demo
async def start(self):
"""Start notification manager"""
self.is_running = True
# Start background tasks for processing queued notifications
asyncio.create_task(self._process_notification_queue())
asyncio.create_task(self._process_scheduled_notifications())
logger.info("Notification manager started")
async def stop(self):
"""Stop notification manager"""
self.is_running = False
logger.info("Notification manager stopped")
async def create_notification(
self,
user_id: str,
title: str,
message: str,
channels: List[NotificationChannel],
priority: NotificationPriority = NotificationPriority.NORMAL,
data: Optional[Dict[str, Any]] = None,
template_id: Optional[str] = None,
schedule_at: Optional[datetime] = None
) -> Notification:
"""Create a new notification"""
# Check user preferences
preferences = await self.preference_manager.get_user_preferences(user_id)
if preferences:
# Filter channels based on user preferences
channels = [ch for ch in channels if preferences.channels.get(ch, True)]
# Apply template if provided
if template_id:
template = await self.template_engine.get_template(template_id)
if template:
message = await self.template_engine.render_template(template, data or {})
# Create notification objects for each channel
notification = Notification(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
message=message,
channel=channels[0] if channels else NotificationChannel.IN_APP,
priority=priority,
data=data,
template_id=template_id,
scheduled_at=schedule_at,
created_at=datetime.now()
)
# Store notification
self.notification_store[notification.id] = notification
logger.info(f"Created notification {notification.id} for user {user_id}")
return notification
async def send_notification(self, notification: Notification):
"""Send a single notification"""
try:
# Check if notification should be sent now
if notification.scheduled_at and notification.scheduled_at > datetime.now():
await self.queue_manager.schedule_notification(notification, notification.scheduled_at)
return
# Get the appropriate handler
handler = self.channel_handlers.get(notification.channel)
if not handler:
raise ValueError(f"No handler for channel {notification.channel}")
# Send through the channel
success = await handler.send(notification)
if success:
notification.status = NotificationStatus.SENT
notification.sent_at = datetime.now()
logger.info(f"Notification {notification.id} sent successfully")
else:
notification.status = NotificationStatus.FAILED
notification.retry_count += 1
logger.error(f"Failed to send notification {notification.id}")
# Retry if needed
if notification.retry_count < self._get_max_retries(notification.priority):
await self.queue_manager.enqueue_notification(notification)
# Update notification
self.notification_store[notification.id] = notification
# Add to history
await self._add_to_history(notification)
except Exception as e:
logger.error(f"Error sending notification {notification.id}: {e}")
notification.status = NotificationStatus.FAILED
notification.error_message = str(e)
self.notification_store[notification.id] = notification
async def send_bulk_notifications(self, notifications: List[Notification]):
"""Send multiple notifications"""
tasks = []
for notification in notifications:
tasks.append(self.send_notification(notification))
await asyncio.gather(*tasks, return_exceptions=True)
async def mark_as_read(self, notification_id: str) -> bool:
"""Mark notification as read"""
notification = self.notification_store.get(notification_id)
if notification:
notification.status = NotificationStatus.READ
notification.read_at = datetime.now()
self.notification_store[notification_id] = notification
logger.info(f"Notification {notification_id} marked as read")
return True
return False
async def delete_notification(self, notification_id: str) -> bool:
"""Delete a notification"""
if notification_id in self.notification_store:
del self.notification_store[notification_id]
logger.info(f"Notification {notification_id} deleted")
return True
return False
async def get_notification(self, notification_id: str) -> Optional[Notification]:
"""Get a notification by ID"""
return self.notification_store.get(notification_id)
async def get_user_notifications(
self,
user_id: str,
status: Optional[NotificationStatus] = None,
channel: Optional[NotificationChannel] = None,
limit: int = 50,
offset: int = 0
) -> List[Notification]:
"""Get notifications for a user"""
notifications = []
for notification in self.notification_store.values():
if notification.user_id != user_id:
continue
if status and notification.status != status:
continue
if channel and notification.channel != channel:
continue
notifications.append(notification)
# Sort by created_at descending
notifications.sort(key=lambda x: x.created_at, reverse=True)
# Apply pagination
return notifications[offset:offset + limit]
async def retry_notification(self, notification: Notification):
"""Retry a failed notification"""
notification.retry_count += 1
notification.status = NotificationStatus.PENDING
notification.error_message = None
await self.send_notification(notification)
async def get_notification_history(
self,
user_id: Optional[str] = None,
channel: Optional[NotificationChannel] = None,
status: Optional[NotificationStatus] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100
) -> List[NotificationHistory]:
"""Get notification history"""
history = []
for entry in self.history_store:
if user_id and entry.user_id != user_id:
continue
if channel and entry.channel != channel:
continue
if status and entry.status != status:
continue
if start_date and entry.sent_at and entry.sent_at < start_date:
continue
if end_date and entry.sent_at and entry.sent_at > end_date:
continue
history.append(entry)
# Sort by sent_at descending and limit
history.sort(key=lambda x: x.sent_at or datetime.min, reverse=True)
return history[:limit]
async def get_analytics(self, start_date: datetime, end_date: datetime) -> Dict[str, Any]:
"""Get notification analytics"""
total_sent = 0
total_delivered = 0
total_read = 0
total_failed = 0
channel_stats = {}
for notification in self.notification_store.values():
if notification.created_at < start_date or notification.created_at > end_date:
continue
if notification.status == NotificationStatus.SENT:
total_sent += 1
elif notification.status == NotificationStatus.DELIVERED:
total_delivered += 1
elif notification.status == NotificationStatus.READ:
total_read += 1
elif notification.status == NotificationStatus.FAILED:
total_failed += 1
# Channel stats
channel_name = notification.channel.value
if channel_name not in channel_stats:
channel_stats[channel_name] = {
"sent": 0,
"delivered": 0,
"read": 0,
"failed": 0
}
if notification.status == NotificationStatus.SENT:
channel_stats[channel_name]["sent"] += 1
elif notification.status == NotificationStatus.DELIVERED:
channel_stats[channel_name]["delivered"] += 1
elif notification.status == NotificationStatus.READ:
channel_stats[channel_name]["read"] += 1
elif notification.status == NotificationStatus.FAILED:
channel_stats[channel_name]["failed"] += 1
total = total_sent + total_delivered + total_read + total_failed
return {
"period": f"{start_date.isoformat()} to {end_date.isoformat()}",
"total_notifications": total,
"total_sent": total_sent,
"total_delivered": total_delivered,
"total_read": total_read,
"total_failed": total_failed,
"delivery_rate": (total_delivered / total * 100) if total > 0 else 0,
"read_rate": (total_read / total * 100) if total > 0 else 0,
"channel_stats": channel_stats
}
async def register_device_token(
self,
user_id: str,
device_token: str,
device_type: str
) -> bool:
"""Register a device token for push notifications"""
if user_id not in self.device_tokens:
self.device_tokens[user_id] = []
# Check if token already exists
for token in self.device_tokens[user_id]:
if token["token"] == device_token:
# Update existing token
token["device_type"] = device_type
token["updated_at"] = datetime.now()
return True
# Add new token
self.device_tokens[user_id].append({
"token": device_token,
"device_type": device_type,
"created_at": datetime.now(),
"updated_at": datetime.now()
})
logger.info(f"Registered device token for user {user_id}")
return True
async def unregister_device_token(self, device_token: str) -> bool:
"""Unregister a device token"""
for user_id, tokens in self.device_tokens.items():
for i, token in enumerate(tokens):
if token["token"] == device_token:
del self.device_tokens[user_id][i]
logger.info(f"Unregistered device token for user {user_id}")
return True
return False
def _get_max_retries(self, priority: NotificationPriority) -> int:
"""Get max retries based on priority"""
retry_map = {
NotificationPriority.LOW: 1,
NotificationPriority.NORMAL: 3,
NotificationPriority.HIGH: 5,
NotificationPriority.URGENT: 10
}
return retry_map.get(priority, 3)
async def _add_to_history(self, notification: Notification):
"""Add notification to history"""
history_entry = NotificationHistory(
notification_id=notification.id,
user_id=notification.user_id,
channel=notification.channel,
status=notification.status,
title=notification.title,
message=notification.message,
sent_at=notification.sent_at,
delivered_at=notification.delivered_at,
read_at=notification.read_at,
error_message=notification.error_message,
metadata={"priority": notification.priority.value}
)
self.history_store.append(history_entry)
async def _process_notification_queue(self):
"""Process queued notifications"""
while self.is_running:
try:
# Get notification from queue
notification_data = await self.queue_manager.dequeue_notification()
if notification_data:
notification = Notification(**notification_data)
await self.send_notification(notification)
except Exception as e:
logger.error(f"Error processing notification queue: {e}")
await asyncio.sleep(1)
async def _process_scheduled_notifications(self):
"""Process scheduled notifications"""
while self.is_running:
try:
# Check for scheduled notifications
now = datetime.now()
for notification in self.notification_store.values():
if (notification.scheduled_at and
notification.scheduled_at <= now and
notification.status == NotificationStatus.PENDING):
await self.send_notification(notification)
except Exception as e:
logger.error(f"Error processing scheduled notifications: {e}")
await asyncio.sleep(10) # Check every 10 seconds

View File

@ -0,0 +1,340 @@
"""
Preference Manager for user notification preferences
"""
import logging
from typing import Optional, Dict, Any, List
from datetime import datetime
import motor.motor_asyncio
from models import NotificationPreference, NotificationChannel, NotificationCategory
logger = logging.getLogger(__name__)
class PreferenceManager:
"""Manages user notification preferences"""
def __init__(self, mongodb_url: str = "mongodb://mongodb:27017", database_name: str = "notifications"):
self.mongodb_url = mongodb_url
self.database_name = database_name
self.client = None
self.db = None
self.preferences_collection = None
self.is_connected = False
# In-memory cache for demo
self.preferences_cache = {}
async def connect(self):
"""Connect to MongoDB"""
try:
self.client = motor.motor_asyncio.AsyncIOMotorClient(self.mongodb_url)
self.db = self.client[self.database_name]
self.preferences_collection = self.db["preferences"]
# Test connection
await self.client.admin.command('ping')
self.is_connected = True
# Create indexes
await self._create_indexes()
logger.info("Connected to MongoDB for preferences")
except Exception as e:
logger.error(f"Failed to connect to MongoDB: {e}")
# Fallback to in-memory storage
self.is_connected = False
logger.warning("Using in-memory storage for preferences")
async def close(self):
"""Close MongoDB connection"""
if self.client:
self.client.close()
self.is_connected = False
logger.info("Disconnected from MongoDB")
async def _create_indexes(self):
"""Create database indexes"""
if self.preferences_collection:
try:
await self.preferences_collection.create_index("user_id", unique=True)
logger.info("Created indexes for preferences collection")
except Exception as e:
logger.error(f"Failed to create indexes: {e}")
async def get_user_preferences(self, user_id: str) -> Optional[NotificationPreference]:
"""Get notification preferences for a user"""
try:
# Check cache first
if user_id in self.preferences_cache:
return self.preferences_cache[user_id]
if self.is_connected and self.preferences_collection:
# Get from MongoDB
doc = await self.preferences_collection.find_one({"user_id": user_id})
if doc:
# Convert document to model
doc.pop('_id', None) # Remove MongoDB ID
preference = NotificationPreference(**doc)
# Update cache
self.preferences_cache[user_id] = preference
return preference
# Return default preferences if not found
return self._get_default_preferences(user_id)
except Exception as e:
logger.error(f"Failed to get preferences for user {user_id}: {e}")
return self._get_default_preferences(user_id)
async def update_user_preferences(
self,
user_id: str,
preferences: NotificationPreference
) -> bool:
"""Update notification preferences for a user"""
try:
preferences.user_id = user_id
preferences.updated_at = datetime.now()
# Update cache
self.preferences_cache[user_id] = preferences
if self.is_connected and self.preferences_collection:
# Convert to dict for MongoDB
pref_dict = preferences.dict()
# Upsert in MongoDB
result = await self.preferences_collection.update_one(
{"user_id": user_id},
{"$set": pref_dict},
upsert=True
)
logger.info(f"Updated preferences for user {user_id}")
return result.modified_count > 0 or result.upserted_id is not None
# If not connected, just use cache
return True
except Exception as e:
logger.error(f"Failed to update preferences for user {user_id}: {e}")
return False
async def unsubscribe_category(self, user_id: str, category: str) -> bool:
"""Unsubscribe user from a notification category"""
try:
preferences = await self.get_user_preferences(user_id)
if not preferences:
preferences = self._get_default_preferences(user_id)
# Update category preference
if hasattr(NotificationCategory, category.upper()):
cat_enum = NotificationCategory(category.lower())
preferences.categories[cat_enum] = False
# Save updated preferences
return await self.update_user_preferences(user_id, preferences)
return False
except Exception as e:
logger.error(f"Failed to unsubscribe user {user_id} from {category}: {e}")
return False
async def subscribe_category(self, user_id: str, category: str) -> bool:
"""Subscribe user to a notification category"""
try:
preferences = await self.get_user_preferences(user_id)
if not preferences:
preferences = self._get_default_preferences(user_id)
# Update category preference
if hasattr(NotificationCategory, category.upper()):
cat_enum = NotificationCategory(category.lower())
preferences.categories[cat_enum] = True
# Save updated preferences
return await self.update_user_preferences(user_id, preferences)
return False
except Exception as e:
logger.error(f"Failed to subscribe user {user_id} to {category}: {e}")
return False
async def enable_channel(self, user_id: str, channel: NotificationChannel) -> bool:
"""Enable a notification channel for user"""
try:
preferences = await self.get_user_preferences(user_id)
if not preferences:
preferences = self._get_default_preferences(user_id)
preferences.channels[channel] = True
return await self.update_user_preferences(user_id, preferences)
except Exception as e:
logger.error(f"Failed to enable channel {channel} for user {user_id}: {e}")
return False
async def disable_channel(self, user_id: str, channel: NotificationChannel) -> bool:
"""Disable a notification channel for user"""
try:
preferences = await self.get_user_preferences(user_id)
if not preferences:
preferences = self._get_default_preferences(user_id)
preferences.channels[channel] = False
return await self.update_user_preferences(user_id, preferences)
except Exception as e:
logger.error(f"Failed to disable channel {channel} for user {user_id}: {e}")
return False
async def set_quiet_hours(
self,
user_id: str,
start_time: str,
end_time: str
) -> bool:
"""Set quiet hours for user"""
try:
preferences = await self.get_user_preferences(user_id)
if not preferences:
preferences = self._get_default_preferences(user_id)
preferences.quiet_hours = {
"start": start_time,
"end": end_time
}
return await self.update_user_preferences(user_id, preferences)
except Exception as e:
logger.error(f"Failed to set quiet hours for user {user_id}: {e}")
return False
async def clear_quiet_hours(self, user_id: str) -> bool:
"""Clear quiet hours for user"""
try:
preferences = await self.get_user_preferences(user_id)
if not preferences:
preferences = self._get_default_preferences(user_id)
preferences.quiet_hours = None
return await self.update_user_preferences(user_id, preferences)
except Exception as e:
logger.error(f"Failed to clear quiet hours for user {user_id}: {e}")
return False
async def set_email_frequency(self, user_id: str, frequency: str) -> bool:
"""Set email notification frequency"""
try:
if frequency not in ["immediate", "daily", "weekly"]:
return False
preferences = await self.get_user_preferences(user_id)
if not preferences:
preferences = self._get_default_preferences(user_id)
preferences.email_frequency = frequency
return await self.update_user_preferences(user_id, preferences)
except Exception as e:
logger.error(f"Failed to set email frequency for user {user_id}: {e}")
return False
async def batch_get_preferences(self, user_ids: List[str]) -> Dict[str, NotificationPreference]:
"""Get preferences for multiple users"""
results = {}
for user_id in user_ids:
pref = await self.get_user_preferences(user_id)
if pref:
results[user_id] = pref
return results
async def delete_user_preferences(self, user_id: str) -> bool:
"""Delete all preferences for a user"""
try:
# Remove from cache
if user_id in self.preferences_cache:
del self.preferences_cache[user_id]
if self.is_connected and self.preferences_collection:
# Delete from MongoDB
result = await self.preferences_collection.delete_one({"user_id": user_id})
logger.info(f"Deleted preferences for user {user_id}")
return result.deleted_count > 0
return True
except Exception as e:
logger.error(f"Failed to delete preferences for user {user_id}: {e}")
return False
def _get_default_preferences(self, user_id: str) -> NotificationPreference:
"""Get default notification preferences"""
return NotificationPreference(
user_id=user_id,
channels={
NotificationChannel.EMAIL: True,
NotificationChannel.SMS: False,
NotificationChannel.PUSH: True,
NotificationChannel.IN_APP: True
},
categories={
NotificationCategory.SYSTEM: True,
NotificationCategory.MARKETING: False,
NotificationCategory.TRANSACTION: True,
NotificationCategory.SOCIAL: True,
NotificationCategory.SECURITY: True,
NotificationCategory.UPDATE: True
},
email_frequency="immediate",
timezone="UTC",
language="en"
)
async def is_notification_allowed(
self,
user_id: str,
channel: NotificationChannel,
category: NotificationCategory
) -> bool:
"""Check if notification is allowed based on preferences"""
preferences = await self.get_user_preferences(user_id)
if not preferences:
return True # Allow by default if no preferences
# Check channel preference
if not preferences.channels.get(channel, True):
return False
# Check category preference
if not preferences.categories.get(category, True):
return False
# Check quiet hours
if preferences.quiet_hours and channel != NotificationChannel.IN_APP:
# Would need to check current time against quiet hours
# For demo, we'll allow all
pass
return True

View File

@ -0,0 +1,304 @@
"""
Notification Queue Manager with priority support
"""
import logging
import json
import asyncio
from typing import Optional, Dict, Any, List
from datetime import datetime
import redis.asyncio as redis
from models import NotificationPriority
logger = logging.getLogger(__name__)
class NotificationQueueManager:
"""Manages notification queues with priority levels"""
def __init__(self, redis_url: str = "redis://redis:6379"):
self.redis_url = redis_url
self.redis_client = None
self.is_connected = False
# Queue names by priority
self.queue_names = {
NotificationPriority.URGENT: "notifications:queue:urgent",
NotificationPriority.HIGH: "notifications:queue:high",
NotificationPriority.NORMAL: "notifications:queue:normal",
NotificationPriority.LOW: "notifications:queue:low"
}
# Scheduled notifications sorted set
self.scheduled_key = "notifications:scheduled"
# Failed notifications queue (DLQ)
self.dlq_key = "notifications:dlq"
async def connect(self):
"""Connect to Redis"""
try:
self.redis_client = await redis.from_url(self.redis_url)
await self.redis_client.ping()
self.is_connected = True
logger.info("Connected to Redis for notification queue")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
self.is_connected = False
raise
async def close(self):
"""Close Redis connection"""
if self.redis_client:
await self.redis_client.close()
self.is_connected = False
logger.info("Disconnected from Redis")
async def enqueue_notification(self, notification: Any, priority: Optional[NotificationPriority] = None):
"""Add notification to queue based on priority"""
if not self.is_connected:
logger.error("Redis not connected")
return False
try:
# Use notification's priority or provided priority
if priority is None:
priority = notification.priority if hasattr(notification, 'priority') else NotificationPriority.NORMAL
queue_name = self.queue_names.get(priority, self.queue_names[NotificationPriority.NORMAL])
# Serialize notification
notification_data = notification.dict() if hasattr(notification, 'dict') else notification
notification_json = json.dumps(notification_data, default=str)
# Add to appropriate queue
await self.redis_client.lpush(queue_name, notification_json)
logger.info(f"Enqueued notification to {queue_name}")
return True
except Exception as e:
logger.error(f"Failed to enqueue notification: {e}")
return False
async def dequeue_notification(self, timeout: int = 1) -> Optional[Dict[str, Any]]:
"""Dequeue notification with priority order"""
if not self.is_connected:
return None
try:
# Check queues in priority order
for priority in [NotificationPriority.URGENT, NotificationPriority.HIGH,
NotificationPriority.NORMAL, NotificationPriority.LOW]:
queue_name = self.queue_names[priority]
# Try to get from this queue
result = await self.redis_client.brpop(queue_name, timeout=timeout)
if result:
_, notification_json = result
notification_data = json.loads(notification_json)
logger.debug(f"Dequeued notification from {queue_name}")
return notification_data
return None
except Exception as e:
logger.error(f"Failed to dequeue notification: {e}")
return None
async def schedule_notification(self, notification: Any, scheduled_time: datetime):
"""Schedule a notification for future delivery"""
if not self.is_connected:
return False
try:
# Serialize notification
notification_data = notification.dict() if hasattr(notification, 'dict') else notification
notification_json = json.dumps(notification_data, default=str)
# Add to scheduled set with timestamp as score
timestamp = scheduled_time.timestamp()
await self.redis_client.zadd(self.scheduled_key, {notification_json: timestamp})
logger.info(f"Scheduled notification for {scheduled_time}")
return True
except Exception as e:
logger.error(f"Failed to schedule notification: {e}")
return False
async def get_due_notifications(self) -> List[Dict[str, Any]]:
"""Get notifications that are due for delivery"""
if not self.is_connected:
return []
try:
# Get current timestamp
now = datetime.now().timestamp()
# Get all notifications with score <= now
results = await self.redis_client.zrangebyscore(
self.scheduled_key,
min=0,
max=now,
withscores=False
)
notifications = []
for notification_json in results:
notification_data = json.loads(notification_json)
notifications.append(notification_data)
# Remove from scheduled set
await self.redis_client.zrem(self.scheduled_key, notification_json)
if notifications:
logger.info(f"Retrieved {len(notifications)} due notifications")
return notifications
except Exception as e:
logger.error(f"Failed to get due notifications: {e}")
return []
async def add_to_dlq(self, notification: Any, error_message: str):
"""Add failed notification to Dead Letter Queue"""
if not self.is_connected:
return False
try:
# Add error information
notification_data = notification.dict() if hasattr(notification, 'dict') else notification
notification_data['dlq_error'] = error_message
notification_data['dlq_timestamp'] = datetime.now().isoformat()
notification_json = json.dumps(notification_data, default=str)
# Add to DLQ
await self.redis_client.lpush(self.dlq_key, notification_json)
logger.info(f"Added notification to DLQ: {error_message}")
return True
except Exception as e:
logger.error(f"Failed to add to DLQ: {e}")
return False
async def get_dlq_notifications(self, limit: int = 10) -> List[Dict[str, Any]]:
"""Get notifications from Dead Letter Queue"""
if not self.is_connected:
return []
try:
# Get from DLQ
results = await self.redis_client.lrange(self.dlq_key, 0, limit - 1)
notifications = []
for notification_json in results:
notification_data = json.loads(notification_json)
notifications.append(notification_data)
return notifications
except Exception as e:
logger.error(f"Failed to get DLQ notifications: {e}")
return []
async def retry_dlq_notification(self, index: int) -> bool:
"""Retry a notification from DLQ"""
if not self.is_connected:
return False
try:
# Get notification at index
notification_json = await self.redis_client.lindex(self.dlq_key, index)
if not notification_json:
return False
# Parse and remove DLQ info
notification_data = json.loads(notification_json)
notification_data.pop('dlq_error', None)
notification_data.pop('dlq_timestamp', None)
# Re-enqueue
priority = NotificationPriority(notification_data.get('priority', 'normal'))
queue_name = self.queue_names[priority]
new_json = json.dumps(notification_data, default=str)
await self.redis_client.lpush(queue_name, new_json)
# Remove from DLQ
await self.redis_client.lrem(self.dlq_key, 1, notification_json)
logger.info(f"Retried DLQ notification at index {index}")
return True
except Exception as e:
logger.error(f"Failed to retry DLQ notification: {e}")
return False
async def get_queue_status(self) -> Dict[str, Any]:
"""Get current queue status"""
if not self.is_connected:
return {"status": "disconnected"}
try:
status = {
"status": "connected",
"queues": {},
"scheduled": 0,
"dlq": 0
}
# Get queue lengths
for priority, queue_name in self.queue_names.items():
length = await self.redis_client.llen(queue_name)
status["queues"][priority.value] = length
# Get scheduled count
status["scheduled"] = await self.redis_client.zcard(self.scheduled_key)
# Get DLQ count
status["dlq"] = await self.redis_client.llen(self.dlq_key)
return status
except Exception as e:
logger.error(f"Failed to get queue status: {e}")
return {"status": "error", "error": str(e)}
async def clear_queue(self, priority: NotificationPriority) -> bool:
"""Clear a specific priority queue"""
if not self.is_connected:
return False
try:
queue_name = self.queue_names[priority]
await self.redis_client.delete(queue_name)
logger.info(f"Cleared queue: {queue_name}")
return True
except Exception as e:
logger.error(f"Failed to clear queue: {e}")
return False
async def clear_all_queues(self) -> bool:
"""Clear all notification queues"""
if not self.is_connected:
return False
try:
# Clear all priority queues
for queue_name in self.queue_names.values():
await self.redis_client.delete(queue_name)
# Clear scheduled and DLQ
await self.redis_client.delete(self.scheduled_key)
await self.redis_client.delete(self.dlq_key)
logger.info("Cleared all notification queues")
return True
except Exception as e:
logger.error(f"Failed to clear all queues: {e}")
return False

View File

@ -0,0 +1,11 @@
fastapi==0.109.0
uvicorn[standard]==0.27.0
pydantic==2.5.3
python-dotenv==1.0.0
redis==5.0.1
motor==3.5.1
pymongo==4.6.1
httpx==0.26.0
websockets==12.0
aiofiles==23.2.1
python-multipart==0.0.6

View File

@ -0,0 +1,334 @@
"""
Template Engine for notification templates
"""
import logging
import re
from typing import Dict, Any, List, Optional
from datetime import datetime
import uuid
from models import NotificationTemplate, NotificationChannel, NotificationCategory
logger = logging.getLogger(__name__)
class TemplateEngine:
"""Manages and renders notification templates"""
def __init__(self):
self.templates = {} # In-memory storage for demo
self._load_default_templates()
async def load_templates(self):
"""Load templates from storage"""
# In production, would load from database
logger.info(f"Loaded {len(self.templates)} templates")
def _load_default_templates(self):
"""Load default system templates"""
default_templates = [
NotificationTemplate(
id="welcome",
name="Welcome Email",
channel=NotificationChannel.EMAIL,
category=NotificationCategory.SYSTEM,
subject_template="Welcome to {{app_name}}!",
body_template="""
Hi {{user_name}},
Welcome to {{app_name}}! We're excited to have you on board.
Here are some things you can do to get started:
- Complete your profile
- Explore our features
- Connect with other users
If you have any questions, feel free to reach out to our support team.
Best regards,
The {{app_name}} Team
""",
variables=["user_name", "app_name"]
),
NotificationTemplate(
id="password_reset",
name="Password Reset",
channel=NotificationChannel.EMAIL,
category=NotificationCategory.SECURITY,
subject_template="Password Reset Request",
body_template="""
Hi {{user_name}},
We received a request to reset your password for {{app_name}}.
Click the link below to reset your password:
{{reset_link}}
This link will expire in {{expiry_hours}} hours.
If you didn't request this, please ignore this email or contact support.
Best regards,
The {{app_name}} Team
""",
variables=["user_name", "app_name", "reset_link", "expiry_hours"]
),
NotificationTemplate(
id="order_confirmation",
name="Order Confirmation",
channel=NotificationChannel.EMAIL,
category=NotificationCategory.TRANSACTION,
subject_template="Order #{{order_id}} Confirmed",
body_template="""
Hi {{user_name}},
Your order #{{order_id}} has been confirmed!
Order Details:
- Total: {{order_total}}
- Items: {{item_count}}
- Estimated Delivery: {{delivery_date}}
You can track your order status at: {{tracking_link}}
Thank you for your purchase!
Best regards,
The {{app_name}} Team
""",
variables=["user_name", "app_name", "order_id", "order_total", "item_count", "delivery_date", "tracking_link"]
),
NotificationTemplate(
id="sms_verification",
name="SMS Verification",
channel=NotificationChannel.SMS,
category=NotificationCategory.SECURITY,
body_template="Your {{app_name}} verification code is: {{code}}. Valid for {{expiry_minutes}} minutes.",
variables=["app_name", "code", "expiry_minutes"]
),
NotificationTemplate(
id="push_reminder",
name="Push Reminder",
channel=NotificationChannel.PUSH,
category=NotificationCategory.UPDATE,
body_template="{{reminder_text}}",
variables=["reminder_text"]
),
NotificationTemplate(
id="in_app_alert",
name="In-App Alert",
channel=NotificationChannel.IN_APP,
category=NotificationCategory.SYSTEM,
body_template="{{alert_message}}",
variables=["alert_message"]
),
NotificationTemplate(
id="weekly_digest",
name="Weekly Digest",
channel=NotificationChannel.EMAIL,
category=NotificationCategory.MARKETING,
subject_template="Your Weekly {{app_name}} Digest",
body_template="""
Hi {{user_name}},
Here's what happened this week on {{app_name}}:
📊 Stats:
- New connections: {{new_connections}}
- Messages received: {{messages_count}}
- Activities completed: {{activities_count}}
🔥 Trending:
{{trending_items}}
💡 Tip of the week:
{{weekly_tip}}
See you next week!
The {{app_name}} Team
""",
variables=["user_name", "app_name", "new_connections", "messages_count", "activities_count", "trending_items", "weekly_tip"]
),
NotificationTemplate(
id="friend_request",
name="Friend Request",
channel=NotificationChannel.IN_APP,
category=NotificationCategory.SOCIAL,
body_template="{{sender_name}} sent you a friend request. {{personal_message}}",
variables=["sender_name", "personal_message"]
)
]
for template in default_templates:
self.templates[template.id] = template
async def create_template(self, template: NotificationTemplate) -> str:
"""Create a new template"""
if not template.id:
template.id = str(uuid.uuid4())
# Validate template
if not self._validate_template(template):
raise ValueError("Invalid template format")
# Extract variables from template
template.variables = self._extract_variables(template.body_template)
if template.subject_template:
template.variables.extend(self._extract_variables(template.subject_template))
template.variables = list(set(template.variables)) # Remove duplicates
# Store template
self.templates[template.id] = template
logger.info(f"Created template: {template.id}")
return template.id
async def update_template(self, template_id: str, template: NotificationTemplate) -> bool:
"""Update an existing template"""
if template_id not in self.templates:
return False
# Validate template
if not self._validate_template(template):
raise ValueError("Invalid template format")
# Update template
template.id = template_id
template.updated_at = datetime.now()
# Re-extract variables
template.variables = self._extract_variables(template.body_template)
if template.subject_template:
template.variables.extend(self._extract_variables(template.subject_template))
template.variables = list(set(template.variables))
self.templates[template_id] = template
logger.info(f"Updated template: {template_id}")
return True
async def get_template(self, template_id: str) -> Optional[NotificationTemplate]:
"""Get a template by ID"""
return self.templates.get(template_id)
async def get_all_templates(self) -> List[NotificationTemplate]:
"""Get all templates"""
return list(self.templates.values())
async def delete_template(self, template_id: str) -> bool:
"""Delete a template"""
if template_id in self.templates:
del self.templates[template_id]
logger.info(f"Deleted template: {template_id}")
return True
return False
async def render_template(self, template: NotificationTemplate, variables: Dict[str, Any]) -> str:
"""Render a template with variables"""
if not template:
raise ValueError("Template not provided")
# Start with body template
rendered = template.body_template
# Replace variables
for var_name in template.variables:
placeholder = f"{{{{{var_name}}}}}"
value = variables.get(var_name, f"[{var_name}]") # Default to placeholder if not provided
# Convert non-string values to string
if not isinstance(value, str):
value = str(value)
rendered = rendered.replace(placeholder, value)
# Clean up extra whitespace
rendered = re.sub(r'\n\s*\n', '\n\n', rendered.strip())
return rendered
async def render_subject(self, template: NotificationTemplate, variables: Dict[str, Any]) -> Optional[str]:
"""Render a template subject with variables"""
if not template or not template.subject_template:
return None
rendered = template.subject_template
# Replace variables
for var_name in self._extract_variables(template.subject_template):
placeholder = f"{{{{{var_name}}}}}"
value = variables.get(var_name, f"[{var_name}]")
if not isinstance(value, str):
value = str(value)
rendered = rendered.replace(placeholder, value)
return rendered
def _validate_template(self, template: NotificationTemplate) -> bool:
"""Validate template format"""
if not template.name or not template.body_template:
return False
# Check for basic template syntax
try:
# Check for balanced braces
open_count = template.body_template.count("{{")
close_count = template.body_template.count("}}")
if open_count != close_count:
return False
if template.subject_template:
open_count = template.subject_template.count("{{")
close_count = template.subject_template.count("}}")
if open_count != close_count:
return False
return True
except Exception as e:
logger.error(f"Template validation error: {e}")
return False
def _extract_variables(self, template_text: str) -> List[str]:
"""Extract variable names from template text"""
if not template_text:
return []
# Find all {{variable_name}} patterns
pattern = r'\{\{(\w+)\}\}'
matches = re.findall(pattern, template_text)
return list(set(matches)) # Return unique variable names
async def get_templates_by_channel(self, channel: NotificationChannel) -> List[NotificationTemplate]:
"""Get templates for a specific channel"""
return [t for t in self.templates.values() if t.channel == channel]
async def get_templates_by_category(self, category: NotificationCategory) -> List[NotificationTemplate]:
"""Get templates for a specific category"""
return [t for t in self.templates.values() if t.category == category]
async def clone_template(self, template_id: str, new_name: str) -> str:
"""Clone an existing template"""
original = self.templates.get(template_id)
if not original:
raise ValueError(f"Template {template_id} not found")
# Create new template
new_template = NotificationTemplate(
id=str(uuid.uuid4()),
name=new_name,
channel=original.channel,
category=original.category,
subject_template=original.subject_template,
body_template=original.body_template,
variables=original.variables.copy(),
metadata=original.metadata.copy(),
is_active=True,
created_at=datetime.now()
)
self.templates[new_template.id] = new_template
logger.info(f"Cloned template {template_id} to {new_template.id}")
return new_template.id

View File

@ -0,0 +1,268 @@
"""
Test script for Notification Service
"""
import asyncio
import httpx
import websockets
import json
from datetime import datetime, timedelta
BASE_URL = "http://localhost:8013"
WS_URL = "ws://localhost:8013/ws/notifications"
async def test_notification_api():
"""Test notification API endpoints"""
async with httpx.AsyncClient() as client:
print("\n🔔 Testing Notification Service API...")
# Test health check
print("\n1. Testing health check...")
response = await client.get(f"{BASE_URL}/health")
print(f"Health check: {response.json()}")
# Test sending single notification
print("\n2. Testing single notification...")
notification_data = {
"user_id": "test_user_123",
"title": "Welcome to Our App!",
"message": "Thank you for joining our platform. We're excited to have you!",
"channels": ["in_app", "email"],
"priority": "high",
"category": "system",
"data": {
"action_url": "https://example.com/welcome",
"icon": "welcome"
}
}
response = await client.post(
f"{BASE_URL}/api/notifications/send",
json=notification_data
)
notification_result = response.json()
print(f"Notification sent: {notification_result}")
notification_id = notification_result.get("notification_id")
# Test bulk notifications
print("\n3. Testing bulk notifications...")
bulk_data = {
"user_ids": ["user1", "user2", "user3"],
"title": "System Maintenance Notice",
"message": "We will be performing system maintenance tonight from 2-4 AM.",
"channels": ["in_app", "push"],
"priority": "normal",
"category": "update"
}
response = await client.post(
f"{BASE_URL}/api/notifications/send-bulk",
json=bulk_data
)
print(f"Bulk notifications: {response.json()}")
# Test scheduled notification
print("\n4. Testing scheduled notification...")
scheduled_time = datetime.now() + timedelta(minutes=5)
scheduled_data = {
"user_id": "test_user_123",
"title": "Reminder: Meeting in 5 minutes",
"message": "Your scheduled meeting is about to start.",
"channels": ["in_app", "push"],
"priority": "urgent",
"category": "system",
"schedule_at": scheduled_time.isoformat()
}
response = await client.post(
f"{BASE_URL}/api/notifications/send",
json=scheduled_data
)
print(f"Scheduled notification: {response.json()}")
# Test get user notifications
print("\n5. Testing get user notifications...")
response = await client.get(
f"{BASE_URL}/api/notifications/user/test_user_123"
)
notifications = response.json()
print(f"User notifications: Found {notifications['count']} notifications")
# Test mark as read
if notification_id:
print("\n6. Testing mark as read...")
response = await client.patch(
f"{BASE_URL}/api/notifications/{notification_id}/read"
)
print(f"Mark as read: {response.json()}")
# Test templates
print("\n7. Testing templates...")
response = await client.get(f"{BASE_URL}/api/templates")
templates = response.json()
print(f"Available templates: {len(templates['templates'])} templates")
# Test preferences
print("\n8. Testing user preferences...")
# Get preferences
response = await client.get(
f"{BASE_URL}/api/preferences/test_user_123"
)
print(f"Current preferences: {response.json()}")
# Update preferences
new_preferences = {
"user_id": "test_user_123",
"channels": {
"email": True,
"sms": False,
"push": True,
"in_app": True
},
"categories": {
"system": True,
"marketing": False,
"transaction": True,
"social": True,
"security": True,
"update": True
},
"email_frequency": "daily",
"timezone": "America/New_York",
"language": "en"
}
response = await client.put(
f"{BASE_URL}/api/preferences/test_user_123",
json=new_preferences
)
print(f"Update preferences: {response.json()}")
# Test unsubscribe
response = await client.post(
f"{BASE_URL}/api/preferences/test_user_123/unsubscribe/marketing"
)
print(f"Unsubscribe from marketing: {response.json()}")
# Test notification with template
print("\n9. Testing notification with template...")
template_notification = {
"user_id": "test_user_123",
"title": "Password Reset Request",
"message": "", # Will be filled by template
"channels": ["email"],
"priority": "high",
"category": "security",
"template_id": "password_reset",
"data": {
"user_name": "John Doe",
"app_name": "Our App",
"reset_link": "https://example.com/reset/abc123",
"expiry_hours": 24
}
}
response = await client.post(
f"{BASE_URL}/api/notifications/send",
json=template_notification
)
print(f"Template notification: {response.json()}")
# Test queue status
print("\n10. Testing queue status...")
response = await client.get(f"{BASE_URL}/api/queue/status")
print(f"Queue status: {response.json()}")
# Test analytics
print("\n11. Testing analytics...")
response = await client.get(f"{BASE_URL}/api/analytics")
analytics = response.json()
print(f"Analytics overview: {analytics}")
# Test notification history
print("\n12. Testing notification history...")
response = await client.get(
f"{BASE_URL}/api/history",
params={"user_id": "test_user_123", "limit": 10}
)
history = response.json()
print(f"Notification history: {history['count']} entries")
# Test device registration
print("\n13. Testing device registration...")
response = await client.post(
f"{BASE_URL}/api/devices/register",
params={
"user_id": "test_user_123",
"device_token": "dummy_token_12345",
"device_type": "ios"
}
)
print(f"Device registration: {response.json()}")
async def test_websocket():
"""Test WebSocket connection for real-time notifications"""
print("\n\n🌐 Testing WebSocket Connection...")
try:
uri = f"{WS_URL}/test_user_123"
async with websockets.connect(uri) as websocket:
print(f"Connected to WebSocket at {uri}")
# Listen for welcome message
message = await websocket.recv()
data = json.loads(message)
print(f"Welcome message: {data}")
# Send ping
await websocket.send("ping")
pong = await websocket.recv()
print(f"Ping response: {pong}")
# Send notification via API while connected
async with httpx.AsyncClient() as client:
notification_data = {
"user_id": "test_user_123",
"title": "Real-time Test",
"message": "This should appear in WebSocket!",
"channels": ["in_app"],
"priority": "normal"
}
response = await client.post(
f"{BASE_URL}/api/notifications/send",
json=notification_data
)
print(f"Sent notification: {response.json()}")
# Wait for real-time notification
print("Waiting for real-time notification...")
try:
notification = await asyncio.wait_for(websocket.recv(), timeout=5.0)
print(f"Received real-time notification: {json.loads(notification)}")
except asyncio.TimeoutError:
print("No real-time notification received (timeout)")
print("WebSocket test completed")
except Exception as e:
print(f"WebSocket error: {e}")
async def main():
"""Run all tests"""
print("=" * 60)
print("NOTIFICATION SERVICE TEST SUITE")
print("=" * 60)
# Test API endpoints
await test_notification_api()
# Test WebSocket
await test_websocket()
print("\n" + "=" * 60)
print("✅ All tests completed!")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,194 @@
"""
WebSocket Server for real-time notifications
"""
import logging
import json
from typing import Dict, List
from fastapi import WebSocket
from datetime import datetime
logger = logging.getLogger(__name__)
class WebSocketNotificationServer:
"""Manages WebSocket connections for real-time notifications"""
def __init__(self):
# Store connections by user_id
self.active_connections: Dict[str, List[WebSocket]] = {}
self.connection_metadata: Dict[WebSocket, Dict] = {}
async def connect(self, websocket: WebSocket, user_id: str):
"""Accept a new WebSocket connection"""
await websocket.accept()
# Add to active connections
if user_id not in self.active_connections:
self.active_connections[user_id] = []
self.active_connections[user_id].append(websocket)
# Store metadata
self.connection_metadata[websocket] = {
"user_id": user_id,
"connected_at": datetime.now(),
"last_activity": datetime.now()
}
logger.info(f"WebSocket connected for user {user_id}. Total connections: {len(self.active_connections[user_id])}")
# Send welcome message
await self.send_welcome_message(websocket, user_id)
def disconnect(self, user_id: str):
"""Remove a WebSocket connection"""
if user_id in self.active_connections:
# Remove all connections for this user
for websocket in self.active_connections[user_id]:
if websocket in self.connection_metadata:
del self.connection_metadata[websocket]
del self.active_connections[user_id]
logger.info(f"WebSocket disconnected for user {user_id}")
async def send_to_user(self, user_id: str, message: Dict):
"""Send a message to all connections for a specific user"""
if user_id not in self.active_connections:
logger.debug(f"No active connections for user {user_id}")
return False
disconnected = []
for websocket in self.active_connections[user_id]:
try:
await websocket.send_json(message)
# Update last activity
if websocket in self.connection_metadata:
self.connection_metadata[websocket]["last_activity"] = datetime.now()
except Exception as e:
logger.error(f"Error sending to WebSocket for user {user_id}: {e}")
disconnected.append(websocket)
# Remove disconnected websockets
for ws in disconnected:
self.active_connections[user_id].remove(ws)
if ws in self.connection_metadata:
del self.connection_metadata[ws]
# Clean up if no more connections
if not self.active_connections[user_id]:
del self.active_connections[user_id]
return True
async def broadcast(self, message: Dict):
"""Broadcast a message to all connected users"""
for user_id in list(self.active_connections.keys()):
await self.send_to_user(user_id, message)
async def send_notification(self, user_id: str, notification: Dict):
"""Send a notification to a specific user"""
message = {
"type": "notification",
"timestamp": datetime.now().isoformat(),
"data": notification
}
return await self.send_to_user(user_id, message)
async def send_welcome_message(self, websocket: WebSocket, user_id: str):
"""Send a welcome message to newly connected user"""
welcome_message = {
"type": "connection",
"status": "connected",
"user_id": user_id,
"timestamp": datetime.now().isoformat(),
"message": "Connected to notification service"
}
try:
await websocket.send_json(welcome_message)
except Exception as e:
logger.error(f"Error sending welcome message: {e}")
def get_connection_count(self, user_id: str = None) -> int:
"""Get the number of active connections"""
if user_id:
return len(self.active_connections.get(user_id, []))
total = 0
for connections in self.active_connections.values():
total += len(connections)
return total
def get_connected_users(self) -> List[str]:
"""Get list of connected user IDs"""
return list(self.active_connections.keys())
async def send_system_message(self, user_id: str, message: str, severity: str = "info"):
"""Send a system message to a user"""
system_message = {
"type": "system",
"severity": severity,
"message": message,
"timestamp": datetime.now().isoformat()
}
return await self.send_to_user(user_id, system_message)
async def send_presence_update(self, user_id: str, status: str):
"""Send presence update to user's connections"""
presence_message = {
"type": "presence",
"user_id": user_id,
"status": status,
"timestamp": datetime.now().isoformat()
}
# Could send to friends/contacts if implemented
return await self.send_to_user(user_id, presence_message)
async def handle_ping(self, websocket: WebSocket):
"""Handle ping message from client"""
try:
await websocket.send_json({
"type": "pong",
"timestamp": datetime.now().isoformat()
})
# Update last activity
if websocket in self.connection_metadata:
self.connection_metadata[websocket]["last_activity"] = datetime.now()
except Exception as e:
logger.error(f"Error handling ping: {e}")
async def cleanup_stale_connections(self, timeout_minutes: int = 30):
"""Clean up stale connections that haven't been active"""
now = datetime.now()
stale_connections = []
for websocket, metadata in self.connection_metadata.items():
last_activity = metadata.get("last_activity")
if last_activity:
time_diff = (now - last_activity).total_seconds() / 60
if time_diff > timeout_minutes:
stale_connections.append({
"websocket": websocket,
"user_id": metadata.get("user_id")
})
# Remove stale connections
for conn in stale_connections:
user_id = conn["user_id"]
websocket = conn["websocket"]
if user_id in self.active_connections:
if websocket in self.active_connections[user_id]:
self.active_connections[user_id].remove(websocket)
# Clean up if no more connections
if not self.active_connections[user_id]:
del self.active_connections[user_id]
if websocket in self.connection_metadata:
del self.connection_metadata[websocket]
logger.info(f"Cleaned up stale connection for user {user_id}")
return len(stale_connections)

View File

@ -0,0 +1,21 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Expose port
EXPOSE 8000
# Run the application
CMD ["python", "main.py"]

View File

@ -0,0 +1,142 @@
from motor.motor_asyncio import AsyncIOMotorClient
from beanie import init_beanie
import os
from models import OAuthApplication, AuthorizationCode, AccessToken, OAuthScope, UserConsent
async def init_db():
client = AsyncIOMotorClient(os.getenv("MONGODB_URL", "mongodb://mongodb:27017"))
database = client[os.getenv("OAUTH_DB_NAME", "oauth_db")]
await init_beanie(
database=database,
document_models=[
OAuthApplication,
AuthorizationCode,
AccessToken,
OAuthScope,
UserConsent
]
)
# 기본 스코프 생성
await create_default_scopes()
async def create_default_scopes():
"""기본 OAuth 스코프 생성"""
default_scopes = [
# 기본 인증 스코프
{
"name": "openid",
"display_name": "OpenID Connect",
"description": "기본 사용자 인증 정보",
"is_default": True,
"requires_approval": False
},
{
"name": "profile",
"display_name": "프로필 정보",
"description": "이름, 프로필 이미지, 기본 정보 접근",
"is_default": True,
"requires_approval": True
},
{
"name": "email",
"display_name": "이메일 주소",
"description": "이메일 주소 및 인증 상태 확인",
"is_default": False,
"requires_approval": True
},
{
"name": "picture",
"display_name": "프로필 사진",
"description": "프로필 사진 및 썸네일 접근",
"is_default": False,
"requires_approval": True
},
# 사용자 데이터 접근 스코프
{
"name": "user:read",
"display_name": "사용자 정보 읽기",
"description": "사용자 프로필 및 설정 읽기",
"is_default": False,
"requires_approval": True
},
{
"name": "user:write",
"display_name": "사용자 정보 수정",
"description": "사용자 프로필 및 설정 수정",
"is_default": False,
"requires_approval": True
},
# 애플리케이션 관리 스코프
{
"name": "app:read",
"display_name": "애플리케이션 정보 읽기",
"description": "OAuth 애플리케이션 정보 조회",
"is_default": False,
"requires_approval": True
},
{
"name": "app:write",
"display_name": "애플리케이션 관리",
"description": "OAuth 애플리케이션 생성 및 수정",
"is_default": False,
"requires_approval": True
},
# 조직/팀 관련 스코프
{
"name": "org:read",
"display_name": "조직 정보 읽기",
"description": "소속 조직 및 팀 정보 조회",
"is_default": False,
"requires_approval": True
},
{
"name": "org:write",
"display_name": "조직 관리",
"description": "조직 설정 및 멤버 관리",
"is_default": False,
"requires_approval": True
},
# API 접근 스코프
{
"name": "api:read",
"display_name": "API 데이터 읽기",
"description": "API를 통한 데이터 조회",
"is_default": False,
"requires_approval": True
},
{
"name": "api:write",
"display_name": "API 데이터 쓰기",
"description": "API를 통한 데이터 생성/수정/삭제",
"is_default": False,
"requires_approval": True
},
# 특수 스코프
{
"name": "offline_access",
"display_name": "오프라인 액세스",
"description": "리프레시 토큰 발급 (장기 액세스)",
"is_default": False,
"requires_approval": True
},
{
"name": "admin",
"display_name": "관리자 권한",
"description": "전체 시스템 관리 권한",
"is_default": False,
"requires_approval": True
}
]
for scope_data in default_scopes:
existing = await OAuthScope.find_one(OAuthScope.name == scope_data["name"])
if not existing:
scope = OAuthScope(**scope_data)
await scope.create()

View File

@ -0,0 +1,591 @@
from fastapi import FastAPI, HTTPException, Depends, Form, Query, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, JSONResponse
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Optional, List, Dict
import uvicorn
import os
import sys
import logging
from database import init_db
from models import (
OAuthApplication, AuthorizationCode, AccessToken,
OAuthScope, UserConsent, GrantType, ResponseType
)
from utils import OAuthUtils, TokenGenerator, ScopeValidator
from pydantic import BaseModel, Field
from beanie import PydanticObjectId
sys.path.append('/app')
from shared.kafka import KafkaProducer, Event, EventType
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Pydantic models
class ApplicationCreate(BaseModel):
name: str
description: Optional[str] = None
redirect_uris: List[str]
website_url: Optional[str] = None
logo_url: Optional[str] = None
privacy_policy_url: Optional[str] = None
terms_url: Optional[str] = None
sso_enabled: Optional[bool] = False
sso_provider: Optional[str] = None
sso_config: Optional[Dict] = None
allowed_domains: Optional[List[str]] = None
class ApplicationUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
redirect_uris: Optional[List[str]] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
privacy_policy_url: Optional[str] = None
terms_url: Optional[str] = None
is_active: Optional[bool] = None
sso_enabled: Optional[bool] = None
sso_provider: Optional[str] = None
sso_config: Optional[Dict] = None
allowed_domains: Optional[List[str]] = None
class ApplicationResponse(BaseModel):
id: str
client_id: str
name: str
description: Optional[str]
redirect_uris: List[str]
allowed_scopes: List[str]
grant_types: List[str]
is_active: bool
is_trusted: bool
sso_enabled: bool
sso_provider: Optional[str]
allowed_domains: List[str]
website_url: Optional[str]
logo_url: Optional[str]
created_at: datetime
class TokenRequest(BaseModel):
grant_type: str
code: Optional[str] = None
redirect_uri: Optional[str] = None
client_id: Optional[str] = None
client_secret: Optional[str] = None
refresh_token: Optional[str] = None
scope: Optional[str] = None
code_verifier: Optional[str] = None
# Global Kafka producer
kafka_producer: Optional[KafkaProducer] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global kafka_producer
await init_db()
# Initialize Kafka producer
try:
kafka_producer = KafkaProducer(
bootstrap_servers=os.getenv('KAFKA_BOOTSTRAP_SERVERS', 'kafka:9092')
)
await kafka_producer.start()
logger.info("Kafka producer initialized")
except Exception as e:
logger.warning(f"Failed to initialize Kafka producer: {e}")
kafka_producer = None
yield
# Shutdown
if kafka_producer:
await kafka_producer.stop()
app = FastAPI(
title="OAuth 2.0 Service",
description="OAuth 2.0 인증 서버 및 애플리케이션 관리",
version="1.0.0",
lifespan=lifespan
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health check
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "oauth",
"timestamp": datetime.now().isoformat()
}
# OAuth Application Management
@app.post("/applications", response_model=ApplicationResponse, status_code=201)
async def create_application(
app_data: ApplicationCreate,
current_user_id: str = "test_user" # TODO: Get from JWT token
):
"""새로운 OAuth 애플리케이션 등록"""
client_id = OAuthUtils.generate_client_id()
client_secret = OAuthUtils.generate_client_secret()
hashed_secret = OAuthUtils.hash_client_secret(client_secret)
# 기본 스코프 가져오기
default_scopes = await OAuthScope.find(OAuthScope.is_default == True).to_list()
allowed_scopes = [scope.name for scope in default_scopes]
application = OAuthApplication(
client_id=client_id,
client_secret=hashed_secret,
name=app_data.name,
description=app_data.description,
owner_id=current_user_id,
redirect_uris=app_data.redirect_uris,
allowed_scopes=allowed_scopes,
grant_types=[GrantType.AUTHORIZATION_CODE, GrantType.REFRESH_TOKEN],
sso_enabled=app_data.sso_enabled or False,
sso_provider=app_data.sso_provider,
sso_config=app_data.sso_config or {},
allowed_domains=app_data.allowed_domains or [],
website_url=app_data.website_url,
logo_url=app_data.logo_url,
privacy_policy_url=app_data.privacy_policy_url,
terms_url=app_data.terms_url
)
await application.create()
# 이벤트 발행
if kafka_producer:
event = Event(
event_type=EventType.TASK_CREATED,
service="oauth",
data={
"app_id": str(application.id),
"client_id": client_id,
"name": application.name,
"owner_id": current_user_id
}
)
await kafka_producer.send_event("oauth-events", event)
# 클라이언트 시크릿은 생성 시에만 반환
return {
**ApplicationResponse(
id=str(application.id),
client_id=application.client_id,
name=application.name,
description=application.description,
redirect_uris=application.redirect_uris,
allowed_scopes=application.allowed_scopes,
grant_types=[gt.value for gt in application.grant_types],
is_active=application.is_active,
is_trusted=application.is_trusted,
sso_enabled=application.sso_enabled,
sso_provider=application.sso_provider,
allowed_domains=application.allowed_domains,
website_url=application.website_url,
logo_url=application.logo_url,
created_at=application.created_at
).dict(),
"client_secret": client_secret # 최초 생성 시에만 반환
}
@app.get("/applications", response_model=List[ApplicationResponse])
async def list_applications(
owner_id: Optional[str] = None,
is_active: Optional[bool] = None
):
"""OAuth 애플리케이션 목록 조회"""
query = {}
if owner_id:
query["owner_id"] = owner_id
if is_active is not None:
query["is_active"] = is_active
applications = await OAuthApplication.find(query).to_list()
return [
ApplicationResponse(
id=str(app.id),
client_id=app.client_id,
name=app.name,
description=app.description,
redirect_uris=app.redirect_uris,
allowed_scopes=app.allowed_scopes,
grant_types=[gt.value for gt in app.grant_types],
is_active=app.is_active,
is_trusted=app.is_trusted,
sso_enabled=app.sso_enabled,
sso_provider=app.sso_provider,
allowed_domains=app.allowed_domains,
website_url=app.website_url,
logo_url=app.logo_url,
created_at=app.created_at
)
for app in applications
]
@app.get("/applications/{client_id}", response_model=ApplicationResponse)
async def get_application(client_id: str):
"""OAuth 애플리케이션 상세 조회"""
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
if not application:
raise HTTPException(status_code=404, detail="Application not found")
return ApplicationResponse(
id=str(application.id),
client_id=application.client_id,
name=application.name,
description=application.description,
redirect_uris=application.redirect_uris,
allowed_scopes=application.allowed_scopes,
grant_types=[gt.value for gt in application.grant_types],
is_active=application.is_active,
is_trusted=application.is_trusted,
sso_enabled=application.sso_enabled,
sso_provider=application.sso_provider,
allowed_domains=application.allowed_domains,
website_url=application.website_url,
logo_url=application.logo_url,
created_at=application.created_at
)
# OAuth 2.0 Authorization Endpoint
@app.get("/authorize")
async def authorize(
response_type: str = Query(..., description="응답 타입 (code, token)"),
client_id: str = Query(..., description="클라이언트 ID"),
redirect_uri: str = Query(..., description="리다이렉트 URI"),
scope: str = Query("", description="요청 스코프"),
state: Optional[str] = Query(None, description="상태 값"),
code_challenge: Optional[str] = Query(None, description="PKCE challenge"),
code_challenge_method: Optional[str] = Query("S256", description="PKCE method"),
current_user_id: str = "test_user" # TODO: Get from session/JWT
):
"""OAuth 2.0 인증 엔드포인트"""
# 애플리케이션 확인
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
if not application or not application.is_active:
raise HTTPException(status_code=400, detail="Invalid client")
# 리다이렉트 URI 확인
if redirect_uri not in application.redirect_uris:
raise HTTPException(status_code=400, detail="Invalid redirect URI")
# 스코프 검증
requested_scopes = ScopeValidator.parse_scope_string(scope)
valid_scopes = ScopeValidator.validate_scopes(requested_scopes, application.allowed_scopes)
# 사용자 동의 확인 (신뢰할 수 있는 앱이거나 이미 동의한 경우 건너뛰기)
if not application.is_trusted:
consent = await UserConsent.find_one(
UserConsent.user_id == current_user_id,
UserConsent.client_id == client_id
)
if not consent or set(valid_scopes) - set(consent.granted_scopes):
# TODO: 동의 화면으로 리다이렉트
pass
if response_type == "code":
# Authorization Code Flow
code = OAuthUtils.generate_authorization_code()
auth_code = AuthorizationCode(
code=code,
client_id=client_id,
user_id=current_user_id,
redirect_uri=redirect_uri,
scopes=valid_scopes,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
expires_at=datetime.now() + timedelta(minutes=10)
)
await auth_code.create()
# 리다이렉트 URL 생성
redirect_url = f"{redirect_uri}?code={code}"
if state:
redirect_url += f"&state={state}"
return RedirectResponse(url=redirect_url)
elif response_type == "token":
# Implicit Flow (권장하지 않음)
raise HTTPException(status_code=400, detail="Implicit flow not supported")
else:
raise HTTPException(status_code=400, detail="Unsupported response type")
# OAuth 2.0 Token Endpoint
@app.post("/token")
async def token(
grant_type: str = Form(...),
code: Optional[str] = Form(None),
redirect_uri: Optional[str] = Form(None),
client_id: Optional[str] = Form(None),
client_secret: Optional[str] = Form(None),
refresh_token: Optional[str] = Form(None),
scope: Optional[str] = Form(None),
code_verifier: Optional[str] = Form(None)
):
"""OAuth 2.0 토큰 엔드포인트"""
# 클라이언트 인증
if not client_id or not client_secret:
raise HTTPException(
status_code=401,
detail="Client authentication required",
headers={"WWW-Authenticate": "Basic"}
)
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
if not application or not OAuthUtils.verify_client_secret(client_secret, application.client_secret):
raise HTTPException(status_code=401, detail="Invalid client credentials")
if grant_type == "authorization_code":
# Authorization Code Grant
if not code or not redirect_uri:
raise HTTPException(status_code=400, detail="Missing required parameters")
auth_code = await AuthorizationCode.find_one(
AuthorizationCode.code == code,
AuthorizationCode.client_id == client_id
)
if not auth_code:
raise HTTPException(status_code=400, detail="Invalid authorization code")
if auth_code.used:
raise HTTPException(status_code=400, detail="Authorization code already used")
if auth_code.expires_at < datetime.now():
raise HTTPException(status_code=400, detail="Authorization code expired")
if auth_code.redirect_uri != redirect_uri:
raise HTTPException(status_code=400, detail="Redirect URI mismatch")
# PKCE 검증
if auth_code.code_challenge:
if not code_verifier:
raise HTTPException(status_code=400, detail="Code verifier required")
if not OAuthUtils.verify_pkce_challenge(
code_verifier,
auth_code.code_challenge,
auth_code.code_challenge_method
):
raise HTTPException(status_code=400, detail="Invalid code verifier")
# 코드를 사용됨으로 표시
auth_code.used = True
auth_code.used_at = datetime.now()
await auth_code.save()
# 토큰 생성
access_token = OAuthUtils.generate_access_token()
refresh_token = OAuthUtils.generate_refresh_token()
token_doc = AccessToken(
token=access_token,
refresh_token=refresh_token,
client_id=client_id,
user_id=auth_code.user_id,
scopes=auth_code.scopes,
expires_at=datetime.now() + timedelta(hours=1),
refresh_expires_at=datetime.now() + timedelta(days=30)
)
await token_doc.create()
return TokenGenerator.generate_token_response(
access_token=access_token,
expires_in=3600,
refresh_token=refresh_token,
scope=" ".join(auth_code.scopes)
)
elif grant_type == "refresh_token":
# Refresh Token Grant
if not refresh_token:
raise HTTPException(status_code=400, detail="Refresh token required")
token_doc = await AccessToken.find_one(
AccessToken.refresh_token == refresh_token,
AccessToken.client_id == client_id
)
if not token_doc:
raise HTTPException(status_code=400, detail="Invalid refresh token")
if token_doc.revoked:
raise HTTPException(status_code=400, detail="Token has been revoked")
if token_doc.refresh_expires_at and token_doc.refresh_expires_at < datetime.now():
raise HTTPException(status_code=400, detail="Refresh token expired")
# 기존 토큰 폐기
token_doc.revoked = True
token_doc.revoked_at = datetime.now()
await token_doc.save()
# 새 토큰 생성
new_access_token = OAuthUtils.generate_access_token()
new_refresh_token = OAuthUtils.generate_refresh_token()
new_token_doc = AccessToken(
token=new_access_token,
refresh_token=new_refresh_token,
client_id=client_id,
user_id=token_doc.user_id,
scopes=token_doc.scopes,
expires_at=datetime.now() + timedelta(hours=1),
refresh_expires_at=datetime.now() + timedelta(days=30)
)
await new_token_doc.create()
return TokenGenerator.generate_token_response(
access_token=new_access_token,
expires_in=3600,
refresh_token=new_refresh_token,
scope=" ".join(token_doc.scopes)
)
elif grant_type == "client_credentials":
# Client Credentials Grant
requested_scopes = ScopeValidator.parse_scope_string(scope) if scope else []
valid_scopes = ScopeValidator.validate_scopes(requested_scopes, application.allowed_scopes)
access_token = OAuthUtils.generate_access_token()
token_doc = AccessToken(
token=access_token,
client_id=client_id,
scopes=valid_scopes,
expires_at=datetime.now() + timedelta(hours=1)
)
await token_doc.create()
return TokenGenerator.generate_token_response(
access_token=access_token,
expires_in=3600,
scope=" ".join(valid_scopes)
)
else:
raise HTTPException(status_code=400, detail="Unsupported grant type")
# Token Introspection Endpoint
@app.post("/introspect")
async def introspect(
token: str = Form(...),
token_type_hint: Optional[str] = Form(None),
client_id: str = Form(...),
client_secret: str = Form(...)
):
"""토큰 검증 엔드포인트"""
# 클라이언트 인증
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
if not application or not OAuthUtils.verify_client_secret(client_secret, application.client_secret):
raise HTTPException(status_code=401, detail="Invalid client credentials")
# 토큰 조회
token_doc = await AccessToken.find_one(AccessToken.token == token)
if not token_doc or token_doc.revoked or token_doc.expires_at < datetime.now():
return {"active": False}
# 토큰 사용 시간 업데이트
token_doc.last_used_at = datetime.now()
await token_doc.save()
return {
"active": True,
"scope": " ".join(token_doc.scopes),
"client_id": token_doc.client_id,
"username": token_doc.user_id,
"exp": int(token_doc.expires_at.timestamp())
}
# Token Revocation Endpoint
@app.post("/revoke")
async def revoke(
token: str = Form(...),
token_type_hint: Optional[str] = Form(None),
client_id: str = Form(...),
client_secret: str = Form(...)
):
"""토큰 폐기 엔드포인트"""
# 클라이언트 인증
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
if not application or not OAuthUtils.verify_client_secret(client_secret, application.client_secret):
raise HTTPException(status_code=401, detail="Invalid client credentials")
# 토큰 조회 및 폐기
token_doc = await AccessToken.find_one(
AccessToken.token == token,
AccessToken.client_id == client_id
)
if token_doc and not token_doc.revoked:
token_doc.revoked = True
token_doc.revoked_at = datetime.now()
await token_doc.save()
# 이벤트 발행
if kafka_producer:
event = Event(
event_type=EventType.TASK_COMPLETED,
service="oauth",
data={
"action": "token_revoked",
"token_id": str(token_doc.id),
"client_id": client_id
}
)
await kafka_producer.send_event("oauth-events", event)
return {"status": "success"}
# Scopes Management
@app.get("/scopes")
async def list_scopes():
"""사용 가능한 스코프 목록 조회"""
scopes = await OAuthScope.find_all().to_list()
return [
{
"name": scope.name,
"display_name": scope.display_name,
"description": scope.description,
"is_default": scope.is_default,
"requires_approval": scope.requires_approval
}
for scope in scopes
]
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)

View File

@ -0,0 +1,126 @@
from beanie import Document, PydanticObjectId
from pydantic import BaseModel, Field, EmailStr
from typing import Optional, List, Dict
from datetime import datetime
from enum import Enum
class GrantType(str, Enum):
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"
PASSWORD = "password"
REFRESH_TOKEN = "refresh_token"
class ResponseType(str, Enum):
CODE = "code"
TOKEN = "token"
class TokenType(str, Enum):
BEARER = "Bearer"
class OAuthApplication(Document):
"""OAuth 2.0 클라이언트 애플리케이션"""
client_id: str = Field(..., unique=True, description="클라이언트 ID")
client_secret: str = Field(..., description="클라이언트 시크릿 (해시됨)")
name: str = Field(..., description="애플리케이션 이름")
description: Optional[str] = Field(None, description="애플리케이션 설명")
owner_id: str = Field(..., description="애플리케이션 소유자 ID")
redirect_uris: List[str] = Field(default_factory=list, description="허용된 리다이렉트 URI들")
allowed_scopes: List[str] = Field(default_factory=list, description="허용된 스코프들")
grant_types: List[GrantType] = Field(default_factory=lambda: [GrantType.AUTHORIZATION_CODE], description="허용된 grant types")
is_active: bool = Field(default=True, description="활성화 상태")
is_trusted: bool = Field(default=False, description="신뢰할 수 있는 앱 (자동 승인)")
# SSO 설정
sso_enabled: bool = Field(default=False, description="SSO 활성화 여부")
sso_provider: Optional[str] = Field(None, description="SSO 제공자 (google, github, saml 등)")
sso_config: Optional[Dict] = Field(default_factory=dict, description="SSO 설정 (provider별 설정)")
allowed_domains: List[str] = Field(default_factory=list, description="SSO 허용 도메인 (예: @company.com)")
website_url: Optional[str] = Field(None, description="애플리케이션 웹사이트")
logo_url: Optional[str] = Field(None, description="애플리케이션 로고 URL")
privacy_policy_url: Optional[str] = Field(None, description="개인정보 처리방침 URL")
terms_url: Optional[str] = Field(None, description="이용약관 URL")
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Settings:
collection = "oauth_applications"
class AuthorizationCode(Document):
"""OAuth 2.0 인증 코드"""
code: str = Field(..., unique=True, description="인증 코드")
client_id: str = Field(..., description="클라이언트 ID")
user_id: str = Field(..., description="사용자 ID")
redirect_uri: str = Field(..., description="리다이렉트 URI")
scopes: List[str] = Field(default_factory=list, description="요청된 스코프")
code_challenge: Optional[str] = Field(None, description="PKCE code challenge")
code_challenge_method: Optional[str] = Field(None, description="PKCE challenge method")
expires_at: datetime = Field(..., description="만료 시간")
used: bool = Field(default=False, description="사용 여부")
used_at: Optional[datetime] = Field(None, description="사용 시간")
created_at: datetime = Field(default_factory=datetime.now)
class Settings:
collection = "authorization_codes"
class AccessToken(Document):
"""OAuth 2.0 액세스 토큰"""
token: str = Field(..., unique=True, description="액세스 토큰")
refresh_token: Optional[str] = Field(None, description="리프레시 토큰")
client_id: str = Field(..., description="클라이언트 ID")
user_id: Optional[str] = Field(None, description="사용자 ID (client credentials flow에서는 없음)")
token_type: TokenType = Field(default=TokenType.BEARER)
scopes: List[str] = Field(default_factory=list, description="부여된 스코프")
expires_at: datetime = Field(..., description="액세스 토큰 만료 시간")
refresh_expires_at: Optional[datetime] = Field(None, description="리프레시 토큰 만료 시간")
revoked: bool = Field(default=False, description="폐기 여부")
revoked_at: Optional[datetime] = Field(None, description="폐기 시간")
created_at: datetime = Field(default_factory=datetime.now)
last_used_at: Optional[datetime] = Field(None, description="마지막 사용 시간")
class Settings:
collection = "access_tokens"
class OAuthScope(Document):
"""OAuth 스코프 정의"""
name: str = Field(..., unique=True, description="스코프 이름 (예: read:profile)")
display_name: str = Field(..., description="표시 이름")
description: str = Field(..., description="스코프 설명")
is_default: bool = Field(default=False, description="기본 스코프 여부")
requires_approval: bool = Field(default=True, description="사용자 승인 필요 여부")
created_at: datetime = Field(default_factory=datetime.now)
class Settings:
collection = "oauth_scopes"
class UserConsent(Document):
"""사용자 동의 기록"""
user_id: str = Field(..., description="사용자 ID")
client_id: str = Field(..., description="클라이언트 ID")
granted_scopes: List[str] = Field(default_factory=list, description="승인된 스코프")
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
expires_at: Optional[datetime] = Field(None, description="동의 만료 시간")
class Settings:
collection = "user_consents"
indexes = [
[("user_id", 1), ("client_id", 1)]
]

View File

@ -0,0 +1,11 @@
fastapi==0.109.0
uvicorn[standard]==0.27.0
pydantic[email]==2.5.3
pymongo==4.6.1
motor==3.3.2
beanie==1.23.6
authlib==1.3.0
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
python-multipart==0.0.6
aiokafka==0.10.0

View File

@ -0,0 +1,131 @@
import secrets
import hashlib
import base64
from datetime import datetime, timedelta
from typing import Optional, List
from passlib.context import CryptContext
from jose import JWTError, jwt
import os
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class OAuthUtils:
@staticmethod
def generate_client_id() -> str:
"""클라이언트 ID 생성"""
return secrets.token_urlsafe(24)
@staticmethod
def generate_client_secret() -> str:
"""클라이언트 시크릿 생성"""
return secrets.token_urlsafe(32)
@staticmethod
def hash_client_secret(secret: str) -> str:
"""클라이언트 시크릿 해싱"""
return pwd_context.hash(secret)
@staticmethod
def verify_client_secret(plain_secret: str, hashed_secret: str) -> bool:
"""클라이언트 시크릿 검증"""
return pwd_context.verify(plain_secret, hashed_secret)
@staticmethod
def generate_authorization_code() -> str:
"""인증 코드 생성"""
return secrets.token_urlsafe(32)
@staticmethod
def generate_access_token() -> str:
"""액세스 토큰 생성"""
return secrets.token_urlsafe(32)
@staticmethod
def generate_refresh_token() -> str:
"""리프레시 토큰 생성"""
return secrets.token_urlsafe(48)
@staticmethod
def verify_pkce_challenge(verifier: str, challenge: str, method: str = "S256") -> bool:
"""PKCE challenge 검증"""
if method == "plain":
return verifier == challenge
elif method == "S256":
verifier_hash = hashlib.sha256(verifier.encode()).digest()
verifier_challenge = base64.urlsafe_b64encode(verifier_hash).decode().rstrip("=")
return verifier_challenge == challenge
return False
@staticmethod
def create_jwt_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""JWT 토큰 생성"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key")
algorithm = os.getenv("JWT_ALGORITHM", "HS256")
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt
@staticmethod
def decode_jwt_token(token: str) -> Optional[dict]:
"""JWT 토큰 디코딩"""
try:
secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key")
algorithm = os.getenv("JWT_ALGORITHM", "HS256")
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
return payload
except JWTError:
return None
class TokenGenerator:
@staticmethod
def generate_token_response(
access_token: str,
token_type: str = "Bearer",
expires_in: int = 3600,
refresh_token: Optional[str] = None,
scope: Optional[str] = None,
id_token: Optional[str] = None
) -> dict:
"""OAuth 2.0 토큰 응답 생성"""
response = {
"access_token": access_token,
"token_type": token_type,
"expires_in": expires_in
}
if refresh_token:
response["refresh_token"] = refresh_token
if scope:
response["scope"] = scope
if id_token:
response["id_token"] = id_token
return response
class ScopeValidator:
@staticmethod
def validate_scopes(requested_scopes: List[str], allowed_scopes: List[str]) -> List[str]:
"""요청된 스코프가 허용된 스코프에 포함되는지 검증"""
return [scope for scope in requested_scopes if scope in allowed_scopes]
@staticmethod
def has_scope(token_scopes: List[str], required_scope: str) -> bool:
"""토큰이 특정 스코프를 가지고 있는지 확인"""
return required_scope in token_scopes
@staticmethod
def parse_scope_string(scope_string: str) -> List[str]:
"""스코프 문자열을 리스트로 파싱"""
if not scope_string:
return []
return scope_string.strip().split()

View File

@ -0,0 +1,90 @@
# Pipeline Makefile
.PHONY: help build up down restart logs clean test monitor
help:
@echo "Pipeline Management Commands:"
@echo " make build - Build all Docker images"
@echo " make up - Start all services"
@echo " make down - Stop all services"
@echo " make restart - Restart all services"
@echo " make logs - View logs for all services"
@echo " make clean - Clean up containers and volumes"
@echo " make monitor - Open monitor dashboard"
@echo " make test - Test pipeline with sample keyword"
build:
docker-compose build
up:
docker-compose up -d
down:
docker-compose down
restart:
docker-compose restart
logs:
docker-compose logs -f
clean:
docker-compose down -v
docker system prune -f
monitor:
@echo "Opening monitor dashboard..."
@echo "Dashboard: http://localhost:8100"
@echo "API Docs: http://localhost:8100/docs"
test:
@echo "Testing pipeline with sample keyword..."
curl -X POST http://localhost:8100/api/keywords \
-H "Content-Type: application/json" \
-d '{"keyword": "테스트", "schedule": "30min"}'
@echo "\nTriggering immediate processing..."
curl -X POST http://localhost:8100/api/trigger/테스트
# Service-specific commands
scheduler-logs:
docker-compose logs -f scheduler
rss-logs:
docker-compose logs -f rss-collector
search-logs:
docker-compose logs -f google-search
summarizer-logs:
docker-compose logs -f ai-summarizer
assembly-logs:
docker-compose logs -f article-assembly
monitor-logs:
docker-compose logs -f monitor
# Database commands
redis-cli:
docker-compose exec redis redis-cli
mongo-shell:
docker-compose exec mongodb mongosh -u admin -p password123
# Queue management
queue-status:
@echo "Checking queue status..."
docker-compose exec redis redis-cli --raw LLEN queue:keyword
docker-compose exec redis redis-cli --raw LLEN queue:rss
docker-compose exec redis redis-cli --raw LLEN queue:search
docker-compose exec redis redis-cli --raw LLEN queue:summarize
docker-compose exec redis redis-cli --raw LLEN queue:assembly
queue-clear:
@echo "Clearing all queues..."
docker-compose exec redis redis-cli FLUSHDB
# Health check
health:
@echo "Checking service health..."
curl -s http://localhost:8100/api/health | python3 -m json.tool

154
services/pipeline/README.md Normal file
View File

@ -0,0 +1,154 @@
# News Pipeline System
비동기 큐 기반 뉴스 생성 파이프라인 시스템
## 아키텍처
```
Scheduler → RSS Collector → Google Search → AI Summarizer → Article Assembly → MongoDB
↓ ↓ ↓ ↓ ↓
Redis Queue Redis Queue Redis Queue Redis Queue Redis Queue
```
## 서비스 구성
### 1. Scheduler
- 30분마다 등록된 키워드 처리
- 오전 7시, 낮 12시, 저녁 6시 우선 처리
- MongoDB에서 키워드 로드 후 큐에 작업 생성
### 2. RSS Collector
- RSS 피드 수집 (Google News RSS)
- 7일간 중복 방지 (Redis Set)
- 키워드 관련성 필터링
### 3. Google Search
- RSS 아이템별 추가 검색 결과 수집
- 아이템당 최대 3개 결과
- 작업당 최대 5개 아이템 처리
### 4. AI Summarizer
- Claude Haiku로 빠른 요약 생성
- 200자 이내 한국어 요약
- 병렬 처리 지원 (3 workers)
### 5. Article Assembly
- Claude Sonnet으로 종합 기사 작성
- 1500자 이내 전문 기사
- MongoDB 저장 및 통계 업데이트
### 6. Monitor
- 실시간 파이프라인 모니터링
- 큐 상태, 워커 상태 확인
- REST API 제공 (포트 8100)
## 시작하기
### 1. 환경 변수 설정
```bash
# .env 파일 확인
CLAUDE_API_KEY=your_claude_api_key
GOOGLE_API_KEY=your_google_api_key
GOOGLE_SEARCH_ENGINE_ID=your_search_engine_id
```
### 2. 서비스 시작
```bash
cd pipeline
docker-compose up -d
```
### 3. 모니터링
```bash
# 로그 확인
docker-compose logs -f
# 특정 서비스 로그
docker-compose logs -f scheduler
# 모니터 API
curl http://localhost:8100/api/stats
```
## API 엔드포인트
### Monitor API (포트 8100)
- `GET /api/stats` - 전체 통계
- `GET /api/queues/{queue_name}` - 큐 상세 정보
- `GET /api/keywords` - 키워드 목록
- `POST /api/keywords` - 키워드 등록
- `DELETE /api/keywords/{id}` - 키워드 삭제
- `GET /api/articles` - 기사 목록
- `GET /api/articles/{id}` - 기사 상세
- `GET /api/workers` - 워커 상태
- `POST /api/trigger/{keyword}` - 수동 처리 트리거
- `GET /api/health` - 헬스 체크
## 키워드 등록 예시
```bash
# 새 키워드 등록
curl -X POST http://localhost:8100/api/keywords \
-H "Content-Type: application/json" \
-d '{"keyword": "인공지능", "schedule": "30min"}'
# 수동 처리 트리거
curl -X POST http://localhost:8100/api/trigger/인공지능
```
## 데이터베이스
### MongoDB Collections
- `keywords` - 등록된 키워드
- `articles` - 생성된 기사
- `keyword_stats` - 키워드별 통계
### Redis Keys
- `queue:*` - 작업 큐
- `processing:*` - 처리 중 작업
- `failed:*` - 실패한 작업
- `dedup:rss:*` - RSS 중복 방지
- `workers:*:active` - 활성 워커
## 트러블슈팅
### 큐 초기화
```bash
docker-compose exec redis redis-cli FLUSHDB
```
### 워커 재시작
```bash
docker-compose restart rss-collector
```
### 데이터베이스 접속
```bash
# MongoDB
docker-compose exec mongodb mongosh -u admin -p password123
# Redis
docker-compose exec redis redis-cli
```
## 스케일링
워커 수 조정:
```yaml
# docker-compose.yml
ai-summarizer:
deploy:
replicas: 5 # 워커 수 증가
```
## 모니터링 대시보드
브라우저에서 http://localhost:8100 접속하여 파이프라인 상태 확인
## 로그 레벨 설정
`.env` 파일에서 조정:
```
LOG_LEVEL=DEBUG # INFO, WARNING, ERROR
```

View File

@ -0,0 +1,19 @@
FROM python:3.11-slim
WORKDIR /app
# 의존성 설치
COPY ./ai-article-generator/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 공통 모듈 복사
COPY ./shared /app/shared
# AI Article Generator 코드 복사
COPY ./ai-article-generator /app
# 환경변수
ENV PYTHONUNBUFFERED=1
# 실행
CMD ["python", "ai_article_generator.py"]

View File

@ -0,0 +1,300 @@
"""
AI Article Generator Service
Claude API를 사용한 뉴스 기사 생성 서비스
"""
import asyncio
import logging
import os
import sys
import json
from datetime import datetime
from typing import List, Dict, Any
from anthropic import AsyncAnthropic
from motor.motor_asyncio import AsyncIOMotorClient
# Import from shared module
from shared.models import PipelineJob, EnrichedItem, FinalArticle, Subtopic, Entities, NewsReference
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AIArticleGeneratorWorker:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.claude_api_key = os.getenv("CLAUDE_API_KEY")
self.claude_client = None
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db") # ai_writer_db 사용
self.db = None
async def start(self):
"""워커 시작"""
logger.info("Starting AI Article Generator Worker")
# Redis 연결
await self.queue_manager.connect()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# Claude 클라이언트 초기화
if self.claude_api_key:
self.claude_client = AsyncAnthropic(api_key=self.claude_api_key)
else:
logger.error("Claude API key not configured")
return
# 메인 처리 루프
while True:
try:
# 큐에서 작업 가져오기
job = await self.queue_manager.dequeue('ai_article_generation', timeout=5)
if job:
await self.process_job(job)
except Exception as e:
logger.error(f"Error in worker loop: {e}")
await asyncio.sleep(1)
async def process_job(self, job: PipelineJob):
"""AI 기사 생성 작업 처리 - 단일 RSS 아이템"""
try:
start_time = datetime.now()
logger.info(f"Processing job {job.job_id} for AI article generation")
# 단일 enriched item 처리
enriched_item_data = job.data.get('enriched_item')
if not enriched_item_data:
# 이전 버전 호환성
enriched_items = job.data.get('enriched_items', [])
if enriched_items:
enriched_item_data = enriched_items[0]
else:
logger.warning(f"No enriched item in job {job.job_id}")
await self.queue_manager.mark_failed(
'ai_article_generation',
job,
"No enriched item to process"
)
return
enriched_item = EnrichedItem(**enriched_item_data)
# 기사 생성
article = await self._generate_article(job, enriched_item)
# 처리 시간 계산
processing_time = (datetime.now() - start_time).total_seconds()
article.processing_time = processing_time
# MongoDB에 저장 (ai_writer_db.articles_ko)
result = await self.db.articles_ko.insert_one(article.model_dump())
mongodb_id = str(result.inserted_id)
logger.info(f"Article {article.news_id} saved to MongoDB with _id: {mongodb_id}")
# 다음 단계로 전달 (이미지 생성)
job.data['news_id'] = article.news_id
job.data['mongodb_id'] = mongodb_id
job.stages_completed.append('ai_article_generation')
job.stage = 'image_generation'
await self.queue_manager.enqueue('image_generation', job)
await self.queue_manager.mark_completed('ai_article_generation', job.job_id)
except Exception as e:
logger.error(f"Error processing job {job.job_id}: {e}")
await self.queue_manager.mark_failed('ai_article_generation', job, str(e))
async def _generate_article(self, job: PipelineJob, enriched_item: EnrichedItem) -> FinalArticle:
"""Claude를 사용한 기사 생성"""
# RSS 아이템 정보
rss_item = enriched_item.rss_item
search_results = enriched_item.search_results
# 검색 결과 텍스트 준비 (최대 10개)
search_text = ""
if search_results:
search_text = "\n관련 검색 결과:\n"
for idx, result in enumerate(search_results[:10], 1):
search_text += f"{idx}. {result.title}\n"
if result.snippet:
search_text += f" {result.snippet}\n"
# Claude로 기사 작성
prompt = f"""다음 뉴스 정보를 바탕으로 상세한 기사를 작성해주세요.
키워드: {job.keyword}
뉴스 정보:
제목: {rss_item.title}
요약: {rss_item.summary or '내용 없음'}
링크: {rss_item.link}
{search_text}
다음 JSON 형식으로 작성해주세요:
{{
"title": "기사 제목 (50자 이내)",
"summary": "한 줄 요약 (100자 이내)",
"subtopics": [
{{
"title": "소제목1",
"content": ["문단1", "문단2", "문단3"]
}},
{{
"title": "소제목2",
"content": ["문단1", "문단2"]
}},
{{
"title": "소제목3",
"content": ["문단1", "문단2"]
}}
],
"categories": ["카테고리1", "카테고리2"],
"entities": {{
"people": ["인물1", "인물2"],
"organizations": ["조직1", "조직2"],
"groups": ["그룹1"],
"countries": ["국가1"],
"events": ["이벤트1"]
}}
}}
요구사항:
- 3개의 소제목로 구성
- 각 소제목별로 2-3개 문단
- 전문적이고 객관적인 톤
- 한국어로 작성
- 실제 정보를 바탕으로 구체적으로 작성"""
try:
response = await self.claude_client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=4000,
temperature=0.7,
messages=[
{"role": "user", "content": prompt}
]
)
# JSON 파싱
content_text = response.content[0].text
json_start = content_text.find('{')
json_end = content_text.rfind('}') + 1
if json_start != -1 and json_end > json_start:
article_data = json.loads(content_text[json_start:json_end])
else:
raise ValueError("No valid JSON in response")
# Subtopic 객체 생성
subtopics = []
for subtopic_data in article_data.get('subtopics', []):
subtopics.append(Subtopic(
title=subtopic_data.get('title', ''),
content=subtopic_data.get('content', [])
))
# Entities 객체 생성
entities_data = article_data.get('entities', {})
entities = Entities(
people=entities_data.get('people', []),
organizations=entities_data.get('organizations', []),
groups=entities_data.get('groups', []),
countries=entities_data.get('countries', []),
events=entities_data.get('events', [])
)
# 레퍼런스 생성
references = []
# RSS 원본 추가
references.append(NewsReference(
title=rss_item.title,
link=rss_item.link,
source=rss_item.source_feed,
published=rss_item.published
))
# 검색 결과 레퍼런스 추가 (최대 9개 - RSS 원본과 합쳐 총 10개)
for search_result in search_results[:9]: # 상위 9개까지
references.append(NewsReference(
title=search_result.title,
link=search_result.link,
source=search_result.source,
published=None
))
# FinalArticle 생성 (ai_writer_db.articles 스키마)
article = FinalArticle(
title=article_data.get('title', rss_item.title),
summary=article_data.get('summary', ''),
subtopics=subtopics,
categories=article_data.get('categories', []),
entities=entities,
source_keyword=job.keyword,
source_count=len(references),
references=references,
job_id=job.job_id,
keyword_id=job.keyword_id,
pipeline_stages=job.stages_completed.copy(),
language='ko',
rss_guid=rss_item.guid # RSS GUID 저장
)
return article
except Exception as e:
logger.error(f"Error generating article: {e}")
# 폴백 기사 생성
fallback_references = [NewsReference(
title=rss_item.title,
link=rss_item.link,
source=rss_item.source_feed,
published=rss_item.published
)]
return FinalArticle(
title=rss_item.title,
summary=rss_item.summary[:100] if rss_item.summary else '',
subtopics=[
Subtopic(
title="주요 내용",
content=[rss_item.summary or rss_item.title]
)
],
categories=['자동생성'],
entities=Entities(),
source_keyword=job.keyword,
source_count=1,
references=fallback_references,
job_id=job.job_id,
keyword_id=job.keyword_id,
pipeline_stages=job.stages_completed.copy(),
language='ko',
rss_guid=rss_item.guid # RSS GUID 저장
)
async def stop(self):
"""워커 중지"""
await self.queue_manager.disconnect()
logger.info("AI Article Generator Worker stopped")
async def main():
"""메인 함수"""
worker = AIArticleGeneratorWorker()
try:
await worker.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await worker.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,5 @@
anthropic==0.50.0
redis[hiredis]==5.0.1
pydantic==2.5.0
motor==3.1.1
pymongo==4.3.3

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
"""키워드 데이터베이스 확인 스크립트"""
import asyncio
from motor.motor_asyncio import AsyncIOMotorClient
from datetime import datetime
async def check_keywords():
client = AsyncIOMotorClient("mongodb://localhost:27017")
db = client.ai_writer_db
# 키워드 조회
keywords = await db.keywords.find().to_list(None)
print(f"\n=== 등록된 키워드: {len(keywords)}개 ===\n")
for kw in keywords:
print(f"키워드: {kw['keyword']}")
print(f" - ID: {kw['keyword_id']}")
print(f" - 간격: {kw['interval_minutes']}")
print(f" - 활성화: {kw['is_active']}")
print(f" - 우선순위: {kw['priority']}")
print(f" - RSS 피드: {len(kw.get('rss_feeds', []))}")
if kw.get('last_run'):
print(f" - 마지막 실행: {kw['last_run']}")
if kw.get('next_run'):
next_run = kw['next_run']
remaining = (next_run - datetime.now()).total_seconds() / 60
print(f" - 다음 실행: {next_run} ({remaining:.1f}분 후)")
print()
client.close()
if __name__ == "__main__":
asyncio.run(check_keywords())

View File

@ -0,0 +1,85 @@
{
"enabled_languages": [
{
"code": "en",
"name": "English",
"deepl_code": "EN",
"collection": "articles_en",
"enabled": true
},
{
"code": "zh-CN",
"name": "Chinese (Simplified)",
"deepl_code": "ZH",
"collection": "articles_zh_cn",
"enabled": false
},
{
"code": "zh-TW",
"name": "Chinese (Traditional)",
"deepl_code": "ZH-HANT",
"collection": "articles_zh_tw",
"enabled": false
},
{
"code": "ja",
"name": "Japanese",
"deepl_code": "JA",
"collection": "articles_ja",
"enabled": false
},
{
"code": "fr",
"name": "French",
"deepl_code": "FR",
"collection": "articles_fr",
"enabled": false
},
{
"code": "de",
"name": "German",
"deepl_code": "DE",
"collection": "articles_de",
"enabled": false
},
{
"code": "es",
"name": "Spanish",
"deepl_code": "ES",
"collection": "articles_es",
"enabled": false
},
{
"code": "pt",
"name": "Portuguese",
"deepl_code": "PT",
"collection": "articles_pt",
"enabled": false
},
{
"code": "ru",
"name": "Russian",
"deepl_code": "RU",
"collection": "articles_ru",
"enabled": false
},
{
"code": "it",
"name": "Italian",
"deepl_code": "IT",
"collection": "articles_it",
"enabled": false
}
],
"source_language": {
"code": "ko",
"name": "Korean",
"collection": "articles_ko"
},
"translation_settings": {
"batch_size": 5,
"delay_between_languages": 2.0,
"delay_between_articles": 0.5,
"max_retries": 3
}
}

View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
"""Fix import statements in all pipeline services"""
import os
import re
def fix_imports(filepath):
"""Fix import statements in a Python file"""
with open(filepath, 'r') as f:
content = f.read()
# Pattern to match the old import style
old_pattern = r"# 상위 디렉토리의 shared 모듈 import\nsys\.path\.append\(os\.path\.join\(os\.path\.dirname\(__file__\), '\.\.', 'shared'\)\)\nfrom ([\w, ]+) import ([\w, ]+)"
# Replace with new import style
def replace_imports(match):
modules = match.group(1)
items = match.group(2)
# Build new import statements
imports = []
if 'models' in modules:
imports.append(f"from shared.models import {items}" if 'models' in modules else "")
if 'queue_manager' in modules:
imports.append(f"from shared.queue_manager import QueueManager")
return "# Import from shared module\n" + "\n".join(filter(None, imports))
# Apply the replacement
new_content = re.sub(old_pattern, replace_imports, content)
# Also handle simpler patterns
new_content = new_content.replace(
"sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'shared'))\nfrom models import",
"from shared.models import"
)
new_content = new_content.replace(
"\nfrom queue_manager import",
"\nfrom shared.queue_manager import"
)
# Write back if changed
if new_content != content:
with open(filepath, 'w') as f:
f.write(new_content)
print(f"Fixed imports in {filepath}")
return True
return False
# Files to fix
files_to_fix = [
"monitor/monitor.py",
"google-search/google_search.py",
"article-assembly/article_assembly.py",
"rss-collector/rss_collector.py",
"ai-summarizer/ai_summarizer.py"
]
for file_path in files_to_fix:
full_path = os.path.join(os.path.dirname(__file__), file_path)
if os.path.exists(full_path):
fix_imports(full_path)

View File

@ -0,0 +1,19 @@
FROM python:3.11-slim
WORKDIR /app
# 의존성 설치
COPY ./google-search/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 공통 모듈 복사
COPY ./shared /app/shared
# Google Search 코드 복사
COPY ./google-search /app
# 환경변수
ENV PYTHONUNBUFFERED=1
# 실행
CMD ["python", "google_search.py"]

View File

@ -0,0 +1,152 @@
"""
Google Search Service
Google 검색으로 RSS 항목 강화
"""
import asyncio
import logging
import os
import sys
import json
from typing import List, Dict, Any
import aiohttp
from datetime import datetime
# Import from shared module
from shared.models import PipelineJob, RSSItem, SearchResult, EnrichedItem
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GoogleSearchWorker:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.google_api_key = os.getenv("GOOGLE_API_KEY")
self.search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
self.max_results_per_item = 3
async def start(self):
"""워커 시작"""
logger.info("Starting Google Search Worker")
# Redis 연결
await self.queue_manager.connect()
# 메인 처리 루프
while True:
try:
# 큐에서 작업 가져오기
job = await self.queue_manager.dequeue('search_enrichment', timeout=5)
if job:
await self.process_job(job)
except Exception as e:
logger.error(f"Error in worker loop: {e}")
await asyncio.sleep(1)
async def process_job(self, job: PipelineJob):
"""검색 강화 작업 처리 - 단일 RSS 아이템"""
try:
logger.info(f"Processing job {job.job_id} for search enrichment")
# 단일 RSS 아이템 처리
rss_item_data = job.data.get('rss_item')
if not rss_item_data:
# 이전 버전 호환성 - 여러 아이템 처리
rss_items = job.data.get('rss_items', [])
if rss_items:
rss_item_data = rss_items[0] # 첫 번째 아이템만 처리
else:
logger.warning(f"No RSS item in job {job.job_id}")
await self.queue_manager.mark_failed(
'search_enrichment',
job,
"No RSS item to process"
)
return
rss_item = RSSItem(**rss_item_data)
# 제목으로 Google 검색
search_results = await self._search_google(rss_item.title)
enriched_item = EnrichedItem(
rss_item=rss_item,
search_results=search_results
)
logger.info(f"Enriched item with {len(search_results)} search results")
# 다음 단계로 전달 - 단일 enriched item
job.data['enriched_item'] = enriched_item.dict()
job.stages_completed.append('search_enrichment')
job.stage = 'ai_article_generation'
await self.queue_manager.enqueue('ai_article_generation', job)
await self.queue_manager.mark_completed('search_enrichment', job.job_id)
except Exception as e:
logger.error(f"Error processing job {job.job_id}: {e}")
await self.queue_manager.mark_failed('search_enrichment', job, str(e))
async def _search_google(self, query: str) -> List[SearchResult]:
"""Google Custom Search API 호출"""
results = []
if not self.google_api_key or not self.search_engine_id:
logger.warning("Google API credentials not configured")
return results
try:
url = "https://www.googleapis.com/customsearch/v1"
params = {
"key": self.google_api_key,
"cx": self.search_engine_id,
"q": query,
"num": self.max_results_per_item,
"hl": "ko",
"gl": "kr"
}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, timeout=30) as response:
if response.status == 200:
data = await response.json()
for item in data.get('items', []):
result = SearchResult(
title=item.get('title', ''),
link=item.get('link', ''),
snippet=item.get('snippet', ''),
source='google'
)
results.append(result)
else:
logger.error(f"Google API error: {response.status}")
except Exception as e:
logger.error(f"Error searching Google for '{query}': {e}")
return results
async def stop(self):
"""워커 중지"""
await self.queue_manager.disconnect()
logger.info("Google Search Worker stopped")
async def main():
"""메인 함수"""
worker = GoogleSearchWorker()
try:
await worker.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await worker.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,3 @@
aiohttp==3.9.1
redis[hiredis]==5.0.1
pydantic==2.5.0

View File

@ -0,0 +1,15 @@
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY ./image-generator/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared modules
COPY ./shared /app/shared
# Copy application code
COPY ./image-generator /app
CMD ["python", "image_generator.py"]

View File

@ -0,0 +1,256 @@
"""
Image Generation Service
Replicate API를 사용한 이미지 생성 서비스
"""
import asyncio
import logging
import os
import sys
import base64
from typing import List, Dict, Any
import httpx
from io import BytesIO
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId
# Import from shared module
from shared.models import PipelineJob
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageGeneratorWorker:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.replicate_api_key = os.getenv("REPLICATE_API_TOKEN")
self.replicate_api_url = "https://api.replicate.com/v1/predictions"
# Stable Diffusion 모델 사용
self.model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.db = None
async def start(self):
"""워커 시작"""
logger.info("Starting Image Generator Worker")
# Redis 연결
await self.queue_manager.connect()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# API 키 확인
if not self.replicate_api_key:
logger.warning("Replicate API key not configured - using placeholder images")
# 메인 처리 루프
while True:
try:
# 큐에서 작업 가져오기
job = await self.queue_manager.dequeue('image_generation', timeout=5)
if job:
await self.process_job(job)
except Exception as e:
logger.error(f"Error in worker loop: {e}")
await asyncio.sleep(1)
async def process_job(self, job: PipelineJob):
"""이미지 생성 및 MongoDB 업데이트"""
try:
logger.info(f"Processing job {job.job_id} for image generation")
# MongoDB에서 기사 정보 가져오기
news_id = job.data.get('news_id')
mongodb_id = job.data.get('mongodb_id')
if not news_id:
logger.error(f"No news_id in job {job.job_id}")
await self.queue_manager.mark_failed('image_generation', job, "No news_id")
return
# MongoDB에서 한국어 기사 조회 (articles_ko)
article = await self.db.articles_ko.find_one({"news_id": news_id})
if not article:
logger.error(f"Article {news_id} not found in MongoDB")
await self.queue_manager.mark_failed('image_generation', job, "Article not found")
return
# 이미지 생성을 위한 프롬프트 생성 (한국어 기사 기반)
prompt = self._create_image_prompt_from_article(article)
# 이미지 생성 (최대 3개)
image_urls = []
for i in range(min(3, 1)): # 테스트를 위해 1개만 생성
image_url = await self._generate_image(prompt)
image_urls.append(image_url)
# API 속도 제한
if self.replicate_api_key and i < 2:
await asyncio.sleep(2)
# MongoDB 업데이트 (이미지 추가 - articles_ko)
await self.db.articles_ko.update_one(
{"news_id": news_id},
{
"$set": {
"images": image_urls,
"image_prompt": prompt
},
"$addToSet": {
"pipeline_stages": "image_generation"
}
}
)
logger.info(f"Updated article {news_id} with {len(image_urls)} images")
# 다음 단계로 전달 (번역)
job.stages_completed.append('image_generation')
job.stage = 'translation'
await self.queue_manager.enqueue('translation', job)
await self.queue_manager.mark_completed('image_generation', job.job_id)
except Exception as e:
logger.error(f"Error processing job {job.job_id}: {e}")
await self.queue_manager.mark_failed('image_generation', job, str(e))
def _create_image_prompt_from_article(self, article: Dict) -> str:
"""기사로부터 이미지 프롬프트 생성"""
# 키워드와 제목을 기반으로 프롬프트 생성
keyword = article.get('keyword', '')
title = article.get('title', '')
categories = article.get('categories', [])
# 카테고리 맵핑 (한글 -> 영어)
category_map = {
'기술': 'technology',
'경제': 'business',
'정치': 'politics',
'교육': 'education',
'사회': 'society',
'문화': 'culture',
'과학': 'science'
}
eng_categories = [category_map.get(cat, cat) for cat in categories]
category_str = ', '.join(eng_categories[:2]) if eng_categories else 'news'
# 뉴스 관련 이미지를 위한 프롬프트
prompt = f"News illustration for {keyword} {category_str}, professional, modern, clean design, high quality, 4k, no text"
return prompt
async def _generate_image(self, prompt: str) -> str:
"""Replicate API를 사용한 이미지 생성"""
try:
if not self.replicate_api_key:
# API 키가 없으면 플레이스홀더 이미지 URL 반환
return "https://via.placeholder.com/800x600.png?text=News+Image"
async with httpx.AsyncClient() as client:
# 예측 생성 요청
response = await client.post(
self.replicate_api_url,
headers={
"Authorization": f"Token {self.replicate_api_key}",
"Content-Type": "application/json"
},
json={
"version": self.model_version,
"input": {
"prompt": prompt,
"width": 768,
"height": 768,
"num_outputs": 1,
"scheduler": "K_EULER",
"num_inference_steps": 25,
"guidance_scale": 7.5,
"prompt_strength": 0.8,
"refine": "expert_ensemble_refiner",
"high_noise_frac": 0.8
}
},
timeout=60
)
if response.status_code in [200, 201]:
result = response.json()
prediction_id = result.get('id')
# 예측 결과 폴링
image_url = await self._poll_prediction(prediction_id)
return image_url
else:
logger.error(f"Replicate API error: {response.status_code}")
return "https://via.placeholder.com/800x600.png?text=Generation+Failed"
except Exception as e:
logger.error(f"Error generating image: {e}")
return "https://via.placeholder.com/800x600.png?text=Error"
async def _poll_prediction(self, prediction_id: str, max_attempts: int = 30) -> str:
"""예측 결과 폴링"""
try:
async with httpx.AsyncClient() as client:
for attempt in range(max_attempts):
response = await client.get(
f"{self.replicate_api_url}/{prediction_id}",
headers={
"Authorization": f"Token {self.replicate_api_key}"
},
timeout=30
)
if response.status_code == 200:
result = response.json()
status = result.get('status')
if status == 'succeeded':
output = result.get('output')
if output and isinstance(output, list) and len(output) > 0:
return output[0]
else:
return "https://via.placeholder.com/800x600.png?text=No+Output"
elif status == 'failed':
logger.error(f"Prediction failed: {result.get('error')}")
return "https://via.placeholder.com/800x600.png?text=Failed"
# 아직 처리중이면 대기
await asyncio.sleep(2)
else:
logger.error(f"Error polling prediction: {response.status_code}")
return "https://via.placeholder.com/800x600.png?text=Poll+Error"
# 최대 시도 횟수 초과
return "https://via.placeholder.com/800x600.png?text=Timeout"
except Exception as e:
logger.error(f"Error polling prediction: {e}")
return "https://via.placeholder.com/800x600.png?text=Poll+Exception"
async def stop(self):
"""워커 중지"""
await self.queue_manager.disconnect()
logger.info("Image Generator Worker stopped")
async def main():
"""메인 함수"""
worker = ImageGeneratorWorker()
try:
await worker.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await worker.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,5 @@
httpx==0.25.0
redis[hiredis]==5.0.1
pydantic==2.5.0
motor==3.1.1
pymongo==4.3.3

View File

@ -0,0 +1,22 @@
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY ./monitor/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared modules
COPY ./shared /app/shared
# Copy monitor code
COPY ./monitor /app
# Environment variables
ENV PYTHONUNBUFFERED=1
# Expose port
EXPOSE 8000
# Run
CMD ["uvicorn", "monitor:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]

View File

@ -0,0 +1,349 @@
"""
Pipeline Monitor Service
파이프라인 상태 모니터링 및 대시보드 API
"""
import os
import sys
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from motor.motor_asyncio import AsyncIOMotorClient
import redis.asyncio as redis
# Import from shared module
from shared.models import KeywordSubscription, PipelineJob, FinalArticle
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Pipeline Monitor", version="1.0.0")
# CORS 설정
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global connections
redis_client = None
mongodb_client = None
db = None
@app.on_event("startup")
async def startup_event():
"""서버 시작 시 연결 초기화"""
global redis_client, mongodb_client, db
# Redis 연결
redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
redis_client = await redis.from_url(redis_url, decode_responses=True)
# MongoDB 연결
mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
mongodb_client = AsyncIOMotorClient(mongodb_url)
db = mongodb_client[os.getenv("DB_NAME", "ai_writer_db")]
logger.info("Pipeline Monitor started successfully")
@app.on_event("shutdown")
async def shutdown_event():
"""서버 종료 시 연결 해제"""
if redis_client:
await redis_client.close()
if mongodb_client:
mongodb_client.close()
@app.get("/")
async def root():
"""헬스 체크"""
return {"status": "Pipeline Monitor is running"}
@app.get("/api/stats")
async def get_stats():
"""전체 파이프라인 통계"""
try:
# 큐별 대기 작업 수
queue_stats = {}
queues = [
"queue:keyword",
"queue:rss",
"queue:search",
"queue:summarize",
"queue:assembly"
]
for queue in queues:
length = await redis_client.llen(queue)
queue_stats[queue] = length
# 오늘 생성된 기사 수
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
articles_today = await db.articles.count_documents({
"created_at": {"$gte": today}
})
# 활성 키워드 수
active_keywords = await db.keywords.count_documents({
"is_active": True
})
# 총 기사 수
total_articles = await db.articles.count_documents({})
return {
"queues": queue_stats,
"articles_today": articles_today,
"active_keywords": active_keywords,
"total_articles": total_articles,
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Error getting stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/queues/{queue_name}")
async def get_queue_details(queue_name: str):
"""특정 큐의 상세 정보"""
try:
queue_key = f"queue:{queue_name}"
# 큐 길이
length = await redis_client.llen(queue_key)
# 최근 10개 작업 미리보기
items = await redis_client.lrange(queue_key, 0, 9)
# 처리 중인 작업
processing_key = f"processing:{queue_name}"
processing = await redis_client.smembers(processing_key)
# 실패한 작업
failed_key = f"failed:{queue_name}"
failed_count = await redis_client.llen(failed_key)
return {
"queue": queue_name,
"length": length,
"processing_count": len(processing),
"failed_count": failed_count,
"preview": items[:10],
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Error getting queue details: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/keywords")
async def get_keywords():
"""등록된 키워드 목록"""
try:
keywords = []
cursor = db.keywords.find({"is_active": True})
async for keyword in cursor:
# 해당 키워드의 최근 기사
latest_article = await db.articles.find_one(
{"keyword_id": str(keyword["_id"])},
sort=[("created_at", -1)]
)
keywords.append({
"id": str(keyword["_id"]),
"keyword": keyword["keyword"],
"schedule": keyword.get("schedule", "30분마다"),
"created_at": keyword.get("created_at"),
"last_article": latest_article["created_at"] if latest_article else None,
"article_count": await db.articles.count_documents(
{"keyword_id": str(keyword["_id"])}
)
})
return keywords
except Exception as e:
logger.error(f"Error getting keywords: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/keywords")
async def add_keyword(keyword: str, schedule: str = "30min"):
"""새 키워드 등록"""
try:
new_keyword = {
"keyword": keyword,
"schedule": schedule,
"is_active": True,
"created_at": datetime.now(),
"updated_at": datetime.now()
}
result = await db.keywords.insert_one(new_keyword)
return {
"id": str(result.inserted_id),
"keyword": keyword,
"message": "Keyword registered successfully"
}
except Exception as e:
logger.error(f"Error adding keyword: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/keywords/{keyword_id}")
async def delete_keyword(keyword_id: str):
"""키워드 비활성화"""
try:
result = await db.keywords.update_one(
{"_id": keyword_id},
{"$set": {"is_active": False, "updated_at": datetime.now()}}
)
if result.modified_count > 0:
return {"message": "Keyword deactivated successfully"}
else:
raise HTTPException(status_code=404, detail="Keyword not found")
except Exception as e:
logger.error(f"Error deleting keyword: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/articles")
async def get_articles(limit: int = 10, skip: int = 0):
"""최근 생성된 기사 목록"""
try:
articles = []
cursor = db.articles.find().sort("created_at", -1).skip(skip).limit(limit)
async for article in cursor:
articles.append({
"id": str(article["_id"]),
"title": article["title"],
"keyword": article["keyword"],
"summary": article.get("summary", ""),
"created_at": article["created_at"],
"processing_time": article.get("processing_time", 0),
"pipeline_stages": article.get("pipeline_stages", [])
})
total = await db.articles.count_documents({})
return {
"articles": articles,
"total": total,
"limit": limit,
"skip": skip
}
except Exception as e:
logger.error(f"Error getting articles: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/articles/{article_id}")
async def get_article(article_id: str):
"""특정 기사 상세 정보"""
try:
article = await db.articles.find_one({"_id": article_id})
if not article:
raise HTTPException(status_code=404, detail="Article not found")
return article
except Exception as e:
logger.error(f"Error getting article: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/workers")
async def get_workers():
"""워커 상태 정보"""
try:
workers = {}
worker_types = [
"scheduler",
"rss_collector",
"google_search",
"ai_summarizer",
"article_assembly"
]
for worker_type in worker_types:
active_key = f"workers:{worker_type}:active"
active_workers = await redis_client.smembers(active_key)
workers[worker_type] = {
"active": len(active_workers),
"worker_ids": list(active_workers)
}
return workers
except Exception as e:
logger.error(f"Error getting workers: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/trigger/{keyword}")
async def trigger_keyword_processing(keyword: str):
"""수동으로 키워드 처리 트리거"""
try:
# 키워드 찾기
keyword_doc = await db.keywords.find_one({
"keyword": keyword,
"is_active": True
})
if not keyword_doc:
raise HTTPException(status_code=404, detail="Keyword not found or inactive")
# 작업 생성
job = PipelineJob(
keyword_id=str(keyword_doc["_id"]),
keyword=keyword,
stage="keyword_processing",
created_at=datetime.now()
)
# 큐에 추가
await redis_client.rpush("queue:keyword", job.json())
return {
"message": f"Processing triggered for keyword: {keyword}",
"job_id": job.job_id
}
except Exception as e:
logger.error(f"Error triggering keyword: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/health")
async def health_check():
"""시스템 헬스 체크"""
try:
# Redis 체크
redis_status = await redis_client.ping()
# MongoDB 체크
mongodb_status = await db.command("ping")
return {
"status": "healthy",
"redis": "connected" if redis_status else "disconnected",
"mongodb": "connected" if mongodb_status else "disconnected",
"timestamp": datetime.now().isoformat()
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
"timestamp": datetime.now().isoformat()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -0,0 +1,6 @@
fastapi==0.104.1
uvicorn[standard]==0.24.0
redis[hiredis]==5.0.1
motor==3.1.1
pymongo==4.3.3
pydantic==2.5.0

View File

@ -0,0 +1,19 @@
FROM python:3.11-slim
WORKDIR /app
# 의존성 설치
COPY ./rss-collector/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 공통 모듈 복사
COPY ./shared /app/shared
# RSS Collector 코드 복사
COPY ./rss-collector /app
# 환경변수
ENV PYTHONUNBUFFERED=1
# 실행
CMD ["python", "rss_collector.py"]

View File

@ -0,0 +1,5 @@
feedparser==6.0.11
aiohttp==3.9.1
redis[hiredis]==5.0.1
pydantic==2.5.0
motor==3.6.0

View File

@ -0,0 +1,270 @@
"""
RSS Collector Service
RSS 피드 수집 및 중복 제거 서비스
"""
import asyncio
import logging
import os
import sys
import hashlib
from datetime import datetime
import feedparser
import aiohttp
import redis.asyncio as redis
from motor.motor_asyncio import AsyncIOMotorClient
from typing import List, Dict, Any
# Import from shared module
from shared.models import PipelineJob, RSSItem, EnrichedItem
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RSSCollectorWorker:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.redis_client = None
self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.db = None
self.dedup_ttl = 86400 * 7 # 7일간 중복 방지
self.max_items_per_feed = 100 # 피드당 최대 항목 수 (Google News는 최대 100개)
async def start(self):
"""워커 시작"""
logger.info("Starting RSS Collector Worker")
# Redis 연결
await self.queue_manager.connect()
self.redis_client = await redis.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True
)
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# 메인 처리 루프
while True:
try:
# 큐에서 작업 가져오기 (5초 대기)
job = await self.queue_manager.dequeue('rss_collection', timeout=5)
if job:
await self.process_job(job)
except Exception as e:
logger.error(f"Error in worker loop: {e}")
await asyncio.sleep(1)
async def process_job(self, job: PipelineJob):
"""RSS 수집 작업 처리"""
try:
logger.info(f"Processing job {job.job_id} for keyword '{job.keyword}'")
keyword = job.keyword # keyword는 job의 직접 속성
rss_feeds = job.data.get('rss_feeds', [])
# RSS 피드가 없으면 기본 피드 사용
if not rss_feeds:
# 기본 RSS 피드 추가 (Google News RSS)
rss_feeds = [
f"https://news.google.com/rss/search?q={keyword}&hl=en-US&gl=US&ceid=US:en",
f"https://news.google.com/rss/search?q={keyword}&hl=ko&gl=KR&ceid=KR:ko",
"https://feeds.bbci.co.uk/news/technology/rss.xml",
"https://rss.nytimes.com/services/xml/rss/nyt/Technology.xml"
]
logger.info(f"Using default RSS feeds for keyword: {keyword}")
# 키워드가 포함된 RSS URL 생성
processed_feeds = self._prepare_feeds(rss_feeds, keyword)
all_items = []
for feed_url in processed_feeds:
try:
items = await self._fetch_rss_feed(feed_url, keyword)
all_items.extend(items)
except Exception as e:
logger.error(f"Error fetching feed {feed_url}: {e}")
if all_items:
# 중복 제거
unique_items = await self._deduplicate_items(all_items, keyword)
if unique_items:
logger.info(f"Collected {len(unique_items)} unique items for '{keyword}'")
# 각 RSS 아이템별로 개별 job 생성하여 다음 단계로 전달
# 시간 지연을 추가하여 API 호출 분산 (초기값: 1초, 점진적으로 조정 가능)
enqueue_delay = float(os.getenv("RSS_ENQUEUE_DELAY", "1.0"))
for idx, item in enumerate(unique_items):
# 각 아이템별로 새로운 job 생성
item_job = PipelineJob(
keyword_id=f"{job.keyword_id}_{idx}",
keyword=job.keyword,
stage='search_enrichment',
data={
'rss_item': item.dict(), # 단일 아이템
'original_job_id': job.job_id,
'item_index': idx,
'total_items': len(unique_items),
'item_hash': hashlib.md5(
f"{keyword}:guid:{item.guid}".encode() if item.guid
else f"{keyword}:title:{item.title}:link:{item.link}".encode()
).hexdigest() # GUID 또는 title+link 해시
},
stages_completed=['rss_collection']
)
# 개별 아이템을 다음 단계로 전달
await self.queue_manager.enqueue('search_enrichment', item_job)
logger.info(f"Enqueued item {idx+1}/{len(unique_items)} for keyword '{keyword}'")
# 다음 아이템 enqueue 전에 지연 추가 (마지막 아이템 제외)
if idx < len(unique_items) - 1:
await asyncio.sleep(enqueue_delay)
logger.debug(f"Waiting {enqueue_delay}s before next item...")
# 원본 job 완료 처리
await self.queue_manager.mark_completed('rss_collection', job.job_id)
logger.info(f"Completed RSS collection for job {job.job_id}: {len(unique_items)} items processed")
else:
logger.info(f"No new items found for '{keyword}' after deduplication")
await self.queue_manager.mark_completed('rss_collection', job.job_id)
else:
logger.warning(f"No RSS items collected for '{keyword}'")
await self.queue_manager.mark_failed(
'rss_collection',
job,
"No RSS items collected"
)
except Exception as e:
logger.error(f"Error processing job {job.job_id}: {e}")
await self.queue_manager.mark_failed('rss_collection', job, str(e))
def _prepare_feeds(self, feeds: List[str], keyword: str) -> List[str]:
"""RSS 피드 URL 준비 (키워드 치환)"""
processed = []
for feed in feeds:
if '{keyword}' in feed:
processed.append(feed.replace('{keyword}', keyword))
else:
processed.append(feed)
return processed
async def _fetch_rss_feed(self, feed_url: str, keyword: str) -> List[RSSItem]:
"""RSS 피드 가져오기"""
items = []
try:
async with aiohttp.ClientSession() as session:
async with session.get(feed_url, timeout=30) as response:
content = await response.text()
# feedparser로 파싱
feed = feedparser.parse(content)
logger.info(f"Found {len(feed.entries)} entries in feed {feed_url}")
for entry in feed.entries[:self.max_items_per_feed]:
# 키워드 관련성 체크
title = entry.get('title', '')
summary = entry.get('summary', '')
# 대소문자 무시하고 키워드 매칭 (영문의 경우)
title_lower = title.lower() if keyword.isascii() else title
summary_lower = summary.lower() if keyword.isascii() else summary
keyword_lower = keyword.lower() if keyword.isascii() else keyword
# 제목이나 요약에 키워드가 포함된 경우
# Google News RSS는 이미 키워드 검색 결과이므로 모든 항목 포함
if "news.google.com" in feed_url or keyword_lower in title_lower or keyword_lower in summary_lower:
# GUID 추출 (Google RSS에서 일반적으로 사용)
guid = entry.get('id', entry.get('guid', ''))
item = RSSItem(
title=title,
link=entry.get('link', ''),
guid=guid, # GUID 추가
published=entry.get('published', ''),
summary=summary[:500] if summary else '',
source_feed=feed_url
)
items.append(item)
logger.debug(f"Added item: {title[:50]}... (guid: {guid[:30] if guid else 'no-guid'})")
except Exception as e:
logger.error(f"Error fetching RSS feed {feed_url}: {e}")
return items
async def _deduplicate_items(self, items: List[RSSItem], keyword: str) -> List[RSSItem]:
"""중복 항목 제거 - GUID 또는 링크 기준으로만 중복 체크"""
unique_items = []
seen_guids = set() # 현재 배치에서 본 GUID
seen_links = set() # 현재 배치에서 본 링크
for item in items:
# GUID가 있는 경우 GUID로 중복 체크
if item.guid:
if item.guid in seen_guids:
logger.debug(f"Duplicate GUID in batch: {item.guid[:30]}")
continue
# MongoDB에서 이미 처리된 기사인지 확인
existing_article = await self.db.articles_ko.find_one({"rss_guid": item.guid})
if existing_article:
logger.info(f"Article with GUID {item.guid[:30]} already processed, skipping")
continue
seen_guids.add(item.guid)
else:
# GUID가 없으면 링크로 중복 체크
if item.link in seen_links:
logger.debug(f"Duplicate link in batch: {item.link[:50]}")
continue
# MongoDB에서 링크로 중복 확인 (references 필드에서 검색)
existing_article = await self.db.articles_ko.find_one({"references.link": item.link})
if existing_article:
logger.info(f"Article with link {item.link[:50]} already processed, skipping")
continue
seen_links.add(item.link)
unique_items.append(item)
logger.debug(f"New item added: {item.title[:50]}...")
logger.info(f"Deduplication result: {len(unique_items)} new items out of {len(items)} total")
return unique_items
async def stop(self):
"""워커 중지"""
await self.queue_manager.disconnect()
if self.redis_client:
await self.redis_client.close()
logger.info("RSS Collector Worker stopped")
async def main():
"""메인 함수"""
worker = RSSCollectorWorker()
try:
await worker.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await worker.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,16 @@
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY ./scheduler/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared module
COPY ./shared /app/shared
# Copy scheduler code
COPY ./scheduler /app
# Run scheduler
CMD ["python", "keyword_scheduler.py"]

View File

@ -0,0 +1,336 @@
"""
Keyword Manager API
키워드를 추가/수정/삭제하는 관리 API
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime, timedelta
from motor.motor_asyncio import AsyncIOMotorClient
import uvicorn
import os
import sys
import uuid
# Import from shared module
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import Keyword
app = FastAPI(title="Keyword Manager API")
# MongoDB 연결
mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
db_name = os.getenv("DB_NAME", "ai_writer_db")
@app.on_event("startup")
async def startup_event():
"""앱 시작 시 MongoDB 연결"""
app.mongodb_client = AsyncIOMotorClient(mongodb_url)
app.db = app.mongodb_client[db_name]
@app.on_event("shutdown")
async def shutdown_event():
"""앱 종료 시 연결 해제"""
app.mongodb_client.close()
class KeywordCreate(BaseModel):
"""키워드 생성 요청 모델"""
keyword: str
interval_minutes: int = 60
priority: int = 0
rss_feeds: List[str] = []
max_articles_per_run: int = 100
is_active: bool = True
class KeywordUpdate(BaseModel):
"""키워드 업데이트 요청 모델"""
interval_minutes: Optional[int] = None
priority: Optional[int] = None
rss_feeds: Optional[List[str]] = None
max_articles_per_run: Optional[int] = None
is_active: Optional[bool] = None
@app.get("/")
async def root():
"""API 상태 확인"""
return {"status": "Keyword Manager API is running"}
@app.get("/threads/status")
async def get_threads_status():
"""모든 스레드 상태 조회"""
try:
# MongoDB에서 키워드 정보와 함께 상태 반환
cursor = app.db.keywords.find()
keywords = await cursor.to_list(None)
threads_status = []
for kw in keywords:
status = {
"keyword": kw.get("keyword"),
"keyword_id": kw.get("keyword_id"),
"is_active": kw.get("is_active"),
"interval_minutes": kw.get("interval_minutes"),
"priority": kw.get("priority"),
"last_run": kw.get("last_run").isoformat() if kw.get("last_run") else None,
"next_run": kw.get("next_run").isoformat() if kw.get("next_run") else None,
"thread_status": "active" if kw.get("is_active") else "inactive"
}
# 다음 실행까지 남은 시간 계산
if kw.get("next_run"):
remaining = (kw.get("next_run") - datetime.now()).total_seconds()
if remaining > 0:
status["minutes_until_next_run"] = round(remaining / 60, 1)
else:
status["minutes_until_next_run"] = 0
status["thread_status"] = "pending_execution"
threads_status.append(status)
# 우선순위 순으로 정렬
threads_status.sort(key=lambda x: x.get("priority", 0), reverse=True)
return {
"total_threads": len(threads_status),
"active_threads": sum(1 for t in threads_status if t.get("is_active")),
"threads": threads_status
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/keywords")
async def list_keywords():
"""모든 키워드 조회"""
try:
cursor = app.db.keywords.find()
keywords = await cursor.to_list(None)
# 각 키워드 정보 정리
result = []
for kw in keywords:
result.append({
"keyword_id": kw.get("keyword_id"),
"keyword": kw.get("keyword"),
"interval_minutes": kw.get("interval_minutes"),
"priority": kw.get("priority"),
"is_active": kw.get("is_active"),
"last_run": kw.get("last_run").isoformat() if kw.get("last_run") else None,
"next_run": kw.get("next_run").isoformat() if kw.get("next_run") else None,
"rss_feeds": kw.get("rss_feeds", []),
"max_articles_per_run": kw.get("max_articles_per_run", 100)
})
return {
"total": len(result),
"keywords": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/keywords/{keyword_text}")
async def get_keyword(keyword_text: str):
"""특정 키워드 조회"""
try:
keyword = await app.db.keywords.find_one({"keyword": keyword_text})
if not keyword:
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
return {
"keyword_id": keyword.get("keyword_id"),
"keyword": keyword.get("keyword"),
"interval_minutes": keyword.get("interval_minutes"),
"priority": keyword.get("priority"),
"is_active": keyword.get("is_active"),
"last_run": keyword.get("last_run").isoformat() if keyword.get("last_run") else None,
"next_run": keyword.get("next_run").isoformat() if keyword.get("next_run") else None,
"rss_feeds": keyword.get("rss_feeds", []),
"max_articles_per_run": keyword.get("max_articles_per_run", 100)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/keywords")
async def create_keyword(keyword_data: KeywordCreate):
"""새 키워드 생성"""
try:
# 중복 체크
existing = await app.db.keywords.find_one({"keyword": keyword_data.keyword})
if existing:
raise HTTPException(status_code=400, detail=f"Keyword '{keyword_data.keyword}' already exists")
# 새 키워드 생성
keyword = Keyword(
keyword_id=str(uuid.uuid4()),
keyword=keyword_data.keyword,
interval_minutes=keyword_data.interval_minutes,
priority=keyword_data.priority,
rss_feeds=keyword_data.rss_feeds,
max_articles_per_run=keyword_data.max_articles_per_run,
is_active=keyword_data.is_active,
next_run=datetime.now() + timedelta(minutes=1), # 1분 후 첫 실행
created_at=datetime.now(),
updated_at=datetime.now()
)
await app.db.keywords.insert_one(keyword.model_dump())
return {
"message": f"Keyword '{keyword_data.keyword}' created successfully",
"keyword_id": keyword.keyword_id,
"note": "The scheduler will automatically detect and start processing this keyword within 30 seconds"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.put("/keywords/{keyword_text}")
async def update_keyword(keyword_text: str, update_data: KeywordUpdate):
"""키워드 업데이트"""
try:
# 키워드 존재 확인
existing = await app.db.keywords.find_one({"keyword": keyword_text})
if not existing:
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
# 업데이트 데이터 준비
update_dict = {}
if update_data.interval_minutes is not None:
update_dict["interval_minutes"] = update_data.interval_minutes
if update_data.priority is not None:
update_dict["priority"] = update_data.priority
if update_data.rss_feeds is not None:
update_dict["rss_feeds"] = update_data.rss_feeds
if update_data.max_articles_per_run is not None:
update_dict["max_articles_per_run"] = update_data.max_articles_per_run
if update_data.is_active is not None:
update_dict["is_active"] = update_data.is_active
if update_dict:
update_dict["updated_at"] = datetime.now()
# 만약 interval이 변경되면 next_run도 재계산
if "interval_minutes" in update_dict:
update_dict["next_run"] = datetime.now() + timedelta(minutes=update_dict["interval_minutes"])
result = await app.db.keywords.update_one(
{"keyword": keyword_text},
{"$set": update_dict}
)
if result.modified_count > 0:
action_note = ""
if update_data.is_active is False:
action_note = "The scheduler will stop the thread for this keyword within 30 seconds."
elif update_data.is_active is True and not existing.get("is_active"):
action_note = "The scheduler will start a new thread for this keyword within 30 seconds."
return {
"message": f"Keyword '{keyword_text}' updated successfully",
"updated_fields": list(update_dict.keys()),
"note": action_note
}
else:
return {"message": "No changes made"}
else:
return {"message": "No update data provided"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/keywords/{keyword_text}")
async def delete_keyword(keyword_text: str):
"""키워드 삭제"""
try:
# 키워드 존재 확인
existing = await app.db.keywords.find_one({"keyword": keyword_text})
if not existing:
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
# 삭제
result = await app.db.keywords.delete_one({"keyword": keyword_text})
if result.deleted_count > 0:
return {
"message": f"Keyword '{keyword_text}' deleted successfully",
"note": "The scheduler will stop the thread for this keyword within 30 seconds"
}
else:
raise HTTPException(status_code=500, detail="Failed to delete keyword")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/keywords/{keyword_text}/activate")
async def activate_keyword(keyword_text: str):
"""키워드 활성화"""
try:
result = await app.db.keywords.update_one(
{"keyword": keyword_text},
{"$set": {"is_active": True, "updated_at": datetime.now()}}
)
if result.matched_count == 0:
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
return {
"message": f"Keyword '{keyword_text}' activated",
"note": "The scheduler will start processing this keyword within 30 seconds"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/keywords/{keyword_text}/deactivate")
async def deactivate_keyword(keyword_text: str):
"""키워드 비활성화"""
try:
result = await app.db.keywords.update_one(
{"keyword": keyword_text},
{"$set": {"is_active": False, "updated_at": datetime.now()}}
)
if result.matched_count == 0:
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
return {
"message": f"Keyword '{keyword_text}' deactivated",
"note": "The scheduler will stop processing this keyword within 30 seconds"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/keywords/{keyword_text}/trigger")
async def trigger_keyword(keyword_text: str):
"""키워드 즉시 실행 트리거"""
try:
# next_run을 현재 시간으로 설정하여 즉시 실행되도록 함
result = await app.db.keywords.update_one(
{"keyword": keyword_text},
{"$set": {"next_run": datetime.now(), "updated_at": datetime.now()}}
)
if result.matched_count == 0:
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
return {
"message": f"Keyword '{keyword_text}' triggered for immediate execution",
"note": "The scheduler will execute this keyword within the next minute"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
port = int(os.getenv("API_PORT", "8100"))
uvicorn.run(app, host="0.0.0.0", port=port)

View File

@ -0,0 +1,245 @@
"""
Keyword Scheduler Service
데이터베이스에 등록된 키워드를 주기적으로 실행하는 스케줄러
"""
import asyncio
import logging
import os
import sys
from datetime import datetime, timedelta
from motor.motor_asyncio import AsyncIOMotorClient
from typing import List, Optional
import uuid
# Import from shared module
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import Keyword, PipelineJob
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class KeywordScheduler:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.db = None
self.check_interval = int(os.getenv("SCHEDULER_CHECK_INTERVAL", "60")) # 1분마다 체크
self.default_interval = int(os.getenv("DEFAULT_KEYWORD_INTERVAL", "60")) # 기본 1시간
async def start(self):
"""스케줄러 시작"""
logger.info("Starting Keyword Scheduler")
# Redis 연결
await self.queue_manager.connect()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# 초기 키워드 설정
await self.initialize_keywords()
# 메인 루프
while True:
try:
await self.check_and_execute_keywords()
await asyncio.sleep(self.check_interval)
except Exception as e:
logger.error(f"Error in scheduler loop: {e}")
await asyncio.sleep(10)
async def initialize_keywords(self):
"""초기 키워드 설정 (없으면 생성)"""
try:
# keywords 컬렉션 확인
count = await self.db.keywords.count_documents({})
if count == 0:
logger.info("No keywords found. Creating default keywords...")
# 기본 키워드 생성
default_keywords = [
{
"keyword": "AI",
"interval_minutes": 60,
"is_active": True,
"priority": 1,
"rss_feeds": []
},
{
"keyword": "경제",
"interval_minutes": 120,
"is_active": True,
"priority": 0,
"rss_feeds": []
},
{
"keyword": "테크놀로지",
"interval_minutes": 60,
"is_active": True,
"priority": 1,
"rss_feeds": []
}
]
for kw_data in default_keywords:
keyword = Keyword(**kw_data)
# 다음 실행 시간 설정
keyword.next_run = datetime.now() + timedelta(minutes=5) # 5분 후 첫 실행
await self.db.keywords.insert_one(keyword.dict())
logger.info(f"Created keyword: {keyword.keyword}")
logger.info(f"Found {count} keywords in database")
except Exception as e:
logger.error(f"Error initializing keywords: {e}")
async def check_and_execute_keywords(self):
"""실행할 키워드 체크 및 실행"""
try:
# 현재 시간
now = datetime.now()
# 실행할 키워드 조회 (활성화되고 next_run이 현재 시간 이전인 것)
query = {
"is_active": True,
"$or": [
{"next_run": {"$lte": now}},
{"next_run": None} # next_run이 설정되지 않은 경우
]
}
# 우선순위 순으로 정렬
cursor = self.db.keywords.find(query).sort("priority", -1)
keywords = await cursor.to_list(None)
for keyword_data in keywords:
keyword = Keyword(**keyword_data)
await self.execute_keyword(keyword)
except Exception as e:
logger.error(f"Error checking keywords: {e}")
async def execute_keyword(self, keyword: Keyword):
"""키워드 실행"""
try:
logger.info(f"Executing keyword: {keyword.keyword}")
# PipelineJob 생성
job = PipelineJob(
keyword_id=keyword.keyword_id,
keyword=keyword.keyword,
stage='rss_collection',
data={
'rss_feeds': keyword.rss_feeds if keyword.rss_feeds else [],
'max_articles': keyword.max_articles_per_run,
'scheduled': True
},
priority=keyword.priority
)
# 큐에 작업 추가
await self.queue_manager.enqueue('rss_collection', job)
logger.info(f"Enqueued job for keyword '{keyword.keyword}' with job_id: {job.job_id}")
# 키워드 업데이트
update_data = {
"last_run": datetime.now(),
"next_run": datetime.now() + timedelta(minutes=keyword.interval_minutes),
"updated_at": datetime.now()
}
await self.db.keywords.update_one(
{"keyword_id": keyword.keyword_id},
{"$set": update_data}
)
logger.info(f"Updated keyword '{keyword.keyword}' - next run at {update_data['next_run']}")
except Exception as e:
logger.error(f"Error executing keyword {keyword.keyword}: {e}")
async def add_keyword(self, keyword_text: str, interval_minutes: int = None,
rss_feeds: List[str] = None, priority: int = 0):
"""새 키워드 추가"""
try:
# 중복 체크
existing = await self.db.keywords.find_one({"keyword": keyword_text})
if existing:
logger.warning(f"Keyword '{keyword_text}' already exists")
return None
# 새 키워드 생성
keyword = Keyword(
keyword=keyword_text,
interval_minutes=interval_minutes or self.default_interval,
rss_feeds=rss_feeds or [],
priority=priority,
next_run=datetime.now() + timedelta(minutes=1) # 1분 후 첫 실행
)
result = await self.db.keywords.insert_one(keyword.dict())
logger.info(f"Added new keyword: {keyword_text}")
return keyword
except Exception as e:
logger.error(f"Error adding keyword: {e}")
return None
async def update_keyword(self, keyword_id: str, **kwargs):
"""키워드 업데이트"""
try:
# 업데이트할 필드
update_data = {k: v for k, v in kwargs.items() if v is not None}
update_data["updated_at"] = datetime.now()
result = await self.db.keywords.update_one(
{"keyword_id": keyword_id},
{"$set": update_data}
)
if result.modified_count > 0:
logger.info(f"Updated keyword {keyword_id}")
return True
return False
except Exception as e:
logger.error(f"Error updating keyword: {e}")
return False
async def delete_keyword(self, keyword_id: str):
"""키워드 삭제"""
try:
result = await self.db.keywords.delete_one({"keyword_id": keyword_id})
if result.deleted_count > 0:
logger.info(f"Deleted keyword {keyword_id}")
return True
return False
except Exception as e:
logger.error(f"Error deleting keyword: {e}")
return False
async def stop(self):
"""스케줄러 중지"""
await self.queue_manager.disconnect()
logger.info("Keyword Scheduler stopped")
async def main():
"""메인 함수"""
scheduler = KeywordScheduler()
try:
await scheduler.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await scheduler.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,361 @@
"""
Multi-threaded Keyword Scheduler Service
하나의 프로세스에서 여러 스레드로 키워드를 관리하는 스케줄러
"""
import asyncio
import logging
import os
import sys
from datetime import datetime, timedelta
from motor.motor_asyncio import AsyncIOMotorClient
from typing import Dict
import threading
import time
# Import from shared module
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import Keyword, PipelineJob
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 전역 변수로 스케줄러 인스턴스 참조 저장
scheduler_instance = None
class KeywordThread(threading.Thread):
"""개별 키워드를 관리하는 스레드"""
def __init__(self, keyword_text: str, mongodb_url: str, db_name: str, redis_url: str):
super().__init__(name=f"Thread-{keyword_text}")
self.keyword_text = keyword_text
self.mongodb_url = mongodb_url
self.db_name = db_name
self.redis_url = redis_url
self.running = True
self.keyword = None
self.status = "initializing"
self.last_execution = None
self.execution_count = 0
self.error_count = 0
self.last_error = None
def run(self):
"""스레드 실행"""
# 새로운 이벤트 루프 생성
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.run_scheduler())
finally:
loop.close()
async def run_scheduler(self):
"""비동기 스케줄러 실행"""
# Redis 연결
self.queue_manager = QueueManager(redis_url=self.redis_url)
await self.queue_manager.connect()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
logger.info(f"[{self.keyword_text}] Thread started")
# 키워드 로드
await self.load_keyword()
if not self.keyword:
logger.error(f"[{self.keyword_text}] Failed to load keyword")
return
# 메인 루프
while self.running:
try:
# 키워드 상태 체크
await self.reload_keyword()
if not self.keyword.is_active:
self.status = "inactive"
logger.info(f"[{self.keyword_text}] Keyword is inactive, sleeping...")
await asyncio.sleep(60)
continue
# 실행 시간 체크
now = datetime.now()
if self.keyword.next_run and self.keyword.next_run <= now:
self.status = "executing"
await self.execute_keyword()
# 다음 실행 시간까지 대기
sleep_seconds = self.keyword.interval_minutes * 60
self.status = "waiting"
else:
# 다음 체크까지 1분 대기
sleep_seconds = 60
self.status = "waiting"
await asyncio.sleep(sleep_seconds)
except Exception as e:
self.error_count += 1
self.last_error = str(e)
self.status = "error"
logger.error(f"[{self.keyword_text}] Error in thread loop: {e}")
await asyncio.sleep(60)
await self.queue_manager.disconnect()
logger.info(f"[{self.keyword_text}] Thread stopped")
async def load_keyword(self):
"""키워드 초기 로드"""
try:
keyword_doc = await self.db.keywords.find_one({"keyword": self.keyword_text})
if keyword_doc:
self.keyword = Keyword(**keyword_doc)
logger.info(f"[{self.keyword_text}] Loaded keyword")
except Exception as e:
logger.error(f"[{self.keyword_text}] Error loading keyword: {e}")
async def reload_keyword(self):
"""키워드 정보 재로드"""
try:
keyword_doc = await self.db.keywords.find_one({"keyword": self.keyword_text})
if keyword_doc:
self.keyword = Keyword(**keyword_doc)
except Exception as e:
logger.error(f"[{self.keyword_text}] Error reloading keyword: {e}")
async def execute_keyword(self):
"""키워드 실행"""
try:
logger.info(f"[{self.keyword_text}] Executing keyword")
# PipelineJob 생성
job = PipelineJob(
keyword_id=self.keyword.keyword_id,
keyword=self.keyword.keyword,
stage='rss_collection',
data={
'rss_feeds': self.keyword.rss_feeds if self.keyword.rss_feeds else [],
'max_articles': self.keyword.max_articles_per_run,
'scheduled': True,
'thread_name': self.name
},
priority=self.keyword.priority
)
# 큐에 작업 추가
await self.queue_manager.enqueue('rss_collection', job)
logger.info(f"[{self.keyword_text}] Enqueued job {job.job_id}")
# 키워드 업데이트
update_data = {
"last_run": datetime.now(),
"next_run": datetime.now() + timedelta(minutes=self.keyword.interval_minutes),
"updated_at": datetime.now()
}
await self.db.keywords.update_one(
{"keyword_id": self.keyword.keyword_id},
{"$set": update_data}
)
self.last_execution = datetime.now()
self.execution_count += 1
logger.info(f"[{self.keyword_text}] Next run at {update_data['next_run']}")
except Exception as e:
self.error_count += 1
self.last_error = str(e)
logger.error(f"[{self.keyword_text}] Error executing keyword: {e}")
def stop(self):
"""스레드 중지"""
self.running = False
self.status = "stopped"
def get_status(self):
"""스레드 상태 반환"""
return {
"keyword": self.keyword_text,
"thread_name": self.name,
"status": self.status,
"is_alive": self.is_alive(),
"execution_count": self.execution_count,
"last_execution": self.last_execution.isoformat() if self.last_execution else None,
"error_count": self.error_count,
"last_error": self.last_error,
"next_run": self.keyword.next_run.isoformat() if self.keyword and self.keyword.next_run else None
}
class MultiThreadScheduler:
"""멀티스레드 키워드 스케줄러"""
def __init__(self):
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
self.threads: Dict[str, KeywordThread] = {}
self.running = True
# Singleton 인스턴스를 전역 변수로 저장
global scheduler_instance
scheduler_instance = self
async def start(self):
"""스케줄러 시작"""
logger.info("Starting Multi-threaded Keyword Scheduler")
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# 초기 키워드 설정
await self.initialize_keywords()
# 키워드 로드 및 스레드 시작
await self.load_and_start_threads()
# 메인 루프 - 새로운 키워드 체크
while self.running:
try:
await self.check_new_keywords()
await asyncio.sleep(30) # 30초마다 새 키워드 체크
except Exception as e:
logger.error(f"Error in main loop: {e}")
await asyncio.sleep(30)
async def initialize_keywords(self):
"""초기 키워드 설정 (없으면 생성)"""
try:
count = await self.db.keywords.count_documents({})
if count == 0:
logger.info("No keywords found. Creating default keywords...")
default_keywords = [
{
"keyword": "AI",
"interval_minutes": 60,
"is_active": True,
"priority": 1,
"rss_feeds": [],
"next_run": datetime.now() + timedelta(minutes=1)
},
{
"keyword": "경제",
"interval_minutes": 120,
"is_active": True,
"priority": 0,
"rss_feeds": [],
"next_run": datetime.now() + timedelta(minutes=1)
},
{
"keyword": "테크놀로지",
"interval_minutes": 60,
"is_active": True,
"priority": 1,
"rss_feeds": [],
"next_run": datetime.now() + timedelta(minutes=1)
}
]
for kw_data in default_keywords:
keyword = Keyword(**kw_data)
await self.db.keywords.insert_one(keyword.model_dump())
logger.info(f"Created keyword: {keyword.keyword}")
logger.info(f"Found {count} keywords in database")
except Exception as e:
logger.error(f"Error initializing keywords: {e}")
async def load_and_start_threads(self):
"""키워드 로드 및 스레드 시작"""
try:
# 활성 키워드 조회
cursor = self.db.keywords.find({"is_active": True})
keywords = await cursor.to_list(None)
for keyword_doc in keywords:
keyword = Keyword(**keyword_doc)
if keyword.keyword not in self.threads:
self.start_keyword_thread(keyword.keyword)
logger.info(f"Started {len(self.threads)} keyword threads")
except Exception as e:
logger.error(f"Error loading keywords: {e}")
def start_keyword_thread(self, keyword_text: str):
"""키워드 스레드 시작"""
if keyword_text not in self.threads:
thread = KeywordThread(
keyword_text=keyword_text,
mongodb_url=self.mongodb_url,
db_name=self.db_name,
redis_url=self.redis_url
)
thread.start()
self.threads[keyword_text] = thread
logger.info(f"Started thread for keyword: {keyword_text}")
async def check_new_keywords(self):
"""새로운 키워드 체크 및 스레드 관리"""
try:
# 현재 활성 키워드 조회
cursor = self.db.keywords.find({"is_active": True})
active_keywords = await cursor.to_list(None)
active_keyword_texts = {kw['keyword'] for kw in active_keywords}
# 새 키워드 시작
for keyword_text in active_keyword_texts:
if keyword_text not in self.threads:
self.start_keyword_thread(keyword_text)
# 비활성화된 키워드 스레드 중지
for keyword_text in list(self.threads.keys()):
if keyword_text not in active_keyword_texts:
thread = self.threads[keyword_text]
thread.stop()
del self.threads[keyword_text]
logger.info(f"Stopped thread for keyword: {keyword_text}")
except Exception as e:
logger.error(f"Error checking new keywords: {e}")
def stop(self):
"""모든 스레드 중지"""
self.running = False
for thread in self.threads.values():
thread.stop()
# 모든 스레드가 종료될 때까지 대기
for thread in self.threads.values():
thread.join(timeout=5)
logger.info("Multi-threaded Keyword Scheduler stopped")
def get_threads_status(self):
"""모든 스레드 상태 반환"""
status_list = []
for thread in self.threads.values():
status_list.append(thread.get_status())
return status_list
async def main():
"""메인 함수"""
scheduler = MultiThreadScheduler()
try:
await scheduler.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
scheduler.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,5 @@
motor==3.6.0
redis[hiredis]==5.0.1
pydantic==2.5.0
fastapi==0.104.1
uvicorn==0.24.0

View File

@ -0,0 +1,203 @@
"""
News Pipeline Scheduler
뉴스 파이프라인 스케줄러 서비스
"""
import asyncio
import logging
import os
import sys
from datetime import datetime, timedelta
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from motor.motor_asyncio import AsyncIOMotorClient
# Import from shared module
from shared.models import KeywordSubscription, PipelineJob
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NewsScheduler:
def __init__(self):
self.scheduler = AsyncIOScheduler()
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.db = None
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
async def start(self):
"""스케줄러 시작"""
logger.info("Starting News Pipeline Scheduler")
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# Redis 연결
await self.queue_manager.connect()
# 기본 스케줄 설정
# 매 30분마다 실행
self.scheduler.add_job(
self.process_keywords,
'interval',
minutes=30,
id='keyword_processor',
name='Process Active Keywords'
)
# 특정 시간대 강화 스케줄 (아침 7시, 점심 12시, 저녁 6시)
for hour in [7, 12, 18]:
self.scheduler.add_job(
self.process_priority_keywords,
'cron',
hour=hour,
minute=0,
id=f'priority_processor_{hour}',
name=f'Process Priority Keywords at {hour}:00'
)
# 매일 자정 통계 초기화
self.scheduler.add_job(
self.reset_daily_stats,
'cron',
hour=0,
minute=0,
id='stats_reset',
name='Reset Daily Statistics'
)
self.scheduler.start()
logger.info("Scheduler started successfully")
# 시작 즉시 한 번 실행
await self.process_keywords()
async def process_keywords(self):
"""활성 키워드 처리"""
try:
logger.info("Processing active keywords")
# MongoDB에서 활성 키워드 로드
now = datetime.now()
thirty_minutes_ago = now - timedelta(minutes=30)
keywords = await self.db.keywords.find({
"is_active": True,
"$or": [
{"last_processed": {"$lt": thirty_minutes_ago}},
{"last_processed": None}
]
}).to_list(None)
logger.info(f"Found {len(keywords)} keywords to process")
for keyword_doc in keywords:
await self._create_job(keyword_doc)
# 처리 시간 업데이트
await self.db.keywords.update_one(
{"keyword_id": keyword_doc['keyword_id']},
{"$set": {"last_processed": now}}
)
logger.info(f"Created jobs for {len(keywords)} keywords")
except Exception as e:
logger.error(f"Error processing keywords: {e}")
async def process_priority_keywords(self):
"""우선순위 키워드 처리"""
try:
logger.info("Processing priority keywords")
keywords = await self.db.keywords.find({
"is_active": True,
"is_priority": True
}).to_list(None)
for keyword_doc in keywords:
await self._create_job(keyword_doc, priority=1)
logger.info(f"Created priority jobs for {len(keywords)} keywords")
except Exception as e:
logger.error(f"Error processing priority keywords: {e}")
async def _create_job(self, keyword_doc: dict, priority: int = 0):
"""파이프라인 작업 생성"""
try:
# KeywordSubscription 모델로 변환
keyword = KeywordSubscription(**keyword_doc)
# PipelineJob 생성
job = PipelineJob(
keyword_id=keyword.keyword_id,
keyword=keyword.keyword,
stage='rss_collection',
stages_completed=[],
priority=priority,
data={
'keyword': keyword.keyword,
'language': keyword.language,
'rss_feeds': keyword.rss_feeds or self._get_default_rss_feeds(),
'categories': keyword.categories
}
)
# 첫 번째 큐에 추가
await self.queue_manager.enqueue(
'rss_collection',
job,
priority=priority
)
logger.info(f"Created job {job.job_id} for keyword '{keyword.keyword}'")
except Exception as e:
logger.error(f"Error creating job for keyword: {e}")
def _get_default_rss_feeds(self) -> list:
"""기본 RSS 피드 목록"""
return [
"https://news.google.com/rss/search?q={keyword}&hl=ko&gl=KR&ceid=KR:ko",
"https://trends.google.com/trends/trendingsearches/daily/rss?geo=KR",
"https://www.mk.co.kr/rss/40300001/", # 매일경제
"https://www.hankyung.com/feed/all-news", # 한국경제
"https://www.zdnet.co.kr/news/news_rss.xml", # ZDNet Korea
]
async def reset_daily_stats(self):
"""일일 통계 초기화"""
try:
logger.info("Resetting daily statistics")
# Redis 통계 초기화
# 구현 필요
pass
except Exception as e:
logger.error(f"Error resetting stats: {e}")
async def stop(self):
"""스케줄러 중지"""
self.scheduler.shutdown()
await self.queue_manager.disconnect()
logger.info("Scheduler stopped")
async def main():
"""메인 함수"""
scheduler = NewsScheduler()
try:
await scheduler.start()
# 계속 실행
while True:
await asyncio.sleep(60)
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await scheduler.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,173 @@
"""
Single Keyword Scheduler Service
단일 키워드를 전담하는 스케줄러
"""
import asyncio
import logging
import os
import sys
from datetime import datetime, timedelta
from motor.motor_asyncio import AsyncIOMotorClient
import uuid
# Import from shared module
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import Keyword, PipelineJob
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SingleKeywordScheduler:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.keyword_text = os.getenv("KEYWORD") # 환경변수로 키워드 지정
self.interval_minutes = int(os.getenv("INTERVAL_MINUTES", "60"))
self.db = None
self.keyword = None
async def start(self):
"""스케줄러 시작"""
if not self.keyword_text:
logger.error("KEYWORD environment variable is required")
return
logger.info(f"Starting Single Keyword Scheduler for '{self.keyword_text}'")
# Redis 연결
await self.queue_manager.connect()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# 키워드 초기화 또는 로드
await self.initialize_keyword()
if not self.keyword:
logger.error(f"Failed to initialize keyword '{self.keyword_text}'")
return
# 메인 루프 - 이 키워드만 처리
while True:
try:
await self.check_and_execute()
# 다음 실행까지 대기
sleep_seconds = self.keyword.interval_minutes * 60
logger.info(f"Sleeping for {self.keyword.interval_minutes} minutes until next execution")
await asyncio.sleep(sleep_seconds)
except Exception as e:
logger.error(f"Error in scheduler loop: {e}")
await asyncio.sleep(60) # 에러 발생시 1분 후 재시도
async def initialize_keyword(self):
"""키워드 초기화 또는 로드"""
try:
# 기존 키워드 찾기
keyword_doc = await self.db.keywords.find_one({"keyword": self.keyword_text})
if keyword_doc:
self.keyword = Keyword(**keyword_doc)
logger.info(f"Loaded existing keyword: {self.keyword_text}")
else:
# 새 키워드 생성
self.keyword = Keyword(
keyword=self.keyword_text,
interval_minutes=self.interval_minutes,
is_active=True,
priority=int(os.getenv("PRIORITY", "0")),
rss_feeds=os.getenv("RSS_FEEDS", "").split(",") if os.getenv("RSS_FEEDS") else [],
max_articles_per_run=int(os.getenv("MAX_ARTICLES", "100"))
)
await self.db.keywords.insert_one(self.keyword.model_dump())
logger.info(f"Created new keyword: {self.keyword_text}")
except Exception as e:
logger.error(f"Error initializing keyword: {e}")
async def check_and_execute(self):
"""키워드 실행 체크 및 실행"""
try:
# 최신 키워드 정보 다시 로드
keyword_doc = await self.db.keywords.find_one({"keyword": self.keyword_text})
if not keyword_doc:
logger.error(f"Keyword '{self.keyword_text}' not found in database")
return
self.keyword = Keyword(**keyword_doc)
# 비활성화된 경우 스킵
if not self.keyword.is_active:
logger.info(f"Keyword '{self.keyword_text}' is inactive, skipping")
return
# 실행
await self.execute_keyword()
except Exception as e:
logger.error(f"Error checking keyword: {e}")
async def execute_keyword(self):
"""키워드 실행"""
try:
logger.info(f"Executing keyword: {self.keyword.keyword}")
# PipelineJob 생성
job = PipelineJob(
keyword_id=self.keyword.keyword_id,
keyword=self.keyword.keyword,
stage='rss_collection',
data={
'rss_feeds': self.keyword.rss_feeds if self.keyword.rss_feeds else [],
'max_articles': self.keyword.max_articles_per_run,
'scheduled': True,
'scheduler_instance': f"single-{self.keyword_text}"
},
priority=self.keyword.priority
)
# 큐에 작업 추가
await self.queue_manager.enqueue('rss_collection', job)
logger.info(f"Enqueued job for keyword '{self.keyword.keyword}' with job_id: {job.job_id}")
# 키워드 업데이트
update_data = {
"last_run": datetime.now(),
"next_run": datetime.now() + timedelta(minutes=self.keyword.interval_minutes),
"updated_at": datetime.now()
}
await self.db.keywords.update_one(
{"keyword_id": self.keyword.keyword_id},
{"$set": update_data}
)
logger.info(f"Updated keyword '{self.keyword.keyword}' - next run at {update_data['next_run']}")
except Exception as e:
logger.error(f"Error executing keyword {self.keyword.keyword}: {e}")
async def stop(self):
"""스케줄러 중지"""
await self.queue_manager.disconnect()
logger.info(f"Single Keyword Scheduler for '{self.keyword_text}' stopped")
async def main():
"""메인 함수"""
scheduler = SingleKeywordScheduler()
try:
await scheduler.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await scheduler.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1 @@
# Shared modules for pipeline services

View File

@ -0,0 +1,159 @@
"""
Pipeline Data Models
파이프라인 전체에서 사용되는 공통 데이터 모델
"""
from datetime import datetime
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field
import uuid
class KeywordSubscription(BaseModel):
"""키워드 구독 모델"""
keyword_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
keyword: str
language: str = "ko"
schedule: str = "0 */30 * * *" # Cron expression (30분마다)
is_active: bool = True
is_priority: bool = False
last_processed: Optional[datetime] = None
rss_feeds: List[str] = Field(default_factory=list)
categories: List[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.now)
owner: Optional[str] = None
class PipelineJob(BaseModel):
"""파이프라인 작업 모델"""
job_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
keyword_id: str
keyword: str
stage: str # current stage
stages_completed: List[str] = Field(default_factory=list)
data: Dict[str, Any] = Field(default_factory=dict)
retry_count: int = 0
max_retries: int = 3
priority: int = 0
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class RSSItem(BaseModel):
"""RSS 피드 아이템"""
item_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
title: str
link: str
guid: Optional[str] = None # RSS GUID for deduplication
published: Optional[str] = None
summary: Optional[str] = None
source_feed: str
class SearchResult(BaseModel):
"""검색 결과"""
title: str
link: str
snippet: Optional[str] = None
source: str = "google"
class EnrichedItem(BaseModel):
"""강화된 뉴스 아이템"""
rss_item: RSSItem
search_results: List[SearchResult] = Field(default_factory=list)
class SummarizedItem(BaseModel):
"""요약된 아이템"""
enriched_item: EnrichedItem
ai_summary: str
summary_language: str = "ko"
class TranslatedItem(BaseModel):
"""번역된 아이템"""
summarized_item: SummarizedItem
title_en: str
summary_en: str
class ItemWithImage(BaseModel):
"""이미지가 추가된 아이템"""
translated_item: TranslatedItem
image_url: str
image_prompt: str
class Subtopic(BaseModel):
"""기사 소주제"""
title: str
content: List[str] # 문단별 내용
class Entities(BaseModel):
"""개체명"""
people: List[str] = Field(default_factory=list)
organizations: List[str] = Field(default_factory=list)
groups: List[str] = Field(default_factory=list)
countries: List[str] = Field(default_factory=list)
events: List[str] = Field(default_factory=list)
class NewsReference(BaseModel):
"""뉴스 레퍼런스"""
title: str
link: str
source: str
published: Optional[str] = None
class FinalArticle(BaseModel):
"""최종 기사 - ai_writer_db.articles 스키마와 일치"""
news_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
title: str
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
summary: str
subtopics: List[Subtopic] = Field(default_factory=list)
categories: List[str] = Field(default_factory=list)
entities: Entities = Field(default_factory=Entities)
source_keyword: str
source_count: int = 1
# 레퍼런스 뉴스 정보
references: List[NewsReference] = Field(default_factory=list)
# 파이프라인 관련 추가 필드
job_id: Optional[str] = None
keyword_id: Optional[str] = None
pipeline_stages: List[str] = Field(default_factory=list)
processing_time: Optional[float] = None
# 다국어 지원
language: str = 'ko'
ref_news_id: Optional[str] = None
# RSS 중복 체크용 GUID
rss_guid: Optional[str] = None
# 이미지 관련 필드
image_prompt: Optional[str] = None
images: List[str] = Field(default_factory=list)
# 번역 추적
translated_languages: List[str] = Field(default_factory=list)
class TranslatedItem(BaseModel):
"""번역된 아이템"""
summarized_item: Dict[str, Any] # SummarizedItem as dict
translated_title: str
translated_summary: str
target_language: str = 'en'
class GeneratedImageItem(BaseModel):
"""이미지 생성된 아이템"""
translated_item: Dict[str, Any] # TranslatedItem as dict
image_url: str
image_prompt: str
class QueueMessage(BaseModel):
"""큐 메시지"""
message_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
queue_name: str
job: PipelineJob
timestamp: datetime = Field(default_factory=datetime.now)
retry_count: int = 0
class Keyword(BaseModel):
"""스케줄러용 키워드 모델"""
keyword_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
keyword: str
interval_minutes: int = Field(default=60) # 기본 1시간
is_active: bool = Field(default=True)
last_run: Optional[datetime] = None
next_run: Optional[datetime] = None
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
rss_feeds: List[str] = Field(default_factory=list) # 커스텀 RSS 피드
priority: int = Field(default=0) # 우선순위 (높을수록 우선)
max_articles_per_run: int = Field(default=100) # 실행당 최대 기사 수

View File

@ -0,0 +1,176 @@
"""
Queue Manager
Redis 기반 큐 관리 시스템
"""
import redis.asyncio as redis
import json
import logging
from typing import Optional, Dict, Any, List
from datetime import datetime
from .models import PipelineJob, QueueMessage
logger = logging.getLogger(__name__)
class QueueManager:
"""Redis 기반 큐 매니저"""
QUEUES = {
"keyword_processing": "queue:keyword_processing",
"rss_collection": "queue:rss_collection",
"search_enrichment": "queue:search_enrichment",
"google_search": "queue:google_search",
"ai_article_generation": "queue:ai_article_generation",
"image_generation": "queue:image_generation",
"translation": "queue:translation",
"failed": "queue:failed",
"scheduled": "queue:scheduled"
}
def __init__(self, redis_url: str = "redis://redis:6379"):
self.redis_url = redis_url
self.redis_client: Optional[redis.Redis] = None
async def connect(self):
"""Redis 연결"""
if not self.redis_client:
self.redis_client = await redis.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True
)
logger.info("Connected to Redis")
async def disconnect(self):
"""Redis 연결 해제"""
if self.redis_client:
await self.redis_client.close()
self.redis_client = None
async def enqueue(self, queue_name: str, job: PipelineJob, priority: int = 0) -> str:
"""작업을 큐에 추가"""
try:
queue_key = self.QUEUES.get(queue_name, f"queue:{queue_name}")
message = QueueMessage(
queue_name=queue_name,
job=job
)
# 우선순위에 따라 추가
if priority > 0:
await self.redis_client.lpush(queue_key, message.json())
else:
await self.redis_client.rpush(queue_key, message.json())
# 통계 업데이트
await self.redis_client.hincrby("stats:queues", queue_name, 1)
logger.info(f"Job {job.job_id} enqueued to {queue_name}")
return job.job_id
except Exception as e:
logger.error(f"Failed to enqueue job: {e}")
raise
async def dequeue(self, queue_name: str, timeout: int = 0) -> Optional[PipelineJob]:
"""큐에서 작업 가져오기"""
try:
queue_key = self.QUEUES.get(queue_name, f"queue:{queue_name}")
logger.info(f"Attempting to dequeue from {queue_key} with timeout={timeout}")
if timeout > 0:
result = await self.redis_client.blpop(queue_key, timeout)
if result:
_, data = result
logger.info(f"Dequeued item from {queue_key}")
else:
logger.debug(f"No item available in {queue_key}")
return None
else:
data = await self.redis_client.lpop(queue_key)
if data:
message = QueueMessage.parse_raw(data)
# 처리 중 목록에 추가
processing_key = f"processing:{queue_name}"
await self.redis_client.hset(
processing_key,
message.job.job_id,
message.json()
)
return message.job
return None
except Exception as e:
logger.error(f"Failed to dequeue job: {e}")
return None
async def mark_completed(self, queue_name: str, job_id: str):
"""작업 완료 표시"""
try:
processing_key = f"processing:{queue_name}"
await self.redis_client.hdel(processing_key, job_id)
# 통계 업데이트
await self.redis_client.hincrby("stats:completed", queue_name, 1)
logger.info(f"Job {job_id} completed in {queue_name}")
except Exception as e:
logger.error(f"Failed to mark job as completed: {e}")
async def mark_failed(self, queue_name: str, job: PipelineJob, error: str):
"""작업 실패 처리"""
try:
processing_key = f"processing:{queue_name}"
await self.redis_client.hdel(processing_key, job.job_id)
# 재시도 확인
if job.retry_count < job.max_retries:
job.retry_count += 1
await self.enqueue(queue_name, job)
logger.info(f"Job {job.job_id} requeued (retry {job.retry_count}/{job.max_retries})")
else:
# 실패 큐로 이동
job.data["error"] = error
job.data["failed_stage"] = queue_name
await self.enqueue("failed", job)
# 통계 업데이트
await self.redis_client.hincrby("stats:failed", queue_name, 1)
logger.error(f"Job {job.job_id} failed: {error}")
except Exception as e:
logger.error(f"Failed to mark job as failed: {e}")
async def get_queue_stats(self) -> Dict[str, Any]:
"""큐 통계 조회"""
try:
stats = {}
for name, key in self.QUEUES.items():
stats[name] = {
"pending": await self.redis_client.llen(key),
"processing": await self.redis_client.hlen(f"processing:{name}"),
}
# 완료/실패 통계
stats["completed"] = await self.redis_client.hgetall("stats:completed") or {}
stats["failed"] = await self.redis_client.hgetall("stats:failed") or {}
return stats
except Exception as e:
logger.error(f"Failed to get queue stats: {e}")
return {}
async def clear_queue(self, queue_name: str):
"""큐 초기화 (테스트용)"""
queue_key = self.QUEUES.get(queue_name, f"queue:{queue_name}")
await self.redis_client.delete(queue_key)
await self.redis_client.delete(f"processing:{queue_name}")
logger.info(f"Queue {queue_name} cleared")

View File

@ -0,0 +1,5 @@
redis[hiredis]==5.0.1
motor==3.1.1
pymongo==4.3.3
pydantic==2.5.0
python-dateutil==2.8.2

View File

@ -0,0 +1,54 @@
#!/usr/bin/env python3
"""
Simple pipeline test - direct queue injection
"""
import asyncio
import json
import redis.asyncio as redis
from datetime import datetime
import uuid
async def test():
# Redis 연결
r = await redis.from_url("redis://redis:6379", decode_responses=True)
# 작업 생성
job = {
"job_id": str(uuid.uuid4()),
"keyword_id": str(uuid.uuid4()),
"keyword": "전기차",
"stage": "rss_collection",
"stages_completed": [],
"data": {
"rss_feeds": [
"https://news.google.com/rss/search?q=전기차&hl=ko&gl=KR&ceid=KR:ko"
],
"categories": ["technology", "automotive"]
},
"priority": 1,
"retry_count": 0,
"max_retries": 3,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat()
}
# QueueMessage 형식으로 래핑
message = {
"message_id": str(uuid.uuid4()),
"queue_name": "rss_collection",
"job": job,
"timestamp": datetime.now().isoformat()
}
# 큐에 추가
await r.lpush("queue:rss_collection", json.dumps(message))
print(f"✅ Job {job['job_id']} added to queue:rss_collection")
# 큐 상태 확인
length = await r.llen("queue:rss_collection")
print(f"📊 Queue length: {length}")
await r.aclose()
if __name__ == "__main__":
asyncio.run(test())

View File

@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""
Direct dequeue test
"""
import asyncio
import redis.asyncio as redis
import json
async def test_dequeue():
"""Test dequeue directly"""
# Connect to Redis
redis_client = await redis.from_url(
"redis://redis:6379",
encoding="utf-8",
decode_responses=True
)
print("Connected to Redis")
# Check queue length
length = await redis_client.llen("queue:rss_collection")
print(f"Queue length: {length}")
if length > 0:
# Get the first item
item = await redis_client.lrange("queue:rss_collection", 0, 0)
print(f"First item preview: {item[0][:200]}...")
# Try blpop with timeout
print("Trying blpop with timeout=5...")
result = await redis_client.blpop("queue:rss_collection", 5)
if result:
queue, data = result
print(f"Successfully dequeued from {queue}")
print(f"Data: {data[:200]}...")
# Parse the message
try:
message = json.loads(data)
print(f"Message ID: {message.get('message_id')}")
print(f"Queue Name: {message.get('queue_name')}")
if 'job' in message:
job = message['job']
print(f"Job ID: {job.get('job_id')}")
print(f"Keyword: {job.get('keyword')}")
except Exception as e:
print(f"Failed to parse message: {e}")
else:
print("blpop timed out - no result")
else:
print("Queue is empty")
await redis_client.close()
if __name__ == "__main__":
asyncio.run(test_dequeue())

View File

@ -0,0 +1,118 @@
#!/usr/bin/env python3
"""
Pipeline Test Script
파이프라인 전체 플로우를 테스트하는 스크립트
"""
import asyncio
import json
from datetime import datetime
from motor.motor_asyncio import AsyncIOMotorClient
import redis.asyncio as redis
from shared.models import KeywordSubscription, PipelineJob
async def test_pipeline():
"""파이프라인 테스트"""
# MongoDB 연결
mongo_client = AsyncIOMotorClient("mongodb://mongodb:27017")
db = mongo_client.pipeline
# Redis 연결
redis_client = redis.Redis(host='redis', port=6379, decode_responses=True)
# 1. 테스트 키워드 추가
test_keyword = KeywordSubscription(
keyword="전기차",
language="ko",
schedule="*/1 * * * *", # 1분마다 (테스트용)
is_active=True,
is_priority=True,
rss_feeds=[
"https://news.google.com/rss/search?q=전기차&hl=ko&gl=KR&ceid=KR:ko",
"https://news.google.com/rss/search?q=electric+vehicle&hl=en&gl=US&ceid=US:en"
],
categories=["technology", "automotive", "environment"],
owner="test_user"
)
# MongoDB에 저장
await db.keyword_subscriptions.replace_one(
{"keyword": test_keyword.keyword},
test_keyword.dict(),
upsert=True
)
print(f"✅ 키워드 '{test_keyword.keyword}' 추가 완료")
# 2. 즉시 파이프라인 트리거 (스케줄러를 거치지 않고 직접)
job = PipelineJob(
keyword_id=test_keyword.keyword_id,
keyword=test_keyword.keyword,
stage="rss_collection",
data={
"rss_feeds": test_keyword.rss_feeds,
"categories": test_keyword.categories
},
priority=1 if test_keyword.is_priority else 0
)
# Redis 큐에 직접 추가 (QueueMessage 형식으로)
from shared.queue_manager import QueueMessage
message = QueueMessage(
queue_name="rss_collection",
job=job
)
await redis_client.lpush("queue:rss_collection", message.json())
print(f"✅ 작업을 RSS Collection 큐에 추가: {job.job_id}")
# 3. 파이프라인 상태 모니터링
print("\n📊 파이프라인 실행 모니터링 중...")
print("각 단계별 로그를 확인하려면 다음 명령을 실행하세요:")
print(" docker-compose logs -f pipeline-rss-collector")
print(" docker-compose logs -f pipeline-google-search")
print(" docker-compose logs -f pipeline-ai-summarizer")
print(" docker-compose logs -f pipeline-translator")
print(" docker-compose logs -f pipeline-image-generator")
print(" docker-compose logs -f pipeline-article-assembly")
# 큐 상태 확인
for i in range(10):
await asyncio.sleep(5)
# 각 큐의 길이 확인
queues = [
"queue:rss_collection",
"queue:google_search",
"queue:ai_summarization",
"queue:translation",
"queue:image_generation",
"queue:article_assembly"
]
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 큐 상태:")
for queue in queues:
length = await redis_client.llen(queue)
if length > 0:
print(f" {queue}: {length} 작업 대기 중")
# 4. 최종 결과 확인
print("\n📄 MongoDB에서 생성된 기사 확인 중...")
articles = await db.articles.find({"keyword": test_keyword.keyword}).to_list(length=5)
if articles:
print(f"{len(articles)}개의 기사 생성 완료!")
for article in articles:
print(f"\n제목: {article.get('title', 'N/A')}")
print(f"ID: {article.get('article_id', 'N/A')}")
print(f"생성 시간: {article.get('created_at', 'N/A')}")
print(f"처리 시간: {article.get('processing_time', 'N/A')}")
print(f"이미지 수: {len(article.get('images', []))}")
else:
print("⚠️ 아직 기사가 생성되지 않았습니다. 조금 더 기다려주세요.")
# 연결 종료
await redis_client.close()
mongo_client.close()
if __name__ == "__main__":
print("🚀 파이프라인 테스트 시작")
asyncio.run(test_pipeline())

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
"""
스타크래프트 키워드로 파이프라인 테스트
"""
import asyncio
import sys
import os
sys.path.append(os.path.dirname(__file__))
from shared.queue_manager import QueueManager
from shared.models import PipelineJob
async def test_starcraft_pipeline():
"""스타크래프트 키워드로 파이프라인 테스트"""
# Queue manager 초기화
queue_manager = QueueManager(redis_url="redis://redis:6379")
await queue_manager.connect()
try:
# 스타크래프트 파이프라인 작업 생성
job = PipelineJob(
keyword_id="test_starcraft_001",
keyword="스타크래프트",
stage="rss_collection",
data={}
)
print(f"🚀 스타크래프트 파이프라인 작업 시작")
print(f" 작업 ID: {job.job_id}")
print(f" 키워드: {job.keyword}")
print(f" 키워드 ID: {job.keyword_id}")
# RSS 수집 큐에 작업 추가
await queue_manager.enqueue('rss_collection', job)
print(f"✅ 작업이 rss_collection 큐에 추가되었습니다")
# 큐 상태 확인
stats = await queue_manager.get_queue_stats()
print(f"\n📊 현재 큐 상태:")
for queue_name, stat in stats.items():
if queue_name not in ['completed', 'failed']:
pending = stat.get('pending', 0)
processing = stat.get('processing', 0)
if pending > 0 or processing > 0:
print(f" {queue_name}: 대기={pending}, 처리중={processing}")
print(f"\n⏳ 파이프라인 실행을 모니터링하세요:")
print(f" docker logs site11_pipeline_rss_collector --tail 20 -f")
print(f" python3 check_mongodb.py")
finally:
await queue_manager.disconnect()
if __name__ == "__main__":
asyncio.run(test_starcraft_pipeline())

View File

@ -0,0 +1,54 @@
"""
파이프라인 테스트 작업 제출 스크립트
"""
import redis
import json
from datetime import datetime
import uuid
import sys
def submit_test_job(keyword='나스닥'):
# Redis 연결
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
# 테스트 작업 생성
job_id = str(uuid.uuid4())
keyword_id = f'test_{job_id[:8]}'
job_data = {
'job_id': job_id,
'keyword_id': keyword_id,
'keyword': keyword,
'created_at': datetime.now().isoformat(),
'stage': 'rss_collection',
'stages_completed': [],
'data': {}
}
# QueueMessage 래퍼 생성
queue_message = {
'message_id': str(uuid.uuid4()),
'queue_name': 'rss_collection',
'job': job_data,
'timestamp': datetime.now().isoformat(),
'attempts': 0
}
# 큐에 작업 추가 (rpush 사용 - priority=0인 경우)
redis_client.rpush('queue:rss_collection', json.dumps(queue_message))
print(f'✅ 파이프라인 시작: job_id={job_id}')
print(f'✅ 키워드: {keyword}')
print(f'✅ RSS Collection 큐에 작업 추가 완료')
# 큐 상태 확인
queue_len = redis_client.llen('queue:rss_collection')
print(f'✅ 현재 큐 길이: {queue_len}')
redis_client.close()
if __name__ == "__main__":
if len(sys.argv) > 1:
keyword = sys.argv[1]
else:
keyword = '나스닥'
submit_test_job(keyword)

View File

@ -0,0 +1,19 @@
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY ./translator/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared modules
COPY ./shared /app/shared
# Copy config directory
COPY ./config /app/config
# Copy application code
COPY ./translator /app
# Use multi_translator.py as the main service
CMD ["python", "multi_translator.py"]

View File

@ -0,0 +1,329 @@
"""
Language Sync Service
기존 기사를 새로운 언어로 번역하는 백그라운드 서비스
"""
import asyncio
import logging
import os
import sys
import json
from typing import List, Dict, Any
import httpx
from motor.motor_asyncio import AsyncIOMotorClient
from datetime import datetime
# Add parent directory to path for shared module
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import from shared module
from shared.models import FinalArticle, Subtopic, Entities, NewsReference
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LanguageSyncService:
def __init__(self):
self.deepl_api_key = os.getenv("DEEPL_API_KEY", "3abbc796-2515-44a8-972d-22dcf27ab54a")
self.deepl_api_url = "https://api.deepl.com/v2/translate"
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.db = None
self.languages_config = None
self.config_path = "/app/config/languages.json"
self.sync_batch_size = 10
self.sync_delay = 2.0 # 언어 간 지연
async def load_config(self):
"""언어 설정 파일 로드"""
try:
if os.path.exists(self.config_path):
with open(self.config_path, 'r', encoding='utf-8') as f:
self.languages_config = json.load(f)
logger.info(f"Loaded language config")
else:
raise FileNotFoundError(f"Config file not found: {self.config_path}")
except Exception as e:
logger.error(f"Error loading config: {e}")
raise
async def start(self):
"""백그라운드 싱크 서비스 시작"""
logger.info("Starting Language Sync Service")
# 설정 로드
await self.load_config()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# 주기적으로 싱크 체크 (10분마다)
while True:
try:
await self.sync_missing_translations()
await asyncio.sleep(600) # 10분 대기
except Exception as e:
logger.error(f"Error in sync loop: {e}")
await asyncio.sleep(60) # 에러 시 1분 후 재시도
async def sync_missing_translations(self):
"""누락된 번역 싱크"""
try:
# 활성화된 언어 목록
enabled_languages = [
lang for lang in self.languages_config["enabled_languages"]
if lang["enabled"]
]
if not enabled_languages:
logger.info("No enabled languages for sync")
return
# 원본 언어 컬렉션
source_collection = self.languages_config["source_language"]["collection"]
for lang_config in enabled_languages:
await self.sync_language(source_collection, lang_config)
await asyncio.sleep(self.sync_delay)
except Exception as e:
logger.error(f"Error in sync_missing_translations: {e}")
async def sync_language(self, source_collection: str, lang_config: Dict):
"""특정 언어로 누락된 기사 번역"""
try:
target_collection = lang_config["collection"]
# 번역되지 않은 기사 찾기
# 원본에는 있지만 대상 컬렉션에는 없는 기사
source_articles = await self.db[source_collection].find(
{},
{"news_id": 1}
).to_list(None)
source_ids = {article["news_id"] for article in source_articles}
translated_articles = await self.db[target_collection].find(
{},
{"news_id": 1}
).to_list(None)
translated_ids = {article["news_id"] for article in translated_articles}
# 누락된 news_id
missing_ids = source_ids - translated_ids
if not missing_ids:
logger.info(f"No missing translations for {lang_config['name']}")
return
logger.info(f"Found {len(missing_ids)} missing translations for {lang_config['name']}")
# 배치로 처리
missing_list = list(missing_ids)
for i in range(0, len(missing_list), self.sync_batch_size):
batch = missing_list[i:i+self.sync_batch_size]
for news_id in batch:
try:
# 원본 기사 조회
korean_article = await self.db[source_collection].find_one(
{"news_id": news_id}
)
if not korean_article:
continue
# 번역 수행
await self.translate_and_save(
korean_article,
lang_config
)
logger.info(f"Synced article {news_id} to {lang_config['code']}")
# API 속도 제한
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"Error translating {news_id} to {lang_config['code']}: {e}")
continue
# 배치 간 지연
if i + self.sync_batch_size < len(missing_list):
await asyncio.sleep(self.sync_delay)
except Exception as e:
logger.error(f"Error syncing language {lang_config['code']}: {e}")
async def translate_and_save(self, korean_article: Dict, lang_config: Dict):
"""기사 번역 및 저장"""
try:
# 제목 번역
translated_title = await self._translate_text(
korean_article.get('title', ''),
target_lang=lang_config["deepl_code"]
)
# 요약 번역
translated_summary = await self._translate_text(
korean_article.get('summary', ''),
target_lang=lang_config["deepl_code"]
)
# Subtopics 번역
translated_subtopics = []
for subtopic in korean_article.get('subtopics', []):
translated_subtopic_title = await self._translate_text(
subtopic.get('title', ''),
target_lang=lang_config["deepl_code"]
)
translated_content_list = []
for content_para in subtopic.get('content', []):
translated_para = await self._translate_text(
content_para,
target_lang=lang_config["deepl_code"]
)
translated_content_list.append(translated_para)
translated_subtopics.append(Subtopic(
title=translated_subtopic_title,
content=translated_content_list
))
# 카테고리 번역
translated_categories = []
for category in korean_article.get('categories', []):
translated_cat = await self._translate_text(
category,
target_lang=lang_config["deepl_code"]
)
translated_categories.append(translated_cat)
# Entities와 References는 원본 유지
entities_data = korean_article.get('entities', {})
translated_entities = Entities(**entities_data) if entities_data else Entities()
references = []
for ref_data in korean_article.get('references', []):
references.append(NewsReference(**ref_data))
# 번역된 기사 생성
translated_article = FinalArticle(
news_id=korean_article.get('news_id'),
title=translated_title,
summary=translated_summary,
subtopics=translated_subtopics,
categories=translated_categories,
entities=translated_entities,
source_keyword=korean_article.get('source_keyword'),
source_count=korean_article.get('source_count', 1),
references=references,
job_id=korean_article.get('job_id'),
keyword_id=korean_article.get('keyword_id'),
pipeline_stages=korean_article.get('pipeline_stages', []) + ['sync_translation'],
processing_time=korean_article.get('processing_time', 0),
language=lang_config["code"],
ref_news_id=None,
rss_guid=korean_article.get('rss_guid'), # RSS GUID 유지
image_prompt=korean_article.get('image_prompt'), # 이미지 프롬프트 유지
images=korean_article.get('images', []), # 이미지 URL 리스트 유지
translated_languages=korean_article.get('translated_languages', []) # 번역 언어 목록 유지
)
# MongoDB에 저장
collection_name = lang_config["collection"]
result = await self.db[collection_name].insert_one(translated_article.model_dump())
# 원본 기사에 번역 완료 표시
await self.db[self.languages_config["source_language"]["collection"]].update_one(
{"news_id": korean_article.get('news_id')},
{
"$addToSet": {
"translated_languages": lang_config["code"]
}
}
)
logger.info(f"Synced article to {collection_name}: {result.inserted_id}")
except Exception as e:
logger.error(f"Error in translate_and_save: {e}")
raise
async def _translate_text(self, text: str, target_lang: str = 'EN') -> str:
"""DeepL API를 사용한 텍스트 번역"""
try:
if not text:
return ""
async with httpx.AsyncClient() as client:
response = await client.post(
self.deepl_api_url,
data={
'auth_key': self.deepl_api_key,
'text': text,
'target_lang': target_lang,
'source_lang': 'KO'
},
timeout=30
)
if response.status_code == 200:
result = response.json()
return result['translations'][0]['text']
else:
logger.error(f"DeepL API error: {response.status_code}")
return text
except Exception as e:
logger.error(f"Error translating text: {e}")
return text
async def manual_sync(self, language_code: str = None):
"""수동 싱크 실행"""
logger.info(f"Manual sync requested for language: {language_code or 'all'}")
await self.load_config()
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
if language_code:
# 특정 언어만 싱크
lang_config = next(
(lang for lang in self.languages_config["enabled_languages"]
if lang["code"] == language_code and lang["enabled"]),
None
)
if lang_config:
source_collection = self.languages_config["source_language"]["collection"]
await self.sync_language(source_collection, lang_config)
else:
logger.error(f"Language {language_code} not found or not enabled")
else:
# 모든 활성 언어 싱크
await self.sync_missing_translations()
async def main():
"""메인 함수"""
service = LanguageSyncService()
# 명령줄 인수 확인
if len(sys.argv) > 1:
if sys.argv[1] == "sync":
# 수동 싱크 모드
language = sys.argv[2] if len(sys.argv) > 2 else None
await service.manual_sync(language)
else:
logger.error(f"Unknown command: {sys.argv[1]}")
else:
# 백그라운드 서비스 모드
try:
await service.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,320 @@
"""
Multi-Language Translation Service
다국어 번역 서비스 - 설정 기반 다중 언어 지원
"""
import asyncio
import logging
import os
import sys
import json
from typing import List, Dict, Any
import httpx
import redis.asyncio as redis
from motor.motor_asyncio import AsyncIOMotorClient
from datetime import datetime
# Import from shared module
from shared.models import PipelineJob, FinalArticle
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultiLanguageTranslator:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.deepl_api_key = os.getenv("DEEPL_API_KEY", "3abbc796-2515-44a8-972d-22dcf27ab54a")
self.deepl_api_url = "https://api.deepl.com/v2/translate"
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.db = None
self.languages_config = None
self.config_path = "/app/config/languages.json"
async def load_config(self):
"""언어 설정 파일 로드"""
try:
if os.path.exists(self.config_path):
with open(self.config_path, 'r', encoding='utf-8') as f:
self.languages_config = json.load(f)
else:
# 기본 설정 (영어만)
self.languages_config = {
"enabled_languages": [
{
"code": "en",
"name": "English",
"deepl_code": "EN",
"collection": "articles_en",
"enabled": True
}
],
"source_language": {
"code": "ko",
"name": "Korean",
"collection": "articles_ko"
},
"translation_settings": {
"batch_size": 5,
"delay_between_languages": 2.0,
"delay_between_articles": 0.5,
"max_retries": 3
}
}
logger.info(f"Loaded language config: {len(self.get_enabled_languages())} languages enabled")
except Exception as e:
logger.error(f"Error loading config: {e}")
raise
def get_enabled_languages(self) -> List[Dict]:
"""활성화된 언어 목록 반환"""
return [lang for lang in self.languages_config["enabled_languages"] if lang["enabled"]]
async def start(self):
"""워커 시작"""
logger.info("Starting Multi-Language Translator Worker")
# 설정 로드
await self.load_config()
# Redis 연결
await self.queue_manager.connect()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# DeepL API 키 확인
if not self.deepl_api_key:
logger.error("DeepL API key not configured")
return
# 메인 처리 루프
while True:
try:
# 큐에서 작업 가져오기
job = await self.queue_manager.dequeue('translation', timeout=5)
if job:
await self.process_job(job)
except Exception as e:
logger.error(f"Error in worker loop: {e}")
await asyncio.sleep(1)
async def process_job(self, job: PipelineJob):
"""모든 활성 언어로 번역"""
try:
logger.info(f"Processing job {job.job_id} for multi-language translation")
# MongoDB에서 한국어 기사 가져오기
news_id = job.data.get('news_id')
if not news_id:
logger.error(f"No news_id in job {job.job_id}")
await self.queue_manager.mark_failed('translation', job, "No news_id")
return
# 원본 컬렉션에서 기사 조회
source_collection = self.languages_config["source_language"]["collection"]
korean_article = await self.db[source_collection].find_one({"news_id": news_id})
if not korean_article:
logger.error(f"Article {news_id} not found in {source_collection}")
await self.queue_manager.mark_failed('translation', job, "Article not found")
return
# 활성화된 모든 언어로 번역
enabled_languages = self.get_enabled_languages()
settings = self.languages_config["translation_settings"]
for lang_config in enabled_languages:
try:
logger.info(f"Translating article {news_id} to {lang_config['name']}")
# 이미 번역되었는지 확인
existing = await self.db[lang_config["collection"]].find_one({"news_id": news_id})
if existing:
logger.info(f"Article {news_id} already translated to {lang_config['code']}")
continue
# 번역 수행
await self.translate_article(
korean_article,
lang_config,
job
)
# 언어 간 지연
if settings.get("delay_between_languages"):
await asyncio.sleep(settings["delay_between_languages"])
except Exception as e:
logger.error(f"Error translating to {lang_config['code']}: {e}")
continue
# 파이프라인 완료 로그
logger.info(f"Translation pipeline completed for news_id: {news_id}")
# 완료 표시
job.stages_completed.append('translation')
await self.queue_manager.mark_completed('translation', job.job_id)
logger.info(f"Multi-language translation completed for job {job.job_id}")
except Exception as e:
logger.error(f"Error processing job {job.job_id}: {e}")
await self.queue_manager.mark_failed('translation', job, str(e))
async def translate_article(self, korean_article: Dict, lang_config: Dict, job: PipelineJob):
"""특정 언어로 기사 번역"""
try:
# 제목 번역
translated_title = await self._translate_text(
korean_article.get('title', ''),
target_lang=lang_config["deepl_code"]
)
# 요약 번역
translated_summary = await self._translate_text(
korean_article.get('summary', ''),
target_lang=lang_config["deepl_code"]
)
# Subtopics 번역
from shared.models import Subtopic
translated_subtopics = []
for subtopic in korean_article.get('subtopics', []):
translated_subtopic_title = await self._translate_text(
subtopic.get('title', ''),
target_lang=lang_config["deepl_code"]
)
translated_content_list = []
for content_para in subtopic.get('content', []):
translated_para = await self._translate_text(
content_para,
target_lang=lang_config["deepl_code"]
)
translated_content_list.append(translated_para)
# API 속도 제한
settings = self.languages_config["translation_settings"]
if settings.get("delay_between_articles"):
await asyncio.sleep(settings["delay_between_articles"])
translated_subtopics.append(Subtopic(
title=translated_subtopic_title,
content=translated_content_list
))
# 카테고리 번역
translated_categories = []
for category in korean_article.get('categories', []):
translated_cat = await self._translate_text(
category,
target_lang=lang_config["deepl_code"]
)
translated_categories.append(translated_cat)
# Entities와 References는 원본 유지
from shared.models import Entities, NewsReference
entities_data = korean_article.get('entities', {})
translated_entities = Entities(**entities_data) if entities_data else Entities()
references = []
for ref_data in korean_article.get('references', []):
references.append(NewsReference(**ref_data))
# 번역된 기사 생성
translated_article = FinalArticle(
news_id=korean_article.get('news_id'), # 같은 news_id 사용
title=translated_title,
summary=translated_summary,
subtopics=translated_subtopics,
categories=translated_categories,
entities=translated_entities,
source_keyword=job.keyword if hasattr(job, 'keyword') else korean_article.get('source_keyword'),
source_count=korean_article.get('source_count', 1),
references=references,
job_id=job.job_id,
keyword_id=job.keyword_id if hasattr(job, 'keyword_id') else None,
pipeline_stages=korean_article.get('pipeline_stages', []) + ['translation'],
processing_time=korean_article.get('processing_time', 0),
language=lang_config["code"],
ref_news_id=None, # 같은 news_id 사용하므로 불필요
rss_guid=korean_article.get('rss_guid'), # RSS GUID 유지
image_prompt=korean_article.get('image_prompt'), # 이미지 프롬프트 유지
images=korean_article.get('images', []), # 이미지 URL 리스트 유지
translated_languages=korean_article.get('translated_languages', []) # 번역 언어 목록 유지
)
# MongoDB에 저장
collection_name = lang_config["collection"]
result = await self.db[collection_name].insert_one(translated_article.model_dump())
logger.info(f"Article saved to {collection_name} with _id: {result.inserted_id}, language: {lang_config['code']}")
# 원본 기사에 번역 완료 표시
await self.db[self.languages_config["source_language"]["collection"]].update_one(
{"news_id": korean_article.get('news_id')},
{
"$addToSet": {
"translated_languages": lang_config["code"]
}
}
)
except Exception as e:
logger.error(f"Error translating article to {lang_config['code']}: {e}")
raise
async def _translate_text(self, text: str, target_lang: str = 'EN') -> str:
"""DeepL API를 사용한 텍스트 번역"""
try:
if not text:
return ""
async with httpx.AsyncClient() as client:
response = await client.post(
self.deepl_api_url,
data={
'auth_key': self.deepl_api_key,
'text': text,
'target_lang': target_lang,
'source_lang': 'KO'
},
timeout=30
)
if response.status_code == 200:
result = response.json()
return result['translations'][0]['text']
else:
logger.error(f"DeepL API error: {response.status_code}")
return text # 번역 실패시 원본 반환
except Exception as e:
logger.error(f"Error translating text: {e}")
return text # 번역 실패시 원본 반환
async def stop(self):
"""워커 중지"""
await self.queue_manager.disconnect()
logger.info("Multi-Language Translator Worker stopped")
async def main():
"""메인 함수"""
worker = MultiLanguageTranslator()
try:
await worker.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await worker.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,5 @@
httpx==0.25.0
redis[hiredis]==5.0.1
pydantic==2.5.0
motor==3.1.1
pymongo==4.3.3

View File

@ -0,0 +1,230 @@
"""
Translation Service
DeepL API를 사용한 번역 서비스
"""
import asyncio
import logging
import os
import sys
from typing import List, Dict, Any
import httpx
from motor.motor_asyncio import AsyncIOMotorClient
from datetime import datetime
# Import from shared module
from shared.models import PipelineJob, FinalArticle
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TranslatorWorker:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.deepl_api_key = os.getenv("DEEPL_API_KEY", "3abbc796-2515-44a8-972d-22dcf27ab54a")
# DeepL Pro API 엔드포인트 사용
self.deepl_api_url = "https://api.deepl.com/v2/translate"
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.db = None
async def start(self):
"""워커 시작"""
logger.info("Starting Translator Worker")
# Redis 연결
await self.queue_manager.connect()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# DeepL API 키 확인
if not self.deepl_api_key:
logger.error("DeepL API key not configured")
return
# 메인 처리 루프
while True:
try:
# 큐에서 작업 가져오기
job = await self.queue_manager.dequeue('translation', timeout=5)
if job:
await self.process_job(job)
except Exception as e:
logger.error(f"Error in worker loop: {e}")
await asyncio.sleep(1)
async def process_job(self, job: PipelineJob):
"""영어 버전 기사 생성 및 저장"""
try:
logger.info(f"Processing job {job.job_id} for translation")
# MongoDB에서 한국어 기사 가져오기
news_id = job.data.get('news_id')
if not news_id:
logger.error(f"No news_id in job {job.job_id}")
await self.queue_manager.mark_failed('translation', job, "No news_id")
return
# MongoDB에서 한국어 기사 조회 (articles_ko)
korean_article = await self.db.articles_ko.find_one({"news_id": news_id})
if not korean_article:
logger.error(f"Article {news_id} not found in MongoDB")
await self.queue_manager.mark_failed('translation', job, "Article not found")
return
# 영어로 번역
translated_title = await self._translate_text(
korean_article.get('title', ''),
target_lang='EN'
)
translated_summary = await self._translate_text(
korean_article.get('summary', ''),
target_lang='EN'
)
# Subtopics 번역
from shared.models import Subtopic
translated_subtopics = []
for subtopic in korean_article.get('subtopics', []):
translated_subtopic_title = await self._translate_text(
subtopic.get('title', ''),
target_lang='EN'
)
translated_content_list = []
for content_para in subtopic.get('content', []):
translated_para = await self._translate_text(
content_para,
target_lang='EN'
)
translated_content_list.append(translated_para)
await asyncio.sleep(0.2) # API 속도 제한
translated_subtopics.append(Subtopic(
title=translated_subtopic_title,
content=translated_content_list
))
# 카테고리 번역
translated_categories = []
for category in korean_article.get('categories', []):
translated_cat = await self._translate_text(category, target_lang='EN')
translated_categories.append(translated_cat)
await asyncio.sleep(0.2) # API 속도 제한
# Entities 번역 (선택적)
from shared.models import Entities
entities_data = korean_article.get('entities', {})
translated_entities = Entities(
people=entities_data.get('people', []), # 인명은 번역하지 않음
organizations=entities_data.get('organizations', []), # 조직명은 번역하지 않음
groups=entities_data.get('groups', []),
countries=entities_data.get('countries', []),
events=entities_data.get('events', [])
)
# 레퍼런스 가져오기 (번역하지 않음)
from shared.models import NewsReference
references = []
for ref_data in korean_article.get('references', []):
references.append(NewsReference(**ref_data))
# 영어 버전 기사 생성 - 같은 news_id 사용
english_article = FinalArticle(
news_id=news_id, # 원본과 같은 news_id 사용
title=translated_title,
summary=translated_summary,
subtopics=translated_subtopics,
categories=translated_categories,
entities=translated_entities,
source_keyword=job.keyword,
source_count=korean_article.get('source_count', 1),
references=references, # 원본 레퍼런스 그대로 사용
job_id=job.job_id,
keyword_id=job.keyword_id,
pipeline_stages=job.stages_completed.copy() + ['translation'],
processing_time=korean_article.get('processing_time', 0),
language='en', # 영어
ref_news_id=None # 같은 news_id를 사용하므로 ref 불필요
)
# MongoDB에 영어 버전 저장 (articles_en)
result = await self.db.articles_en.insert_one(english_article.model_dump())
english_article_id = str(result.inserted_id)
logger.info(f"English article saved with _id: {english_article_id}, news_id: {news_id}, language: en")
# 원본 한국어 기사 업데이트 - 번역 완료 표시
await self.db.articles_ko.update_one(
{"news_id": news_id},
{
"$addToSet": {
"pipeline_stages": "translation"
}
}
)
# 완료 표시
job.stages_completed.append('translation')
await self.queue_manager.mark_completed('translation', job.job_id)
logger.info(f"Translation completed for job {job.job_id}")
except Exception as e:
logger.error(f"Error processing job {job.job_id}: {e}")
await self.queue_manager.mark_failed('translation', job, str(e))
async def _translate_text(self, text: str, target_lang: str = 'EN') -> str:
"""DeepL API를 사용한 텍스트 번역"""
try:
if not text:
return ""
async with httpx.AsyncClient() as client:
response = await client.post(
self.deepl_api_url,
data={
'auth_key': self.deepl_api_key,
'text': text,
'target_lang': target_lang,
'source_lang': 'KO'
},
timeout=30
)
if response.status_code == 200:
result = response.json()
return result['translations'][0]['text']
else:
logger.error(f"DeepL API error: {response.status_code}")
return text # 번역 실패시 원본 반환
except Exception as e:
logger.error(f"Error translating text: {e}")
return text # 번역 실패시 원본 반환
async def stop(self):
"""워커 중지"""
await self.queue_manager.disconnect()
logger.info("Translator Worker stopped")
async def main():
"""메인 함수"""
worker = TranslatorWorker()
try:
await worker.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await worker.stop()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,21 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Create necessary directories
RUN mkdir -p /app/logs
# Run the application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]

View File

@ -0,0 +1,286 @@
"""
Data indexer for synchronizing data from other services to Solr
"""
import asyncio
import logging
from typing import Dict, Any, List
from motor.motor_asyncio import AsyncIOMotorClient
from aiokafka import AIOKafkaConsumer
import json
from solr_client import SolrClient
from datetime import datetime
logger = logging.getLogger(__name__)
class DataIndexer:
def __init__(self, solr_client: SolrClient, mongodb_url: str, kafka_servers: str):
self.solr = solr_client
self.mongodb_url = mongodb_url
self.kafka_servers = kafka_servers
self.mongo_client = None
self.kafka_consumer = None
self.running = False
async def start(self):
"""Start the indexer"""
try:
# Connect to MongoDB
self.mongo_client = AsyncIOMotorClient(self.mongodb_url)
# Initialize Kafka consumer
await self._init_kafka_consumer()
# Start background tasks
self.running = True
asyncio.create_task(self._consume_kafka_events())
asyncio.create_task(self._periodic_sync())
logger.info("Data indexer started")
except Exception as e:
logger.error(f"Failed to start indexer: {e}")
async def stop(self):
"""Stop the indexer"""
self.running = False
if self.kafka_consumer:
await self.kafka_consumer.stop()
if self.mongo_client:
self.mongo_client.close()
logger.info("Data indexer stopped")
async def _init_kafka_consumer(self):
"""Initialize Kafka consumer"""
try:
self.kafka_consumer = AIOKafkaConsumer(
'user_events',
'file_events',
'content_events',
bootstrap_servers=self.kafka_servers,
value_deserializer=lambda m: json.loads(m.decode('utf-8')),
group_id='search_indexer',
auto_offset_reset='latest'
)
await self.kafka_consumer.start()
logger.info("Kafka consumer initialized")
except Exception as e:
logger.warning(f"Kafka consumer initialization failed: {e}")
self.kafka_consumer = None
async def _consume_kafka_events(self):
"""Consume events from Kafka and index them"""
if not self.kafka_consumer:
return
while self.running:
try:
async for msg in self.kafka_consumer:
await self._handle_kafka_event(msg.topic, msg.value)
except Exception as e:
logger.error(f"Kafka consumption error: {e}")
await asyncio.sleep(5)
async def _handle_kafka_event(self, topic: str, event: Dict[str, Any]):
"""Handle a Kafka event"""
try:
event_type = event.get('type')
data = event.get('data', {})
if topic == 'user_events':
await self._index_user_event(event_type, data)
elif topic == 'file_events':
await self._index_file_event(event_type, data)
elif topic == 'content_events':
await self._index_content_event(event_type, data)
except Exception as e:
logger.error(f"Failed to handle event: {e}")
async def _index_user_event(self, event_type: str, data: Dict):
"""Index user-related events"""
if event_type == 'user_created' or event_type == 'user_updated':
user_doc = {
'id': f"user_{data.get('user_id')}",
'doc_type': 'user',
'user_id': data.get('user_id'),
'username': data.get('username'),
'email': data.get('email'),
'name': data.get('name', ''),
'bio': data.get('bio', ''),
'tags': data.get('tags', []),
'created_at': data.get('created_at'),
'updated_at': datetime.utcnow().isoformat()
}
self.solr.index_document(user_doc)
elif event_type == 'user_deleted':
self.solr.delete_document(f"user_{data.get('user_id')}")
async def _index_file_event(self, event_type: str, data: Dict):
"""Index file-related events"""
if event_type == 'file_uploaded':
file_doc = {
'id': f"file_{data.get('file_id')}",
'doc_type': 'file',
'file_id': data.get('file_id'),
'filename': data.get('filename'),
'content_type': data.get('content_type'),
'size': data.get('size'),
'user_id': data.get('user_id'),
'tags': data.get('tags', []),
'description': data.get('description', ''),
'created_at': data.get('created_at'),
'updated_at': datetime.utcnow().isoformat()
}
self.solr.index_document(file_doc)
elif event_type == 'file_deleted':
self.solr.delete_document(f"file_{data.get('file_id')}")
async def _index_content_event(self, event_type: str, data: Dict):
"""Index content-related events"""
if event_type in ['content_created', 'content_updated']:
content_doc = {
'id': f"content_{data.get('content_id')}",
'doc_type': 'content',
'content_id': data.get('content_id'),
'title': data.get('title'),
'content': data.get('content', ''),
'summary': data.get('summary', ''),
'author_id': data.get('author_id'),
'tags': data.get('tags', []),
'category': data.get('category'),
'status': data.get('status', 'draft'),
'created_at': data.get('created_at'),
'updated_at': datetime.utcnow().isoformat()
}
self.solr.index_document(content_doc)
elif event_type == 'content_deleted':
self.solr.delete_document(f"content_{data.get('content_id')}")
async def _periodic_sync(self):
"""Periodically sync data from MongoDB"""
while self.running:
try:
# Sync every 5 minutes
await asyncio.sleep(300)
await self.sync_all_data()
except Exception as e:
logger.error(f"Periodic sync error: {e}")
async def sync_all_data(self):
"""Sync all data from MongoDB to Solr"""
try:
logger.info("Starting full data sync")
# Sync users
await self._sync_users()
# Sync files
await self._sync_files()
# Optimize index after bulk sync
self.solr.optimize_index()
logger.info("Full data sync completed")
except Exception as e:
logger.error(f"Full sync failed: {e}")
async def _sync_users(self):
"""Sync users from MongoDB"""
try:
db = self.mongo_client['users_db']
collection = db['users']
users = []
async for user in collection.find({'deleted_at': None}):
user_doc = {
'id': f"user_{str(user['_id'])}",
'doc_type': 'user',
'user_id': str(user['_id']),
'username': user.get('username'),
'email': user.get('email'),
'name': user.get('name', ''),
'bio': user.get('bio', ''),
'tags': user.get('tags', []),
'created_at': user.get('created_at').isoformat() if user.get('created_at') else None,
'updated_at': datetime.utcnow().isoformat()
}
users.append(user_doc)
# Bulk index every 100 documents
if len(users) >= 100:
self.solr.bulk_index(users, 'user')
users = []
# Index remaining users
if users:
self.solr.bulk_index(users, 'user')
logger.info(f"Synced users to Solr")
except Exception as e:
logger.error(f"Failed to sync users: {e}")
async def _sync_files(self):
"""Sync files from MongoDB"""
try:
db = self.mongo_client['files_db']
collection = db['file_metadata']
files = []
async for file in collection.find({'deleted_at': None}):
file_doc = {
'id': f"file_{str(file['_id'])}",
'doc_type': 'file',
'file_id': str(file['_id']),
'filename': file.get('filename'),
'original_name': file.get('original_name'),
'content_type': file.get('content_type'),
'size': file.get('size'),
'user_id': file.get('user_id'),
'tags': list(file.get('tags', {}).keys()),
'description': file.get('metadata', {}).get('description', ''),
'created_at': file.get('created_at').isoformat() if file.get('created_at') else None,
'updated_at': datetime.utcnow().isoformat()
}
files.append(file_doc)
# Bulk index every 100 documents
if len(files) >= 100:
self.solr.bulk_index(files, 'file')
files = []
# Index remaining files
if files:
self.solr.bulk_index(files, 'file')
logger.info(f"Synced files to Solr")
except Exception as e:
logger.error(f"Failed to sync files: {e}")
async def reindex_collection(self, collection_name: str, doc_type: str):
"""Reindex a specific collection"""
try:
# Delete existing documents of this type
self.solr.delete_by_query(f'doc_type:{doc_type}')
# Sync the collection
if collection_name == 'users':
await self._sync_users()
elif collection_name == 'files':
await self._sync_files()
logger.info(f"Reindexed {collection_name}")
except Exception as e:
logger.error(f"Failed to reindex {collection_name}: {e}")

View File

@ -0,0 +1,362 @@
"""
Search Service with Apache Solr
"""
from fastapi import FastAPI, Query, HTTPException
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
import logging
import os
from typing import Optional, List, Dict, Any
from datetime import datetime
from solr_client import SolrClient
from indexer import DataIndexer
import asyncio
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global instances
solr_client = None
data_indexer = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifecycle"""
global solr_client, data_indexer
# Startup
logger.info("Starting Search Service...")
# Wait for Solr to be ready
solr_url = os.getenv("SOLR_URL", "http://solr:8983/solr")
max_retries = 30
for i in range(max_retries):
try:
solr_client = SolrClient(solr_url=solr_url, core_name="site11")
logger.info("Connected to Solr")
break
except Exception as e:
logger.warning(f"Waiting for Solr... ({i+1}/{max_retries})")
await asyncio.sleep(2)
if solr_client:
# Initialize data indexer
mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
kafka_servers = os.getenv("KAFKA_BOOTSTRAP_SERVERS", "kafka:9092")
data_indexer = DataIndexer(solr_client, mongodb_url, kafka_servers)
await data_indexer.start()
# Initial data sync
asyncio.create_task(data_indexer.sync_all_data())
yield
# Shutdown
if data_indexer:
await data_indexer.stop()
logger.info("Search Service stopped")
app = FastAPI(
title="Search Service",
description="Full-text search with Apache Solr",
version="1.0.0",
lifespan=lifespan
)
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"service": "search",
"timestamp": datetime.utcnow().isoformat(),
"solr_connected": solr_client is not None
}
@app.get("/api/search")
async def search(
q: str = Query(..., description="Search query"),
doc_type: Optional[str] = Query(None, description="Filter by document type"),
start: int = Query(0, ge=0, description="Starting offset"),
rows: int = Query(10, ge=1, le=100, description="Number of results"),
sort: Optional[str] = Query(None, description="Sort order (e.g., 'created_at desc')"),
facet: bool = Query(False, description="Enable faceting"),
facet_field: Optional[List[str]] = Query(None, description="Fields to facet on")
):
"""
Search documents across all indexed content
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
# Build filter query
fq = []
if doc_type:
fq.append(f"doc_type:{doc_type}")
# Prepare search parameters
search_params = {
'start': start,
'rows': rows,
'facet': facet
}
if fq:
search_params['fq'] = fq
if sort:
search_params['sort'] = sort
if facet_field:
search_params['facet_field'] = facet_field
# Execute search
results = solr_client.search(q, **search_params)
return {
"query": q,
"total": results['total'],
"start": start,
"rows": rows,
"documents": results['documents'],
"facets": results.get('facets', {}),
"highlighting": results.get('highlighting', {})
}
except Exception as e:
logger.error(f"Search failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/search/suggest")
async def suggest(
q: str = Query(..., min_length=1, description="Query prefix"),
field: str = Query("title", description="Field to search in"),
limit: int = Query(10, ge=1, le=50, description="Maximum suggestions")
):
"""
Get autocomplete suggestions
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
suggestions = solr_client.suggest(q, field, limit)
return {
"query": q,
"suggestions": suggestions,
"count": len(suggestions)
}
except Exception as e:
logger.error(f"Suggest failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/search/similar/{doc_id}")
async def find_similar(
doc_id: str,
rows: int = Query(5, ge=1, le=20, description="Number of similar documents")
):
"""
Find documents similar to the given document
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
similar_docs = solr_client.more_like_this(doc_id, rows=rows)
return {
"source_document": doc_id,
"similar_documents": similar_docs,
"count": len(similar_docs)
}
except Exception as e:
logger.error(f"Similar search failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search/index")
async def index_document(document: Dict[str, Any]):
"""
Index a single document
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
doc_type = document.get('doc_type', 'general')
success = solr_client.index_document(document, doc_type)
if success:
return {
"status": "success",
"message": "Document indexed",
"document_id": document.get('id')
}
else:
raise HTTPException(status_code=500, detail="Failed to index document")
except Exception as e:
logger.error(f"Indexing failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search/bulk-index")
async def bulk_index(documents: List[Dict[str, Any]]):
"""
Bulk index multiple documents
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
indexed = solr_client.bulk_index(documents)
return {
"status": "success",
"message": f"Indexed {indexed} documents",
"count": indexed
}
except Exception as e:
logger.error(f"Bulk indexing failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/search/document/{doc_id}")
async def delete_document(doc_id: str):
"""
Delete a document from the index
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
success = solr_client.delete_document(doc_id)
if success:
return {
"status": "success",
"message": "Document deleted",
"document_id": doc_id
}
else:
raise HTTPException(status_code=500, detail="Failed to delete document")
except Exception as e:
logger.error(f"Deletion failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/search/stats")
async def get_stats():
"""
Get search index statistics
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
stats = solr_client.get_stats()
return {
"status": "success",
"statistics": stats,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Failed to get stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search/reindex/{collection}")
async def reindex_collection(
collection: str,
doc_type: Optional[str] = Query(None, description="Document type for the collection")
):
"""
Reindex a specific collection
"""
if not data_indexer:
raise HTTPException(status_code=503, detail="Indexer service unavailable")
try:
if not doc_type:
# Map collection to doc_type
doc_type_map = {
'users': 'user',
'files': 'file',
'content': 'content'
}
doc_type = doc_type_map.get(collection, collection)
asyncio.create_task(data_indexer.reindex_collection(collection, doc_type))
return {
"status": "success",
"message": f"Reindexing {collection} started",
"collection": collection,
"doc_type": doc_type
}
except Exception as e:
logger.error(f"Reindex failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search/optimize")
async def optimize_index():
"""
Optimize the search index
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
success = solr_client.optimize_index()
if success:
return {
"status": "success",
"message": "Index optimization started"
}
else:
raise HTTPException(status_code=500, detail="Failed to optimize index")
except Exception as e:
logger.error(f"Optimization failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search/clear")
async def clear_index():
"""
Clear all documents from the index (DANGER!)
"""
if not solr_client:
raise HTTPException(status_code=503, detail="Search service unavailable")
try:
success = solr_client.clear_index()
if success:
return {
"status": "success",
"message": "Index cleared",
"warning": "All documents have been deleted!"
}
else:
raise HTTPException(status_code=500, detail="Failed to clear index")
except Exception as e:
logger.error(f"Clear index failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -0,0 +1,10 @@
fastapi==0.109.0
uvicorn[standard]==0.27.0
pydantic==2.5.3
python-dotenv==1.0.0
pysolr==3.9.0
httpx==0.25.2
motor==3.5.1
pymongo==4.6.1
aiokafka==0.10.0
redis==5.0.1

View File

@ -0,0 +1,303 @@
"""
Apache Solr client for search operations
"""
import pysolr
import logging
from typing import Dict, List, Any, Optional
from datetime import datetime
import json
logger = logging.getLogger(__name__)
class SolrClient:
def __init__(self, solr_url: str = "http://solr:8983/solr", core_name: str = "site11"):
self.solr_url = f"{solr_url}/{core_name}"
self.core_name = core_name
self.solr = None
self.connect()
def connect(self):
"""Connect to Solr instance"""
try:
self.solr = pysolr.Solr(
self.solr_url,
always_commit=True,
timeout=10
)
# Test connection
self.solr.ping()
logger.info(f"Connected to Solr at {self.solr_url}")
except Exception as e:
logger.error(f"Failed to connect to Solr: {e}")
raise
def index_document(self, document: Dict[str, Any], doc_type: str = None) -> bool:
"""Index a single document"""
try:
# Add metadata
if doc_type:
document["doc_type"] = doc_type
if "id" not in document:
document["id"] = f"{doc_type}_{document.get('_id', '')}"
# Add indexing timestamp
document["indexed_at"] = datetime.utcnow().isoformat()
# Index the document
self.solr.add([document])
logger.info(f"Indexed document: {document.get('id')}")
return True
except Exception as e:
logger.error(f"Failed to index document: {e}")
return False
def bulk_index(self, documents: List[Dict[str, Any]], doc_type: str = None) -> int:
"""Bulk index multiple documents"""
try:
indexed = 0
for doc in documents:
if doc_type:
doc["doc_type"] = doc_type
if "id" not in doc:
doc["id"] = f"{doc_type}_{doc.get('_id', '')}"
doc["indexed_at"] = datetime.utcnow().isoformat()
self.solr.add(documents)
indexed = len(documents)
logger.info(f"Bulk indexed {indexed} documents")
return indexed
except Exception as e:
logger.error(f"Failed to bulk index: {e}")
return 0
def search(self, query: str, **kwargs) -> Dict[str, Any]:
"""
Search documents
Args:
query: Search query string
**kwargs: Additional search parameters
- fq: Filter queries
- fl: Fields to return
- start: Starting offset
- rows: Number of rows
- sort: Sort order
- facet: Enable faceting
- facet.field: Fields to facet on
"""
try:
# Default parameters
params = {
'q': query,
'start': kwargs.get('start', 0),
'rows': kwargs.get('rows', 10),
'fl': kwargs.get('fl', '*,score'),
'defType': 'edismax',
'qf': 'title^3 content^2 tags description name', # Boost fields
'mm': '2<-25%', # Minimum match
'hl': 'true', # Highlighting
'hl.fl': 'title,content,description',
'hl.simple.pre': '<mark>',
'hl.simple.post': '</mark>'
}
# Add filter queries
if 'fq' in kwargs:
params['fq'] = kwargs['fq']
# Add sorting
if 'sort' in kwargs:
params['sort'] = kwargs['sort']
# Add faceting
if kwargs.get('facet'):
params.update({
'facet': 'true',
'facet.field': kwargs.get('facet_field', ['doc_type', 'tags', 'status']),
'facet.mincount': 1
})
# Execute search
results = self.solr.search(**params)
# Format response
response = {
'total': results.hits,
'documents': [],
'facets': {},
'highlighting': {}
}
# Add documents
for doc in results.docs:
response['documents'].append(doc)
# Add facets if available
if hasattr(results, 'facets') and results.facets:
if 'facet_fields' in results.facets:
for field, values in results.facets['facet_fields'].items():
response['facets'][field] = [
{'value': values[i], 'count': values[i+1]}
for i in range(0, len(values), 2)
]
# Add highlighting if available
if hasattr(results, 'highlighting'):
response['highlighting'] = results.highlighting
return response
except Exception as e:
logger.error(f"Search failed: {e}")
return {'total': 0, 'documents': [], 'error': str(e)}
def suggest(self, prefix: str, field: str = "suggest", limit: int = 10) -> List[str]:
"""Get autocomplete suggestions"""
try:
params = {
'q': f'{field}:{prefix}*',
'fl': field,
'rows': limit,
'start': 0
}
results = self.solr.search(**params)
suggestions = []
for doc in results.docs:
if field in doc:
value = doc[field]
if isinstance(value, list):
suggestions.extend(value)
else:
suggestions.append(value)
# Remove duplicates and limit
seen = set()
unique_suggestions = []
for s in suggestions:
if s not in seen:
seen.add(s)
unique_suggestions.append(s)
if len(unique_suggestions) >= limit:
break
return unique_suggestions
except Exception as e:
logger.error(f"Suggest failed: {e}")
return []
def more_like_this(self, doc_id: str, mlt_fields: List[str] = None, rows: int = 5) -> List[Dict]:
"""Find similar documents"""
try:
if not mlt_fields:
mlt_fields = ['title', 'content', 'tags', 'description']
params = {
'q': f'id:{doc_id}',
'mlt': 'true',
'mlt.fl': ','.join(mlt_fields),
'mlt.mindf': 1,
'mlt.mintf': 1,
'mlt.count': rows,
'fl': '*,score'
}
results = self.solr.search(**params)
if results.docs:
# The MLT results are in the moreLikeThis section
if hasattr(results, 'moreLikeThis'):
mlt_results = results.moreLikeThis.get(doc_id, {})
if 'docs' in mlt_results:
return mlt_results['docs']
return []
except Exception as e:
logger.error(f"More like this failed: {e}")
return []
def delete_document(self, doc_id: str) -> bool:
"""Delete a document by ID"""
try:
self.solr.delete(id=doc_id)
logger.info(f"Deleted document: {doc_id}")
return True
except Exception as e:
logger.error(f"Failed to delete document: {e}")
return False
def delete_by_query(self, query: str) -> bool:
"""Delete documents matching a query"""
try:
self.solr.delete(q=query)
logger.info(f"Deleted documents matching: {query}")
return True
except Exception as e:
logger.error(f"Failed to delete by query: {e}")
return False
def clear_index(self) -> bool:
"""Clear all documents from index"""
try:
self.solr.delete(q='*:*')
logger.info("Cleared all documents from index")
return True
except Exception as e:
logger.error(f"Failed to clear index: {e}")
return False
def get_stats(self) -> Dict[str, Any]:
"""Get index statistics"""
try:
# Get document count
results = self.solr.search(q='*:*', rows=0)
# Get facet counts for doc_type
facet_results = self.solr.search(
q='*:*',
rows=0,
facet='true',
**{'facet.field': ['doc_type', 'status']}
)
stats = {
'total_documents': results.hits,
'doc_types': {},
'status_counts': {}
}
if hasattr(facet_results, 'facets') and facet_results.facets:
if 'facet_fields' in facet_results.facets:
# Parse doc_type facets
doc_type_facets = facet_results.facets['facet_fields'].get('doc_type', [])
for i in range(0, len(doc_type_facets), 2):
stats['doc_types'][doc_type_facets[i]] = doc_type_facets[i+1]
# Parse status facets
status_facets = facet_results.facets['facet_fields'].get('status', [])
for i in range(0, len(status_facets), 2):
stats['status_counts'][status_facets[i]] = status_facets[i+1]
return stats
except Exception as e:
logger.error(f"Failed to get stats: {e}")
return {'error': str(e)}
def optimize_index(self) -> bool:
"""Optimize the Solr index"""
try:
self.solr.optimize()
logger.info("Index optimized")
return True
except Exception as e:
logger.error(f"Failed to optimize index: {e}")
return False

View File

@ -0,0 +1,292 @@
#!/usr/bin/env python3
"""
Test script for Search Service with Apache Solr
"""
import asyncio
import httpx
import json
from datetime import datetime
BASE_URL = "http://localhost:8015"
async def test_search_api():
"""Test search API endpoints"""
async with httpx.AsyncClient() as client:
print("\n🔍 Testing Search Service API...")
# Test health check
print("\n1. Testing health check...")
response = await client.get(f"{BASE_URL}/health")
print(f"Health check: {response.json()}")
# Test index sample documents
print("\n2. Indexing sample documents...")
# Index user document
user_doc = {
"id": "user_test_001",
"doc_type": "user",
"user_id": "test_001",
"username": "john_doe",
"email": "john@example.com",
"name": "John Doe",
"bio": "Software developer passionate about Python and microservices",
"tags": ["python", "developer", "backend"],
"created_at": datetime.utcnow().isoformat()
}
response = await client.post(f"{BASE_URL}/api/search/index", json=user_doc)
print(f"Indexed user: {response.json()}")
# Index file documents
file_docs = [
{
"id": "file_test_001",
"doc_type": "file",
"file_id": "test_file_001",
"filename": "architecture_diagram.png",
"content_type": "image/png",
"size": 1024000,
"user_id": "test_001",
"tags": ["architecture", "design", "documentation"],
"description": "System architecture diagram showing microservices",
"created_at": datetime.utcnow().isoformat()
},
{
"id": "file_test_002",
"doc_type": "file",
"file_id": "test_file_002",
"filename": "user_manual.pdf",
"content_type": "application/pdf",
"size": 2048000,
"user_id": "test_001",
"tags": ["documentation", "manual", "guide"],
"description": "Complete user manual for the application",
"created_at": datetime.utcnow().isoformat()
}
]
response = await client.post(f"{BASE_URL}/api/search/bulk-index", json=file_docs)
print(f"Bulk indexed files: {response.json()}")
# Index content documents
content_docs = [
{
"id": "content_test_001",
"doc_type": "content",
"content_id": "test_content_001",
"title": "Getting Started with Microservices",
"content": "Microservices architecture is a method of developing software applications as a suite of independently deployable services.",
"summary": "Introduction to microservices architecture patterns",
"author_id": "test_001",
"tags": ["microservices", "architecture", "tutorial"],
"category": "technology",
"status": "published",
"created_at": datetime.utcnow().isoformat()
},
{
"id": "content_test_002",
"doc_type": "content",
"content_id": "test_content_002",
"title": "Python Best Practices",
"content": "Learn the best practices for writing clean, maintainable Python code including PEP 8 style guide.",
"summary": "Essential Python coding standards and practices",
"author_id": "test_001",
"tags": ["python", "programming", "best-practices"],
"category": "programming",
"status": "published",
"created_at": datetime.utcnow().isoformat()
}
]
response = await client.post(f"{BASE_URL}/api/search/bulk-index", json=content_docs)
print(f"Bulk indexed content: {response.json()}")
# Wait for indexing
await asyncio.sleep(2)
# Test basic search
print("\n3. Testing basic search...")
response = await client.get(
f"{BASE_URL}/api/search",
params={"q": "microservices"}
)
results = response.json()
print(f"Search for 'microservices': Found {results['total']} results")
if results['documents']:
print(f"First result: {results['documents'][0].get('title', results['documents'][0].get('filename', 'N/A'))}")
# Test search with filters
print("\n4. Testing filtered search...")
response = await client.get(
f"{BASE_URL}/api/search",
params={
"q": "*:*",
"doc_type": "file",
"rows": 5
}
)
results = response.json()
print(f"Files search: Found {results['total']} files")
# Test faceted search
print("\n5. Testing faceted search...")
response = await client.get(
f"{BASE_URL}/api/search",
params={
"q": "*:*",
"facet": "true",
"facet_field": ["doc_type", "tags", "category", "status"]
}
)
results = response.json()
print(f"Facets: {json.dumps(results['facets'], indent=2)}")
# Test autocomplete/suggest
print("\n6. Testing autocomplete...")
response = await client.get(
f"{BASE_URL}/api/search/suggest",
params={
"q": "micro",
"field": "title",
"limit": 5
}
)
suggestions = response.json()
print(f"Suggestions for 'micro': {suggestions['suggestions']}")
# Test similar documents
print("\n7. Testing similar documents...")
response = await client.get(f"{BASE_URL}/api/search/similar/content_test_001")
if response.status_code == 200:
similar = response.json()
print(f"Found {similar['count']} similar documents")
else:
print(f"Similar search: {response.status_code}")
# Test search with highlighting
print("\n8. Testing search with highlighting...")
response = await client.get(
f"{BASE_URL}/api/search",
params={"q": "Python"}
)
results = response.json()
if results['highlighting']:
print(f"Highlighting results: {len(results['highlighting'])} documents highlighted")
# Test search statistics
print("\n9. Testing search statistics...")
response = await client.get(f"{BASE_URL}/api/search/stats")
if response.status_code == 200:
stats = response.json()
print(f"Index stats: {stats['statistics']}")
# Test complex query
print("\n10. Testing complex query...")
response = await client.get(
f"{BASE_URL}/api/search",
params={
"q": "architecture OR python",
"doc_type": "content",
"sort": "created_at desc",
"rows": 10
}
)
results = response.json()
print(f"Complex query: Found {results['total']} results")
# Test delete document
print("\n11. Testing document deletion...")
response = await client.delete(f"{BASE_URL}/api/search/document/content_test_002")
if response.status_code == 200:
print(f"Deleted document: {response.json()}")
# Verify deletion
await asyncio.sleep(1)
response = await client.get(
f"{BASE_URL}/api/search",
params={"q": "id:content_test_002"}
)
results = response.json()
print(f"Verify deletion: Found {results['total']} results (should be 0)")
async def test_performance():
"""Test search performance"""
print("\n\n⚡ Testing Search Performance...")
async with httpx.AsyncClient(timeout=30.0) as client:
# Index many documents
print("Indexing 100 test documents...")
docs = []
for i in range(100):
docs.append({
"id": f"perf_test_{i}",
"doc_type": "content",
"title": f"Test Document {i}",
"content": f"This is test content for document {i} with various keywords like search, Solr, Python, microservices",
"tags": [f"tag{i%10}", f"category{i%5}"],
"created_at": datetime.utcnow().isoformat()
})
response = await client.post(f"{BASE_URL}/api/search/bulk-index", json=docs)
print(f"Indexed {response.json().get('count', 0)} documents")
# Wait for indexing
await asyncio.sleep(2)
# Test search speed
print("\nTesting search response times...")
import time
queries = ["search", "Python", "document", "test", "microservices"]
for query in queries:
start = time.time()
response = await client.get(
f"{BASE_URL}/api/search",
params={"q": query, "rows": 20}
)
elapsed = time.time() - start
results = response.json()
print(f"Query '{query}': {results['total']} results in {elapsed:.3f}s")
async def test_reindex():
"""Test reindexing from MongoDB"""
print("\n\n🔄 Testing Reindex Functionality...")
async with httpx.AsyncClient() as client:
# Trigger reindex for users collection
print("Triggering reindex for users collection...")
response = await client.post(
f"{BASE_URL}/api/search/reindex/users",
params={"doc_type": "user"}
)
if response.status_code == 200:
print(f"Reindex started: {response.json()}")
else:
print(f"Reindex failed: {response.status_code}")
# Test index optimization
print("\nTesting index optimization...")
response = await client.post(f"{BASE_URL}/api/search/optimize")
if response.status_code == 200:
print(f"Optimization: {response.json()}")
async def main():
"""Run all tests"""
print("=" * 60)
print("SEARCH SERVICE TEST SUITE (Apache Solr)")
print("=" * 60)
print(f"Started at: {datetime.now().isoformat()}")
# Run tests
await test_search_api()
await test_performance()
await test_reindex()
print("\n" + "=" * 60)
print("✅ All search tests completed!")
print(f"Finished at: {datetime.now().isoformat()}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,105 @@
<?xml version="1.0" encoding="UTF-8"?>
<schema name="site11" version="1.6">
<!-- Field Types -->
<fieldType name="string" class="solr.StrField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="boolean" class="solr.BoolField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="int" class="solr.IntPointField" omitNorms="true"/>
<fieldType name="long" class="solr.LongPointField" omitNorms="true"/>
<fieldType name="float" class="solr.FloatPointField" omitNorms="true"/>
<fieldType name="double" class="solr.DoublePointField" omitNorms="true"/>
<fieldType name="date" class="solr.DatePointField" omitNorms="true"/>
<!-- Text field with analysis -->
<fieldType name="text_general" class="solr.TextField" positionIncrementGap="100">
<analyzer type="index">
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.StopFilterFactory" ignoreCase="true" words="stopwords.txt"/>
<filter class="solr.LowerCaseFilterFactory"/>
<filter class="solr.EdgeNGramFilterFactory" minGramSize="2" maxGramSize="15"/>
</analyzer>
<analyzer type="query">
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.StopFilterFactory" ignoreCase="true" words="stopwords.txt"/>
<filter class="solr.SynonymGraphFilterFactory" synonyms="synonyms.txt" ignoreCase="true" expand="true"/>
<filter class="solr.LowerCaseFilterFactory"/>
</analyzer>
</fieldType>
<!-- Text field for exact matching -->
<fieldType name="text_exact" class="solr.TextField">
<analyzer>
<tokenizer class="solr.KeywordTokenizerFactory"/>
<filter class="solr.LowerCaseFilterFactory"/>
</analyzer>
</fieldType>
<!-- Autocomplete/Suggest field -->
<fieldType name="text_suggest" class="solr.TextField" positionIncrementGap="100">
<analyzer>
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.LowerCaseFilterFactory"/>
<filter class="solr.EdgeNGramFilterFactory" minGramSize="1" maxGramSize="20"/>
</analyzer>
</fieldType>
<!-- Fields -->
<field name="id" type="string" indexed="true" stored="true" required="true"/>
<field name="_version_" type="long" indexed="true" stored="true"/>
<!-- Document type and metadata -->
<field name="doc_type" type="string" indexed="true" stored="true" docValues="true"/>
<field name="indexed_at" type="date" indexed="true" stored="true"/>
<!-- Common fields across document types -->
<field name="title" type="text_general" indexed="true" stored="true" termVectors="true"/>
<field name="content" type="text_general" indexed="true" stored="true" termVectors="true"/>
<field name="description" type="text_general" indexed="true" stored="true"/>
<field name="summary" type="text_general" indexed="true" stored="true"/>
<field name="tags" type="string" indexed="true" stored="true" multiValued="true" docValues="true"/>
<field name="category" type="string" indexed="true" stored="true" docValues="true"/>
<field name="status" type="string" indexed="true" stored="true" docValues="true"/>
<!-- User-specific fields -->
<field name="user_id" type="string" indexed="true" stored="true"/>
<field name="username" type="text_exact" indexed="true" stored="true"/>
<field name="email" type="text_exact" indexed="true" stored="true"/>
<field name="name" type="text_general" indexed="true" stored="true"/>
<field name="bio" type="text_general" indexed="true" stored="true"/>
<!-- File-specific fields -->
<field name="file_id" type="string" indexed="true" stored="true"/>
<field name="filename" type="text_general" indexed="true" stored="true"/>
<field name="original_name" type="text_general" indexed="true" stored="true"/>
<field name="content_type" type="string" indexed="true" stored="true" docValues="true"/>
<field name="size" type="long" indexed="true" stored="true"/>
<!-- Content-specific fields -->
<field name="content_id" type="string" indexed="true" stored="true"/>
<field name="author_id" type="string" indexed="true" stored="true"/>
<!-- Dates -->
<field name="created_at" type="date" indexed="true" stored="true"/>
<field name="updated_at" type="date" indexed="true" stored="true"/>
<!-- Suggest field for autocomplete -->
<field name="suggest" type="text_suggest" indexed="true" stored="false" multiValued="true"/>
<!-- Copy fields for better search -->
<copyField source="title" dest="suggest"/>
<copyField source="name" dest="suggest"/>
<copyField source="filename" dest="suggest"/>
<copyField source="tags" dest="suggest"/>
<!-- Dynamic fields -->
<dynamicField name="*_i" type="int" indexed="true" stored="true"/>
<dynamicField name="*_l" type="long" indexed="true" stored="true"/>
<dynamicField name="*_f" type="float" indexed="true" stored="true"/>
<dynamicField name="*_d" type="double" indexed="true" stored="true"/>
<dynamicField name="*_s" type="string" indexed="true" stored="true"/>
<dynamicField name="*_t" type="text_general" indexed="true" stored="true"/>
<dynamicField name="*_dt" type="date" indexed="true" stored="true"/>
<dynamicField name="*_b" type="boolean" indexed="true" stored="true"/>
<!-- Unique Key -->
<uniqueKey>id</uniqueKey>
</schema>

View File

@ -0,0 +1,152 @@
<?xml version="1.0" encoding="UTF-8" ?>
<config>
<luceneMatchVersion>9.4.0</luceneMatchVersion>
<!-- Data Directory -->
<dataDir>${solr.data.dir:}</dataDir>
<!-- Index Config -->
<indexConfig>
<ramBufferSizeMB>100</ramBufferSizeMB>
<maxBufferedDocs>1000</maxBufferedDocs>
<mergePolicyFactory class="org.apache.solr.index.TieredMergePolicyFactory">
<int name="maxMergeAtOnce">10</int>
<int name="segmentsPerTier">10</int>
</mergePolicyFactory>
</indexConfig>
<!-- Update Handler -->
<updateHandler class="solr.DirectUpdateHandler2">
<updateLog>
<str name="dir">${solr.ulog.dir:}</str>
<int name="numVersionBuckets">${solr.ulog.numVersionBuckets:65536}</int>
</updateLog>
<autoCommit>
<maxTime>${solr.autoCommit.maxTime:15000}</maxTime>
<openSearcher>false</openSearcher>
</autoCommit>
<autoSoftCommit>
<maxTime>${solr.autoSoftCommit.maxTime:1000}</maxTime>
</autoSoftCommit>
</updateHandler>
<!-- Query Settings -->
<query>
<maxBooleanClauses>1024</maxBooleanClauses>
<filterCache class="solr.CaffeineCache" size="512" initialSize="512" autowarmCount="0"/>
<queryResultCache class="solr.CaffeineCache" size="512" initialSize="512" autowarmCount="0"/>
<documentCache class="solr.CaffeineCache" size="512" initialSize="512" autowarmCount="0"/>
<enableLazyFieldLoading>true</enableLazyFieldLoading>
<queryResultWindowSize>20</queryResultWindowSize>
<queryResultMaxDocsCached>200</queryResultMaxDocsCached>
</query>
<!-- Request Dispatcher -->
<requestDispatcher>
<requestParsers enableRemoteStreaming="true" multipartUploadLimitInKB="2048000"
formdataUploadLimitInKB="2048" addHttpRequestToContext="false"/>
<httpCaching never304="true"/>
</requestDispatcher>
<!-- Request Handlers -->
<!-- Standard search handler -->
<requestHandler name="/select" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<int name="rows">10</int>
<str name="df">content</str>
<str name="q.op">OR</str>
<str name="defType">edismax</str>
<str name="qf">
title^3.0 name^2.5 content^2.0 description^1.5 summary^1.5
filename^1.5 tags^1.2 category username email bio
</str>
<str name="pf">
title^4.0 name^3.0 content^2.5 description^2.0
</str>
<str name="mm">2&lt;-25%</str>
<str name="hl">true</str>
<str name="hl.fl">title,content,description,summary</str>
<str name="hl.simple.pre">&lt;mark&gt;</str>
<str name="hl.simple.post">&lt;/mark&gt;</str>
<str name="facet">true</str>
<str name="facet.mincount">1</str>
</lst>
</requestHandler>
<!-- Update handler -->
<requestHandler name="/update" class="solr.UpdateRequestHandler"/>
<!-- Get handler -->
<requestHandler name="/get" class="solr.RealTimeGetHandler">
<lst name="defaults">
<str name="omitHeader">true</str>
</lst>
</requestHandler>
<!-- Admin handlers -->
<requestHandler name="/admin/ping" class="solr.PingRequestHandler">
<lst name="invariants">
<str name="q">solrpingquery</str>
</lst>
<lst name="defaults">
<str name="echoParams">all</str>
</lst>
</requestHandler>
<!-- Suggest/Autocomplete handler -->
<requestHandler name="/suggest" class="solr.SearchHandler">
<lst name="defaults">
<str name="suggest">true</str>
<str name="suggest.count">10</str>
<str name="suggest.dictionary">suggest</str>
</lst>
<arr name="components">
<str>suggest</str>
</arr>
</requestHandler>
<!-- Spell check component -->
<searchComponent name="spellcheck" class="solr.SpellCheckComponent">
<str name="queryAnalyzerFieldType">text_general</str>
<lst name="spellchecker">
<str name="name">default</str>
<str name="field">content</str>
<str name="classname">solr.DirectSolrSpellChecker</str>
<str name="distanceMeasure">internal</str>
<float name="accuracy">0.5</float>
<int name="maxEdits">2</int>
<int name="minPrefix">1</int>
<int name="maxInspections">5</int>
<int name="minQueryLength">4</int>
<float name="maxQueryFrequency">0.01</float>
</lst>
</searchComponent>
<!-- Suggest component -->
<searchComponent name="suggest" class="solr.SuggestComponent">
<lst name="suggester">
<str name="name">suggest</str>
<str name="lookupImpl">FuzzyLookupFactory</str>
<str name="dictionaryImpl">DocumentDictionaryFactory</str>
<str name="field">suggest</str>
<str name="suggestAnalyzerFieldType">text_suggest</str>
<str name="buildOnStartup">false</str>
</lst>
</searchComponent>
<!-- More Like This handler -->
<requestHandler name="/mlt" class="solr.MoreLikeThisHandler">
<lst name="defaults">
<str name="mlt.fl">title,content,description,tags</str>
<int name="mlt.mindf">1</int>
<int name="mlt.mintf">1</int>
<int name="mlt.count">10</int>
</lst>
</requestHandler>
<!-- Schema handler (removed for Solr 9.x compatibility) -->
<!-- Config handler (removed for Solr 9.x compatibility) -->
</config>

View File

@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF)
# Standard English stop words
a
an
and
are
as
at
be
but
by
for
if
in
into
is
it
no
not
of
on
or
such
that
the
their
then
there
these
they
this
to
was
will
with

View File

@ -0,0 +1,38 @@
# Synonyms for site11 search
# Format: term1, term2, term3 => all are synonyms
# Or: term1, term2 => term1 is replaced by term2
# Technology synonyms
javascript, js
typescript, ts
python, py
golang, go
database, db
kubernetes, k8s
docker, container, containerization
# Common terms
search, find, query, lookup
upload, import, add
download, export, get
delete, remove, erase
update, modify, edit, change
create, make, new, add
# File related
document, doc, file
image, picture, photo, img
video, movie, clip
audio, sound, music
# User related
user, member, account
admin, administrator, moderator
profile, account, user
# Status
active, enabled, live
inactive, disabled, offline
pending, waiting, processing
complete, done, finished
error, failed, failure

View File

@ -0,0 +1,21 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Expose port
EXPOSE 8000
# Run the application
CMD ["python", "main.py"]

View File

@ -0,0 +1,617 @@
"""
Data Aggregator - Performs data aggregation and analytics
"""
import asyncio
import logging
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from models import (
AggregatedMetric, AggregationType, Granularity,
UserAnalytics, SystemAnalytics, EventAnalytics,
AlertRule, Alert
)
import uuid
import io
import csv
logger = logging.getLogger(__name__)
class DataAggregator:
"""Performs data aggregation and analytics operations"""
def __init__(self, ts_db, cache):
self.ts_db = ts_db
self.cache = cache
self.is_running = False
self.alert_rules = {}
self.active_alerts = {}
self.aggregation_jobs = []
async def start_aggregation_jobs(self):
"""Start background aggregation jobs"""
self.is_running = True
# Schedule periodic aggregation jobs
self.aggregation_jobs = [
asyncio.create_task(self._aggregate_hourly_metrics()),
asyncio.create_task(self._aggregate_daily_metrics()),
asyncio.create_task(self._check_alert_rules()),
asyncio.create_task(self._cleanup_old_data())
]
logger.info("Data aggregation jobs started")
async def stop(self):
"""Stop aggregation jobs"""
self.is_running = False
# Cancel all jobs
for job in self.aggregation_jobs:
job.cancel()
# Wait for jobs to complete
await asyncio.gather(*self.aggregation_jobs, return_exceptions=True)
logger.info("Data aggregation jobs stopped")
async def _aggregate_hourly_metrics(self):
"""Aggregate metrics every hour"""
while self.is_running:
try:
await asyncio.sleep(3600) # Run every hour
end_time = datetime.now()
start_time = end_time - timedelta(hours=1)
# Aggregate different metric types
await self._aggregate_metric_type("user.event", start_time, end_time, Granularity.HOUR)
await self._aggregate_metric_type("system.cpu", start_time, end_time, Granularity.HOUR)
await self._aggregate_metric_type("system.memory", start_time, end_time, Granularity.HOUR)
logger.info("Completed hourly metrics aggregation")
except Exception as e:
logger.error(f"Error in hourly aggregation: {e}")
async def _aggregate_daily_metrics(self):
"""Aggregate metrics every day"""
while self.is_running:
try:
await asyncio.sleep(86400) # Run every 24 hours
end_time = datetime.now()
start_time = end_time - timedelta(days=1)
# Aggregate different metric types
await self._aggregate_metric_type("user.event", start_time, end_time, Granularity.DAY)
await self._aggregate_metric_type("system", start_time, end_time, Granularity.DAY)
logger.info("Completed daily metrics aggregation")
except Exception as e:
logger.error(f"Error in daily aggregation: {e}")
async def _aggregate_metric_type(
self,
metric_prefix: str,
start_time: datetime,
end_time: datetime,
granularity: Granularity
):
"""Aggregate a specific metric type"""
try:
# Query raw metrics
metrics = await self.ts_db.query_metrics(
metric_type=metric_prefix,
start_time=start_time,
end_time=end_time
)
if not metrics:
return
# Calculate aggregations
aggregations = {
AggregationType.AVG: sum(m['value'] for m in metrics) / len(metrics),
AggregationType.SUM: sum(m['value'] for m in metrics),
AggregationType.MIN: min(m['value'] for m in metrics),
AggregationType.MAX: max(m['value'] for m in metrics),
AggregationType.COUNT: len(metrics)
}
# Store aggregated results
for agg_type, value in aggregations.items():
aggregated = AggregatedMetric(
metric_name=metric_prefix,
aggregation_type=agg_type,
value=value,
start_time=start_time,
end_time=end_time,
granularity=granularity,
count=len(metrics)
)
await self.ts_db.store_aggregated_metric(aggregated)
except Exception as e:
logger.error(f"Error aggregating {metric_prefix}: {e}")
async def aggregate_metrics(
self,
metric_type: str,
aggregation: str,
group_by: Optional[str],
start_time: datetime,
end_time: datetime
) -> Dict[str, Any]:
"""Perform custom metric aggregation"""
try:
# Query metrics
metrics = await self.ts_db.query_metrics(
metric_type=metric_type,
start_time=start_time,
end_time=end_time
)
if not metrics:
return {"result": 0, "count": 0}
# Group metrics if requested
if group_by:
grouped = {}
for metric in metrics:
key = metric.get('tags', {}).get(group_by, 'unknown')
if key not in grouped:
grouped[key] = []
grouped[key].append(metric['value'])
# Aggregate each group
results = {}
for key, values in grouped.items():
results[key] = self._calculate_aggregation(values, aggregation)
return {"grouped_results": results, "count": len(metrics)}
else:
# Single aggregation
values = [m['value'] for m in metrics]
result = self._calculate_aggregation(values, aggregation)
return {"result": result, "count": len(metrics)}
except Exception as e:
logger.error(f"Error in custom aggregation: {e}")
raise
def _calculate_aggregation(self, values: List[float], aggregation: str) -> float:
"""Calculate aggregation on values"""
if not values:
return 0
if aggregation == "avg":
return sum(values) / len(values)
elif aggregation == "sum":
return sum(values)
elif aggregation == "min":
return min(values)
elif aggregation == "max":
return max(values)
elif aggregation == "count":
return len(values)
else:
return 0
async def get_overview(self) -> Dict[str, Any]:
"""Get analytics overview"""
try:
now = datetime.now()
last_hour = now - timedelta(hours=1)
last_day = now - timedelta(days=1)
last_week = now - timedelta(weeks=1)
# Get various metrics
hourly_events = await self.ts_db.count_metrics("user.event", last_hour, now)
daily_events = await self.ts_db.count_metrics("user.event", last_day, now)
weekly_events = await self.ts_db.count_metrics("user.event", last_week, now)
# Get system status
cpu_avg = await self.ts_db.get_average("system.cpu.usage", last_hour, now)
memory_avg = await self.ts_db.get_average("system.memory.usage", last_hour, now)
# Get active users (approximate from events)
active_users = await self.ts_db.count_distinct_tags("user.event", "user_id", last_day, now)
return {
"events": {
"last_hour": hourly_events,
"last_day": daily_events,
"last_week": weekly_events
},
"system": {
"cpu_usage": cpu_avg,
"memory_usage": memory_avg
},
"users": {
"active_daily": active_users
},
"alerts": {
"active": len(self.active_alerts)
},
"timestamp": now.isoformat()
}
except Exception as e:
logger.error(f"Error getting overview: {e}")
return {}
async def get_user_analytics(
self,
start_date: datetime,
end_date: datetime,
granularity: str
) -> UserAnalytics:
"""Get user analytics"""
try:
# Get user metrics
total_users = await self.ts_db.count_distinct_tags(
"user.event.user_created",
"user_id",
datetime.min,
end_date
)
active_users = await self.ts_db.count_distinct_tags(
"user.event",
"user_id",
start_date,
end_date
)
new_users = await self.ts_db.count_metrics(
"user.event.user_created",
start_date,
end_date
)
# Calculate growth rate
prev_period_start = start_date - (end_date - start_date)
prev_users = await self.ts_db.count_distinct_tags(
"user.event",
"user_id",
prev_period_start,
start_date
)
growth_rate = ((active_users - prev_users) / max(prev_users, 1)) * 100
# Get top actions
top_actions = await self.ts_db.get_top_metrics(
"user.event",
"event_type",
start_date,
end_date,
limit=10
)
return UserAnalytics(
total_users=total_users,
active_users=active_users,
new_users=new_users,
user_growth_rate=growth_rate,
average_session_duration=0, # Would need session tracking
top_actions=top_actions,
user_distribution={}, # Would need geographic data
period=f"{start_date.date()} to {end_date.date()}"
)
except Exception as e:
logger.error(f"Error getting user analytics: {e}")
raise
async def get_system_analytics(self) -> SystemAnalytics:
"""Get system performance analytics"""
try:
now = datetime.now()
last_hour = now - timedelta(hours=1)
last_day = now - timedelta(days=1)
# Calculate uptime (simplified - would need actual downtime tracking)
total_checks = await self.ts_db.count_metrics("system.health", last_day, now)
successful_checks = await self.ts_db.count_metrics_with_value(
"system.health",
1,
last_day,
now
)
uptime = (successful_checks / max(total_checks, 1)) * 100
# Get averages
cpu_usage = await self.ts_db.get_average("system.cpu.usage", last_hour, now)
memory_usage = await self.ts_db.get_average("system.memory.usage", last_hour, now)
disk_usage = await self.ts_db.get_average("system.disk.usage", last_hour, now)
response_time = await self.ts_db.get_average("api.response_time", last_hour, now)
# Get error rate
total_requests = await self.ts_db.count_metrics("api.request", last_hour, now)
error_requests = await self.ts_db.count_metrics("api.error", last_hour, now)
error_rate = (error_requests / max(total_requests, 1)) * 100
# Throughput
throughput = total_requests / 3600 # requests per second
return SystemAnalytics(
uptime_percentage=uptime,
average_response_time=response_time or 0,
error_rate=error_rate,
throughput=throughput,
cpu_usage=cpu_usage or 0,
memory_usage=memory_usage or 0,
disk_usage=disk_usage or 0,
active_connections=0, # Would need connection tracking
services_health={} # Would need service health checks
)
except Exception as e:
logger.error(f"Error getting system analytics: {e}")
raise
async def get_event_analytics(
self,
event_type: Optional[str],
limit: int
) -> EventAnalytics:
"""Get event analytics"""
try:
now = datetime.now()
last_hour = now - timedelta(hours=1)
# Get total events
total_events = await self.ts_db.count_metrics(
event_type or "user.event",
last_hour,
now
)
# Events per second
events_per_second = total_events / 3600
# Get event types distribution
event_types = await self.ts_db.get_metric_distribution(
"user.event",
"event_type",
last_hour,
now
)
# Top events
top_events = await self.ts_db.get_top_metrics(
event_type or "user.event",
"event_type",
last_hour,
now,
limit=limit
)
# Error events
error_events = await self.ts_db.count_metrics(
"user.event.error",
last_hour,
now
)
# Success rate
success_rate = ((total_events - error_events) / max(total_events, 1)) * 100
return EventAnalytics(
total_events=total_events,
events_per_second=events_per_second,
event_types=event_types,
top_events=top_events,
error_events=error_events,
success_rate=success_rate,
processing_time={} # Would need timing metrics
)
except Exception as e:
logger.error(f"Error getting event analytics: {e}")
raise
async def get_dashboard_configs(self) -> List[Dict[str, Any]]:
"""Get available dashboard configurations"""
return [
{
"id": "overview",
"name": "Overview Dashboard",
"description": "General system overview"
},
{
"id": "users",
"name": "User Analytics",
"description": "User behavior and statistics"
},
{
"id": "system",
"name": "System Performance",
"description": "System health and performance metrics"
},
{
"id": "events",
"name": "Event Analytics",
"description": "Event processing and statistics"
}
]
async def get_dashboard_data(self, dashboard_id: str) -> Dict[str, Any]:
"""Get data for a specific dashboard"""
if dashboard_id == "overview":
return await self.get_overview()
elif dashboard_id == "users":
end_date = datetime.now()
start_date = end_date - timedelta(days=7)
analytics = await self.get_user_analytics(start_date, end_date, "day")
return analytics.dict()
elif dashboard_id == "system":
analytics = await self.get_system_analytics()
return analytics.dict()
elif dashboard_id == "events":
analytics = await self.get_event_analytics(None, 100)
return analytics.dict()
else:
raise ValueError(f"Unknown dashboard: {dashboard_id}")
async def create_alert_rule(self, rule_data: Dict[str, Any]) -> str:
"""Create a new alert rule"""
rule = AlertRule(**rule_data)
rule.id = str(uuid.uuid4())
self.alert_rules[rule.id] = rule
# Store in cache
await self.cache.set(
f"alert_rule:{rule.id}",
rule.json(),
expire=None # Permanent
)
return rule.id
async def _check_alert_rules(self):
"""Check alert rules periodically"""
while self.is_running:
try:
await asyncio.sleep(60) # Check every minute
for rule_id, rule in self.alert_rules.items():
if not rule.enabled:
continue
await self._evaluate_alert_rule(rule)
except Exception as e:
logger.error(f"Error checking alert rules: {e}")
async def _evaluate_alert_rule(self, rule: AlertRule):
"""Evaluate a single alert rule"""
try:
# Get recent metric values
end_time = datetime.now()
start_time = end_time - timedelta(seconds=rule.duration)
avg_value = await self.ts_db.get_average(
rule.metric_name,
start_time,
end_time
)
if avg_value is None:
return
# Check condition
triggered = False
if rule.condition == "gt" and avg_value > rule.threshold:
triggered = True
elif rule.condition == "lt" and avg_value < rule.threshold:
triggered = True
elif rule.condition == "gte" and avg_value >= rule.threshold:
triggered = True
elif rule.condition == "lte" and avg_value <= rule.threshold:
triggered = True
elif rule.condition == "eq" and avg_value == rule.threshold:
triggered = True
elif rule.condition == "neq" and avg_value != rule.threshold:
triggered = True
# Handle alert state
alert_key = f"{rule.id}:{rule.metric_name}"
if triggered:
if alert_key not in self.active_alerts:
# New alert
alert = Alert(
id=str(uuid.uuid4()),
rule_id=rule.id,
rule_name=rule.name,
metric_name=rule.metric_name,
current_value=avg_value,
threshold=rule.threshold,
severity=rule.severity,
triggered_at=datetime.now(),
status="active"
)
self.active_alerts[alert_key] = alert
# Send notifications
await self._send_alert_notifications(alert, rule)
else:
if alert_key in self.active_alerts:
# Alert resolved
alert = self.active_alerts[alert_key]
alert.resolved_at = datetime.now()
alert.status = "resolved"
del self.active_alerts[alert_key]
logger.info(f"Alert resolved: {rule.name}")
except Exception as e:
logger.error(f"Error evaluating alert rule {rule.id}: {e}")
async def _send_alert_notifications(self, alert: Alert, rule: AlertRule):
"""Send alert notifications"""
logger.warning(f"ALERT: {rule.name} - {alert.metric_name} = {alert.current_value} (threshold: {alert.threshold})")
# Would implement actual notification channels here
async def get_active_alerts(self) -> List[Dict[str, Any]]:
"""Get currently active alerts"""
return [alert.dict() for alert in self.active_alerts.values()]
async def export_to_csv(
self,
metric_type: str,
start_time: datetime,
end_time: datetime
):
"""Export metrics to CSV"""
try:
# Get metrics
metrics = await self.ts_db.query_metrics(
metric_type=metric_type,
start_time=start_time,
end_time=end_time
)
# Create CSV
output = io.StringIO()
writer = csv.DictWriter(
output,
fieldnames=['timestamp', 'metric_name', 'value', 'tags', 'service']
)
writer.writeheader()
for metric in metrics:
writer.writerow({
'timestamp': metric.get('timestamp'),
'metric_name': metric.get('name'),
'value': metric.get('value'),
'tags': str(metric.get('tags', {})),
'service': metric.get('service')
})
output.seek(0)
return output
except Exception as e:
logger.error(f"Error exporting to CSV: {e}")
raise
async def _cleanup_old_data(self):
"""Clean up old data periodically"""
while self.is_running:
try:
await asyncio.sleep(86400) # Run daily
# Delete data older than 30 days
cutoff_date = datetime.now() - timedelta(days=30)
await self.ts_db.delete_old_data(cutoff_date)
logger.info("Completed old data cleanup")
except Exception as e:
logger.error(f"Error in data cleanup: {e}")

View File

@ -0,0 +1,32 @@
"""Cache Manager for Redis"""
import json
import logging
from typing import Optional, Any
logger = logging.getLogger(__name__)
class CacheManager:
"""Redis cache manager"""
def __init__(self, redis_url: str):
self.redis_url = redis_url
self.is_connected = False
self.cache = {} # Simplified in-memory cache
async def connect(self):
"""Connect to Redis"""
self.is_connected = True
logger.info("Connected to cache")
async def close(self):
"""Close Redis connection"""
self.is_connected = False
logger.info("Disconnected from cache")
async def get(self, key: str) -> Optional[str]:
"""Get value from cache"""
return self.cache.get(key)
async def set(self, key: str, value: str, expire: Optional[int] = None):
"""Set value in cache"""
self.cache[key] = value

View File

@ -0,0 +1,396 @@
"""
Statistics Service - Real-time Analytics and Metrics
"""
from fastapi import FastAPI, HTTPException, Depends, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import uvicorn
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
import asyncio
import json
import os
from contextlib import asynccontextmanager
import logging
# Import custom modules
from models import Metric, AggregatedMetric, TimeSeriesData, DashboardConfig
from metrics_collector import MetricsCollector
from aggregator import DataAggregator
from websocket_manager import WebSocketManager
from time_series_db import TimeSeriesDB
from cache_manager import CacheManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global instances
metrics_collector = None
data_aggregator = None
ws_manager = None
ts_db = None
cache_manager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global metrics_collector, data_aggregator, ws_manager, ts_db, cache_manager
try:
# Initialize TimeSeriesDB (using InfluxDB)
ts_db = TimeSeriesDB(
host=os.getenv("INFLUXDB_HOST", "influxdb"),
port=int(os.getenv("INFLUXDB_PORT", 8086)),
database=os.getenv("INFLUXDB_DATABASE", "statistics")
)
await ts_db.connect()
logger.info("Connected to InfluxDB")
# Initialize Cache Manager
cache_manager = CacheManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
await cache_manager.connect()
logger.info("Connected to Redis cache")
# Initialize Metrics Collector (optional Kafka connection)
try:
metrics_collector = MetricsCollector(
kafka_bootstrap_servers=os.getenv("KAFKA_BOOTSTRAP_SERVERS", "kafka:9092"),
ts_db=ts_db,
cache=cache_manager
)
await metrics_collector.start()
logger.info("Metrics collector started")
except Exception as e:
logger.warning(f"Metrics collector failed to start (Kafka not available): {e}")
metrics_collector = None
# Initialize Data Aggregator
data_aggregator = DataAggregator(
ts_db=ts_db,
cache=cache_manager
)
asyncio.create_task(data_aggregator.start_aggregation_jobs())
logger.info("Data aggregator started")
# Initialize WebSocket Manager
ws_manager = WebSocketManager()
logger.info("WebSocket manager initialized")
except Exception as e:
logger.error(f"Failed to start Statistics service: {e}")
raise
yield
# Shutdown
if metrics_collector:
await metrics_collector.stop()
if data_aggregator:
await data_aggregator.stop()
if ts_db:
await ts_db.close()
if cache_manager:
await cache_manager.close()
logger.info("Statistics service shutdown complete")
app = FastAPI(
title="Statistics Service",
description="Real-time Analytics and Metrics Service",
version="1.0.0",
lifespan=lifespan
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"service": "Statistics Service",
"status": "running",
"timestamp": datetime.now().isoformat()
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "statistics",
"components": {
"influxdb": "connected" if ts_db and ts_db.is_connected else "disconnected",
"redis": "connected" if cache_manager and cache_manager.is_connected else "disconnected",
"metrics_collector": "running" if metrics_collector and metrics_collector.is_running else "stopped",
"aggregator": "running" if data_aggregator and data_aggregator.is_running else "stopped"
},
"timestamp": datetime.now().isoformat()
}
# Metrics Endpoints
@app.post("/api/metrics")
async def record_metric(metric: Metric):
"""Record a single metric"""
try:
await metrics_collector.record_metric(metric)
return {"status": "recorded", "metric_id": metric.id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/metrics/batch")
async def record_metrics_batch(metrics: List[Metric]):
"""Record multiple metrics in batch"""
try:
await metrics_collector.record_metrics_batch(metrics)
return {"status": "recorded", "count": len(metrics)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/metrics/realtime/{metric_type}")
async def get_realtime_metrics(
metric_type: str,
duration: int = Query(60, description="Duration in seconds")
):
"""Get real-time metrics for the specified type"""
try:
end_time = datetime.now()
start_time = end_time - timedelta(seconds=duration)
metrics = await ts_db.query_metrics(
metric_type=metric_type,
start_time=start_time,
end_time=end_time
)
return {
"metric_type": metric_type,
"duration": duration,
"data": metrics,
"timestamp": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Analytics Endpoints
@app.get("/api/analytics/overview")
async def get_analytics_overview():
"""Get overall analytics overview"""
try:
# Try to get from cache first
cached = await cache_manager.get("analytics:overview")
if cached:
return json.loads(cached)
# Calculate analytics
overview = await data_aggregator.get_overview()
# Cache for 1 minute
await cache_manager.set(
"analytics:overview",
json.dumps(overview),
expire=60
)
return overview
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/analytics/users")
async def get_user_analytics(
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
granularity: str = Query("hour", regex="^(minute|hour|day|week|month)$")
):
"""Get user analytics"""
try:
if not start_date:
start_date = datetime.now() - timedelta(days=7)
if not end_date:
end_date = datetime.now()
analytics = await data_aggregator.get_user_analytics(
start_date=start_date,
end_date=end_date,
granularity=granularity
)
return analytics
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/analytics/system")
async def get_system_analytics():
"""Get system performance analytics"""
try:
analytics = await data_aggregator.get_system_analytics()
return analytics
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/analytics/events")
async def get_event_analytics(
event_type: Optional[str] = None,
limit: int = Query(100, le=1000)
):
"""Get event analytics"""
try:
analytics = await data_aggregator.get_event_analytics(
event_type=event_type,
limit=limit
)
return analytics
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Time Series Endpoints
@app.get("/api/timeseries/{metric_name}")
async def get_time_series(
metric_name: str,
start_time: datetime,
end_time: datetime,
interval: str = Query("1m", regex="^\\d+[smhd]$")
):
"""Get time series data for a specific metric"""
try:
data = await ts_db.get_time_series(
metric_name=metric_name,
start_time=start_time,
end_time=end_time,
interval=interval
)
return TimeSeriesData(
metric_name=metric_name,
start_time=start_time,
end_time=end_time,
interval=interval,
data=data
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Aggregation Endpoints
@app.get("/api/aggregates/{metric_type}")
async def get_aggregated_metrics(
metric_type: str,
aggregation: str = Query("avg", regex="^(avg|sum|min|max|count)$"),
group_by: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None
):
"""Get aggregated metrics"""
try:
if not start_time:
start_time = datetime.now() - timedelta(hours=24)
if not end_time:
end_time = datetime.now()
result = await data_aggregator.aggregate_metrics(
metric_type=metric_type,
aggregation=aggregation,
group_by=group_by,
start_time=start_time,
end_time=end_time
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Dashboard Endpoints
@app.get("/api/dashboard/configs")
async def get_dashboard_configs():
"""Get available dashboard configurations"""
configs = await data_aggregator.get_dashboard_configs()
return {"configs": configs}
@app.get("/api/dashboard/{dashboard_id}")
async def get_dashboard_data(dashboard_id: str):
"""Get data for a specific dashboard"""
try:
data = await data_aggregator.get_dashboard_data(dashboard_id)
return data
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# WebSocket Endpoint for Real-time Updates
from fastapi import WebSocket, WebSocketDisconnect
@app.websocket("/ws/metrics")
async def websocket_metrics(websocket: WebSocket):
"""WebSocket endpoint for real-time metrics streaming"""
await ws_manager.connect(websocket)
try:
while True:
# Send metrics updates every second
metrics = await metrics_collector.get_latest_metrics()
await websocket.send_json({
"type": "metrics_update",
"data": metrics,
"timestamp": datetime.now().isoformat()
})
await asyncio.sleep(1)
except WebSocketDisconnect:
ws_manager.disconnect(websocket)
except Exception as e:
logger.error(f"WebSocket error: {e}")
ws_manager.disconnect(websocket)
# Alert Management Endpoints
@app.post("/api/alerts/rules")
async def create_alert_rule(rule: Dict[str, Any]):
"""Create a new alert rule"""
try:
rule_id = await data_aggregator.create_alert_rule(rule)
return {"rule_id": rule_id, "status": "created"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/alerts/active")
async def get_active_alerts():
"""Get currently active alerts"""
try:
alerts = await data_aggregator.get_active_alerts()
return {"alerts": alerts, "count": len(alerts)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Export Endpoints
@app.get("/api/export/csv")
async def export_metrics_csv(
metric_type: str,
start_time: datetime,
end_time: datetime
):
"""Export metrics as CSV"""
try:
csv_data = await data_aggregator.export_to_csv(
metric_type=metric_type,
start_time=start_time,
end_time=end_time
)
return StreamingResponse(
csv_data,
media_type="text/csv",
headers={
"Content-Disposition": f"attachment; filename=metrics_{metric_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)

View File

@ -0,0 +1,242 @@
"""
Metrics Collector - Collects metrics from Kafka and other sources
"""
import asyncio
import json
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime
from aiokafka import AIOKafkaConsumer
from models import Metric, MetricType
import uuid
logger = logging.getLogger(__name__)
class MetricsCollector:
"""Collects and processes metrics from various sources"""
def __init__(self, kafka_bootstrap_servers: str, ts_db, cache):
self.kafka_servers = kafka_bootstrap_servers
self.ts_db = ts_db
self.cache = cache
self.consumer = None
self.is_running = False
self.latest_metrics = {}
self.metrics_buffer = []
self.buffer_size = 100
self.flush_interval = 5 # seconds
async def start(self):
"""Start the metrics collector"""
try:
# Start Kafka consumer for event metrics
self.consumer = AIOKafkaConsumer(
'metrics-events',
'user-events',
'system-metrics',
bootstrap_servers=self.kafka_servers,
group_id='statistics-consumer-group',
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
await self.consumer.start()
self.is_running = True
# Start background tasks
asyncio.create_task(self._consume_metrics())
asyncio.create_task(self._flush_metrics_periodically())
logger.info("Metrics collector started")
except Exception as e:
logger.error(f"Failed to start metrics collector: {e}")
raise
async def stop(self):
"""Stop the metrics collector"""
self.is_running = False
if self.consumer:
await self.consumer.stop()
# Flush remaining metrics
if self.metrics_buffer:
await self._flush_metrics()
logger.info("Metrics collector stopped")
async def _consume_metrics(self):
"""Consume metrics from Kafka"""
while self.is_running:
try:
async for msg in self.consumer:
if not self.is_running:
break
metric = self._parse_kafka_message(msg)
if metric:
await self.record_metric(metric)
except Exception as e:
logger.error(f"Error consuming metrics: {e}")
await asyncio.sleep(5)
def _parse_kafka_message(self, msg) -> Optional[Metric]:
"""Parse Kafka message into Metric"""
try:
data = msg.value
topic = msg.topic
# Create metric based on topic
if topic == 'user-events':
return self._create_user_metric(data)
elif topic == 'system-metrics':
return self._create_system_metric(data)
elif topic == 'metrics-events':
return Metric(**data)
else:
return None
except Exception as e:
logger.error(f"Failed to parse Kafka message: {e}")
return None
def _create_user_metric(self, data: Dict) -> Metric:
"""Create metric from user event"""
event_type = data.get('event_type', 'unknown')
return Metric(
id=str(uuid.uuid4()),
name=f"user.event.{event_type.lower()}",
type=MetricType.COUNTER,
value=1,
tags={
"event_type": event_type,
"user_id": data.get('data', {}).get('user_id', 'unknown'),
"service": data.get('service', 'unknown')
},
timestamp=datetime.fromisoformat(data.get('timestamp', datetime.now().isoformat())),
service=data.get('service', 'users')
)
def _create_system_metric(self, data: Dict) -> Metric:
"""Create metric from system event"""
return Metric(
id=str(uuid.uuid4()),
name=data.get('metric_name', 'system.unknown'),
type=MetricType.GAUGE,
value=float(data.get('value', 0)),
tags=data.get('tags', {}),
timestamp=datetime.fromisoformat(data.get('timestamp', datetime.now().isoformat())),
service=data.get('service', 'system')
)
async def record_metric(self, metric: Metric):
"""Record a single metric"""
try:
# Add to buffer
self.metrics_buffer.append(metric)
# Update latest metrics cache
self.latest_metrics[metric.name] = {
"value": metric.value,
"timestamp": metric.timestamp.isoformat(),
"tags": metric.tags
}
# Flush if buffer is full
if len(self.metrics_buffer) >= self.buffer_size:
await self._flush_metrics()
except Exception as e:
logger.error(f"Failed to record metric: {e}")
raise
async def record_metrics_batch(self, metrics: List[Metric]):
"""Record multiple metrics"""
for metric in metrics:
await self.record_metric(metric)
async def _flush_metrics(self):
"""Flush metrics buffer to time series database"""
if not self.metrics_buffer:
return
try:
# Write to time series database
await self.ts_db.write_metrics(self.metrics_buffer)
# Clear buffer
self.metrics_buffer.clear()
logger.debug(f"Flushed {len(self.metrics_buffer)} metrics to database")
except Exception as e:
logger.error(f"Failed to flush metrics: {e}")
async def _flush_metrics_periodically(self):
"""Periodically flush metrics buffer"""
while self.is_running:
await asyncio.sleep(self.flush_interval)
await self._flush_metrics()
async def get_latest_metrics(self) -> Dict[str, Any]:
"""Get latest metrics for real-time display"""
return self.latest_metrics
async def collect_system_metrics(self):
"""Collect system-level metrics"""
import psutil
try:
# CPU metrics
cpu_percent = psutil.cpu_percent(interval=1)
await self.record_metric(Metric(
name="system.cpu.usage",
type=MetricType.GAUGE,
value=cpu_percent,
tags={"host": "localhost"},
service="statistics"
))
# Memory metrics
memory = psutil.virtual_memory()
await self.record_metric(Metric(
name="system.memory.usage",
type=MetricType.GAUGE,
value=memory.percent,
tags={"host": "localhost"},
service="statistics"
))
# Disk metrics
disk = psutil.disk_usage('/')
await self.record_metric(Metric(
name="system.disk.usage",
type=MetricType.GAUGE,
value=disk.percent,
tags={"host": "localhost", "mount": "/"},
service="statistics"
))
# Network metrics
net_io = psutil.net_io_counters()
await self.record_metric(Metric(
name="system.network.bytes_sent",
type=MetricType.COUNTER,
value=net_io.bytes_sent,
tags={"host": "localhost"},
service="statistics"
))
await self.record_metric(Metric(
name="system.network.bytes_recv",
type=MetricType.COUNTER,
value=net_io.bytes_recv,
tags={"host": "localhost"},
service="statistics"
))
except Exception as e:
logger.error(f"Failed to collect system metrics: {e}")
async def collect_application_metrics(self):
"""Collect application-level metrics"""
# This would be called by other services to report their metrics
pass

View File

@ -0,0 +1,159 @@
"""
Data models for Statistics Service
"""
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional, List, Dict, Any, Literal
from enum import Enum
class MetricType(str, Enum):
"""Types of metrics"""
COUNTER = "counter"
GAUGE = "gauge"
HISTOGRAM = "histogram"
SUMMARY = "summary"
class AggregationType(str, Enum):
"""Types of aggregation"""
AVG = "avg"
SUM = "sum"
MIN = "min"
MAX = "max"
COUNT = "count"
PERCENTILE = "percentile"
class Granularity(str, Enum):
"""Time granularity for aggregation"""
MINUTE = "minute"
HOUR = "hour"
DAY = "day"
WEEK = "week"
MONTH = "month"
class Metric(BaseModel):
"""Single metric data point"""
id: Optional[str] = Field(None, description="Unique metric ID")
name: str = Field(..., description="Metric name")
type: MetricType = Field(..., description="Metric type")
value: float = Field(..., description="Metric value")
tags: Dict[str, str] = Field(default_factory=dict, description="Metric tags")
timestamp: datetime = Field(default_factory=datetime.now, description="Metric timestamp")
service: str = Field(..., description="Source service")
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class AggregatedMetric(BaseModel):
"""Aggregated metric result"""
metric_name: str
aggregation_type: AggregationType
value: float
start_time: datetime
end_time: datetime
granularity: Optional[Granularity] = None
group_by: Optional[str] = None
count: int = Field(..., description="Number of data points aggregated")
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class TimeSeriesData(BaseModel):
"""Time series data response"""
metric_name: str
start_time: datetime
end_time: datetime
interval: str
data: List[Dict[str, Any]]
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class DashboardConfig(BaseModel):
"""Dashboard configuration"""
id: str
name: str
description: Optional[str] = None
widgets: List[Dict[str, Any]]
refresh_interval: int = Field(60, description="Refresh interval in seconds")
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class AlertRule(BaseModel):
"""Alert rule configuration"""
id: Optional[str] = None
name: str
metric_name: str
condition: Literal["gt", "lt", "gte", "lte", "eq", "neq"]
threshold: float
duration: int = Field(..., description="Duration in seconds")
severity: Literal["low", "medium", "high", "critical"]
enabled: bool = True
notification_channels: List[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.now)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class Alert(BaseModel):
"""Active alert"""
id: str
rule_id: str
rule_name: str
metric_name: str
current_value: float
threshold: float
severity: str
triggered_at: datetime
resolved_at: Optional[datetime] = None
status: Literal["active", "resolved", "acknowledged"]
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class UserAnalytics(BaseModel):
"""User analytics data"""
total_users: int
active_users: int
new_users: int
user_growth_rate: float
average_session_duration: float
top_actions: List[Dict[str, Any]]
user_distribution: Dict[str, int]
period: str
class SystemAnalytics(BaseModel):
"""System performance analytics"""
uptime_percentage: float
average_response_time: float
error_rate: float
throughput: float
cpu_usage: float
memory_usage: float
disk_usage: float
active_connections: int
services_health: Dict[str, str]
class EventAnalytics(BaseModel):
"""Event analytics data"""
total_events: int
events_per_second: float
event_types: Dict[str, int]
top_events: List[Dict[str, Any]]
error_events: int
success_rate: float
processing_time: Dict[str, float]

View File

@ -0,0 +1,9 @@
fastapi==0.109.0
uvicorn[standard]==0.27.0
pydantic==2.5.3
python-dotenv==1.0.0
aiokafka==0.10.0
redis==5.0.1
psutil==5.9.8
httpx==0.26.0
websockets==12.0

View File

@ -0,0 +1,165 @@
"""
Time Series Database Interface (Simplified for InfluxDB)
"""
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime
from models import Metric, AggregatedMetric
logger = logging.getLogger(__name__)
class TimeSeriesDB:
"""Time series database interface"""
def __init__(self, host: str, port: int, database: str):
self.host = host
self.port = port
self.database = database
self.is_connected = False
# In production, would use actual InfluxDB client
self.data_store = [] # Simplified in-memory storage
async def connect(self):
"""Connect to database"""
# Simplified connection
self.is_connected = True
logger.info(f"Connected to time series database at {self.host}:{self.port}")
async def close(self):
"""Close database connection"""
self.is_connected = False
logger.info("Disconnected from time series database")
async def write_metrics(self, metrics: List[Metric]):
"""Write metrics to database"""
for metric in metrics:
self.data_store.append({
"name": metric.name,
"value": metric.value,
"timestamp": metric.timestamp,
"tags": metric.tags,
"service": metric.service
})
async def query_metrics(
self,
metric_type: str,
start_time: datetime,
end_time: datetime
) -> List[Dict[str, Any]]:
"""Query metrics from database"""
results = []
for data in self.data_store:
if (data["name"].startswith(metric_type) and
start_time <= data["timestamp"] <= end_time):
results.append(data)
return results
async def get_time_series(
self,
metric_name: str,
start_time: datetime,
end_time: datetime,
interval: str
) -> List[Dict[str, Any]]:
"""Get time series data"""
return await self.query_metrics(metric_name, start_time, end_time)
async def store_aggregated_metric(self, metric: AggregatedMetric):
"""Store aggregated metric"""
self.data_store.append({
"name": f"agg.{metric.metric_name}",
"value": metric.value,
"timestamp": metric.end_time,
"tags": {"aggregation": metric.aggregation_type},
"service": "statistics"
})
async def count_metrics(
self,
metric_type: str,
start_time: datetime,
end_time: datetime
) -> int:
"""Count metrics"""
metrics = await self.query_metrics(metric_type, start_time, end_time)
return len(metrics)
async def get_average(
self,
metric_name: str,
start_time: datetime,
end_time: datetime
) -> Optional[float]:
"""Get average value"""
metrics = await self.query_metrics(metric_name, start_time, end_time)
if not metrics:
return None
values = [m["value"] for m in metrics]
return sum(values) / len(values)
async def count_distinct_tags(
self,
metric_type: str,
tag_name: str,
start_time: datetime,
end_time: datetime
) -> int:
"""Count distinct tag values"""
metrics = await self.query_metrics(metric_type, start_time, end_time)
unique_values = set()
for metric in metrics:
if tag_name in metric.get("tags", {}):
unique_values.add(metric["tags"][tag_name])
return len(unique_values)
async def get_top_metrics(
self,
metric_type: str,
group_by: str,
start_time: datetime,
end_time: datetime,
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get top metrics grouped by tag"""
metrics = await self.query_metrics(metric_type, start_time, end_time)
grouped = {}
for metric in metrics:
key = metric.get("tags", {}).get(group_by, "unknown")
grouped[key] = grouped.get(key, 0) + 1
sorted_items = sorted(grouped.items(), key=lambda x: x[1], reverse=True)
return [{"name": k, "count": v} for k, v in sorted_items[:limit]]
async def count_metrics_with_value(
self,
metric_name: str,
value: float,
start_time: datetime,
end_time: datetime
) -> int:
"""Count metrics with specific value"""
metrics = await self.query_metrics(metric_name, start_time, end_time)
return sum(1 for m in metrics if m["value"] == value)
async def get_metric_distribution(
self,
metric_type: str,
tag_name: str,
start_time: datetime,
end_time: datetime
) -> Dict[str, int]:
"""Get metric distribution by tag"""
metrics = await self.query_metrics(metric_type, start_time, end_time)
distribution = {}
for metric in metrics:
key = metric.get("tags", {}).get(tag_name, "unknown")
distribution[key] = distribution.get(key, 0) + 1
return distribution
async def delete_old_data(self, cutoff_date: datetime):
"""Delete old data"""
self.data_store = [
d for d in self.data_store
if d["timestamp"] >= cutoff_date
]

View File

@ -0,0 +1,33 @@
"""WebSocket Manager for real-time updates"""
from typing import List
from fastapi import WebSocket
import logging
logger = logging.getLogger(__name__)
class WebSocketManager:
"""Manages WebSocket connections"""
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
"""Accept WebSocket connection"""
await websocket.accept()
self.active_connections.append(websocket)
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
"""Remove WebSocket connection"""
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
async def broadcast(self, message: dict):
"""Broadcast message to all connected clients"""
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception as e:
logger.error(f"Error broadcasting to WebSocket: {e}")
self.disconnect(connection)

View File

@ -0,0 +1,21 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Expose port
EXPOSE 8000
# Run the application
CMD ["python", "main.py"]

View File

@ -0,0 +1,22 @@
from motor.motor_asyncio import AsyncIOMotorClient
from beanie import init_beanie
import os
from models import User
async def init_db():
"""Initialize database connection"""
# Get MongoDB URL from environment
mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
db_name = os.getenv("DB_NAME", "users_db")
# Create Motor client
client = AsyncIOMotorClient(mongodb_url)
# Initialize beanie with the User model
await init_beanie(
database=client[db_name],
document_models=[User]
)
print(f"Connected to MongoDB: {mongodb_url}/{db_name}")

View File

@ -0,0 +1,334 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime
import uvicorn
import os
import sys
import logging
from contextlib import asynccontextmanager
from database import init_db
from models import User
from beanie import PydanticObjectId
sys.path.append('/app')
from shared.kafka import KafkaProducer, Event, EventType
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Pydantic models for requests
class UserCreate(BaseModel):
username: str
email: str
full_name: Optional[str] = None
profile_picture: Optional[str] = None
bio: Optional[str] = None
location: Optional[str] = None
website: Optional[str] = None
class UserUpdate(BaseModel):
username: Optional[str] = None
email: Optional[str] = None
full_name: Optional[str] = None
profile_picture: Optional[str] = None
profile_picture_thumbnail: Optional[str] = None
bio: Optional[str] = None
location: Optional[str] = None
website: Optional[str] = None
is_email_verified: Optional[bool] = None
is_active: Optional[bool] = None
class UserResponse(BaseModel):
id: str
username: str
email: str
full_name: Optional[str] = None
profile_picture: Optional[str] = None
profile_picture_thumbnail: Optional[str] = None
bio: Optional[str] = None
location: Optional[str] = None
website: Optional[str] = None
is_email_verified: bool
is_active: bool
created_at: datetime
updated_at: datetime
class UserPublicResponse(BaseModel):
"""공개 프로필용 응답 (민감한 정보 제외)"""
id: str
username: str
full_name: Optional[str] = None
profile_picture: Optional[str] = None
profile_picture_thumbnail: Optional[str] = None
bio: Optional[str] = None
location: Optional[str] = None
website: Optional[str] = None
created_at: datetime
# Global Kafka producer
kafka_producer: Optional[KafkaProducer] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global kafka_producer
await init_db()
# Initialize Kafka producer
try:
kafka_producer = KafkaProducer(
bootstrap_servers=os.getenv('KAFKA_BOOTSTRAP_SERVERS', 'kafka:9092')
)
await kafka_producer.start()
logger.info("Kafka producer initialized")
except Exception as e:
logger.warning(f"Failed to initialize Kafka producer: {e}")
kafka_producer = None
yield
# Shutdown
if kafka_producer:
await kafka_producer.stop()
app = FastAPI(
title="Users Service",
description="User management microservice with MongoDB",
version="0.2.0",
lifespan=lifespan
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health check
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "users",
"timestamp": datetime.now().isoformat()
}
# CRUD Operations
@app.get("/users", response_model=List[UserResponse])
async def get_users():
users = await User.find_all().to_list()
return [UserResponse(
id=str(user.id),
username=user.username,
email=user.email,
full_name=user.full_name,
profile_picture=user.profile_picture,
profile_picture_thumbnail=user.profile_picture_thumbnail,
bio=user.bio,
location=user.location,
website=user.website,
is_email_verified=user.is_email_verified,
is_active=user.is_active,
created_at=user.created_at,
updated_at=user.updated_at
) for user in users]
@app.get("/users/{user_id}", response_model=UserResponse)
async def get_user(user_id: str):
try:
user = await User.get(PydanticObjectId(user_id))
if not user:
raise HTTPException(status_code=404, detail="User not found")
return UserResponse(
id=str(user.id),
username=user.username,
email=user.email,
full_name=user.full_name,
profile_picture=user.profile_picture,
profile_picture_thumbnail=user.profile_picture_thumbnail,
bio=user.bio,
location=user.location,
website=user.website,
is_email_verified=user.is_email_verified,
is_active=user.is_active,
created_at=user.created_at,
updated_at=user.updated_at
)
except Exception:
raise HTTPException(status_code=404, detail="User not found")
@app.post("/users", response_model=UserResponse, status_code=201)
async def create_user(user_data: UserCreate):
# Check if username already exists
existing_user = await User.find_one(User.username == user_data.username)
if existing_user:
raise HTTPException(status_code=400, detail="Username already exists")
# Create new user
user = User(
username=user_data.username,
email=user_data.email,
full_name=user_data.full_name,
profile_picture=user_data.profile_picture,
bio=user_data.bio,
location=user_data.location,
website=user_data.website
)
await user.create()
# Publish event
if kafka_producer:
event = Event(
event_type=EventType.USER_CREATED,
service="users",
data={
"user_id": str(user.id),
"username": user.username,
"email": user.email
},
user_id=str(user.id)
)
await kafka_producer.send_event("user-events", event)
return UserResponse(
id=str(user.id),
username=user.username,
email=user.email,
full_name=user.full_name,
profile_picture=user.profile_picture,
profile_picture_thumbnail=user.profile_picture_thumbnail,
bio=user.bio,
location=user.location,
website=user.website,
is_email_verified=user.is_email_verified,
is_active=user.is_active,
created_at=user.created_at,
updated_at=user.updated_at
)
@app.put("/users/{user_id}", response_model=UserResponse)
async def update_user(user_id: str, user_update: UserUpdate):
try:
user = await User.get(PydanticObjectId(user_id))
if not user:
raise HTTPException(status_code=404, detail="User not found")
except Exception:
raise HTTPException(status_code=404, detail="User not found")
if user_update.username is not None:
# Check if new username already exists
existing_user = await User.find_one(
User.username == user_update.username,
User.id != user.id
)
if existing_user:
raise HTTPException(status_code=400, detail="Username already exists")
user.username = user_update.username
if user_update.email is not None:
user.email = user_update.email
if user_update.full_name is not None:
user.full_name = user_update.full_name
if user_update.profile_picture is not None:
user.profile_picture = user_update.profile_picture
if user_update.profile_picture_thumbnail is not None:
user.profile_picture_thumbnail = user_update.profile_picture_thumbnail
if user_update.bio is not None:
user.bio = user_update.bio
if user_update.location is not None:
user.location = user_update.location
if user_update.website is not None:
user.website = user_update.website
if user_update.is_email_verified is not None:
user.is_email_verified = user_update.is_email_verified
if user_update.is_active is not None:
user.is_active = user_update.is_active
user.updated_at = datetime.now()
await user.save()
# Publish event
if kafka_producer:
event = Event(
event_type=EventType.USER_UPDATED,
service="users",
data={
"user_id": str(user.id),
"username": user.username,
"email": user.email,
"updated_fields": list(user_update.dict(exclude_unset=True).keys())
},
user_id=str(user.id)
)
await kafka_producer.send_event("user-events", event)
return UserResponse(
id=str(user.id),
username=user.username,
email=user.email,
full_name=user.full_name,
profile_picture=user.profile_picture,
profile_picture_thumbnail=user.profile_picture_thumbnail,
bio=user.bio,
location=user.location,
website=user.website,
is_email_verified=user.is_email_verified,
is_active=user.is_active,
created_at=user.created_at,
updated_at=user.updated_at
)
@app.delete("/users/{user_id}")
async def delete_user(user_id: str):
try:
user = await User.get(PydanticObjectId(user_id))
if not user:
raise HTTPException(status_code=404, detail="User not found")
user_id_str = str(user.id)
username = user.username
await user.delete()
# Publish event
if kafka_producer:
event = Event(
event_type=EventType.USER_DELETED,
service="users",
data={
"user_id": user_id_str,
"username": username
},
user_id=user_id_str
)
await kafka_producer.send_event("user-events", event)
return {"message": "User deleted successfully"}
except Exception:
raise HTTPException(status_code=404, detail="User not found")
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)

View File

@ -0,0 +1,31 @@
from beanie import Document
from pydantic import EmailStr, Field, HttpUrl
from datetime import datetime
from typing import Optional
class User(Document):
username: str = Field(..., unique=True)
email: EmailStr
full_name: Optional[str] = None
profile_picture: Optional[str] = Field(None, description="프로필 사진 URL")
profile_picture_thumbnail: Optional[str] = Field(None, description="프로필 사진 썸네일 URL")
bio: Optional[str] = Field(None, max_length=500, description="자기소개")
location: Optional[str] = Field(None, description="위치")
website: Optional[str] = Field(None, description="개인 웹사이트")
is_email_verified: bool = Field(default=False, description="이메일 인증 여부")
is_active: bool = Field(default=True, description="계정 활성화 상태")
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Settings:
collection = "users"
class Config:
json_schema_extra = {
"example": {
"username": "john_doe",
"email": "john@example.com",
"full_name": "John Doe"
}
}

View File

@ -0,0 +1,7 @@
fastapi==0.109.0
uvicorn[standard]==0.27.0
pydantic[email]==2.5.3
pymongo==4.6.1
motor==3.3.2
beanie==1.23.6
aiokafka==0.10.0