Initial commit - cleaned repository
This commit is contained in:
27
services/files/backend/Dockerfile
Normal file
27
services/files/backend/Dockerfile
Normal file
@ -0,0 +1,27 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies for Pillow and file type detection
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libmagic1 \
|
||||
libjpeg-dev \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create directories for thumbnails cache
|
||||
RUN mkdir -p /tmp/thumbnails
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
247
services/files/backend/file_processor.py
Normal file
247
services/files/backend/file_processor.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""
|
||||
File Processor for handling file uploads and processing
|
||||
"""
|
||||
import hashlib
|
||||
import mimetypes
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
import uuid
|
||||
from fastapi import UploadFile
|
||||
from models import FileType, FileStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FileProcessor:
|
||||
def __init__(self, minio_client, metadata_manager, thumbnail_generator):
|
||||
self.minio_client = minio_client
|
||||
self.metadata_manager = metadata_manager
|
||||
self.thumbnail_generator = thumbnail_generator
|
||||
|
||||
def _determine_file_type(self, content_type: str) -> FileType:
|
||||
"""Determine file type from content type"""
|
||||
if content_type.startswith('image/'):
|
||||
return FileType.IMAGE
|
||||
elif content_type.startswith('video/'):
|
||||
return FileType.VIDEO
|
||||
elif content_type.startswith('audio/'):
|
||||
return FileType.AUDIO
|
||||
elif content_type in ['application/pdf', 'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument',
|
||||
'text/plain', 'text/html', 'text/csv']:
|
||||
return FileType.DOCUMENT
|
||||
elif content_type in ['application/zip', 'application/x-rar-compressed',
|
||||
'application/x-tar', 'application/gzip']:
|
||||
return FileType.ARCHIVE
|
||||
else:
|
||||
return FileType.OTHER
|
||||
|
||||
def _calculate_file_hash(self, file_data: bytes) -> str:
|
||||
"""Calculate SHA256 hash of file data"""
|
||||
return hashlib.sha256(file_data).hexdigest()
|
||||
|
||||
async def process_upload(self, file: UploadFile, user_id: str,
|
||||
bucket: str = "default",
|
||||
public: bool = False,
|
||||
generate_thumbnail: bool = True,
|
||||
tags: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Process file upload"""
|
||||
try:
|
||||
# Read file data
|
||||
file_data = await file.read()
|
||||
file_size = len(file_data)
|
||||
|
||||
# Get content type
|
||||
content_type = file.content_type or mimetypes.guess_type(file.filename)[0] or 'application/octet-stream'
|
||||
|
||||
# Generate file ID and object name
|
||||
file_id = str(uuid.uuid4())
|
||||
timestamp = datetime.now().strftime('%Y%m%d')
|
||||
file_extension = file.filename.split('.')[-1] if '.' in file.filename else ''
|
||||
object_name = f"{timestamp}/{user_id}/{file_id}.{file_extension}" if file_extension else f"{timestamp}/{user_id}/{file_id}"
|
||||
|
||||
# Calculate file hash
|
||||
file_hash = self._calculate_file_hash(file_data)
|
||||
|
||||
# Check for duplicates
|
||||
duplicates = await self.metadata_manager.find_duplicate_files(file_hash)
|
||||
if duplicates and not public: # Allow duplicates for public files
|
||||
# Return existing file info
|
||||
existing = duplicates[0]
|
||||
logger.info(f"Duplicate file detected: {existing['id']}")
|
||||
return {
|
||||
"file_id": existing["id"],
|
||||
"filename": existing["filename"],
|
||||
"size": existing["size"],
|
||||
"content_type": existing["content_type"],
|
||||
"file_type": existing["file_type"],
|
||||
"bucket": existing["bucket"],
|
||||
"public": existing["public"],
|
||||
"has_thumbnail": existing.get("has_thumbnail", False),
|
||||
"thumbnail_url": existing.get("thumbnail_url"),
|
||||
"created_at": existing["created_at"],
|
||||
"duplicate": True
|
||||
}
|
||||
|
||||
# Upload to MinIO
|
||||
upload_result = await self.minio_client.upload_file(
|
||||
bucket=bucket,
|
||||
object_name=object_name,
|
||||
file_data=file_data,
|
||||
content_type=content_type,
|
||||
metadata={
|
||||
"user_id": user_id,
|
||||
"original_name": file.filename,
|
||||
"upload_date": datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
# Determine file type
|
||||
file_type = self._determine_file_type(content_type)
|
||||
|
||||
# Generate thumbnail if applicable
|
||||
has_thumbnail = False
|
||||
thumbnail_url = None
|
||||
|
||||
if generate_thumbnail and file_type == FileType.IMAGE:
|
||||
thumbnail_data = await self.thumbnail_generator.generate_thumbnail(
|
||||
file_data=file_data,
|
||||
content_type=content_type
|
||||
)
|
||||
|
||||
if thumbnail_data:
|
||||
has_thumbnail = True
|
||||
# Generate multiple sizes
|
||||
await self.thumbnail_generator.generate_multiple_sizes(
|
||||
file_data=file_data,
|
||||
content_type=content_type,
|
||||
file_id=file_id
|
||||
)
|
||||
|
||||
if public:
|
||||
thumbnail_url = await self.minio_client.generate_presigned_download_url(
|
||||
bucket="thumbnails",
|
||||
object_name=f"thumbnails/{file_id}_medium.jpg",
|
||||
expires_in=86400 * 30 # 30 days
|
||||
)
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"id": file_id,
|
||||
"filename": file.filename,
|
||||
"original_name": file.filename,
|
||||
"size": file_size,
|
||||
"content_type": content_type,
|
||||
"file_type": file_type.value,
|
||||
"bucket": bucket,
|
||||
"object_name": object_name,
|
||||
"user_id": user_id,
|
||||
"hash": file_hash,
|
||||
"public": public,
|
||||
"has_thumbnail": has_thumbnail,
|
||||
"thumbnail_url": thumbnail_url,
|
||||
"tags": tags or {},
|
||||
"metadata": {
|
||||
"etag": upload_result.get("etag"),
|
||||
"version_id": upload_result.get("version_id")
|
||||
}
|
||||
}
|
||||
|
||||
# Save metadata to database
|
||||
await self.metadata_manager.create_file_metadata(metadata)
|
||||
|
||||
# Generate download URL if public
|
||||
download_url = None
|
||||
if public:
|
||||
download_url = await self.minio_client.generate_presigned_download_url(
|
||||
bucket=bucket,
|
||||
object_name=object_name,
|
||||
expires_in=86400 * 30 # 30 days
|
||||
)
|
||||
|
||||
logger.info(f"File uploaded successfully: {file_id}")
|
||||
|
||||
return {
|
||||
"file_id": file_id,
|
||||
"filename": file.filename,
|
||||
"size": file_size,
|
||||
"content_type": content_type,
|
||||
"file_type": file_type.value,
|
||||
"bucket": bucket,
|
||||
"public": public,
|
||||
"has_thumbnail": has_thumbnail,
|
||||
"thumbnail_url": thumbnail_url,
|
||||
"download_url": download_url,
|
||||
"created_at": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"File processing error: {e}")
|
||||
raise
|
||||
|
||||
async def process_large_file(self, file: UploadFile, user_id: str,
|
||||
bucket: str = "default",
|
||||
chunk_size: int = 1024 * 1024 * 5) -> Dict[str, Any]:
|
||||
"""Process large file upload in chunks"""
|
||||
try:
|
||||
file_id = str(uuid.uuid4())
|
||||
timestamp = datetime.now().strftime('%Y%m%d')
|
||||
file_extension = file.filename.split('.')[-1] if '.' in file.filename else ''
|
||||
object_name = f"{timestamp}/{user_id}/{file_id}.{file_extension}"
|
||||
|
||||
# Initialize multipart upload
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
# Process file in chunks
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = await file.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
chunks.append(chunk)
|
||||
hasher.update(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
# Combine chunks and upload
|
||||
file_data = b''.join(chunks)
|
||||
file_hash = hasher.hexdigest()
|
||||
|
||||
# Upload to MinIO
|
||||
content_type = file.content_type or 'application/octet-stream'
|
||||
await self.minio_client.upload_file(
|
||||
bucket=bucket,
|
||||
object_name=object_name,
|
||||
file_data=file_data,
|
||||
content_type=content_type
|
||||
)
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"id": file_id,
|
||||
"filename": file.filename,
|
||||
"original_name": file.filename,
|
||||
"size": total_size,
|
||||
"content_type": content_type,
|
||||
"file_type": self._determine_file_type(content_type).value,
|
||||
"bucket": bucket,
|
||||
"object_name": object_name,
|
||||
"user_id": user_id,
|
||||
"hash": file_hash,
|
||||
"public": False,
|
||||
"has_thumbnail": False
|
||||
}
|
||||
|
||||
await self.metadata_manager.create_file_metadata(metadata)
|
||||
|
||||
return {
|
||||
"file_id": file_id,
|
||||
"filename": file.filename,
|
||||
"size": total_size,
|
||||
"message": "Large file uploaded successfully"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Large file processing error: {e}")
|
||||
raise
|
||||
541
services/files/backend/main.py
Normal file
541
services/files/backend/main.py
Normal file
@ -0,0 +1,541 @@
|
||||
"""
|
||||
File Management Service - S3-compatible Object Storage with MinIO
|
||||
"""
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Query, Form
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
import uvicorn
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
import asyncio
|
||||
import os
|
||||
import hashlib
|
||||
import magic
|
||||
import io
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
# Import custom modules
|
||||
from models import FileMetadata, FileUploadResponse, FileListResponse, StorageStats
|
||||
from minio_client import MinIOManager
|
||||
from thumbnail_generator import ThumbnailGenerator
|
||||
from metadata_manager import MetadataManager
|
||||
from file_processor import FileProcessor
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
minio_manager = None
|
||||
thumbnail_generator = None
|
||||
metadata_manager = None
|
||||
file_processor = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
global minio_manager, thumbnail_generator, metadata_manager, file_processor
|
||||
|
||||
try:
|
||||
# Initialize MinIO client
|
||||
minio_manager = MinIOManager(
|
||||
endpoint=os.getenv("MINIO_ENDPOINT", "minio:9000"),
|
||||
access_key=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
|
||||
secret_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
|
||||
secure=os.getenv("MINIO_SECURE", "false").lower() == "true"
|
||||
)
|
||||
await minio_manager.initialize()
|
||||
logger.info("MinIO client initialized")
|
||||
|
||||
# Initialize Metadata Manager (MongoDB)
|
||||
metadata_manager = MetadataManager(
|
||||
mongodb_url=os.getenv("MONGODB_URL", "mongodb://mongodb:27017"),
|
||||
database=os.getenv("FILES_DB_NAME", "files_db")
|
||||
)
|
||||
await metadata_manager.connect()
|
||||
logger.info("Metadata manager connected to MongoDB")
|
||||
|
||||
# Initialize Thumbnail Generator
|
||||
thumbnail_generator = ThumbnailGenerator(
|
||||
minio_client=minio_manager,
|
||||
cache_dir="/tmp/thumbnails"
|
||||
)
|
||||
logger.info("Thumbnail generator initialized")
|
||||
|
||||
# Initialize File Processor
|
||||
file_processor = FileProcessor(
|
||||
minio_client=minio_manager,
|
||||
metadata_manager=metadata_manager,
|
||||
thumbnail_generator=thumbnail_generator
|
||||
)
|
||||
logger.info("File processor initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start File service: {e}")
|
||||
raise
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if metadata_manager:
|
||||
await metadata_manager.close()
|
||||
|
||||
logger.info("File service shutdown complete")
|
||||
|
||||
app = FastAPI(
|
||||
title="File Management Service",
|
||||
description="S3-compatible object storage with MinIO",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "File Management Service",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "files",
|
||||
"components": {
|
||||
"minio": "connected" if minio_manager and minio_manager.is_connected else "disconnected",
|
||||
"mongodb": "connected" if metadata_manager and metadata_manager.is_connected else "disconnected",
|
||||
"thumbnail_generator": "ready" if thumbnail_generator else "not_initialized"
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# File Upload Endpoints
|
||||
@app.post("/api/files/upload")
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
user_id: str = Form(...),
|
||||
bucket: str = Form("default"),
|
||||
public: bool = Form(False),
|
||||
generate_thumbnail: bool = Form(True),
|
||||
tags: Optional[str] = Form(None)
|
||||
):
|
||||
"""Upload a file to object storage"""
|
||||
try:
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
# Process file upload
|
||||
result = await file_processor.process_upload(
|
||||
file=file,
|
||||
user_id=user_id,
|
||||
bucket=bucket,
|
||||
public=public,
|
||||
generate_thumbnail=generate_thumbnail,
|
||||
tags=json.loads(tags) if tags else {}
|
||||
)
|
||||
|
||||
return FileUploadResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"File upload error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/files/upload-multiple")
|
||||
async def upload_multiple_files(
|
||||
files: List[UploadFile] = File(...),
|
||||
user_id: str = Form(...),
|
||||
bucket: str = Form("default"),
|
||||
public: bool = Form(False)
|
||||
):
|
||||
"""Upload multiple files"""
|
||||
try:
|
||||
results = []
|
||||
for file in files:
|
||||
result = await file_processor.process_upload(
|
||||
file=file,
|
||||
user_id=user_id,
|
||||
bucket=bucket,
|
||||
public=public,
|
||||
generate_thumbnail=True
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return {
|
||||
"uploaded": len(results),
|
||||
"files": results
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Multiple file upload error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# File Retrieval Endpoints
|
||||
@app.get("/api/files/{file_id}")
|
||||
async def get_file(file_id: str):
|
||||
"""Get file by ID"""
|
||||
try:
|
||||
# Get metadata
|
||||
metadata = await metadata_manager.get_file_metadata(file_id)
|
||||
if not metadata:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Get file from MinIO
|
||||
file_stream = await minio_manager.get_file(
|
||||
bucket=metadata["bucket"],
|
||||
object_name=metadata["object_name"]
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type=metadata.get("content_type", "application/octet-stream"),
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{metadata["filename"]}"'
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"File retrieval error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/files/{file_id}/metadata")
|
||||
async def get_file_metadata(file_id: str):
|
||||
"""Get file metadata"""
|
||||
try:
|
||||
metadata = await metadata_manager.get_file_metadata(file_id)
|
||||
if not metadata:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return FileMetadata(**metadata)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Metadata retrieval error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/files/{file_id}/thumbnail")
|
||||
async def get_thumbnail(
|
||||
file_id: str,
|
||||
width: int = Query(200, ge=50, le=1000),
|
||||
height: int = Query(200, ge=50, le=1000)
|
||||
):
|
||||
"""Get file thumbnail"""
|
||||
try:
|
||||
# Get metadata
|
||||
metadata = await metadata_manager.get_file_metadata(file_id)
|
||||
if not metadata:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Check if file has thumbnail
|
||||
if not metadata.get("has_thumbnail"):
|
||||
raise HTTPException(status_code=404, detail="No thumbnail available")
|
||||
|
||||
# Get or generate thumbnail
|
||||
thumbnail = await thumbnail_generator.get_thumbnail(
|
||||
file_id=file_id,
|
||||
bucket=metadata["bucket"],
|
||||
object_name=metadata["object_name"],
|
||||
width=width,
|
||||
height=height
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(thumbnail),
|
||||
media_type="image/jpeg"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Thumbnail retrieval error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/files/{file_id}/download")
|
||||
async def download_file(file_id: str):
|
||||
"""Download file with proper headers"""
|
||||
try:
|
||||
# Get metadata
|
||||
metadata = await metadata_manager.get_file_metadata(file_id)
|
||||
if not metadata:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Update download count
|
||||
await metadata_manager.increment_download_count(file_id)
|
||||
|
||||
# Get file from MinIO
|
||||
file_stream = await minio_manager.get_file(
|
||||
bucket=metadata["bucket"],
|
||||
object_name=metadata["object_name"]
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type=metadata.get("content_type", "application/octet-stream"),
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{metadata["filename"]}"',
|
||||
"Content-Length": str(metadata["size"])
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"File download error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# File Management Endpoints
|
||||
@app.delete("/api/files/{file_id}")
|
||||
async def delete_file(file_id: str, user_id: str):
|
||||
"""Delete a file"""
|
||||
try:
|
||||
# Get metadata
|
||||
metadata = await metadata_manager.get_file_metadata(file_id)
|
||||
if not metadata:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Check ownership
|
||||
if metadata["user_id"] != user_id:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
# Delete from MinIO
|
||||
await minio_manager.delete_file(
|
||||
bucket=metadata["bucket"],
|
||||
object_name=metadata["object_name"]
|
||||
)
|
||||
|
||||
# Delete thumbnail if exists
|
||||
if metadata.get("has_thumbnail"):
|
||||
await thumbnail_generator.delete_thumbnail(file_id)
|
||||
|
||||
# Delete metadata
|
||||
await metadata_manager.delete_file_metadata(file_id)
|
||||
|
||||
return {"status": "deleted", "file_id": file_id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"File deletion error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.patch("/api/files/{file_id}")
|
||||
async def update_file_metadata(
|
||||
file_id: str,
|
||||
user_id: str,
|
||||
updates: Dict[str, Any]
|
||||
):
|
||||
"""Update file metadata"""
|
||||
try:
|
||||
# Get existing metadata
|
||||
metadata = await metadata_manager.get_file_metadata(file_id)
|
||||
if not metadata:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Check ownership
|
||||
if metadata["user_id"] != user_id:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
# Update metadata
|
||||
updated = await metadata_manager.update_file_metadata(file_id, updates)
|
||||
|
||||
return {"status": "updated", "file_id": file_id, "metadata": updated}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Metadata update error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# File Listing Endpoints
|
||||
@app.get("/api/files")
|
||||
async def list_files(
|
||||
user_id: Optional[str] = None,
|
||||
bucket: str = Query("default"),
|
||||
limit: int = Query(20, le=100),
|
||||
offset: int = Query(0),
|
||||
search: Optional[str] = None,
|
||||
file_type: Optional[str] = None,
|
||||
sort_by: str = Query("created_at", pattern="^(created_at|filename|size)$"),
|
||||
order: str = Query("desc", pattern="^(asc|desc)$")
|
||||
):
|
||||
"""List files with filtering and pagination"""
|
||||
try:
|
||||
files = await metadata_manager.list_files(
|
||||
user_id=user_id,
|
||||
bucket=bucket,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
search=search,
|
||||
file_type=file_type,
|
||||
sort_by=sort_by,
|
||||
order=order
|
||||
)
|
||||
|
||||
return FileListResponse(**files)
|
||||
except Exception as e:
|
||||
logger.error(f"File listing error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/files/user/{user_id}")
|
||||
async def get_user_files(
|
||||
user_id: str,
|
||||
limit: int = Query(20, le=100),
|
||||
offset: int = Query(0)
|
||||
):
|
||||
"""Get all files for a specific user"""
|
||||
try:
|
||||
files = await metadata_manager.list_files(
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return FileListResponse(**files)
|
||||
except Exception as e:
|
||||
logger.error(f"User files listing error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Storage Management Endpoints
|
||||
@app.get("/api/storage/stats")
|
||||
async def get_storage_stats():
|
||||
"""Get storage statistics"""
|
||||
try:
|
||||
stats = await minio_manager.get_storage_stats()
|
||||
db_stats = await metadata_manager.get_storage_stats()
|
||||
|
||||
return StorageStats(
|
||||
total_files=db_stats["total_files"],
|
||||
total_size=db_stats["total_size"],
|
||||
buckets=stats["buckets"],
|
||||
users_count=db_stats["users_count"],
|
||||
file_types=db_stats["file_types"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Storage stats error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/storage/buckets")
|
||||
async def create_bucket(bucket_name: str, public: bool = False):
|
||||
"""Create a new storage bucket"""
|
||||
try:
|
||||
await minio_manager.create_bucket(bucket_name, public=public)
|
||||
return {"status": "created", "bucket": bucket_name}
|
||||
except Exception as e:
|
||||
logger.error(f"Bucket creation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/storage/buckets")
|
||||
async def list_buckets():
|
||||
"""List all storage buckets"""
|
||||
try:
|
||||
buckets = await minio_manager.list_buckets()
|
||||
return {"buckets": buckets}
|
||||
except Exception as e:
|
||||
logger.error(f"Bucket listing error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Presigned URL Endpoints
|
||||
@app.post("/api/files/presigned-upload")
|
||||
async def generate_presigned_upload_url(
|
||||
filename: str,
|
||||
content_type: str,
|
||||
bucket: str = "default",
|
||||
expires_in: int = Query(3600, ge=60, le=86400)
|
||||
):
|
||||
"""Generate presigned URL for direct upload to MinIO"""
|
||||
try:
|
||||
url = await minio_manager.generate_presigned_upload_url(
|
||||
bucket=bucket,
|
||||
object_name=f"{datetime.now().strftime('%Y%m%d')}/{filename}",
|
||||
expires_in=expires_in
|
||||
)
|
||||
|
||||
return {
|
||||
"upload_url": url,
|
||||
"expires_in": expires_in,
|
||||
"method": "PUT"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Presigned URL generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/files/{file_id}/share")
|
||||
async def generate_share_link(
|
||||
file_id: str,
|
||||
expires_in: int = Query(86400, ge=60, le=604800) # 1 day default, max 7 days
|
||||
):
|
||||
"""Generate a shareable link for a file"""
|
||||
try:
|
||||
# Get metadata
|
||||
metadata = await metadata_manager.get_file_metadata(file_id)
|
||||
if not metadata:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Generate presigned URL
|
||||
url = await minio_manager.generate_presigned_download_url(
|
||||
bucket=metadata["bucket"],
|
||||
object_name=metadata["object_name"],
|
||||
expires_in=expires_in
|
||||
)
|
||||
|
||||
return {
|
||||
"share_url": url,
|
||||
"expires_in": expires_in,
|
||||
"file_id": file_id,
|
||||
"filename": metadata["filename"]
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Share link generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Batch Operations
|
||||
@app.post("/api/files/batch-delete")
|
||||
async def batch_delete_files(file_ids: List[str], user_id: str):
|
||||
"""Delete multiple files at once"""
|
||||
try:
|
||||
deleted = []
|
||||
errors = []
|
||||
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
# Get metadata
|
||||
metadata = await metadata_manager.get_file_metadata(file_id)
|
||||
if metadata and metadata["user_id"] == user_id:
|
||||
# Delete from MinIO
|
||||
await minio_manager.delete_file(
|
||||
bucket=metadata["bucket"],
|
||||
object_name=metadata["object_name"]
|
||||
)
|
||||
# Delete metadata
|
||||
await metadata_manager.delete_file_metadata(file_id)
|
||||
deleted.append(file_id)
|
||||
else:
|
||||
errors.append({"file_id": file_id, "error": "Not found or permission denied"})
|
||||
except Exception as e:
|
||||
errors.append({"file_id": file_id, "error": str(e)})
|
||||
|
||||
return {
|
||||
"deleted": deleted,
|
||||
"errors": errors,
|
||||
"total_deleted": len(deleted)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Batch delete error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True
|
||||
)
|
||||
331
services/files/backend/metadata_manager.py
Normal file
331
services/files/backend/metadata_manager.py
Normal file
@ -0,0 +1,331 @@
|
||||
"""
|
||||
Metadata Manager for file information storage in MongoDB
|
||||
"""
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
import logging
|
||||
import uuid
|
||||
from models import FileType, FileStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetadataManager:
|
||||
def __init__(self, mongodb_url: str, database: str = "files_db"):
|
||||
self.mongodb_url = mongodb_url
|
||||
self.database_name = database
|
||||
self.client = None
|
||||
self.db = None
|
||||
self.collection = None
|
||||
self.is_connected = False
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to MongoDB"""
|
||||
try:
|
||||
self.client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = self.client[self.database_name]
|
||||
self.collection = self.db.files
|
||||
|
||||
# Create indexes
|
||||
await self._create_indexes()
|
||||
|
||||
# Test connection
|
||||
await self.client.admin.command('ping')
|
||||
self.is_connected = True
|
||||
logger.info(f"Connected to MongoDB at {self.mongodb_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to MongoDB: {e}")
|
||||
self.is_connected = False
|
||||
raise
|
||||
|
||||
async def _create_indexes(self):
|
||||
"""Create database indexes for better performance"""
|
||||
try:
|
||||
# Create indexes
|
||||
await self.collection.create_index("user_id")
|
||||
await self.collection.create_index("bucket")
|
||||
await self.collection.create_index("created_at")
|
||||
await self.collection.create_index("file_type")
|
||||
await self.collection.create_index([("filename", "text")])
|
||||
await self.collection.create_index([("user_id", 1), ("created_at", -1)])
|
||||
|
||||
logger.info("Database indexes created")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create indexes: {e}")
|
||||
|
||||
async def create_file_metadata(self, metadata: Dict[str, Any]) -> str:
|
||||
"""Create new file metadata"""
|
||||
try:
|
||||
# Add timestamps
|
||||
metadata["created_at"] = datetime.now()
|
||||
metadata["updated_at"] = datetime.now()
|
||||
metadata["download_count"] = 0
|
||||
metadata["status"] = FileStatus.READY.value
|
||||
|
||||
# Generate unique ID if not provided
|
||||
if "id" not in metadata:
|
||||
metadata["id"] = str(uuid.uuid4())
|
||||
|
||||
# Insert document
|
||||
result = await self.collection.insert_one(metadata)
|
||||
|
||||
logger.info(f"Created metadata for file: {metadata['id']}")
|
||||
return metadata["id"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file metadata: {e}")
|
||||
raise
|
||||
|
||||
async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get file metadata by ID"""
|
||||
try:
|
||||
metadata = await self.collection.find_one({"id": file_id})
|
||||
|
||||
if metadata:
|
||||
# Remove MongoDB's _id field
|
||||
metadata.pop("_id", None)
|
||||
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file metadata: {e}")
|
||||
raise
|
||||
|
||||
async def update_file_metadata(self, file_id: str, updates: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update file metadata"""
|
||||
try:
|
||||
# Add update timestamp
|
||||
updates["updated_at"] = datetime.now()
|
||||
|
||||
# Update document
|
||||
result = await self.collection.update_one(
|
||||
{"id": file_id},
|
||||
{"$set": updates}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
raise Exception(f"File {file_id} not found")
|
||||
|
||||
# Return updated metadata
|
||||
return await self.get_file_metadata(file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file metadata: {e}")
|
||||
raise
|
||||
|
||||
async def delete_file_metadata(self, file_id: str) -> bool:
|
||||
"""Delete file metadata (soft delete)"""
|
||||
try:
|
||||
# Soft delete by marking as deleted
|
||||
updates = {
|
||||
"status": FileStatus.DELETED.value,
|
||||
"deleted_at": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
result = await self.collection.update_one(
|
||||
{"id": file_id},
|
||||
{"$set": updates}
|
||||
)
|
||||
|
||||
return result.modified_count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file metadata: {e}")
|
||||
raise
|
||||
|
||||
async def list_files(self, user_id: Optional[str] = None,
|
||||
bucket: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
search: Optional[str] = None,
|
||||
file_type: Optional[str] = None,
|
||||
sort_by: str = "created_at",
|
||||
order: str = "desc") -> Dict[str, Any]:
|
||||
"""List files with filtering and pagination"""
|
||||
try:
|
||||
# Build query
|
||||
query = {"status": {"$ne": FileStatus.DELETED.value}}
|
||||
|
||||
if user_id:
|
||||
query["user_id"] = user_id
|
||||
|
||||
if bucket:
|
||||
query["bucket"] = bucket
|
||||
|
||||
if file_type:
|
||||
query["file_type"] = file_type
|
||||
|
||||
if search:
|
||||
query["$text"] = {"$search": search}
|
||||
|
||||
# Count total documents
|
||||
total = await self.collection.count_documents(query)
|
||||
|
||||
# Sort order
|
||||
sort_order = -1 if order == "desc" else 1
|
||||
|
||||
# Execute query with pagination
|
||||
cursor = self.collection.find(query)\
|
||||
.sort(sort_by, sort_order)\
|
||||
.skip(offset)\
|
||||
.limit(limit)
|
||||
|
||||
files = []
|
||||
async for doc in cursor:
|
||||
doc.pop("_id", None)
|
||||
files.append(doc)
|
||||
|
||||
return {
|
||||
"files": files,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + limit) < total
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files: {e}")
|
||||
raise
|
||||
|
||||
async def increment_download_count(self, file_id: str):
|
||||
"""Increment download counter for a file"""
|
||||
try:
|
||||
await self.collection.update_one(
|
||||
{"id": file_id},
|
||||
{
|
||||
"$inc": {"download_count": 1},
|
||||
"$set": {"last_accessed": datetime.now()}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to increment download count: {e}")
|
||||
|
||||
async def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage statistics"""
|
||||
try:
|
||||
# Aggregation pipeline for statistics
|
||||
pipeline = [
|
||||
{"$match": {"status": {"$ne": FileStatus.DELETED.value}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_files": {"$sum": 1},
|
||||
"total_size": {"$sum": "$size"},
|
||||
"users": {"$addToSet": "$user_id"}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
result = await cursor.to_list(length=1)
|
||||
|
||||
if result:
|
||||
stats = result[0]
|
||||
users_count = len(stats.get("users", []))
|
||||
else:
|
||||
stats = {"total_files": 0, "total_size": 0}
|
||||
users_count = 0
|
||||
|
||||
# Get file type distribution
|
||||
type_pipeline = [
|
||||
{"$match": {"status": {"$ne": FileStatus.DELETED.value}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$file_type",
|
||||
"count": {"$sum": 1}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
type_cursor = self.collection.aggregate(type_pipeline)
|
||||
type_results = await type_cursor.to_list(length=None)
|
||||
|
||||
file_types = {
|
||||
item["_id"]: item["count"]
|
||||
for item in type_results if item["_id"]
|
||||
}
|
||||
|
||||
return {
|
||||
"total_files": stats.get("total_files", 0),
|
||||
"total_size": stats.get("total_size", 0),
|
||||
"users_count": users_count,
|
||||
"file_types": file_types
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get storage stats: {e}")
|
||||
raise
|
||||
|
||||
async def find_duplicate_files(self, file_hash: str) -> List[Dict[str, Any]]:
|
||||
"""Find duplicate files by hash"""
|
||||
try:
|
||||
cursor = self.collection.find({
|
||||
"hash": file_hash,
|
||||
"status": {"$ne": FileStatus.DELETED.value}
|
||||
})
|
||||
|
||||
duplicates = []
|
||||
async for doc in cursor:
|
||||
doc.pop("_id", None)
|
||||
duplicates.append(doc)
|
||||
|
||||
return duplicates
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to find duplicate files: {e}")
|
||||
raise
|
||||
|
||||
async def get_user_storage_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get storage usage for a specific user"""
|
||||
try:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"user_id": user_id,
|
||||
"status": {"$ne": FileStatus.DELETED.value}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$file_type",
|
||||
"count": {"$sum": 1},
|
||||
"size": {"$sum": "$size"}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
results = await cursor.to_list(length=None)
|
||||
|
||||
total_size = sum(item["size"] for item in results)
|
||||
total_files = sum(item["count"] for item in results)
|
||||
|
||||
breakdown = {
|
||||
item["_id"]: {
|
||||
"count": item["count"],
|
||||
"size": item["size"]
|
||||
}
|
||||
for item in results if item["_id"]
|
||||
}
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"total_files": total_files,
|
||||
"total_size": total_size,
|
||||
"breakdown": breakdown
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user storage usage: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close MongoDB connection"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.is_connected = False
|
||||
logger.info("MongoDB connection closed")
|
||||
333
services/files/backend/minio_client.py
Normal file
333
services/files/backend/minio_client.py
Normal file
@ -0,0 +1,333 @@
|
||||
"""
|
||||
MinIO Client for S3-compatible object storage
|
||||
"""
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
import asyncio
|
||||
import io
|
||||
from typing import Optional, Dict, Any, List
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MinIOManager:
|
||||
def __init__(self, endpoint: str, access_key: str, secret_key: str, secure: bool = False):
|
||||
self.endpoint = endpoint
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self.secure = secure
|
||||
self.client = None
|
||||
self.is_connected = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize MinIO client and create default buckets"""
|
||||
try:
|
||||
self.client = Minio(
|
||||
self.endpoint,
|
||||
access_key=self.access_key,
|
||||
secret_key=self.secret_key,
|
||||
secure=self.secure
|
||||
)
|
||||
|
||||
# Create default buckets
|
||||
default_buckets = ["default", "public", "thumbnails", "temp"]
|
||||
for bucket in default_buckets:
|
||||
await self.create_bucket(bucket, public=(bucket == "public"))
|
||||
|
||||
self.is_connected = True
|
||||
logger.info(f"Connected to MinIO at {self.endpoint}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MinIO: {e}")
|
||||
self.is_connected = False
|
||||
raise
|
||||
|
||||
async def create_bucket(self, bucket_name: str, public: bool = False):
|
||||
"""Create a new bucket"""
|
||||
try:
|
||||
# Run in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Check if bucket exists
|
||||
exists = await loop.run_in_executor(
|
||||
None,
|
||||
self.client.bucket_exists,
|
||||
bucket_name
|
||||
)
|
||||
|
||||
if not exists:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self.client.make_bucket,
|
||||
bucket_name
|
||||
)
|
||||
logger.info(f"Created bucket: {bucket_name}")
|
||||
|
||||
# Set bucket policy if public
|
||||
if public:
|
||||
policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"AWS": ["*"]},
|
||||
"Action": ["s3:GetObject"],
|
||||
"Resource": [f"arn:aws:s3:::{bucket_name}/*"]
|
||||
}
|
||||
]
|
||||
}
|
||||
import json
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self.client.set_bucket_policy,
|
||||
bucket_name,
|
||||
json.dumps(policy)
|
||||
)
|
||||
logger.info(f"Set public policy for bucket: {bucket_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create bucket {bucket_name}: {e}")
|
||||
raise
|
||||
|
||||
async def upload_file(self, bucket: str, object_name: str, file_data: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
metadata: Optional[Dict[str, str]] = None):
|
||||
"""Upload a file to MinIO"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Convert bytes to BytesIO
|
||||
file_stream = io.BytesIO(file_data)
|
||||
length = len(file_data)
|
||||
|
||||
# Upload file
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
self.client.put_object,
|
||||
bucket,
|
||||
object_name,
|
||||
file_stream,
|
||||
length,
|
||||
content_type,
|
||||
metadata
|
||||
)
|
||||
|
||||
logger.info(f"Uploaded {object_name} to {bucket}")
|
||||
return {
|
||||
"bucket": bucket,
|
||||
"object_name": object_name,
|
||||
"etag": result.etag,
|
||||
"version_id": result.version_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload file: {e}")
|
||||
raise
|
||||
|
||||
async def get_file(self, bucket: str, object_name: str) -> io.BytesIO:
|
||||
"""Get a file from MinIO"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Get object
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
self.client.get_object,
|
||||
bucket,
|
||||
object_name
|
||||
)
|
||||
|
||||
# Read data
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
return io.BytesIO(data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file: {e}")
|
||||
raise
|
||||
|
||||
async def delete_file(self, bucket: str, object_name: str):
|
||||
"""Delete a file from MinIO"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self.client.remove_object,
|
||||
bucket,
|
||||
object_name
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {object_name} from {bucket}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file: {e}")
|
||||
raise
|
||||
|
||||
async def list_files(self, bucket: str, prefix: Optional[str] = None,
|
||||
recursive: bool = True) -> List[Dict[str, Any]]:
|
||||
"""List files in a bucket"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
objects = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: list(self.client.list_objects(
|
||||
bucket,
|
||||
prefix=prefix,
|
||||
recursive=recursive
|
||||
))
|
||||
)
|
||||
|
||||
files = []
|
||||
for obj in objects:
|
||||
files.append({
|
||||
"name": obj.object_name,
|
||||
"size": obj.size,
|
||||
"last_modified": obj.last_modified,
|
||||
"etag": obj.etag,
|
||||
"content_type": obj.content_type
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files: {e}")
|
||||
raise
|
||||
|
||||
async def get_file_info(self, bucket: str, object_name: str) -> Dict[str, Any]:
|
||||
"""Get file information"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
stat = await loop.run_in_executor(
|
||||
None,
|
||||
self.client.stat_object,
|
||||
bucket,
|
||||
object_name
|
||||
)
|
||||
|
||||
return {
|
||||
"size": stat.size,
|
||||
"etag": stat.etag,
|
||||
"content_type": stat.content_type,
|
||||
"last_modified": stat.last_modified,
|
||||
"metadata": stat.metadata
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file info: {e}")
|
||||
raise
|
||||
|
||||
async def generate_presigned_download_url(self, bucket: str, object_name: str,
|
||||
expires_in: int = 3600) -> str:
|
||||
"""Generate a presigned URL for downloading"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
url = await loop.run_in_executor(
|
||||
None,
|
||||
self.client.presigned_get_object,
|
||||
bucket,
|
||||
object_name,
|
||||
timedelta(seconds=expires_in)
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate presigned URL: {e}")
|
||||
raise
|
||||
|
||||
async def generate_presigned_upload_url(self, bucket: str, object_name: str,
|
||||
expires_in: int = 3600) -> str:
|
||||
"""Generate a presigned URL for uploading"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
url = await loop.run_in_executor(
|
||||
None,
|
||||
self.client.presigned_put_object,
|
||||
bucket,
|
||||
object_name,
|
||||
timedelta(seconds=expires_in)
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate presigned upload URL: {e}")
|
||||
raise
|
||||
|
||||
async def copy_file(self, source_bucket: str, source_object: str,
|
||||
dest_bucket: str, dest_object: str):
|
||||
"""Copy a file within MinIO"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self.client.copy_object,
|
||||
dest_bucket,
|
||||
dest_object,
|
||||
f"/{source_bucket}/{source_object}"
|
||||
)
|
||||
|
||||
logger.info(f"Copied {source_object} to {dest_object}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to copy file: {e}")
|
||||
raise
|
||||
|
||||
async def list_buckets(self) -> List[str]:
|
||||
"""List all buckets"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
buckets = await loop.run_in_executor(
|
||||
None,
|
||||
self.client.list_buckets
|
||||
)
|
||||
|
||||
return [bucket.name for bucket in buckets]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list buckets: {e}")
|
||||
raise
|
||||
|
||||
async def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage statistics"""
|
||||
try:
|
||||
buckets = await self.list_buckets()
|
||||
|
||||
stats = {
|
||||
"buckets": buckets,
|
||||
"bucket_count": len(buckets),
|
||||
"bucket_stats": {}
|
||||
}
|
||||
|
||||
# Get stats for each bucket
|
||||
for bucket in buckets:
|
||||
files = await self.list_files(bucket)
|
||||
total_size = sum(f["size"] for f in files)
|
||||
stats["bucket_stats"][bucket] = {
|
||||
"file_count": len(files),
|
||||
"total_size": total_size
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get storage stats: {e}")
|
||||
raise
|
||||
|
||||
async def check_file_exists(self, bucket: str, object_name: str) -> bool:
|
||||
"""Check if a file exists"""
|
||||
try:
|
||||
await self.get_file_info(bucket, object_name)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
112
services/files/backend/models.py
Normal file
112
services/files/backend/models.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""
|
||||
Data models for File Management Service
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
class FileType(str, Enum):
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = "document"
|
||||
ARCHIVE = "archive"
|
||||
OTHER = "other"
|
||||
|
||||
class FileStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
READY = "ready"
|
||||
ERROR = "error"
|
||||
DELETED = "deleted"
|
||||
|
||||
class FileMetadata(BaseModel):
|
||||
id: str
|
||||
filename: str
|
||||
original_name: str
|
||||
size: int
|
||||
content_type: str
|
||||
file_type: FileType
|
||||
bucket: str
|
||||
object_name: str
|
||||
user_id: str
|
||||
hash: str
|
||||
status: FileStatus = FileStatus.READY
|
||||
public: bool = False
|
||||
has_thumbnail: bool = False
|
||||
thumbnail_url: Optional[str] = None
|
||||
tags: Dict[str, Any] = {}
|
||||
metadata: Dict[str, Any] = {}
|
||||
download_count: int = 0
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
deleted_at: Optional[datetime] = None
|
||||
|
||||
class FileUploadResponse(BaseModel):
|
||||
file_id: str
|
||||
filename: str
|
||||
size: int
|
||||
content_type: str
|
||||
file_type: FileType
|
||||
bucket: str
|
||||
public: bool
|
||||
has_thumbnail: bool
|
||||
thumbnail_url: Optional[str] = None
|
||||
download_url: Optional[str] = None
|
||||
created_at: datetime
|
||||
message: str = "File uploaded successfully"
|
||||
|
||||
class FileListResponse(BaseModel):
|
||||
files: List[FileMetadata]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
has_more: bool
|
||||
|
||||
class StorageStats(BaseModel):
|
||||
total_files: int
|
||||
total_size: int
|
||||
buckets: List[str]
|
||||
users_count: int
|
||||
file_types: Dict[str, int]
|
||||
storage_used_percentage: Optional[float] = None
|
||||
|
||||
class ThumbnailRequest(BaseModel):
|
||||
file_id: str
|
||||
width: int = Field(200, ge=50, le=1000)
|
||||
height: int = Field(200, ge=50, le=1000)
|
||||
quality: int = Field(85, ge=50, le=100)
|
||||
format: str = Field("jpeg", pattern="^(jpeg|png|webp)$")
|
||||
|
||||
class PresignedUrlResponse(BaseModel):
|
||||
url: str
|
||||
expires_in: int
|
||||
method: str
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
|
||||
class BatchOperationResult(BaseModel):
|
||||
successful: List[str]
|
||||
failed: List[Dict[str, str]]
|
||||
total_processed: int
|
||||
total_successful: int
|
||||
total_failed: int
|
||||
|
||||
class FileShareLink(BaseModel):
|
||||
share_url: str
|
||||
expires_in: int
|
||||
file_id: str
|
||||
filename: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
|
||||
class FileProcessingJob(BaseModel):
|
||||
job_id: str
|
||||
file_id: str
|
||||
job_type: str # thumbnail, compress, convert, etc.
|
||||
status: str # pending, processing, completed, failed
|
||||
progress: Optional[float] = None
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
created_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
11
services/files/backend/requirements.txt
Normal file
11
services/files/backend/requirements.txt
Normal file
@ -0,0 +1,11 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
pydantic==2.5.3
|
||||
python-dotenv==1.0.0
|
||||
motor==3.5.1
|
||||
pymongo==4.6.1
|
||||
minio==7.2.3
|
||||
pillow==10.2.0
|
||||
python-magic==0.4.27
|
||||
aiofiles==23.2.1
|
||||
python-multipart==0.0.6
|
||||
281
services/files/backend/test_files.py
Executable file
281
services/files/backend/test_files.py
Executable file
@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for File Management Service
|
||||
"""
|
||||
import asyncio
|
||||
import httpx
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
import base64
|
||||
|
||||
BASE_URL = "http://localhost:8014"
|
||||
|
||||
# Sample image for testing (1x1 pixel PNG)
|
||||
TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
TEST_IMAGE_DATA = base64.b64decode(TEST_IMAGE_BASE64)
|
||||
|
||||
async def test_file_api():
|
||||
"""Test file management API endpoints"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
print("\n📁 Testing File Management Service API...")
|
||||
|
||||
# Test health check
|
||||
print("\n1. Testing health check...")
|
||||
response = await client.get(f"{BASE_URL}/health")
|
||||
print(f"Health check: {response.json()}")
|
||||
|
||||
# Test file upload
|
||||
print("\n2. Testing file upload...")
|
||||
files = {
|
||||
'file': ('test_image.png', TEST_IMAGE_DATA, 'image/png')
|
||||
}
|
||||
data = {
|
||||
'user_id': 'test_user_123',
|
||||
'bucket': 'default',
|
||||
'public': 'false',
|
||||
'generate_thumbnail': 'true',
|
||||
'tags': json.dumps({"test": "true", "category": "sample"})
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/files/upload",
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
upload_result = response.json()
|
||||
print(f"File uploaded: {upload_result}")
|
||||
file_id = upload_result.get("file_id")
|
||||
else:
|
||||
print(f"Upload failed: {response.status_code} - {response.text}")
|
||||
file_id = None
|
||||
|
||||
# Test multiple file upload
|
||||
print("\n3. Testing multiple file upload...")
|
||||
files = [
|
||||
('files', ('test1.png', TEST_IMAGE_DATA, 'image/png')),
|
||||
('files', ('test2.png', TEST_IMAGE_DATA, 'image/png')),
|
||||
('files', ('test3.png', TEST_IMAGE_DATA, 'image/png'))
|
||||
]
|
||||
data = {
|
||||
'user_id': 'test_user_123',
|
||||
'bucket': 'default',
|
||||
'public': 'false'
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/files/upload-multiple",
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(f"Multiple files uploaded: {response.json()}")
|
||||
else:
|
||||
print(f"Multiple upload failed: {response.status_code}")
|
||||
|
||||
# Test file metadata retrieval
|
||||
if file_id:
|
||||
print("\n4. Testing file metadata retrieval...")
|
||||
response = await client.get(f"{BASE_URL}/api/files/{file_id}/metadata")
|
||||
if response.status_code == 200:
|
||||
print(f"File metadata: {response.json()}")
|
||||
else:
|
||||
print(f"Metadata retrieval failed: {response.status_code}")
|
||||
|
||||
# Test thumbnail generation
|
||||
print("\n5. Testing thumbnail retrieval...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/files/{file_id}/thumbnail",
|
||||
params={"width": 100, "height": 100}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"Thumbnail retrieved: {len(response.content)} bytes")
|
||||
else:
|
||||
print(f"Thumbnail retrieval failed: {response.status_code}")
|
||||
|
||||
# Test file download
|
||||
print("\n6. Testing file download...")
|
||||
response = await client.get(f"{BASE_URL}/api/files/{file_id}/download")
|
||||
if response.status_code == 200:
|
||||
print(f"File downloaded: {len(response.content)} bytes")
|
||||
else:
|
||||
print(f"Download failed: {response.status_code}")
|
||||
|
||||
# Test share link generation
|
||||
print("\n7. Testing share link generation...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/files/{file_id}/share",
|
||||
params={"expires_in": 3600}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
share_result = response.json()
|
||||
print(f"Share link generated: {share_result.get('share_url', 'N/A')[:50]}...")
|
||||
else:
|
||||
print(f"Share link generation failed: {response.status_code}")
|
||||
|
||||
# Test file listing
|
||||
print("\n8. Testing file listing...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/files",
|
||||
params={
|
||||
"user_id": "test_user_123",
|
||||
"limit": 10,
|
||||
"offset": 0
|
||||
}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
files_list = response.json()
|
||||
print(f"Files found: {files_list.get('total', 0)} files")
|
||||
else:
|
||||
print(f"File listing failed: {response.status_code}")
|
||||
|
||||
# Test storage statistics
|
||||
print("\n9. Testing storage statistics...")
|
||||
response = await client.get(f"{BASE_URL}/api/storage/stats")
|
||||
if response.status_code == 200:
|
||||
stats = response.json()
|
||||
print(f"Storage stats: {stats}")
|
||||
else:
|
||||
print(f"Storage stats failed: {response.status_code}")
|
||||
|
||||
# Test bucket operations
|
||||
print("\n10. Testing bucket operations...")
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/storage/buckets",
|
||||
params={"bucket_name": "test-bucket", "public": False}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"Bucket created: {response.json()}")
|
||||
else:
|
||||
print(f"Bucket creation: {response.status_code}")
|
||||
|
||||
response = await client.get(f"{BASE_URL}/api/storage/buckets")
|
||||
if response.status_code == 200:
|
||||
print(f"Available buckets: {response.json()}")
|
||||
else:
|
||||
print(f"Bucket listing failed: {response.status_code}")
|
||||
|
||||
# Test presigned URL generation
|
||||
print("\n11. Testing presigned URL generation...")
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/files/presigned-upload",
|
||||
params={
|
||||
"filename": "test_upload.txt",
|
||||
"content_type": "text/plain",
|
||||
"bucket": "default",
|
||||
"expires_in": 3600
|
||||
}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
presigned = response.json()
|
||||
print(f"Presigned upload URL generated: {presigned.get('upload_url', 'N/A')[:50]}...")
|
||||
else:
|
||||
print(f"Presigned URL generation failed: {response.status_code}")
|
||||
|
||||
# Test file deletion
|
||||
if file_id:
|
||||
print("\n12. Testing file deletion...")
|
||||
response = await client.delete(
|
||||
f"{BASE_URL}/api/files/{file_id}",
|
||||
params={"user_id": "test_user_123"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"File deleted: {response.json()}")
|
||||
else:
|
||||
print(f"File deletion failed: {response.status_code}")
|
||||
|
||||
async def test_large_file_upload():
|
||||
"""Test large file upload"""
|
||||
print("\n\n📦 Testing Large File Upload...")
|
||||
|
||||
# Create a larger test file (1MB)
|
||||
large_data = b"x" * (1024 * 1024) # 1MB of data
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
files = {
|
||||
'file': ('large_test.bin', large_data, 'application/octet-stream')
|
||||
}
|
||||
data = {
|
||||
'user_id': 'test_user_123',
|
||||
'bucket': 'default',
|
||||
'public': 'false'
|
||||
}
|
||||
|
||||
print("Uploading 1MB file...")
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/files/upload",
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"Large file uploaded successfully: {result.get('file_id')}")
|
||||
print(f"File size: {result.get('size')} bytes")
|
||||
else:
|
||||
print(f"Large file upload failed: {response.status_code}")
|
||||
|
||||
async def test_duplicate_detection():
|
||||
"""Test duplicate file detection"""
|
||||
print("\n\n🔍 Testing Duplicate Detection...")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Upload the same file twice
|
||||
files = {
|
||||
'file': ('duplicate_test.png', TEST_IMAGE_DATA, 'image/png')
|
||||
}
|
||||
data = {
|
||||
'user_id': 'test_user_123',
|
||||
'bucket': 'default',
|
||||
'public': 'false'
|
||||
}
|
||||
|
||||
print("Uploading file first time...")
|
||||
response1 = await client.post(
|
||||
f"{BASE_URL}/api/files/upload",
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
if response1.status_code == 200:
|
||||
result1 = response1.json()
|
||||
print(f"First upload: {result1.get('file_id')}")
|
||||
|
||||
print("Uploading same file again...")
|
||||
response2 = await client.post(
|
||||
f"{BASE_URL}/api/files/upload",
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
if response2.status_code == 200:
|
||||
result2 = response2.json()
|
||||
print(f"Second upload: {result2.get('file_id')}")
|
||||
|
||||
if result2.get('duplicate'):
|
||||
print("✅ Duplicate detected successfully!")
|
||||
else:
|
||||
print("❌ Duplicate not detected")
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
print("=" * 60)
|
||||
print("FILE MANAGEMENT SERVICE TEST SUITE")
|
||||
print("=" * 60)
|
||||
print(f"Started at: {datetime.now().isoformat()}")
|
||||
|
||||
# Run tests
|
||||
await test_file_api()
|
||||
await test_large_file_upload()
|
||||
await test_duplicate_detection()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ All tests completed!")
|
||||
print(f"Finished at: {datetime.now().isoformat()}")
|
||||
print("=" * 60)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
236
services/files/backend/thumbnail_generator.py
Normal file
236
services/files/backend/thumbnail_generator.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""
|
||||
Thumbnail Generator for image and video files
|
||||
"""
|
||||
from PIL import Image, ImageOps
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ThumbnailGenerator:
|
||||
def __init__(self, minio_client, cache_dir: str = "/tmp/thumbnails"):
|
||||
self.minio_client = minio_client
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Supported image formats for thumbnail generation
|
||||
self.supported_formats = {
|
||||
'image/jpeg', 'image/jpg', 'image/png', 'image/gif',
|
||||
'image/webp', 'image/bmp', 'image/tiff'
|
||||
}
|
||||
|
||||
def _get_cache_path(self, file_id: str, width: int, height: int) -> Path:
|
||||
"""Generate cache file path for thumbnail"""
|
||||
cache_key = f"{file_id}_{width}x{height}"
|
||||
cache_hash = hashlib.md5(cache_key.encode()).hexdigest()
|
||||
return self.cache_dir / f"{cache_hash[:2]}" / f"{cache_hash}.jpg"
|
||||
|
||||
async def generate_thumbnail(self, file_data: bytes, content_type: str,
|
||||
width: int = 200, height: int = 200) -> Optional[bytes]:
|
||||
"""Generate a thumbnail from file data"""
|
||||
try:
|
||||
if content_type not in self.supported_formats:
|
||||
logger.warning(f"Unsupported format for thumbnail: {content_type}")
|
||||
return None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Generate thumbnail in thread pool
|
||||
thumbnail_data = await loop.run_in_executor(
|
||||
None,
|
||||
self._create_thumbnail,
|
||||
file_data,
|
||||
width,
|
||||
height
|
||||
)
|
||||
|
||||
return thumbnail_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate thumbnail: {e}")
|
||||
return None
|
||||
|
||||
def _create_thumbnail(self, file_data: bytes, width: int, height: int) -> bytes:
|
||||
"""Create thumbnail using PIL"""
|
||||
try:
|
||||
# Open image
|
||||
image = Image.open(io.BytesIO(file_data))
|
||||
|
||||
# Convert RGBA to RGB if necessary
|
||||
if image.mode in ('RGBA', 'LA', 'P'):
|
||||
# Create a white background
|
||||
background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
if image.mode == 'P':
|
||||
image = image.convert('RGBA')
|
||||
background.paste(image, mask=image.split()[-1] if image.mode == 'RGBA' else None)
|
||||
image = background
|
||||
elif image.mode not in ('RGB', 'L'):
|
||||
image = image.convert('RGB')
|
||||
|
||||
# Calculate thumbnail size maintaining aspect ratio
|
||||
image.thumbnail((width, height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Apply EXIF orientation if present
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
# Save thumbnail to bytes
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='JPEG', quality=85, optimize=True)
|
||||
output.seek(0)
|
||||
|
||||
return output.read()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Thumbnail creation failed: {e}")
|
||||
raise
|
||||
|
||||
async def get_thumbnail(self, file_id: str, bucket: str, object_name: str,
|
||||
width: int = 200, height: int = 200) -> Optional[bytes]:
|
||||
"""Get or generate thumbnail for a file"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_path = self._get_cache_path(file_id, width, height)
|
||||
|
||||
if cache_path.exists():
|
||||
logger.info(f"Thumbnail found in cache: {cache_path}")
|
||||
with open(cache_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
# Check if thumbnail exists in MinIO
|
||||
thumbnail_object = f"thumbnails/{file_id}_{width}x{height}.jpg"
|
||||
try:
|
||||
thumbnail_stream = await self.minio_client.get_file(
|
||||
bucket="thumbnails",
|
||||
object_name=thumbnail_object
|
||||
)
|
||||
thumbnail_data = thumbnail_stream.read()
|
||||
|
||||
# Save to cache
|
||||
await self._save_to_cache(cache_path, thumbnail_data)
|
||||
|
||||
return thumbnail_data
|
||||
except:
|
||||
pass # Thumbnail doesn't exist, generate it
|
||||
|
||||
# Get original file
|
||||
file_stream = await self.minio_client.get_file(bucket, object_name)
|
||||
file_data = file_stream.read()
|
||||
|
||||
# Get file info for content type
|
||||
file_info = await self.minio_client.get_file_info(bucket, object_name)
|
||||
content_type = file_info.get("content_type", "")
|
||||
|
||||
# Generate thumbnail
|
||||
thumbnail_data = await self.generate_thumbnail(
|
||||
file_data, content_type, width, height
|
||||
)
|
||||
|
||||
if thumbnail_data:
|
||||
# Save to MinIO
|
||||
await self.minio_client.upload_file(
|
||||
bucket="thumbnails",
|
||||
object_name=thumbnail_object,
|
||||
file_data=thumbnail_data,
|
||||
content_type="image/jpeg"
|
||||
)
|
||||
|
||||
# Save to cache
|
||||
await self._save_to_cache(cache_path, thumbnail_data)
|
||||
|
||||
return thumbnail_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get thumbnail: {e}")
|
||||
return None
|
||||
|
||||
async def _save_to_cache(self, cache_path: Path, data: bytes):
|
||||
"""Save thumbnail to cache"""
|
||||
try:
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: cache_path.write_bytes(data)
|
||||
)
|
||||
|
||||
logger.info(f"Thumbnail saved to cache: {cache_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save to cache: {e}")
|
||||
|
||||
async def delete_thumbnail(self, file_id: str):
|
||||
"""Delete all thumbnails for a file"""
|
||||
try:
|
||||
# Delete from cache
|
||||
for cache_file in self.cache_dir.rglob(f"*{file_id}*"):
|
||||
try:
|
||||
cache_file.unlink()
|
||||
logger.info(f"Deleted cache file: {cache_file}")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Delete from MinIO (list and delete all sizes)
|
||||
files = await self.minio_client.list_files(
|
||||
bucket="thumbnails",
|
||||
prefix=f"thumbnails/{file_id}_"
|
||||
)
|
||||
|
||||
for file in files:
|
||||
await self.minio_client.delete_file(
|
||||
bucket="thumbnails",
|
||||
object_name=file["name"]
|
||||
)
|
||||
logger.info(f"Deleted thumbnail: {file['name']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete thumbnails: {e}")
|
||||
|
||||
async def generate_multiple_sizes(self, file_data: bytes, content_type: str,
|
||||
file_id: str) -> dict:
|
||||
"""Generate thumbnails in multiple sizes"""
|
||||
sizes = {
|
||||
"small": (150, 150),
|
||||
"medium": (300, 300),
|
||||
"large": (600, 600)
|
||||
}
|
||||
|
||||
results = {}
|
||||
|
||||
for size_name, (width, height) in sizes.items():
|
||||
thumbnail = await self.generate_thumbnail(
|
||||
file_data, content_type, width, height
|
||||
)
|
||||
|
||||
if thumbnail:
|
||||
# Save to MinIO
|
||||
object_name = f"thumbnails/{file_id}_{size_name}.jpg"
|
||||
await self.minio_client.upload_file(
|
||||
bucket="thumbnails",
|
||||
object_name=object_name,
|
||||
file_data=thumbnail,
|
||||
content_type="image/jpeg"
|
||||
)
|
||||
|
||||
results[size_name] = {
|
||||
"size": len(thumbnail),
|
||||
"dimensions": f"{width}x{height}",
|
||||
"object_name": object_name
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear thumbnail cache"""
|
||||
try:
|
||||
import shutil
|
||||
shutil.rmtree(self.cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Thumbnail cache cleared")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
26
services/images/backend/Dockerfile
Normal file
26
services/images/backend/Dockerfile
Normal file
@ -0,0 +1,26 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 시스템 패키지 설치
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libheif-dev \
|
||||
libde265-dev \
|
||||
libjpeg-dev \
|
||||
libpng-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Python 패키지 설치
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 애플리케이션 코드 복사
|
||||
COPY . .
|
||||
|
||||
# 캐시 디렉토리 생성
|
||||
RUN mkdir -p /app/cache
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
0
services/images/backend/app/__init__.py
Normal file
0
services/images/backend/app/__init__.py
Normal file
197
services/images/backend/app/api/endpoints.py
Normal file
197
services/images/backend/app/api/endpoints.py
Normal file
@ -0,0 +1,197 @@
|
||||
from fastapi import APIRouter, Query, HTTPException, Body
|
||||
from fastapi.responses import Response
|
||||
from typing import Optional, Dict
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
# MinIO 사용 여부에 따라 적절한 캐시 모듈 선택
|
||||
if settings.use_minio:
|
||||
from ..core.minio_cache import cache
|
||||
else:
|
||||
from ..core.cache import cache
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/image")
|
||||
async def get_image(
|
||||
url: str = Query(..., description="원본 이미지 URL"),
|
||||
size: Optional[str] = Query(None, description="이미지 크기 (thumb, card, list, detail, hero)")
|
||||
):
|
||||
"""
|
||||
이미지 프록시 엔드포인트
|
||||
|
||||
- 외부 URL의 이미지를 가져와서 캐싱
|
||||
- 선택적으로 리사이징 및 최적화
|
||||
- WebP 포맷으로 자동 변환 (설정에 따라)
|
||||
"""
|
||||
try:
|
||||
# 캐시 확인
|
||||
cached_data = await cache.get(url, size)
|
||||
|
||||
if cached_data:
|
||||
# 캐시된 이미지 반환
|
||||
# SVG 체크
|
||||
if url.lower().endswith('.svg') or cache._is_svg(cached_data):
|
||||
content_type = 'image/svg+xml'
|
||||
# GIF 체크 (GIF는 WebP로 변환하지 않음)
|
||||
elif url.lower().endswith('.gif'):
|
||||
content_type = 'image/gif'
|
||||
# WebP 변환이 활성화된 경우 항상 WebP로 제공 (GIF 제외)
|
||||
elif settings.convert_to_webp and size:
|
||||
content_type = 'image/webp'
|
||||
else:
|
||||
content_type = mimetypes.guess_type(url)[0] or 'image/jpeg'
|
||||
return Response(
|
||||
content=cached_data,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Cache-Control": f"public, max-age={86400 * 7}", # 7일 브라우저 캐시
|
||||
"X-Cache": "HIT",
|
||||
"X-Image-Format": content_type.split('/')[-1].upper()
|
||||
}
|
||||
)
|
||||
|
||||
# 캐시 미스 - 이미지 다운로드
|
||||
image_data = await cache.download_image(url)
|
||||
|
||||
# URL에서 MIME 타입 추측
|
||||
guessed_type = mimetypes.guess_type(url)[0]
|
||||
|
||||
# SVG 확장자 체크 (mimetypes가 SVG를 제대로 인식하지 못할 수 있음)
|
||||
if url.lower().endswith('.svg') or cache._is_svg(image_data):
|
||||
content_type = 'image/svg+xml'
|
||||
# GIF 체크
|
||||
elif url.lower().endswith('.gif') or (guessed_type and 'gif' in guessed_type.lower()):
|
||||
content_type = 'image/gif'
|
||||
else:
|
||||
content_type = guessed_type or 'image/jpeg'
|
||||
|
||||
# 리사이징 및 최적화 (SVG와 GIF는 특별 처리)
|
||||
if size and content_type != 'image/svg+xml':
|
||||
# GIF는 특별 처리
|
||||
if content_type == 'image/gif':
|
||||
image_data, content_type = cache._process_gif(image_data, settings.thumbnail_sizes[size])
|
||||
else:
|
||||
image_data, content_type = cache.resize_and_optimize_image(image_data, size)
|
||||
|
||||
# 캐시에 저장
|
||||
await cache.set(url, image_data, size)
|
||||
|
||||
# 백그라운드에서 다른 크기들도 생성하도록 트리거
|
||||
await cache.trigger_background_generation(url)
|
||||
|
||||
# 이미지 반환
|
||||
return Response(
|
||||
content=image_data,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Cache-Control": f"public, max-age={86400 * 7}",
|
||||
"X-Cache": "MISS",
|
||||
"X-Image-Format": content_type.split('/')[-1].upper()
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"Error processing image from {url}: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 403 에러를 명확히 처리
|
||||
if "403" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"이미지 접근 거부됨: {url}"
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"이미지 처리 실패: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats():
|
||||
"""캐시 통계 정보"""
|
||||
cache_size = await cache.get_cache_size()
|
||||
|
||||
# 디렉토리 구조 통계 추가 (MinIO 또는 파일시스템)
|
||||
dir_stats = await cache.get_directory_stats()
|
||||
|
||||
return {
|
||||
"cache_size_gb": round(cache_size, 2),
|
||||
"max_cache_size_gb": settings.max_cache_size_gb,
|
||||
"cache_usage_percent": round((cache_size / settings.max_cache_size_gb) * 100, 2),
|
||||
"directory_stats": dir_stats
|
||||
}
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def cleanup_cache():
|
||||
"""오래된 캐시 정리"""
|
||||
await cache.cleanup_old_cache()
|
||||
|
||||
return {"message": "캐시 정리 완료"}
|
||||
|
||||
@router.post("/cache/delete")
|
||||
async def delete_cache(request: Dict = Body(...)):
|
||||
"""특정 URL의 캐시 삭제"""
|
||||
url = request.get("url")
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="URL이 필요합니다")
|
||||
|
||||
try:
|
||||
# URL의 모든 크기 버전 삭제
|
||||
sizes = ["thumb", "card", "list", "detail", "hero", None] # None은 원본
|
||||
deleted_count = 0
|
||||
|
||||
for size in sizes:
|
||||
# 캐시 경로 계산
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
|
||||
# 3단계 디렉토리 구조
|
||||
level1 = url_hash[:2]
|
||||
level2 = url_hash[2:4]
|
||||
level3 = url_hash[4:6]
|
||||
|
||||
# 크기별 파일명
|
||||
if size:
|
||||
patterns = [
|
||||
f"{url_hash}_{size}.webp",
|
||||
f"{url_hash}_{size}.jpg",
|
||||
f"{url_hash}_{size}.jpeg",
|
||||
f"{url_hash}_{size}.png",
|
||||
f"{url_hash}_{size}.gif"
|
||||
]
|
||||
else:
|
||||
patterns = [
|
||||
f"{url_hash}",
|
||||
f"{url_hash}.jpg",
|
||||
f"{url_hash}.jpeg",
|
||||
f"{url_hash}.png",
|
||||
f"{url_hash}.gif",
|
||||
f"{url_hash}.webp"
|
||||
]
|
||||
|
||||
# 각 패턴에 대해 파일 삭제 시도
|
||||
for filename in patterns:
|
||||
cache_path = settings.cache_dir / level1 / level2 / level3 / filename
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
deleted_count += 1
|
||||
print(f"✅ 캐시 파일 삭제: {cache_path}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count}개의 캐시 파일이 삭제되었습니다",
|
||||
"url": url
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 캐시 삭제 오류: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"캐시 삭제 실패: {str(e)}"
|
||||
)
|
||||
91
services/images/backend/app/core/background_tasks.py
Normal file
91
services/images/backend/app/core/background_tasks.py
Normal file
@ -0,0 +1,91 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Set, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""백그라운드 작업 관리자"""
|
||||
|
||||
def __init__(self):
|
||||
self.processing_urls: Set[str] = set() # 현재 처리 중인 URL 목록
|
||||
self.task_queue: asyncio.Queue = None
|
||||
self.worker_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self):
|
||||
"""백그라운드 워커 시작"""
|
||||
self.task_queue = asyncio.Queue(maxsize=100)
|
||||
self.worker_task = asyncio.create_task(self._worker())
|
||||
logger.info("백그라운드 작업 관리자 시작됨")
|
||||
|
||||
async def stop(self):
|
||||
"""백그라운드 워커 정지"""
|
||||
if self.worker_task:
|
||||
self.worker_task.cancel()
|
||||
try:
|
||||
await self.worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("백그라운드 작업 관리자 정지됨")
|
||||
|
||||
async def add_task(self, url: str):
|
||||
"""작업 큐에 URL 추가"""
|
||||
if url not in self.processing_urls and self.task_queue:
|
||||
try:
|
||||
self.processing_urls.add(url)
|
||||
await self.task_queue.put(url)
|
||||
logger.info(f"백그라운드 작업 추가: {url}")
|
||||
except asyncio.QueueFull:
|
||||
self.processing_urls.discard(url)
|
||||
logger.warning(f"작업 큐가 가득 참: {url}")
|
||||
|
||||
async def _worker(self):
|
||||
"""백그라운드 워커 - 큐에서 작업을 가져와 처리"""
|
||||
from .cache import cache
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 큐에서 URL 가져오기
|
||||
url = await self.task_queue.get()
|
||||
|
||||
try:
|
||||
# 원본 이미지가 캐시에 있는지 확인
|
||||
original_data = await cache.get(url, None)
|
||||
|
||||
if not original_data:
|
||||
# 원본 이미지 다운로드
|
||||
original_data = await cache.download_image(url)
|
||||
await cache.set(url, original_data, None)
|
||||
|
||||
# 모든 크기의 이미지 생성
|
||||
sizes = ['thumb', 'card', 'list', 'detail', 'hero']
|
||||
for size in sizes:
|
||||
# 이미 존재하는지 확인
|
||||
existing = await cache.get(url, size)
|
||||
if not existing:
|
||||
try:
|
||||
# 리사이징 및 최적화 - cache.resize_and_optimize_image가 WebP를 처리함
|
||||
resized_data, _ = cache.resize_and_optimize_image(original_data, size)
|
||||
await cache.set(url, resized_data, size)
|
||||
logger.info(f"백그라운드 생성 완료: {url} ({size})")
|
||||
except Exception as e:
|
||||
logger.error(f"백그라운드 리사이징 실패: {url} ({size}) - {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"백그라운드 작업 실패: {url} - {str(e)}")
|
||||
finally:
|
||||
# 처리 완료된 URL 제거
|
||||
self.processing_urls.discard(url)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"백그라운드 워커 오류: {str(e)}")
|
||||
await asyncio.sleep(1) # 오류 발생 시 잠시 대기
|
||||
|
||||
# 전역 백그라운드 작업 관리자
|
||||
background_manager = BackgroundTaskManager()
|
||||
796
services/images/backend/app/core/cache.py
Normal file
796
services/images/backend/app/core/cache.py
Normal file
@ -0,0 +1,796 @@
|
||||
import hashlib
|
||||
import aiofiles
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
import httpx
|
||||
from PIL import Image
|
||||
try:
|
||||
from pillow_heif import register_heif_opener, register_avif_opener
|
||||
register_heif_opener() # HEIF/HEIC 지원
|
||||
register_avif_opener() # AVIF 지원
|
||||
print("HEIF/AVIF support enabled successfully")
|
||||
except ImportError:
|
||||
print("Warning: pillow_heif not installed, HEIF/AVIF support disabled")
|
||||
import io
|
||||
import asyncio
|
||||
import ssl
|
||||
|
||||
from .config import settings
|
||||
|
||||
class ImageCache:
|
||||
def __init__(self):
|
||||
self.cache_dir = settings.cache_dir
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_cache_path(self, url: str, size: Optional[str] = None) -> Path:
|
||||
"""URL을 기반으로 캐시 파일 경로 생성"""
|
||||
# URL을 해시하여 파일명 생성
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
|
||||
# 3단계 디렉토리 구조 생성
|
||||
# 예: 10f8a8f96aa1377e86fdbc6bf3c631cf -> 10/f8/a8/
|
||||
level1 = url_hash[:2] # 첫 2자리
|
||||
level2 = url_hash[2:4] # 다음 2자리
|
||||
level3 = url_hash[4:6] # 다음 2자리
|
||||
|
||||
# 크기별로 다른 파일명 사용
|
||||
if size:
|
||||
filename = f"{url_hash}_{size}"
|
||||
else:
|
||||
filename = url_hash
|
||||
|
||||
# 확장자 추출 (WebP로 저장되는 경우 .webp 사용)
|
||||
if settings.convert_to_webp and size:
|
||||
filename = f"{filename}.webp"
|
||||
else:
|
||||
ext = self._get_extension_from_url(url)
|
||||
if ext:
|
||||
filename = f"{filename}.{ext}"
|
||||
|
||||
# 3단계 디렉토리 경로 생성
|
||||
path = self.cache_dir / level1 / level2 / level3 / filename
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return path
|
||||
|
||||
def _get_extension_from_url(self, url: str) -> Optional[str]:
|
||||
"""URL에서 파일 확장자 추출"""
|
||||
path = url.split('?')[0] # 쿼리 파라미터 제거
|
||||
parts = path.split('.')
|
||||
if len(parts) > 1:
|
||||
ext = parts[-1].lower()
|
||||
if ext in settings.allowed_formats:
|
||||
return ext
|
||||
return None
|
||||
|
||||
def _is_svg(self, data: bytes) -> bool:
|
||||
"""SVG 파일인지 확인"""
|
||||
# SVG 파일의 시작 부분 확인
|
||||
if len(data) < 100:
|
||||
return False
|
||||
|
||||
# 처음 1000바이트만 확인 (성능 최적화)
|
||||
header = data[:1000].lower()
|
||||
|
||||
# SVG 시그니처 확인
|
||||
svg_signatures = [
|
||||
b'<svg',
|
||||
b'<?xml',
|
||||
b'<!doctype svg'
|
||||
]
|
||||
|
||||
for sig in svg_signatures:
|
||||
if sig in header:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _process_gif(self, gif_data: bytes, target_size: tuple) -> tuple[bytes, str]:
|
||||
"""GIF 처리 - JPEG로 변환하여 안정적으로 처리"""
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
# GIF 열기
|
||||
img = Image.open(io.BytesIO(gif_data))
|
||||
|
||||
# 모든 GIF를 RGB로 변환 (팔레트 모드 문제 해결)
|
||||
# 팔레트 모드(P)를 RGB로 직접 변환
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 리사이징
|
||||
img = img.resize(target_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# JPEG로 저장 (안정적)
|
||||
output = io.BytesIO()
|
||||
img.save(output, format='JPEG', quality=85, optimize=True)
|
||||
return output.getvalue(), 'image/jpeg'
|
||||
|
||||
except Exception as e:
|
||||
print(f"GIF 처리 중 오류: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# 오류 발생 시 원본 반환
|
||||
return gif_data, 'image/gif'
|
||||
|
||||
async def get(self, url: str, size: Optional[str] = None) -> Optional[bytes]:
|
||||
"""캐시에서 이미지 가져오기"""
|
||||
cache_path = self._get_cache_path(url, size)
|
||||
|
||||
if cache_path.exists():
|
||||
# 캐시 만료 확인
|
||||
stat = cache_path.stat()
|
||||
age = datetime.now() - datetime.fromtimestamp(stat.st_mtime)
|
||||
|
||||
if age < timedelta(days=settings.cache_ttl_days):
|
||||
async with aiofiles.open(cache_path, 'rb') as f:
|
||||
return await f.read()
|
||||
else:
|
||||
# 만료된 캐시 삭제
|
||||
cache_path.unlink()
|
||||
|
||||
return None
|
||||
|
||||
async def set(self, url: str, data: bytes, size: Optional[str] = None):
|
||||
"""캐시에 이미지 저장"""
|
||||
cache_path = self._get_cache_path(url, size)
|
||||
|
||||
async with aiofiles.open(cache_path, 'wb') as f:
|
||||
await f.write(data)
|
||||
|
||||
async def download_image(self, url: str) -> bytes:
|
||||
"""외부 URL에서 이미지 다운로드"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# URL에서 도메인 추출
|
||||
parsed_url = urlparse(url)
|
||||
domain = parsed_url.netloc
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
||||
|
||||
# 기본 헤더 설정
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||||
'Accept': 'image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8',
|
||||
'Accept-Language': 'ko-KR,ko;q=0.9,en-US;q=0.8,en;q=0.7',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Pragma': 'no-cache',
|
||||
'Sec-Fetch-Dest': 'image',
|
||||
'Sec-Fetch-Mode': 'no-cors',
|
||||
'Sec-Fetch-Site': 'cross-site',
|
||||
'Referer': base_url # 항상 기본 Referer 설정
|
||||
}
|
||||
|
||||
# 특정 사이트별 Referer 오버라이드
|
||||
if 'yna.co.kr' in url:
|
||||
headers['Referer'] = 'https://www.yna.co.kr/'
|
||||
client = httpx.AsyncClient(
|
||||
verify=False, # SSL 검증 비활성화
|
||||
timeout=30.0,
|
||||
follow_redirects=True
|
||||
)
|
||||
elif 'investing.com' in url:
|
||||
headers['Referer'] = 'https://www.investing.com/'
|
||||
client = httpx.AsyncClient()
|
||||
elif 'naver.com' in url:
|
||||
headers['Referer'] = 'https://news.naver.com/'
|
||||
client = httpx.AsyncClient()
|
||||
elif 'daum.net' in url:
|
||||
headers['Referer'] = 'https://news.daum.net/'
|
||||
client = httpx.AsyncClient()
|
||||
elif 'chosun.com' in url:
|
||||
headers['Referer'] = 'https://www.chosun.com/'
|
||||
client = httpx.AsyncClient()
|
||||
elif 'vietnam.vn' in url or 'vstatic.vietnam.vn' in url:
|
||||
headers['Referer'] = 'https://vietnam.vn/'
|
||||
client = httpx.AsyncClient()
|
||||
elif 'ddaily.co.kr' in url:
|
||||
# ddaily는 /photos/ 경로를 사용해야 함
|
||||
headers['Referer'] = 'https://www.ddaily.co.kr/'
|
||||
# URL이 잘못된 경로를 사용하는 경우 수정
|
||||
if '/2025/' in url and '/photos/' not in url:
|
||||
url = url.replace('/2025/', '/photos/2025/')
|
||||
print(f"Fixed ddaily URL: {url}")
|
||||
client = httpx.AsyncClient()
|
||||
else:
|
||||
# 기본적으로 도메인 기반 Referer 사용
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
async with client:
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=settings.request_timeout,
|
||||
follow_redirects=True
|
||||
)
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
# 모든 에러에 대해 Playwright 사용 시도
|
||||
error_msg = str(e)
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_type = f"HTTP {e.response.status_code}"
|
||||
elif isinstance(e, httpx.ConnectError):
|
||||
error_type = "Connection Error"
|
||||
elif isinstance(e, ssl.SSLError):
|
||||
error_type = "SSL Error"
|
||||
elif "resolve" in error_msg.lower() or "dns" in error_msg.lower():
|
||||
error_type = "DNS Resolution Error"
|
||||
else:
|
||||
error_type = "Network Error"
|
||||
|
||||
print(f"{error_type} for {url}, trying with Playwright...")
|
||||
|
||||
# Playwright로 이미지 가져오기 시도
|
||||
try:
|
||||
from playwright.async_api import async_playwright
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
async with async_playwright() as p:
|
||||
# 브라우저 실행
|
||||
browser = await p.chromium.launch(
|
||||
headless=True,
|
||||
args=['--no-sandbox', '--disable-setuid-sandbox']
|
||||
)
|
||||
|
||||
# Referer 설정을 위한 도메인 추출
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(url)
|
||||
referer_url = f"{parsed.scheme}://{parsed.netloc}/"
|
||||
|
||||
context = await browser.new_context(
|
||||
viewport={'width': 1920, 'height': 1080},
|
||||
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||||
extra_http_headers={
|
||||
'Referer': referer_url
|
||||
}
|
||||
)
|
||||
|
||||
page = await context.new_page()
|
||||
|
||||
try:
|
||||
# Response를 가로채기 위한 설정
|
||||
image_data = None
|
||||
|
||||
async def handle_response(response):
|
||||
nonlocal image_data
|
||||
# 이미지 URL에 대한 응답 가로채기
|
||||
if url in response.url or response.url == url:
|
||||
try:
|
||||
image_data = await response.body()
|
||||
print(f"✅ Image intercepted: {len(image_data)} bytes")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Response 이벤트 리스너 등록
|
||||
page.on('response', handle_response)
|
||||
|
||||
# 이미지 URL로 이동 (에러 무시)
|
||||
try:
|
||||
await page.goto(url, wait_until='networkidle', timeout=30000)
|
||||
except Exception as goto_error:
|
||||
print(f"⚠️ Direct navigation failed: {goto_error}")
|
||||
# 직접 이동 실패 시 HTML에 img 태그 삽입
|
||||
await page.set_content(f'''
|
||||
<html>
|
||||
<body style="margin:0;padding:0;">
|
||||
<img src="{url}" style="max-width:100%;height:auto;"
|
||||
crossorigin="anonymous" />
|
||||
</body>
|
||||
</html>
|
||||
''')
|
||||
await page.wait_for_timeout(3000) # 이미지 로딩 대기
|
||||
|
||||
# 이미지 데이터가 없으면 JavaScript로 직접 fetch
|
||||
if not image_data:
|
||||
# JavaScript로 이미지 fetch
|
||||
image_data_base64 = await page.evaluate('''
|
||||
async (url) => {
|
||||
try {
|
||||
const response = await fetch(url);
|
||||
const blob = await response.blob();
|
||||
return new Promise((resolve) => {
|
||||
const reader = new FileReader();
|
||||
reader.onloadend = () => resolve(reader.result.split(',')[1]);
|
||||
reader.readAsDataURL(blob);
|
||||
});
|
||||
} catch (e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
''', url)
|
||||
|
||||
if image_data_base64:
|
||||
import base64
|
||||
image_data = base64.b64decode(image_data_base64)
|
||||
print(f"✅ Image fetched via JavaScript: {len(image_data)} bytes")
|
||||
|
||||
# 여전히 데이터가 없으면 스크린샷 사용
|
||||
if not image_data:
|
||||
# 이미지 요소 찾기
|
||||
img_element = await page.query_selector('img')
|
||||
if img_element:
|
||||
# 이미지가 로드되었는지 확인
|
||||
is_loaded = await img_element.evaluate('(img) => img.complete && img.naturalHeight > 0')
|
||||
if is_loaded:
|
||||
image_data = await img_element.screenshot()
|
||||
print(f"✅ Screenshot from loaded image: {len(image_data)} bytes")
|
||||
else:
|
||||
# 이미지 로드 대기
|
||||
try:
|
||||
await img_element.evaluate('(img) => new Promise(r => img.onload = r)')
|
||||
image_data = await img_element.screenshot()
|
||||
print(f"✅ Screenshot after waiting: {len(image_data)} bytes")
|
||||
except:
|
||||
# 전체 페이지 스크린샷
|
||||
image_data = await page.screenshot(full_page=True)
|
||||
print(f"⚠️ Full page screenshot: {len(image_data)} bytes")
|
||||
else:
|
||||
image_data = await page.screenshot(full_page=True)
|
||||
print(f"⚠️ No image element, full screenshot: {len(image_data)} bytes")
|
||||
|
||||
print(f"✅ Successfully fetched image with Playwright: {url}")
|
||||
return image_data
|
||||
|
||||
finally:
|
||||
await page.close()
|
||||
await context.close()
|
||||
await browser.close()
|
||||
|
||||
except Exception as pw_error:
|
||||
print(f"Playwright failed: {pw_error}, returning placeholder")
|
||||
|
||||
# Playwright도 실패하면 세련된 placeholder 반환
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import io
|
||||
import random
|
||||
|
||||
# 그라디언트 배경색 선택 (부드러운 색상)
|
||||
gradients = [
|
||||
('#667eea', '#764ba2'), # 보라 그라디언트
|
||||
('#f093fb', '#f5576c'), # 핑크 그라디언트
|
||||
('#4facfe', '#00f2fe'), # 하늘색 그라디언트
|
||||
('#43e97b', '#38f9d7'), # 민트 그라디언트
|
||||
('#fa709a', '#fee140'), # 선셋 그라디언트
|
||||
('#30cfd0', '#330867'), # 딥 오션
|
||||
('#a8edea', '#fed6e3'), # 파스텔
|
||||
('#ffecd2', '#fcb69f'), # 피치
|
||||
]
|
||||
|
||||
# 랜덤 그라디언트 선택
|
||||
color1, color2 = random.choice(gradients)
|
||||
|
||||
# 이미지 생성 (16:9 비율)
|
||||
width, height = 800, 450
|
||||
img = Image.new('RGB', (width, height))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 그라디언트 배경 생성
|
||||
def hex_to_rgb(hex_color):
|
||||
hex_color = hex_color.lstrip('#')
|
||||
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
||||
|
||||
rgb1 = hex_to_rgb(color1)
|
||||
rgb2 = hex_to_rgb(color2)
|
||||
|
||||
# 세로 그라디언트
|
||||
for y in range(height):
|
||||
ratio = y / height
|
||||
r = int(rgb1[0] * (1 - ratio) + rgb2[0] * ratio)
|
||||
g = int(rgb1[1] * (1 - ratio) + rgb2[1] * ratio)
|
||||
b = int(rgb1[2] * (1 - ratio) + rgb2[2] * ratio)
|
||||
draw.rectangle([(0, y), (width, y + 1)], fill=(r, g, b))
|
||||
|
||||
# 반투명 오버레이 추가 (깊이감)
|
||||
overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0))
|
||||
overlay_draw = ImageDraw.Draw(overlay)
|
||||
|
||||
# 중앙 원형 그라디언트 효과
|
||||
center_x, center_y = width // 2, height // 2
|
||||
max_radius = min(width, height) // 3
|
||||
|
||||
for radius in range(max_radius, 0, -2):
|
||||
opacity = int(255 * (1 - radius / max_radius) * 0.3)
|
||||
overlay_draw.ellipse(
|
||||
[(center_x - radius, center_y - radius),
|
||||
(center_x + radius, center_y + radius)],
|
||||
fill=(255, 255, 255, opacity)
|
||||
)
|
||||
|
||||
# 이미지 아이콘 그리기 (산 모양)
|
||||
icon_color = (255, 255, 255, 200)
|
||||
icon_size = 80
|
||||
icon_x = center_x
|
||||
icon_y = center_y - 20
|
||||
|
||||
# 산 아이콘 (사진 이미지를 나타냄)
|
||||
mountain_points = [
|
||||
(icon_x - icon_size, icon_y + icon_size//2),
|
||||
(icon_x - icon_size//2, icon_y - icon_size//4),
|
||||
(icon_x - icon_size//4, icon_y),
|
||||
(icon_x + icon_size//4, icon_y - icon_size//2),
|
||||
(icon_x + icon_size, icon_y + icon_size//2),
|
||||
]
|
||||
overlay_draw.polygon(mountain_points, fill=icon_color)
|
||||
|
||||
# 태양/달 원
|
||||
sun_radius = icon_size // 4
|
||||
overlay_draw.ellipse(
|
||||
[(icon_x - icon_size//2, icon_y - icon_size//2 - sun_radius),
|
||||
(icon_x - icon_size//2 + sun_radius*2, icon_y - icon_size//2 + sun_radius)],
|
||||
fill=icon_color
|
||||
)
|
||||
|
||||
# 프레임 테두리
|
||||
frame_margin = 40
|
||||
overlay_draw.rectangle(
|
||||
[(frame_margin, frame_margin),
|
||||
(width - frame_margin, height - frame_margin)],
|
||||
outline=(255, 255, 255, 150),
|
||||
width=3
|
||||
)
|
||||
|
||||
# 코너 장식
|
||||
corner_size = 20
|
||||
corner_width = 4
|
||||
corners = [
|
||||
(frame_margin, frame_margin),
|
||||
(width - frame_margin - corner_size, frame_margin),
|
||||
(frame_margin, height - frame_margin - corner_size),
|
||||
(width - frame_margin - corner_size, height - frame_margin - corner_size)
|
||||
]
|
||||
|
||||
for x, y in corners:
|
||||
# 가로선
|
||||
overlay_draw.rectangle(
|
||||
[(x, y), (x + corner_size, y + corner_width)],
|
||||
fill=(255, 255, 255, 200)
|
||||
)
|
||||
# 세로선
|
||||
overlay_draw.rectangle(
|
||||
[(x, y), (x + corner_width, y + corner_size)],
|
||||
fill=(255, 255, 255, 200)
|
||||
)
|
||||
|
||||
# "Image Loading..." 텍스트 (작게)
|
||||
try:
|
||||
# 시스템 폰트 시도
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
text = "Image Loading..."
|
||||
bbox = draw.textbbox((0, 0), text, font=font)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
text_x = (width - text_width) // 2
|
||||
text_y = center_y + icon_size
|
||||
|
||||
# 텍스트 그림자
|
||||
for offset in [(2, 2), (-1, -1)]:
|
||||
overlay_draw.text(
|
||||
(text_x + offset[0], text_y + offset[1]),
|
||||
text,
|
||||
font=font,
|
||||
fill=(0, 0, 0, 100)
|
||||
)
|
||||
|
||||
# 텍스트 본체
|
||||
overlay_draw.text(
|
||||
(text_x, text_y),
|
||||
text,
|
||||
font=font,
|
||||
fill=(255, 255, 255, 220)
|
||||
)
|
||||
|
||||
# 오버레이 합성
|
||||
img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB')
|
||||
|
||||
# 약간의 노이즈 추가 (텍스처)
|
||||
pixels = img.load()
|
||||
for _ in range(1000):
|
||||
x = random.randint(0, width - 1)
|
||||
y = random.randint(0, height - 1)
|
||||
r, g, b = pixels[x, y]
|
||||
brightness = random.randint(-20, 20)
|
||||
pixels[x, y] = (
|
||||
max(0, min(255, r + brightness)),
|
||||
max(0, min(255, g + brightness)),
|
||||
max(0, min(255, b + brightness))
|
||||
)
|
||||
|
||||
# JPEG로 변환 (높은 품질)
|
||||
output = io.BytesIO()
|
||||
img.save(output, format='JPEG', quality=85, optimize=True)
|
||||
return output.getvalue()
|
||||
raise
|
||||
|
||||
# 이미지 크기 확인
|
||||
content_length = int(response.headers.get('content-length', 0))
|
||||
max_size = settings.max_image_size_mb * 1024 * 1024
|
||||
|
||||
if content_length > max_size:
|
||||
raise ValueError(f"Image too large: {content_length} bytes")
|
||||
|
||||
# 응답 데이터 확인
|
||||
content = response.content
|
||||
print(f"Downloaded {len(content)} bytes from {url[:50]}...")
|
||||
|
||||
# gzip 압축 확인 및 해제
|
||||
import gzip
|
||||
if len(content) > 2 and content[:2] == b'\x1f\x8b':
|
||||
print("📦 Gzip compressed data detected, decompressing...")
|
||||
try:
|
||||
content = gzip.decompress(content)
|
||||
print(f"✅ Decompressed to {len(content)} bytes")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to decompress gzip: {e}")
|
||||
|
||||
# 처음 몇 바이트로 이미지 형식 확인
|
||||
if len(content) > 10:
|
||||
header = content[:12]
|
||||
if header[:2] == b'\xff\xd8':
|
||||
print("✅ JPEG image detected")
|
||||
elif header[:8] == b'\x89PNG\r\n\x1a\n':
|
||||
print("✅ PNG image detected")
|
||||
elif header[:6] in (b'GIF87a', b'GIF89a'):
|
||||
print("✅ GIF image detected")
|
||||
elif header[:4] == b'RIFF' and header[8:12] == b'WEBP':
|
||||
print("✅ WebP image detected")
|
||||
elif b'<svg' in header or b'<?xml' in header:
|
||||
print("✅ SVG image detected")
|
||||
elif header[4:12] == b'ftypavif':
|
||||
print("✅ AVIF image detected")
|
||||
else:
|
||||
print(f"⚠️ Unknown image format. Header: {header.hex()}")
|
||||
|
||||
return content
|
||||
|
||||
def resize_and_optimize_image(self, image_data: bytes, size: str) -> tuple[bytes, str]:
|
||||
"""이미지 리사이징 및 최적화"""
|
||||
if size not in settings.thumbnail_sizes:
|
||||
raise ValueError(f"Invalid size: {size}")
|
||||
|
||||
target_size = settings.thumbnail_sizes[size]
|
||||
|
||||
# SVG 체크 - SVG는 리사이징하지 않고 그대로 반환
|
||||
if self._is_svg(image_data):
|
||||
return image_data, 'image/svg+xml'
|
||||
|
||||
# PIL로 이미지 열기
|
||||
try:
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
except Exception as e:
|
||||
# WebP 헤더 체크 (RIFF....WEBP)
|
||||
header = image_data[:12] if len(image_data) >= 12 else image_data
|
||||
if header[:4] == b'RIFF' and header[8:12] == b'WEBP':
|
||||
print("🎨 WebP 이미지 감지됨, 변환 시도")
|
||||
# WebP 형식이지만 PIL이 열지 못하는 경우
|
||||
# Pillow-SIMD 또는 추가 라이브러리가 필요할 수 있음
|
||||
try:
|
||||
# 재시도
|
||||
from PIL import WebPImagePlugin
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
except:
|
||||
print("❌ WebP 이미지를 열 수 없음, 원본 반환")
|
||||
return image_data, 'image/webp'
|
||||
else:
|
||||
raise e
|
||||
|
||||
# GIF 애니메이션 체크 및 처리
|
||||
if getattr(img, "format", None) == "GIF":
|
||||
return self._process_gif(image_data, target_size)
|
||||
|
||||
# WebP 형식 체크
|
||||
original_format = getattr(img, "format", None)
|
||||
is_webp = original_format == "WEBP"
|
||||
|
||||
# 원본 모드와 투명도 정보 저장
|
||||
original_mode = img.mode
|
||||
original_has_transparency = img.mode in ('RGBA', 'LA')
|
||||
original_has_palette = img.mode == 'P'
|
||||
|
||||
# 팔레트 모드(P) 처리 - 간단하게 PIL의 기본 변환 사용
|
||||
if img.mode == 'P':
|
||||
# 팔레트 모드는 RGB로 직접 변환
|
||||
# PIL의 convert 메서드가 팔레트를 올바르게 처리함
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 투명도가 있는 이미지 처리
|
||||
if img.mode == 'RGBA':
|
||||
# RGBA는 흰색 배경과 합성
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
background.paste(img, mask=img.split()[-1])
|
||||
img = background
|
||||
elif img.mode == 'LA':
|
||||
# LA(그레이스케일+알파)는 RGBA를 거쳐 RGB로
|
||||
img = img.convert('RGBA')
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
background.paste(img, mask=img.split()[-1])
|
||||
img = background
|
||||
elif img.mode == 'L':
|
||||
# 그레이스케일은 RGB로 변환
|
||||
img = img.convert('RGB')
|
||||
elif img.mode not in ('RGB',):
|
||||
# 기타 모드는 모두 RGB로 변환
|
||||
img = img.convert('RGB')
|
||||
|
||||
# EXIF 방향 정보 처리 (RGB 변환 후에 수행)
|
||||
try:
|
||||
from PIL import ImageOps
|
||||
img = ImageOps.exif_transpose(img)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 메타데이터 제거는 스킵 (팔레트 모드 이미지에서 문제 발생)
|
||||
# RGB로 변환되었으므로 이미 메타데이터는 대부분 제거됨
|
||||
|
||||
# 비율 유지하며 리사이징 (크롭 없이)
|
||||
img_ratio = img.width / img.height
|
||||
target_width = target_size[0]
|
||||
target_height = target_size[1]
|
||||
|
||||
# 원본 비율을 유지하면서 목표 크기에 맞추기
|
||||
# 너비 또는 높이 중 하나를 기준으로 비율 계산
|
||||
if img.width > target_width or img.height > target_height:
|
||||
# 너비 기준 리사이징
|
||||
width_ratio = target_width / img.width
|
||||
# 높이 기준 리사이징
|
||||
height_ratio = target_height / img.height
|
||||
# 둘 중 작은 비율 사용 (목표 크기를 넘지 않도록)
|
||||
ratio = min(width_ratio, height_ratio)
|
||||
|
||||
new_width = int(img.width * ratio)
|
||||
new_height = int(img.height * ratio)
|
||||
|
||||
# 큰 이미지를 작게 만들 때는 2단계 리샘플링으로 품질 향상
|
||||
if img.width > new_width * 2 or img.height > new_height * 2:
|
||||
# 1단계: 목표 크기의 2배로 먼저 축소
|
||||
intermediate_width = new_width * 2
|
||||
intermediate_height = new_height * 2
|
||||
img = img.resize((intermediate_width, intermediate_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 최종 목표 크기로 리샘플링
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 샤프닝 적용 (작은 이미지에만)
|
||||
if target_size[0] <= 400:
|
||||
from PIL import ImageEnhance
|
||||
enhancer = ImageEnhance.Sharpness(img)
|
||||
img = enhancer.enhance(1.2)
|
||||
|
||||
# 바이트로 변환
|
||||
output = io.BytesIO()
|
||||
|
||||
# 적응형 품질 계산 (이미지 크기에 따라 조정)
|
||||
def get_adaptive_quality(base_quality: int, target_width: int) -> int:
|
||||
"""이미지 크기에 따른 적응형 품질 계산"""
|
||||
# 품질을 더 높게 설정하여 검정색 문제 해결
|
||||
if target_width <= 150: # 썸네일
|
||||
return min(base_quality + 10, 95)
|
||||
elif target_width <= 360: # 카드
|
||||
return min(base_quality + 5, 90)
|
||||
elif target_width <= 800: # 상세
|
||||
return base_quality # 85
|
||||
else: # 히어로
|
||||
return base_quality # 85
|
||||
|
||||
# WebP 변환 및 최적화 - 최고 압축률 설정
|
||||
# WebP 입력은 JPEG로 변환 (WebP 리사이징 문제 회피)
|
||||
if is_webp:
|
||||
output_format = 'JPEG'
|
||||
content_type = 'image/jpeg'
|
||||
else:
|
||||
output_format = 'WEBP' if settings.convert_to_webp else 'JPEG'
|
||||
content_type = 'image/webp' if output_format == 'WEBP' else 'image/jpeg'
|
||||
|
||||
if output_format == 'WEBP':
|
||||
# WebP 최적화: method=6(최고품질), lossless=False, exact=False
|
||||
adaptive_quality = get_adaptive_quality(settings.webp_quality, target_size[0])
|
||||
|
||||
save_kwargs = {
|
||||
'format': 'WEBP',
|
||||
'quality': adaptive_quality,
|
||||
'method': 6, # 최고 압축 알고리즘 (0-6)
|
||||
'lossless': settings.webp_lossless,
|
||||
'exact': False, # 약간의 품질 손실 허용하여 더 작은 크기
|
||||
}
|
||||
|
||||
img.save(output, **save_kwargs)
|
||||
elif original_has_transparency and not settings.convert_to_webp:
|
||||
# PNG 최적화 (투명도가 있는 이미지)
|
||||
save_kwargs = {
|
||||
'format': 'PNG',
|
||||
'optimize': settings.optimize_png,
|
||||
'compress_level': settings.png_compress_level,
|
||||
}
|
||||
|
||||
# 팔레트 모드로 변환 가능한지 확인 (256색 이하)
|
||||
if settings.optimize_png:
|
||||
try:
|
||||
# 색상 수가 256개 이하이면 팔레트 모드로 변환
|
||||
quantized = img.quantize(colors=256, method=Image.Quantize.MEDIANCUT)
|
||||
if len(quantized.getcolors()) <= 256:
|
||||
img = quantized
|
||||
save_kwargs['format'] = 'PNG'
|
||||
except:
|
||||
pass
|
||||
|
||||
content_type = 'image/png'
|
||||
img.save(output, **save_kwargs)
|
||||
else:
|
||||
# JPEG 최적화 설정 (기본값)
|
||||
adaptive_quality = get_adaptive_quality(settings.jpeg_quality, target_size[0])
|
||||
|
||||
save_kwargs = {
|
||||
'format': 'JPEG',
|
||||
'quality': adaptive_quality,
|
||||
'optimize': True,
|
||||
'progressive': settings.progressive_jpeg,
|
||||
}
|
||||
|
||||
img.save(output, **save_kwargs)
|
||||
|
||||
return output.getvalue(), content_type
|
||||
|
||||
async def get_cache_size(self) -> float:
|
||||
"""현재 캐시 크기 (GB)"""
|
||||
total_size = 0
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(self.cache_dir):
|
||||
for filename in filenames:
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
total_size += os.path.getsize(filepath)
|
||||
|
||||
return total_size / (1024 ** 3) # GB로 변환
|
||||
|
||||
async def cleanup_old_cache(self):
|
||||
"""오래된 캐시 파일 정리"""
|
||||
cutoff_time = datetime.now() - timedelta(days=settings.cache_ttl_days)
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(self.cache_dir):
|
||||
for filename in filenames:
|
||||
filepath = Path(dirpath) / filename
|
||||
|
||||
if filepath.stat().st_mtime < cutoff_time.timestamp():
|
||||
filepath.unlink()
|
||||
|
||||
async def trigger_background_generation(self, url: str):
|
||||
"""백그라운드에서 모든 크기의 이미지 생성 트리거"""
|
||||
from .background_tasks import background_manager
|
||||
|
||||
# 백그라운드 작업 큐에 추가
|
||||
asyncio.create_task(background_manager.add_task(url))
|
||||
|
||||
async def get_directory_stats(self) -> dict:
|
||||
"""디렉토리 구조 통계 정보"""
|
||||
total_files = 0
|
||||
total_dirs = 0
|
||||
files_per_dir = {}
|
||||
|
||||
for root, dirs, files in os.walk(self.cache_dir):
|
||||
total_dirs += len(dirs)
|
||||
total_files += len(files)
|
||||
|
||||
# 각 디렉토리의 파일 수 계산
|
||||
rel_path = os.path.relpath(root, self.cache_dir)
|
||||
depth = len(Path(rel_path).parts) if rel_path != '.' else 0
|
||||
|
||||
if files and depth == 3: # 3단계 디렉토리에서만 파일 수 계산
|
||||
files_per_dir[rel_path] = len(files)
|
||||
|
||||
# 통계 계산
|
||||
avg_files_per_dir = sum(files_per_dir.values()) / len(files_per_dir) if files_per_dir else 0
|
||||
max_files_in_dir = max(files_per_dir.values()) if files_per_dir else 0
|
||||
|
||||
return {
|
||||
"total_files": total_files,
|
||||
"total_directories": total_dirs,
|
||||
"average_files_per_directory": round(avg_files_per_dir, 2),
|
||||
"max_files_in_single_directory": max_files_in_dir,
|
||||
"directory_depth": 3
|
||||
}
|
||||
|
||||
cache = ImageCache()
|
||||
54
services/images/backend/app/core/config.py
Normal file
54
services/images/backend/app/core/config.py
Normal file
@ -0,0 +1,54 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pathlib import Path
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 기본 설정
|
||||
app_name: str = "Image Proxy Service"
|
||||
debug: bool = True
|
||||
|
||||
# 캐시 설정 (MinIO 전환 시에도 로컬 임시 파일용)
|
||||
cache_dir: Path = Path("/app/cache")
|
||||
max_cache_size_gb: int = 10
|
||||
cache_ttl_days: int = 30
|
||||
|
||||
# MinIO 설정
|
||||
use_minio: bool = True # MinIO 사용 여부
|
||||
minio_endpoint: str = "minio:9000"
|
||||
minio_access_key: str = "minioadmin"
|
||||
minio_secret_key: str = "minioadmin"
|
||||
minio_bucket_name: str = "image-cache"
|
||||
minio_secure: bool = False
|
||||
|
||||
# 이미지 설정
|
||||
max_image_size_mb: int = 20
|
||||
allowed_formats: list = ["jpg", "jpeg", "png", "gif", "webp", "svg"]
|
||||
|
||||
# 리사이징 설정 - 뉴스 카드 용도별 최적화
|
||||
thumbnail_sizes: dict = {
|
||||
"thumb": (150, 100), # 작은 썸네일 (3:2 비율)
|
||||
"card": (360, 240), # 뉴스 카드용 (3:2 비율)
|
||||
"list": (300, 200), # 리스트용 (3:2 비율)
|
||||
"detail": (800, 533), # 상세 페이지용 (원본 비율 유지)
|
||||
"hero": (1200, 800) # 히어로 이미지용 (원본 비율 유지)
|
||||
}
|
||||
|
||||
# 이미지 최적화 설정 - 품질 보장하면서 최저 용량
|
||||
jpeg_quality: int = 85 # JPEG 품질 (품질 향상)
|
||||
webp_quality: int = 85 # WebP 품질 (품질 향상으로 검정색 문제 해결)
|
||||
webp_lossless: bool = False # 무손실 압축 비활성화 (용량 최적화)
|
||||
png_compress_level: int = 9 # PNG 최대 압축 (0-9, 9가 최고 압축)
|
||||
convert_to_webp: bool = False # WebP 변환 임시 비활성화 (검정색 이미지 문제)
|
||||
|
||||
# 고급 최적화 설정
|
||||
progressive_jpeg: bool = True # 점진적 JPEG (로딩 성능 향상)
|
||||
strip_metadata: bool = True # EXIF 등 메타데이터 제거 (용량 절약)
|
||||
optimize_png: bool = True # PNG 팔레트 최적화
|
||||
|
||||
# 외부 요청 설정
|
||||
request_timeout: int = 30
|
||||
user_agent: str = "ImageProxyService/1.0"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
settings = Settings()
|
||||
414
services/images/backend/app/core/minio_cache.py
Normal file
414
services/images/backend/app/core/minio_cache.py
Normal file
@ -0,0 +1,414 @@
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
import httpx
|
||||
from PIL import Image
|
||||
try:
|
||||
from pillow_heif import register_heif_opener, register_avif_opener
|
||||
register_heif_opener() # HEIF/HEIC 지원
|
||||
register_avif_opener() # AVIF 지원
|
||||
print("HEIF/AVIF support enabled successfully")
|
||||
except ImportError:
|
||||
print("Warning: pillow_heif not installed, HEIF/AVIF support disabled")
|
||||
import io
|
||||
import asyncio
|
||||
import ssl
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
import tempfile
|
||||
|
||||
from .config import settings
|
||||
|
||||
class MinIOImageCache:
|
||||
def __init__(self):
|
||||
# MinIO 클라이언트 초기화
|
||||
self.client = Minio(
|
||||
settings.minio_endpoint,
|
||||
access_key=settings.minio_access_key,
|
||||
secret_key=settings.minio_secret_key,
|
||||
secure=settings.minio_secure
|
||||
)
|
||||
|
||||
# 버킷 생성 (동기 호출)
|
||||
self._ensure_bucket()
|
||||
|
||||
# 로컬 임시 디렉토리 (이미지 처리용)
|
||||
self.temp_dir = Path(tempfile.gettempdir()) / "image_cache_temp"
|
||||
self.temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _ensure_bucket(self):
|
||||
"""버킷이 존재하는지 확인하고 없으면 생성"""
|
||||
try:
|
||||
if not self.client.bucket_exists(settings.minio_bucket_name):
|
||||
self.client.make_bucket(settings.minio_bucket_name)
|
||||
print(f"✅ Created MinIO bucket: {settings.minio_bucket_name}")
|
||||
else:
|
||||
print(f"✅ MinIO bucket exists: {settings.minio_bucket_name}")
|
||||
except S3Error as e:
|
||||
print(f"❌ Error creating bucket: {e}")
|
||||
|
||||
def _get_object_name(self, url: str, size: Optional[str] = None) -> str:
|
||||
"""URL을 기반으로 MinIO 객체 이름 생성"""
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
|
||||
# 3단계 디렉토리 구조 생성 (MinIO는 /를 디렉토리처럼 취급)
|
||||
level1 = url_hash[:2]
|
||||
level2 = url_hash[2:4]
|
||||
level3 = url_hash[4:6]
|
||||
|
||||
# 크기별로 다른 파일명 사용
|
||||
if size:
|
||||
filename = f"{url_hash}_{size}"
|
||||
else:
|
||||
filename = url_hash
|
||||
|
||||
# 확장자 추출 (WebP로 저장되는 경우 .webp 사용)
|
||||
if settings.convert_to_webp and size:
|
||||
filename = f"{filename}.webp"
|
||||
else:
|
||||
ext = self._get_extension_from_url(url)
|
||||
if ext:
|
||||
filename = f"{filename}.{ext}"
|
||||
|
||||
# MinIO 객체 경로 생성
|
||||
object_name = f"{level1}/{level2}/{level3}/{filename}"
|
||||
return object_name
|
||||
|
||||
def _get_extension_from_url(self, url: str) -> Optional[str]:
|
||||
"""URL에서 파일 확장자 추출"""
|
||||
path = url.split('?')[0] # 쿼리 파라미터 제거
|
||||
parts = path.split('.')
|
||||
if len(parts) > 1:
|
||||
ext = parts[-1].lower()
|
||||
if ext in settings.allowed_formats:
|
||||
return ext
|
||||
return None
|
||||
|
||||
def _is_svg(self, data: bytes) -> bool:
|
||||
"""SVG 파일인지 확인"""
|
||||
if len(data) < 100:
|
||||
return False
|
||||
|
||||
header = data[:1000].lower()
|
||||
svg_signatures = [
|
||||
b'<svg',
|
||||
b'<?xml',
|
||||
b'<!doctype svg'
|
||||
]
|
||||
|
||||
for sig in svg_signatures:
|
||||
if sig in header:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _process_gif(self, gif_data: bytes, target_size: tuple) -> tuple[bytes, str]:
|
||||
"""GIF 처리 - JPEG로 변환하여 안정적으로 처리"""
|
||||
try:
|
||||
img = Image.open(io.BytesIO(gif_data))
|
||||
|
||||
if img.mode != 'RGB':
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
if img.mode == 'RGBA':
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
background.paste(img, mask=img.split()[3] if len(img.split()) == 4 else None)
|
||||
img = background
|
||||
elif img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 리사이즈
|
||||
img.thumbnail(target_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# JPEG로 저장
|
||||
output = io.BytesIO()
|
||||
img.save(
|
||||
output,
|
||||
format='JPEG',
|
||||
quality=settings.jpeg_quality,
|
||||
optimize=True,
|
||||
progressive=settings.progressive_jpeg
|
||||
)
|
||||
|
||||
return output.getvalue(), 'image/jpeg'
|
||||
|
||||
except Exception as e:
|
||||
print(f"GIF 처리 오류: {e}")
|
||||
return gif_data, 'image/gif'
|
||||
|
||||
def resize_and_optimize_image(self, image_data: bytes, size: str) -> tuple[bytes, str]:
|
||||
"""이미지 리사이징 및 최적화"""
|
||||
try:
|
||||
target_size = settings.thumbnail_sizes.get(size, settings.thumbnail_sizes["thumb"])
|
||||
|
||||
# 이미지 열기
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# EXIF 회전 정보 처리
|
||||
try:
|
||||
from PIL import ImageOps
|
||||
img = ImageOps.exif_transpose(img)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 리사이즈 (원본 비율 유지)
|
||||
img.thumbnail(target_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 출력 버퍼
|
||||
output = io.BytesIO()
|
||||
|
||||
# WebP로 변환 설정이 활성화되어 있으면
|
||||
if settings.convert_to_webp:
|
||||
# RGBA를 RGB로 변환 (WebP는 투명도 지원하지만 일부 브라우저 호환성 문제)
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
# 투명 배경을 흰색으로
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
background.paste(img, mask=img.split()[-1] if 'A' in img.mode else None)
|
||||
img = background
|
||||
elif img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# WebP로 저장
|
||||
img.save(
|
||||
output,
|
||||
format='WEBP',
|
||||
quality=settings.webp_quality,
|
||||
lossless=settings.webp_lossless,
|
||||
method=6 # 최고 압축
|
||||
)
|
||||
content_type = 'image/webp'
|
||||
else:
|
||||
# 원본 포맷 유지하면서 최적화
|
||||
if img.format == 'PNG':
|
||||
img.save(
|
||||
output,
|
||||
format='PNG',
|
||||
compress_level=settings.png_compress_level,
|
||||
optimize=settings.optimize_png
|
||||
)
|
||||
content_type = 'image/png'
|
||||
else:
|
||||
# JPEG로 변환
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
img.save(
|
||||
output,
|
||||
format='JPEG',
|
||||
quality=settings.jpeg_quality,
|
||||
optimize=True,
|
||||
progressive=settings.progressive_jpeg
|
||||
)
|
||||
content_type = 'image/jpeg'
|
||||
|
||||
return output.getvalue(), content_type
|
||||
|
||||
except Exception as e:
|
||||
print(f"이미지 최적화 오류: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return image_data, 'image/jpeg'
|
||||
|
||||
async def get(self, url: str, size: Optional[str] = None) -> Optional[bytes]:
|
||||
"""MinIO에서 캐시된 이미지 가져오기"""
|
||||
object_name = self._get_object_name(url, size)
|
||||
|
||||
try:
|
||||
# MinIO에서 객체 가져오기
|
||||
response = self.client.get_object(settings.minio_bucket_name, object_name)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
print(f"✅ Cache HIT from MinIO: {object_name}")
|
||||
return data
|
||||
|
||||
except S3Error as e:
|
||||
if e.code == 'NoSuchKey':
|
||||
print(f"📭 Cache MISS in MinIO: {object_name}")
|
||||
return None
|
||||
else:
|
||||
print(f"❌ MinIO error: {e}")
|
||||
return None
|
||||
|
||||
async def set(self, url: str, data: bytes, size: Optional[str] = None):
|
||||
"""MinIO에 이미지 캐시 저장"""
|
||||
object_name = self._get_object_name(url, size)
|
||||
|
||||
try:
|
||||
# 바이트 데이터를 스트림으로 변환
|
||||
data_stream = io.BytesIO(data)
|
||||
data_length = len(data)
|
||||
|
||||
# content-type 결정
|
||||
if url.lower().endswith('.svg') or self._is_svg(data):
|
||||
content_type = 'image/svg+xml'
|
||||
elif url.lower().endswith('.gif'):
|
||||
content_type = 'image/gif'
|
||||
elif settings.convert_to_webp and size:
|
||||
content_type = 'image/webp'
|
||||
else:
|
||||
content_type = 'application/octet-stream'
|
||||
|
||||
# MinIO에 저장 (메타데이터는 ASCII만 지원하므로 URL 해시 사용)
|
||||
self.client.put_object(
|
||||
settings.minio_bucket_name,
|
||||
object_name,
|
||||
data_stream,
|
||||
data_length,
|
||||
content_type=content_type,
|
||||
metadata={
|
||||
'url_hash': hashlib.md5(url.encode()).hexdigest(),
|
||||
'cached_at': datetime.utcnow().isoformat(),
|
||||
'size_variant': size or 'original'
|
||||
}
|
||||
)
|
||||
|
||||
print(f"✅ Cached to MinIO: {object_name} ({data_length} bytes)")
|
||||
|
||||
except S3Error as e:
|
||||
print(f"❌ Failed to cache to MinIO: {e}")
|
||||
|
||||
async def download_image(self, url: str) -> bytes:
|
||||
"""외부 URL에서 이미지 다운로드"""
|
||||
# SSL 검증 비활성화 (개발 환경용)
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=settings.request_timeout,
|
||||
verify=False,
|
||||
follow_redirects=True
|
||||
) as client:
|
||||
headers = {
|
||||
"User-Agent": settings.user_agent,
|
||||
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"Cache-Control": "no-cache",
|
||||
"Referer": url.split('/')[0] + '//' + url.split('/')[2] if len(url.split('/')) > 2 else url
|
||||
}
|
||||
|
||||
response = await client.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 403:
|
||||
headers["User-Agent"] = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
|
||||
response = await client.get(url, headers=headers)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length:
|
||||
size_mb = int(content_length) / (1024 * 1024)
|
||||
if size_mb > settings.max_image_size_mb:
|
||||
raise ValueError(f"이미지 크기가 {settings.max_image_size_mb}MB를 초과합니다")
|
||||
|
||||
return response.content
|
||||
|
||||
async def get_cache_size(self) -> float:
|
||||
"""MinIO 버킷 크기 조회 (GB)"""
|
||||
try:
|
||||
total_size = 0
|
||||
objects = self.client.list_objects(settings.minio_bucket_name, recursive=True)
|
||||
|
||||
for obj in objects:
|
||||
total_size += obj.size
|
||||
|
||||
return total_size / (1024 ** 3) # GB로 변환
|
||||
|
||||
except S3Error as e:
|
||||
print(f"❌ Failed to get cache size: {e}")
|
||||
return 0.0
|
||||
|
||||
async def get_directory_stats(self) -> dict:
|
||||
"""MinIO 디렉토리 구조 통계"""
|
||||
try:
|
||||
total_files = 0
|
||||
directories = set()
|
||||
|
||||
objects = self.client.list_objects(settings.minio_bucket_name, recursive=True)
|
||||
|
||||
for obj in objects:
|
||||
total_files += 1
|
||||
# 디렉토리 경로 추출
|
||||
parts = obj.object_name.split('/')
|
||||
if len(parts) > 1:
|
||||
dir_path = '/'.join(parts[:-1])
|
||||
directories.add(dir_path)
|
||||
|
||||
return {
|
||||
"total_files": total_files,
|
||||
"total_directories": len(directories),
|
||||
"average_files_per_directory": total_files / max(len(directories), 1),
|
||||
"bucket_name": settings.minio_bucket_name
|
||||
}
|
||||
|
||||
except S3Error as e:
|
||||
print(f"❌ Failed to get directory stats: {e}")
|
||||
return {
|
||||
"total_files": 0,
|
||||
"total_directories": 0,
|
||||
"average_files_per_directory": 0,
|
||||
"bucket_name": settings.minio_bucket_name
|
||||
}
|
||||
|
||||
async def cleanup_old_cache(self):
|
||||
"""오래된 캐시 정리"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=settings.cache_ttl_days)
|
||||
deleted_count = 0
|
||||
|
||||
objects = self.client.list_objects(settings.minio_bucket_name, recursive=True)
|
||||
|
||||
for obj in objects:
|
||||
# 객체의 마지막 수정 시간이 cutoff_date 이전이면 삭제
|
||||
if obj.last_modified.replace(tzinfo=None) < cutoff_date:
|
||||
self.client.remove_object(settings.minio_bucket_name, obj.object_name)
|
||||
deleted_count += 1
|
||||
print(f"🗑️ Deleted old cache: {obj.object_name}")
|
||||
|
||||
print(f"✅ Cleaned up {deleted_count} old cached files")
|
||||
return deleted_count
|
||||
|
||||
except S3Error as e:
|
||||
print(f"❌ Failed to cleanup cache: {e}")
|
||||
return 0
|
||||
|
||||
async def trigger_background_generation(self, url: str):
|
||||
"""백그라운드에서 다양한 크기 생성"""
|
||||
asyncio.create_task(self._generate_all_sizes(url))
|
||||
|
||||
async def _generate_all_sizes(self, url: str):
|
||||
"""모든 크기 버전 생성"""
|
||||
try:
|
||||
# 원본 이미지 다운로드
|
||||
image_data = await self.download_image(url)
|
||||
|
||||
# SVG는 리사이징 불필요
|
||||
if self._is_svg(image_data):
|
||||
return
|
||||
|
||||
# 모든 크기 생성
|
||||
for size_name in settings.thumbnail_sizes.keys():
|
||||
# 이미 캐시되어 있는지 확인
|
||||
existing = await self.get(url, size_name)
|
||||
if not existing:
|
||||
# 리사이징 및 최적화
|
||||
if url.lower().endswith('.gif'):
|
||||
resized_data, _ = self._process_gif(image_data, settings.thumbnail_sizes[size_name])
|
||||
else:
|
||||
resized_data, _ = self.resize_and_optimize_image(image_data, size_name)
|
||||
|
||||
# 캐시에 저장
|
||||
await self.set(url, resized_data, size_name)
|
||||
|
||||
print(f"✅ Generated {size_name} version for {url}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Background generation failed for {url}: {e}")
|
||||
|
||||
# 싱글톤 인스턴스
|
||||
cache = MinIOImageCache()
|
||||
65
services/images/backend/main.py
Normal file
65
services/images/backend/main.py
Normal file
@ -0,0 +1,65 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
from datetime import datetime
|
||||
|
||||
from app.api.endpoints import router
|
||||
from app.core.config import settings
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 시작 시
|
||||
print("Images service starting...")
|
||||
yield
|
||||
# 종료 시
|
||||
print("Images service stopping...")
|
||||
|
||||
app = FastAPI(
|
||||
title="Images Service",
|
||||
description="이미지 업로드, 프록시 및 캐싱 서비스",
|
||||
version="2.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS 설정
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 라우터 등록
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "Images Service",
|
||||
"version": "2.0.0",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"endpoints": {
|
||||
"proxy": "/api/v1/image?url=<image_url>&size=<optional_size>",
|
||||
"upload": "/api/v1/upload",
|
||||
"stats": "/api/v1/stats",
|
||||
"cleanup": "/api/v1/cleanup"
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "images",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True
|
||||
)
|
||||
12
services/images/backend/requirements.txt
Normal file
12
services/images/backend/requirements.txt
Normal file
@ -0,0 +1,12 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
httpx==0.26.0
|
||||
pillow==10.2.0
|
||||
pillow-heif==0.20.0
|
||||
aiofiles==23.2.1
|
||||
python-multipart==0.0.6
|
||||
pydantic==2.5.3
|
||||
pydantic-settings==2.1.0
|
||||
motor==3.3.2
|
||||
redis==5.0.1
|
||||
minio==7.2.3
|
||||
21
services/notifications/backend/Dockerfile
Normal file
21
services/notifications/backend/Dockerfile
Normal file
@ -0,0 +1,21 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
335
services/notifications/backend/channel_handlers.py
Normal file
335
services/notifications/backend/channel_handlers.py
Normal file
@ -0,0 +1,335 @@
|
||||
"""
|
||||
Channel Handlers for different notification delivery methods
|
||||
"""
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
from models import Notification, NotificationStatus
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
import httpx
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseChannelHandler:
|
||||
"""Base class for channel handlers"""
|
||||
|
||||
async def send(self, notification: Notification) -> bool:
|
||||
"""Send notification through the channel"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def verify_delivery(self, notification: Notification) -> bool:
|
||||
"""Verify if notification was delivered"""
|
||||
return True
|
||||
|
||||
class EmailHandler(BaseChannelHandler):
|
||||
"""Email notification handler"""
|
||||
|
||||
def __init__(self, smtp_host: str, smtp_port: int, smtp_user: str, smtp_password: str):
|
||||
self.smtp_host = smtp_host
|
||||
self.smtp_port = smtp_port
|
||||
self.smtp_user = smtp_user
|
||||
self.smtp_password = smtp_password
|
||||
|
||||
async def send(self, notification: Notification) -> bool:
|
||||
"""Send email notification"""
|
||||
try:
|
||||
# In production, would use async SMTP library
|
||||
# For demo, we'll simulate email sending
|
||||
logger.info(f"Sending email to user {notification.user_id}")
|
||||
|
||||
if not self.smtp_user or not self.smtp_password:
|
||||
# Simulate sending without actual SMTP config
|
||||
await asyncio.sleep(0.1) # Simulate network delay
|
||||
logger.info(f"Email sent (simulated) to user {notification.user_id}")
|
||||
return True
|
||||
|
||||
# Create message
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = self.smtp_user
|
||||
msg['To'] = f"user_{notification.user_id}@example.com" # Would fetch actual email
|
||||
msg['Subject'] = notification.title
|
||||
|
||||
# Add body
|
||||
body = notification.message
|
||||
if notification.data and "html_content" in notification.data:
|
||||
msg.attach(MIMEText(notification.data["html_content"], 'html'))
|
||||
else:
|
||||
msg.attach(MIMEText(body, 'plain'))
|
||||
|
||||
# Send email (would be async in production)
|
||||
# server = smtplib.SMTP(self.smtp_host, self.smtp_port)
|
||||
# server.starttls()
|
||||
# server.login(self.smtp_user, self.smtp_password)
|
||||
# server.send_message(msg)
|
||||
# server.quit()
|
||||
|
||||
logger.info(f"Email sent successfully to user {notification.user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email: {e}")
|
||||
return False
|
||||
|
||||
class SMSHandler(BaseChannelHandler):
|
||||
"""SMS notification handler"""
|
||||
|
||||
def __init__(self, api_key: str, api_url: str):
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.client = httpx.AsyncClient()
|
||||
|
||||
async def send(self, notification: Notification) -> bool:
|
||||
"""Send SMS notification"""
|
||||
try:
|
||||
# In production, would integrate with SMS provider (Twilio, etc.)
|
||||
logger.info(f"Sending SMS to user {notification.user_id}")
|
||||
|
||||
if not self.api_key or not self.api_url:
|
||||
# Simulate sending without actual API config
|
||||
await asyncio.sleep(0.1) # Simulate network delay
|
||||
logger.info(f"SMS sent (simulated) to user {notification.user_id}")
|
||||
return True
|
||||
|
||||
# Would fetch user's phone number from database
|
||||
phone_number = notification.data.get("phone") if notification.data else None
|
||||
if not phone_number:
|
||||
phone_number = "+1234567890" # Demo number
|
||||
|
||||
# Send SMS via API (example structure)
|
||||
payload = {
|
||||
"to": phone_number,
|
||||
"message": f"{notification.title}\n{notification.message}",
|
||||
"api_key": self.api_key
|
||||
}
|
||||
|
||||
# response = await self.client.post(self.api_url, json=payload)
|
||||
# return response.status_code == 200
|
||||
|
||||
# Simulate success
|
||||
await asyncio.sleep(0.1)
|
||||
logger.info(f"SMS sent successfully to user {notification.user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send SMS: {e}")
|
||||
return False
|
||||
|
||||
class PushHandler(BaseChannelHandler):
|
||||
"""Push notification handler (FCM/APNS)"""
|
||||
|
||||
def __init__(self, fcm_server_key: str):
|
||||
self.fcm_server_key = fcm_server_key
|
||||
self.fcm_url = "https://fcm.googleapis.com/fcm/send"
|
||||
self.client = httpx.AsyncClient()
|
||||
|
||||
async def send(self, notification: Notification) -> bool:
|
||||
"""Send push notification"""
|
||||
try:
|
||||
logger.info(f"Sending push notification to user {notification.user_id}")
|
||||
|
||||
if not self.fcm_server_key:
|
||||
# Simulate sending without actual FCM config
|
||||
await asyncio.sleep(0.1)
|
||||
logger.info(f"Push notification sent (simulated) to user {notification.user_id}")
|
||||
return True
|
||||
|
||||
# Would fetch user's device tokens from database
|
||||
device_tokens = notification.data.get("device_tokens", []) if notification.data else []
|
||||
|
||||
if not device_tokens:
|
||||
# Simulate with dummy token
|
||||
device_tokens = ["dummy_token"]
|
||||
|
||||
# Send to each device token
|
||||
for token in device_tokens:
|
||||
payload = {
|
||||
"to": token,
|
||||
"notification": {
|
||||
"title": notification.title,
|
||||
"body": notification.message,
|
||||
"icon": notification.data.get("icon") if notification.data else None,
|
||||
"click_action": notification.data.get("click_action") if notification.data else None
|
||||
},
|
||||
"data": notification.data or {}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"key={self.fcm_server_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# response = await self.client.post(
|
||||
# self.fcm_url,
|
||||
# json=payload,
|
||||
# headers=headers
|
||||
# )
|
||||
|
||||
# Simulate success
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
logger.info(f"Push notification sent successfully to user {notification.user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send push notification: {e}")
|
||||
return False
|
||||
|
||||
class InAppHandler(BaseChannelHandler):
|
||||
"""In-app notification handler"""
|
||||
|
||||
def __init__(self):
|
||||
self.ws_server = None
|
||||
|
||||
def set_ws_server(self, ws_server):
|
||||
"""Set WebSocket server for real-time delivery"""
|
||||
self.ws_server = ws_server
|
||||
|
||||
async def send(self, notification: Notification) -> bool:
|
||||
"""Send in-app notification"""
|
||||
try:
|
||||
logger.info(f"Sending in-app notification to user {notification.user_id}")
|
||||
|
||||
# Store notification in database (already done in manager)
|
||||
# This would be retrieved when user logs in or requests notifications
|
||||
|
||||
# If WebSocket connection exists, send real-time
|
||||
if self.ws_server:
|
||||
await self.ws_server.send_to_user(
|
||||
notification.user_id,
|
||||
{
|
||||
"type": "notification",
|
||||
"notification": {
|
||||
"id": notification.id,
|
||||
"title": notification.title,
|
||||
"message": notification.message,
|
||||
"priority": notification.priority.value,
|
||||
"category": notification.category.value if hasattr(notification, 'category') else "system",
|
||||
"timestamp": notification.created_at.isoformat(),
|
||||
"data": notification.data
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"In-app notification sent successfully to user {notification.user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send in-app notification: {e}")
|
||||
return False
|
||||
|
||||
class SlackHandler(BaseChannelHandler):
|
||||
"""Slack notification handler"""
|
||||
|
||||
def __init__(self, webhook_url: Optional[str] = None):
|
||||
self.webhook_url = webhook_url
|
||||
self.client = httpx.AsyncClient()
|
||||
|
||||
async def send(self, notification: Notification) -> bool:
|
||||
"""Send Slack notification"""
|
||||
try:
|
||||
logger.info(f"Sending Slack notification for user {notification.user_id}")
|
||||
|
||||
if not self.webhook_url:
|
||||
# Simulate sending
|
||||
await asyncio.sleep(0.1)
|
||||
logger.info(f"Slack notification sent (simulated) for user {notification.user_id}")
|
||||
return True
|
||||
|
||||
# Format message for Slack
|
||||
slack_message = {
|
||||
"text": notification.title,
|
||||
"blocks": [
|
||||
{
|
||||
"type": "header",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": notification.title
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": notification.message
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Add additional fields if present
|
||||
if notification.data:
|
||||
fields = []
|
||||
for key, value in notification.data.items():
|
||||
if key not in ["html_content", "device_tokens"]:
|
||||
fields.append({
|
||||
"type": "mrkdwn",
|
||||
"text": f"*{key}:* {value}"
|
||||
})
|
||||
|
||||
if fields:
|
||||
slack_message["blocks"].append({
|
||||
"type": "section",
|
||||
"fields": fields[:10] # Slack limits to 10 fields
|
||||
})
|
||||
|
||||
# Send to Slack
|
||||
# response = await self.client.post(self.webhook_url, json=slack_message)
|
||||
# return response.status_code == 200
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
logger.info(f"Slack notification sent successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Slack notification: {e}")
|
||||
return False
|
||||
|
||||
class WebhookHandler(BaseChannelHandler):
|
||||
"""Generic webhook notification handler"""
|
||||
|
||||
def __init__(self, default_webhook_url: Optional[str] = None):
|
||||
self.default_webhook_url = default_webhook_url
|
||||
self.client = httpx.AsyncClient()
|
||||
|
||||
async def send(self, notification: Notification) -> bool:
|
||||
"""Send webhook notification"""
|
||||
try:
|
||||
# Get webhook URL from notification data or use default
|
||||
webhook_url = None
|
||||
if notification.data and "webhook_url" in notification.data:
|
||||
webhook_url = notification.data["webhook_url"]
|
||||
else:
|
||||
webhook_url = self.default_webhook_url
|
||||
|
||||
if not webhook_url:
|
||||
logger.warning("No webhook URL configured")
|
||||
return False
|
||||
|
||||
logger.info(f"Sending webhook notification for user {notification.user_id}")
|
||||
|
||||
# Prepare payload
|
||||
payload = {
|
||||
"notification_id": notification.id,
|
||||
"user_id": notification.user_id,
|
||||
"title": notification.title,
|
||||
"message": notification.message,
|
||||
"priority": notification.priority.value,
|
||||
"timestamp": notification.created_at.isoformat(),
|
||||
"data": notification.data
|
||||
}
|
||||
|
||||
# Send webhook
|
||||
# response = await self.client.post(webhook_url, json=payload)
|
||||
# return response.status_code in [200, 201, 202, 204]
|
||||
|
||||
# Simulate success
|
||||
await asyncio.sleep(0.1)
|
||||
logger.info(f"Webhook notification sent successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook notification: {e}")
|
||||
return False
|
||||
514
services/notifications/backend/main.py
Normal file
514
services/notifications/backend/main.py
Normal file
@ -0,0 +1,514 @@
|
||||
"""
|
||||
Notification Service - Real-time Multi-channel Notifications
|
||||
"""
|
||||
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
# Import custom modules
|
||||
from models import (
|
||||
Notification, NotificationChannel, NotificationTemplate,
|
||||
NotificationPreference, NotificationHistory, NotificationStatus,
|
||||
NotificationPriority, CreateNotificationRequest, BulkNotificationRequest
|
||||
)
|
||||
from notification_manager import NotificationManager
|
||||
from channel_handlers import EmailHandler, SMSHandler, PushHandler, InAppHandler
|
||||
from websocket_server import WebSocketNotificationServer
|
||||
from queue_manager import NotificationQueueManager
|
||||
from template_engine import TemplateEngine
|
||||
from preference_manager import PreferenceManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
notification_manager = None
|
||||
ws_server = None
|
||||
queue_manager = None
|
||||
template_engine = None
|
||||
preference_manager = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
global notification_manager, ws_server, queue_manager, template_engine, preference_manager
|
||||
|
||||
try:
|
||||
# Initialize Template Engine
|
||||
template_engine = TemplateEngine()
|
||||
await template_engine.load_templates()
|
||||
logger.info("Template engine initialized")
|
||||
|
||||
# Initialize Preference Manager
|
||||
preference_manager = PreferenceManager(
|
||||
mongodb_url=os.getenv("MONGODB_URL", "mongodb://mongodb:27017"),
|
||||
database_name="notifications"
|
||||
)
|
||||
await preference_manager.connect()
|
||||
logger.info("Preference manager connected")
|
||||
|
||||
# Initialize Notification Queue Manager
|
||||
queue_manager = NotificationQueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
await queue_manager.connect()
|
||||
logger.info("Queue manager connected")
|
||||
|
||||
# Initialize Channel Handlers
|
||||
email_handler = EmailHandler(
|
||||
smtp_host=os.getenv("SMTP_HOST", "smtp.gmail.com"),
|
||||
smtp_port=int(os.getenv("SMTP_PORT", 587)),
|
||||
smtp_user=os.getenv("SMTP_USER", ""),
|
||||
smtp_password=os.getenv("SMTP_PASSWORD", "")
|
||||
)
|
||||
|
||||
sms_handler = SMSHandler(
|
||||
api_key=os.getenv("SMS_API_KEY", ""),
|
||||
api_url=os.getenv("SMS_API_URL", "")
|
||||
)
|
||||
|
||||
push_handler = PushHandler(
|
||||
fcm_server_key=os.getenv("FCM_SERVER_KEY", "")
|
||||
)
|
||||
|
||||
in_app_handler = InAppHandler()
|
||||
|
||||
# Initialize Notification Manager
|
||||
notification_manager = NotificationManager(
|
||||
channel_handlers={
|
||||
NotificationChannel.EMAIL: email_handler,
|
||||
NotificationChannel.SMS: sms_handler,
|
||||
NotificationChannel.PUSH: push_handler,
|
||||
NotificationChannel.IN_APP: in_app_handler
|
||||
},
|
||||
queue_manager=queue_manager,
|
||||
template_engine=template_engine,
|
||||
preference_manager=preference_manager
|
||||
)
|
||||
await notification_manager.start()
|
||||
logger.info("Notification manager started")
|
||||
|
||||
# Initialize WebSocket Server
|
||||
ws_server = WebSocketNotificationServer()
|
||||
logger.info("WebSocket server initialized")
|
||||
|
||||
# Register in-app handler with WebSocket server
|
||||
in_app_handler.set_ws_server(ws_server)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start Notification service: {e}")
|
||||
raise
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if notification_manager:
|
||||
await notification_manager.stop()
|
||||
if queue_manager:
|
||||
await queue_manager.close()
|
||||
if preference_manager:
|
||||
await preference_manager.close()
|
||||
|
||||
logger.info("Notification service shutdown complete")
|
||||
|
||||
app = FastAPI(
|
||||
title="Notification Service",
|
||||
description="Real-time Multi-channel Notification Service",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "Notification Service",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "notifications",
|
||||
"components": {
|
||||
"queue_manager": "connected" if queue_manager and queue_manager.is_connected else "disconnected",
|
||||
"preference_manager": "connected" if preference_manager and preference_manager.is_connected else "disconnected",
|
||||
"notification_manager": "running" if notification_manager and notification_manager.is_running else "stopped",
|
||||
"websocket_connections": len(ws_server.active_connections) if ws_server else 0
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Notification Endpoints
|
||||
@app.post("/api/notifications/send")
|
||||
async def send_notification(
|
||||
request: CreateNotificationRequest,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Send a single notification"""
|
||||
try:
|
||||
notification = await notification_manager.create_notification(
|
||||
user_id=request.user_id,
|
||||
title=request.title,
|
||||
message=request.message,
|
||||
channels=request.channels,
|
||||
priority=request.priority,
|
||||
data=request.data,
|
||||
template_id=request.template_id,
|
||||
schedule_at=request.schedule_at
|
||||
)
|
||||
|
||||
if request.schedule_at and request.schedule_at > datetime.now():
|
||||
# Schedule for later
|
||||
await queue_manager.schedule_notification(notification, request.schedule_at)
|
||||
return {
|
||||
"notification_id": notification.id,
|
||||
"status": "scheduled",
|
||||
"scheduled_at": request.schedule_at.isoformat()
|
||||
}
|
||||
else:
|
||||
# Send immediately
|
||||
background_tasks.add_task(
|
||||
notification_manager.send_notification,
|
||||
notification
|
||||
)
|
||||
return {
|
||||
"notification_id": notification.id,
|
||||
"status": "queued"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/notifications/send-bulk")
|
||||
async def send_bulk_notifications(
|
||||
request: BulkNotificationRequest,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Send notifications to multiple users"""
|
||||
try:
|
||||
notifications = []
|
||||
for user_id in request.user_ids:
|
||||
notification = await notification_manager.create_notification(
|
||||
user_id=user_id,
|
||||
title=request.title,
|
||||
message=request.message,
|
||||
channels=request.channels,
|
||||
priority=request.priority,
|
||||
data=request.data,
|
||||
template_id=request.template_id
|
||||
)
|
||||
notifications.append(notification)
|
||||
|
||||
# Queue all notifications
|
||||
background_tasks.add_task(
|
||||
notification_manager.send_bulk_notifications,
|
||||
notifications
|
||||
)
|
||||
|
||||
return {
|
||||
"count": len(notifications),
|
||||
"notification_ids": [n.id for n in notifications],
|
||||
"status": "queued"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/notifications/user/{user_id}")
|
||||
async def get_user_notifications(
|
||||
user_id: str,
|
||||
status: Optional[NotificationStatus] = None,
|
||||
channel: Optional[NotificationChannel] = None,
|
||||
limit: int = Query(50, le=200),
|
||||
offset: int = Query(0, ge=0)
|
||||
):
|
||||
"""Get notifications for a specific user"""
|
||||
try:
|
||||
notifications = await notification_manager.get_user_notifications(
|
||||
user_id=user_id,
|
||||
status=status,
|
||||
channel=channel,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return {
|
||||
"notifications": notifications,
|
||||
"count": len(notifications),
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.patch("/api/notifications/{notification_id}/read")
|
||||
async def mark_notification_read(notification_id: str):
|
||||
"""Mark a notification as read"""
|
||||
try:
|
||||
success = await notification_manager.mark_as_read(notification_id)
|
||||
if success:
|
||||
return {"status": "marked_as_read", "notification_id": notification_id}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/api/notifications/{notification_id}")
|
||||
async def delete_notification(notification_id: str):
|
||||
"""Delete a notification"""
|
||||
try:
|
||||
success = await notification_manager.delete_notification(notification_id)
|
||||
if success:
|
||||
return {"status": "deleted", "notification_id": notification_id}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Template Endpoints
|
||||
@app.get("/api/templates")
|
||||
async def get_templates():
|
||||
"""Get all notification templates"""
|
||||
templates = await template_engine.get_all_templates()
|
||||
return {"templates": templates}
|
||||
|
||||
@app.post("/api/templates")
|
||||
async def create_template(template: NotificationTemplate):
|
||||
"""Create a new notification template"""
|
||||
try:
|
||||
template_id = await template_engine.create_template(template)
|
||||
return {"template_id": template_id, "status": "created"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.put("/api/templates/{template_id}")
|
||||
async def update_template(template_id: str, template: NotificationTemplate):
|
||||
"""Update an existing template"""
|
||||
try:
|
||||
success = await template_engine.update_template(template_id, template)
|
||||
if success:
|
||||
return {"status": "updated", "template_id": template_id}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Preference Endpoints
|
||||
@app.get("/api/preferences/{user_id}")
|
||||
async def get_user_preferences(user_id: str):
|
||||
"""Get notification preferences for a user"""
|
||||
try:
|
||||
preferences = await preference_manager.get_user_preferences(user_id)
|
||||
return {"user_id": user_id, "preferences": preferences}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.put("/api/preferences/{user_id}")
|
||||
async def update_user_preferences(
|
||||
user_id: str,
|
||||
preferences: NotificationPreference
|
||||
):
|
||||
"""Update notification preferences for a user"""
|
||||
try:
|
||||
success = await preference_manager.update_user_preferences(user_id, preferences)
|
||||
if success:
|
||||
return {"status": "updated", "user_id": user_id}
|
||||
else:
|
||||
return {"status": "created", "user_id": user_id}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/preferences/{user_id}/unsubscribe/{category}")
|
||||
async def unsubscribe_from_category(user_id: str, category: str):
|
||||
"""Unsubscribe user from a notification category"""
|
||||
try:
|
||||
success = await preference_manager.unsubscribe_category(user_id, category)
|
||||
if success:
|
||||
return {"status": "unsubscribed", "user_id": user_id, "category": category}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="User preferences not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# History and Analytics Endpoints
|
||||
@app.get("/api/history")
|
||||
async def get_notification_history(
|
||||
user_id: Optional[str] = None,
|
||||
channel: Optional[NotificationChannel] = None,
|
||||
status: Optional[NotificationStatus] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = Query(100, le=1000)
|
||||
):
|
||||
"""Get notification history with filters"""
|
||||
try:
|
||||
history = await notification_manager.get_notification_history(
|
||||
user_id=user_id,
|
||||
channel=channel,
|
||||
status=status,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"history": history,
|
||||
"count": len(history),
|
||||
"filters": {
|
||||
"user_id": user_id,
|
||||
"channel": channel,
|
||||
"status": status,
|
||||
"date_range": f"{start_date} to {end_date}" if start_date and end_date else None
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/analytics")
|
||||
async def get_notification_analytics(
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
):
|
||||
"""Get notification analytics"""
|
||||
try:
|
||||
if not start_date:
|
||||
start_date = datetime.now() - timedelta(days=7)
|
||||
if not end_date:
|
||||
end_date = datetime.now()
|
||||
|
||||
analytics = await notification_manager.get_analytics(start_date, end_date)
|
||||
return analytics
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Queue Management Endpoints
|
||||
@app.get("/api/queue/status")
|
||||
async def get_queue_status():
|
||||
"""Get current queue status"""
|
||||
try:
|
||||
status = await queue_manager.get_queue_status()
|
||||
return status
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/queue/retry/{notification_id}")
|
||||
async def retry_failed_notification(
|
||||
notification_id: str,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Retry a failed notification"""
|
||||
try:
|
||||
notification = await notification_manager.get_notification(notification_id)
|
||||
if not notification:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
|
||||
if notification.status != NotificationStatus.FAILED:
|
||||
raise HTTPException(status_code=400, detail="Only failed notifications can be retried")
|
||||
|
||||
background_tasks.add_task(
|
||||
notification_manager.retry_notification,
|
||||
notification
|
||||
)
|
||||
|
||||
return {"status": "retry_queued", "notification_id": notification_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# WebSocket Endpoint
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
@app.websocket("/ws/notifications/{user_id}")
|
||||
async def websocket_notifications(websocket: WebSocket, user_id: str):
|
||||
"""WebSocket endpoint for real-time notifications"""
|
||||
await ws_server.connect(websocket, user_id)
|
||||
try:
|
||||
while True:
|
||||
# Keep connection alive and handle incoming messages
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Handle different message types
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
elif data.startswith("read:"):
|
||||
# Mark notification as read
|
||||
notification_id = data.split(":")[1]
|
||||
await notification_manager.mark_as_read(notification_id)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
ws_server.disconnect(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for user {user_id}: {e}")
|
||||
ws_server.disconnect(user_id)
|
||||
|
||||
# Device Token Management
|
||||
@app.post("/api/devices/register")
|
||||
async def register_device_token(
|
||||
user_id: str,
|
||||
device_token: str,
|
||||
device_type: str = Query(..., regex="^(ios|android|web)$")
|
||||
):
|
||||
"""Register a device token for push notifications"""
|
||||
try:
|
||||
success = await notification_manager.register_device_token(
|
||||
user_id=user_id,
|
||||
device_token=device_token,
|
||||
device_type=device_type
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "registered",
|
||||
"user_id": user_id,
|
||||
"device_type": device_type
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to register device token")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/api/devices/{device_token}")
|
||||
async def unregister_device_token(device_token: str):
|
||||
"""Unregister a device token"""
|
||||
try:
|
||||
success = await notification_manager.unregister_device_token(device_token)
|
||||
|
||||
if success:
|
||||
return {"status": "unregistered", "device_token": device_token}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Device token not found")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True
|
||||
)
|
||||
201
services/notifications/backend/models.py
Normal file
201
services/notifications/backend/models.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""
|
||||
Data models for Notification Service
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from enum import Enum
|
||||
|
||||
class NotificationChannel(str, Enum):
|
||||
"""Notification delivery channels"""
|
||||
EMAIL = "email"
|
||||
SMS = "sms"
|
||||
PUSH = "push"
|
||||
IN_APP = "in_app"
|
||||
|
||||
class NotificationStatus(str, Enum):
|
||||
"""Notification status"""
|
||||
PENDING = "pending"
|
||||
SENT = "sent"
|
||||
DELIVERED = "delivered"
|
||||
READ = "read"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class NotificationPriority(str, Enum):
|
||||
"""Notification priority levels"""
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
URGENT = "urgent"
|
||||
|
||||
class NotificationCategory(str, Enum):
|
||||
"""Notification categories"""
|
||||
SYSTEM = "system"
|
||||
MARKETING = "marketing"
|
||||
TRANSACTION = "transaction"
|
||||
SOCIAL = "social"
|
||||
SECURITY = "security"
|
||||
UPDATE = "update"
|
||||
|
||||
class Notification(BaseModel):
|
||||
"""Notification model"""
|
||||
id: Optional[str] = Field(None, description="Unique notification ID")
|
||||
user_id: str = Field(..., description="Target user ID")
|
||||
title: str = Field(..., description="Notification title")
|
||||
message: str = Field(..., description="Notification message")
|
||||
channel: NotificationChannel = Field(..., description="Delivery channel")
|
||||
status: NotificationStatus = Field(default=NotificationStatus.PENDING)
|
||||
priority: NotificationPriority = Field(default=NotificationPriority.NORMAL)
|
||||
category: NotificationCategory = Field(default=NotificationCategory.SYSTEM)
|
||||
data: Optional[Dict[str, Any]] = Field(default=None, description="Additional data")
|
||||
template_id: Optional[str] = Field(None, description="Template ID if using template")
|
||||
scheduled_at: Optional[datetime] = Field(None, description="Scheduled delivery time")
|
||||
sent_at: Optional[datetime] = Field(None, description="Actual sent time")
|
||||
delivered_at: Optional[datetime] = Field(None, description="Delivery confirmation time")
|
||||
read_at: Optional[datetime] = Field(None, description="Read time")
|
||||
retry_count: int = Field(default=0, description="Number of retry attempts")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class NotificationTemplate(BaseModel):
|
||||
"""Notification template model"""
|
||||
id: Optional[str] = Field(None, description="Template ID")
|
||||
name: str = Field(..., description="Template name")
|
||||
channel: NotificationChannel = Field(..., description="Target channel")
|
||||
category: NotificationCategory = Field(..., description="Template category")
|
||||
subject_template: Optional[str] = Field(None, description="Subject template (for email)")
|
||||
body_template: str = Field(..., description="Body template with variables")
|
||||
variables: List[str] = Field(default_factory=list, description="List of required variables")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Template metadata")
|
||||
is_active: bool = Field(default=True, description="Template active status")
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class NotificationPreference(BaseModel):
|
||||
"""User notification preferences"""
|
||||
user_id: str = Field(..., description="User ID")
|
||||
channels: Dict[NotificationChannel, bool] = Field(
|
||||
default_factory=lambda: {
|
||||
NotificationChannel.EMAIL: True,
|
||||
NotificationChannel.SMS: False,
|
||||
NotificationChannel.PUSH: True,
|
||||
NotificationChannel.IN_APP: True
|
||||
}
|
||||
)
|
||||
categories: Dict[NotificationCategory, bool] = Field(
|
||||
default_factory=lambda: {
|
||||
NotificationCategory.SYSTEM: True,
|
||||
NotificationCategory.MARKETING: False,
|
||||
NotificationCategory.TRANSACTION: True,
|
||||
NotificationCategory.SOCIAL: True,
|
||||
NotificationCategory.SECURITY: True,
|
||||
NotificationCategory.UPDATE: True
|
||||
}
|
||||
)
|
||||
quiet_hours: Optional[Dict[str, str]] = Field(
|
||||
default=None,
|
||||
description="Quiet hours configuration {start: 'HH:MM', end: 'HH:MM'}"
|
||||
)
|
||||
timezone: str = Field(default="UTC", description="User timezone")
|
||||
language: str = Field(default="en", description="Preferred language")
|
||||
email_frequency: Literal["immediate", "daily", "weekly"] = Field(default="immediate")
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class NotificationHistory(BaseModel):
|
||||
"""Notification history entry"""
|
||||
notification_id: str
|
||||
user_id: str
|
||||
channel: NotificationChannel
|
||||
status: NotificationStatus
|
||||
title: str
|
||||
message: str
|
||||
sent_at: Optional[datetime]
|
||||
delivered_at: Optional[datetime]
|
||||
read_at: Optional[datetime]
|
||||
error_message: Optional[str]
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class CreateNotificationRequest(BaseModel):
|
||||
"""Request model for creating notification"""
|
||||
user_id: str
|
||||
title: str
|
||||
message: str
|
||||
channels: List[NotificationChannel] = Field(default=[NotificationChannel.IN_APP])
|
||||
priority: NotificationPriority = Field(default=NotificationPriority.NORMAL)
|
||||
category: NotificationCategory = Field(default=NotificationCategory.SYSTEM)
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
template_id: Optional[str] = None
|
||||
schedule_at: Optional[datetime] = None
|
||||
|
||||
class BulkNotificationRequest(BaseModel):
|
||||
"""Request model for bulk notifications"""
|
||||
user_ids: List[str]
|
||||
title: str
|
||||
message: str
|
||||
channels: List[NotificationChannel] = Field(default=[NotificationChannel.IN_APP])
|
||||
priority: NotificationPriority = Field(default=NotificationPriority.NORMAL)
|
||||
category: NotificationCategory = Field(default=NotificationCategory.SYSTEM)
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
template_id: Optional[str] = None
|
||||
|
||||
class DeviceToken(BaseModel):
|
||||
"""Device token for push notifications"""
|
||||
user_id: str
|
||||
token: str
|
||||
device_type: Literal["ios", "android", "web"]
|
||||
app_version: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class NotificationStats(BaseModel):
|
||||
"""Notification statistics"""
|
||||
total_sent: int
|
||||
total_delivered: int
|
||||
total_read: int
|
||||
total_failed: int
|
||||
delivery_rate: float
|
||||
read_rate: float
|
||||
channel_stats: Dict[str, Dict[str, int]]
|
||||
category_stats: Dict[str, Dict[str, int]]
|
||||
period: str
|
||||
|
||||
class NotificationEvent(BaseModel):
|
||||
"""Notification event for tracking"""
|
||||
event_type: Literal["sent", "delivered", "read", "failed", "clicked"]
|
||||
notification_id: str
|
||||
user_id: str
|
||||
channel: NotificationChannel
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
375
services/notifications/backend/notification_manager.py
Normal file
375
services/notifications/backend/notification_manager.py
Normal file
@ -0,0 +1,375 @@
|
||||
"""
|
||||
Notification Manager - Core notification orchestration
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from models import (
|
||||
Notification, NotificationChannel, NotificationStatus,
|
||||
NotificationPriority, NotificationHistory, NotificationPreference
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NotificationManager:
|
||||
"""Manages notification creation, delivery, and tracking"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel_handlers: Dict[NotificationChannel, Any],
|
||||
queue_manager: Any,
|
||||
template_engine: Any,
|
||||
preference_manager: Any
|
||||
):
|
||||
self.channel_handlers = channel_handlers
|
||||
self.queue_manager = queue_manager
|
||||
self.template_engine = template_engine
|
||||
self.preference_manager = preference_manager
|
||||
self.is_running = False
|
||||
self.notification_store = {} # In-memory store for demo
|
||||
self.history_store = [] # In-memory history for demo
|
||||
self.device_tokens = {} # In-memory device tokens for demo
|
||||
|
||||
async def start(self):
|
||||
"""Start notification manager"""
|
||||
self.is_running = True
|
||||
# Start background tasks for processing queued notifications
|
||||
asyncio.create_task(self._process_notification_queue())
|
||||
asyncio.create_task(self._process_scheduled_notifications())
|
||||
logger.info("Notification manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop notification manager"""
|
||||
self.is_running = False
|
||||
logger.info("Notification manager stopped")
|
||||
|
||||
async def create_notification(
|
||||
self,
|
||||
user_id: str,
|
||||
title: str,
|
||||
message: str,
|
||||
channels: List[NotificationChannel],
|
||||
priority: NotificationPriority = NotificationPriority.NORMAL,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
template_id: Optional[str] = None,
|
||||
schedule_at: Optional[datetime] = None
|
||||
) -> Notification:
|
||||
"""Create a new notification"""
|
||||
|
||||
# Check user preferences
|
||||
preferences = await self.preference_manager.get_user_preferences(user_id)
|
||||
if preferences:
|
||||
# Filter channels based on user preferences
|
||||
channels = [ch for ch in channels if preferences.channels.get(ch, True)]
|
||||
|
||||
# Apply template if provided
|
||||
if template_id:
|
||||
template = await self.template_engine.get_template(template_id)
|
||||
if template:
|
||||
message = await self.template_engine.render_template(template, data or {})
|
||||
|
||||
# Create notification objects for each channel
|
||||
notification = Notification(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
message=message,
|
||||
channel=channels[0] if channels else NotificationChannel.IN_APP,
|
||||
priority=priority,
|
||||
data=data,
|
||||
template_id=template_id,
|
||||
scheduled_at=schedule_at,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
# Store notification
|
||||
self.notification_store[notification.id] = notification
|
||||
|
||||
logger.info(f"Created notification {notification.id} for user {user_id}")
|
||||
return notification
|
||||
|
||||
async def send_notification(self, notification: Notification):
|
||||
"""Send a single notification"""
|
||||
try:
|
||||
# Check if notification should be sent now
|
||||
if notification.scheduled_at and notification.scheduled_at > datetime.now():
|
||||
await self.queue_manager.schedule_notification(notification, notification.scheduled_at)
|
||||
return
|
||||
|
||||
# Get the appropriate handler
|
||||
handler = self.channel_handlers.get(notification.channel)
|
||||
if not handler:
|
||||
raise ValueError(f"No handler for channel {notification.channel}")
|
||||
|
||||
# Send through the channel
|
||||
success = await handler.send(notification)
|
||||
|
||||
if success:
|
||||
notification.status = NotificationStatus.SENT
|
||||
notification.sent_at = datetime.now()
|
||||
logger.info(f"Notification {notification.id} sent successfully")
|
||||
else:
|
||||
notification.status = NotificationStatus.FAILED
|
||||
notification.retry_count += 1
|
||||
logger.error(f"Failed to send notification {notification.id}")
|
||||
|
||||
# Retry if needed
|
||||
if notification.retry_count < self._get_max_retries(notification.priority):
|
||||
await self.queue_manager.enqueue_notification(notification)
|
||||
|
||||
# Update notification
|
||||
self.notification_store[notification.id] = notification
|
||||
|
||||
# Add to history
|
||||
await self._add_to_history(notification)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending notification {notification.id}: {e}")
|
||||
notification.status = NotificationStatus.FAILED
|
||||
notification.error_message = str(e)
|
||||
self.notification_store[notification.id] = notification
|
||||
|
||||
async def send_bulk_notifications(self, notifications: List[Notification]):
|
||||
"""Send multiple notifications"""
|
||||
tasks = []
|
||||
for notification in notifications:
|
||||
tasks.append(self.send_notification(notification))
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def mark_as_read(self, notification_id: str) -> bool:
|
||||
"""Mark notification as read"""
|
||||
notification = self.notification_store.get(notification_id)
|
||||
if notification:
|
||||
notification.status = NotificationStatus.READ
|
||||
notification.read_at = datetime.now()
|
||||
self.notification_store[notification_id] = notification
|
||||
logger.info(f"Notification {notification_id} marked as read")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def delete_notification(self, notification_id: str) -> bool:
|
||||
"""Delete a notification"""
|
||||
if notification_id in self.notification_store:
|
||||
del self.notification_store[notification_id]
|
||||
logger.info(f"Notification {notification_id} deleted")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_notification(self, notification_id: str) -> Optional[Notification]:
|
||||
"""Get a notification by ID"""
|
||||
return self.notification_store.get(notification_id)
|
||||
|
||||
async def get_user_notifications(
|
||||
self,
|
||||
user_id: str,
|
||||
status: Optional[NotificationStatus] = None,
|
||||
channel: Optional[NotificationChannel] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Notification]:
|
||||
"""Get notifications for a user"""
|
||||
notifications = []
|
||||
|
||||
for notification in self.notification_store.values():
|
||||
if notification.user_id != user_id:
|
||||
continue
|
||||
if status and notification.status != status:
|
||||
continue
|
||||
if channel and notification.channel != channel:
|
||||
continue
|
||||
notifications.append(notification)
|
||||
|
||||
# Sort by created_at descending
|
||||
notifications.sort(key=lambda x: x.created_at, reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return notifications[offset:offset + limit]
|
||||
|
||||
async def retry_notification(self, notification: Notification):
|
||||
"""Retry a failed notification"""
|
||||
notification.retry_count += 1
|
||||
notification.status = NotificationStatus.PENDING
|
||||
notification.error_message = None
|
||||
await self.send_notification(notification)
|
||||
|
||||
async def get_notification_history(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
channel: Optional[NotificationChannel] = None,
|
||||
status: Optional[NotificationStatus] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[NotificationHistory]:
|
||||
"""Get notification history"""
|
||||
history = []
|
||||
|
||||
for entry in self.history_store:
|
||||
if user_id and entry.user_id != user_id:
|
||||
continue
|
||||
if channel and entry.channel != channel:
|
||||
continue
|
||||
if status and entry.status != status:
|
||||
continue
|
||||
if start_date and entry.sent_at and entry.sent_at < start_date:
|
||||
continue
|
||||
if end_date and entry.sent_at and entry.sent_at > end_date:
|
||||
continue
|
||||
history.append(entry)
|
||||
|
||||
# Sort by sent_at descending and limit
|
||||
history.sort(key=lambda x: x.sent_at or datetime.min, reverse=True)
|
||||
return history[:limit]
|
||||
|
||||
async def get_analytics(self, start_date: datetime, end_date: datetime) -> Dict[str, Any]:
|
||||
"""Get notification analytics"""
|
||||
total_sent = 0
|
||||
total_delivered = 0
|
||||
total_read = 0
|
||||
total_failed = 0
|
||||
channel_stats = {}
|
||||
|
||||
for notification in self.notification_store.values():
|
||||
if notification.created_at < start_date or notification.created_at > end_date:
|
||||
continue
|
||||
|
||||
if notification.status == NotificationStatus.SENT:
|
||||
total_sent += 1
|
||||
elif notification.status == NotificationStatus.DELIVERED:
|
||||
total_delivered += 1
|
||||
elif notification.status == NotificationStatus.READ:
|
||||
total_read += 1
|
||||
elif notification.status == NotificationStatus.FAILED:
|
||||
total_failed += 1
|
||||
|
||||
# Channel stats
|
||||
channel_name = notification.channel.value
|
||||
if channel_name not in channel_stats:
|
||||
channel_stats[channel_name] = {
|
||||
"sent": 0,
|
||||
"delivered": 0,
|
||||
"read": 0,
|
||||
"failed": 0
|
||||
}
|
||||
|
||||
if notification.status == NotificationStatus.SENT:
|
||||
channel_stats[channel_name]["sent"] += 1
|
||||
elif notification.status == NotificationStatus.DELIVERED:
|
||||
channel_stats[channel_name]["delivered"] += 1
|
||||
elif notification.status == NotificationStatus.READ:
|
||||
channel_stats[channel_name]["read"] += 1
|
||||
elif notification.status == NotificationStatus.FAILED:
|
||||
channel_stats[channel_name]["failed"] += 1
|
||||
|
||||
total = total_sent + total_delivered + total_read + total_failed
|
||||
|
||||
return {
|
||||
"period": f"{start_date.isoformat()} to {end_date.isoformat()}",
|
||||
"total_notifications": total,
|
||||
"total_sent": total_sent,
|
||||
"total_delivered": total_delivered,
|
||||
"total_read": total_read,
|
||||
"total_failed": total_failed,
|
||||
"delivery_rate": (total_delivered / total * 100) if total > 0 else 0,
|
||||
"read_rate": (total_read / total * 100) if total > 0 else 0,
|
||||
"channel_stats": channel_stats
|
||||
}
|
||||
|
||||
async def register_device_token(
|
||||
self,
|
||||
user_id: str,
|
||||
device_token: str,
|
||||
device_type: str
|
||||
) -> bool:
|
||||
"""Register a device token for push notifications"""
|
||||
if user_id not in self.device_tokens:
|
||||
self.device_tokens[user_id] = []
|
||||
|
||||
# Check if token already exists
|
||||
for token in self.device_tokens[user_id]:
|
||||
if token["token"] == device_token:
|
||||
# Update existing token
|
||||
token["device_type"] = device_type
|
||||
token["updated_at"] = datetime.now()
|
||||
return True
|
||||
|
||||
# Add new token
|
||||
self.device_tokens[user_id].append({
|
||||
"token": device_token,
|
||||
"device_type": device_type,
|
||||
"created_at": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info(f"Registered device token for user {user_id}")
|
||||
return True
|
||||
|
||||
async def unregister_device_token(self, device_token: str) -> bool:
|
||||
"""Unregister a device token"""
|
||||
for user_id, tokens in self.device_tokens.items():
|
||||
for i, token in enumerate(tokens):
|
||||
if token["token"] == device_token:
|
||||
del self.device_tokens[user_id][i]
|
||||
logger.info(f"Unregistered device token for user {user_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_max_retries(self, priority: NotificationPriority) -> int:
|
||||
"""Get max retries based on priority"""
|
||||
retry_map = {
|
||||
NotificationPriority.LOW: 1,
|
||||
NotificationPriority.NORMAL: 3,
|
||||
NotificationPriority.HIGH: 5,
|
||||
NotificationPriority.URGENT: 10
|
||||
}
|
||||
return retry_map.get(priority, 3)
|
||||
|
||||
async def _add_to_history(self, notification: Notification):
|
||||
"""Add notification to history"""
|
||||
history_entry = NotificationHistory(
|
||||
notification_id=notification.id,
|
||||
user_id=notification.user_id,
|
||||
channel=notification.channel,
|
||||
status=notification.status,
|
||||
title=notification.title,
|
||||
message=notification.message,
|
||||
sent_at=notification.sent_at,
|
||||
delivered_at=notification.delivered_at,
|
||||
read_at=notification.read_at,
|
||||
error_message=notification.error_message,
|
||||
metadata={"priority": notification.priority.value}
|
||||
)
|
||||
self.history_store.append(history_entry)
|
||||
|
||||
async def _process_notification_queue(self):
|
||||
"""Process queued notifications"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# Get notification from queue
|
||||
notification_data = await self.queue_manager.dequeue_notification()
|
||||
if notification_data:
|
||||
notification = Notification(**notification_data)
|
||||
await self.send_notification(notification)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing notification queue: {e}")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _process_scheduled_notifications(self):
|
||||
"""Process scheduled notifications"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# Check for scheduled notifications
|
||||
now = datetime.now()
|
||||
for notification in self.notification_store.values():
|
||||
if (notification.scheduled_at and
|
||||
notification.scheduled_at <= now and
|
||||
notification.status == NotificationStatus.PENDING):
|
||||
await self.send_notification(notification)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing scheduled notifications: {e}")
|
||||
|
||||
await asyncio.sleep(10) # Check every 10 seconds
|
||||
340
services/notifications/backend/preference_manager.py
Normal file
340
services/notifications/backend/preference_manager.py
Normal file
@ -0,0 +1,340 @@
|
||||
"""
|
||||
Preference Manager for user notification preferences
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
import motor.motor_asyncio
|
||||
from models import NotificationPreference, NotificationChannel, NotificationCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PreferenceManager:
|
||||
"""Manages user notification preferences"""
|
||||
|
||||
def __init__(self, mongodb_url: str = "mongodb://mongodb:27017", database_name: str = "notifications"):
|
||||
self.mongodb_url = mongodb_url
|
||||
self.database_name = database_name
|
||||
self.client = None
|
||||
self.db = None
|
||||
self.preferences_collection = None
|
||||
self.is_connected = False
|
||||
|
||||
# In-memory cache for demo
|
||||
self.preferences_cache = {}
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to MongoDB"""
|
||||
try:
|
||||
self.client = motor.motor_asyncio.AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = self.client[self.database_name]
|
||||
self.preferences_collection = self.db["preferences"]
|
||||
|
||||
# Test connection
|
||||
await self.client.admin.command('ping')
|
||||
self.is_connected = True
|
||||
|
||||
# Create indexes
|
||||
await self._create_indexes()
|
||||
|
||||
logger.info("Connected to MongoDB for preferences")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to MongoDB: {e}")
|
||||
# Fallback to in-memory storage
|
||||
self.is_connected = False
|
||||
logger.warning("Using in-memory storage for preferences")
|
||||
|
||||
async def close(self):
|
||||
"""Close MongoDB connection"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.is_connected = False
|
||||
logger.info("Disconnected from MongoDB")
|
||||
|
||||
async def _create_indexes(self):
|
||||
"""Create database indexes"""
|
||||
if self.preferences_collection:
|
||||
try:
|
||||
await self.preferences_collection.create_index("user_id", unique=True)
|
||||
logger.info("Created indexes for preferences collection")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create indexes: {e}")
|
||||
|
||||
async def get_user_preferences(self, user_id: str) -> Optional[NotificationPreference]:
|
||||
"""Get notification preferences for a user"""
|
||||
try:
|
||||
# Check cache first
|
||||
if user_id in self.preferences_cache:
|
||||
return self.preferences_cache[user_id]
|
||||
|
||||
if self.is_connected and self.preferences_collection:
|
||||
# Get from MongoDB
|
||||
doc = await self.preferences_collection.find_one({"user_id": user_id})
|
||||
|
||||
if doc:
|
||||
# Convert document to model
|
||||
doc.pop('_id', None) # Remove MongoDB ID
|
||||
preference = NotificationPreference(**doc)
|
||||
|
||||
# Update cache
|
||||
self.preferences_cache[user_id] = preference
|
||||
|
||||
return preference
|
||||
|
||||
# Return default preferences if not found
|
||||
return self._get_default_preferences(user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get preferences for user {user_id}: {e}")
|
||||
return self._get_default_preferences(user_id)
|
||||
|
||||
async def update_user_preferences(
|
||||
self,
|
||||
user_id: str,
|
||||
preferences: NotificationPreference
|
||||
) -> bool:
|
||||
"""Update notification preferences for a user"""
|
||||
try:
|
||||
preferences.user_id = user_id
|
||||
preferences.updated_at = datetime.now()
|
||||
|
||||
# Update cache
|
||||
self.preferences_cache[user_id] = preferences
|
||||
|
||||
if self.is_connected and self.preferences_collection:
|
||||
# Convert to dict for MongoDB
|
||||
pref_dict = preferences.dict()
|
||||
|
||||
# Upsert in MongoDB
|
||||
result = await self.preferences_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$set": pref_dict},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
logger.info(f"Updated preferences for user {user_id}")
|
||||
return result.modified_count > 0 or result.upserted_id is not None
|
||||
|
||||
# If not connected, just use cache
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update preferences for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def unsubscribe_category(self, user_id: str, category: str) -> bool:
|
||||
"""Unsubscribe user from a notification category"""
|
||||
try:
|
||||
preferences = await self.get_user_preferences(user_id)
|
||||
|
||||
if not preferences:
|
||||
preferences = self._get_default_preferences(user_id)
|
||||
|
||||
# Update category preference
|
||||
if hasattr(NotificationCategory, category.upper()):
|
||||
cat_enum = NotificationCategory(category.lower())
|
||||
preferences.categories[cat_enum] = False
|
||||
|
||||
# Save updated preferences
|
||||
return await self.update_user_preferences(user_id, preferences)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unsubscribe user {user_id} from {category}: {e}")
|
||||
return False
|
||||
|
||||
async def subscribe_category(self, user_id: str, category: str) -> bool:
|
||||
"""Subscribe user to a notification category"""
|
||||
try:
|
||||
preferences = await self.get_user_preferences(user_id)
|
||||
|
||||
if not preferences:
|
||||
preferences = self._get_default_preferences(user_id)
|
||||
|
||||
# Update category preference
|
||||
if hasattr(NotificationCategory, category.upper()):
|
||||
cat_enum = NotificationCategory(category.lower())
|
||||
preferences.categories[cat_enum] = True
|
||||
|
||||
# Save updated preferences
|
||||
return await self.update_user_preferences(user_id, preferences)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe user {user_id} to {category}: {e}")
|
||||
return False
|
||||
|
||||
async def enable_channel(self, user_id: str, channel: NotificationChannel) -> bool:
|
||||
"""Enable a notification channel for user"""
|
||||
try:
|
||||
preferences = await self.get_user_preferences(user_id)
|
||||
|
||||
if not preferences:
|
||||
preferences = self._get_default_preferences(user_id)
|
||||
|
||||
preferences.channels[channel] = True
|
||||
|
||||
return await self.update_user_preferences(user_id, preferences)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enable channel {channel} for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def disable_channel(self, user_id: str, channel: NotificationChannel) -> bool:
|
||||
"""Disable a notification channel for user"""
|
||||
try:
|
||||
preferences = await self.get_user_preferences(user_id)
|
||||
|
||||
if not preferences:
|
||||
preferences = self._get_default_preferences(user_id)
|
||||
|
||||
preferences.channels[channel] = False
|
||||
|
||||
return await self.update_user_preferences(user_id, preferences)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disable channel {channel} for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def set_quiet_hours(
|
||||
self,
|
||||
user_id: str,
|
||||
start_time: str,
|
||||
end_time: str
|
||||
) -> bool:
|
||||
"""Set quiet hours for user"""
|
||||
try:
|
||||
preferences = await self.get_user_preferences(user_id)
|
||||
|
||||
if not preferences:
|
||||
preferences = self._get_default_preferences(user_id)
|
||||
|
||||
preferences.quiet_hours = {
|
||||
"start": start_time,
|
||||
"end": end_time
|
||||
}
|
||||
|
||||
return await self.update_user_preferences(user_id, preferences)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set quiet hours for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def clear_quiet_hours(self, user_id: str) -> bool:
|
||||
"""Clear quiet hours for user"""
|
||||
try:
|
||||
preferences = await self.get_user_preferences(user_id)
|
||||
|
||||
if not preferences:
|
||||
preferences = self._get_default_preferences(user_id)
|
||||
|
||||
preferences.quiet_hours = None
|
||||
|
||||
return await self.update_user_preferences(user_id, preferences)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear quiet hours for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def set_email_frequency(self, user_id: str, frequency: str) -> bool:
|
||||
"""Set email notification frequency"""
|
||||
try:
|
||||
if frequency not in ["immediate", "daily", "weekly"]:
|
||||
return False
|
||||
|
||||
preferences = await self.get_user_preferences(user_id)
|
||||
|
||||
if not preferences:
|
||||
preferences = self._get_default_preferences(user_id)
|
||||
|
||||
preferences.email_frequency = frequency
|
||||
|
||||
return await self.update_user_preferences(user_id, preferences)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set email frequency for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def batch_get_preferences(self, user_ids: List[str]) -> Dict[str, NotificationPreference]:
|
||||
"""Get preferences for multiple users"""
|
||||
results = {}
|
||||
|
||||
for user_id in user_ids:
|
||||
pref = await self.get_user_preferences(user_id)
|
||||
if pref:
|
||||
results[user_id] = pref
|
||||
|
||||
return results
|
||||
|
||||
async def delete_user_preferences(self, user_id: str) -> bool:
|
||||
"""Delete all preferences for a user"""
|
||||
try:
|
||||
# Remove from cache
|
||||
if user_id in self.preferences_cache:
|
||||
del self.preferences_cache[user_id]
|
||||
|
||||
if self.is_connected and self.preferences_collection:
|
||||
# Delete from MongoDB
|
||||
result = await self.preferences_collection.delete_one({"user_id": user_id})
|
||||
logger.info(f"Deleted preferences for user {user_id}")
|
||||
return result.deleted_count > 0
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete preferences for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
def _get_default_preferences(self, user_id: str) -> NotificationPreference:
|
||||
"""Get default notification preferences"""
|
||||
return NotificationPreference(
|
||||
user_id=user_id,
|
||||
channels={
|
||||
NotificationChannel.EMAIL: True,
|
||||
NotificationChannel.SMS: False,
|
||||
NotificationChannel.PUSH: True,
|
||||
NotificationChannel.IN_APP: True
|
||||
},
|
||||
categories={
|
||||
NotificationCategory.SYSTEM: True,
|
||||
NotificationCategory.MARKETING: False,
|
||||
NotificationCategory.TRANSACTION: True,
|
||||
NotificationCategory.SOCIAL: True,
|
||||
NotificationCategory.SECURITY: True,
|
||||
NotificationCategory.UPDATE: True
|
||||
},
|
||||
email_frequency="immediate",
|
||||
timezone="UTC",
|
||||
language="en"
|
||||
)
|
||||
|
||||
async def is_notification_allowed(
|
||||
self,
|
||||
user_id: str,
|
||||
channel: NotificationChannel,
|
||||
category: NotificationCategory
|
||||
) -> bool:
|
||||
"""Check if notification is allowed based on preferences"""
|
||||
preferences = await self.get_user_preferences(user_id)
|
||||
|
||||
if not preferences:
|
||||
return True # Allow by default if no preferences
|
||||
|
||||
# Check channel preference
|
||||
if not preferences.channels.get(channel, True):
|
||||
return False
|
||||
|
||||
# Check category preference
|
||||
if not preferences.categories.get(category, True):
|
||||
return False
|
||||
|
||||
# Check quiet hours
|
||||
if preferences.quiet_hours and channel != NotificationChannel.IN_APP:
|
||||
# Would need to check current time against quiet hours
|
||||
# For demo, we'll allow all
|
||||
pass
|
||||
|
||||
return True
|
||||
304
services/notifications/backend/queue_manager.py
Normal file
304
services/notifications/backend/queue_manager.py
Normal file
@ -0,0 +1,304 @@
|
||||
"""
|
||||
Notification Queue Manager with priority support
|
||||
"""
|
||||
import logging
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
import redis.asyncio as redis
|
||||
from models import NotificationPriority
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NotificationQueueManager:
|
||||
"""Manages notification queues with priority levels"""
|
||||
|
||||
def __init__(self, redis_url: str = "redis://redis:6379"):
|
||||
self.redis_url = redis_url
|
||||
self.redis_client = None
|
||||
self.is_connected = False
|
||||
|
||||
# Queue names by priority
|
||||
self.queue_names = {
|
||||
NotificationPriority.URGENT: "notifications:queue:urgent",
|
||||
NotificationPriority.HIGH: "notifications:queue:high",
|
||||
NotificationPriority.NORMAL: "notifications:queue:normal",
|
||||
NotificationPriority.LOW: "notifications:queue:low"
|
||||
}
|
||||
|
||||
# Scheduled notifications sorted set
|
||||
self.scheduled_key = "notifications:scheduled"
|
||||
|
||||
# Failed notifications queue (DLQ)
|
||||
self.dlq_key = "notifications:dlq"
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to Redis"""
|
||||
try:
|
||||
self.redis_client = await redis.from_url(self.redis_url)
|
||||
await self.redis_client.ping()
|
||||
self.is_connected = True
|
||||
logger.info("Connected to Redis for notification queue")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
self.is_connected = False
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection"""
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
self.is_connected = False
|
||||
logger.info("Disconnected from Redis")
|
||||
|
||||
async def enqueue_notification(self, notification: Any, priority: Optional[NotificationPriority] = None):
|
||||
"""Add notification to queue based on priority"""
|
||||
if not self.is_connected:
|
||||
logger.error("Redis not connected")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Use notification's priority or provided priority
|
||||
if priority is None:
|
||||
priority = notification.priority if hasattr(notification, 'priority') else NotificationPriority.NORMAL
|
||||
|
||||
queue_name = self.queue_names.get(priority, self.queue_names[NotificationPriority.NORMAL])
|
||||
|
||||
# Serialize notification
|
||||
notification_data = notification.dict() if hasattr(notification, 'dict') else notification
|
||||
notification_json = json.dumps(notification_data, default=str)
|
||||
|
||||
# Add to appropriate queue
|
||||
await self.redis_client.lpush(queue_name, notification_json)
|
||||
|
||||
logger.info(f"Enqueued notification to {queue_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enqueue notification: {e}")
|
||||
return False
|
||||
|
||||
async def dequeue_notification(self, timeout: int = 1) -> Optional[Dict[str, Any]]:
|
||||
"""Dequeue notification with priority order"""
|
||||
if not self.is_connected:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Check queues in priority order
|
||||
for priority in [NotificationPriority.URGENT, NotificationPriority.HIGH,
|
||||
NotificationPriority.NORMAL, NotificationPriority.LOW]:
|
||||
queue_name = self.queue_names[priority]
|
||||
|
||||
# Try to get from this queue
|
||||
result = await self.redis_client.brpop(queue_name, timeout=timeout)
|
||||
|
||||
if result:
|
||||
_, notification_json = result
|
||||
notification_data = json.loads(notification_json)
|
||||
logger.debug(f"Dequeued notification from {queue_name}")
|
||||
return notification_data
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dequeue notification: {e}")
|
||||
return None
|
||||
|
||||
async def schedule_notification(self, notification: Any, scheduled_time: datetime):
|
||||
"""Schedule a notification for future delivery"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Serialize notification
|
||||
notification_data = notification.dict() if hasattr(notification, 'dict') else notification
|
||||
notification_json = json.dumps(notification_data, default=str)
|
||||
|
||||
# Add to scheduled set with timestamp as score
|
||||
timestamp = scheduled_time.timestamp()
|
||||
await self.redis_client.zadd(self.scheduled_key, {notification_json: timestamp})
|
||||
|
||||
logger.info(f"Scheduled notification for {scheduled_time}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule notification: {e}")
|
||||
return False
|
||||
|
||||
async def get_due_notifications(self) -> List[Dict[str, Any]]:
|
||||
"""Get notifications that are due for delivery"""
|
||||
if not self.is_connected:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Get current timestamp
|
||||
now = datetime.now().timestamp()
|
||||
|
||||
# Get all notifications with score <= now
|
||||
results = await self.redis_client.zrangebyscore(
|
||||
self.scheduled_key,
|
||||
min=0,
|
||||
max=now,
|
||||
withscores=False
|
||||
)
|
||||
|
||||
notifications = []
|
||||
for notification_json in results:
|
||||
notification_data = json.loads(notification_json)
|
||||
notifications.append(notification_data)
|
||||
|
||||
# Remove from scheduled set
|
||||
await self.redis_client.zrem(self.scheduled_key, notification_json)
|
||||
|
||||
if notifications:
|
||||
logger.info(f"Retrieved {len(notifications)} due notifications")
|
||||
|
||||
return notifications
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get due notifications: {e}")
|
||||
return []
|
||||
|
||||
async def add_to_dlq(self, notification: Any, error_message: str):
|
||||
"""Add failed notification to Dead Letter Queue"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Add error information
|
||||
notification_data = notification.dict() if hasattr(notification, 'dict') else notification
|
||||
notification_data['dlq_error'] = error_message
|
||||
notification_data['dlq_timestamp'] = datetime.now().isoformat()
|
||||
|
||||
notification_json = json.dumps(notification_data, default=str)
|
||||
|
||||
# Add to DLQ
|
||||
await self.redis_client.lpush(self.dlq_key, notification_json)
|
||||
|
||||
logger.info(f"Added notification to DLQ: {error_message}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add to DLQ: {e}")
|
||||
return False
|
||||
|
||||
async def get_dlq_notifications(self, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Get notifications from Dead Letter Queue"""
|
||||
if not self.is_connected:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Get from DLQ
|
||||
results = await self.redis_client.lrange(self.dlq_key, 0, limit - 1)
|
||||
|
||||
notifications = []
|
||||
for notification_json in results:
|
||||
notification_data = json.loads(notification_json)
|
||||
notifications.append(notification_data)
|
||||
|
||||
return notifications
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get DLQ notifications: {e}")
|
||||
return []
|
||||
|
||||
async def retry_dlq_notification(self, index: int) -> bool:
|
||||
"""Retry a notification from DLQ"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get notification at index
|
||||
notification_json = await self.redis_client.lindex(self.dlq_key, index)
|
||||
|
||||
if not notification_json:
|
||||
return False
|
||||
|
||||
# Parse and remove DLQ info
|
||||
notification_data = json.loads(notification_json)
|
||||
notification_data.pop('dlq_error', None)
|
||||
notification_data.pop('dlq_timestamp', None)
|
||||
|
||||
# Re-enqueue
|
||||
priority = NotificationPriority(notification_data.get('priority', 'normal'))
|
||||
queue_name = self.queue_names[priority]
|
||||
|
||||
new_json = json.dumps(notification_data, default=str)
|
||||
await self.redis_client.lpush(queue_name, new_json)
|
||||
|
||||
# Remove from DLQ
|
||||
await self.redis_client.lrem(self.dlq_key, 1, notification_json)
|
||||
|
||||
logger.info(f"Retried DLQ notification at index {index}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retry DLQ notification: {e}")
|
||||
return False
|
||||
|
||||
async def get_queue_status(self) -> Dict[str, Any]:
|
||||
"""Get current queue status"""
|
||||
if not self.is_connected:
|
||||
return {"status": "disconnected"}
|
||||
|
||||
try:
|
||||
status = {
|
||||
"status": "connected",
|
||||
"queues": {},
|
||||
"scheduled": 0,
|
||||
"dlq": 0
|
||||
}
|
||||
|
||||
# Get queue lengths
|
||||
for priority, queue_name in self.queue_names.items():
|
||||
length = await self.redis_client.llen(queue_name)
|
||||
status["queues"][priority.value] = length
|
||||
|
||||
# Get scheduled count
|
||||
status["scheduled"] = await self.redis_client.zcard(self.scheduled_key)
|
||||
|
||||
# Get DLQ count
|
||||
status["dlq"] = await self.redis_client.llen(self.dlq_key)
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get queue status: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
async def clear_queue(self, priority: NotificationPriority) -> bool:
|
||||
"""Clear a specific priority queue"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
queue_name = self.queue_names[priority]
|
||||
await self.redis_client.delete(queue_name)
|
||||
logger.info(f"Cleared queue: {queue_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear queue: {e}")
|
||||
return False
|
||||
|
||||
async def clear_all_queues(self) -> bool:
|
||||
"""Clear all notification queues"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Clear all priority queues
|
||||
for queue_name in self.queue_names.values():
|
||||
await self.redis_client.delete(queue_name)
|
||||
|
||||
# Clear scheduled and DLQ
|
||||
await self.redis_client.delete(self.scheduled_key)
|
||||
await self.redis_client.delete(self.dlq_key)
|
||||
|
||||
logger.info("Cleared all notification queues")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear all queues: {e}")
|
||||
return False
|
||||
11
services/notifications/backend/requirements.txt
Normal file
11
services/notifications/backend/requirements.txt
Normal file
@ -0,0 +1,11 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
pydantic==2.5.3
|
||||
python-dotenv==1.0.0
|
||||
redis==5.0.1
|
||||
motor==3.5.1
|
||||
pymongo==4.6.1
|
||||
httpx==0.26.0
|
||||
websockets==12.0
|
||||
aiofiles==23.2.1
|
||||
python-multipart==0.0.6
|
||||
334
services/notifications/backend/template_engine.py
Normal file
334
services/notifications/backend/template_engine.py
Normal file
@ -0,0 +1,334 @@
|
||||
"""
|
||||
Template Engine for notification templates
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from models import NotificationTemplate, NotificationChannel, NotificationCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TemplateEngine:
|
||||
"""Manages and renders notification templates"""
|
||||
|
||||
def __init__(self):
|
||||
self.templates = {} # In-memory storage for demo
|
||||
self._load_default_templates()
|
||||
|
||||
async def load_templates(self):
|
||||
"""Load templates from storage"""
|
||||
# In production, would load from database
|
||||
logger.info(f"Loaded {len(self.templates)} templates")
|
||||
|
||||
def _load_default_templates(self):
|
||||
"""Load default system templates"""
|
||||
default_templates = [
|
||||
NotificationTemplate(
|
||||
id="welcome",
|
||||
name="Welcome Email",
|
||||
channel=NotificationChannel.EMAIL,
|
||||
category=NotificationCategory.SYSTEM,
|
||||
subject_template="Welcome to {{app_name}}!",
|
||||
body_template="""
|
||||
Hi {{user_name}},
|
||||
|
||||
Welcome to {{app_name}}! We're excited to have you on board.
|
||||
|
||||
Here are some things you can do to get started:
|
||||
- Complete your profile
|
||||
- Explore our features
|
||||
- Connect with other users
|
||||
|
||||
If you have any questions, feel free to reach out to our support team.
|
||||
|
||||
Best regards,
|
||||
The {{app_name}} Team
|
||||
""",
|
||||
variables=["user_name", "app_name"]
|
||||
),
|
||||
NotificationTemplate(
|
||||
id="password_reset",
|
||||
name="Password Reset",
|
||||
channel=NotificationChannel.EMAIL,
|
||||
category=NotificationCategory.SECURITY,
|
||||
subject_template="Password Reset Request",
|
||||
body_template="""
|
||||
Hi {{user_name}},
|
||||
|
||||
We received a request to reset your password for {{app_name}}.
|
||||
|
||||
Click the link below to reset your password:
|
||||
{{reset_link}}
|
||||
|
||||
This link will expire in {{expiry_hours}} hours.
|
||||
|
||||
If you didn't request this, please ignore this email or contact support.
|
||||
|
||||
Best regards,
|
||||
The {{app_name}} Team
|
||||
""",
|
||||
variables=["user_name", "app_name", "reset_link", "expiry_hours"]
|
||||
),
|
||||
NotificationTemplate(
|
||||
id="order_confirmation",
|
||||
name="Order Confirmation",
|
||||
channel=NotificationChannel.EMAIL,
|
||||
category=NotificationCategory.TRANSACTION,
|
||||
subject_template="Order #{{order_id}} Confirmed",
|
||||
body_template="""
|
||||
Hi {{user_name}},
|
||||
|
||||
Your order #{{order_id}} has been confirmed!
|
||||
|
||||
Order Details:
|
||||
- Total: {{order_total}}
|
||||
- Items: {{item_count}}
|
||||
- Estimated Delivery: {{delivery_date}}
|
||||
|
||||
You can track your order status at: {{tracking_link}}
|
||||
|
||||
Thank you for your purchase!
|
||||
|
||||
Best regards,
|
||||
The {{app_name}} Team
|
||||
""",
|
||||
variables=["user_name", "app_name", "order_id", "order_total", "item_count", "delivery_date", "tracking_link"]
|
||||
),
|
||||
NotificationTemplate(
|
||||
id="sms_verification",
|
||||
name="SMS Verification",
|
||||
channel=NotificationChannel.SMS,
|
||||
category=NotificationCategory.SECURITY,
|
||||
body_template="Your {{app_name}} verification code is: {{code}}. Valid for {{expiry_minutes}} minutes.",
|
||||
variables=["app_name", "code", "expiry_minutes"]
|
||||
),
|
||||
NotificationTemplate(
|
||||
id="push_reminder",
|
||||
name="Push Reminder",
|
||||
channel=NotificationChannel.PUSH,
|
||||
category=NotificationCategory.UPDATE,
|
||||
body_template="{{reminder_text}}",
|
||||
variables=["reminder_text"]
|
||||
),
|
||||
NotificationTemplate(
|
||||
id="in_app_alert",
|
||||
name="In-App Alert",
|
||||
channel=NotificationChannel.IN_APP,
|
||||
category=NotificationCategory.SYSTEM,
|
||||
body_template="{{alert_message}}",
|
||||
variables=["alert_message"]
|
||||
),
|
||||
NotificationTemplate(
|
||||
id="weekly_digest",
|
||||
name="Weekly Digest",
|
||||
channel=NotificationChannel.EMAIL,
|
||||
category=NotificationCategory.MARKETING,
|
||||
subject_template="Your Weekly {{app_name}} Digest",
|
||||
body_template="""
|
||||
Hi {{user_name}},
|
||||
|
||||
Here's what happened this week on {{app_name}}:
|
||||
|
||||
📊 Stats:
|
||||
- New connections: {{new_connections}}
|
||||
- Messages received: {{messages_count}}
|
||||
- Activities completed: {{activities_count}}
|
||||
|
||||
🔥 Trending:
|
||||
{{trending_items}}
|
||||
|
||||
💡 Tip of the week:
|
||||
{{weekly_tip}}
|
||||
|
||||
See you next week!
|
||||
The {{app_name}} Team
|
||||
""",
|
||||
variables=["user_name", "app_name", "new_connections", "messages_count", "activities_count", "trending_items", "weekly_tip"]
|
||||
),
|
||||
NotificationTemplate(
|
||||
id="friend_request",
|
||||
name="Friend Request",
|
||||
channel=NotificationChannel.IN_APP,
|
||||
category=NotificationCategory.SOCIAL,
|
||||
body_template="{{sender_name}} sent you a friend request. {{personal_message}}",
|
||||
variables=["sender_name", "personal_message"]
|
||||
)
|
||||
]
|
||||
|
||||
for template in default_templates:
|
||||
self.templates[template.id] = template
|
||||
|
||||
async def create_template(self, template: NotificationTemplate) -> str:
|
||||
"""Create a new template"""
|
||||
if not template.id:
|
||||
template.id = str(uuid.uuid4())
|
||||
|
||||
# Validate template
|
||||
if not self._validate_template(template):
|
||||
raise ValueError("Invalid template format")
|
||||
|
||||
# Extract variables from template
|
||||
template.variables = self._extract_variables(template.body_template)
|
||||
if template.subject_template:
|
||||
template.variables.extend(self._extract_variables(template.subject_template))
|
||||
template.variables = list(set(template.variables)) # Remove duplicates
|
||||
|
||||
# Store template
|
||||
self.templates[template.id] = template
|
||||
|
||||
logger.info(f"Created template: {template.id}")
|
||||
return template.id
|
||||
|
||||
async def update_template(self, template_id: str, template: NotificationTemplate) -> bool:
|
||||
"""Update an existing template"""
|
||||
if template_id not in self.templates:
|
||||
return False
|
||||
|
||||
# Validate template
|
||||
if not self._validate_template(template):
|
||||
raise ValueError("Invalid template format")
|
||||
|
||||
# Update template
|
||||
template.id = template_id
|
||||
template.updated_at = datetime.now()
|
||||
|
||||
# Re-extract variables
|
||||
template.variables = self._extract_variables(template.body_template)
|
||||
if template.subject_template:
|
||||
template.variables.extend(self._extract_variables(template.subject_template))
|
||||
template.variables = list(set(template.variables))
|
||||
|
||||
self.templates[template_id] = template
|
||||
|
||||
logger.info(f"Updated template: {template_id}")
|
||||
return True
|
||||
|
||||
async def get_template(self, template_id: str) -> Optional[NotificationTemplate]:
|
||||
"""Get a template by ID"""
|
||||
return self.templates.get(template_id)
|
||||
|
||||
async def get_all_templates(self) -> List[NotificationTemplate]:
|
||||
"""Get all templates"""
|
||||
return list(self.templates.values())
|
||||
|
||||
async def delete_template(self, template_id: str) -> bool:
|
||||
"""Delete a template"""
|
||||
if template_id in self.templates:
|
||||
del self.templates[template_id]
|
||||
logger.info(f"Deleted template: {template_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def render_template(self, template: NotificationTemplate, variables: Dict[str, Any]) -> str:
|
||||
"""Render a template with variables"""
|
||||
if not template:
|
||||
raise ValueError("Template not provided")
|
||||
|
||||
# Start with body template
|
||||
rendered = template.body_template
|
||||
|
||||
# Replace variables
|
||||
for var_name in template.variables:
|
||||
placeholder = f"{{{{{var_name}}}}}"
|
||||
value = variables.get(var_name, f"[{var_name}]") # Default to placeholder if not provided
|
||||
|
||||
# Convert non-string values to string
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
|
||||
rendered = rendered.replace(placeholder, value)
|
||||
|
||||
# Clean up extra whitespace
|
||||
rendered = re.sub(r'\n\s*\n', '\n\n', rendered.strip())
|
||||
|
||||
return rendered
|
||||
|
||||
async def render_subject(self, template: NotificationTemplate, variables: Dict[str, Any]) -> Optional[str]:
|
||||
"""Render a template subject with variables"""
|
||||
if not template or not template.subject_template:
|
||||
return None
|
||||
|
||||
rendered = template.subject_template
|
||||
|
||||
# Replace variables
|
||||
for var_name in self._extract_variables(template.subject_template):
|
||||
placeholder = f"{{{{{var_name}}}}}"
|
||||
value = variables.get(var_name, f"[{var_name}]")
|
||||
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
|
||||
rendered = rendered.replace(placeholder, value)
|
||||
|
||||
return rendered
|
||||
|
||||
def _validate_template(self, template: NotificationTemplate) -> bool:
|
||||
"""Validate template format"""
|
||||
if not template.name or not template.body_template:
|
||||
return False
|
||||
|
||||
# Check for basic template syntax
|
||||
try:
|
||||
# Check for balanced braces
|
||||
open_count = template.body_template.count("{{")
|
||||
close_count = template.body_template.count("}}")
|
||||
if open_count != close_count:
|
||||
return False
|
||||
|
||||
if template.subject_template:
|
||||
open_count = template.subject_template.count("{{")
|
||||
close_count = template.subject_template.count("}}")
|
||||
if open_count != close_count:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Template validation error: {e}")
|
||||
return False
|
||||
|
||||
def _extract_variables(self, template_text: str) -> List[str]:
|
||||
"""Extract variable names from template text"""
|
||||
if not template_text:
|
||||
return []
|
||||
|
||||
# Find all {{variable_name}} patterns
|
||||
pattern = r'\{\{(\w+)\}\}'
|
||||
matches = re.findall(pattern, template_text)
|
||||
|
||||
return list(set(matches)) # Return unique variable names
|
||||
|
||||
async def get_templates_by_channel(self, channel: NotificationChannel) -> List[NotificationTemplate]:
|
||||
"""Get templates for a specific channel"""
|
||||
return [t for t in self.templates.values() if t.channel == channel]
|
||||
|
||||
async def get_templates_by_category(self, category: NotificationCategory) -> List[NotificationTemplate]:
|
||||
"""Get templates for a specific category"""
|
||||
return [t for t in self.templates.values() if t.category == category]
|
||||
|
||||
async def clone_template(self, template_id: str, new_name: str) -> str:
|
||||
"""Clone an existing template"""
|
||||
original = self.templates.get(template_id)
|
||||
if not original:
|
||||
raise ValueError(f"Template {template_id} not found")
|
||||
|
||||
# Create new template
|
||||
new_template = NotificationTemplate(
|
||||
id=str(uuid.uuid4()),
|
||||
name=new_name,
|
||||
channel=original.channel,
|
||||
category=original.category,
|
||||
subject_template=original.subject_template,
|
||||
body_template=original.body_template,
|
||||
variables=original.variables.copy(),
|
||||
metadata=original.metadata.copy(),
|
||||
is_active=True,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
self.templates[new_template.id] = new_template
|
||||
|
||||
logger.info(f"Cloned template {template_id} to {new_template.id}")
|
||||
return new_template.id
|
||||
268
services/notifications/backend/test_notifications.py
Normal file
268
services/notifications/backend/test_notifications.py
Normal file
@ -0,0 +1,268 @@
|
||||
"""
|
||||
Test script for Notification Service
|
||||
"""
|
||||
import asyncio
|
||||
import httpx
|
||||
import websockets
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
BASE_URL = "http://localhost:8013"
|
||||
WS_URL = "ws://localhost:8013/ws/notifications"
|
||||
|
||||
async def test_notification_api():
|
||||
"""Test notification API endpoints"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
print("\n🔔 Testing Notification Service API...")
|
||||
|
||||
# Test health check
|
||||
print("\n1. Testing health check...")
|
||||
response = await client.get(f"{BASE_URL}/health")
|
||||
print(f"Health check: {response.json()}")
|
||||
|
||||
# Test sending single notification
|
||||
print("\n2. Testing single notification...")
|
||||
notification_data = {
|
||||
"user_id": "test_user_123",
|
||||
"title": "Welcome to Our App!",
|
||||
"message": "Thank you for joining our platform. We're excited to have you!",
|
||||
"channels": ["in_app", "email"],
|
||||
"priority": "high",
|
||||
"category": "system",
|
||||
"data": {
|
||||
"action_url": "https://example.com/welcome",
|
||||
"icon": "welcome"
|
||||
}
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/notifications/send",
|
||||
json=notification_data
|
||||
)
|
||||
notification_result = response.json()
|
||||
print(f"Notification sent: {notification_result}")
|
||||
notification_id = notification_result.get("notification_id")
|
||||
|
||||
# Test bulk notifications
|
||||
print("\n3. Testing bulk notifications...")
|
||||
bulk_data = {
|
||||
"user_ids": ["user1", "user2", "user3"],
|
||||
"title": "System Maintenance Notice",
|
||||
"message": "We will be performing system maintenance tonight from 2-4 AM.",
|
||||
"channels": ["in_app", "push"],
|
||||
"priority": "normal",
|
||||
"category": "update"
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/notifications/send-bulk",
|
||||
json=bulk_data
|
||||
)
|
||||
print(f"Bulk notifications: {response.json()}")
|
||||
|
||||
# Test scheduled notification
|
||||
print("\n4. Testing scheduled notification...")
|
||||
scheduled_time = datetime.now() + timedelta(minutes=5)
|
||||
scheduled_data = {
|
||||
"user_id": "test_user_123",
|
||||
"title": "Reminder: Meeting in 5 minutes",
|
||||
"message": "Your scheduled meeting is about to start.",
|
||||
"channels": ["in_app", "push"],
|
||||
"priority": "urgent",
|
||||
"category": "system",
|
||||
"schedule_at": scheduled_time.isoformat()
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/notifications/send",
|
||||
json=scheduled_data
|
||||
)
|
||||
print(f"Scheduled notification: {response.json()}")
|
||||
|
||||
# Test get user notifications
|
||||
print("\n5. Testing get user notifications...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/notifications/user/test_user_123"
|
||||
)
|
||||
notifications = response.json()
|
||||
print(f"User notifications: Found {notifications['count']} notifications")
|
||||
|
||||
# Test mark as read
|
||||
if notification_id:
|
||||
print("\n6. Testing mark as read...")
|
||||
response = await client.patch(
|
||||
f"{BASE_URL}/api/notifications/{notification_id}/read"
|
||||
)
|
||||
print(f"Mark as read: {response.json()}")
|
||||
|
||||
# Test templates
|
||||
print("\n7. Testing templates...")
|
||||
response = await client.get(f"{BASE_URL}/api/templates")
|
||||
templates = response.json()
|
||||
print(f"Available templates: {len(templates['templates'])} templates")
|
||||
|
||||
# Test preferences
|
||||
print("\n8. Testing user preferences...")
|
||||
|
||||
# Get preferences
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/preferences/test_user_123"
|
||||
)
|
||||
print(f"Current preferences: {response.json()}")
|
||||
|
||||
# Update preferences
|
||||
new_preferences = {
|
||||
"user_id": "test_user_123",
|
||||
"channels": {
|
||||
"email": True,
|
||||
"sms": False,
|
||||
"push": True,
|
||||
"in_app": True
|
||||
},
|
||||
"categories": {
|
||||
"system": True,
|
||||
"marketing": False,
|
||||
"transaction": True,
|
||||
"social": True,
|
||||
"security": True,
|
||||
"update": True
|
||||
},
|
||||
"email_frequency": "daily",
|
||||
"timezone": "America/New_York",
|
||||
"language": "en"
|
||||
}
|
||||
|
||||
response = await client.put(
|
||||
f"{BASE_URL}/api/preferences/test_user_123",
|
||||
json=new_preferences
|
||||
)
|
||||
print(f"Update preferences: {response.json()}")
|
||||
|
||||
# Test unsubscribe
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/preferences/test_user_123/unsubscribe/marketing"
|
||||
)
|
||||
print(f"Unsubscribe from marketing: {response.json()}")
|
||||
|
||||
# Test notification with template
|
||||
print("\n9. Testing notification with template...")
|
||||
template_notification = {
|
||||
"user_id": "test_user_123",
|
||||
"title": "Password Reset Request",
|
||||
"message": "", # Will be filled by template
|
||||
"channels": ["email"],
|
||||
"priority": "high",
|
||||
"category": "security",
|
||||
"template_id": "password_reset",
|
||||
"data": {
|
||||
"user_name": "John Doe",
|
||||
"app_name": "Our App",
|
||||
"reset_link": "https://example.com/reset/abc123",
|
||||
"expiry_hours": 24
|
||||
}
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/notifications/send",
|
||||
json=template_notification
|
||||
)
|
||||
print(f"Template notification: {response.json()}")
|
||||
|
||||
# Test queue status
|
||||
print("\n10. Testing queue status...")
|
||||
response = await client.get(f"{BASE_URL}/api/queue/status")
|
||||
print(f"Queue status: {response.json()}")
|
||||
|
||||
# Test analytics
|
||||
print("\n11. Testing analytics...")
|
||||
response = await client.get(f"{BASE_URL}/api/analytics")
|
||||
analytics = response.json()
|
||||
print(f"Analytics overview: {analytics}")
|
||||
|
||||
# Test notification history
|
||||
print("\n12. Testing notification history...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/history",
|
||||
params={"user_id": "test_user_123", "limit": 10}
|
||||
)
|
||||
history = response.json()
|
||||
print(f"Notification history: {history['count']} entries")
|
||||
|
||||
# Test device registration
|
||||
print("\n13. Testing device registration...")
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/devices/register",
|
||||
params={
|
||||
"user_id": "test_user_123",
|
||||
"device_token": "dummy_token_12345",
|
||||
"device_type": "ios"
|
||||
}
|
||||
)
|
||||
print(f"Device registration: {response.json()}")
|
||||
|
||||
async def test_websocket():
|
||||
"""Test WebSocket connection for real-time notifications"""
|
||||
print("\n\n🌐 Testing WebSocket Connection...")
|
||||
|
||||
try:
|
||||
uri = f"{WS_URL}/test_user_123"
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print(f"Connected to WebSocket at {uri}")
|
||||
|
||||
# Listen for welcome message
|
||||
message = await websocket.recv()
|
||||
data = json.loads(message)
|
||||
print(f"Welcome message: {data}")
|
||||
|
||||
# Send ping
|
||||
await websocket.send("ping")
|
||||
pong = await websocket.recv()
|
||||
print(f"Ping response: {pong}")
|
||||
|
||||
# Send notification via API while connected
|
||||
async with httpx.AsyncClient() as client:
|
||||
notification_data = {
|
||||
"user_id": "test_user_123",
|
||||
"title": "Real-time Test",
|
||||
"message": "This should appear in WebSocket!",
|
||||
"channels": ["in_app"],
|
||||
"priority": "normal"
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/notifications/send",
|
||||
json=notification_data
|
||||
)
|
||||
print(f"Sent notification: {response.json()}")
|
||||
|
||||
# Wait for real-time notification
|
||||
print("Waiting for real-time notification...")
|
||||
try:
|
||||
notification = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||
print(f"Received real-time notification: {json.loads(notification)}")
|
||||
except asyncio.TimeoutError:
|
||||
print("No real-time notification received (timeout)")
|
||||
|
||||
print("WebSocket test completed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
print("=" * 60)
|
||||
print("NOTIFICATION SERVICE TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
# Test API endpoints
|
||||
await test_notification_api()
|
||||
|
||||
# Test WebSocket
|
||||
await test_websocket()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ All tests completed!")
|
||||
print("=" * 60)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
194
services/notifications/backend/websocket_server.py
Normal file
194
services/notifications/backend/websocket_server.py
Normal file
@ -0,0 +1,194 @@
|
||||
"""
|
||||
WebSocket Server for real-time notifications
|
||||
"""
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict, List
|
||||
from fastapi import WebSocket
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WebSocketNotificationServer:
|
||||
"""Manages WebSocket connections for real-time notifications"""
|
||||
|
||||
def __init__(self):
|
||||
# Store connections by user_id
|
||||
self.active_connections: Dict[str, List[WebSocket]] = {}
|
||||
self.connection_metadata: Dict[WebSocket, Dict] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, user_id: str):
|
||||
"""Accept a new WebSocket connection"""
|
||||
await websocket.accept()
|
||||
|
||||
# Add to active connections
|
||||
if user_id not in self.active_connections:
|
||||
self.active_connections[user_id] = []
|
||||
|
||||
self.active_connections[user_id].append(websocket)
|
||||
|
||||
# Store metadata
|
||||
self.connection_metadata[websocket] = {
|
||||
"user_id": user_id,
|
||||
"connected_at": datetime.now(),
|
||||
"last_activity": datetime.now()
|
||||
}
|
||||
|
||||
logger.info(f"WebSocket connected for user {user_id}. Total connections: {len(self.active_connections[user_id])}")
|
||||
|
||||
# Send welcome message
|
||||
await self.send_welcome_message(websocket, user_id)
|
||||
|
||||
def disconnect(self, user_id: str):
|
||||
"""Remove a WebSocket connection"""
|
||||
if user_id in self.active_connections:
|
||||
# Remove all connections for this user
|
||||
for websocket in self.active_connections[user_id]:
|
||||
if websocket in self.connection_metadata:
|
||||
del self.connection_metadata[websocket]
|
||||
|
||||
del self.active_connections[user_id]
|
||||
logger.info(f"WebSocket disconnected for user {user_id}")
|
||||
|
||||
async def send_to_user(self, user_id: str, message: Dict):
|
||||
"""Send a message to all connections for a specific user"""
|
||||
if user_id not in self.active_connections:
|
||||
logger.debug(f"No active connections for user {user_id}")
|
||||
return False
|
||||
|
||||
disconnected = []
|
||||
for websocket in self.active_connections[user_id]:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
# Update last activity
|
||||
if websocket in self.connection_metadata:
|
||||
self.connection_metadata[websocket]["last_activity"] = datetime.now()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to WebSocket for user {user_id}: {e}")
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Remove disconnected websockets
|
||||
for ws in disconnected:
|
||||
self.active_connections[user_id].remove(ws)
|
||||
if ws in self.connection_metadata:
|
||||
del self.connection_metadata[ws]
|
||||
|
||||
# Clean up if no more connections
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
return True
|
||||
|
||||
async def broadcast(self, message: Dict):
|
||||
"""Broadcast a message to all connected users"""
|
||||
for user_id in list(self.active_connections.keys()):
|
||||
await self.send_to_user(user_id, message)
|
||||
|
||||
async def send_notification(self, user_id: str, notification: Dict):
|
||||
"""Send a notification to a specific user"""
|
||||
message = {
|
||||
"type": "notification",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": notification
|
||||
}
|
||||
return await self.send_to_user(user_id, message)
|
||||
|
||||
async def send_welcome_message(self, websocket: WebSocket, user_id: str):
|
||||
"""Send a welcome message to newly connected user"""
|
||||
welcome_message = {
|
||||
"type": "connection",
|
||||
"status": "connected",
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": "Connected to notification service"
|
||||
}
|
||||
|
||||
try:
|
||||
await websocket.send_json(welcome_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending welcome message: {e}")
|
||||
|
||||
def get_connection_count(self, user_id: str = None) -> int:
|
||||
"""Get the number of active connections"""
|
||||
if user_id:
|
||||
return len(self.active_connections.get(user_id, []))
|
||||
|
||||
total = 0
|
||||
for connections in self.active_connections.values():
|
||||
total += len(connections)
|
||||
return total
|
||||
|
||||
def get_connected_users(self) -> List[str]:
|
||||
"""Get list of connected user IDs"""
|
||||
return list(self.active_connections.keys())
|
||||
|
||||
async def send_system_message(self, user_id: str, message: str, severity: str = "info"):
|
||||
"""Send a system message to a user"""
|
||||
system_message = {
|
||||
"type": "system",
|
||||
"severity": severity,
|
||||
"message": message,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return await self.send_to_user(user_id, system_message)
|
||||
|
||||
async def send_presence_update(self, user_id: str, status: str):
|
||||
"""Send presence update to user's connections"""
|
||||
presence_message = {
|
||||
"type": "presence",
|
||||
"user_id": user_id,
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Could send to friends/contacts if implemented
|
||||
return await self.send_to_user(user_id, presence_message)
|
||||
|
||||
async def handle_ping(self, websocket: WebSocket):
|
||||
"""Handle ping message from client"""
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "pong",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# Update last activity
|
||||
if websocket in self.connection_metadata:
|
||||
self.connection_metadata[websocket]["last_activity"] = datetime.now()
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling ping: {e}")
|
||||
|
||||
async def cleanup_stale_connections(self, timeout_minutes: int = 30):
|
||||
"""Clean up stale connections that haven't been active"""
|
||||
now = datetime.now()
|
||||
stale_connections = []
|
||||
|
||||
for websocket, metadata in self.connection_metadata.items():
|
||||
last_activity = metadata.get("last_activity")
|
||||
if last_activity:
|
||||
time_diff = (now - last_activity).total_seconds() / 60
|
||||
if time_diff > timeout_minutes:
|
||||
stale_connections.append({
|
||||
"websocket": websocket,
|
||||
"user_id": metadata.get("user_id")
|
||||
})
|
||||
|
||||
# Remove stale connections
|
||||
for conn in stale_connections:
|
||||
user_id = conn["user_id"]
|
||||
websocket = conn["websocket"]
|
||||
|
||||
if user_id in self.active_connections:
|
||||
if websocket in self.active_connections[user_id]:
|
||||
self.active_connections[user_id].remove(websocket)
|
||||
|
||||
# Clean up if no more connections
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
if websocket in self.connection_metadata:
|
||||
del self.connection_metadata[websocket]
|
||||
|
||||
logger.info(f"Cleaned up stale connection for user {user_id}")
|
||||
|
||||
return len(stale_connections)
|
||||
21
services/oauth/backend/Dockerfile
Normal file
21
services/oauth/backend/Dockerfile
Normal file
@ -0,0 +1,21 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the application
|
||||
CMD ["python", "main.py"]
|
||||
142
services/oauth/backend/database.py
Normal file
142
services/oauth/backend/database.py
Normal file
@ -0,0 +1,142 @@
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from beanie import init_beanie
|
||||
import os
|
||||
from models import OAuthApplication, AuthorizationCode, AccessToken, OAuthScope, UserConsent
|
||||
|
||||
async def init_db():
|
||||
client = AsyncIOMotorClient(os.getenv("MONGODB_URL", "mongodb://mongodb:27017"))
|
||||
database = client[os.getenv("OAUTH_DB_NAME", "oauth_db")]
|
||||
|
||||
await init_beanie(
|
||||
database=database,
|
||||
document_models=[
|
||||
OAuthApplication,
|
||||
AuthorizationCode,
|
||||
AccessToken,
|
||||
OAuthScope,
|
||||
UserConsent
|
||||
]
|
||||
)
|
||||
|
||||
# 기본 스코프 생성
|
||||
await create_default_scopes()
|
||||
|
||||
async def create_default_scopes():
|
||||
"""기본 OAuth 스코프 생성"""
|
||||
default_scopes = [
|
||||
# 기본 인증 스코프
|
||||
{
|
||||
"name": "openid",
|
||||
"display_name": "OpenID Connect",
|
||||
"description": "기본 사용자 인증 정보",
|
||||
"is_default": True,
|
||||
"requires_approval": False
|
||||
},
|
||||
{
|
||||
"name": "profile",
|
||||
"display_name": "프로필 정보",
|
||||
"description": "이름, 프로필 이미지, 기본 정보 접근",
|
||||
"is_default": True,
|
||||
"requires_approval": True
|
||||
},
|
||||
{
|
||||
"name": "email",
|
||||
"display_name": "이메일 주소",
|
||||
"description": "이메일 주소 및 인증 상태 확인",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
{
|
||||
"name": "picture",
|
||||
"display_name": "프로필 사진",
|
||||
"description": "프로필 사진 및 썸네일 접근",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
|
||||
# 사용자 데이터 접근 스코프
|
||||
{
|
||||
"name": "user:read",
|
||||
"display_name": "사용자 정보 읽기",
|
||||
"description": "사용자 프로필 및 설정 읽기",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
{
|
||||
"name": "user:write",
|
||||
"display_name": "사용자 정보 수정",
|
||||
"description": "사용자 프로필 및 설정 수정",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
|
||||
# 애플리케이션 관리 스코프
|
||||
{
|
||||
"name": "app:read",
|
||||
"display_name": "애플리케이션 정보 읽기",
|
||||
"description": "OAuth 애플리케이션 정보 조회",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
{
|
||||
"name": "app:write",
|
||||
"display_name": "애플리케이션 관리",
|
||||
"description": "OAuth 애플리케이션 생성 및 수정",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
|
||||
# 조직/팀 관련 스코프
|
||||
{
|
||||
"name": "org:read",
|
||||
"display_name": "조직 정보 읽기",
|
||||
"description": "소속 조직 및 팀 정보 조회",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
{
|
||||
"name": "org:write",
|
||||
"display_name": "조직 관리",
|
||||
"description": "조직 설정 및 멤버 관리",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
|
||||
# API 접근 스코프
|
||||
{
|
||||
"name": "api:read",
|
||||
"display_name": "API 데이터 읽기",
|
||||
"description": "API를 통한 데이터 조회",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
{
|
||||
"name": "api:write",
|
||||
"display_name": "API 데이터 쓰기",
|
||||
"description": "API를 통한 데이터 생성/수정/삭제",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
|
||||
# 특수 스코프
|
||||
{
|
||||
"name": "offline_access",
|
||||
"display_name": "오프라인 액세스",
|
||||
"description": "리프레시 토큰 발급 (장기 액세스)",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
},
|
||||
{
|
||||
"name": "admin",
|
||||
"display_name": "관리자 권한",
|
||||
"description": "전체 시스템 관리 권한",
|
||||
"is_default": False,
|
||||
"requires_approval": True
|
||||
}
|
||||
]
|
||||
|
||||
for scope_data in default_scopes:
|
||||
existing = await OAuthScope.find_one(OAuthScope.name == scope_data["name"])
|
||||
if not existing:
|
||||
scope = OAuthScope(**scope_data)
|
||||
await scope.create()
|
||||
591
services/oauth/backend/main.py
Normal file
591
services/oauth/backend/main.py
Normal file
@ -0,0 +1,591 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends, Form, Query, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse, JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict
|
||||
import uvicorn
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
from database import init_db
|
||||
from models import (
|
||||
OAuthApplication, AuthorizationCode, AccessToken,
|
||||
OAuthScope, UserConsent, GrantType, ResponseType
|
||||
)
|
||||
from utils import OAuthUtils, TokenGenerator, ScopeValidator
|
||||
from pydantic import BaseModel, Field
|
||||
from beanie import PydanticObjectId
|
||||
|
||||
sys.path.append('/app')
|
||||
from shared.kafka import KafkaProducer, Event, EventType
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pydantic models
|
||||
class ApplicationCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
redirect_uris: List[str]
|
||||
website_url: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
privacy_policy_url: Optional[str] = None
|
||||
terms_url: Optional[str] = None
|
||||
sso_enabled: Optional[bool] = False
|
||||
sso_provider: Optional[str] = None
|
||||
sso_config: Optional[Dict] = None
|
||||
allowed_domains: Optional[List[str]] = None
|
||||
|
||||
class ApplicationUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
redirect_uris: Optional[List[str]] = None
|
||||
website_url: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
privacy_policy_url: Optional[str] = None
|
||||
terms_url: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
sso_enabled: Optional[bool] = None
|
||||
sso_provider: Optional[str] = None
|
||||
sso_config: Optional[Dict] = None
|
||||
allowed_domains: Optional[List[str]] = None
|
||||
|
||||
class ApplicationResponse(BaseModel):
|
||||
id: str
|
||||
client_id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
redirect_uris: List[str]
|
||||
allowed_scopes: List[str]
|
||||
grant_types: List[str]
|
||||
is_active: bool
|
||||
is_trusted: bool
|
||||
sso_enabled: bool
|
||||
sso_provider: Optional[str]
|
||||
allowed_domains: List[str]
|
||||
website_url: Optional[str]
|
||||
logo_url: Optional[str]
|
||||
created_at: datetime
|
||||
|
||||
class TokenRequest(BaseModel):
|
||||
grant_type: str
|
||||
code: Optional[str] = None
|
||||
redirect_uri: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
refresh_token: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
code_verifier: Optional[str] = None
|
||||
|
||||
# Global Kafka producer
|
||||
kafka_producer: Optional[KafkaProducer] = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
global kafka_producer
|
||||
|
||||
await init_db()
|
||||
|
||||
# Initialize Kafka producer
|
||||
try:
|
||||
kafka_producer = KafkaProducer(
|
||||
bootstrap_servers=os.getenv('KAFKA_BOOTSTRAP_SERVERS', 'kafka:9092')
|
||||
)
|
||||
await kafka_producer.start()
|
||||
logger.info("Kafka producer initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize Kafka producer: {e}")
|
||||
kafka_producer = None
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if kafka_producer:
|
||||
await kafka_producer.stop()
|
||||
|
||||
app = FastAPI(
|
||||
title="OAuth 2.0 Service",
|
||||
description="OAuth 2.0 인증 서버 및 애플리케이션 관리",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Health check
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "oauth",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# OAuth Application Management
|
||||
@app.post("/applications", response_model=ApplicationResponse, status_code=201)
|
||||
async def create_application(
|
||||
app_data: ApplicationCreate,
|
||||
current_user_id: str = "test_user" # TODO: Get from JWT token
|
||||
):
|
||||
"""새로운 OAuth 애플리케이션 등록"""
|
||||
client_id = OAuthUtils.generate_client_id()
|
||||
client_secret = OAuthUtils.generate_client_secret()
|
||||
hashed_secret = OAuthUtils.hash_client_secret(client_secret)
|
||||
|
||||
# 기본 스코프 가져오기
|
||||
default_scopes = await OAuthScope.find(OAuthScope.is_default == True).to_list()
|
||||
allowed_scopes = [scope.name for scope in default_scopes]
|
||||
|
||||
application = OAuthApplication(
|
||||
client_id=client_id,
|
||||
client_secret=hashed_secret,
|
||||
name=app_data.name,
|
||||
description=app_data.description,
|
||||
owner_id=current_user_id,
|
||||
redirect_uris=app_data.redirect_uris,
|
||||
allowed_scopes=allowed_scopes,
|
||||
grant_types=[GrantType.AUTHORIZATION_CODE, GrantType.REFRESH_TOKEN],
|
||||
sso_enabled=app_data.sso_enabled or False,
|
||||
sso_provider=app_data.sso_provider,
|
||||
sso_config=app_data.sso_config or {},
|
||||
allowed_domains=app_data.allowed_domains or [],
|
||||
website_url=app_data.website_url,
|
||||
logo_url=app_data.logo_url,
|
||||
privacy_policy_url=app_data.privacy_policy_url,
|
||||
terms_url=app_data.terms_url
|
||||
)
|
||||
|
||||
await application.create()
|
||||
|
||||
# 이벤트 발행
|
||||
if kafka_producer:
|
||||
event = Event(
|
||||
event_type=EventType.TASK_CREATED,
|
||||
service="oauth",
|
||||
data={
|
||||
"app_id": str(application.id),
|
||||
"client_id": client_id,
|
||||
"name": application.name,
|
||||
"owner_id": current_user_id
|
||||
}
|
||||
)
|
||||
await kafka_producer.send_event("oauth-events", event)
|
||||
|
||||
# 클라이언트 시크릿은 생성 시에만 반환
|
||||
return {
|
||||
**ApplicationResponse(
|
||||
id=str(application.id),
|
||||
client_id=application.client_id,
|
||||
name=application.name,
|
||||
description=application.description,
|
||||
redirect_uris=application.redirect_uris,
|
||||
allowed_scopes=application.allowed_scopes,
|
||||
grant_types=[gt.value for gt in application.grant_types],
|
||||
is_active=application.is_active,
|
||||
is_trusted=application.is_trusted,
|
||||
sso_enabled=application.sso_enabled,
|
||||
sso_provider=application.sso_provider,
|
||||
allowed_domains=application.allowed_domains,
|
||||
website_url=application.website_url,
|
||||
logo_url=application.logo_url,
|
||||
created_at=application.created_at
|
||||
).dict(),
|
||||
"client_secret": client_secret # 최초 생성 시에만 반환
|
||||
}
|
||||
|
||||
@app.get("/applications", response_model=List[ApplicationResponse])
|
||||
async def list_applications(
|
||||
owner_id: Optional[str] = None,
|
||||
is_active: Optional[bool] = None
|
||||
):
|
||||
"""OAuth 애플리케이션 목록 조회"""
|
||||
query = {}
|
||||
if owner_id:
|
||||
query["owner_id"] = owner_id
|
||||
if is_active is not None:
|
||||
query["is_active"] = is_active
|
||||
|
||||
applications = await OAuthApplication.find(query).to_list()
|
||||
|
||||
return [
|
||||
ApplicationResponse(
|
||||
id=str(app.id),
|
||||
client_id=app.client_id,
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
redirect_uris=app.redirect_uris,
|
||||
allowed_scopes=app.allowed_scopes,
|
||||
grant_types=[gt.value for gt in app.grant_types],
|
||||
is_active=app.is_active,
|
||||
is_trusted=app.is_trusted,
|
||||
sso_enabled=app.sso_enabled,
|
||||
sso_provider=app.sso_provider,
|
||||
allowed_domains=app.allowed_domains,
|
||||
website_url=app.website_url,
|
||||
logo_url=app.logo_url,
|
||||
created_at=app.created_at
|
||||
)
|
||||
for app in applications
|
||||
]
|
||||
|
||||
@app.get("/applications/{client_id}", response_model=ApplicationResponse)
|
||||
async def get_application(client_id: str):
|
||||
"""OAuth 애플리케이션 상세 조회"""
|
||||
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
|
||||
if not application:
|
||||
raise HTTPException(status_code=404, detail="Application not found")
|
||||
|
||||
return ApplicationResponse(
|
||||
id=str(application.id),
|
||||
client_id=application.client_id,
|
||||
name=application.name,
|
||||
description=application.description,
|
||||
redirect_uris=application.redirect_uris,
|
||||
allowed_scopes=application.allowed_scopes,
|
||||
grant_types=[gt.value for gt in application.grant_types],
|
||||
is_active=application.is_active,
|
||||
is_trusted=application.is_trusted,
|
||||
sso_enabled=application.sso_enabled,
|
||||
sso_provider=application.sso_provider,
|
||||
allowed_domains=application.allowed_domains,
|
||||
website_url=application.website_url,
|
||||
logo_url=application.logo_url,
|
||||
created_at=application.created_at
|
||||
)
|
||||
|
||||
# OAuth 2.0 Authorization Endpoint
|
||||
@app.get("/authorize")
|
||||
async def authorize(
|
||||
response_type: str = Query(..., description="응답 타입 (code, token)"),
|
||||
client_id: str = Query(..., description="클라이언트 ID"),
|
||||
redirect_uri: str = Query(..., description="리다이렉트 URI"),
|
||||
scope: str = Query("", description="요청 스코프"),
|
||||
state: Optional[str] = Query(None, description="상태 값"),
|
||||
code_challenge: Optional[str] = Query(None, description="PKCE challenge"),
|
||||
code_challenge_method: Optional[str] = Query("S256", description="PKCE method"),
|
||||
current_user_id: str = "test_user" # TODO: Get from session/JWT
|
||||
):
|
||||
"""OAuth 2.0 인증 엔드포인트"""
|
||||
|
||||
# 애플리케이션 확인
|
||||
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
|
||||
if not application or not application.is_active:
|
||||
raise HTTPException(status_code=400, detail="Invalid client")
|
||||
|
||||
# 리다이렉트 URI 확인
|
||||
if redirect_uri not in application.redirect_uris:
|
||||
raise HTTPException(status_code=400, detail="Invalid redirect URI")
|
||||
|
||||
# 스코프 검증
|
||||
requested_scopes = ScopeValidator.parse_scope_string(scope)
|
||||
valid_scopes = ScopeValidator.validate_scopes(requested_scopes, application.allowed_scopes)
|
||||
|
||||
# 사용자 동의 확인 (신뢰할 수 있는 앱이거나 이미 동의한 경우 건너뛰기)
|
||||
if not application.is_trusted:
|
||||
consent = await UserConsent.find_one(
|
||||
UserConsent.user_id == current_user_id,
|
||||
UserConsent.client_id == client_id
|
||||
)
|
||||
|
||||
if not consent or set(valid_scopes) - set(consent.granted_scopes):
|
||||
# TODO: 동의 화면으로 리다이렉트
|
||||
pass
|
||||
|
||||
if response_type == "code":
|
||||
# Authorization Code Flow
|
||||
code = OAuthUtils.generate_authorization_code()
|
||||
|
||||
auth_code = AuthorizationCode(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
user_id=current_user_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scopes=valid_scopes,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
expires_at=datetime.now() + timedelta(minutes=10)
|
||||
)
|
||||
|
||||
await auth_code.create()
|
||||
|
||||
# 리다이렉트 URL 생성
|
||||
redirect_url = f"{redirect_uri}?code={code}"
|
||||
if state:
|
||||
redirect_url += f"&state={state}"
|
||||
|
||||
return RedirectResponse(url=redirect_url)
|
||||
|
||||
elif response_type == "token":
|
||||
# Implicit Flow (권장하지 않음)
|
||||
raise HTTPException(status_code=400, detail="Implicit flow not supported")
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported response type")
|
||||
|
||||
# OAuth 2.0 Token Endpoint
|
||||
@app.post("/token")
|
||||
async def token(
|
||||
grant_type: str = Form(...),
|
||||
code: Optional[str] = Form(None),
|
||||
redirect_uri: Optional[str] = Form(None),
|
||||
client_id: Optional[str] = Form(None),
|
||||
client_secret: Optional[str] = Form(None),
|
||||
refresh_token: Optional[str] = Form(None),
|
||||
scope: Optional[str] = Form(None),
|
||||
code_verifier: Optional[str] = Form(None)
|
||||
):
|
||||
"""OAuth 2.0 토큰 엔드포인트"""
|
||||
|
||||
# 클라이언트 인증
|
||||
if not client_id or not client_secret:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Client authentication required",
|
||||
headers={"WWW-Authenticate": "Basic"}
|
||||
)
|
||||
|
||||
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
|
||||
if not application or not OAuthUtils.verify_client_secret(client_secret, application.client_secret):
|
||||
raise HTTPException(status_code=401, detail="Invalid client credentials")
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
# Authorization Code Grant
|
||||
if not code or not redirect_uri:
|
||||
raise HTTPException(status_code=400, detail="Missing required parameters")
|
||||
|
||||
auth_code = await AuthorizationCode.find_one(
|
||||
AuthorizationCode.code == code,
|
||||
AuthorizationCode.client_id == client_id
|
||||
)
|
||||
|
||||
if not auth_code:
|
||||
raise HTTPException(status_code=400, detail="Invalid authorization code")
|
||||
|
||||
if auth_code.used:
|
||||
raise HTTPException(status_code=400, detail="Authorization code already used")
|
||||
|
||||
if auth_code.expires_at < datetime.now():
|
||||
raise HTTPException(status_code=400, detail="Authorization code expired")
|
||||
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise HTTPException(status_code=400, detail="Redirect URI mismatch")
|
||||
|
||||
# PKCE 검증
|
||||
if auth_code.code_challenge:
|
||||
if not code_verifier:
|
||||
raise HTTPException(status_code=400, detail="Code verifier required")
|
||||
|
||||
if not OAuthUtils.verify_pkce_challenge(
|
||||
code_verifier,
|
||||
auth_code.code_challenge,
|
||||
auth_code.code_challenge_method
|
||||
):
|
||||
raise HTTPException(status_code=400, detail="Invalid code verifier")
|
||||
|
||||
# 코드를 사용됨으로 표시
|
||||
auth_code.used = True
|
||||
auth_code.used_at = datetime.now()
|
||||
await auth_code.save()
|
||||
|
||||
# 토큰 생성
|
||||
access_token = OAuthUtils.generate_access_token()
|
||||
refresh_token = OAuthUtils.generate_refresh_token()
|
||||
|
||||
token_doc = AccessToken(
|
||||
token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
client_id=client_id,
|
||||
user_id=auth_code.user_id,
|
||||
scopes=auth_code.scopes,
|
||||
expires_at=datetime.now() + timedelta(hours=1),
|
||||
refresh_expires_at=datetime.now() + timedelta(days=30)
|
||||
)
|
||||
|
||||
await token_doc.create()
|
||||
|
||||
return TokenGenerator.generate_token_response(
|
||||
access_token=access_token,
|
||||
expires_in=3600,
|
||||
refresh_token=refresh_token,
|
||||
scope=" ".join(auth_code.scopes)
|
||||
)
|
||||
|
||||
elif grant_type == "refresh_token":
|
||||
# Refresh Token Grant
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=400, detail="Refresh token required")
|
||||
|
||||
token_doc = await AccessToken.find_one(
|
||||
AccessToken.refresh_token == refresh_token,
|
||||
AccessToken.client_id == client_id
|
||||
)
|
||||
|
||||
if not token_doc:
|
||||
raise HTTPException(status_code=400, detail="Invalid refresh token")
|
||||
|
||||
if token_doc.revoked:
|
||||
raise HTTPException(status_code=400, detail="Token has been revoked")
|
||||
|
||||
if token_doc.refresh_expires_at and token_doc.refresh_expires_at < datetime.now():
|
||||
raise HTTPException(status_code=400, detail="Refresh token expired")
|
||||
|
||||
# 기존 토큰 폐기
|
||||
token_doc.revoked = True
|
||||
token_doc.revoked_at = datetime.now()
|
||||
await token_doc.save()
|
||||
|
||||
# 새 토큰 생성
|
||||
new_access_token = OAuthUtils.generate_access_token()
|
||||
new_refresh_token = OAuthUtils.generate_refresh_token()
|
||||
|
||||
new_token_doc = AccessToken(
|
||||
token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
client_id=client_id,
|
||||
user_id=token_doc.user_id,
|
||||
scopes=token_doc.scopes,
|
||||
expires_at=datetime.now() + timedelta(hours=1),
|
||||
refresh_expires_at=datetime.now() + timedelta(days=30)
|
||||
)
|
||||
|
||||
await new_token_doc.create()
|
||||
|
||||
return TokenGenerator.generate_token_response(
|
||||
access_token=new_access_token,
|
||||
expires_in=3600,
|
||||
refresh_token=new_refresh_token,
|
||||
scope=" ".join(token_doc.scopes)
|
||||
)
|
||||
|
||||
elif grant_type == "client_credentials":
|
||||
# Client Credentials Grant
|
||||
requested_scopes = ScopeValidator.parse_scope_string(scope) if scope else []
|
||||
valid_scopes = ScopeValidator.validate_scopes(requested_scopes, application.allowed_scopes)
|
||||
|
||||
access_token = OAuthUtils.generate_access_token()
|
||||
|
||||
token_doc = AccessToken(
|
||||
token=access_token,
|
||||
client_id=client_id,
|
||||
scopes=valid_scopes,
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
)
|
||||
|
||||
await token_doc.create()
|
||||
|
||||
return TokenGenerator.generate_token_response(
|
||||
access_token=access_token,
|
||||
expires_in=3600,
|
||||
scope=" ".join(valid_scopes)
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported grant type")
|
||||
|
||||
# Token Introspection Endpoint
|
||||
@app.post("/introspect")
|
||||
async def introspect(
|
||||
token: str = Form(...),
|
||||
token_type_hint: Optional[str] = Form(None),
|
||||
client_id: str = Form(...),
|
||||
client_secret: str = Form(...)
|
||||
):
|
||||
"""토큰 검증 엔드포인트"""
|
||||
|
||||
# 클라이언트 인증
|
||||
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
|
||||
if not application or not OAuthUtils.verify_client_secret(client_secret, application.client_secret):
|
||||
raise HTTPException(status_code=401, detail="Invalid client credentials")
|
||||
|
||||
# 토큰 조회
|
||||
token_doc = await AccessToken.find_one(AccessToken.token == token)
|
||||
|
||||
if not token_doc or token_doc.revoked or token_doc.expires_at < datetime.now():
|
||||
return {"active": False}
|
||||
|
||||
# 토큰 사용 시간 업데이트
|
||||
token_doc.last_used_at = datetime.now()
|
||||
await token_doc.save()
|
||||
|
||||
return {
|
||||
"active": True,
|
||||
"scope": " ".join(token_doc.scopes),
|
||||
"client_id": token_doc.client_id,
|
||||
"username": token_doc.user_id,
|
||||
"exp": int(token_doc.expires_at.timestamp())
|
||||
}
|
||||
|
||||
# Token Revocation Endpoint
|
||||
@app.post("/revoke")
|
||||
async def revoke(
|
||||
token: str = Form(...),
|
||||
token_type_hint: Optional[str] = Form(None),
|
||||
client_id: str = Form(...),
|
||||
client_secret: str = Form(...)
|
||||
):
|
||||
"""토큰 폐기 엔드포인트"""
|
||||
|
||||
# 클라이언트 인증
|
||||
application = await OAuthApplication.find_one(OAuthApplication.client_id == client_id)
|
||||
if not application or not OAuthUtils.verify_client_secret(client_secret, application.client_secret):
|
||||
raise HTTPException(status_code=401, detail="Invalid client credentials")
|
||||
|
||||
# 토큰 조회 및 폐기
|
||||
token_doc = await AccessToken.find_one(
|
||||
AccessToken.token == token,
|
||||
AccessToken.client_id == client_id
|
||||
)
|
||||
|
||||
if token_doc and not token_doc.revoked:
|
||||
token_doc.revoked = True
|
||||
token_doc.revoked_at = datetime.now()
|
||||
await token_doc.save()
|
||||
|
||||
# 이벤트 발행
|
||||
if kafka_producer:
|
||||
event = Event(
|
||||
event_type=EventType.TASK_COMPLETED,
|
||||
service="oauth",
|
||||
data={
|
||||
"action": "token_revoked",
|
||||
"token_id": str(token_doc.id),
|
||||
"client_id": client_id
|
||||
}
|
||||
)
|
||||
await kafka_producer.send_event("oauth-events", event)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
# Scopes Management
|
||||
@app.get("/scopes")
|
||||
async def list_scopes():
|
||||
"""사용 가능한 스코프 목록 조회"""
|
||||
scopes = await OAuthScope.find_all().to_list()
|
||||
return [
|
||||
{
|
||||
"name": scope.name,
|
||||
"display_name": scope.display_name,
|
||||
"description": scope.description,
|
||||
"is_default": scope.is_default,
|
||||
"requires_approval": scope.requires_approval
|
||||
}
|
||||
for scope in scopes
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True
|
||||
)
|
||||
126
services/oauth/backend/models.py
Normal file
126
services/oauth/backend/models.py
Normal file
@ -0,0 +1,126 @@
|
||||
from beanie import Document, PydanticObjectId
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
class GrantType(str, Enum):
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
PASSWORD = "password"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
CODE = "code"
|
||||
TOKEN = "token"
|
||||
|
||||
class TokenType(str, Enum):
|
||||
BEARER = "Bearer"
|
||||
|
||||
class OAuthApplication(Document):
|
||||
"""OAuth 2.0 클라이언트 애플리케이션"""
|
||||
client_id: str = Field(..., unique=True, description="클라이언트 ID")
|
||||
client_secret: str = Field(..., description="클라이언트 시크릿 (해시됨)")
|
||||
name: str = Field(..., description="애플리케이션 이름")
|
||||
description: Optional[str] = Field(None, description="애플리케이션 설명")
|
||||
|
||||
owner_id: str = Field(..., description="애플리케이션 소유자 ID")
|
||||
|
||||
redirect_uris: List[str] = Field(default_factory=list, description="허용된 리다이렉트 URI들")
|
||||
allowed_scopes: List[str] = Field(default_factory=list, description="허용된 스코프들")
|
||||
grant_types: List[GrantType] = Field(default_factory=lambda: [GrantType.AUTHORIZATION_CODE], description="허용된 grant types")
|
||||
|
||||
is_active: bool = Field(default=True, description="활성화 상태")
|
||||
is_trusted: bool = Field(default=False, description="신뢰할 수 있는 앱 (자동 승인)")
|
||||
|
||||
# SSO 설정
|
||||
sso_enabled: bool = Field(default=False, description="SSO 활성화 여부")
|
||||
sso_provider: Optional[str] = Field(None, description="SSO 제공자 (google, github, saml 등)")
|
||||
sso_config: Optional[Dict] = Field(default_factory=dict, description="SSO 설정 (provider별 설정)")
|
||||
allowed_domains: List[str] = Field(default_factory=list, description="SSO 허용 도메인 (예: @company.com)")
|
||||
|
||||
website_url: Optional[str] = Field(None, description="애플리케이션 웹사이트")
|
||||
logo_url: Optional[str] = Field(None, description="애플리케이션 로고 URL")
|
||||
privacy_policy_url: Optional[str] = Field(None, description="개인정보 처리방침 URL")
|
||||
terms_url: Optional[str] = Field(None, description="이용약관 URL")
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Settings:
|
||||
collection = "oauth_applications"
|
||||
|
||||
class AuthorizationCode(Document):
|
||||
"""OAuth 2.0 인증 코드"""
|
||||
code: str = Field(..., unique=True, description="인증 코드")
|
||||
client_id: str = Field(..., description="클라이언트 ID")
|
||||
user_id: str = Field(..., description="사용자 ID")
|
||||
|
||||
redirect_uri: str = Field(..., description="리다이렉트 URI")
|
||||
scopes: List[str] = Field(default_factory=list, description="요청된 스코프")
|
||||
|
||||
code_challenge: Optional[str] = Field(None, description="PKCE code challenge")
|
||||
code_challenge_method: Optional[str] = Field(None, description="PKCE challenge method")
|
||||
|
||||
expires_at: datetime = Field(..., description="만료 시간")
|
||||
used: bool = Field(default=False, description="사용 여부")
|
||||
used_at: Optional[datetime] = Field(None, description="사용 시간")
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Settings:
|
||||
collection = "authorization_codes"
|
||||
|
||||
class AccessToken(Document):
|
||||
"""OAuth 2.0 액세스 토큰"""
|
||||
token: str = Field(..., unique=True, description="액세스 토큰")
|
||||
refresh_token: Optional[str] = Field(None, description="리프레시 토큰")
|
||||
|
||||
client_id: str = Field(..., description="클라이언트 ID")
|
||||
user_id: Optional[str] = Field(None, description="사용자 ID (client credentials flow에서는 없음)")
|
||||
|
||||
token_type: TokenType = Field(default=TokenType.BEARER)
|
||||
scopes: List[str] = Field(default_factory=list, description="부여된 스코프")
|
||||
|
||||
expires_at: datetime = Field(..., description="액세스 토큰 만료 시간")
|
||||
refresh_expires_at: Optional[datetime] = Field(None, description="리프레시 토큰 만료 시간")
|
||||
|
||||
revoked: bool = Field(default=False, description="폐기 여부")
|
||||
revoked_at: Optional[datetime] = Field(None, description="폐기 시간")
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
last_used_at: Optional[datetime] = Field(None, description="마지막 사용 시간")
|
||||
|
||||
class Settings:
|
||||
collection = "access_tokens"
|
||||
|
||||
class OAuthScope(Document):
|
||||
"""OAuth 스코프 정의"""
|
||||
name: str = Field(..., unique=True, description="스코프 이름 (예: read:profile)")
|
||||
display_name: str = Field(..., description="표시 이름")
|
||||
description: str = Field(..., description="스코프 설명")
|
||||
|
||||
is_default: bool = Field(default=False, description="기본 스코프 여부")
|
||||
requires_approval: bool = Field(default=True, description="사용자 승인 필요 여부")
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Settings:
|
||||
collection = "oauth_scopes"
|
||||
|
||||
class UserConsent(Document):
|
||||
"""사용자 동의 기록"""
|
||||
user_id: str = Field(..., description="사용자 ID")
|
||||
client_id: str = Field(..., description="클라이언트 ID")
|
||||
|
||||
granted_scopes: List[str] = Field(default_factory=list, description="승인된 스코프")
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
expires_at: Optional[datetime] = Field(None, description="동의 만료 시간")
|
||||
|
||||
class Settings:
|
||||
collection = "user_consents"
|
||||
indexes = [
|
||||
[("user_id", 1), ("client_id", 1)]
|
||||
]
|
||||
11
services/oauth/backend/requirements.txt
Normal file
11
services/oauth/backend/requirements.txt
Normal file
@ -0,0 +1,11 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
pydantic[email]==2.5.3
|
||||
pymongo==4.6.1
|
||||
motor==3.3.2
|
||||
beanie==1.23.6
|
||||
authlib==1.3.0
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
python-multipart==0.0.6
|
||||
aiokafka==0.10.0
|
||||
131
services/oauth/backend/utils.py
Normal file
131
services/oauth/backend/utils.py
Normal file
@ -0,0 +1,131 @@
|
||||
import secrets
|
||||
import hashlib
|
||||
import base64
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
from passlib.context import CryptContext
|
||||
from jose import JWTError, jwt
|
||||
import os
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
class OAuthUtils:
|
||||
@staticmethod
|
||||
def generate_client_id() -> str:
|
||||
"""클라이언트 ID 생성"""
|
||||
return secrets.token_urlsafe(24)
|
||||
|
||||
@staticmethod
|
||||
def generate_client_secret() -> str:
|
||||
"""클라이언트 시크릿 생성"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def hash_client_secret(secret: str) -> str:
|
||||
"""클라이언트 시크릿 해싱"""
|
||||
return pwd_context.hash(secret)
|
||||
|
||||
@staticmethod
|
||||
def verify_client_secret(plain_secret: str, hashed_secret: str) -> bool:
|
||||
"""클라이언트 시크릿 검증"""
|
||||
return pwd_context.verify(plain_secret, hashed_secret)
|
||||
|
||||
@staticmethod
|
||||
def generate_authorization_code() -> str:
|
||||
"""인증 코드 생성"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def generate_access_token() -> str:
|
||||
"""액세스 토큰 생성"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def generate_refresh_token() -> str:
|
||||
"""리프레시 토큰 생성"""
|
||||
return secrets.token_urlsafe(48)
|
||||
|
||||
@staticmethod
|
||||
def verify_pkce_challenge(verifier: str, challenge: str, method: str = "S256") -> bool:
|
||||
"""PKCE challenge 검증"""
|
||||
if method == "plain":
|
||||
return verifier == challenge
|
||||
elif method == "S256":
|
||||
verifier_hash = hashlib.sha256(verifier.encode()).digest()
|
||||
verifier_challenge = base64.urlsafe_b64encode(verifier_hash).decode().rstrip("=")
|
||||
return verifier_challenge == challenge
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def create_jwt_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""JWT 토큰 생성"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key")
|
||||
algorithm = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def decode_jwt_token(token: str) -> Optional[dict]:
|
||||
"""JWT 토큰 디코딩"""
|
||||
try:
|
||||
secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key")
|
||||
algorithm = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
|
||||
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
class TokenGenerator:
|
||||
@staticmethod
|
||||
def generate_token_response(
|
||||
access_token: str,
|
||||
token_type: str = "Bearer",
|
||||
expires_in: int = 3600,
|
||||
refresh_token: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
id_token: Optional[str] = None
|
||||
) -> dict:
|
||||
"""OAuth 2.0 토큰 응답 생성"""
|
||||
response = {
|
||||
"access_token": access_token,
|
||||
"token_type": token_type,
|
||||
"expires_in": expires_in
|
||||
}
|
||||
|
||||
if refresh_token:
|
||||
response["refresh_token"] = refresh_token
|
||||
|
||||
if scope:
|
||||
response["scope"] = scope
|
||||
|
||||
if id_token:
|
||||
response["id_token"] = id_token
|
||||
|
||||
return response
|
||||
|
||||
class ScopeValidator:
|
||||
@staticmethod
|
||||
def validate_scopes(requested_scopes: List[str], allowed_scopes: List[str]) -> List[str]:
|
||||
"""요청된 스코프가 허용된 스코프에 포함되는지 검증"""
|
||||
return [scope for scope in requested_scopes if scope in allowed_scopes]
|
||||
|
||||
@staticmethod
|
||||
def has_scope(token_scopes: List[str], required_scope: str) -> bool:
|
||||
"""토큰이 특정 스코프를 가지고 있는지 확인"""
|
||||
return required_scope in token_scopes
|
||||
|
||||
@staticmethod
|
||||
def parse_scope_string(scope_string: str) -> List[str]:
|
||||
"""스코프 문자열을 리스트로 파싱"""
|
||||
if not scope_string:
|
||||
return []
|
||||
return scope_string.strip().split()
|
||||
90
services/pipeline/Makefile
Normal file
90
services/pipeline/Makefile
Normal file
@ -0,0 +1,90 @@
|
||||
# Pipeline Makefile
|
||||
|
||||
.PHONY: help build up down restart logs clean test monitor
|
||||
|
||||
help:
|
||||
@echo "Pipeline Management Commands:"
|
||||
@echo " make build - Build all Docker images"
|
||||
@echo " make up - Start all services"
|
||||
@echo " make down - Stop all services"
|
||||
@echo " make restart - Restart all services"
|
||||
@echo " make logs - View logs for all services"
|
||||
@echo " make clean - Clean up containers and volumes"
|
||||
@echo " make monitor - Open monitor dashboard"
|
||||
@echo " make test - Test pipeline with sample keyword"
|
||||
|
||||
build:
|
||||
docker-compose build
|
||||
|
||||
up:
|
||||
docker-compose up -d
|
||||
|
||||
down:
|
||||
docker-compose down
|
||||
|
||||
restart:
|
||||
docker-compose restart
|
||||
|
||||
logs:
|
||||
docker-compose logs -f
|
||||
|
||||
clean:
|
||||
docker-compose down -v
|
||||
docker system prune -f
|
||||
|
||||
monitor:
|
||||
@echo "Opening monitor dashboard..."
|
||||
@echo "Dashboard: http://localhost:8100"
|
||||
@echo "API Docs: http://localhost:8100/docs"
|
||||
|
||||
test:
|
||||
@echo "Testing pipeline with sample keyword..."
|
||||
curl -X POST http://localhost:8100/api/keywords \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"keyword": "테스트", "schedule": "30min"}'
|
||||
@echo "\nTriggering immediate processing..."
|
||||
curl -X POST http://localhost:8100/api/trigger/테스트
|
||||
|
||||
# Service-specific commands
|
||||
scheduler-logs:
|
||||
docker-compose logs -f scheduler
|
||||
|
||||
rss-logs:
|
||||
docker-compose logs -f rss-collector
|
||||
|
||||
search-logs:
|
||||
docker-compose logs -f google-search
|
||||
|
||||
summarizer-logs:
|
||||
docker-compose logs -f ai-summarizer
|
||||
|
||||
assembly-logs:
|
||||
docker-compose logs -f article-assembly
|
||||
|
||||
monitor-logs:
|
||||
docker-compose logs -f monitor
|
||||
|
||||
# Database commands
|
||||
redis-cli:
|
||||
docker-compose exec redis redis-cli
|
||||
|
||||
mongo-shell:
|
||||
docker-compose exec mongodb mongosh -u admin -p password123
|
||||
|
||||
# Queue management
|
||||
queue-status:
|
||||
@echo "Checking queue status..."
|
||||
docker-compose exec redis redis-cli --raw LLEN queue:keyword
|
||||
docker-compose exec redis redis-cli --raw LLEN queue:rss
|
||||
docker-compose exec redis redis-cli --raw LLEN queue:search
|
||||
docker-compose exec redis redis-cli --raw LLEN queue:summarize
|
||||
docker-compose exec redis redis-cli --raw LLEN queue:assembly
|
||||
|
||||
queue-clear:
|
||||
@echo "Clearing all queues..."
|
||||
docker-compose exec redis redis-cli FLUSHDB
|
||||
|
||||
# Health check
|
||||
health:
|
||||
@echo "Checking service health..."
|
||||
curl -s http://localhost:8100/api/health | python3 -m json.tool
|
||||
154
services/pipeline/README.md
Normal file
154
services/pipeline/README.md
Normal file
@ -0,0 +1,154 @@
|
||||
# News Pipeline System
|
||||
|
||||
비동기 큐 기반 뉴스 생성 파이프라인 시스템
|
||||
|
||||
## 아키텍처
|
||||
|
||||
```
|
||||
Scheduler → RSS Collector → Google Search → AI Summarizer → Article Assembly → MongoDB
|
||||
↓ ↓ ↓ ↓ ↓
|
||||
Redis Queue Redis Queue Redis Queue Redis Queue Redis Queue
|
||||
```
|
||||
|
||||
## 서비스 구성
|
||||
|
||||
### 1. Scheduler
|
||||
- 30분마다 등록된 키워드 처리
|
||||
- 오전 7시, 낮 12시, 저녁 6시 우선 처리
|
||||
- MongoDB에서 키워드 로드 후 큐에 작업 생성
|
||||
|
||||
### 2. RSS Collector
|
||||
- RSS 피드 수집 (Google News RSS)
|
||||
- 7일간 중복 방지 (Redis Set)
|
||||
- 키워드 관련성 필터링
|
||||
|
||||
### 3. Google Search
|
||||
- RSS 아이템별 추가 검색 결과 수집
|
||||
- 아이템당 최대 3개 결과
|
||||
- 작업당 최대 5개 아이템 처리
|
||||
|
||||
### 4. AI Summarizer
|
||||
- Claude Haiku로 빠른 요약 생성
|
||||
- 200자 이내 한국어 요약
|
||||
- 병렬 처리 지원 (3 workers)
|
||||
|
||||
### 5. Article Assembly
|
||||
- Claude Sonnet으로 종합 기사 작성
|
||||
- 1500자 이내 전문 기사
|
||||
- MongoDB 저장 및 통계 업데이트
|
||||
|
||||
### 6. Monitor
|
||||
- 실시간 파이프라인 모니터링
|
||||
- 큐 상태, 워커 상태 확인
|
||||
- REST API 제공 (포트 8100)
|
||||
|
||||
## 시작하기
|
||||
|
||||
### 1. 환경 변수 설정
|
||||
```bash
|
||||
# .env 파일 확인
|
||||
CLAUDE_API_KEY=your_claude_api_key
|
||||
GOOGLE_API_KEY=your_google_api_key
|
||||
GOOGLE_SEARCH_ENGINE_ID=your_search_engine_id
|
||||
```
|
||||
|
||||
### 2. 서비스 시작
|
||||
```bash
|
||||
cd pipeline
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
### 3. 모니터링
|
||||
```bash
|
||||
# 로그 확인
|
||||
docker-compose logs -f
|
||||
|
||||
# 특정 서비스 로그
|
||||
docker-compose logs -f scheduler
|
||||
|
||||
# 모니터 API
|
||||
curl http://localhost:8100/api/stats
|
||||
```
|
||||
|
||||
## API 엔드포인트
|
||||
|
||||
### Monitor API (포트 8100)
|
||||
|
||||
- `GET /api/stats` - 전체 통계
|
||||
- `GET /api/queues/{queue_name}` - 큐 상세 정보
|
||||
- `GET /api/keywords` - 키워드 목록
|
||||
- `POST /api/keywords` - 키워드 등록
|
||||
- `DELETE /api/keywords/{id}` - 키워드 삭제
|
||||
- `GET /api/articles` - 기사 목록
|
||||
- `GET /api/articles/{id}` - 기사 상세
|
||||
- `GET /api/workers` - 워커 상태
|
||||
- `POST /api/trigger/{keyword}` - 수동 처리 트리거
|
||||
- `GET /api/health` - 헬스 체크
|
||||
|
||||
## 키워드 등록 예시
|
||||
|
||||
```bash
|
||||
# 새 키워드 등록
|
||||
curl -X POST http://localhost:8100/api/keywords \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"keyword": "인공지능", "schedule": "30min"}'
|
||||
|
||||
# 수동 처리 트리거
|
||||
curl -X POST http://localhost:8100/api/trigger/인공지능
|
||||
```
|
||||
|
||||
## 데이터베이스
|
||||
|
||||
### MongoDB Collections
|
||||
- `keywords` - 등록된 키워드
|
||||
- `articles` - 생성된 기사
|
||||
- `keyword_stats` - 키워드별 통계
|
||||
|
||||
### Redis Keys
|
||||
- `queue:*` - 작업 큐
|
||||
- `processing:*` - 처리 중 작업
|
||||
- `failed:*` - 실패한 작업
|
||||
- `dedup:rss:*` - RSS 중복 방지
|
||||
- `workers:*:active` - 활성 워커
|
||||
|
||||
## 트러블슈팅
|
||||
|
||||
### 큐 초기화
|
||||
```bash
|
||||
docker-compose exec redis redis-cli FLUSHDB
|
||||
```
|
||||
|
||||
### 워커 재시작
|
||||
```bash
|
||||
docker-compose restart rss-collector
|
||||
```
|
||||
|
||||
### 데이터베이스 접속
|
||||
```bash
|
||||
# MongoDB
|
||||
docker-compose exec mongodb mongosh -u admin -p password123
|
||||
|
||||
# Redis
|
||||
docker-compose exec redis redis-cli
|
||||
```
|
||||
|
||||
## 스케일링
|
||||
|
||||
워커 수 조정:
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
ai-summarizer:
|
||||
deploy:
|
||||
replicas: 5 # 워커 수 증가
|
||||
```
|
||||
|
||||
## 모니터링 대시보드
|
||||
|
||||
브라우저에서 http://localhost:8100 접속하여 파이프라인 상태 확인
|
||||
|
||||
## 로그 레벨 설정
|
||||
|
||||
`.env` 파일에서 조정:
|
||||
```
|
||||
LOG_LEVEL=DEBUG # INFO, WARNING, ERROR
|
||||
```
|
||||
19
services/pipeline/ai-article-generator/Dockerfile
Normal file
19
services/pipeline/ai-article-generator/Dockerfile
Normal file
@ -0,0 +1,19 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 의존성 설치
|
||||
COPY ./ai-article-generator/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 공통 모듈 복사
|
||||
COPY ./shared /app/shared
|
||||
|
||||
# AI Article Generator 코드 복사
|
||||
COPY ./ai-article-generator /app
|
||||
|
||||
# 환경변수
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 실행
|
||||
CMD ["python", "ai_article_generator.py"]
|
||||
300
services/pipeline/ai-article-generator/ai_article_generator.py
Normal file
300
services/pipeline/ai-article-generator/ai_article_generator.py
Normal file
@ -0,0 +1,300 @@
|
||||
"""
|
||||
AI Article Generator Service
|
||||
Claude API를 사용한 뉴스 기사 생성 서비스
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from anthropic import AsyncAnthropic
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import PipelineJob, EnrichedItem, FinalArticle, Subtopic, Entities, NewsReference
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AIArticleGeneratorWorker:
|
||||
def __init__(self):
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
self.claude_api_key = os.getenv("CLAUDE_API_KEY")
|
||||
self.claude_client = None
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db") # ai_writer_db 사용
|
||||
self.db = None
|
||||
|
||||
async def start(self):
|
||||
"""워커 시작"""
|
||||
logger.info("Starting AI Article Generator Worker")
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# Claude 클라이언트 초기화
|
||||
if self.claude_api_key:
|
||||
self.claude_client = AsyncAnthropic(api_key=self.claude_api_key)
|
||||
else:
|
||||
logger.error("Claude API key not configured")
|
||||
return
|
||||
|
||||
# 메인 처리 루프
|
||||
while True:
|
||||
try:
|
||||
# 큐에서 작업 가져오기
|
||||
job = await self.queue_manager.dequeue('ai_article_generation', timeout=5)
|
||||
|
||||
if job:
|
||||
await self.process_job(job)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker loop: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_job(self, job: PipelineJob):
|
||||
"""AI 기사 생성 작업 처리 - 단일 RSS 아이템"""
|
||||
try:
|
||||
start_time = datetime.now()
|
||||
logger.info(f"Processing job {job.job_id} for AI article generation")
|
||||
|
||||
# 단일 enriched item 처리
|
||||
enriched_item_data = job.data.get('enriched_item')
|
||||
if not enriched_item_data:
|
||||
# 이전 버전 호환성
|
||||
enriched_items = job.data.get('enriched_items', [])
|
||||
if enriched_items:
|
||||
enriched_item_data = enriched_items[0]
|
||||
else:
|
||||
logger.warning(f"No enriched item in job {job.job_id}")
|
||||
await self.queue_manager.mark_failed(
|
||||
'ai_article_generation',
|
||||
job,
|
||||
"No enriched item to process"
|
||||
)
|
||||
return
|
||||
|
||||
enriched_item = EnrichedItem(**enriched_item_data)
|
||||
|
||||
# 기사 생성
|
||||
article = await self._generate_article(job, enriched_item)
|
||||
|
||||
# 처리 시간 계산
|
||||
processing_time = (datetime.now() - start_time).total_seconds()
|
||||
article.processing_time = processing_time
|
||||
|
||||
# MongoDB에 저장 (ai_writer_db.articles_ko)
|
||||
result = await self.db.articles_ko.insert_one(article.model_dump())
|
||||
mongodb_id = str(result.inserted_id)
|
||||
|
||||
logger.info(f"Article {article.news_id} saved to MongoDB with _id: {mongodb_id}")
|
||||
|
||||
# 다음 단계로 전달 (이미지 생성)
|
||||
job.data['news_id'] = article.news_id
|
||||
job.data['mongodb_id'] = mongodb_id
|
||||
job.stages_completed.append('ai_article_generation')
|
||||
job.stage = 'image_generation'
|
||||
|
||||
await self.queue_manager.enqueue('image_generation', job)
|
||||
await self.queue_manager.mark_completed('ai_article_generation', job.job_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing job {job.job_id}: {e}")
|
||||
await self.queue_manager.mark_failed('ai_article_generation', job, str(e))
|
||||
|
||||
async def _generate_article(self, job: PipelineJob, enriched_item: EnrichedItem) -> FinalArticle:
|
||||
"""Claude를 사용한 기사 생성"""
|
||||
|
||||
# RSS 아이템 정보
|
||||
rss_item = enriched_item.rss_item
|
||||
search_results = enriched_item.search_results
|
||||
|
||||
# 검색 결과 텍스트 준비 (최대 10개)
|
||||
search_text = ""
|
||||
if search_results:
|
||||
search_text = "\n관련 검색 결과:\n"
|
||||
for idx, result in enumerate(search_results[:10], 1):
|
||||
search_text += f"{idx}. {result.title}\n"
|
||||
if result.snippet:
|
||||
search_text += f" {result.snippet}\n"
|
||||
|
||||
# Claude로 기사 작성
|
||||
prompt = f"""다음 뉴스 정보를 바탕으로 상세한 기사를 작성해주세요.
|
||||
|
||||
키워드: {job.keyword}
|
||||
|
||||
뉴스 정보:
|
||||
제목: {rss_item.title}
|
||||
요약: {rss_item.summary or '내용 없음'}
|
||||
링크: {rss_item.link}
|
||||
{search_text}
|
||||
|
||||
다음 JSON 형식으로 작성해주세요:
|
||||
{{
|
||||
"title": "기사 제목 (50자 이내)",
|
||||
"summary": "한 줄 요약 (100자 이내)",
|
||||
"subtopics": [
|
||||
{{
|
||||
"title": "소제목1",
|
||||
"content": ["문단1", "문단2", "문단3"]
|
||||
}},
|
||||
{{
|
||||
"title": "소제목2",
|
||||
"content": ["문단1", "문단2"]
|
||||
}},
|
||||
{{
|
||||
"title": "소제목3",
|
||||
"content": ["문단1", "문단2"]
|
||||
}}
|
||||
],
|
||||
"categories": ["카테고리1", "카테고리2"],
|
||||
"entities": {{
|
||||
"people": ["인물1", "인물2"],
|
||||
"organizations": ["조직1", "조직2"],
|
||||
"groups": ["그룹1"],
|
||||
"countries": ["국가1"],
|
||||
"events": ["이벤트1"]
|
||||
}}
|
||||
}}
|
||||
|
||||
요구사항:
|
||||
- 3개의 소제목로 구성
|
||||
- 각 소제목별로 2-3개 문단
|
||||
- 전문적이고 객관적인 톤
|
||||
- 한국어로 작성
|
||||
- 실제 정보를 바탕으로 구체적으로 작성"""
|
||||
|
||||
try:
|
||||
response = await self.claude_client.messages.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=4000,
|
||||
temperature=0.7,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
# JSON 파싱
|
||||
content_text = response.content[0].text
|
||||
json_start = content_text.find('{')
|
||||
json_end = content_text.rfind('}') + 1
|
||||
|
||||
if json_start != -1 and json_end > json_start:
|
||||
article_data = json.loads(content_text[json_start:json_end])
|
||||
else:
|
||||
raise ValueError("No valid JSON in response")
|
||||
|
||||
# Subtopic 객체 생성
|
||||
subtopics = []
|
||||
for subtopic_data in article_data.get('subtopics', []):
|
||||
subtopics.append(Subtopic(
|
||||
title=subtopic_data.get('title', ''),
|
||||
content=subtopic_data.get('content', [])
|
||||
))
|
||||
|
||||
# Entities 객체 생성
|
||||
entities_data = article_data.get('entities', {})
|
||||
entities = Entities(
|
||||
people=entities_data.get('people', []),
|
||||
organizations=entities_data.get('organizations', []),
|
||||
groups=entities_data.get('groups', []),
|
||||
countries=entities_data.get('countries', []),
|
||||
events=entities_data.get('events', [])
|
||||
)
|
||||
|
||||
# 레퍼런스 생성
|
||||
references = []
|
||||
# RSS 원본 추가
|
||||
references.append(NewsReference(
|
||||
title=rss_item.title,
|
||||
link=rss_item.link,
|
||||
source=rss_item.source_feed,
|
||||
published=rss_item.published
|
||||
))
|
||||
|
||||
# 검색 결과 레퍼런스 추가 (최대 9개 - RSS 원본과 합쳐 총 10개)
|
||||
for search_result in search_results[:9]: # 상위 9개까지
|
||||
references.append(NewsReference(
|
||||
title=search_result.title,
|
||||
link=search_result.link,
|
||||
source=search_result.source,
|
||||
published=None
|
||||
))
|
||||
|
||||
# FinalArticle 생성 (ai_writer_db.articles 스키마)
|
||||
article = FinalArticle(
|
||||
title=article_data.get('title', rss_item.title),
|
||||
summary=article_data.get('summary', ''),
|
||||
subtopics=subtopics,
|
||||
categories=article_data.get('categories', []),
|
||||
entities=entities,
|
||||
source_keyword=job.keyword,
|
||||
source_count=len(references),
|
||||
references=references,
|
||||
job_id=job.job_id,
|
||||
keyword_id=job.keyword_id,
|
||||
pipeline_stages=job.stages_completed.copy(),
|
||||
language='ko',
|
||||
rss_guid=rss_item.guid # RSS GUID 저장
|
||||
)
|
||||
|
||||
return article
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating article: {e}")
|
||||
# 폴백 기사 생성
|
||||
fallback_references = [NewsReference(
|
||||
title=rss_item.title,
|
||||
link=rss_item.link,
|
||||
source=rss_item.source_feed,
|
||||
published=rss_item.published
|
||||
)]
|
||||
|
||||
return FinalArticle(
|
||||
title=rss_item.title,
|
||||
summary=rss_item.summary[:100] if rss_item.summary else '',
|
||||
subtopics=[
|
||||
Subtopic(
|
||||
title="주요 내용",
|
||||
content=[rss_item.summary or rss_item.title]
|
||||
)
|
||||
],
|
||||
categories=['자동생성'],
|
||||
entities=Entities(),
|
||||
source_keyword=job.keyword,
|
||||
source_count=1,
|
||||
references=fallback_references,
|
||||
job_id=job.job_id,
|
||||
keyword_id=job.keyword_id,
|
||||
pipeline_stages=job.stages_completed.copy(),
|
||||
language='ko',
|
||||
rss_guid=rss_item.guid # RSS GUID 저장
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""워커 중지"""
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info("AI Article Generator Worker stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
worker = AIArticleGeneratorWorker()
|
||||
|
||||
try:
|
||||
await worker.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await worker.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
5
services/pipeline/ai-article-generator/requirements.txt
Normal file
5
services/pipeline/ai-article-generator/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
anthropic==0.50.0
|
||||
redis[hiredis]==5.0.1
|
||||
pydantic==2.5.0
|
||||
motor==3.1.1
|
||||
pymongo==4.3.3
|
||||
37
services/pipeline/check_keywords.py
Normal file
37
services/pipeline/check_keywords.py
Normal file
@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
"""키워드 데이터베이스 확인 스크립트"""
|
||||
import asyncio
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from datetime import datetime
|
||||
|
||||
async def check_keywords():
|
||||
client = AsyncIOMotorClient("mongodb://localhost:27017")
|
||||
db = client.ai_writer_db
|
||||
|
||||
# 키워드 조회
|
||||
keywords = await db.keywords.find().to_list(None)
|
||||
|
||||
print(f"\n=== 등록된 키워드: {len(keywords)}개 ===\n")
|
||||
|
||||
for kw in keywords:
|
||||
print(f"키워드: {kw['keyword']}")
|
||||
print(f" - ID: {kw['keyword_id']}")
|
||||
print(f" - 간격: {kw['interval_minutes']}분")
|
||||
print(f" - 활성화: {kw['is_active']}")
|
||||
print(f" - 우선순위: {kw['priority']}")
|
||||
print(f" - RSS 피드: {len(kw.get('rss_feeds', []))}개")
|
||||
|
||||
if kw.get('last_run'):
|
||||
print(f" - 마지막 실행: {kw['last_run']}")
|
||||
|
||||
if kw.get('next_run'):
|
||||
next_run = kw['next_run']
|
||||
remaining = (next_run - datetime.now()).total_seconds() / 60
|
||||
print(f" - 다음 실행: {next_run} ({remaining:.1f}분 후)")
|
||||
|
||||
print()
|
||||
|
||||
client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_keywords())
|
||||
85
services/pipeline/config/languages.json
Normal file
85
services/pipeline/config/languages.json
Normal file
@ -0,0 +1,85 @@
|
||||
{
|
||||
"enabled_languages": [
|
||||
{
|
||||
"code": "en",
|
||||
"name": "English",
|
||||
"deepl_code": "EN",
|
||||
"collection": "articles_en",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"code": "zh-CN",
|
||||
"name": "Chinese (Simplified)",
|
||||
"deepl_code": "ZH",
|
||||
"collection": "articles_zh_cn",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"code": "zh-TW",
|
||||
"name": "Chinese (Traditional)",
|
||||
"deepl_code": "ZH-HANT",
|
||||
"collection": "articles_zh_tw",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"code": "ja",
|
||||
"name": "Japanese",
|
||||
"deepl_code": "JA",
|
||||
"collection": "articles_ja",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"code": "fr",
|
||||
"name": "French",
|
||||
"deepl_code": "FR",
|
||||
"collection": "articles_fr",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"code": "de",
|
||||
"name": "German",
|
||||
"deepl_code": "DE",
|
||||
"collection": "articles_de",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"code": "es",
|
||||
"name": "Spanish",
|
||||
"deepl_code": "ES",
|
||||
"collection": "articles_es",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"code": "pt",
|
||||
"name": "Portuguese",
|
||||
"deepl_code": "PT",
|
||||
"collection": "articles_pt",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"code": "ru",
|
||||
"name": "Russian",
|
||||
"deepl_code": "RU",
|
||||
"collection": "articles_ru",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"code": "it",
|
||||
"name": "Italian",
|
||||
"deepl_code": "IT",
|
||||
"collection": "articles_it",
|
||||
"enabled": false
|
||||
}
|
||||
],
|
||||
"source_language": {
|
||||
"code": "ko",
|
||||
"name": "Korean",
|
||||
"collection": "articles_ko"
|
||||
},
|
||||
"translation_settings": {
|
||||
"batch_size": 5,
|
||||
"delay_between_languages": 2.0,
|
||||
"delay_between_articles": 0.5,
|
||||
"max_retries": 3
|
||||
}
|
||||
}
|
||||
62
services/pipeline/fix_imports.py
Normal file
62
services/pipeline/fix_imports.py
Normal file
@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Fix import statements in all pipeline services"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
def fix_imports(filepath):
|
||||
"""Fix import statements in a Python file"""
|
||||
with open(filepath, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Pattern to match the old import style
|
||||
old_pattern = r"# 상위 디렉토리의 shared 모듈 import\nsys\.path\.append\(os\.path\.join\(os\.path\.dirname\(__file__\), '\.\.', 'shared'\)\)\nfrom ([\w, ]+) import ([\w, ]+)"
|
||||
|
||||
# Replace with new import style
|
||||
def replace_imports(match):
|
||||
modules = match.group(1)
|
||||
items = match.group(2)
|
||||
|
||||
# Build new import statements
|
||||
imports = []
|
||||
if 'models' in modules:
|
||||
imports.append(f"from shared.models import {items}" if 'models' in modules else "")
|
||||
if 'queue_manager' in modules:
|
||||
imports.append(f"from shared.queue_manager import QueueManager")
|
||||
|
||||
return "# Import from shared module\n" + "\n".join(filter(None, imports))
|
||||
|
||||
# Apply the replacement
|
||||
new_content = re.sub(old_pattern, replace_imports, content)
|
||||
|
||||
# Also handle simpler patterns
|
||||
new_content = new_content.replace(
|
||||
"sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'shared'))\nfrom models import",
|
||||
"from shared.models import"
|
||||
)
|
||||
new_content = new_content.replace(
|
||||
"\nfrom queue_manager import",
|
||||
"\nfrom shared.queue_manager import"
|
||||
)
|
||||
|
||||
# Write back if changed
|
||||
if new_content != content:
|
||||
with open(filepath, 'w') as f:
|
||||
f.write(new_content)
|
||||
print(f"Fixed imports in {filepath}")
|
||||
return True
|
||||
return False
|
||||
|
||||
# Files to fix
|
||||
files_to_fix = [
|
||||
"monitor/monitor.py",
|
||||
"google-search/google_search.py",
|
||||
"article-assembly/article_assembly.py",
|
||||
"rss-collector/rss_collector.py",
|
||||
"ai-summarizer/ai_summarizer.py"
|
||||
]
|
||||
|
||||
for file_path in files_to_fix:
|
||||
full_path = os.path.join(os.path.dirname(__file__), file_path)
|
||||
if os.path.exists(full_path):
|
||||
fix_imports(full_path)
|
||||
19
services/pipeline/google-search/Dockerfile
Normal file
19
services/pipeline/google-search/Dockerfile
Normal file
@ -0,0 +1,19 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 의존성 설치
|
||||
COPY ./google-search/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 공통 모듈 복사
|
||||
COPY ./shared /app/shared
|
||||
|
||||
# Google Search 코드 복사
|
||||
COPY ./google-search /app
|
||||
|
||||
# 환경변수
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 실행
|
||||
CMD ["python", "google_search.py"]
|
||||
152
services/pipeline/google-search/google_search.py
Normal file
152
services/pipeline/google-search/google_search.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""
|
||||
Google Search Service
|
||||
Google 검색으로 RSS 항목 강화
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
from datetime import datetime
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import PipelineJob, RSSItem, SearchResult, EnrichedItem
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GoogleSearchWorker:
|
||||
def __init__(self):
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
self.google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
self.search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
|
||||
self.max_results_per_item = 3
|
||||
|
||||
async def start(self):
|
||||
"""워커 시작"""
|
||||
logger.info("Starting Google Search Worker")
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# 메인 처리 루프
|
||||
while True:
|
||||
try:
|
||||
# 큐에서 작업 가져오기
|
||||
job = await self.queue_manager.dequeue('search_enrichment', timeout=5)
|
||||
|
||||
if job:
|
||||
await self.process_job(job)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker loop: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_job(self, job: PipelineJob):
|
||||
"""검색 강화 작업 처리 - 단일 RSS 아이템"""
|
||||
try:
|
||||
logger.info(f"Processing job {job.job_id} for search enrichment")
|
||||
|
||||
# 단일 RSS 아이템 처리
|
||||
rss_item_data = job.data.get('rss_item')
|
||||
if not rss_item_data:
|
||||
# 이전 버전 호환성 - 여러 아이템 처리
|
||||
rss_items = job.data.get('rss_items', [])
|
||||
if rss_items:
|
||||
rss_item_data = rss_items[0] # 첫 번째 아이템만 처리
|
||||
else:
|
||||
logger.warning(f"No RSS item in job {job.job_id}")
|
||||
await self.queue_manager.mark_failed(
|
||||
'search_enrichment',
|
||||
job,
|
||||
"No RSS item to process"
|
||||
)
|
||||
return
|
||||
|
||||
rss_item = RSSItem(**rss_item_data)
|
||||
|
||||
# 제목으로 Google 검색
|
||||
search_results = await self._search_google(rss_item.title)
|
||||
|
||||
enriched_item = EnrichedItem(
|
||||
rss_item=rss_item,
|
||||
search_results=search_results
|
||||
)
|
||||
|
||||
logger.info(f"Enriched item with {len(search_results)} search results")
|
||||
|
||||
# 다음 단계로 전달 - 단일 enriched item
|
||||
job.data['enriched_item'] = enriched_item.dict()
|
||||
job.stages_completed.append('search_enrichment')
|
||||
job.stage = 'ai_article_generation'
|
||||
|
||||
await self.queue_manager.enqueue('ai_article_generation', job)
|
||||
await self.queue_manager.mark_completed('search_enrichment', job.job_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing job {job.job_id}: {e}")
|
||||
await self.queue_manager.mark_failed('search_enrichment', job, str(e))
|
||||
|
||||
async def _search_google(self, query: str) -> List[SearchResult]:
|
||||
"""Google Custom Search API 호출"""
|
||||
results = []
|
||||
|
||||
if not self.google_api_key or not self.search_engine_id:
|
||||
logger.warning("Google API credentials not configured")
|
||||
return results
|
||||
|
||||
try:
|
||||
url = "https://www.googleapis.com/customsearch/v1"
|
||||
params = {
|
||||
"key": self.google_api_key,
|
||||
"cx": self.search_engine_id,
|
||||
"q": query,
|
||||
"num": self.max_results_per_item,
|
||||
"hl": "ko",
|
||||
"gl": "kr"
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=30) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
|
||||
for item in data.get('items', []):
|
||||
result = SearchResult(
|
||||
title=item.get('title', ''),
|
||||
link=item.get('link', ''),
|
||||
snippet=item.get('snippet', ''),
|
||||
source='google'
|
||||
)
|
||||
results.append(result)
|
||||
else:
|
||||
logger.error(f"Google API error: {response.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching Google for '{query}': {e}")
|
||||
|
||||
return results
|
||||
|
||||
async def stop(self):
|
||||
"""워커 중지"""
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info("Google Search Worker stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
worker = GoogleSearchWorker()
|
||||
|
||||
try:
|
||||
await worker.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await worker.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
3
services/pipeline/google-search/requirements.txt
Normal file
3
services/pipeline/google-search/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
aiohttp==3.9.1
|
||||
redis[hiredis]==5.0.1
|
||||
pydantic==2.5.0
|
||||
15
services/pipeline/image-generator/Dockerfile
Normal file
15
services/pipeline/image-generator/Dockerfile
Normal file
@ -0,0 +1,15 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies
|
||||
COPY ./image-generator/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared modules
|
||||
COPY ./shared /app/shared
|
||||
|
||||
# Copy application code
|
||||
COPY ./image-generator /app
|
||||
|
||||
CMD ["python", "image_generator.py"]
|
||||
256
services/pipeline/image-generator/image_generator.py
Normal file
256
services/pipeline/image-generator/image_generator.py
Normal file
@ -0,0 +1,256 @@
|
||||
"""
|
||||
Image Generation Service
|
||||
Replicate API를 사용한 이미지 생성 서비스
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import base64
|
||||
from typing import List, Dict, Any
|
||||
import httpx
|
||||
from io import BytesIO
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from bson import ObjectId
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import PipelineJob
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ImageGeneratorWorker:
|
||||
def __init__(self):
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
self.replicate_api_key = os.getenv("REPLICATE_API_TOKEN")
|
||||
self.replicate_api_url = "https://api.replicate.com/v1/predictions"
|
||||
# Stable Diffusion 모델 사용
|
||||
self.model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.db = None
|
||||
|
||||
async def start(self):
|
||||
"""워커 시작"""
|
||||
logger.info("Starting Image Generator Worker")
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# API 키 확인
|
||||
if not self.replicate_api_key:
|
||||
logger.warning("Replicate API key not configured - using placeholder images")
|
||||
|
||||
# 메인 처리 루프
|
||||
while True:
|
||||
try:
|
||||
# 큐에서 작업 가져오기
|
||||
job = await self.queue_manager.dequeue('image_generation', timeout=5)
|
||||
|
||||
if job:
|
||||
await self.process_job(job)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker loop: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_job(self, job: PipelineJob):
|
||||
"""이미지 생성 및 MongoDB 업데이트"""
|
||||
try:
|
||||
logger.info(f"Processing job {job.job_id} for image generation")
|
||||
|
||||
# MongoDB에서 기사 정보 가져오기
|
||||
news_id = job.data.get('news_id')
|
||||
mongodb_id = job.data.get('mongodb_id')
|
||||
|
||||
if not news_id:
|
||||
logger.error(f"No news_id in job {job.job_id}")
|
||||
await self.queue_manager.mark_failed('image_generation', job, "No news_id")
|
||||
return
|
||||
|
||||
# MongoDB에서 한국어 기사 조회 (articles_ko)
|
||||
article = await self.db.articles_ko.find_one({"news_id": news_id})
|
||||
if not article:
|
||||
logger.error(f"Article {news_id} not found in MongoDB")
|
||||
await self.queue_manager.mark_failed('image_generation', job, "Article not found")
|
||||
return
|
||||
|
||||
# 이미지 생성을 위한 프롬프트 생성 (한국어 기사 기반)
|
||||
prompt = self._create_image_prompt_from_article(article)
|
||||
|
||||
# 이미지 생성 (최대 3개)
|
||||
image_urls = []
|
||||
for i in range(min(3, 1)): # 테스트를 위해 1개만 생성
|
||||
image_url = await self._generate_image(prompt)
|
||||
image_urls.append(image_url)
|
||||
|
||||
# API 속도 제한
|
||||
if self.replicate_api_key and i < 2:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# MongoDB 업데이트 (이미지 추가 - articles_ko)
|
||||
await self.db.articles_ko.update_one(
|
||||
{"news_id": news_id},
|
||||
{
|
||||
"$set": {
|
||||
"images": image_urls,
|
||||
"image_prompt": prompt
|
||||
},
|
||||
"$addToSet": {
|
||||
"pipeline_stages": "image_generation"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Updated article {news_id} with {len(image_urls)} images")
|
||||
|
||||
# 다음 단계로 전달 (번역)
|
||||
job.stages_completed.append('image_generation')
|
||||
job.stage = 'translation'
|
||||
|
||||
await self.queue_manager.enqueue('translation', job)
|
||||
await self.queue_manager.mark_completed('image_generation', job.job_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing job {job.job_id}: {e}")
|
||||
await self.queue_manager.mark_failed('image_generation', job, str(e))
|
||||
|
||||
def _create_image_prompt_from_article(self, article: Dict) -> str:
|
||||
"""기사로부터 이미지 프롬프트 생성"""
|
||||
# 키워드와 제목을 기반으로 프롬프트 생성
|
||||
keyword = article.get('keyword', '')
|
||||
title = article.get('title', '')
|
||||
categories = article.get('categories', [])
|
||||
|
||||
# 카테고리 맵핑 (한글 -> 영어)
|
||||
category_map = {
|
||||
'기술': 'technology',
|
||||
'경제': 'business',
|
||||
'정치': 'politics',
|
||||
'교육': 'education',
|
||||
'사회': 'society',
|
||||
'문화': 'culture',
|
||||
'과학': 'science'
|
||||
}
|
||||
|
||||
eng_categories = [category_map.get(cat, cat) for cat in categories]
|
||||
category_str = ', '.join(eng_categories[:2]) if eng_categories else 'news'
|
||||
|
||||
# 뉴스 관련 이미지를 위한 프롬프트
|
||||
prompt = f"News illustration for {keyword} {category_str}, professional, modern, clean design, high quality, 4k, no text"
|
||||
|
||||
return prompt
|
||||
|
||||
async def _generate_image(self, prompt: str) -> str:
|
||||
"""Replicate API를 사용한 이미지 생성"""
|
||||
try:
|
||||
if not self.replicate_api_key:
|
||||
# API 키가 없으면 플레이스홀더 이미지 URL 반환
|
||||
return "https://via.placeholder.com/800x600.png?text=News+Image"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# 예측 생성 요청
|
||||
response = await client.post(
|
||||
self.replicate_api_url,
|
||||
headers={
|
||||
"Authorization": f"Token {self.replicate_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"version": self.model_version,
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
"width": 768,
|
||||
"height": 768,
|
||||
"num_outputs": 1,
|
||||
"scheduler": "K_EULER",
|
||||
"num_inference_steps": 25,
|
||||
"guidance_scale": 7.5,
|
||||
"prompt_strength": 0.8,
|
||||
"refine": "expert_ensemble_refiner",
|
||||
"high_noise_frac": 0.8
|
||||
}
|
||||
},
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
prediction_id = result.get('id')
|
||||
|
||||
# 예측 결과 폴링
|
||||
image_url = await self._poll_prediction(prediction_id)
|
||||
return image_url
|
||||
else:
|
||||
logger.error(f"Replicate API error: {response.status_code}")
|
||||
return "https://via.placeholder.com/800x600.png?text=Generation+Failed"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating image: {e}")
|
||||
return "https://via.placeholder.com/800x600.png?text=Error"
|
||||
|
||||
async def _poll_prediction(self, prediction_id: str, max_attempts: int = 30) -> str:
|
||||
"""예측 결과 폴링"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
for attempt in range(max_attempts):
|
||||
response = await client.get(
|
||||
f"{self.replicate_api_url}/{prediction_id}",
|
||||
headers={
|
||||
"Authorization": f"Token {self.replicate_api_key}"
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
status = result.get('status')
|
||||
|
||||
if status == 'succeeded':
|
||||
output = result.get('output')
|
||||
if output and isinstance(output, list) and len(output) > 0:
|
||||
return output[0]
|
||||
else:
|
||||
return "https://via.placeholder.com/800x600.png?text=No+Output"
|
||||
elif status == 'failed':
|
||||
logger.error(f"Prediction failed: {result.get('error')}")
|
||||
return "https://via.placeholder.com/800x600.png?text=Failed"
|
||||
|
||||
# 아직 처리중이면 대기
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logger.error(f"Error polling prediction: {response.status_code}")
|
||||
return "https://via.placeholder.com/800x600.png?text=Poll+Error"
|
||||
|
||||
# 최대 시도 횟수 초과
|
||||
return "https://via.placeholder.com/800x600.png?text=Timeout"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error polling prediction: {e}")
|
||||
return "https://via.placeholder.com/800x600.png?text=Poll+Exception"
|
||||
|
||||
async def stop(self):
|
||||
"""워커 중지"""
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info("Image Generator Worker stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
worker = ImageGeneratorWorker()
|
||||
|
||||
try:
|
||||
await worker.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await worker.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
5
services/pipeline/image-generator/requirements.txt
Normal file
5
services/pipeline/image-generator/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
httpx==0.25.0
|
||||
redis[hiredis]==5.0.1
|
||||
pydantic==2.5.0
|
||||
motor==3.1.1
|
||||
pymongo==4.3.3
|
||||
22
services/pipeline/monitor/Dockerfile
Normal file
22
services/pipeline/monitor/Dockerfile
Normal file
@ -0,0 +1,22 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies
|
||||
COPY ./monitor/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared modules
|
||||
COPY ./shared /app/shared
|
||||
|
||||
# Copy monitor code
|
||||
COPY ./monitor /app
|
||||
|
||||
# Environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run
|
||||
CMD ["uvicorn", "monitor:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
349
services/pipeline/monitor/monitor.py
Normal file
349
services/pipeline/monitor/monitor.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""
|
||||
Pipeline Monitor Service
|
||||
파이프라인 상태 모니터링 및 대시보드 API
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
import redis.asyncio as redis
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import KeywordSubscription, PipelineJob, FinalArticle
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Pipeline Monitor", version="1.0.0")
|
||||
|
||||
# CORS 설정
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Global connections
|
||||
redis_client = None
|
||||
mongodb_client = None
|
||||
db = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""서버 시작 시 연결 초기화"""
|
||||
global redis_client, mongodb_client, db
|
||||
|
||||
# Redis 연결
|
||||
redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
redis_client = await redis.from_url(redis_url, decode_responses=True)
|
||||
|
||||
# MongoDB 연결
|
||||
mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
mongodb_client = AsyncIOMotorClient(mongodb_url)
|
||||
db = mongodb_client[os.getenv("DB_NAME", "ai_writer_db")]
|
||||
|
||||
logger.info("Pipeline Monitor started successfully")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""서버 종료 시 연결 해제"""
|
||||
if redis_client:
|
||||
await redis_client.close()
|
||||
if mongodb_client:
|
||||
mongodb_client.close()
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""헬스 체크"""
|
||||
return {"status": "Pipeline Monitor is running"}
|
||||
|
||||
@app.get("/api/stats")
|
||||
async def get_stats():
|
||||
"""전체 파이프라인 통계"""
|
||||
try:
|
||||
# 큐별 대기 작업 수
|
||||
queue_stats = {}
|
||||
queues = [
|
||||
"queue:keyword",
|
||||
"queue:rss",
|
||||
"queue:search",
|
||||
"queue:summarize",
|
||||
"queue:assembly"
|
||||
]
|
||||
|
||||
for queue in queues:
|
||||
length = await redis_client.llen(queue)
|
||||
queue_stats[queue] = length
|
||||
|
||||
# 오늘 생성된 기사 수
|
||||
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
articles_today = await db.articles.count_documents({
|
||||
"created_at": {"$gte": today}
|
||||
})
|
||||
|
||||
# 활성 키워드 수
|
||||
active_keywords = await db.keywords.count_documents({
|
||||
"is_active": True
|
||||
})
|
||||
|
||||
# 총 기사 수
|
||||
total_articles = await db.articles.count_documents({})
|
||||
|
||||
return {
|
||||
"queues": queue_stats,
|
||||
"articles_today": articles_today,
|
||||
"active_keywords": active_keywords,
|
||||
"total_articles": total_articles,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/queues/{queue_name}")
|
||||
async def get_queue_details(queue_name: str):
|
||||
"""특정 큐의 상세 정보"""
|
||||
try:
|
||||
queue_key = f"queue:{queue_name}"
|
||||
|
||||
# 큐 길이
|
||||
length = await redis_client.llen(queue_key)
|
||||
|
||||
# 최근 10개 작업 미리보기
|
||||
items = await redis_client.lrange(queue_key, 0, 9)
|
||||
|
||||
# 처리 중인 작업
|
||||
processing_key = f"processing:{queue_name}"
|
||||
processing = await redis_client.smembers(processing_key)
|
||||
|
||||
# 실패한 작업
|
||||
failed_key = f"failed:{queue_name}"
|
||||
failed_count = await redis_client.llen(failed_key)
|
||||
|
||||
return {
|
||||
"queue": queue_name,
|
||||
"length": length,
|
||||
"processing_count": len(processing),
|
||||
"failed_count": failed_count,
|
||||
"preview": items[:10],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting queue details: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/keywords")
|
||||
async def get_keywords():
|
||||
"""등록된 키워드 목록"""
|
||||
try:
|
||||
keywords = []
|
||||
cursor = db.keywords.find({"is_active": True})
|
||||
|
||||
async for keyword in cursor:
|
||||
# 해당 키워드의 최근 기사
|
||||
latest_article = await db.articles.find_one(
|
||||
{"keyword_id": str(keyword["_id"])},
|
||||
sort=[("created_at", -1)]
|
||||
)
|
||||
|
||||
keywords.append({
|
||||
"id": str(keyword["_id"]),
|
||||
"keyword": keyword["keyword"],
|
||||
"schedule": keyword.get("schedule", "30분마다"),
|
||||
"created_at": keyword.get("created_at"),
|
||||
"last_article": latest_article["created_at"] if latest_article else None,
|
||||
"article_count": await db.articles.count_documents(
|
||||
{"keyword_id": str(keyword["_id"])}
|
||||
)
|
||||
})
|
||||
|
||||
return keywords
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting keywords: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/keywords")
|
||||
async def add_keyword(keyword: str, schedule: str = "30min"):
|
||||
"""새 키워드 등록"""
|
||||
try:
|
||||
new_keyword = {
|
||||
"keyword": keyword,
|
||||
"schedule": schedule,
|
||||
"is_active": True,
|
||||
"created_at": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
result = await db.keywords.insert_one(new_keyword)
|
||||
|
||||
return {
|
||||
"id": str(result.inserted_id),
|
||||
"keyword": keyword,
|
||||
"message": "Keyword registered successfully"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding keyword: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/api/keywords/{keyword_id}")
|
||||
async def delete_keyword(keyword_id: str):
|
||||
"""키워드 비활성화"""
|
||||
try:
|
||||
result = await db.keywords.update_one(
|
||||
{"_id": keyword_id},
|
||||
{"$set": {"is_active": False, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if result.modified_count > 0:
|
||||
return {"message": "Keyword deactivated successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Keyword not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting keyword: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/articles")
|
||||
async def get_articles(limit: int = 10, skip: int = 0):
|
||||
"""최근 생성된 기사 목록"""
|
||||
try:
|
||||
articles = []
|
||||
cursor = db.articles.find().sort("created_at", -1).skip(skip).limit(limit)
|
||||
|
||||
async for article in cursor:
|
||||
articles.append({
|
||||
"id": str(article["_id"]),
|
||||
"title": article["title"],
|
||||
"keyword": article["keyword"],
|
||||
"summary": article.get("summary", ""),
|
||||
"created_at": article["created_at"],
|
||||
"processing_time": article.get("processing_time", 0),
|
||||
"pipeline_stages": article.get("pipeline_stages", [])
|
||||
})
|
||||
|
||||
total = await db.articles.count_documents({})
|
||||
|
||||
return {
|
||||
"articles": articles,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"skip": skip
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting articles: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/articles/{article_id}")
|
||||
async def get_article(article_id: str):
|
||||
"""특정 기사 상세 정보"""
|
||||
try:
|
||||
article = await db.articles.find_one({"_id": article_id})
|
||||
|
||||
if not article:
|
||||
raise HTTPException(status_code=404, detail="Article not found")
|
||||
|
||||
return article
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting article: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/workers")
|
||||
async def get_workers():
|
||||
"""워커 상태 정보"""
|
||||
try:
|
||||
workers = {}
|
||||
worker_types = [
|
||||
"scheduler",
|
||||
"rss_collector",
|
||||
"google_search",
|
||||
"ai_summarizer",
|
||||
"article_assembly"
|
||||
]
|
||||
|
||||
for worker_type in worker_types:
|
||||
active_key = f"workers:{worker_type}:active"
|
||||
active_workers = await redis_client.smembers(active_key)
|
||||
|
||||
workers[worker_type] = {
|
||||
"active": len(active_workers),
|
||||
"worker_ids": list(active_workers)
|
||||
}
|
||||
|
||||
return workers
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workers: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/trigger/{keyword}")
|
||||
async def trigger_keyword_processing(keyword: str):
|
||||
"""수동으로 키워드 처리 트리거"""
|
||||
try:
|
||||
# 키워드 찾기
|
||||
keyword_doc = await db.keywords.find_one({
|
||||
"keyword": keyword,
|
||||
"is_active": True
|
||||
})
|
||||
|
||||
if not keyword_doc:
|
||||
raise HTTPException(status_code=404, detail="Keyword not found or inactive")
|
||||
|
||||
# 작업 생성
|
||||
job = PipelineJob(
|
||||
keyword_id=str(keyword_doc["_id"]),
|
||||
keyword=keyword,
|
||||
stage="keyword_processing",
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
# 큐에 추가
|
||||
await redis_client.rpush("queue:keyword", job.json())
|
||||
|
||||
return {
|
||||
"message": f"Processing triggered for keyword: {keyword}",
|
||||
"job_id": job.job_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering keyword: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""시스템 헬스 체크"""
|
||||
try:
|
||||
# Redis 체크
|
||||
redis_status = await redis_client.ping()
|
||||
|
||||
# MongoDB 체크
|
||||
mongodb_status = await db.command("ping")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"redis": "connected" if redis_status else "disconnected",
|
||||
"mongodb": "connected" if mongodb_status else "disconnected",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
6
services/pipeline/monitor/requirements.txt
Normal file
6
services/pipeline/monitor/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
redis[hiredis]==5.0.1
|
||||
motor==3.1.1
|
||||
pymongo==4.3.3
|
||||
pydantic==2.5.0
|
||||
19
services/pipeline/rss-collector/Dockerfile
Normal file
19
services/pipeline/rss-collector/Dockerfile
Normal file
@ -0,0 +1,19 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 의존성 설치
|
||||
COPY ./rss-collector/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 공통 모듈 복사
|
||||
COPY ./shared /app/shared
|
||||
|
||||
# RSS Collector 코드 복사
|
||||
COPY ./rss-collector /app
|
||||
|
||||
# 환경변수
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 실행
|
||||
CMD ["python", "rss_collector.py"]
|
||||
5
services/pipeline/rss-collector/requirements.txt
Normal file
5
services/pipeline/rss-collector/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
feedparser==6.0.11
|
||||
aiohttp==3.9.1
|
||||
redis[hiredis]==5.0.1
|
||||
pydantic==2.5.0
|
||||
motor==3.6.0
|
||||
270
services/pipeline/rss-collector/rss_collector.py
Normal file
270
services/pipeline/rss-collector/rss_collector.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""
|
||||
RSS Collector Service
|
||||
RSS 피드 수집 및 중복 제거 서비스
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
import feedparser
|
||||
import aiohttp
|
||||
import redis.asyncio as redis
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import PipelineJob, RSSItem, EnrichedItem
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RSSCollectorWorker:
|
||||
def __init__(self):
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
self.redis_client = None
|
||||
self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.db = None
|
||||
self.dedup_ttl = 86400 * 7 # 7일간 중복 방지
|
||||
self.max_items_per_feed = 100 # 피드당 최대 항목 수 (Google News는 최대 100개)
|
||||
|
||||
async def start(self):
|
||||
"""워커 시작"""
|
||||
logger.info("Starting RSS Collector Worker")
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
self.redis_client = await redis.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# 메인 처리 루프
|
||||
while True:
|
||||
try:
|
||||
# 큐에서 작업 가져오기 (5초 대기)
|
||||
job = await self.queue_manager.dequeue('rss_collection', timeout=5)
|
||||
|
||||
if job:
|
||||
await self.process_job(job)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker loop: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_job(self, job: PipelineJob):
|
||||
"""RSS 수집 작업 처리"""
|
||||
try:
|
||||
logger.info(f"Processing job {job.job_id} for keyword '{job.keyword}'")
|
||||
|
||||
keyword = job.keyword # keyword는 job의 직접 속성
|
||||
rss_feeds = job.data.get('rss_feeds', [])
|
||||
|
||||
# RSS 피드가 없으면 기본 피드 사용
|
||||
if not rss_feeds:
|
||||
# 기본 RSS 피드 추가 (Google News RSS)
|
||||
rss_feeds = [
|
||||
f"https://news.google.com/rss/search?q={keyword}&hl=en-US&gl=US&ceid=US:en",
|
||||
f"https://news.google.com/rss/search?q={keyword}&hl=ko&gl=KR&ceid=KR:ko",
|
||||
"https://feeds.bbci.co.uk/news/technology/rss.xml",
|
||||
"https://rss.nytimes.com/services/xml/rss/nyt/Technology.xml"
|
||||
]
|
||||
logger.info(f"Using default RSS feeds for keyword: {keyword}")
|
||||
|
||||
# 키워드가 포함된 RSS URL 생성
|
||||
processed_feeds = self._prepare_feeds(rss_feeds, keyword)
|
||||
|
||||
all_items = []
|
||||
|
||||
for feed_url in processed_feeds:
|
||||
try:
|
||||
items = await self._fetch_rss_feed(feed_url, keyword)
|
||||
all_items.extend(items)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching feed {feed_url}: {e}")
|
||||
|
||||
if all_items:
|
||||
# 중복 제거
|
||||
unique_items = await self._deduplicate_items(all_items, keyword)
|
||||
|
||||
if unique_items:
|
||||
logger.info(f"Collected {len(unique_items)} unique items for '{keyword}'")
|
||||
|
||||
# 각 RSS 아이템별로 개별 job 생성하여 다음 단계로 전달
|
||||
# 시간 지연을 추가하여 API 호출 분산 (초기값: 1초, 점진적으로 조정 가능)
|
||||
enqueue_delay = float(os.getenv("RSS_ENQUEUE_DELAY", "1.0"))
|
||||
|
||||
for idx, item in enumerate(unique_items):
|
||||
# 각 아이템별로 새로운 job 생성
|
||||
item_job = PipelineJob(
|
||||
keyword_id=f"{job.keyword_id}_{idx}",
|
||||
keyword=job.keyword,
|
||||
stage='search_enrichment',
|
||||
data={
|
||||
'rss_item': item.dict(), # 단일 아이템
|
||||
'original_job_id': job.job_id,
|
||||
'item_index': idx,
|
||||
'total_items': len(unique_items),
|
||||
'item_hash': hashlib.md5(
|
||||
f"{keyword}:guid:{item.guid}".encode() if item.guid
|
||||
else f"{keyword}:title:{item.title}:link:{item.link}".encode()
|
||||
).hexdigest() # GUID 또는 title+link 해시
|
||||
},
|
||||
stages_completed=['rss_collection']
|
||||
)
|
||||
|
||||
# 개별 아이템을 다음 단계로 전달
|
||||
await self.queue_manager.enqueue('search_enrichment', item_job)
|
||||
logger.info(f"Enqueued item {idx+1}/{len(unique_items)} for keyword '{keyword}'")
|
||||
|
||||
# 다음 아이템 enqueue 전에 지연 추가 (마지막 아이템 제외)
|
||||
if idx < len(unique_items) - 1:
|
||||
await asyncio.sleep(enqueue_delay)
|
||||
logger.debug(f"Waiting {enqueue_delay}s before next item...")
|
||||
|
||||
# 원본 job 완료 처리
|
||||
await self.queue_manager.mark_completed('rss_collection', job.job_id)
|
||||
logger.info(f"Completed RSS collection for job {job.job_id}: {len(unique_items)} items processed")
|
||||
else:
|
||||
logger.info(f"No new items found for '{keyword}' after deduplication")
|
||||
await self.queue_manager.mark_completed('rss_collection', job.job_id)
|
||||
else:
|
||||
logger.warning(f"No RSS items collected for '{keyword}'")
|
||||
await self.queue_manager.mark_failed(
|
||||
'rss_collection',
|
||||
job,
|
||||
"No RSS items collected"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing job {job.job_id}: {e}")
|
||||
await self.queue_manager.mark_failed('rss_collection', job, str(e))
|
||||
|
||||
def _prepare_feeds(self, feeds: List[str], keyword: str) -> List[str]:
|
||||
"""RSS 피드 URL 준비 (키워드 치환)"""
|
||||
processed = []
|
||||
for feed in feeds:
|
||||
if '{keyword}' in feed:
|
||||
processed.append(feed.replace('{keyword}', keyword))
|
||||
else:
|
||||
processed.append(feed)
|
||||
return processed
|
||||
|
||||
async def _fetch_rss_feed(self, feed_url: str, keyword: str) -> List[RSSItem]:
|
||||
"""RSS 피드 가져오기"""
|
||||
items = []
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(feed_url, timeout=30) as response:
|
||||
content = await response.text()
|
||||
|
||||
# feedparser로 파싱
|
||||
feed = feedparser.parse(content)
|
||||
|
||||
logger.info(f"Found {len(feed.entries)} entries in feed {feed_url}")
|
||||
|
||||
for entry in feed.entries[:self.max_items_per_feed]:
|
||||
# 키워드 관련성 체크
|
||||
title = entry.get('title', '')
|
||||
summary = entry.get('summary', '')
|
||||
|
||||
# 대소문자 무시하고 키워드 매칭 (영문의 경우)
|
||||
title_lower = title.lower() if keyword.isascii() else title
|
||||
summary_lower = summary.lower() if keyword.isascii() else summary
|
||||
keyword_lower = keyword.lower() if keyword.isascii() else keyword
|
||||
|
||||
# 제목이나 요약에 키워드가 포함된 경우
|
||||
# Google News RSS는 이미 키워드 검색 결과이므로 모든 항목 포함
|
||||
if "news.google.com" in feed_url or keyword_lower in title_lower or keyword_lower in summary_lower:
|
||||
# GUID 추출 (Google RSS에서 일반적으로 사용)
|
||||
guid = entry.get('id', entry.get('guid', ''))
|
||||
|
||||
item = RSSItem(
|
||||
title=title,
|
||||
link=entry.get('link', ''),
|
||||
guid=guid, # GUID 추가
|
||||
published=entry.get('published', ''),
|
||||
summary=summary[:500] if summary else '',
|
||||
source_feed=feed_url
|
||||
)
|
||||
items.append(item)
|
||||
logger.debug(f"Added item: {title[:50]}... (guid: {guid[:30] if guid else 'no-guid'})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching RSS feed {feed_url}: {e}")
|
||||
|
||||
return items
|
||||
|
||||
async def _deduplicate_items(self, items: List[RSSItem], keyword: str) -> List[RSSItem]:
|
||||
"""중복 항목 제거 - GUID 또는 링크 기준으로만 중복 체크"""
|
||||
unique_items = []
|
||||
seen_guids = set() # 현재 배치에서 본 GUID
|
||||
seen_links = set() # 현재 배치에서 본 링크
|
||||
|
||||
for item in items:
|
||||
# GUID가 있는 경우 GUID로 중복 체크
|
||||
if item.guid:
|
||||
if item.guid in seen_guids:
|
||||
logger.debug(f"Duplicate GUID in batch: {item.guid[:30]}")
|
||||
continue
|
||||
|
||||
# MongoDB에서 이미 처리된 기사인지 확인
|
||||
existing_article = await self.db.articles_ko.find_one({"rss_guid": item.guid})
|
||||
if existing_article:
|
||||
logger.info(f"Article with GUID {item.guid[:30]} already processed, skipping")
|
||||
continue
|
||||
|
||||
seen_guids.add(item.guid)
|
||||
else:
|
||||
# GUID가 없으면 링크로 중복 체크
|
||||
if item.link in seen_links:
|
||||
logger.debug(f"Duplicate link in batch: {item.link[:50]}")
|
||||
continue
|
||||
|
||||
# MongoDB에서 링크로 중복 확인 (references 필드에서 검색)
|
||||
existing_article = await self.db.articles_ko.find_one({"references.link": item.link})
|
||||
if existing_article:
|
||||
logger.info(f"Article with link {item.link[:50]} already processed, skipping")
|
||||
continue
|
||||
|
||||
seen_links.add(item.link)
|
||||
|
||||
unique_items.append(item)
|
||||
logger.debug(f"New item added: {item.title[:50]}...")
|
||||
|
||||
logger.info(f"Deduplication result: {len(unique_items)} new items out of {len(items)} total")
|
||||
|
||||
return unique_items
|
||||
|
||||
async def stop(self):
|
||||
"""워커 중지"""
|
||||
await self.queue_manager.disconnect()
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
logger.info("RSS Collector Worker stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
worker = RSSCollectorWorker()
|
||||
|
||||
try:
|
||||
await worker.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await worker.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
16
services/pipeline/scheduler/Dockerfile
Normal file
16
services/pipeline/scheduler/Dockerfile
Normal file
@ -0,0 +1,16 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies
|
||||
COPY ./scheduler/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared module
|
||||
COPY ./shared /app/shared
|
||||
|
||||
# Copy scheduler code
|
||||
COPY ./scheduler /app
|
||||
|
||||
# Run scheduler
|
||||
CMD ["python", "keyword_scheduler.py"]
|
||||
336
services/pipeline/scheduler/keyword_manager.py
Normal file
336
services/pipeline/scheduler/keyword_manager.py
Normal file
@ -0,0 +1,336 @@
|
||||
"""
|
||||
Keyword Manager API
|
||||
키워드를 추가/수정/삭제하는 관리 API
|
||||
"""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
import uvicorn
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
# Import from shared module
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from shared.models import Keyword
|
||||
|
||||
app = FastAPI(title="Keyword Manager API")
|
||||
|
||||
# MongoDB 연결
|
||||
mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""앱 시작 시 MongoDB 연결"""
|
||||
app.mongodb_client = AsyncIOMotorClient(mongodb_url)
|
||||
app.db = app.mongodb_client[db_name]
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""앱 종료 시 연결 해제"""
|
||||
app.mongodb_client.close()
|
||||
|
||||
class KeywordCreate(BaseModel):
|
||||
"""키워드 생성 요청 모델"""
|
||||
keyword: str
|
||||
interval_minutes: int = 60
|
||||
priority: int = 0
|
||||
rss_feeds: List[str] = []
|
||||
max_articles_per_run: int = 100
|
||||
is_active: bool = True
|
||||
|
||||
class KeywordUpdate(BaseModel):
|
||||
"""키워드 업데이트 요청 모델"""
|
||||
interval_minutes: Optional[int] = None
|
||||
priority: Optional[int] = None
|
||||
rss_feeds: Optional[List[str]] = None
|
||||
max_articles_per_run: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""API 상태 확인"""
|
||||
return {"status": "Keyword Manager API is running"}
|
||||
|
||||
@app.get("/threads/status")
|
||||
async def get_threads_status():
|
||||
"""모든 스레드 상태 조회"""
|
||||
try:
|
||||
# MongoDB에서 키워드 정보와 함께 상태 반환
|
||||
cursor = app.db.keywords.find()
|
||||
keywords = await cursor.to_list(None)
|
||||
|
||||
threads_status = []
|
||||
for kw in keywords:
|
||||
status = {
|
||||
"keyword": kw.get("keyword"),
|
||||
"keyword_id": kw.get("keyword_id"),
|
||||
"is_active": kw.get("is_active"),
|
||||
"interval_minutes": kw.get("interval_minutes"),
|
||||
"priority": kw.get("priority"),
|
||||
"last_run": kw.get("last_run").isoformat() if kw.get("last_run") else None,
|
||||
"next_run": kw.get("next_run").isoformat() if kw.get("next_run") else None,
|
||||
"thread_status": "active" if kw.get("is_active") else "inactive"
|
||||
}
|
||||
|
||||
# 다음 실행까지 남은 시간 계산
|
||||
if kw.get("next_run"):
|
||||
remaining = (kw.get("next_run") - datetime.now()).total_seconds()
|
||||
if remaining > 0:
|
||||
status["minutes_until_next_run"] = round(remaining / 60, 1)
|
||||
else:
|
||||
status["minutes_until_next_run"] = 0
|
||||
status["thread_status"] = "pending_execution"
|
||||
|
||||
threads_status.append(status)
|
||||
|
||||
# 우선순위 순으로 정렬
|
||||
threads_status.sort(key=lambda x: x.get("priority", 0), reverse=True)
|
||||
|
||||
return {
|
||||
"total_threads": len(threads_status),
|
||||
"active_threads": sum(1 for t in threads_status if t.get("is_active")),
|
||||
"threads": threads_status
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/keywords")
|
||||
async def list_keywords():
|
||||
"""모든 키워드 조회"""
|
||||
try:
|
||||
cursor = app.db.keywords.find()
|
||||
keywords = await cursor.to_list(None)
|
||||
|
||||
# 각 키워드 정보 정리
|
||||
result = []
|
||||
for kw in keywords:
|
||||
result.append({
|
||||
"keyword_id": kw.get("keyword_id"),
|
||||
"keyword": kw.get("keyword"),
|
||||
"interval_minutes": kw.get("interval_minutes"),
|
||||
"priority": kw.get("priority"),
|
||||
"is_active": kw.get("is_active"),
|
||||
"last_run": kw.get("last_run").isoformat() if kw.get("last_run") else None,
|
||||
"next_run": kw.get("next_run").isoformat() if kw.get("next_run") else None,
|
||||
"rss_feeds": kw.get("rss_feeds", []),
|
||||
"max_articles_per_run": kw.get("max_articles_per_run", 100)
|
||||
})
|
||||
|
||||
return {
|
||||
"total": len(result),
|
||||
"keywords": result
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/keywords/{keyword_text}")
|
||||
async def get_keyword(keyword_text: str):
|
||||
"""특정 키워드 조회"""
|
||||
try:
|
||||
keyword = await app.db.keywords.find_one({"keyword": keyword_text})
|
||||
if not keyword:
|
||||
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
|
||||
|
||||
return {
|
||||
"keyword_id": keyword.get("keyword_id"),
|
||||
"keyword": keyword.get("keyword"),
|
||||
"interval_minutes": keyword.get("interval_minutes"),
|
||||
"priority": keyword.get("priority"),
|
||||
"is_active": keyword.get("is_active"),
|
||||
"last_run": keyword.get("last_run").isoformat() if keyword.get("last_run") else None,
|
||||
"next_run": keyword.get("next_run").isoformat() if keyword.get("next_run") else None,
|
||||
"rss_feeds": keyword.get("rss_feeds", []),
|
||||
"max_articles_per_run": keyword.get("max_articles_per_run", 100)
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/keywords")
|
||||
async def create_keyword(keyword_data: KeywordCreate):
|
||||
"""새 키워드 생성"""
|
||||
try:
|
||||
# 중복 체크
|
||||
existing = await app.db.keywords.find_one({"keyword": keyword_data.keyword})
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"Keyword '{keyword_data.keyword}' already exists")
|
||||
|
||||
# 새 키워드 생성
|
||||
keyword = Keyword(
|
||||
keyword_id=str(uuid.uuid4()),
|
||||
keyword=keyword_data.keyword,
|
||||
interval_minutes=keyword_data.interval_minutes,
|
||||
priority=keyword_data.priority,
|
||||
rss_feeds=keyword_data.rss_feeds,
|
||||
max_articles_per_run=keyword_data.max_articles_per_run,
|
||||
is_active=keyword_data.is_active,
|
||||
next_run=datetime.now() + timedelta(minutes=1), # 1분 후 첫 실행
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
|
||||
await app.db.keywords.insert_one(keyword.model_dump())
|
||||
|
||||
return {
|
||||
"message": f"Keyword '{keyword_data.keyword}' created successfully",
|
||||
"keyword_id": keyword.keyword_id,
|
||||
"note": "The scheduler will automatically detect and start processing this keyword within 30 seconds"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.put("/keywords/{keyword_text}")
|
||||
async def update_keyword(keyword_text: str, update_data: KeywordUpdate):
|
||||
"""키워드 업데이트"""
|
||||
try:
|
||||
# 키워드 존재 확인
|
||||
existing = await app.db.keywords.find_one({"keyword": keyword_text})
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
|
||||
|
||||
# 업데이트 데이터 준비
|
||||
update_dict = {}
|
||||
if update_data.interval_minutes is not None:
|
||||
update_dict["interval_minutes"] = update_data.interval_minutes
|
||||
if update_data.priority is not None:
|
||||
update_dict["priority"] = update_data.priority
|
||||
if update_data.rss_feeds is not None:
|
||||
update_dict["rss_feeds"] = update_data.rss_feeds
|
||||
if update_data.max_articles_per_run is not None:
|
||||
update_dict["max_articles_per_run"] = update_data.max_articles_per_run
|
||||
if update_data.is_active is not None:
|
||||
update_dict["is_active"] = update_data.is_active
|
||||
|
||||
if update_dict:
|
||||
update_dict["updated_at"] = datetime.now()
|
||||
|
||||
# 만약 interval이 변경되면 next_run도 재계산
|
||||
if "interval_minutes" in update_dict:
|
||||
update_dict["next_run"] = datetime.now() + timedelta(minutes=update_dict["interval_minutes"])
|
||||
|
||||
result = await app.db.keywords.update_one(
|
||||
{"keyword": keyword_text},
|
||||
{"$set": update_dict}
|
||||
)
|
||||
|
||||
if result.modified_count > 0:
|
||||
action_note = ""
|
||||
if update_data.is_active is False:
|
||||
action_note = "The scheduler will stop the thread for this keyword within 30 seconds."
|
||||
elif update_data.is_active is True and not existing.get("is_active"):
|
||||
action_note = "The scheduler will start a new thread for this keyword within 30 seconds."
|
||||
|
||||
return {
|
||||
"message": f"Keyword '{keyword_text}' updated successfully",
|
||||
"updated_fields": list(update_dict.keys()),
|
||||
"note": action_note
|
||||
}
|
||||
else:
|
||||
return {"message": "No changes made"}
|
||||
else:
|
||||
return {"message": "No update data provided"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/keywords/{keyword_text}")
|
||||
async def delete_keyword(keyword_text: str):
|
||||
"""키워드 삭제"""
|
||||
try:
|
||||
# 키워드 존재 확인
|
||||
existing = await app.db.keywords.find_one({"keyword": keyword_text})
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
|
||||
|
||||
# 삭제
|
||||
result = await app.db.keywords.delete_one({"keyword": keyword_text})
|
||||
|
||||
if result.deleted_count > 0:
|
||||
return {
|
||||
"message": f"Keyword '{keyword_text}' deleted successfully",
|
||||
"note": "The scheduler will stop the thread for this keyword within 30 seconds"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete keyword")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/keywords/{keyword_text}/activate")
|
||||
async def activate_keyword(keyword_text: str):
|
||||
"""키워드 활성화"""
|
||||
try:
|
||||
result = await app.db.keywords.update_one(
|
||||
{"keyword": keyword_text},
|
||||
{"$set": {"is_active": True, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
|
||||
|
||||
return {
|
||||
"message": f"Keyword '{keyword_text}' activated",
|
||||
"note": "The scheduler will start processing this keyword within 30 seconds"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/keywords/{keyword_text}/deactivate")
|
||||
async def deactivate_keyword(keyword_text: str):
|
||||
"""키워드 비활성화"""
|
||||
try:
|
||||
result = await app.db.keywords.update_one(
|
||||
{"keyword": keyword_text},
|
||||
{"$set": {"is_active": False, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
|
||||
|
||||
return {
|
||||
"message": f"Keyword '{keyword_text}' deactivated",
|
||||
"note": "The scheduler will stop processing this keyword within 30 seconds"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/keywords/{keyword_text}/trigger")
|
||||
async def trigger_keyword(keyword_text: str):
|
||||
"""키워드 즉시 실행 트리거"""
|
||||
try:
|
||||
# next_run을 현재 시간으로 설정하여 즉시 실행되도록 함
|
||||
result = await app.db.keywords.update_one(
|
||||
{"keyword": keyword_text},
|
||||
{"$set": {"next_run": datetime.now(), "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise HTTPException(status_code=404, detail=f"Keyword '{keyword_text}' not found")
|
||||
|
||||
return {
|
||||
"message": f"Keyword '{keyword_text}' triggered for immediate execution",
|
||||
"note": "The scheduler will execute this keyword within the next minute"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.getenv("API_PORT", "8100"))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
245
services/pipeline/scheduler/keyword_scheduler.py
Normal file
245
services/pipeline/scheduler/keyword_scheduler.py
Normal file
@ -0,0 +1,245 @@
|
||||
"""
|
||||
Keyword Scheduler Service
|
||||
데이터베이스에 등록된 키워드를 주기적으로 실행하는 스케줄러
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
# Import from shared module
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from shared.models import Keyword, PipelineJob
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class KeywordScheduler:
|
||||
def __init__(self):
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.db = None
|
||||
self.check_interval = int(os.getenv("SCHEDULER_CHECK_INTERVAL", "60")) # 1분마다 체크
|
||||
self.default_interval = int(os.getenv("DEFAULT_KEYWORD_INTERVAL", "60")) # 기본 1시간
|
||||
|
||||
async def start(self):
|
||||
"""스케줄러 시작"""
|
||||
logger.info("Starting Keyword Scheduler")
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# 초기 키워드 설정
|
||||
await self.initialize_keywords()
|
||||
|
||||
# 메인 루프
|
||||
while True:
|
||||
try:
|
||||
await self.check_and_execute_keywords()
|
||||
await asyncio.sleep(self.check_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scheduler loop: {e}")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
async def initialize_keywords(self):
|
||||
"""초기 키워드 설정 (없으면 생성)"""
|
||||
try:
|
||||
# keywords 컬렉션 확인
|
||||
count = await self.db.keywords.count_documents({})
|
||||
|
||||
if count == 0:
|
||||
logger.info("No keywords found. Creating default keywords...")
|
||||
|
||||
# 기본 키워드 생성
|
||||
default_keywords = [
|
||||
{
|
||||
"keyword": "AI",
|
||||
"interval_minutes": 60,
|
||||
"is_active": True,
|
||||
"priority": 1,
|
||||
"rss_feeds": []
|
||||
},
|
||||
{
|
||||
"keyword": "경제",
|
||||
"interval_minutes": 120,
|
||||
"is_active": True,
|
||||
"priority": 0,
|
||||
"rss_feeds": []
|
||||
},
|
||||
{
|
||||
"keyword": "테크놀로지",
|
||||
"interval_minutes": 60,
|
||||
"is_active": True,
|
||||
"priority": 1,
|
||||
"rss_feeds": []
|
||||
}
|
||||
]
|
||||
|
||||
for kw_data in default_keywords:
|
||||
keyword = Keyword(**kw_data)
|
||||
# 다음 실행 시간 설정
|
||||
keyword.next_run = datetime.now() + timedelta(minutes=5) # 5분 후 첫 실행
|
||||
await self.db.keywords.insert_one(keyword.dict())
|
||||
logger.info(f"Created keyword: {keyword.keyword}")
|
||||
|
||||
logger.info(f"Found {count} keywords in database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing keywords: {e}")
|
||||
|
||||
async def check_and_execute_keywords(self):
|
||||
"""실행할 키워드 체크 및 실행"""
|
||||
try:
|
||||
# 현재 시간
|
||||
now = datetime.now()
|
||||
|
||||
# 실행할 키워드 조회 (활성화되고 next_run이 현재 시간 이전인 것)
|
||||
query = {
|
||||
"is_active": True,
|
||||
"$or": [
|
||||
{"next_run": {"$lte": now}},
|
||||
{"next_run": None} # next_run이 설정되지 않은 경우
|
||||
]
|
||||
}
|
||||
|
||||
# 우선순위 순으로 정렬
|
||||
cursor = self.db.keywords.find(query).sort("priority", -1)
|
||||
keywords = await cursor.to_list(None)
|
||||
|
||||
for keyword_data in keywords:
|
||||
keyword = Keyword(**keyword_data)
|
||||
await self.execute_keyword(keyword)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking keywords: {e}")
|
||||
|
||||
async def execute_keyword(self, keyword: Keyword):
|
||||
"""키워드 실행"""
|
||||
try:
|
||||
logger.info(f"Executing keyword: {keyword.keyword}")
|
||||
|
||||
# PipelineJob 생성
|
||||
job = PipelineJob(
|
||||
keyword_id=keyword.keyword_id,
|
||||
keyword=keyword.keyword,
|
||||
stage='rss_collection',
|
||||
data={
|
||||
'rss_feeds': keyword.rss_feeds if keyword.rss_feeds else [],
|
||||
'max_articles': keyword.max_articles_per_run,
|
||||
'scheduled': True
|
||||
},
|
||||
priority=keyword.priority
|
||||
)
|
||||
|
||||
# 큐에 작업 추가
|
||||
await self.queue_manager.enqueue('rss_collection', job)
|
||||
logger.info(f"Enqueued job for keyword '{keyword.keyword}' with job_id: {job.job_id}")
|
||||
|
||||
# 키워드 업데이트
|
||||
update_data = {
|
||||
"last_run": datetime.now(),
|
||||
"next_run": datetime.now() + timedelta(minutes=keyword.interval_minutes),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
await self.db.keywords.update_one(
|
||||
{"keyword_id": keyword.keyword_id},
|
||||
{"$set": update_data}
|
||||
)
|
||||
|
||||
logger.info(f"Updated keyword '{keyword.keyword}' - next run at {update_data['next_run']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing keyword {keyword.keyword}: {e}")
|
||||
|
||||
async def add_keyword(self, keyword_text: str, interval_minutes: int = None,
|
||||
rss_feeds: List[str] = None, priority: int = 0):
|
||||
"""새 키워드 추가"""
|
||||
try:
|
||||
# 중복 체크
|
||||
existing = await self.db.keywords.find_one({"keyword": keyword_text})
|
||||
if existing:
|
||||
logger.warning(f"Keyword '{keyword_text}' already exists")
|
||||
return None
|
||||
|
||||
# 새 키워드 생성
|
||||
keyword = Keyword(
|
||||
keyword=keyword_text,
|
||||
interval_minutes=interval_minutes or self.default_interval,
|
||||
rss_feeds=rss_feeds or [],
|
||||
priority=priority,
|
||||
next_run=datetime.now() + timedelta(minutes=1) # 1분 후 첫 실행
|
||||
)
|
||||
|
||||
result = await self.db.keywords.insert_one(keyword.dict())
|
||||
logger.info(f"Added new keyword: {keyword_text}")
|
||||
return keyword
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding keyword: {e}")
|
||||
return None
|
||||
|
||||
async def update_keyword(self, keyword_id: str, **kwargs):
|
||||
"""키워드 업데이트"""
|
||||
try:
|
||||
# 업데이트할 필드
|
||||
update_data = {k: v for k, v in kwargs.items() if v is not None}
|
||||
update_data["updated_at"] = datetime.now()
|
||||
|
||||
result = await self.db.keywords.update_one(
|
||||
{"keyword_id": keyword_id},
|
||||
{"$set": update_data}
|
||||
)
|
||||
|
||||
if result.modified_count > 0:
|
||||
logger.info(f"Updated keyword {keyword_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating keyword: {e}")
|
||||
return False
|
||||
|
||||
async def delete_keyword(self, keyword_id: str):
|
||||
"""키워드 삭제"""
|
||||
try:
|
||||
result = await self.db.keywords.delete_one({"keyword_id": keyword_id})
|
||||
if result.deleted_count > 0:
|
||||
logger.info(f"Deleted keyword {keyword_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting keyword: {e}")
|
||||
return False
|
||||
|
||||
async def stop(self):
|
||||
"""스케줄러 중지"""
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info("Keyword Scheduler stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
scheduler = KeywordScheduler()
|
||||
|
||||
try:
|
||||
await scheduler.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await scheduler.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
361
services/pipeline/scheduler/multi_thread_scheduler.py
Normal file
361
services/pipeline/scheduler/multi_thread_scheduler.py
Normal file
@ -0,0 +1,361 @@
|
||||
"""
|
||||
Multi-threaded Keyword Scheduler Service
|
||||
하나의 프로세스에서 여러 스레드로 키워드를 관리하는 스케줄러
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from typing import Dict
|
||||
import threading
|
||||
import time
|
||||
|
||||
# Import from shared module
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from shared.models import Keyword, PipelineJob
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 전역 변수로 스케줄러 인스턴스 참조 저장
|
||||
scheduler_instance = None
|
||||
|
||||
class KeywordThread(threading.Thread):
|
||||
"""개별 키워드를 관리하는 스레드"""
|
||||
|
||||
def __init__(self, keyword_text: str, mongodb_url: str, db_name: str, redis_url: str):
|
||||
super().__init__(name=f"Thread-{keyword_text}")
|
||||
self.keyword_text = keyword_text
|
||||
self.mongodb_url = mongodb_url
|
||||
self.db_name = db_name
|
||||
self.redis_url = redis_url
|
||||
self.running = True
|
||||
self.keyword = None
|
||||
self.status = "initializing"
|
||||
self.last_execution = None
|
||||
self.execution_count = 0
|
||||
self.error_count = 0
|
||||
self.last_error = None
|
||||
|
||||
def run(self):
|
||||
"""스레드 실행"""
|
||||
# 새로운 이벤트 루프 생성
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(self.run_scheduler())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def run_scheduler(self):
|
||||
"""비동기 스케줄러 실행"""
|
||||
# Redis 연결
|
||||
self.queue_manager = QueueManager(redis_url=self.redis_url)
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
logger.info(f"[{self.keyword_text}] Thread started")
|
||||
|
||||
# 키워드 로드
|
||||
await self.load_keyword()
|
||||
|
||||
if not self.keyword:
|
||||
logger.error(f"[{self.keyword_text}] Failed to load keyword")
|
||||
return
|
||||
|
||||
# 메인 루프
|
||||
while self.running:
|
||||
try:
|
||||
# 키워드 상태 체크
|
||||
await self.reload_keyword()
|
||||
|
||||
if not self.keyword.is_active:
|
||||
self.status = "inactive"
|
||||
logger.info(f"[{self.keyword_text}] Keyword is inactive, sleeping...")
|
||||
await asyncio.sleep(60)
|
||||
continue
|
||||
|
||||
# 실행 시간 체크
|
||||
now = datetime.now()
|
||||
if self.keyword.next_run and self.keyword.next_run <= now:
|
||||
self.status = "executing"
|
||||
await self.execute_keyword()
|
||||
# 다음 실행 시간까지 대기
|
||||
sleep_seconds = self.keyword.interval_minutes * 60
|
||||
self.status = "waiting"
|
||||
else:
|
||||
# 다음 체크까지 1분 대기
|
||||
sleep_seconds = 60
|
||||
self.status = "waiting"
|
||||
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
|
||||
except Exception as e:
|
||||
self.error_count += 1
|
||||
self.last_error = str(e)
|
||||
self.status = "error"
|
||||
logger.error(f"[{self.keyword_text}] Error in thread loop: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info(f"[{self.keyword_text}] Thread stopped")
|
||||
|
||||
async def load_keyword(self):
|
||||
"""키워드 초기 로드"""
|
||||
try:
|
||||
keyword_doc = await self.db.keywords.find_one({"keyword": self.keyword_text})
|
||||
if keyword_doc:
|
||||
self.keyword = Keyword(**keyword_doc)
|
||||
logger.info(f"[{self.keyword_text}] Loaded keyword")
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.keyword_text}] Error loading keyword: {e}")
|
||||
|
||||
async def reload_keyword(self):
|
||||
"""키워드 정보 재로드"""
|
||||
try:
|
||||
keyword_doc = await self.db.keywords.find_one({"keyword": self.keyword_text})
|
||||
if keyword_doc:
|
||||
self.keyword = Keyword(**keyword_doc)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.keyword_text}] Error reloading keyword: {e}")
|
||||
|
||||
async def execute_keyword(self):
|
||||
"""키워드 실행"""
|
||||
try:
|
||||
logger.info(f"[{self.keyword_text}] Executing keyword")
|
||||
|
||||
# PipelineJob 생성
|
||||
job = PipelineJob(
|
||||
keyword_id=self.keyword.keyword_id,
|
||||
keyword=self.keyword.keyword,
|
||||
stage='rss_collection',
|
||||
data={
|
||||
'rss_feeds': self.keyword.rss_feeds if self.keyword.rss_feeds else [],
|
||||
'max_articles': self.keyword.max_articles_per_run,
|
||||
'scheduled': True,
|
||||
'thread_name': self.name
|
||||
},
|
||||
priority=self.keyword.priority
|
||||
)
|
||||
|
||||
# 큐에 작업 추가
|
||||
await self.queue_manager.enqueue('rss_collection', job)
|
||||
logger.info(f"[{self.keyword_text}] Enqueued job {job.job_id}")
|
||||
|
||||
# 키워드 업데이트
|
||||
update_data = {
|
||||
"last_run": datetime.now(),
|
||||
"next_run": datetime.now() + timedelta(minutes=self.keyword.interval_minutes),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
await self.db.keywords.update_one(
|
||||
{"keyword_id": self.keyword.keyword_id},
|
||||
{"$set": update_data}
|
||||
)
|
||||
|
||||
self.last_execution = datetime.now()
|
||||
self.execution_count += 1
|
||||
logger.info(f"[{self.keyword_text}] Next run at {update_data['next_run']}")
|
||||
|
||||
except Exception as e:
|
||||
self.error_count += 1
|
||||
self.last_error = str(e)
|
||||
logger.error(f"[{self.keyword_text}] Error executing keyword: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""스레드 중지"""
|
||||
self.running = False
|
||||
self.status = "stopped"
|
||||
|
||||
def get_status(self):
|
||||
"""스레드 상태 반환"""
|
||||
return {
|
||||
"keyword": self.keyword_text,
|
||||
"thread_name": self.name,
|
||||
"status": self.status,
|
||||
"is_alive": self.is_alive(),
|
||||
"execution_count": self.execution_count,
|
||||
"last_execution": self.last_execution.isoformat() if self.last_execution else None,
|
||||
"error_count": self.error_count,
|
||||
"last_error": self.last_error,
|
||||
"next_run": self.keyword.next_run.isoformat() if self.keyword and self.keyword.next_run else None
|
||||
}
|
||||
|
||||
|
||||
class MultiThreadScheduler:
|
||||
"""멀티스레드 키워드 스케줄러"""
|
||||
|
||||
def __init__(self):
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
self.threads: Dict[str, KeywordThread] = {}
|
||||
self.running = True
|
||||
# Singleton 인스턴스를 전역 변수로 저장
|
||||
global scheduler_instance
|
||||
scheduler_instance = self
|
||||
|
||||
async def start(self):
|
||||
"""스케줄러 시작"""
|
||||
logger.info("Starting Multi-threaded Keyword Scheduler")
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# 초기 키워드 설정
|
||||
await self.initialize_keywords()
|
||||
|
||||
# 키워드 로드 및 스레드 시작
|
||||
await self.load_and_start_threads()
|
||||
|
||||
# 메인 루프 - 새로운 키워드 체크
|
||||
while self.running:
|
||||
try:
|
||||
await self.check_new_keywords()
|
||||
await asyncio.sleep(30) # 30초마다 새 키워드 체크
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main loop: {e}")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
async def initialize_keywords(self):
|
||||
"""초기 키워드 설정 (없으면 생성)"""
|
||||
try:
|
||||
count = await self.db.keywords.count_documents({})
|
||||
|
||||
if count == 0:
|
||||
logger.info("No keywords found. Creating default keywords...")
|
||||
|
||||
default_keywords = [
|
||||
{
|
||||
"keyword": "AI",
|
||||
"interval_minutes": 60,
|
||||
"is_active": True,
|
||||
"priority": 1,
|
||||
"rss_feeds": [],
|
||||
"next_run": datetime.now() + timedelta(minutes=1)
|
||||
},
|
||||
{
|
||||
"keyword": "경제",
|
||||
"interval_minutes": 120,
|
||||
"is_active": True,
|
||||
"priority": 0,
|
||||
"rss_feeds": [],
|
||||
"next_run": datetime.now() + timedelta(minutes=1)
|
||||
},
|
||||
{
|
||||
"keyword": "테크놀로지",
|
||||
"interval_minutes": 60,
|
||||
"is_active": True,
|
||||
"priority": 1,
|
||||
"rss_feeds": [],
|
||||
"next_run": datetime.now() + timedelta(minutes=1)
|
||||
}
|
||||
]
|
||||
|
||||
for kw_data in default_keywords:
|
||||
keyword = Keyword(**kw_data)
|
||||
await self.db.keywords.insert_one(keyword.model_dump())
|
||||
logger.info(f"Created keyword: {keyword.keyword}")
|
||||
|
||||
logger.info(f"Found {count} keywords in database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing keywords: {e}")
|
||||
|
||||
async def load_and_start_threads(self):
|
||||
"""키워드 로드 및 스레드 시작"""
|
||||
try:
|
||||
# 활성 키워드 조회
|
||||
cursor = self.db.keywords.find({"is_active": True})
|
||||
keywords = await cursor.to_list(None)
|
||||
|
||||
for keyword_doc in keywords:
|
||||
keyword = Keyword(**keyword_doc)
|
||||
if keyword.keyword not in self.threads:
|
||||
self.start_keyword_thread(keyword.keyword)
|
||||
|
||||
logger.info(f"Started {len(self.threads)} keyword threads")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading keywords: {e}")
|
||||
|
||||
def start_keyword_thread(self, keyword_text: str):
|
||||
"""키워드 스레드 시작"""
|
||||
if keyword_text not in self.threads:
|
||||
thread = KeywordThread(
|
||||
keyword_text=keyword_text,
|
||||
mongodb_url=self.mongodb_url,
|
||||
db_name=self.db_name,
|
||||
redis_url=self.redis_url
|
||||
)
|
||||
thread.start()
|
||||
self.threads[keyword_text] = thread
|
||||
logger.info(f"Started thread for keyword: {keyword_text}")
|
||||
|
||||
async def check_new_keywords(self):
|
||||
"""새로운 키워드 체크 및 스레드 관리"""
|
||||
try:
|
||||
# 현재 활성 키워드 조회
|
||||
cursor = self.db.keywords.find({"is_active": True})
|
||||
active_keywords = await cursor.to_list(None)
|
||||
active_keyword_texts = {kw['keyword'] for kw in active_keywords}
|
||||
|
||||
# 새 키워드 시작
|
||||
for keyword_text in active_keyword_texts:
|
||||
if keyword_text not in self.threads:
|
||||
self.start_keyword_thread(keyword_text)
|
||||
|
||||
# 비활성화된 키워드 스레드 중지
|
||||
for keyword_text in list(self.threads.keys()):
|
||||
if keyword_text not in active_keyword_texts:
|
||||
thread = self.threads[keyword_text]
|
||||
thread.stop()
|
||||
del self.threads[keyword_text]
|
||||
logger.info(f"Stopped thread for keyword: {keyword_text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking new keywords: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""모든 스레드 중지"""
|
||||
self.running = False
|
||||
for thread in self.threads.values():
|
||||
thread.stop()
|
||||
|
||||
# 모든 스레드가 종료될 때까지 대기
|
||||
for thread in self.threads.values():
|
||||
thread.join(timeout=5)
|
||||
|
||||
logger.info("Multi-threaded Keyword Scheduler stopped")
|
||||
|
||||
def get_threads_status(self):
|
||||
"""모든 스레드 상태 반환"""
|
||||
status_list = []
|
||||
for thread in self.threads.values():
|
||||
status_list.append(thread.get_status())
|
||||
return status_list
|
||||
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
scheduler = MultiThreadScheduler()
|
||||
|
||||
try:
|
||||
await scheduler.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
scheduler.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
5
services/pipeline/scheduler/requirements.txt
Normal file
5
services/pipeline/scheduler/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
motor==3.6.0
|
||||
redis[hiredis]==5.0.1
|
||||
pydantic==2.5.0
|
||||
fastapi==0.104.1
|
||||
uvicorn==0.24.0
|
||||
203
services/pipeline/scheduler/scheduler.py
Normal file
203
services/pipeline/scheduler/scheduler.py
Normal file
@ -0,0 +1,203 @@
|
||||
"""
|
||||
News Pipeline Scheduler
|
||||
뉴스 파이프라인 스케줄러 서비스
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import KeywordSubscription, PipelineJob
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NewsScheduler:
|
||||
def __init__(self):
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.db = None
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""스케줄러 시작"""
|
||||
logger.info("Starting News Pipeline Scheduler")
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# 기본 스케줄 설정
|
||||
# 매 30분마다 실행
|
||||
self.scheduler.add_job(
|
||||
self.process_keywords,
|
||||
'interval',
|
||||
minutes=30,
|
||||
id='keyword_processor',
|
||||
name='Process Active Keywords'
|
||||
)
|
||||
|
||||
# 특정 시간대 강화 스케줄 (아침 7시, 점심 12시, 저녁 6시)
|
||||
for hour in [7, 12, 18]:
|
||||
self.scheduler.add_job(
|
||||
self.process_priority_keywords,
|
||||
'cron',
|
||||
hour=hour,
|
||||
minute=0,
|
||||
id=f'priority_processor_{hour}',
|
||||
name=f'Process Priority Keywords at {hour}:00'
|
||||
)
|
||||
|
||||
# 매일 자정 통계 초기화
|
||||
self.scheduler.add_job(
|
||||
self.reset_daily_stats,
|
||||
'cron',
|
||||
hour=0,
|
||||
minute=0,
|
||||
id='stats_reset',
|
||||
name='Reset Daily Statistics'
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
logger.info("Scheduler started successfully")
|
||||
|
||||
# 시작 즉시 한 번 실행
|
||||
await self.process_keywords()
|
||||
|
||||
async def process_keywords(self):
|
||||
"""활성 키워드 처리"""
|
||||
try:
|
||||
logger.info("Processing active keywords")
|
||||
|
||||
# MongoDB에서 활성 키워드 로드
|
||||
now = datetime.now()
|
||||
thirty_minutes_ago = now - timedelta(minutes=30)
|
||||
|
||||
keywords = await self.db.keywords.find({
|
||||
"is_active": True,
|
||||
"$or": [
|
||||
{"last_processed": {"$lt": thirty_minutes_ago}},
|
||||
{"last_processed": None}
|
||||
]
|
||||
}).to_list(None)
|
||||
|
||||
logger.info(f"Found {len(keywords)} keywords to process")
|
||||
|
||||
for keyword_doc in keywords:
|
||||
await self._create_job(keyword_doc)
|
||||
|
||||
# 처리 시간 업데이트
|
||||
await self.db.keywords.update_one(
|
||||
{"keyword_id": keyword_doc['keyword_id']},
|
||||
{"$set": {"last_processed": now}}
|
||||
)
|
||||
|
||||
logger.info(f"Created jobs for {len(keywords)} keywords")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing keywords: {e}")
|
||||
|
||||
async def process_priority_keywords(self):
|
||||
"""우선순위 키워드 처리"""
|
||||
try:
|
||||
logger.info("Processing priority keywords")
|
||||
|
||||
keywords = await self.db.keywords.find({
|
||||
"is_active": True,
|
||||
"is_priority": True
|
||||
}).to_list(None)
|
||||
|
||||
for keyword_doc in keywords:
|
||||
await self._create_job(keyword_doc, priority=1)
|
||||
|
||||
logger.info(f"Created priority jobs for {len(keywords)} keywords")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing priority keywords: {e}")
|
||||
|
||||
async def _create_job(self, keyword_doc: dict, priority: int = 0):
|
||||
"""파이프라인 작업 생성"""
|
||||
try:
|
||||
# KeywordSubscription 모델로 변환
|
||||
keyword = KeywordSubscription(**keyword_doc)
|
||||
|
||||
# PipelineJob 생성
|
||||
job = PipelineJob(
|
||||
keyword_id=keyword.keyword_id,
|
||||
keyword=keyword.keyword,
|
||||
stage='rss_collection',
|
||||
stages_completed=[],
|
||||
priority=priority,
|
||||
data={
|
||||
'keyword': keyword.keyword,
|
||||
'language': keyword.language,
|
||||
'rss_feeds': keyword.rss_feeds or self._get_default_rss_feeds(),
|
||||
'categories': keyword.categories
|
||||
}
|
||||
)
|
||||
|
||||
# 첫 번째 큐에 추가
|
||||
await self.queue_manager.enqueue(
|
||||
'rss_collection',
|
||||
job,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
logger.info(f"Created job {job.job_id} for keyword '{keyword.keyword}'")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating job for keyword: {e}")
|
||||
|
||||
def _get_default_rss_feeds(self) -> list:
|
||||
"""기본 RSS 피드 목록"""
|
||||
return [
|
||||
"https://news.google.com/rss/search?q={keyword}&hl=ko&gl=KR&ceid=KR:ko",
|
||||
"https://trends.google.com/trends/trendingsearches/daily/rss?geo=KR",
|
||||
"https://www.mk.co.kr/rss/40300001/", # 매일경제
|
||||
"https://www.hankyung.com/feed/all-news", # 한국경제
|
||||
"https://www.zdnet.co.kr/news/news_rss.xml", # ZDNet Korea
|
||||
]
|
||||
|
||||
async def reset_daily_stats(self):
|
||||
"""일일 통계 초기화"""
|
||||
try:
|
||||
logger.info("Resetting daily statistics")
|
||||
# Redis 통계 초기화
|
||||
# 구현 필요
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting stats: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""스케줄러 중지"""
|
||||
self.scheduler.shutdown()
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
scheduler = NewsScheduler()
|
||||
|
||||
try:
|
||||
await scheduler.start()
|
||||
# 계속 실행
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await scheduler.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
173
services/pipeline/scheduler/single_keyword_scheduler.py
Normal file
173
services/pipeline/scheduler/single_keyword_scheduler.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""
|
||||
Single Keyword Scheduler Service
|
||||
단일 키워드를 전담하는 스케줄러
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
import uuid
|
||||
|
||||
# Import from shared module
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from shared.models import Keyword, PipelineJob
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SingleKeywordScheduler:
|
||||
def __init__(self):
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.keyword_text = os.getenv("KEYWORD") # 환경변수로 키워드 지정
|
||||
self.interval_minutes = int(os.getenv("INTERVAL_MINUTES", "60"))
|
||||
self.db = None
|
||||
self.keyword = None
|
||||
|
||||
async def start(self):
|
||||
"""스케줄러 시작"""
|
||||
if not self.keyword_text:
|
||||
logger.error("KEYWORD environment variable is required")
|
||||
return
|
||||
|
||||
logger.info(f"Starting Single Keyword Scheduler for '{self.keyword_text}'")
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# 키워드 초기화 또는 로드
|
||||
await self.initialize_keyword()
|
||||
|
||||
if not self.keyword:
|
||||
logger.error(f"Failed to initialize keyword '{self.keyword_text}'")
|
||||
return
|
||||
|
||||
# 메인 루프 - 이 키워드만 처리
|
||||
while True:
|
||||
try:
|
||||
await self.check_and_execute()
|
||||
# 다음 실행까지 대기
|
||||
sleep_seconds = self.keyword.interval_minutes * 60
|
||||
logger.info(f"Sleeping for {self.keyword.interval_minutes} minutes until next execution")
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scheduler loop: {e}")
|
||||
await asyncio.sleep(60) # 에러 발생시 1분 후 재시도
|
||||
|
||||
async def initialize_keyword(self):
|
||||
"""키워드 초기화 또는 로드"""
|
||||
try:
|
||||
# 기존 키워드 찾기
|
||||
keyword_doc = await self.db.keywords.find_one({"keyword": self.keyword_text})
|
||||
|
||||
if keyword_doc:
|
||||
self.keyword = Keyword(**keyword_doc)
|
||||
logger.info(f"Loaded existing keyword: {self.keyword_text}")
|
||||
else:
|
||||
# 새 키워드 생성
|
||||
self.keyword = Keyword(
|
||||
keyword=self.keyword_text,
|
||||
interval_minutes=self.interval_minutes,
|
||||
is_active=True,
|
||||
priority=int(os.getenv("PRIORITY", "0")),
|
||||
rss_feeds=os.getenv("RSS_FEEDS", "").split(",") if os.getenv("RSS_FEEDS") else [],
|
||||
max_articles_per_run=int(os.getenv("MAX_ARTICLES", "100"))
|
||||
)
|
||||
|
||||
await self.db.keywords.insert_one(self.keyword.model_dump())
|
||||
logger.info(f"Created new keyword: {self.keyword_text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing keyword: {e}")
|
||||
|
||||
async def check_and_execute(self):
|
||||
"""키워드 실행 체크 및 실행"""
|
||||
try:
|
||||
# 최신 키워드 정보 다시 로드
|
||||
keyword_doc = await self.db.keywords.find_one({"keyword": self.keyword_text})
|
||||
|
||||
if not keyword_doc:
|
||||
logger.error(f"Keyword '{self.keyword_text}' not found in database")
|
||||
return
|
||||
|
||||
self.keyword = Keyword(**keyword_doc)
|
||||
|
||||
# 비활성화된 경우 스킵
|
||||
if not self.keyword.is_active:
|
||||
logger.info(f"Keyword '{self.keyword_text}' is inactive, skipping")
|
||||
return
|
||||
|
||||
# 실행
|
||||
await self.execute_keyword()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking keyword: {e}")
|
||||
|
||||
async def execute_keyword(self):
|
||||
"""키워드 실행"""
|
||||
try:
|
||||
logger.info(f"Executing keyword: {self.keyword.keyword}")
|
||||
|
||||
# PipelineJob 생성
|
||||
job = PipelineJob(
|
||||
keyword_id=self.keyword.keyword_id,
|
||||
keyword=self.keyword.keyword,
|
||||
stage='rss_collection',
|
||||
data={
|
||||
'rss_feeds': self.keyword.rss_feeds if self.keyword.rss_feeds else [],
|
||||
'max_articles': self.keyword.max_articles_per_run,
|
||||
'scheduled': True,
|
||||
'scheduler_instance': f"single-{self.keyword_text}"
|
||||
},
|
||||
priority=self.keyword.priority
|
||||
)
|
||||
|
||||
# 큐에 작업 추가
|
||||
await self.queue_manager.enqueue('rss_collection', job)
|
||||
logger.info(f"Enqueued job for keyword '{self.keyword.keyword}' with job_id: {job.job_id}")
|
||||
|
||||
# 키워드 업데이트
|
||||
update_data = {
|
||||
"last_run": datetime.now(),
|
||||
"next_run": datetime.now() + timedelta(minutes=self.keyword.interval_minutes),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
await self.db.keywords.update_one(
|
||||
{"keyword_id": self.keyword.keyword_id},
|
||||
{"$set": update_data}
|
||||
)
|
||||
|
||||
logger.info(f"Updated keyword '{self.keyword.keyword}' - next run at {update_data['next_run']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing keyword {self.keyword.keyword}: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""스케줄러 중지"""
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info(f"Single Keyword Scheduler for '{self.keyword_text}' stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
scheduler = SingleKeywordScheduler()
|
||||
|
||||
try:
|
||||
await scheduler.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await scheduler.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1
services/pipeline/shared/__init__.py
Normal file
1
services/pipeline/shared/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Shared modules for pipeline services
|
||||
159
services/pipeline/shared/models.py
Normal file
159
services/pipeline/shared/models.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""
|
||||
Pipeline Data Models
|
||||
파이프라인 전체에서 사용되는 공통 데이터 모델
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid
|
||||
|
||||
class KeywordSubscription(BaseModel):
|
||||
"""키워드 구독 모델"""
|
||||
keyword_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
keyword: str
|
||||
language: str = "ko"
|
||||
schedule: str = "0 */30 * * *" # Cron expression (30분마다)
|
||||
is_active: bool = True
|
||||
is_priority: bool = False
|
||||
last_processed: Optional[datetime] = None
|
||||
rss_feeds: List[str] = Field(default_factory=list)
|
||||
categories: List[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
owner: Optional[str] = None
|
||||
|
||||
class PipelineJob(BaseModel):
|
||||
"""파이프라인 작업 모델"""
|
||||
job_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
keyword_id: str
|
||||
keyword: str
|
||||
stage: str # current stage
|
||||
stages_completed: List[str] = Field(default_factory=list)
|
||||
data: Dict[str, Any] = Field(default_factory=dict)
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
priority: int = 0
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class RSSItem(BaseModel):
|
||||
"""RSS 피드 아이템"""
|
||||
item_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
title: str
|
||||
link: str
|
||||
guid: Optional[str] = None # RSS GUID for deduplication
|
||||
published: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
source_feed: str
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""검색 결과"""
|
||||
title: str
|
||||
link: str
|
||||
snippet: Optional[str] = None
|
||||
source: str = "google"
|
||||
|
||||
class EnrichedItem(BaseModel):
|
||||
"""강화된 뉴스 아이템"""
|
||||
rss_item: RSSItem
|
||||
search_results: List[SearchResult] = Field(default_factory=list)
|
||||
|
||||
class SummarizedItem(BaseModel):
|
||||
"""요약된 아이템"""
|
||||
enriched_item: EnrichedItem
|
||||
ai_summary: str
|
||||
summary_language: str = "ko"
|
||||
|
||||
class TranslatedItem(BaseModel):
|
||||
"""번역된 아이템"""
|
||||
summarized_item: SummarizedItem
|
||||
title_en: str
|
||||
summary_en: str
|
||||
|
||||
class ItemWithImage(BaseModel):
|
||||
"""이미지가 추가된 아이템"""
|
||||
translated_item: TranslatedItem
|
||||
image_url: str
|
||||
image_prompt: str
|
||||
|
||||
class Subtopic(BaseModel):
|
||||
"""기사 소주제"""
|
||||
title: str
|
||||
content: List[str] # 문단별 내용
|
||||
|
||||
class Entities(BaseModel):
|
||||
"""개체명"""
|
||||
people: List[str] = Field(default_factory=list)
|
||||
organizations: List[str] = Field(default_factory=list)
|
||||
groups: List[str] = Field(default_factory=list)
|
||||
countries: List[str] = Field(default_factory=list)
|
||||
events: List[str] = Field(default_factory=list)
|
||||
|
||||
class NewsReference(BaseModel):
|
||||
"""뉴스 레퍼런스"""
|
||||
title: str
|
||||
link: str
|
||||
source: str
|
||||
published: Optional[str] = None
|
||||
|
||||
class FinalArticle(BaseModel):
|
||||
"""최종 기사 - ai_writer_db.articles 스키마와 일치"""
|
||||
news_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
title: str
|
||||
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
|
||||
summary: str
|
||||
subtopics: List[Subtopic] = Field(default_factory=list)
|
||||
categories: List[str] = Field(default_factory=list)
|
||||
entities: Entities = Field(default_factory=Entities)
|
||||
source_keyword: str
|
||||
source_count: int = 1
|
||||
# 레퍼런스 뉴스 정보
|
||||
references: List[NewsReference] = Field(default_factory=list)
|
||||
# 파이프라인 관련 추가 필드
|
||||
job_id: Optional[str] = None
|
||||
keyword_id: Optional[str] = None
|
||||
pipeline_stages: List[str] = Field(default_factory=list)
|
||||
processing_time: Optional[float] = None
|
||||
# 다국어 지원
|
||||
language: str = 'ko'
|
||||
ref_news_id: Optional[str] = None
|
||||
# RSS 중복 체크용 GUID
|
||||
rss_guid: Optional[str] = None
|
||||
# 이미지 관련 필드
|
||||
image_prompt: Optional[str] = None
|
||||
images: List[str] = Field(default_factory=list)
|
||||
# 번역 추적
|
||||
translated_languages: List[str] = Field(default_factory=list)
|
||||
|
||||
class TranslatedItem(BaseModel):
|
||||
"""번역된 아이템"""
|
||||
summarized_item: Dict[str, Any] # SummarizedItem as dict
|
||||
translated_title: str
|
||||
translated_summary: str
|
||||
target_language: str = 'en'
|
||||
|
||||
class GeneratedImageItem(BaseModel):
|
||||
"""이미지 생성된 아이템"""
|
||||
translated_item: Dict[str, Any] # TranslatedItem as dict
|
||||
image_url: str
|
||||
image_prompt: str
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""큐 메시지"""
|
||||
message_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
queue_name: str
|
||||
job: PipelineJob
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
retry_count: int = 0
|
||||
class Keyword(BaseModel):
|
||||
"""스케줄러용 키워드 모델"""
|
||||
keyword_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
keyword: str
|
||||
interval_minutes: int = Field(default=60) # 기본 1시간
|
||||
is_active: bool = Field(default=True)
|
||||
last_run: Optional[datetime] = None
|
||||
next_run: Optional[datetime] = None
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
rss_feeds: List[str] = Field(default_factory=list) # 커스텀 RSS 피드
|
||||
priority: int = Field(default=0) # 우선순위 (높을수록 우선)
|
||||
max_articles_per_run: int = Field(default=100) # 실행당 최대 기사 수
|
||||
176
services/pipeline/shared/queue_manager.py
Normal file
176
services/pipeline/shared/queue_manager.py
Normal file
@ -0,0 +1,176 @@
|
||||
"""
|
||||
Queue Manager
|
||||
Redis 기반 큐 관리 시스템
|
||||
"""
|
||||
import redis.asyncio as redis
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
from .models import PipelineJob, QueueMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class QueueManager:
|
||||
"""Redis 기반 큐 매니저"""
|
||||
|
||||
QUEUES = {
|
||||
"keyword_processing": "queue:keyword_processing",
|
||||
"rss_collection": "queue:rss_collection",
|
||||
"search_enrichment": "queue:search_enrichment",
|
||||
"google_search": "queue:google_search",
|
||||
"ai_article_generation": "queue:ai_article_generation",
|
||||
"image_generation": "queue:image_generation",
|
||||
"translation": "queue:translation",
|
||||
"failed": "queue:failed",
|
||||
"scheduled": "queue:scheduled"
|
||||
}
|
||||
|
||||
def __init__(self, redis_url: str = "redis://redis:6379"):
|
||||
self.redis_url = redis_url
|
||||
self.redis_client: Optional[redis.Redis] = None
|
||||
|
||||
async def connect(self):
|
||||
"""Redis 연결"""
|
||||
if not self.redis_client:
|
||||
self.redis_client = await redis.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
logger.info("Connected to Redis")
|
||||
|
||||
async def disconnect(self):
|
||||
"""Redis 연결 해제"""
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
self.redis_client = None
|
||||
|
||||
async def enqueue(self, queue_name: str, job: PipelineJob, priority: int = 0) -> str:
|
||||
"""작업을 큐에 추가"""
|
||||
try:
|
||||
queue_key = self.QUEUES.get(queue_name, f"queue:{queue_name}")
|
||||
|
||||
message = QueueMessage(
|
||||
queue_name=queue_name,
|
||||
job=job
|
||||
)
|
||||
|
||||
# 우선순위에 따라 추가
|
||||
if priority > 0:
|
||||
await self.redis_client.lpush(queue_key, message.json())
|
||||
else:
|
||||
await self.redis_client.rpush(queue_key, message.json())
|
||||
|
||||
# 통계 업데이트
|
||||
await self.redis_client.hincrby("stats:queues", queue_name, 1)
|
||||
|
||||
logger.info(f"Job {job.job_id} enqueued to {queue_name}")
|
||||
return job.job_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enqueue job: {e}")
|
||||
raise
|
||||
|
||||
async def dequeue(self, queue_name: str, timeout: int = 0) -> Optional[PipelineJob]:
|
||||
"""큐에서 작업 가져오기"""
|
||||
try:
|
||||
queue_key = self.QUEUES.get(queue_name, f"queue:{queue_name}")
|
||||
logger.info(f"Attempting to dequeue from {queue_key} with timeout={timeout}")
|
||||
|
||||
if timeout > 0:
|
||||
result = await self.redis_client.blpop(queue_key, timeout)
|
||||
if result:
|
||||
_, data = result
|
||||
logger.info(f"Dequeued item from {queue_key}")
|
||||
else:
|
||||
logger.debug(f"No item available in {queue_key}")
|
||||
return None
|
||||
else:
|
||||
data = await self.redis_client.lpop(queue_key)
|
||||
|
||||
if data:
|
||||
message = QueueMessage.parse_raw(data)
|
||||
|
||||
# 처리 중 목록에 추가
|
||||
processing_key = f"processing:{queue_name}"
|
||||
await self.redis_client.hset(
|
||||
processing_key,
|
||||
message.job.job_id,
|
||||
message.json()
|
||||
)
|
||||
|
||||
return message.job
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dequeue job: {e}")
|
||||
return None
|
||||
|
||||
async def mark_completed(self, queue_name: str, job_id: str):
|
||||
"""작업 완료 표시"""
|
||||
try:
|
||||
processing_key = f"processing:{queue_name}"
|
||||
await self.redis_client.hdel(processing_key, job_id)
|
||||
|
||||
# 통계 업데이트
|
||||
await self.redis_client.hincrby("stats:completed", queue_name, 1)
|
||||
|
||||
logger.info(f"Job {job_id} completed in {queue_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark job as completed: {e}")
|
||||
|
||||
async def mark_failed(self, queue_name: str, job: PipelineJob, error: str):
|
||||
"""작업 실패 처리"""
|
||||
try:
|
||||
processing_key = f"processing:{queue_name}"
|
||||
await self.redis_client.hdel(processing_key, job.job_id)
|
||||
|
||||
# 재시도 확인
|
||||
if job.retry_count < job.max_retries:
|
||||
job.retry_count += 1
|
||||
await self.enqueue(queue_name, job)
|
||||
logger.info(f"Job {job.job_id} requeued (retry {job.retry_count}/{job.max_retries})")
|
||||
else:
|
||||
# 실패 큐로 이동
|
||||
job.data["error"] = error
|
||||
job.data["failed_stage"] = queue_name
|
||||
await self.enqueue("failed", job)
|
||||
|
||||
# 통계 업데이트
|
||||
await self.redis_client.hincrby("stats:failed", queue_name, 1)
|
||||
logger.error(f"Job {job.job_id} failed: {error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark job as failed: {e}")
|
||||
|
||||
async def get_queue_stats(self) -> Dict[str, Any]:
|
||||
"""큐 통계 조회"""
|
||||
try:
|
||||
stats = {}
|
||||
|
||||
for name, key in self.QUEUES.items():
|
||||
stats[name] = {
|
||||
"pending": await self.redis_client.llen(key),
|
||||
"processing": await self.redis_client.hlen(f"processing:{name}"),
|
||||
}
|
||||
|
||||
# 완료/실패 통계
|
||||
stats["completed"] = await self.redis_client.hgetall("stats:completed") or {}
|
||||
stats["failed"] = await self.redis_client.hgetall("stats:failed") or {}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get queue stats: {e}")
|
||||
return {}
|
||||
|
||||
async def clear_queue(self, queue_name: str):
|
||||
"""큐 초기화 (테스트용)"""
|
||||
queue_key = self.QUEUES.get(queue_name, f"queue:{queue_name}")
|
||||
await self.redis_client.delete(queue_key)
|
||||
await self.redis_client.delete(f"processing:{queue_name}")
|
||||
logger.info(f"Queue {queue_name} cleared")
|
||||
5
services/pipeline/shared/requirements.txt
Normal file
5
services/pipeline/shared/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
redis[hiredis]==5.0.1
|
||||
motor==3.1.1
|
||||
pymongo==4.3.3
|
||||
pydantic==2.5.0
|
||||
python-dateutil==2.8.2
|
||||
54
services/pipeline/simple_test.py
Normal file
54
services/pipeline/simple_test.py
Normal file
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple pipeline test - direct queue injection
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import redis.asyncio as redis
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
async def test():
|
||||
# Redis 연결
|
||||
r = await redis.from_url("redis://redis:6379", decode_responses=True)
|
||||
|
||||
# 작업 생성
|
||||
job = {
|
||||
"job_id": str(uuid.uuid4()),
|
||||
"keyword_id": str(uuid.uuid4()),
|
||||
"keyword": "전기차",
|
||||
"stage": "rss_collection",
|
||||
"stages_completed": [],
|
||||
"data": {
|
||||
"rss_feeds": [
|
||||
"https://news.google.com/rss/search?q=전기차&hl=ko&gl=KR&ceid=KR:ko"
|
||||
],
|
||||
"categories": ["technology", "automotive"]
|
||||
},
|
||||
"priority": 1,
|
||||
"retry_count": 0,
|
||||
"max_retries": 3,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# QueueMessage 형식으로 래핑
|
||||
message = {
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"queue_name": "rss_collection",
|
||||
"job": job,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 큐에 추가
|
||||
await r.lpush("queue:rss_collection", json.dumps(message))
|
||||
print(f"✅ Job {job['job_id']} added to queue:rss_collection")
|
||||
|
||||
# 큐 상태 확인
|
||||
length = await r.llen("queue:rss_collection")
|
||||
print(f"📊 Queue length: {length}")
|
||||
|
||||
await r.aclose()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test())
|
||||
57
services/pipeline/test_dequeue.py
Normal file
57
services/pipeline/test_dequeue.py
Normal file
@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Direct dequeue test
|
||||
"""
|
||||
import asyncio
|
||||
import redis.asyncio as redis
|
||||
import json
|
||||
|
||||
async def test_dequeue():
|
||||
"""Test dequeue directly"""
|
||||
|
||||
# Connect to Redis
|
||||
redis_client = await redis.from_url(
|
||||
"redis://redis:6379",
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
|
||||
print("Connected to Redis")
|
||||
|
||||
# Check queue length
|
||||
length = await redis_client.llen("queue:rss_collection")
|
||||
print(f"Queue length: {length}")
|
||||
|
||||
if length > 0:
|
||||
# Get the first item
|
||||
item = await redis_client.lrange("queue:rss_collection", 0, 0)
|
||||
print(f"First item preview: {item[0][:200]}...")
|
||||
|
||||
# Try blpop with timeout
|
||||
print("Trying blpop with timeout=5...")
|
||||
result = await redis_client.blpop("queue:rss_collection", 5)
|
||||
if result:
|
||||
queue, data = result
|
||||
print(f"Successfully dequeued from {queue}")
|
||||
print(f"Data: {data[:200]}...")
|
||||
|
||||
# Parse the message
|
||||
try:
|
||||
message = json.loads(data)
|
||||
print(f"Message ID: {message.get('message_id')}")
|
||||
print(f"Queue Name: {message.get('queue_name')}")
|
||||
if 'job' in message:
|
||||
job = message['job']
|
||||
print(f"Job ID: {job.get('job_id')}")
|
||||
print(f"Keyword: {job.get('keyword')}")
|
||||
except Exception as e:
|
||||
print(f"Failed to parse message: {e}")
|
||||
else:
|
||||
print("blpop timed out - no result")
|
||||
else:
|
||||
print("Queue is empty")
|
||||
|
||||
await redis_client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_dequeue())
|
||||
118
services/pipeline/test_pipeline.py
Normal file
118
services/pipeline/test_pipeline.py
Normal file
@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pipeline Test Script
|
||||
파이프라인 전체 플로우를 테스트하는 스크립트
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
import redis.asyncio as redis
|
||||
from shared.models import KeywordSubscription, PipelineJob
|
||||
|
||||
async def test_pipeline():
|
||||
"""파이프라인 테스트"""
|
||||
|
||||
# MongoDB 연결
|
||||
mongo_client = AsyncIOMotorClient("mongodb://mongodb:27017")
|
||||
db = mongo_client.pipeline
|
||||
|
||||
# Redis 연결
|
||||
redis_client = redis.Redis(host='redis', port=6379, decode_responses=True)
|
||||
|
||||
# 1. 테스트 키워드 추가
|
||||
test_keyword = KeywordSubscription(
|
||||
keyword="전기차",
|
||||
language="ko",
|
||||
schedule="*/1 * * * *", # 1분마다 (테스트용)
|
||||
is_active=True,
|
||||
is_priority=True,
|
||||
rss_feeds=[
|
||||
"https://news.google.com/rss/search?q=전기차&hl=ko&gl=KR&ceid=KR:ko",
|
||||
"https://news.google.com/rss/search?q=electric+vehicle&hl=en&gl=US&ceid=US:en"
|
||||
],
|
||||
categories=["technology", "automotive", "environment"],
|
||||
owner="test_user"
|
||||
)
|
||||
|
||||
# MongoDB에 저장
|
||||
await db.keyword_subscriptions.replace_one(
|
||||
{"keyword": test_keyword.keyword},
|
||||
test_keyword.dict(),
|
||||
upsert=True
|
||||
)
|
||||
print(f"✅ 키워드 '{test_keyword.keyword}' 추가 완료")
|
||||
|
||||
# 2. 즉시 파이프라인 트리거 (스케줄러를 거치지 않고 직접)
|
||||
job = PipelineJob(
|
||||
keyword_id=test_keyword.keyword_id,
|
||||
keyword=test_keyword.keyword,
|
||||
stage="rss_collection",
|
||||
data={
|
||||
"rss_feeds": test_keyword.rss_feeds,
|
||||
"categories": test_keyword.categories
|
||||
},
|
||||
priority=1 if test_keyword.is_priority else 0
|
||||
)
|
||||
|
||||
# Redis 큐에 직접 추가 (QueueMessage 형식으로)
|
||||
from shared.queue_manager import QueueMessage
|
||||
message = QueueMessage(
|
||||
queue_name="rss_collection",
|
||||
job=job
|
||||
)
|
||||
await redis_client.lpush("queue:rss_collection", message.json())
|
||||
print(f"✅ 작업을 RSS Collection 큐에 추가: {job.job_id}")
|
||||
|
||||
# 3. 파이프라인 상태 모니터링
|
||||
print("\n📊 파이프라인 실행 모니터링 중...")
|
||||
print("각 단계별 로그를 확인하려면 다음 명령을 실행하세요:")
|
||||
print(" docker-compose logs -f pipeline-rss-collector")
|
||||
print(" docker-compose logs -f pipeline-google-search")
|
||||
print(" docker-compose logs -f pipeline-ai-summarizer")
|
||||
print(" docker-compose logs -f pipeline-translator")
|
||||
print(" docker-compose logs -f pipeline-image-generator")
|
||||
print(" docker-compose logs -f pipeline-article-assembly")
|
||||
|
||||
# 큐 상태 확인
|
||||
for i in range(10):
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# 각 큐의 길이 확인
|
||||
queues = [
|
||||
"queue:rss_collection",
|
||||
"queue:google_search",
|
||||
"queue:ai_summarization",
|
||||
"queue:translation",
|
||||
"queue:image_generation",
|
||||
"queue:article_assembly"
|
||||
]
|
||||
|
||||
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 큐 상태:")
|
||||
for queue in queues:
|
||||
length = await redis_client.llen(queue)
|
||||
if length > 0:
|
||||
print(f" {queue}: {length} 작업 대기 중")
|
||||
|
||||
# 4. 최종 결과 확인
|
||||
print("\n📄 MongoDB에서 생성된 기사 확인 중...")
|
||||
articles = await db.articles.find({"keyword": test_keyword.keyword}).to_list(length=5)
|
||||
|
||||
if articles:
|
||||
print(f"✅ {len(articles)}개의 기사 생성 완료!")
|
||||
for article in articles:
|
||||
print(f"\n제목: {article.get('title', 'N/A')}")
|
||||
print(f"ID: {article.get('article_id', 'N/A')}")
|
||||
print(f"생성 시간: {article.get('created_at', 'N/A')}")
|
||||
print(f"처리 시간: {article.get('processing_time', 'N/A')}초")
|
||||
print(f"이미지 수: {len(article.get('images', []))}")
|
||||
else:
|
||||
print("⚠️ 아직 기사가 생성되지 않았습니다. 조금 더 기다려주세요.")
|
||||
|
||||
# 연결 종료
|
||||
await redis_client.close()
|
||||
mongo_client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 파이프라인 테스트 시작")
|
||||
asyncio.run(test_pipeline())
|
||||
56
services/pipeline/test_starcraft.py
Normal file
56
services/pipeline/test_starcraft.py
Normal file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
스타크래프트 키워드로 파이프라인 테스트
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from shared.queue_manager import QueueManager
|
||||
from shared.models import PipelineJob
|
||||
|
||||
async def test_starcraft_pipeline():
|
||||
"""스타크래프트 키워드로 파이프라인 테스트"""
|
||||
|
||||
# Queue manager 초기화
|
||||
queue_manager = QueueManager(redis_url="redis://redis:6379")
|
||||
await queue_manager.connect()
|
||||
|
||||
try:
|
||||
# 스타크래프트 파이프라인 작업 생성
|
||||
job = PipelineJob(
|
||||
keyword_id="test_starcraft_001",
|
||||
keyword="스타크래프트",
|
||||
stage="rss_collection",
|
||||
data={}
|
||||
)
|
||||
|
||||
print(f"🚀 스타크래프트 파이프라인 작업 시작")
|
||||
print(f" 작업 ID: {job.job_id}")
|
||||
print(f" 키워드: {job.keyword}")
|
||||
print(f" 키워드 ID: {job.keyword_id}")
|
||||
|
||||
# RSS 수집 큐에 작업 추가
|
||||
await queue_manager.enqueue('rss_collection', job)
|
||||
print(f"✅ 작업이 rss_collection 큐에 추가되었습니다")
|
||||
|
||||
# 큐 상태 확인
|
||||
stats = await queue_manager.get_queue_stats()
|
||||
print(f"\n📊 현재 큐 상태:")
|
||||
for queue_name, stat in stats.items():
|
||||
if queue_name not in ['completed', 'failed']:
|
||||
pending = stat.get('pending', 0)
|
||||
processing = stat.get('processing', 0)
|
||||
if pending > 0 or processing > 0:
|
||||
print(f" {queue_name}: 대기={pending}, 처리중={processing}")
|
||||
|
||||
print(f"\n⏳ 파이프라인 실행을 모니터링하세요:")
|
||||
print(f" docker logs site11_pipeline_rss_collector --tail 20 -f")
|
||||
print(f" python3 check_mongodb.py")
|
||||
|
||||
finally:
|
||||
await queue_manager.disconnect()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_starcraft_pipeline())
|
||||
54
services/pipeline/test_submit_job.py
Normal file
54
services/pipeline/test_submit_job.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""
|
||||
파이프라인 테스트 작업 제출 스크립트
|
||||
"""
|
||||
import redis
|
||||
import json
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import sys
|
||||
|
||||
def submit_test_job(keyword='나스닥'):
|
||||
# Redis 연결
|
||||
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
|
||||
|
||||
# 테스트 작업 생성
|
||||
job_id = str(uuid.uuid4())
|
||||
keyword_id = f'test_{job_id[:8]}'
|
||||
|
||||
job_data = {
|
||||
'job_id': job_id,
|
||||
'keyword_id': keyword_id,
|
||||
'keyword': keyword,
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'stage': 'rss_collection',
|
||||
'stages_completed': [],
|
||||
'data': {}
|
||||
}
|
||||
|
||||
# QueueMessage 래퍼 생성
|
||||
queue_message = {
|
||||
'message_id': str(uuid.uuid4()),
|
||||
'queue_name': 'rss_collection',
|
||||
'job': job_data,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'attempts': 0
|
||||
}
|
||||
|
||||
# 큐에 작업 추가 (rpush 사용 - priority=0인 경우)
|
||||
redis_client.rpush('queue:rss_collection', json.dumps(queue_message))
|
||||
print(f'✅ 파이프라인 시작: job_id={job_id}')
|
||||
print(f'✅ 키워드: {keyword}')
|
||||
print(f'✅ RSS Collection 큐에 작업 추가 완료')
|
||||
|
||||
# 큐 상태 확인
|
||||
queue_len = redis_client.llen('queue:rss_collection')
|
||||
print(f'✅ 현재 큐 길이: {queue_len}')
|
||||
|
||||
redis_client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
keyword = sys.argv[1]
|
||||
else:
|
||||
keyword = '나스닥'
|
||||
submit_test_job(keyword)
|
||||
19
services/pipeline/translator/Dockerfile
Normal file
19
services/pipeline/translator/Dockerfile
Normal file
@ -0,0 +1,19 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies
|
||||
COPY ./translator/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared modules
|
||||
COPY ./shared /app/shared
|
||||
|
||||
# Copy config directory
|
||||
COPY ./config /app/config
|
||||
|
||||
# Copy application code
|
||||
COPY ./translator /app
|
||||
|
||||
# Use multi_translator.py as the main service
|
||||
CMD ["python", "multi_translator.py"]
|
||||
329
services/pipeline/translator/language_sync.py
Normal file
329
services/pipeline/translator/language_sync.py
Normal file
@ -0,0 +1,329 @@
|
||||
"""
|
||||
Language Sync Service
|
||||
기존 기사를 새로운 언어로 번역하는 백그라운드 서비스
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
import httpx
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from datetime import datetime
|
||||
|
||||
# Add parent directory to path for shared module
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import FinalArticle, Subtopic, Entities, NewsReference
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LanguageSyncService:
|
||||
def __init__(self):
|
||||
self.deepl_api_key = os.getenv("DEEPL_API_KEY", "3abbc796-2515-44a8-972d-22dcf27ab54a")
|
||||
self.deepl_api_url = "https://api.deepl.com/v2/translate"
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.db = None
|
||||
self.languages_config = None
|
||||
self.config_path = "/app/config/languages.json"
|
||||
self.sync_batch_size = 10
|
||||
self.sync_delay = 2.0 # 언어 간 지연
|
||||
|
||||
async def load_config(self):
|
||||
"""언어 설정 파일 로드"""
|
||||
try:
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
self.languages_config = json.load(f)
|
||||
logger.info(f"Loaded language config")
|
||||
else:
|
||||
raise FileNotFoundError(f"Config file not found: {self.config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config: {e}")
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""백그라운드 싱크 서비스 시작"""
|
||||
logger.info("Starting Language Sync Service")
|
||||
|
||||
# 설정 로드
|
||||
await self.load_config()
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# 주기적으로 싱크 체크 (10분마다)
|
||||
while True:
|
||||
try:
|
||||
await self.sync_missing_translations()
|
||||
await asyncio.sleep(600) # 10분 대기
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sync loop: {e}")
|
||||
await asyncio.sleep(60) # 에러 시 1분 후 재시도
|
||||
|
||||
async def sync_missing_translations(self):
|
||||
"""누락된 번역 싱크"""
|
||||
try:
|
||||
# 활성화된 언어 목록
|
||||
enabled_languages = [
|
||||
lang for lang in self.languages_config["enabled_languages"]
|
||||
if lang["enabled"]
|
||||
]
|
||||
|
||||
if not enabled_languages:
|
||||
logger.info("No enabled languages for sync")
|
||||
return
|
||||
|
||||
# 원본 언어 컬렉션
|
||||
source_collection = self.languages_config["source_language"]["collection"]
|
||||
|
||||
for lang_config in enabled_languages:
|
||||
await self.sync_language(source_collection, lang_config)
|
||||
await asyncio.sleep(self.sync_delay)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sync_missing_translations: {e}")
|
||||
|
||||
async def sync_language(self, source_collection: str, lang_config: Dict):
|
||||
"""특정 언어로 누락된 기사 번역"""
|
||||
try:
|
||||
target_collection = lang_config["collection"]
|
||||
|
||||
# 번역되지 않은 기사 찾기
|
||||
# 원본에는 있지만 대상 컬렉션에는 없는 기사
|
||||
source_articles = await self.db[source_collection].find(
|
||||
{},
|
||||
{"news_id": 1}
|
||||
).to_list(None)
|
||||
|
||||
source_ids = {article["news_id"] for article in source_articles}
|
||||
|
||||
translated_articles = await self.db[target_collection].find(
|
||||
{},
|
||||
{"news_id": 1}
|
||||
).to_list(None)
|
||||
|
||||
translated_ids = {article["news_id"] for article in translated_articles}
|
||||
|
||||
# 누락된 news_id
|
||||
missing_ids = source_ids - translated_ids
|
||||
|
||||
if not missing_ids:
|
||||
logger.info(f"No missing translations for {lang_config['name']}")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(missing_ids)} missing translations for {lang_config['name']}")
|
||||
|
||||
# 배치로 처리
|
||||
missing_list = list(missing_ids)
|
||||
for i in range(0, len(missing_list), self.sync_batch_size):
|
||||
batch = missing_list[i:i+self.sync_batch_size]
|
||||
|
||||
for news_id in batch:
|
||||
try:
|
||||
# 원본 기사 조회
|
||||
korean_article = await self.db[source_collection].find_one(
|
||||
{"news_id": news_id}
|
||||
)
|
||||
|
||||
if not korean_article:
|
||||
continue
|
||||
|
||||
# 번역 수행
|
||||
await self.translate_and_save(
|
||||
korean_article,
|
||||
lang_config
|
||||
)
|
||||
|
||||
logger.info(f"Synced article {news_id} to {lang_config['code']}")
|
||||
|
||||
# API 속도 제한
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating {news_id} to {lang_config['code']}: {e}")
|
||||
continue
|
||||
|
||||
# 배치 간 지연
|
||||
if i + self.sync_batch_size < len(missing_list):
|
||||
await asyncio.sleep(self.sync_delay)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing language {lang_config['code']}: {e}")
|
||||
|
||||
async def translate_and_save(self, korean_article: Dict, lang_config: Dict):
|
||||
"""기사 번역 및 저장"""
|
||||
try:
|
||||
# 제목 번역
|
||||
translated_title = await self._translate_text(
|
||||
korean_article.get('title', ''),
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
|
||||
# 요약 번역
|
||||
translated_summary = await self._translate_text(
|
||||
korean_article.get('summary', ''),
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
|
||||
# Subtopics 번역
|
||||
translated_subtopics = []
|
||||
for subtopic in korean_article.get('subtopics', []):
|
||||
translated_subtopic_title = await self._translate_text(
|
||||
subtopic.get('title', ''),
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
|
||||
translated_content_list = []
|
||||
for content_para in subtopic.get('content', []):
|
||||
translated_para = await self._translate_text(
|
||||
content_para,
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
translated_content_list.append(translated_para)
|
||||
|
||||
translated_subtopics.append(Subtopic(
|
||||
title=translated_subtopic_title,
|
||||
content=translated_content_list
|
||||
))
|
||||
|
||||
# 카테고리 번역
|
||||
translated_categories = []
|
||||
for category in korean_article.get('categories', []):
|
||||
translated_cat = await self._translate_text(
|
||||
category,
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
translated_categories.append(translated_cat)
|
||||
|
||||
# Entities와 References는 원본 유지
|
||||
entities_data = korean_article.get('entities', {})
|
||||
translated_entities = Entities(**entities_data) if entities_data else Entities()
|
||||
|
||||
references = []
|
||||
for ref_data in korean_article.get('references', []):
|
||||
references.append(NewsReference(**ref_data))
|
||||
|
||||
# 번역된 기사 생성
|
||||
translated_article = FinalArticle(
|
||||
news_id=korean_article.get('news_id'),
|
||||
title=translated_title,
|
||||
summary=translated_summary,
|
||||
subtopics=translated_subtopics,
|
||||
categories=translated_categories,
|
||||
entities=translated_entities,
|
||||
source_keyword=korean_article.get('source_keyword'),
|
||||
source_count=korean_article.get('source_count', 1),
|
||||
references=references,
|
||||
job_id=korean_article.get('job_id'),
|
||||
keyword_id=korean_article.get('keyword_id'),
|
||||
pipeline_stages=korean_article.get('pipeline_stages', []) + ['sync_translation'],
|
||||
processing_time=korean_article.get('processing_time', 0),
|
||||
language=lang_config["code"],
|
||||
ref_news_id=None,
|
||||
rss_guid=korean_article.get('rss_guid'), # RSS GUID 유지
|
||||
image_prompt=korean_article.get('image_prompt'), # 이미지 프롬프트 유지
|
||||
images=korean_article.get('images', []), # 이미지 URL 리스트 유지
|
||||
translated_languages=korean_article.get('translated_languages', []) # 번역 언어 목록 유지
|
||||
)
|
||||
|
||||
# MongoDB에 저장
|
||||
collection_name = lang_config["collection"]
|
||||
result = await self.db[collection_name].insert_one(translated_article.model_dump())
|
||||
|
||||
# 원본 기사에 번역 완료 표시
|
||||
await self.db[self.languages_config["source_language"]["collection"]].update_one(
|
||||
{"news_id": korean_article.get('news_id')},
|
||||
{
|
||||
"$addToSet": {
|
||||
"translated_languages": lang_config["code"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Synced article to {collection_name}: {result.inserted_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in translate_and_save: {e}")
|
||||
raise
|
||||
|
||||
async def _translate_text(self, text: str, target_lang: str = 'EN') -> str:
|
||||
"""DeepL API를 사용한 텍스트 번역"""
|
||||
try:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.deepl_api_url,
|
||||
data={
|
||||
'auth_key': self.deepl_api_key,
|
||||
'text': text,
|
||||
'target_lang': target_lang,
|
||||
'source_lang': 'KO'
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result['translations'][0]['text']
|
||||
else:
|
||||
logger.error(f"DeepL API error: {response.status_code}")
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating text: {e}")
|
||||
return text
|
||||
|
||||
async def manual_sync(self, language_code: str = None):
|
||||
"""수동 싱크 실행"""
|
||||
logger.info(f"Manual sync requested for language: {language_code or 'all'}")
|
||||
|
||||
await self.load_config()
|
||||
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
if language_code:
|
||||
# 특정 언어만 싱크
|
||||
lang_config = next(
|
||||
(lang for lang in self.languages_config["enabled_languages"]
|
||||
if lang["code"] == language_code and lang["enabled"]),
|
||||
None
|
||||
)
|
||||
if lang_config:
|
||||
source_collection = self.languages_config["source_language"]["collection"]
|
||||
await self.sync_language(source_collection, lang_config)
|
||||
else:
|
||||
logger.error(f"Language {language_code} not found or not enabled")
|
||||
else:
|
||||
# 모든 활성 언어 싱크
|
||||
await self.sync_missing_translations()
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
service = LanguageSyncService()
|
||||
|
||||
# 명령줄 인수 확인
|
||||
if len(sys.argv) > 1:
|
||||
if sys.argv[1] == "sync":
|
||||
# 수동 싱크 모드
|
||||
language = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
await service.manual_sync(language)
|
||||
else:
|
||||
logger.error(f"Unknown command: {sys.argv[1]}")
|
||||
else:
|
||||
# 백그라운드 서비스 모드
|
||||
try:
|
||||
await service.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
320
services/pipeline/translator/multi_translator.py
Normal file
320
services/pipeline/translator/multi_translator.py
Normal file
@ -0,0 +1,320 @@
|
||||
"""
|
||||
Multi-Language Translation Service
|
||||
다국어 번역 서비스 - 설정 기반 다중 언어 지원
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
import httpx
|
||||
import redis.asyncio as redis
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from datetime import datetime
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import PipelineJob, FinalArticle
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MultiLanguageTranslator:
|
||||
def __init__(self):
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
self.deepl_api_key = os.getenv("DEEPL_API_KEY", "3abbc796-2515-44a8-972d-22dcf27ab54a")
|
||||
self.deepl_api_url = "https://api.deepl.com/v2/translate"
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.db = None
|
||||
self.languages_config = None
|
||||
self.config_path = "/app/config/languages.json"
|
||||
|
||||
async def load_config(self):
|
||||
"""언어 설정 파일 로드"""
|
||||
try:
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
self.languages_config = json.load(f)
|
||||
else:
|
||||
# 기본 설정 (영어만)
|
||||
self.languages_config = {
|
||||
"enabled_languages": [
|
||||
{
|
||||
"code": "en",
|
||||
"name": "English",
|
||||
"deepl_code": "EN",
|
||||
"collection": "articles_en",
|
||||
"enabled": True
|
||||
}
|
||||
],
|
||||
"source_language": {
|
||||
"code": "ko",
|
||||
"name": "Korean",
|
||||
"collection": "articles_ko"
|
||||
},
|
||||
"translation_settings": {
|
||||
"batch_size": 5,
|
||||
"delay_between_languages": 2.0,
|
||||
"delay_between_articles": 0.5,
|
||||
"max_retries": 3
|
||||
}
|
||||
}
|
||||
logger.info(f"Loaded language config: {len(self.get_enabled_languages())} languages enabled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config: {e}")
|
||||
raise
|
||||
|
||||
def get_enabled_languages(self) -> List[Dict]:
|
||||
"""활성화된 언어 목록 반환"""
|
||||
return [lang for lang in self.languages_config["enabled_languages"] if lang["enabled"]]
|
||||
|
||||
async def start(self):
|
||||
"""워커 시작"""
|
||||
logger.info("Starting Multi-Language Translator Worker")
|
||||
|
||||
# 설정 로드
|
||||
await self.load_config()
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# DeepL API 키 확인
|
||||
if not self.deepl_api_key:
|
||||
logger.error("DeepL API key not configured")
|
||||
return
|
||||
|
||||
# 메인 처리 루프
|
||||
while True:
|
||||
try:
|
||||
# 큐에서 작업 가져오기
|
||||
job = await self.queue_manager.dequeue('translation', timeout=5)
|
||||
|
||||
if job:
|
||||
await self.process_job(job)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker loop: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_job(self, job: PipelineJob):
|
||||
"""모든 활성 언어로 번역"""
|
||||
try:
|
||||
logger.info(f"Processing job {job.job_id} for multi-language translation")
|
||||
|
||||
# MongoDB에서 한국어 기사 가져오기
|
||||
news_id = job.data.get('news_id')
|
||||
if not news_id:
|
||||
logger.error(f"No news_id in job {job.job_id}")
|
||||
await self.queue_manager.mark_failed('translation', job, "No news_id")
|
||||
return
|
||||
|
||||
# 원본 컬렉션에서 기사 조회
|
||||
source_collection = self.languages_config["source_language"]["collection"]
|
||||
korean_article = await self.db[source_collection].find_one({"news_id": news_id})
|
||||
|
||||
if not korean_article:
|
||||
logger.error(f"Article {news_id} not found in {source_collection}")
|
||||
await self.queue_manager.mark_failed('translation', job, "Article not found")
|
||||
return
|
||||
|
||||
# 활성화된 모든 언어로 번역
|
||||
enabled_languages = self.get_enabled_languages()
|
||||
settings = self.languages_config["translation_settings"]
|
||||
|
||||
for lang_config in enabled_languages:
|
||||
try:
|
||||
logger.info(f"Translating article {news_id} to {lang_config['name']}")
|
||||
|
||||
# 이미 번역되었는지 확인
|
||||
existing = await self.db[lang_config["collection"]].find_one({"news_id": news_id})
|
||||
if existing:
|
||||
logger.info(f"Article {news_id} already translated to {lang_config['code']}")
|
||||
continue
|
||||
|
||||
# 번역 수행
|
||||
await self.translate_article(
|
||||
korean_article,
|
||||
lang_config,
|
||||
job
|
||||
)
|
||||
|
||||
# 언어 간 지연
|
||||
if settings.get("delay_between_languages"):
|
||||
await asyncio.sleep(settings["delay_between_languages"])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating to {lang_config['code']}: {e}")
|
||||
continue
|
||||
|
||||
# 파이프라인 완료 로그
|
||||
logger.info(f"Translation pipeline completed for news_id: {news_id}")
|
||||
|
||||
# 완료 표시
|
||||
job.stages_completed.append('translation')
|
||||
await self.queue_manager.mark_completed('translation', job.job_id)
|
||||
|
||||
logger.info(f"Multi-language translation completed for job {job.job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing job {job.job_id}: {e}")
|
||||
await self.queue_manager.mark_failed('translation', job, str(e))
|
||||
|
||||
async def translate_article(self, korean_article: Dict, lang_config: Dict, job: PipelineJob):
|
||||
"""특정 언어로 기사 번역"""
|
||||
try:
|
||||
# 제목 번역
|
||||
translated_title = await self._translate_text(
|
||||
korean_article.get('title', ''),
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
|
||||
# 요약 번역
|
||||
translated_summary = await self._translate_text(
|
||||
korean_article.get('summary', ''),
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
|
||||
# Subtopics 번역
|
||||
from shared.models import Subtopic
|
||||
translated_subtopics = []
|
||||
|
||||
for subtopic in korean_article.get('subtopics', []):
|
||||
translated_subtopic_title = await self._translate_text(
|
||||
subtopic.get('title', ''),
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
|
||||
translated_content_list = []
|
||||
for content_para in subtopic.get('content', []):
|
||||
translated_para = await self._translate_text(
|
||||
content_para,
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
translated_content_list.append(translated_para)
|
||||
|
||||
# API 속도 제한
|
||||
settings = self.languages_config["translation_settings"]
|
||||
if settings.get("delay_between_articles"):
|
||||
await asyncio.sleep(settings["delay_between_articles"])
|
||||
|
||||
translated_subtopics.append(Subtopic(
|
||||
title=translated_subtopic_title,
|
||||
content=translated_content_list
|
||||
))
|
||||
|
||||
# 카테고리 번역
|
||||
translated_categories = []
|
||||
for category in korean_article.get('categories', []):
|
||||
translated_cat = await self._translate_text(
|
||||
category,
|
||||
target_lang=lang_config["deepl_code"]
|
||||
)
|
||||
translated_categories.append(translated_cat)
|
||||
|
||||
# Entities와 References는 원본 유지
|
||||
from shared.models import Entities, NewsReference
|
||||
entities_data = korean_article.get('entities', {})
|
||||
translated_entities = Entities(**entities_data) if entities_data else Entities()
|
||||
|
||||
references = []
|
||||
for ref_data in korean_article.get('references', []):
|
||||
references.append(NewsReference(**ref_data))
|
||||
|
||||
# 번역된 기사 생성
|
||||
translated_article = FinalArticle(
|
||||
news_id=korean_article.get('news_id'), # 같은 news_id 사용
|
||||
title=translated_title,
|
||||
summary=translated_summary,
|
||||
subtopics=translated_subtopics,
|
||||
categories=translated_categories,
|
||||
entities=translated_entities,
|
||||
source_keyword=job.keyword if hasattr(job, 'keyword') else korean_article.get('source_keyword'),
|
||||
source_count=korean_article.get('source_count', 1),
|
||||
references=references,
|
||||
job_id=job.job_id,
|
||||
keyword_id=job.keyword_id if hasattr(job, 'keyword_id') else None,
|
||||
pipeline_stages=korean_article.get('pipeline_stages', []) + ['translation'],
|
||||
processing_time=korean_article.get('processing_time', 0),
|
||||
language=lang_config["code"],
|
||||
ref_news_id=None, # 같은 news_id 사용하므로 불필요
|
||||
rss_guid=korean_article.get('rss_guid'), # RSS GUID 유지
|
||||
image_prompt=korean_article.get('image_prompt'), # 이미지 프롬프트 유지
|
||||
images=korean_article.get('images', []), # 이미지 URL 리스트 유지
|
||||
translated_languages=korean_article.get('translated_languages', []) # 번역 언어 목록 유지
|
||||
)
|
||||
|
||||
# MongoDB에 저장
|
||||
collection_name = lang_config["collection"]
|
||||
result = await self.db[collection_name].insert_one(translated_article.model_dump())
|
||||
|
||||
logger.info(f"Article saved to {collection_name} with _id: {result.inserted_id}, language: {lang_config['code']}")
|
||||
|
||||
# 원본 기사에 번역 완료 표시
|
||||
await self.db[self.languages_config["source_language"]["collection"]].update_one(
|
||||
{"news_id": korean_article.get('news_id')},
|
||||
{
|
||||
"$addToSet": {
|
||||
"translated_languages": lang_config["code"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating article to {lang_config['code']}: {e}")
|
||||
raise
|
||||
|
||||
async def _translate_text(self, text: str, target_lang: str = 'EN') -> str:
|
||||
"""DeepL API를 사용한 텍스트 번역"""
|
||||
try:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.deepl_api_url,
|
||||
data={
|
||||
'auth_key': self.deepl_api_key,
|
||||
'text': text,
|
||||
'target_lang': target_lang,
|
||||
'source_lang': 'KO'
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result['translations'][0]['text']
|
||||
else:
|
||||
logger.error(f"DeepL API error: {response.status_code}")
|
||||
return text # 번역 실패시 원본 반환
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating text: {e}")
|
||||
return text # 번역 실패시 원본 반환
|
||||
|
||||
async def stop(self):
|
||||
"""워커 중지"""
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info("Multi-Language Translator Worker stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
worker = MultiLanguageTranslator()
|
||||
|
||||
try:
|
||||
await worker.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await worker.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
5
services/pipeline/translator/requirements.txt
Normal file
5
services/pipeline/translator/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
httpx==0.25.0
|
||||
redis[hiredis]==5.0.1
|
||||
pydantic==2.5.0
|
||||
motor==3.1.1
|
||||
pymongo==4.3.3
|
||||
230
services/pipeline/translator/translator.py
Normal file
230
services/pipeline/translator/translator.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""
|
||||
Translation Service
|
||||
DeepL API를 사용한 번역 서비스
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any
|
||||
import httpx
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from datetime import datetime
|
||||
|
||||
# Import from shared module
|
||||
from shared.models import PipelineJob, FinalArticle
|
||||
from shared.queue_manager import QueueManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TranslatorWorker:
|
||||
def __init__(self):
|
||||
self.queue_manager = QueueManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
self.deepl_api_key = os.getenv("DEEPL_API_KEY", "3abbc796-2515-44a8-972d-22dcf27ab54a")
|
||||
# DeepL Pro API 엔드포인트 사용
|
||||
self.deepl_api_url = "https://api.deepl.com/v2/translate"
|
||||
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
|
||||
self.db = None
|
||||
|
||||
async def start(self):
|
||||
"""워커 시작"""
|
||||
logger.info("Starting Translator Worker")
|
||||
|
||||
# Redis 연결
|
||||
await self.queue_manager.connect()
|
||||
|
||||
# MongoDB 연결
|
||||
client = AsyncIOMotorClient(self.mongodb_url)
|
||||
self.db = client[self.db_name]
|
||||
|
||||
# DeepL API 키 확인
|
||||
if not self.deepl_api_key:
|
||||
logger.error("DeepL API key not configured")
|
||||
return
|
||||
|
||||
# 메인 처리 루프
|
||||
while True:
|
||||
try:
|
||||
# 큐에서 작업 가져오기
|
||||
job = await self.queue_manager.dequeue('translation', timeout=5)
|
||||
|
||||
if job:
|
||||
await self.process_job(job)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker loop: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_job(self, job: PipelineJob):
|
||||
"""영어 버전 기사 생성 및 저장"""
|
||||
try:
|
||||
logger.info(f"Processing job {job.job_id} for translation")
|
||||
|
||||
# MongoDB에서 한국어 기사 가져오기
|
||||
news_id = job.data.get('news_id')
|
||||
if not news_id:
|
||||
logger.error(f"No news_id in job {job.job_id}")
|
||||
await self.queue_manager.mark_failed('translation', job, "No news_id")
|
||||
return
|
||||
|
||||
# MongoDB에서 한국어 기사 조회 (articles_ko)
|
||||
korean_article = await self.db.articles_ko.find_one({"news_id": news_id})
|
||||
if not korean_article:
|
||||
logger.error(f"Article {news_id} not found in MongoDB")
|
||||
await self.queue_manager.mark_failed('translation', job, "Article not found")
|
||||
return
|
||||
|
||||
# 영어로 번역
|
||||
translated_title = await self._translate_text(
|
||||
korean_article.get('title', ''),
|
||||
target_lang='EN'
|
||||
)
|
||||
|
||||
translated_summary = await self._translate_text(
|
||||
korean_article.get('summary', ''),
|
||||
target_lang='EN'
|
||||
)
|
||||
|
||||
# Subtopics 번역
|
||||
from shared.models import Subtopic
|
||||
translated_subtopics = []
|
||||
for subtopic in korean_article.get('subtopics', []):
|
||||
translated_subtopic_title = await self._translate_text(
|
||||
subtopic.get('title', ''),
|
||||
target_lang='EN'
|
||||
)
|
||||
|
||||
translated_content_list = []
|
||||
for content_para in subtopic.get('content', []):
|
||||
translated_para = await self._translate_text(
|
||||
content_para,
|
||||
target_lang='EN'
|
||||
)
|
||||
translated_content_list.append(translated_para)
|
||||
await asyncio.sleep(0.2) # API 속도 제한
|
||||
|
||||
translated_subtopics.append(Subtopic(
|
||||
title=translated_subtopic_title,
|
||||
content=translated_content_list
|
||||
))
|
||||
|
||||
# 카테고리 번역
|
||||
translated_categories = []
|
||||
for category in korean_article.get('categories', []):
|
||||
translated_cat = await self._translate_text(category, target_lang='EN')
|
||||
translated_categories.append(translated_cat)
|
||||
await asyncio.sleep(0.2) # API 속도 제한
|
||||
|
||||
# Entities 번역 (선택적)
|
||||
from shared.models import Entities
|
||||
entities_data = korean_article.get('entities', {})
|
||||
translated_entities = Entities(
|
||||
people=entities_data.get('people', []), # 인명은 번역하지 않음
|
||||
organizations=entities_data.get('organizations', []), # 조직명은 번역하지 않음
|
||||
groups=entities_data.get('groups', []),
|
||||
countries=entities_data.get('countries', []),
|
||||
events=entities_data.get('events', [])
|
||||
)
|
||||
|
||||
# 레퍼런스 가져오기 (번역하지 않음)
|
||||
from shared.models import NewsReference
|
||||
references = []
|
||||
for ref_data in korean_article.get('references', []):
|
||||
references.append(NewsReference(**ref_data))
|
||||
|
||||
# 영어 버전 기사 생성 - 같은 news_id 사용
|
||||
english_article = FinalArticle(
|
||||
news_id=news_id, # 원본과 같은 news_id 사용
|
||||
title=translated_title,
|
||||
summary=translated_summary,
|
||||
subtopics=translated_subtopics,
|
||||
categories=translated_categories,
|
||||
entities=translated_entities,
|
||||
source_keyword=job.keyword,
|
||||
source_count=korean_article.get('source_count', 1),
|
||||
references=references, # 원본 레퍼런스 그대로 사용
|
||||
job_id=job.job_id,
|
||||
keyword_id=job.keyword_id,
|
||||
pipeline_stages=job.stages_completed.copy() + ['translation'],
|
||||
processing_time=korean_article.get('processing_time', 0),
|
||||
language='en', # 영어
|
||||
ref_news_id=None # 같은 news_id를 사용하므로 ref 불필요
|
||||
)
|
||||
|
||||
# MongoDB에 영어 버전 저장 (articles_en)
|
||||
result = await self.db.articles_en.insert_one(english_article.model_dump())
|
||||
english_article_id = str(result.inserted_id)
|
||||
|
||||
logger.info(f"English article saved with _id: {english_article_id}, news_id: {news_id}, language: en")
|
||||
|
||||
# 원본 한국어 기사 업데이트 - 번역 완료 표시
|
||||
await self.db.articles_ko.update_one(
|
||||
{"news_id": news_id},
|
||||
{
|
||||
"$addToSet": {
|
||||
"pipeline_stages": "translation"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# 완료 표시
|
||||
job.stages_completed.append('translation')
|
||||
await self.queue_manager.mark_completed('translation', job.job_id)
|
||||
|
||||
logger.info(f"Translation completed for job {job.job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing job {job.job_id}: {e}")
|
||||
await self.queue_manager.mark_failed('translation', job, str(e))
|
||||
|
||||
async def _translate_text(self, text: str, target_lang: str = 'EN') -> str:
|
||||
"""DeepL API를 사용한 텍스트 번역"""
|
||||
try:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.deepl_api_url,
|
||||
data={
|
||||
'auth_key': self.deepl_api_key,
|
||||
'text': text,
|
||||
'target_lang': target_lang,
|
||||
'source_lang': 'KO'
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result['translations'][0]['text']
|
||||
else:
|
||||
logger.error(f"DeepL API error: {response.status_code}")
|
||||
return text # 번역 실패시 원본 반환
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating text: {e}")
|
||||
return text # 번역 실패시 원본 반환
|
||||
|
||||
async def stop(self):
|
||||
"""워커 중지"""
|
||||
await self.queue_manager.disconnect()
|
||||
logger.info("Translator Worker stopped")
|
||||
|
||||
async def main():
|
||||
"""메인 함수"""
|
||||
worker = TranslatorWorker()
|
||||
|
||||
try:
|
||||
await worker.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal")
|
||||
finally:
|
||||
await worker.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
21
services/search/backend/Dockerfile
Normal file
21
services/search/backend/Dockerfile
Normal file
@ -0,0 +1,21 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /app/logs
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
286
services/search/backend/indexer.py
Normal file
286
services/search/backend/indexer.py
Normal file
@ -0,0 +1,286 @@
|
||||
"""
|
||||
Data indexer for synchronizing data from other services to Solr
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from aiokafka import AIOKafkaConsumer
|
||||
import json
|
||||
from solr_client import SolrClient
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DataIndexer:
|
||||
def __init__(self, solr_client: SolrClient, mongodb_url: str, kafka_servers: str):
|
||||
self.solr = solr_client
|
||||
self.mongodb_url = mongodb_url
|
||||
self.kafka_servers = kafka_servers
|
||||
self.mongo_client = None
|
||||
self.kafka_consumer = None
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the indexer"""
|
||||
try:
|
||||
# Connect to MongoDB
|
||||
self.mongo_client = AsyncIOMotorClient(self.mongodb_url)
|
||||
|
||||
# Initialize Kafka consumer
|
||||
await self._init_kafka_consumer()
|
||||
|
||||
# Start background tasks
|
||||
self.running = True
|
||||
asyncio.create_task(self._consume_kafka_events())
|
||||
asyncio.create_task(self._periodic_sync())
|
||||
|
||||
logger.info("Data indexer started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start indexer: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the indexer"""
|
||||
self.running = False
|
||||
|
||||
if self.kafka_consumer:
|
||||
await self.kafka_consumer.stop()
|
||||
|
||||
if self.mongo_client:
|
||||
self.mongo_client.close()
|
||||
|
||||
logger.info("Data indexer stopped")
|
||||
|
||||
async def _init_kafka_consumer(self):
|
||||
"""Initialize Kafka consumer"""
|
||||
try:
|
||||
self.kafka_consumer = AIOKafkaConsumer(
|
||||
'user_events',
|
||||
'file_events',
|
||||
'content_events',
|
||||
bootstrap_servers=self.kafka_servers,
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8')),
|
||||
group_id='search_indexer',
|
||||
auto_offset_reset='latest'
|
||||
)
|
||||
await self.kafka_consumer.start()
|
||||
logger.info("Kafka consumer initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Kafka consumer initialization failed: {e}")
|
||||
self.kafka_consumer = None
|
||||
|
||||
async def _consume_kafka_events(self):
|
||||
"""Consume events from Kafka and index them"""
|
||||
if not self.kafka_consumer:
|
||||
return
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
async for msg in self.kafka_consumer:
|
||||
await self._handle_kafka_event(msg.topic, msg.value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Kafka consumption error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _handle_kafka_event(self, topic: str, event: Dict[str, Any]):
|
||||
"""Handle a Kafka event"""
|
||||
try:
|
||||
event_type = event.get('type')
|
||||
data = event.get('data', {})
|
||||
|
||||
if topic == 'user_events':
|
||||
await self._index_user_event(event_type, data)
|
||||
elif topic == 'file_events':
|
||||
await self._index_file_event(event_type, data)
|
||||
elif topic == 'content_events':
|
||||
await self._index_content_event(event_type, data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle event: {e}")
|
||||
|
||||
async def _index_user_event(self, event_type: str, data: Dict):
|
||||
"""Index user-related events"""
|
||||
if event_type == 'user_created' or event_type == 'user_updated':
|
||||
user_doc = {
|
||||
'id': f"user_{data.get('user_id')}",
|
||||
'doc_type': 'user',
|
||||
'user_id': data.get('user_id'),
|
||||
'username': data.get('username'),
|
||||
'email': data.get('email'),
|
||||
'name': data.get('name', ''),
|
||||
'bio': data.get('bio', ''),
|
||||
'tags': data.get('tags', []),
|
||||
'created_at': data.get('created_at'),
|
||||
'updated_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
self.solr.index_document(user_doc)
|
||||
|
||||
elif event_type == 'user_deleted':
|
||||
self.solr.delete_document(f"user_{data.get('user_id')}")
|
||||
|
||||
async def _index_file_event(self, event_type: str, data: Dict):
|
||||
"""Index file-related events"""
|
||||
if event_type == 'file_uploaded':
|
||||
file_doc = {
|
||||
'id': f"file_{data.get('file_id')}",
|
||||
'doc_type': 'file',
|
||||
'file_id': data.get('file_id'),
|
||||
'filename': data.get('filename'),
|
||||
'content_type': data.get('content_type'),
|
||||
'size': data.get('size'),
|
||||
'user_id': data.get('user_id'),
|
||||
'tags': data.get('tags', []),
|
||||
'description': data.get('description', ''),
|
||||
'created_at': data.get('created_at'),
|
||||
'updated_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
self.solr.index_document(file_doc)
|
||||
|
||||
elif event_type == 'file_deleted':
|
||||
self.solr.delete_document(f"file_{data.get('file_id')}")
|
||||
|
||||
async def _index_content_event(self, event_type: str, data: Dict):
|
||||
"""Index content-related events"""
|
||||
if event_type in ['content_created', 'content_updated']:
|
||||
content_doc = {
|
||||
'id': f"content_{data.get('content_id')}",
|
||||
'doc_type': 'content',
|
||||
'content_id': data.get('content_id'),
|
||||
'title': data.get('title'),
|
||||
'content': data.get('content', ''),
|
||||
'summary': data.get('summary', ''),
|
||||
'author_id': data.get('author_id'),
|
||||
'tags': data.get('tags', []),
|
||||
'category': data.get('category'),
|
||||
'status': data.get('status', 'draft'),
|
||||
'created_at': data.get('created_at'),
|
||||
'updated_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
self.solr.index_document(content_doc)
|
||||
|
||||
elif event_type == 'content_deleted':
|
||||
self.solr.delete_document(f"content_{data.get('content_id')}")
|
||||
|
||||
async def _periodic_sync(self):
|
||||
"""Periodically sync data from MongoDB"""
|
||||
while self.running:
|
||||
try:
|
||||
# Sync every 5 minutes
|
||||
await asyncio.sleep(300)
|
||||
await self.sync_all_data()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Periodic sync error: {e}")
|
||||
|
||||
async def sync_all_data(self):
|
||||
"""Sync all data from MongoDB to Solr"""
|
||||
try:
|
||||
logger.info("Starting full data sync")
|
||||
|
||||
# Sync users
|
||||
await self._sync_users()
|
||||
|
||||
# Sync files
|
||||
await self._sync_files()
|
||||
|
||||
# Optimize index after bulk sync
|
||||
self.solr.optimize_index()
|
||||
|
||||
logger.info("Full data sync completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Full sync failed: {e}")
|
||||
|
||||
async def _sync_users(self):
|
||||
"""Sync users from MongoDB"""
|
||||
try:
|
||||
db = self.mongo_client['users_db']
|
||||
collection = db['users']
|
||||
|
||||
users = []
|
||||
async for user in collection.find({'deleted_at': None}):
|
||||
user_doc = {
|
||||
'id': f"user_{str(user['_id'])}",
|
||||
'doc_type': 'user',
|
||||
'user_id': str(user['_id']),
|
||||
'username': user.get('username'),
|
||||
'email': user.get('email'),
|
||||
'name': user.get('name', ''),
|
||||
'bio': user.get('bio', ''),
|
||||
'tags': user.get('tags', []),
|
||||
'created_at': user.get('created_at').isoformat() if user.get('created_at') else None,
|
||||
'updated_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
users.append(user_doc)
|
||||
|
||||
# Bulk index every 100 documents
|
||||
if len(users) >= 100:
|
||||
self.solr.bulk_index(users, 'user')
|
||||
users = []
|
||||
|
||||
# Index remaining users
|
||||
if users:
|
||||
self.solr.bulk_index(users, 'user')
|
||||
|
||||
logger.info(f"Synced users to Solr")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync users: {e}")
|
||||
|
||||
async def _sync_files(self):
|
||||
"""Sync files from MongoDB"""
|
||||
try:
|
||||
db = self.mongo_client['files_db']
|
||||
collection = db['file_metadata']
|
||||
|
||||
files = []
|
||||
async for file in collection.find({'deleted_at': None}):
|
||||
file_doc = {
|
||||
'id': f"file_{str(file['_id'])}",
|
||||
'doc_type': 'file',
|
||||
'file_id': str(file['_id']),
|
||||
'filename': file.get('filename'),
|
||||
'original_name': file.get('original_name'),
|
||||
'content_type': file.get('content_type'),
|
||||
'size': file.get('size'),
|
||||
'user_id': file.get('user_id'),
|
||||
'tags': list(file.get('tags', {}).keys()),
|
||||
'description': file.get('metadata', {}).get('description', ''),
|
||||
'created_at': file.get('created_at').isoformat() if file.get('created_at') else None,
|
||||
'updated_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
files.append(file_doc)
|
||||
|
||||
# Bulk index every 100 documents
|
||||
if len(files) >= 100:
|
||||
self.solr.bulk_index(files, 'file')
|
||||
files = []
|
||||
|
||||
# Index remaining files
|
||||
if files:
|
||||
self.solr.bulk_index(files, 'file')
|
||||
|
||||
logger.info(f"Synced files to Solr")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync files: {e}")
|
||||
|
||||
async def reindex_collection(self, collection_name: str, doc_type: str):
|
||||
"""Reindex a specific collection"""
|
||||
try:
|
||||
# Delete existing documents of this type
|
||||
self.solr.delete_by_query(f'doc_type:{doc_type}')
|
||||
|
||||
# Sync the collection
|
||||
if collection_name == 'users':
|
||||
await self._sync_users()
|
||||
elif collection_name == 'files':
|
||||
await self._sync_files()
|
||||
|
||||
logger.info(f"Reindexed {collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reindex {collection_name}: {e}")
|
||||
362
services/search/backend/main.py
Normal file
362
services/search/backend/main.py
Normal file
@ -0,0 +1,362 @@
|
||||
"""
|
||||
Search Service with Apache Solr
|
||||
"""
|
||||
from fastapi import FastAPI, Query, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from solr_client import SolrClient
|
||||
from indexer import DataIndexer
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
solr_client = None
|
||||
data_indexer = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifecycle"""
|
||||
global solr_client, data_indexer
|
||||
|
||||
# Startup
|
||||
logger.info("Starting Search Service...")
|
||||
|
||||
# Wait for Solr to be ready
|
||||
solr_url = os.getenv("SOLR_URL", "http://solr:8983/solr")
|
||||
max_retries = 30
|
||||
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
solr_client = SolrClient(solr_url=solr_url, core_name="site11")
|
||||
logger.info("Connected to Solr")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Waiting for Solr... ({i+1}/{max_retries})")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if solr_client:
|
||||
# Initialize data indexer
|
||||
mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
kafka_servers = os.getenv("KAFKA_BOOTSTRAP_SERVERS", "kafka:9092")
|
||||
|
||||
data_indexer = DataIndexer(solr_client, mongodb_url, kafka_servers)
|
||||
await data_indexer.start()
|
||||
|
||||
# Initial data sync
|
||||
asyncio.create_task(data_indexer.sync_all_data())
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if data_indexer:
|
||||
await data_indexer.stop()
|
||||
|
||||
logger.info("Search Service stopped")
|
||||
|
||||
app = FastAPI(
|
||||
title="Search Service",
|
||||
description="Full-text search with Apache Solr",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "search",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"solr_connected": solr_client is not None
|
||||
}
|
||||
|
||||
@app.get("/api/search")
|
||||
async def search(
|
||||
q: str = Query(..., description="Search query"),
|
||||
doc_type: Optional[str] = Query(None, description="Filter by document type"),
|
||||
start: int = Query(0, ge=0, description="Starting offset"),
|
||||
rows: int = Query(10, ge=1, le=100, description="Number of results"),
|
||||
sort: Optional[str] = Query(None, description="Sort order (e.g., 'created_at desc')"),
|
||||
facet: bool = Query(False, description="Enable faceting"),
|
||||
facet_field: Optional[List[str]] = Query(None, description="Fields to facet on")
|
||||
):
|
||||
"""
|
||||
Search documents across all indexed content
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
# Build filter query
|
||||
fq = []
|
||||
if doc_type:
|
||||
fq.append(f"doc_type:{doc_type}")
|
||||
|
||||
# Prepare search parameters
|
||||
search_params = {
|
||||
'start': start,
|
||||
'rows': rows,
|
||||
'facet': facet
|
||||
}
|
||||
|
||||
if fq:
|
||||
search_params['fq'] = fq
|
||||
|
||||
if sort:
|
||||
search_params['sort'] = sort
|
||||
|
||||
if facet_field:
|
||||
search_params['facet_field'] = facet_field
|
||||
|
||||
# Execute search
|
||||
results = solr_client.search(q, **search_params)
|
||||
|
||||
return {
|
||||
"query": q,
|
||||
"total": results['total'],
|
||||
"start": start,
|
||||
"rows": rows,
|
||||
"documents": results['documents'],
|
||||
"facets": results.get('facets', {}),
|
||||
"highlighting": results.get('highlighting', {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/search/suggest")
|
||||
async def suggest(
|
||||
q: str = Query(..., min_length=1, description="Query prefix"),
|
||||
field: str = Query("title", description="Field to search in"),
|
||||
limit: int = Query(10, ge=1, le=50, description="Maximum suggestions")
|
||||
):
|
||||
"""
|
||||
Get autocomplete suggestions
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
suggestions = solr_client.suggest(q, field, limit)
|
||||
|
||||
return {
|
||||
"query": q,
|
||||
"suggestions": suggestions,
|
||||
"count": len(suggestions)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Suggest failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/search/similar/{doc_id}")
|
||||
async def find_similar(
|
||||
doc_id: str,
|
||||
rows: int = Query(5, ge=1, le=20, description="Number of similar documents")
|
||||
):
|
||||
"""
|
||||
Find documents similar to the given document
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
similar_docs = solr_client.more_like_this(doc_id, rows=rows)
|
||||
|
||||
return {
|
||||
"source_document": doc_id,
|
||||
"similar_documents": similar_docs,
|
||||
"count": len(similar_docs)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Similar search failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/search/index")
|
||||
async def index_document(document: Dict[str, Any]):
|
||||
"""
|
||||
Index a single document
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
doc_type = document.get('doc_type', 'general')
|
||||
success = solr_client.index_document(document, doc_type)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Document indexed",
|
||||
"document_id": document.get('id')
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to index document")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Indexing failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/search/bulk-index")
|
||||
async def bulk_index(documents: List[Dict[str, Any]]):
|
||||
"""
|
||||
Bulk index multiple documents
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
indexed = solr_client.bulk_index(documents)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Indexed {indexed} documents",
|
||||
"count": indexed
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bulk indexing failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/api/search/document/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""
|
||||
Delete a document from the index
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
success = solr_client.delete_document(doc_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Document deleted",
|
||||
"document_id": doc_id
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete document")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Deletion failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/search/stats")
|
||||
async def get_stats():
|
||||
"""
|
||||
Get search index statistics
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
stats = solr_client.get_stats()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"statistics": stats,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/search/reindex/{collection}")
|
||||
async def reindex_collection(
|
||||
collection: str,
|
||||
doc_type: Optional[str] = Query(None, description="Document type for the collection")
|
||||
):
|
||||
"""
|
||||
Reindex a specific collection
|
||||
"""
|
||||
if not data_indexer:
|
||||
raise HTTPException(status_code=503, detail="Indexer service unavailable")
|
||||
|
||||
try:
|
||||
if not doc_type:
|
||||
# Map collection to doc_type
|
||||
doc_type_map = {
|
||||
'users': 'user',
|
||||
'files': 'file',
|
||||
'content': 'content'
|
||||
}
|
||||
doc_type = doc_type_map.get(collection, collection)
|
||||
|
||||
asyncio.create_task(data_indexer.reindex_collection(collection, doc_type))
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Reindexing {collection} started",
|
||||
"collection": collection,
|
||||
"doc_type": doc_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reindex failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/search/optimize")
|
||||
async def optimize_index():
|
||||
"""
|
||||
Optimize the search index
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
success = solr_client.optimize_index()
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Index optimization started"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to optimize index")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Optimization failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/search/clear")
|
||||
async def clear_index():
|
||||
"""
|
||||
Clear all documents from the index (DANGER!)
|
||||
"""
|
||||
if not solr_client:
|
||||
raise HTTPException(status_code=503, detail="Search service unavailable")
|
||||
|
||||
try:
|
||||
success = solr_client.clear_index()
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Index cleared",
|
||||
"warning": "All documents have been deleted!"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear index")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Clear index failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
10
services/search/backend/requirements.txt
Normal file
10
services/search/backend/requirements.txt
Normal file
@ -0,0 +1,10 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
pydantic==2.5.3
|
||||
python-dotenv==1.0.0
|
||||
pysolr==3.9.0
|
||||
httpx==0.25.2
|
||||
motor==3.5.1
|
||||
pymongo==4.6.1
|
||||
aiokafka==0.10.0
|
||||
redis==5.0.1
|
||||
303
services/search/backend/solr_client.py
Normal file
303
services/search/backend/solr_client.py
Normal file
@ -0,0 +1,303 @@
|
||||
"""
|
||||
Apache Solr client for search operations
|
||||
"""
|
||||
import pysolr
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SolrClient:
|
||||
def __init__(self, solr_url: str = "http://solr:8983/solr", core_name: str = "site11"):
|
||||
self.solr_url = f"{solr_url}/{core_name}"
|
||||
self.core_name = core_name
|
||||
self.solr = None
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
"""Connect to Solr instance"""
|
||||
try:
|
||||
self.solr = pysolr.Solr(
|
||||
self.solr_url,
|
||||
always_commit=True,
|
||||
timeout=10
|
||||
)
|
||||
# Test connection
|
||||
self.solr.ping()
|
||||
logger.info(f"Connected to Solr at {self.solr_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Solr: {e}")
|
||||
raise
|
||||
|
||||
def index_document(self, document: Dict[str, Any], doc_type: str = None) -> bool:
|
||||
"""Index a single document"""
|
||||
try:
|
||||
# Add metadata
|
||||
if doc_type:
|
||||
document["doc_type"] = doc_type
|
||||
|
||||
if "id" not in document:
|
||||
document["id"] = f"{doc_type}_{document.get('_id', '')}"
|
||||
|
||||
# Add indexing timestamp
|
||||
document["indexed_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
# Index the document
|
||||
self.solr.add([document])
|
||||
logger.info(f"Indexed document: {document.get('id')}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index document: {e}")
|
||||
return False
|
||||
|
||||
def bulk_index(self, documents: List[Dict[str, Any]], doc_type: str = None) -> int:
|
||||
"""Bulk index multiple documents"""
|
||||
try:
|
||||
indexed = 0
|
||||
for doc in documents:
|
||||
if doc_type:
|
||||
doc["doc_type"] = doc_type
|
||||
|
||||
if "id" not in doc:
|
||||
doc["id"] = f"{doc_type}_{doc.get('_id', '')}"
|
||||
|
||||
doc["indexed_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
self.solr.add(documents)
|
||||
indexed = len(documents)
|
||||
logger.info(f"Bulk indexed {indexed} documents")
|
||||
return indexed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to bulk index: {e}")
|
||||
return 0
|
||||
|
||||
def search(self, query: str, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Search documents
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
**kwargs: Additional search parameters
|
||||
- fq: Filter queries
|
||||
- fl: Fields to return
|
||||
- start: Starting offset
|
||||
- rows: Number of rows
|
||||
- sort: Sort order
|
||||
- facet: Enable faceting
|
||||
- facet.field: Fields to facet on
|
||||
"""
|
||||
try:
|
||||
# Default parameters
|
||||
params = {
|
||||
'q': query,
|
||||
'start': kwargs.get('start', 0),
|
||||
'rows': kwargs.get('rows', 10),
|
||||
'fl': kwargs.get('fl', '*,score'),
|
||||
'defType': 'edismax',
|
||||
'qf': 'title^3 content^2 tags description name', # Boost fields
|
||||
'mm': '2<-25%', # Minimum match
|
||||
'hl': 'true', # Highlighting
|
||||
'hl.fl': 'title,content,description',
|
||||
'hl.simple.pre': '<mark>',
|
||||
'hl.simple.post': '</mark>'
|
||||
}
|
||||
|
||||
# Add filter queries
|
||||
if 'fq' in kwargs:
|
||||
params['fq'] = kwargs['fq']
|
||||
|
||||
# Add sorting
|
||||
if 'sort' in kwargs:
|
||||
params['sort'] = kwargs['sort']
|
||||
|
||||
# Add faceting
|
||||
if kwargs.get('facet'):
|
||||
params.update({
|
||||
'facet': 'true',
|
||||
'facet.field': kwargs.get('facet_field', ['doc_type', 'tags', 'status']),
|
||||
'facet.mincount': 1
|
||||
})
|
||||
|
||||
# Execute search
|
||||
results = self.solr.search(**params)
|
||||
|
||||
# Format response
|
||||
response = {
|
||||
'total': results.hits,
|
||||
'documents': [],
|
||||
'facets': {},
|
||||
'highlighting': {}
|
||||
}
|
||||
|
||||
# Add documents
|
||||
for doc in results.docs:
|
||||
response['documents'].append(doc)
|
||||
|
||||
# Add facets if available
|
||||
if hasattr(results, 'facets') and results.facets:
|
||||
if 'facet_fields' in results.facets:
|
||||
for field, values in results.facets['facet_fields'].items():
|
||||
response['facets'][field] = [
|
||||
{'value': values[i], 'count': values[i+1]}
|
||||
for i in range(0, len(values), 2)
|
||||
]
|
||||
|
||||
# Add highlighting if available
|
||||
if hasattr(results, 'highlighting'):
|
||||
response['highlighting'] = results.highlighting
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
return {'total': 0, 'documents': [], 'error': str(e)}
|
||||
|
||||
def suggest(self, prefix: str, field: str = "suggest", limit: int = 10) -> List[str]:
|
||||
"""Get autocomplete suggestions"""
|
||||
try:
|
||||
params = {
|
||||
'q': f'{field}:{prefix}*',
|
||||
'fl': field,
|
||||
'rows': limit,
|
||||
'start': 0
|
||||
}
|
||||
|
||||
results = self.solr.search(**params)
|
||||
suggestions = []
|
||||
|
||||
for doc in results.docs:
|
||||
if field in doc:
|
||||
value = doc[field]
|
||||
if isinstance(value, list):
|
||||
suggestions.extend(value)
|
||||
else:
|
||||
suggestions.append(value)
|
||||
|
||||
# Remove duplicates and limit
|
||||
seen = set()
|
||||
unique_suggestions = []
|
||||
for s in suggestions:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
unique_suggestions.append(s)
|
||||
if len(unique_suggestions) >= limit:
|
||||
break
|
||||
|
||||
return unique_suggestions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Suggest failed: {e}")
|
||||
return []
|
||||
|
||||
def more_like_this(self, doc_id: str, mlt_fields: List[str] = None, rows: int = 5) -> List[Dict]:
|
||||
"""Find similar documents"""
|
||||
try:
|
||||
if not mlt_fields:
|
||||
mlt_fields = ['title', 'content', 'tags', 'description']
|
||||
|
||||
params = {
|
||||
'q': f'id:{doc_id}',
|
||||
'mlt': 'true',
|
||||
'mlt.fl': ','.join(mlt_fields),
|
||||
'mlt.mindf': 1,
|
||||
'mlt.mintf': 1,
|
||||
'mlt.count': rows,
|
||||
'fl': '*,score'
|
||||
}
|
||||
|
||||
results = self.solr.search(**params)
|
||||
|
||||
if results.docs:
|
||||
# The MLT results are in the moreLikeThis section
|
||||
if hasattr(results, 'moreLikeThis'):
|
||||
mlt_results = results.moreLikeThis.get(doc_id, {})
|
||||
if 'docs' in mlt_results:
|
||||
return mlt_results['docs']
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"More like this failed: {e}")
|
||||
return []
|
||||
|
||||
def delete_document(self, doc_id: str) -> bool:
|
||||
"""Delete a document by ID"""
|
||||
try:
|
||||
self.solr.delete(id=doc_id)
|
||||
logger.info(f"Deleted document: {doc_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document: {e}")
|
||||
return False
|
||||
|
||||
def delete_by_query(self, query: str) -> bool:
|
||||
"""Delete documents matching a query"""
|
||||
try:
|
||||
self.solr.delete(q=query)
|
||||
logger.info(f"Deleted documents matching: {query}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete by query: {e}")
|
||||
return False
|
||||
|
||||
def clear_index(self) -> bool:
|
||||
"""Clear all documents from index"""
|
||||
try:
|
||||
self.solr.delete(q='*:*')
|
||||
logger.info("Cleared all documents from index")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear index: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get index statistics"""
|
||||
try:
|
||||
# Get document count
|
||||
results = self.solr.search(q='*:*', rows=0)
|
||||
|
||||
# Get facet counts for doc_type
|
||||
facet_results = self.solr.search(
|
||||
q='*:*',
|
||||
rows=0,
|
||||
facet='true',
|
||||
**{'facet.field': ['doc_type', 'status']}
|
||||
)
|
||||
|
||||
stats = {
|
||||
'total_documents': results.hits,
|
||||
'doc_types': {},
|
||||
'status_counts': {}
|
||||
}
|
||||
|
||||
if hasattr(facet_results, 'facets') and facet_results.facets:
|
||||
if 'facet_fields' in facet_results.facets:
|
||||
# Parse doc_type facets
|
||||
doc_type_facets = facet_results.facets['facet_fields'].get('doc_type', [])
|
||||
for i in range(0, len(doc_type_facets), 2):
|
||||
stats['doc_types'][doc_type_facets[i]] = doc_type_facets[i+1]
|
||||
|
||||
# Parse status facets
|
||||
status_facets = facet_results.facets['facet_fields'].get('status', [])
|
||||
for i in range(0, len(status_facets), 2):
|
||||
stats['status_counts'][status_facets[i]] = status_facets[i+1]
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def optimize_index(self) -> bool:
|
||||
"""Optimize the Solr index"""
|
||||
try:
|
||||
self.solr.optimize()
|
||||
logger.info("Index optimized")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize index: {e}")
|
||||
return False
|
||||
292
services/search/backend/test_search.py
Normal file
292
services/search/backend/test_search.py
Normal file
@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for Search Service with Apache Solr
|
||||
"""
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
BASE_URL = "http://localhost:8015"
|
||||
|
||||
async def test_search_api():
|
||||
"""Test search API endpoints"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
print("\n🔍 Testing Search Service API...")
|
||||
|
||||
# Test health check
|
||||
print("\n1. Testing health check...")
|
||||
response = await client.get(f"{BASE_URL}/health")
|
||||
print(f"Health check: {response.json()}")
|
||||
|
||||
# Test index sample documents
|
||||
print("\n2. Indexing sample documents...")
|
||||
|
||||
# Index user document
|
||||
user_doc = {
|
||||
"id": "user_test_001",
|
||||
"doc_type": "user",
|
||||
"user_id": "test_001",
|
||||
"username": "john_doe",
|
||||
"email": "john@example.com",
|
||||
"name": "John Doe",
|
||||
"bio": "Software developer passionate about Python and microservices",
|
||||
"tags": ["python", "developer", "backend"],
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
response = await client.post(f"{BASE_URL}/api/search/index", json=user_doc)
|
||||
print(f"Indexed user: {response.json()}")
|
||||
|
||||
# Index file documents
|
||||
file_docs = [
|
||||
{
|
||||
"id": "file_test_001",
|
||||
"doc_type": "file",
|
||||
"file_id": "test_file_001",
|
||||
"filename": "architecture_diagram.png",
|
||||
"content_type": "image/png",
|
||||
"size": 1024000,
|
||||
"user_id": "test_001",
|
||||
"tags": ["architecture", "design", "documentation"],
|
||||
"description": "System architecture diagram showing microservices",
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
},
|
||||
{
|
||||
"id": "file_test_002",
|
||||
"doc_type": "file",
|
||||
"file_id": "test_file_002",
|
||||
"filename": "user_manual.pdf",
|
||||
"content_type": "application/pdf",
|
||||
"size": 2048000,
|
||||
"user_id": "test_001",
|
||||
"tags": ["documentation", "manual", "guide"],
|
||||
"description": "Complete user manual for the application",
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
]
|
||||
|
||||
response = await client.post(f"{BASE_URL}/api/search/bulk-index", json=file_docs)
|
||||
print(f"Bulk indexed files: {response.json()}")
|
||||
|
||||
# Index content documents
|
||||
content_docs = [
|
||||
{
|
||||
"id": "content_test_001",
|
||||
"doc_type": "content",
|
||||
"content_id": "test_content_001",
|
||||
"title": "Getting Started with Microservices",
|
||||
"content": "Microservices architecture is a method of developing software applications as a suite of independently deployable services.",
|
||||
"summary": "Introduction to microservices architecture patterns",
|
||||
"author_id": "test_001",
|
||||
"tags": ["microservices", "architecture", "tutorial"],
|
||||
"category": "technology",
|
||||
"status": "published",
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
},
|
||||
{
|
||||
"id": "content_test_002",
|
||||
"doc_type": "content",
|
||||
"content_id": "test_content_002",
|
||||
"title": "Python Best Practices",
|
||||
"content": "Learn the best practices for writing clean, maintainable Python code including PEP 8 style guide.",
|
||||
"summary": "Essential Python coding standards and practices",
|
||||
"author_id": "test_001",
|
||||
"tags": ["python", "programming", "best-practices"],
|
||||
"category": "programming",
|
||||
"status": "published",
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
]
|
||||
|
||||
response = await client.post(f"{BASE_URL}/api/search/bulk-index", json=content_docs)
|
||||
print(f"Bulk indexed content: {response.json()}")
|
||||
|
||||
# Wait for indexing
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test basic search
|
||||
print("\n3. Testing basic search...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/search",
|
||||
params={"q": "microservices"}
|
||||
)
|
||||
results = response.json()
|
||||
print(f"Search for 'microservices': Found {results['total']} results")
|
||||
if results['documents']:
|
||||
print(f"First result: {results['documents'][0].get('title', results['documents'][0].get('filename', 'N/A'))}")
|
||||
|
||||
# Test search with filters
|
||||
print("\n4. Testing filtered search...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/search",
|
||||
params={
|
||||
"q": "*:*",
|
||||
"doc_type": "file",
|
||||
"rows": 5
|
||||
}
|
||||
)
|
||||
results = response.json()
|
||||
print(f"Files search: Found {results['total']} files")
|
||||
|
||||
# Test faceted search
|
||||
print("\n5. Testing faceted search...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/search",
|
||||
params={
|
||||
"q": "*:*",
|
||||
"facet": "true",
|
||||
"facet_field": ["doc_type", "tags", "category", "status"]
|
||||
}
|
||||
)
|
||||
results = response.json()
|
||||
print(f"Facets: {json.dumps(results['facets'], indent=2)}")
|
||||
|
||||
# Test autocomplete/suggest
|
||||
print("\n6. Testing autocomplete...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/search/suggest",
|
||||
params={
|
||||
"q": "micro",
|
||||
"field": "title",
|
||||
"limit": 5
|
||||
}
|
||||
)
|
||||
suggestions = response.json()
|
||||
print(f"Suggestions for 'micro': {suggestions['suggestions']}")
|
||||
|
||||
# Test similar documents
|
||||
print("\n7. Testing similar documents...")
|
||||
response = await client.get(f"{BASE_URL}/api/search/similar/content_test_001")
|
||||
if response.status_code == 200:
|
||||
similar = response.json()
|
||||
print(f"Found {similar['count']} similar documents")
|
||||
else:
|
||||
print(f"Similar search: {response.status_code}")
|
||||
|
||||
# Test search with highlighting
|
||||
print("\n8. Testing search with highlighting...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/search",
|
||||
params={"q": "Python"}
|
||||
)
|
||||
results = response.json()
|
||||
if results['highlighting']:
|
||||
print(f"Highlighting results: {len(results['highlighting'])} documents highlighted")
|
||||
|
||||
# Test search statistics
|
||||
print("\n9. Testing search statistics...")
|
||||
response = await client.get(f"{BASE_URL}/api/search/stats")
|
||||
if response.status_code == 200:
|
||||
stats = response.json()
|
||||
print(f"Index stats: {stats['statistics']}")
|
||||
|
||||
# Test complex query
|
||||
print("\n10. Testing complex query...")
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/search",
|
||||
params={
|
||||
"q": "architecture OR python",
|
||||
"doc_type": "content",
|
||||
"sort": "created_at desc",
|
||||
"rows": 10
|
||||
}
|
||||
)
|
||||
results = response.json()
|
||||
print(f"Complex query: Found {results['total']} results")
|
||||
|
||||
# Test delete document
|
||||
print("\n11. Testing document deletion...")
|
||||
response = await client.delete(f"{BASE_URL}/api/search/document/content_test_002")
|
||||
if response.status_code == 200:
|
||||
print(f"Deleted document: {response.json()}")
|
||||
|
||||
# Verify deletion
|
||||
await asyncio.sleep(1)
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/search",
|
||||
params={"q": "id:content_test_002"}
|
||||
)
|
||||
results = response.json()
|
||||
print(f"Verify deletion: Found {results['total']} results (should be 0)")
|
||||
|
||||
async def test_performance():
|
||||
"""Test search performance"""
|
||||
print("\n\n⚡ Testing Search Performance...")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
# Index many documents
|
||||
print("Indexing 100 test documents...")
|
||||
docs = []
|
||||
for i in range(100):
|
||||
docs.append({
|
||||
"id": f"perf_test_{i}",
|
||||
"doc_type": "content",
|
||||
"title": f"Test Document {i}",
|
||||
"content": f"This is test content for document {i} with various keywords like search, Solr, Python, microservices",
|
||||
"tags": [f"tag{i%10}", f"category{i%5}"],
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
response = await client.post(f"{BASE_URL}/api/search/bulk-index", json=docs)
|
||||
print(f"Indexed {response.json().get('count', 0)} documents")
|
||||
|
||||
# Wait for indexing
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test search speed
|
||||
print("\nTesting search response times...")
|
||||
import time
|
||||
|
||||
queries = ["search", "Python", "document", "test", "microservices"]
|
||||
for query in queries:
|
||||
start = time.time()
|
||||
response = await client.get(
|
||||
f"{BASE_URL}/api/search",
|
||||
params={"q": query, "rows": 20}
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
results = response.json()
|
||||
print(f"Query '{query}': {results['total']} results in {elapsed:.3f}s")
|
||||
|
||||
async def test_reindex():
|
||||
"""Test reindexing from MongoDB"""
|
||||
print("\n\n🔄 Testing Reindex Functionality...")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Trigger reindex for users collection
|
||||
print("Triggering reindex for users collection...")
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/search/reindex/users",
|
||||
params={"doc_type": "user"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"Reindex started: {response.json()}")
|
||||
else:
|
||||
print(f"Reindex failed: {response.status_code}")
|
||||
|
||||
# Test index optimization
|
||||
print("\nTesting index optimization...")
|
||||
response = await client.post(f"{BASE_URL}/api/search/optimize")
|
||||
if response.status_code == 200:
|
||||
print(f"Optimization: {response.json()}")
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
print("=" * 60)
|
||||
print("SEARCH SERVICE TEST SUITE (Apache Solr)")
|
||||
print("=" * 60)
|
||||
print(f"Started at: {datetime.now().isoformat()}")
|
||||
|
||||
# Run tests
|
||||
await test_search_api()
|
||||
await test_performance()
|
||||
await test_reindex()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ All search tests completed!")
|
||||
print(f"Finished at: {datetime.now().isoformat()}")
|
||||
print("=" * 60)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
105
services/search/solr-config/conf/managed-schema.xml
Normal file
105
services/search/solr-config/conf/managed-schema.xml
Normal file
@ -0,0 +1,105 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<schema name="site11" version="1.6">
|
||||
<!-- Field Types -->
|
||||
<fieldType name="string" class="solr.StrField" sortMissingLast="true" omitNorms="true"/>
|
||||
<fieldType name="boolean" class="solr.BoolField" sortMissingLast="true" omitNorms="true"/>
|
||||
<fieldType name="int" class="solr.IntPointField" omitNorms="true"/>
|
||||
<fieldType name="long" class="solr.LongPointField" omitNorms="true"/>
|
||||
<fieldType name="float" class="solr.FloatPointField" omitNorms="true"/>
|
||||
<fieldType name="double" class="solr.DoublePointField" omitNorms="true"/>
|
||||
<fieldType name="date" class="solr.DatePointField" omitNorms="true"/>
|
||||
|
||||
<!-- Text field with analysis -->
|
||||
<fieldType name="text_general" class="solr.TextField" positionIncrementGap="100">
|
||||
<analyzer type="index">
|
||||
<tokenizer class="solr.StandardTokenizerFactory"/>
|
||||
<filter class="solr.StopFilterFactory" ignoreCase="true" words="stopwords.txt"/>
|
||||
<filter class="solr.LowerCaseFilterFactory"/>
|
||||
<filter class="solr.EdgeNGramFilterFactory" minGramSize="2" maxGramSize="15"/>
|
||||
</analyzer>
|
||||
<analyzer type="query">
|
||||
<tokenizer class="solr.StandardTokenizerFactory"/>
|
||||
<filter class="solr.StopFilterFactory" ignoreCase="true" words="stopwords.txt"/>
|
||||
<filter class="solr.SynonymGraphFilterFactory" synonyms="synonyms.txt" ignoreCase="true" expand="true"/>
|
||||
<filter class="solr.LowerCaseFilterFactory"/>
|
||||
</analyzer>
|
||||
</fieldType>
|
||||
|
||||
<!-- Text field for exact matching -->
|
||||
<fieldType name="text_exact" class="solr.TextField">
|
||||
<analyzer>
|
||||
<tokenizer class="solr.KeywordTokenizerFactory"/>
|
||||
<filter class="solr.LowerCaseFilterFactory"/>
|
||||
</analyzer>
|
||||
</fieldType>
|
||||
|
||||
<!-- Autocomplete/Suggest field -->
|
||||
<fieldType name="text_suggest" class="solr.TextField" positionIncrementGap="100">
|
||||
<analyzer>
|
||||
<tokenizer class="solr.StandardTokenizerFactory"/>
|
||||
<filter class="solr.LowerCaseFilterFactory"/>
|
||||
<filter class="solr.EdgeNGramFilterFactory" minGramSize="1" maxGramSize="20"/>
|
||||
</analyzer>
|
||||
</fieldType>
|
||||
|
||||
<!-- Fields -->
|
||||
<field name="id" type="string" indexed="true" stored="true" required="true"/>
|
||||
<field name="_version_" type="long" indexed="true" stored="true"/>
|
||||
|
||||
<!-- Document type and metadata -->
|
||||
<field name="doc_type" type="string" indexed="true" stored="true" docValues="true"/>
|
||||
<field name="indexed_at" type="date" indexed="true" stored="true"/>
|
||||
|
||||
<!-- Common fields across document types -->
|
||||
<field name="title" type="text_general" indexed="true" stored="true" termVectors="true"/>
|
||||
<field name="content" type="text_general" indexed="true" stored="true" termVectors="true"/>
|
||||
<field name="description" type="text_general" indexed="true" stored="true"/>
|
||||
<field name="summary" type="text_general" indexed="true" stored="true"/>
|
||||
<field name="tags" type="string" indexed="true" stored="true" multiValued="true" docValues="true"/>
|
||||
<field name="category" type="string" indexed="true" stored="true" docValues="true"/>
|
||||
<field name="status" type="string" indexed="true" stored="true" docValues="true"/>
|
||||
|
||||
<!-- User-specific fields -->
|
||||
<field name="user_id" type="string" indexed="true" stored="true"/>
|
||||
<field name="username" type="text_exact" indexed="true" stored="true"/>
|
||||
<field name="email" type="text_exact" indexed="true" stored="true"/>
|
||||
<field name="name" type="text_general" indexed="true" stored="true"/>
|
||||
<field name="bio" type="text_general" indexed="true" stored="true"/>
|
||||
|
||||
<!-- File-specific fields -->
|
||||
<field name="file_id" type="string" indexed="true" stored="true"/>
|
||||
<field name="filename" type="text_general" indexed="true" stored="true"/>
|
||||
<field name="original_name" type="text_general" indexed="true" stored="true"/>
|
||||
<field name="content_type" type="string" indexed="true" stored="true" docValues="true"/>
|
||||
<field name="size" type="long" indexed="true" stored="true"/>
|
||||
|
||||
<!-- Content-specific fields -->
|
||||
<field name="content_id" type="string" indexed="true" stored="true"/>
|
||||
<field name="author_id" type="string" indexed="true" stored="true"/>
|
||||
|
||||
<!-- Dates -->
|
||||
<field name="created_at" type="date" indexed="true" stored="true"/>
|
||||
<field name="updated_at" type="date" indexed="true" stored="true"/>
|
||||
|
||||
<!-- Suggest field for autocomplete -->
|
||||
<field name="suggest" type="text_suggest" indexed="true" stored="false" multiValued="true"/>
|
||||
|
||||
<!-- Copy fields for better search -->
|
||||
<copyField source="title" dest="suggest"/>
|
||||
<copyField source="name" dest="suggest"/>
|
||||
<copyField source="filename" dest="suggest"/>
|
||||
<copyField source="tags" dest="suggest"/>
|
||||
|
||||
<!-- Dynamic fields -->
|
||||
<dynamicField name="*_i" type="int" indexed="true" stored="true"/>
|
||||
<dynamicField name="*_l" type="long" indexed="true" stored="true"/>
|
||||
<dynamicField name="*_f" type="float" indexed="true" stored="true"/>
|
||||
<dynamicField name="*_d" type="double" indexed="true" stored="true"/>
|
||||
<dynamicField name="*_s" type="string" indexed="true" stored="true"/>
|
||||
<dynamicField name="*_t" type="text_general" indexed="true" stored="true"/>
|
||||
<dynamicField name="*_dt" type="date" indexed="true" stored="true"/>
|
||||
<dynamicField name="*_b" type="boolean" indexed="true" stored="true"/>
|
||||
|
||||
<!-- Unique Key -->
|
||||
<uniqueKey>id</uniqueKey>
|
||||
</schema>
|
||||
152
services/search/solr-config/conf/solrconfig.xml
Normal file
152
services/search/solr-config/conf/solrconfig.xml
Normal file
@ -0,0 +1,152 @@
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<config>
|
||||
<luceneMatchVersion>9.4.0</luceneMatchVersion>
|
||||
|
||||
<!-- Data Directory -->
|
||||
<dataDir>${solr.data.dir:}</dataDir>
|
||||
|
||||
<!-- Index Config -->
|
||||
<indexConfig>
|
||||
<ramBufferSizeMB>100</ramBufferSizeMB>
|
||||
<maxBufferedDocs>1000</maxBufferedDocs>
|
||||
<mergePolicyFactory class="org.apache.solr.index.TieredMergePolicyFactory">
|
||||
<int name="maxMergeAtOnce">10</int>
|
||||
<int name="segmentsPerTier">10</int>
|
||||
</mergePolicyFactory>
|
||||
</indexConfig>
|
||||
|
||||
<!-- Update Handler -->
|
||||
<updateHandler class="solr.DirectUpdateHandler2">
|
||||
<updateLog>
|
||||
<str name="dir">${solr.ulog.dir:}</str>
|
||||
<int name="numVersionBuckets">${solr.ulog.numVersionBuckets:65536}</int>
|
||||
</updateLog>
|
||||
<autoCommit>
|
||||
<maxTime>${solr.autoCommit.maxTime:15000}</maxTime>
|
||||
<openSearcher>false</openSearcher>
|
||||
</autoCommit>
|
||||
<autoSoftCommit>
|
||||
<maxTime>${solr.autoSoftCommit.maxTime:1000}</maxTime>
|
||||
</autoSoftCommit>
|
||||
</updateHandler>
|
||||
|
||||
<!-- Query Settings -->
|
||||
<query>
|
||||
<maxBooleanClauses>1024</maxBooleanClauses>
|
||||
<filterCache class="solr.CaffeineCache" size="512" initialSize="512" autowarmCount="0"/>
|
||||
<queryResultCache class="solr.CaffeineCache" size="512" initialSize="512" autowarmCount="0"/>
|
||||
<documentCache class="solr.CaffeineCache" size="512" initialSize="512" autowarmCount="0"/>
|
||||
<enableLazyFieldLoading>true</enableLazyFieldLoading>
|
||||
<queryResultWindowSize>20</queryResultWindowSize>
|
||||
<queryResultMaxDocsCached>200</queryResultMaxDocsCached>
|
||||
</query>
|
||||
|
||||
<!-- Request Dispatcher -->
|
||||
<requestDispatcher>
|
||||
<requestParsers enableRemoteStreaming="true" multipartUploadLimitInKB="2048000"
|
||||
formdataUploadLimitInKB="2048" addHttpRequestToContext="false"/>
|
||||
<httpCaching never304="true"/>
|
||||
</requestDispatcher>
|
||||
|
||||
<!-- Request Handlers -->
|
||||
|
||||
<!-- Standard search handler -->
|
||||
<requestHandler name="/select" class="solr.SearchHandler">
|
||||
<lst name="defaults">
|
||||
<str name="echoParams">explicit</str>
|
||||
<int name="rows">10</int>
|
||||
<str name="df">content</str>
|
||||
<str name="q.op">OR</str>
|
||||
<str name="defType">edismax</str>
|
||||
<str name="qf">
|
||||
title^3.0 name^2.5 content^2.0 description^1.5 summary^1.5
|
||||
filename^1.5 tags^1.2 category username email bio
|
||||
</str>
|
||||
<str name="pf">
|
||||
title^4.0 name^3.0 content^2.5 description^2.0
|
||||
</str>
|
||||
<str name="mm">2<-25%</str>
|
||||
<str name="hl">true</str>
|
||||
<str name="hl.fl">title,content,description,summary</str>
|
||||
<str name="hl.simple.pre"><mark></str>
|
||||
<str name="hl.simple.post"></mark></str>
|
||||
<str name="facet">true</str>
|
||||
<str name="facet.mincount">1</str>
|
||||
</lst>
|
||||
</requestHandler>
|
||||
|
||||
<!-- Update handler -->
|
||||
<requestHandler name="/update" class="solr.UpdateRequestHandler"/>
|
||||
|
||||
<!-- Get handler -->
|
||||
<requestHandler name="/get" class="solr.RealTimeGetHandler">
|
||||
<lst name="defaults">
|
||||
<str name="omitHeader">true</str>
|
||||
</lst>
|
||||
</requestHandler>
|
||||
|
||||
<!-- Admin handlers -->
|
||||
<requestHandler name="/admin/ping" class="solr.PingRequestHandler">
|
||||
<lst name="invariants">
|
||||
<str name="q">solrpingquery</str>
|
||||
</lst>
|
||||
<lst name="defaults">
|
||||
<str name="echoParams">all</str>
|
||||
</lst>
|
||||
</requestHandler>
|
||||
|
||||
<!-- Suggest/Autocomplete handler -->
|
||||
<requestHandler name="/suggest" class="solr.SearchHandler">
|
||||
<lst name="defaults">
|
||||
<str name="suggest">true</str>
|
||||
<str name="suggest.count">10</str>
|
||||
<str name="suggest.dictionary">suggest</str>
|
||||
</lst>
|
||||
<arr name="components">
|
||||
<str>suggest</str>
|
||||
</arr>
|
||||
</requestHandler>
|
||||
|
||||
<!-- Spell check component -->
|
||||
<searchComponent name="spellcheck" class="solr.SpellCheckComponent">
|
||||
<str name="queryAnalyzerFieldType">text_general</str>
|
||||
<lst name="spellchecker">
|
||||
<str name="name">default</str>
|
||||
<str name="field">content</str>
|
||||
<str name="classname">solr.DirectSolrSpellChecker</str>
|
||||
<str name="distanceMeasure">internal</str>
|
||||
<float name="accuracy">0.5</float>
|
||||
<int name="maxEdits">2</int>
|
||||
<int name="minPrefix">1</int>
|
||||
<int name="maxInspections">5</int>
|
||||
<int name="minQueryLength">4</int>
|
||||
<float name="maxQueryFrequency">0.01</float>
|
||||
</lst>
|
||||
</searchComponent>
|
||||
|
||||
<!-- Suggest component -->
|
||||
<searchComponent name="suggest" class="solr.SuggestComponent">
|
||||
<lst name="suggester">
|
||||
<str name="name">suggest</str>
|
||||
<str name="lookupImpl">FuzzyLookupFactory</str>
|
||||
<str name="dictionaryImpl">DocumentDictionaryFactory</str>
|
||||
<str name="field">suggest</str>
|
||||
<str name="suggestAnalyzerFieldType">text_suggest</str>
|
||||
<str name="buildOnStartup">false</str>
|
||||
</lst>
|
||||
</searchComponent>
|
||||
|
||||
<!-- More Like This handler -->
|
||||
<requestHandler name="/mlt" class="solr.MoreLikeThisHandler">
|
||||
<lst name="defaults">
|
||||
<str name="mlt.fl">title,content,description,tags</str>
|
||||
<int name="mlt.mindf">1</int>
|
||||
<int name="mlt.mintf">1</int>
|
||||
<int name="mlt.count">10</int>
|
||||
</lst>
|
||||
</requestHandler>
|
||||
|
||||
<!-- Schema handler (removed for Solr 9.x compatibility) -->
|
||||
|
||||
<!-- Config handler (removed for Solr 9.x compatibility) -->
|
||||
</config>
|
||||
35
services/search/solr-config/conf/stopwords.txt
Normal file
35
services/search/solr-config/conf/stopwords.txt
Normal file
@ -0,0 +1,35 @@
|
||||
# Licensed to the Apache Software Foundation (ASF)
|
||||
# Standard English stop words
|
||||
a
|
||||
an
|
||||
and
|
||||
are
|
||||
as
|
||||
at
|
||||
be
|
||||
but
|
||||
by
|
||||
for
|
||||
if
|
||||
in
|
||||
into
|
||||
is
|
||||
it
|
||||
no
|
||||
not
|
||||
of
|
||||
on
|
||||
or
|
||||
such
|
||||
that
|
||||
the
|
||||
their
|
||||
then
|
||||
there
|
||||
these
|
||||
they
|
||||
this
|
||||
to
|
||||
was
|
||||
will
|
||||
with
|
||||
38
services/search/solr-config/conf/synonyms.txt
Normal file
38
services/search/solr-config/conf/synonyms.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Synonyms for site11 search
|
||||
# Format: term1, term2, term3 => all are synonyms
|
||||
# Or: term1, term2 => term1 is replaced by term2
|
||||
|
||||
# Technology synonyms
|
||||
javascript, js
|
||||
typescript, ts
|
||||
python, py
|
||||
golang, go
|
||||
database, db
|
||||
kubernetes, k8s
|
||||
docker, container, containerization
|
||||
|
||||
# Common terms
|
||||
search, find, query, lookup
|
||||
upload, import, add
|
||||
download, export, get
|
||||
delete, remove, erase
|
||||
update, modify, edit, change
|
||||
create, make, new, add
|
||||
|
||||
# File related
|
||||
document, doc, file
|
||||
image, picture, photo, img
|
||||
video, movie, clip
|
||||
audio, sound, music
|
||||
|
||||
# User related
|
||||
user, member, account
|
||||
admin, administrator, moderator
|
||||
profile, account, user
|
||||
|
||||
# Status
|
||||
active, enabled, live
|
||||
inactive, disabled, offline
|
||||
pending, waiting, processing
|
||||
complete, done, finished
|
||||
error, failed, failure
|
||||
21
services/statistics/backend/Dockerfile
Normal file
21
services/statistics/backend/Dockerfile
Normal file
@ -0,0 +1,21 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the application
|
||||
CMD ["python", "main.py"]
|
||||
617
services/statistics/backend/aggregator.py
Normal file
617
services/statistics/backend/aggregator.py
Normal file
@ -0,0 +1,617 @@
|
||||
"""
|
||||
Data Aggregator - Performs data aggregation and analytics
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from models import (
|
||||
AggregatedMetric, AggregationType, Granularity,
|
||||
UserAnalytics, SystemAnalytics, EventAnalytics,
|
||||
AlertRule, Alert
|
||||
)
|
||||
import uuid
|
||||
import io
|
||||
import csv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DataAggregator:
|
||||
"""Performs data aggregation and analytics operations"""
|
||||
|
||||
def __init__(self, ts_db, cache):
|
||||
self.ts_db = ts_db
|
||||
self.cache = cache
|
||||
self.is_running = False
|
||||
self.alert_rules = {}
|
||||
self.active_alerts = {}
|
||||
self.aggregation_jobs = []
|
||||
|
||||
async def start_aggregation_jobs(self):
|
||||
"""Start background aggregation jobs"""
|
||||
self.is_running = True
|
||||
|
||||
# Schedule periodic aggregation jobs
|
||||
self.aggregation_jobs = [
|
||||
asyncio.create_task(self._aggregate_hourly_metrics()),
|
||||
asyncio.create_task(self._aggregate_daily_metrics()),
|
||||
asyncio.create_task(self._check_alert_rules()),
|
||||
asyncio.create_task(self._cleanup_old_data())
|
||||
]
|
||||
|
||||
logger.info("Data aggregation jobs started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop aggregation jobs"""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel all jobs
|
||||
for job in self.aggregation_jobs:
|
||||
job.cancel()
|
||||
|
||||
# Wait for jobs to complete
|
||||
await asyncio.gather(*self.aggregation_jobs, return_exceptions=True)
|
||||
|
||||
logger.info("Data aggregation jobs stopped")
|
||||
|
||||
async def _aggregate_hourly_metrics(self):
|
||||
"""Aggregate metrics every hour"""
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(3600) # Run every hour
|
||||
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(hours=1)
|
||||
|
||||
# Aggregate different metric types
|
||||
await self._aggregate_metric_type("user.event", start_time, end_time, Granularity.HOUR)
|
||||
await self._aggregate_metric_type("system.cpu", start_time, end_time, Granularity.HOUR)
|
||||
await self._aggregate_metric_type("system.memory", start_time, end_time, Granularity.HOUR)
|
||||
|
||||
logger.info("Completed hourly metrics aggregation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hourly aggregation: {e}")
|
||||
|
||||
async def _aggregate_daily_metrics(self):
|
||||
"""Aggregate metrics every day"""
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(86400) # Run every 24 hours
|
||||
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(days=1)
|
||||
|
||||
# Aggregate different metric types
|
||||
await self._aggregate_metric_type("user.event", start_time, end_time, Granularity.DAY)
|
||||
await self._aggregate_metric_type("system", start_time, end_time, Granularity.DAY)
|
||||
|
||||
logger.info("Completed daily metrics aggregation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in daily aggregation: {e}")
|
||||
|
||||
async def _aggregate_metric_type(
|
||||
self,
|
||||
metric_prefix: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
granularity: Granularity
|
||||
):
|
||||
"""Aggregate a specific metric type"""
|
||||
try:
|
||||
# Query raw metrics
|
||||
metrics = await self.ts_db.query_metrics(
|
||||
metric_type=metric_prefix,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
if not metrics:
|
||||
return
|
||||
|
||||
# Calculate aggregations
|
||||
aggregations = {
|
||||
AggregationType.AVG: sum(m['value'] for m in metrics) / len(metrics),
|
||||
AggregationType.SUM: sum(m['value'] for m in metrics),
|
||||
AggregationType.MIN: min(m['value'] for m in metrics),
|
||||
AggregationType.MAX: max(m['value'] for m in metrics),
|
||||
AggregationType.COUNT: len(metrics)
|
||||
}
|
||||
|
||||
# Store aggregated results
|
||||
for agg_type, value in aggregations.items():
|
||||
aggregated = AggregatedMetric(
|
||||
metric_name=metric_prefix,
|
||||
aggregation_type=agg_type,
|
||||
value=value,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
granularity=granularity,
|
||||
count=len(metrics)
|
||||
)
|
||||
|
||||
await self.ts_db.store_aggregated_metric(aggregated)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error aggregating {metric_prefix}: {e}")
|
||||
|
||||
async def aggregate_metrics(
|
||||
self,
|
||||
metric_type: str,
|
||||
aggregation: str,
|
||||
group_by: Optional[str],
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform custom metric aggregation"""
|
||||
try:
|
||||
# Query metrics
|
||||
metrics = await self.ts_db.query_metrics(
|
||||
metric_type=metric_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
if not metrics:
|
||||
return {"result": 0, "count": 0}
|
||||
|
||||
# Group metrics if requested
|
||||
if group_by:
|
||||
grouped = {}
|
||||
for metric in metrics:
|
||||
key = metric.get('tags', {}).get(group_by, 'unknown')
|
||||
if key not in grouped:
|
||||
grouped[key] = []
|
||||
grouped[key].append(metric['value'])
|
||||
|
||||
# Aggregate each group
|
||||
results = {}
|
||||
for key, values in grouped.items():
|
||||
results[key] = self._calculate_aggregation(values, aggregation)
|
||||
|
||||
return {"grouped_results": results, "count": len(metrics)}
|
||||
else:
|
||||
# Single aggregation
|
||||
values = [m['value'] for m in metrics]
|
||||
result = self._calculate_aggregation(values, aggregation)
|
||||
return {"result": result, "count": len(metrics)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in custom aggregation: {e}")
|
||||
raise
|
||||
|
||||
def _calculate_aggregation(self, values: List[float], aggregation: str) -> float:
|
||||
"""Calculate aggregation on values"""
|
||||
if not values:
|
||||
return 0
|
||||
|
||||
if aggregation == "avg":
|
||||
return sum(values) / len(values)
|
||||
elif aggregation == "sum":
|
||||
return sum(values)
|
||||
elif aggregation == "min":
|
||||
return min(values)
|
||||
elif aggregation == "max":
|
||||
return max(values)
|
||||
elif aggregation == "count":
|
||||
return len(values)
|
||||
else:
|
||||
return 0
|
||||
|
||||
async def get_overview(self) -> Dict[str, Any]:
|
||||
"""Get analytics overview"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
last_hour = now - timedelta(hours=1)
|
||||
last_day = now - timedelta(days=1)
|
||||
last_week = now - timedelta(weeks=1)
|
||||
|
||||
# Get various metrics
|
||||
hourly_events = await self.ts_db.count_metrics("user.event", last_hour, now)
|
||||
daily_events = await self.ts_db.count_metrics("user.event", last_day, now)
|
||||
weekly_events = await self.ts_db.count_metrics("user.event", last_week, now)
|
||||
|
||||
# Get system status
|
||||
cpu_avg = await self.ts_db.get_average("system.cpu.usage", last_hour, now)
|
||||
memory_avg = await self.ts_db.get_average("system.memory.usage", last_hour, now)
|
||||
|
||||
# Get active users (approximate from events)
|
||||
active_users = await self.ts_db.count_distinct_tags("user.event", "user_id", last_day, now)
|
||||
|
||||
return {
|
||||
"events": {
|
||||
"last_hour": hourly_events,
|
||||
"last_day": daily_events,
|
||||
"last_week": weekly_events
|
||||
},
|
||||
"system": {
|
||||
"cpu_usage": cpu_avg,
|
||||
"memory_usage": memory_avg
|
||||
},
|
||||
"users": {
|
||||
"active_daily": active_users
|
||||
},
|
||||
"alerts": {
|
||||
"active": len(self.active_alerts)
|
||||
},
|
||||
"timestamp": now.isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting overview: {e}")
|
||||
return {}
|
||||
|
||||
async def get_user_analytics(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
granularity: str
|
||||
) -> UserAnalytics:
|
||||
"""Get user analytics"""
|
||||
try:
|
||||
# Get user metrics
|
||||
total_users = await self.ts_db.count_distinct_tags(
|
||||
"user.event.user_created",
|
||||
"user_id",
|
||||
datetime.min,
|
||||
end_date
|
||||
)
|
||||
|
||||
active_users = await self.ts_db.count_distinct_tags(
|
||||
"user.event",
|
||||
"user_id",
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
new_users = await self.ts_db.count_metrics(
|
||||
"user.event.user_created",
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
# Calculate growth rate
|
||||
prev_period_start = start_date - (end_date - start_date)
|
||||
prev_users = await self.ts_db.count_distinct_tags(
|
||||
"user.event",
|
||||
"user_id",
|
||||
prev_period_start,
|
||||
start_date
|
||||
)
|
||||
|
||||
growth_rate = ((active_users - prev_users) / max(prev_users, 1)) * 100
|
||||
|
||||
# Get top actions
|
||||
top_actions = await self.ts_db.get_top_metrics(
|
||||
"user.event",
|
||||
"event_type",
|
||||
start_date,
|
||||
end_date,
|
||||
limit=10
|
||||
)
|
||||
|
||||
return UserAnalytics(
|
||||
total_users=total_users,
|
||||
active_users=active_users,
|
||||
new_users=new_users,
|
||||
user_growth_rate=growth_rate,
|
||||
average_session_duration=0, # Would need session tracking
|
||||
top_actions=top_actions,
|
||||
user_distribution={}, # Would need geographic data
|
||||
period=f"{start_date.date()} to {end_date.date()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user analytics: {e}")
|
||||
raise
|
||||
|
||||
async def get_system_analytics(self) -> SystemAnalytics:
|
||||
"""Get system performance analytics"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
last_hour = now - timedelta(hours=1)
|
||||
last_day = now - timedelta(days=1)
|
||||
|
||||
# Calculate uptime (simplified - would need actual downtime tracking)
|
||||
total_checks = await self.ts_db.count_metrics("system.health", last_day, now)
|
||||
successful_checks = await self.ts_db.count_metrics_with_value(
|
||||
"system.health",
|
||||
1,
|
||||
last_day,
|
||||
now
|
||||
)
|
||||
uptime = (successful_checks / max(total_checks, 1)) * 100
|
||||
|
||||
# Get averages
|
||||
cpu_usage = await self.ts_db.get_average("system.cpu.usage", last_hour, now)
|
||||
memory_usage = await self.ts_db.get_average("system.memory.usage", last_hour, now)
|
||||
disk_usage = await self.ts_db.get_average("system.disk.usage", last_hour, now)
|
||||
response_time = await self.ts_db.get_average("api.response_time", last_hour, now)
|
||||
|
||||
# Get error rate
|
||||
total_requests = await self.ts_db.count_metrics("api.request", last_hour, now)
|
||||
error_requests = await self.ts_db.count_metrics("api.error", last_hour, now)
|
||||
error_rate = (error_requests / max(total_requests, 1)) * 100
|
||||
|
||||
# Throughput
|
||||
throughput = total_requests / 3600 # requests per second
|
||||
|
||||
return SystemAnalytics(
|
||||
uptime_percentage=uptime,
|
||||
average_response_time=response_time or 0,
|
||||
error_rate=error_rate,
|
||||
throughput=throughput,
|
||||
cpu_usage=cpu_usage or 0,
|
||||
memory_usage=memory_usage or 0,
|
||||
disk_usage=disk_usage or 0,
|
||||
active_connections=0, # Would need connection tracking
|
||||
services_health={} # Would need service health checks
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system analytics: {e}")
|
||||
raise
|
||||
|
||||
async def get_event_analytics(
|
||||
self,
|
||||
event_type: Optional[str],
|
||||
limit: int
|
||||
) -> EventAnalytics:
|
||||
"""Get event analytics"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
last_hour = now - timedelta(hours=1)
|
||||
|
||||
# Get total events
|
||||
total_events = await self.ts_db.count_metrics(
|
||||
event_type or "user.event",
|
||||
last_hour,
|
||||
now
|
||||
)
|
||||
|
||||
# Events per second
|
||||
events_per_second = total_events / 3600
|
||||
|
||||
# Get event types distribution
|
||||
event_types = await self.ts_db.get_metric_distribution(
|
||||
"user.event",
|
||||
"event_type",
|
||||
last_hour,
|
||||
now
|
||||
)
|
||||
|
||||
# Top events
|
||||
top_events = await self.ts_db.get_top_metrics(
|
||||
event_type or "user.event",
|
||||
"event_type",
|
||||
last_hour,
|
||||
now,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Error events
|
||||
error_events = await self.ts_db.count_metrics(
|
||||
"user.event.error",
|
||||
last_hour,
|
||||
now
|
||||
)
|
||||
|
||||
# Success rate
|
||||
success_rate = ((total_events - error_events) / max(total_events, 1)) * 100
|
||||
|
||||
return EventAnalytics(
|
||||
total_events=total_events,
|
||||
events_per_second=events_per_second,
|
||||
event_types=event_types,
|
||||
top_events=top_events,
|
||||
error_events=error_events,
|
||||
success_rate=success_rate,
|
||||
processing_time={} # Would need timing metrics
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting event analytics: {e}")
|
||||
raise
|
||||
|
||||
async def get_dashboard_configs(self) -> List[Dict[str, Any]]:
|
||||
"""Get available dashboard configurations"""
|
||||
return [
|
||||
{
|
||||
"id": "overview",
|
||||
"name": "Overview Dashboard",
|
||||
"description": "General system overview"
|
||||
},
|
||||
{
|
||||
"id": "users",
|
||||
"name": "User Analytics",
|
||||
"description": "User behavior and statistics"
|
||||
},
|
||||
{
|
||||
"id": "system",
|
||||
"name": "System Performance",
|
||||
"description": "System health and performance metrics"
|
||||
},
|
||||
{
|
||||
"id": "events",
|
||||
"name": "Event Analytics",
|
||||
"description": "Event processing and statistics"
|
||||
}
|
||||
]
|
||||
|
||||
async def get_dashboard_data(self, dashboard_id: str) -> Dict[str, Any]:
|
||||
"""Get data for a specific dashboard"""
|
||||
if dashboard_id == "overview":
|
||||
return await self.get_overview()
|
||||
elif dashboard_id == "users":
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=7)
|
||||
analytics = await self.get_user_analytics(start_date, end_date, "day")
|
||||
return analytics.dict()
|
||||
elif dashboard_id == "system":
|
||||
analytics = await self.get_system_analytics()
|
||||
return analytics.dict()
|
||||
elif dashboard_id == "events":
|
||||
analytics = await self.get_event_analytics(None, 100)
|
||||
return analytics.dict()
|
||||
else:
|
||||
raise ValueError(f"Unknown dashboard: {dashboard_id}")
|
||||
|
||||
async def create_alert_rule(self, rule_data: Dict[str, Any]) -> str:
|
||||
"""Create a new alert rule"""
|
||||
rule = AlertRule(**rule_data)
|
||||
rule.id = str(uuid.uuid4())
|
||||
self.alert_rules[rule.id] = rule
|
||||
|
||||
# Store in cache
|
||||
await self.cache.set(
|
||||
f"alert_rule:{rule.id}",
|
||||
rule.json(),
|
||||
expire=None # Permanent
|
||||
)
|
||||
|
||||
return rule.id
|
||||
|
||||
async def _check_alert_rules(self):
|
||||
"""Check alert rules periodically"""
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
|
||||
for rule_id, rule in self.alert_rules.items():
|
||||
if not rule.enabled:
|
||||
continue
|
||||
|
||||
await self._evaluate_alert_rule(rule)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking alert rules: {e}")
|
||||
|
||||
async def _evaluate_alert_rule(self, rule: AlertRule):
|
||||
"""Evaluate a single alert rule"""
|
||||
try:
|
||||
# Get recent metric values
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(seconds=rule.duration)
|
||||
|
||||
avg_value = await self.ts_db.get_average(
|
||||
rule.metric_name,
|
||||
start_time,
|
||||
end_time
|
||||
)
|
||||
|
||||
if avg_value is None:
|
||||
return
|
||||
|
||||
# Check condition
|
||||
triggered = False
|
||||
if rule.condition == "gt" and avg_value > rule.threshold:
|
||||
triggered = True
|
||||
elif rule.condition == "lt" and avg_value < rule.threshold:
|
||||
triggered = True
|
||||
elif rule.condition == "gte" and avg_value >= rule.threshold:
|
||||
triggered = True
|
||||
elif rule.condition == "lte" and avg_value <= rule.threshold:
|
||||
triggered = True
|
||||
elif rule.condition == "eq" and avg_value == rule.threshold:
|
||||
triggered = True
|
||||
elif rule.condition == "neq" and avg_value != rule.threshold:
|
||||
triggered = True
|
||||
|
||||
# Handle alert state
|
||||
alert_key = f"{rule.id}:{rule.metric_name}"
|
||||
|
||||
if triggered:
|
||||
if alert_key not in self.active_alerts:
|
||||
# New alert
|
||||
alert = Alert(
|
||||
id=str(uuid.uuid4()),
|
||||
rule_id=rule.id,
|
||||
rule_name=rule.name,
|
||||
metric_name=rule.metric_name,
|
||||
current_value=avg_value,
|
||||
threshold=rule.threshold,
|
||||
severity=rule.severity,
|
||||
triggered_at=datetime.now(),
|
||||
status="active"
|
||||
)
|
||||
self.active_alerts[alert_key] = alert
|
||||
|
||||
# Send notifications
|
||||
await self._send_alert_notifications(alert, rule)
|
||||
|
||||
else:
|
||||
if alert_key in self.active_alerts:
|
||||
# Alert resolved
|
||||
alert = self.active_alerts[alert_key]
|
||||
alert.resolved_at = datetime.now()
|
||||
alert.status = "resolved"
|
||||
del self.active_alerts[alert_key]
|
||||
|
||||
logger.info(f"Alert resolved: {rule.name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating alert rule {rule.id}: {e}")
|
||||
|
||||
async def _send_alert_notifications(self, alert: Alert, rule: AlertRule):
|
||||
"""Send alert notifications"""
|
||||
logger.warning(f"ALERT: {rule.name} - {alert.metric_name} = {alert.current_value} (threshold: {alert.threshold})")
|
||||
# Would implement actual notification channels here
|
||||
|
||||
async def get_active_alerts(self) -> List[Dict[str, Any]]:
|
||||
"""Get currently active alerts"""
|
||||
return [alert.dict() for alert in self.active_alerts.values()]
|
||||
|
||||
async def export_to_csv(
|
||||
self,
|
||||
metric_type: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
):
|
||||
"""Export metrics to CSV"""
|
||||
try:
|
||||
# Get metrics
|
||||
metrics = await self.ts_db.query_metrics(
|
||||
metric_type=metric_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
# Create CSV
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
output,
|
||||
fieldnames=['timestamp', 'metric_name', 'value', 'tags', 'service']
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
for metric in metrics:
|
||||
writer.writerow({
|
||||
'timestamp': metric.get('timestamp'),
|
||||
'metric_name': metric.get('name'),
|
||||
'value': metric.get('value'),
|
||||
'tags': str(metric.get('tags', {})),
|
||||
'service': metric.get('service')
|
||||
})
|
||||
|
||||
output.seek(0)
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting to CSV: {e}")
|
||||
raise
|
||||
|
||||
async def _cleanup_old_data(self):
|
||||
"""Clean up old data periodically"""
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(86400) # Run daily
|
||||
|
||||
# Delete data older than 30 days
|
||||
cutoff_date = datetime.now() - timedelta(days=30)
|
||||
await self.ts_db.delete_old_data(cutoff_date)
|
||||
|
||||
logger.info("Completed old data cleanup")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data cleanup: {e}")
|
||||
32
services/statistics/backend/cache_manager.py
Normal file
32
services/statistics/backend/cache_manager.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Cache Manager for Redis"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CacheManager:
|
||||
"""Redis cache manager"""
|
||||
|
||||
def __init__(self, redis_url: str):
|
||||
self.redis_url = redis_url
|
||||
self.is_connected = False
|
||||
self.cache = {} # Simplified in-memory cache
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to Redis"""
|
||||
self.is_connected = True
|
||||
logger.info("Connected to cache")
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection"""
|
||||
self.is_connected = False
|
||||
logger.info("Disconnected from cache")
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
"""Get value from cache"""
|
||||
return self.cache.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, expire: Optional[int] = None):
|
||||
"""Set value in cache"""
|
||||
self.cache[key] = value
|
||||
396
services/statistics/backend/main.py
Normal file
396
services/statistics/backend/main.py
Normal file
@ -0,0 +1,396 @@
|
||||
"""
|
||||
Statistics Service - Real-time Analytics and Metrics
|
||||
"""
|
||||
from fastapi import FastAPI, HTTPException, Depends, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
import uvicorn
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
# Import custom modules
|
||||
from models import Metric, AggregatedMetric, TimeSeriesData, DashboardConfig
|
||||
from metrics_collector import MetricsCollector
|
||||
from aggregator import DataAggregator
|
||||
from websocket_manager import WebSocketManager
|
||||
from time_series_db import TimeSeriesDB
|
||||
from cache_manager import CacheManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
metrics_collector = None
|
||||
data_aggregator = None
|
||||
ws_manager = None
|
||||
ts_db = None
|
||||
cache_manager = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
global metrics_collector, data_aggregator, ws_manager, ts_db, cache_manager
|
||||
|
||||
try:
|
||||
# Initialize TimeSeriesDB (using InfluxDB)
|
||||
ts_db = TimeSeriesDB(
|
||||
host=os.getenv("INFLUXDB_HOST", "influxdb"),
|
||||
port=int(os.getenv("INFLUXDB_PORT", 8086)),
|
||||
database=os.getenv("INFLUXDB_DATABASE", "statistics")
|
||||
)
|
||||
await ts_db.connect()
|
||||
logger.info("Connected to InfluxDB")
|
||||
|
||||
# Initialize Cache Manager
|
||||
cache_manager = CacheManager(
|
||||
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
)
|
||||
await cache_manager.connect()
|
||||
logger.info("Connected to Redis cache")
|
||||
|
||||
# Initialize Metrics Collector (optional Kafka connection)
|
||||
try:
|
||||
metrics_collector = MetricsCollector(
|
||||
kafka_bootstrap_servers=os.getenv("KAFKA_BOOTSTRAP_SERVERS", "kafka:9092"),
|
||||
ts_db=ts_db,
|
||||
cache=cache_manager
|
||||
)
|
||||
await metrics_collector.start()
|
||||
logger.info("Metrics collector started")
|
||||
except Exception as e:
|
||||
logger.warning(f"Metrics collector failed to start (Kafka not available): {e}")
|
||||
metrics_collector = None
|
||||
|
||||
# Initialize Data Aggregator
|
||||
data_aggregator = DataAggregator(
|
||||
ts_db=ts_db,
|
||||
cache=cache_manager
|
||||
)
|
||||
asyncio.create_task(data_aggregator.start_aggregation_jobs())
|
||||
logger.info("Data aggregator started")
|
||||
|
||||
# Initialize WebSocket Manager
|
||||
ws_manager = WebSocketManager()
|
||||
logger.info("WebSocket manager initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start Statistics service: {e}")
|
||||
raise
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if metrics_collector:
|
||||
await metrics_collector.stop()
|
||||
if data_aggregator:
|
||||
await data_aggregator.stop()
|
||||
if ts_db:
|
||||
await ts_db.close()
|
||||
if cache_manager:
|
||||
await cache_manager.close()
|
||||
|
||||
logger.info("Statistics service shutdown complete")
|
||||
|
||||
app = FastAPI(
|
||||
title="Statistics Service",
|
||||
description="Real-time Analytics and Metrics Service",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "Statistics Service",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "statistics",
|
||||
"components": {
|
||||
"influxdb": "connected" if ts_db and ts_db.is_connected else "disconnected",
|
||||
"redis": "connected" if cache_manager and cache_manager.is_connected else "disconnected",
|
||||
"metrics_collector": "running" if metrics_collector and metrics_collector.is_running else "stopped",
|
||||
"aggregator": "running" if data_aggregator and data_aggregator.is_running else "stopped"
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Metrics Endpoints
|
||||
@app.post("/api/metrics")
|
||||
async def record_metric(metric: Metric):
|
||||
"""Record a single metric"""
|
||||
try:
|
||||
await metrics_collector.record_metric(metric)
|
||||
return {"status": "recorded", "metric_id": metric.id}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/metrics/batch")
|
||||
async def record_metrics_batch(metrics: List[Metric]):
|
||||
"""Record multiple metrics in batch"""
|
||||
try:
|
||||
await metrics_collector.record_metrics_batch(metrics)
|
||||
return {"status": "recorded", "count": len(metrics)}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/metrics/realtime/{metric_type}")
|
||||
async def get_realtime_metrics(
|
||||
metric_type: str,
|
||||
duration: int = Query(60, description="Duration in seconds")
|
||||
):
|
||||
"""Get real-time metrics for the specified type"""
|
||||
try:
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(seconds=duration)
|
||||
|
||||
metrics = await ts_db.query_metrics(
|
||||
metric_type=metric_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
return {
|
||||
"metric_type": metric_type,
|
||||
"duration": duration,
|
||||
"data": metrics,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Analytics Endpoints
|
||||
@app.get("/api/analytics/overview")
|
||||
async def get_analytics_overview():
|
||||
"""Get overall analytics overview"""
|
||||
try:
|
||||
# Try to get from cache first
|
||||
cached = await cache_manager.get("analytics:overview")
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
# Calculate analytics
|
||||
overview = await data_aggregator.get_overview()
|
||||
|
||||
# Cache for 1 minute
|
||||
await cache_manager.set(
|
||||
"analytics:overview",
|
||||
json.dumps(overview),
|
||||
expire=60
|
||||
)
|
||||
|
||||
return overview
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/analytics/users")
|
||||
async def get_user_analytics(
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
granularity: str = Query("hour", regex="^(minute|hour|day|week|month)$")
|
||||
):
|
||||
"""Get user analytics"""
|
||||
try:
|
||||
if not start_date:
|
||||
start_date = datetime.now() - timedelta(days=7)
|
||||
if not end_date:
|
||||
end_date = datetime.now()
|
||||
|
||||
analytics = await data_aggregator.get_user_analytics(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
granularity=granularity
|
||||
)
|
||||
|
||||
return analytics
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/analytics/system")
|
||||
async def get_system_analytics():
|
||||
"""Get system performance analytics"""
|
||||
try:
|
||||
analytics = await data_aggregator.get_system_analytics()
|
||||
return analytics
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/analytics/events")
|
||||
async def get_event_analytics(
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = Query(100, le=1000)
|
||||
):
|
||||
"""Get event analytics"""
|
||||
try:
|
||||
analytics = await data_aggregator.get_event_analytics(
|
||||
event_type=event_type,
|
||||
limit=limit
|
||||
)
|
||||
return analytics
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Time Series Endpoints
|
||||
@app.get("/api/timeseries/{metric_name}")
|
||||
async def get_time_series(
|
||||
metric_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
interval: str = Query("1m", regex="^\\d+[smhd]$")
|
||||
):
|
||||
"""Get time series data for a specific metric"""
|
||||
try:
|
||||
data = await ts_db.get_time_series(
|
||||
metric_name=metric_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
interval=interval
|
||||
)
|
||||
|
||||
return TimeSeriesData(
|
||||
metric_name=metric_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
interval=interval,
|
||||
data=data
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Aggregation Endpoints
|
||||
@app.get("/api/aggregates/{metric_type}")
|
||||
async def get_aggregated_metrics(
|
||||
metric_type: str,
|
||||
aggregation: str = Query("avg", regex="^(avg|sum|min|max|count)$"),
|
||||
group_by: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None
|
||||
):
|
||||
"""Get aggregated metrics"""
|
||||
try:
|
||||
if not start_time:
|
||||
start_time = datetime.now() - timedelta(hours=24)
|
||||
if not end_time:
|
||||
end_time = datetime.now()
|
||||
|
||||
result = await data_aggregator.aggregate_metrics(
|
||||
metric_type=metric_type,
|
||||
aggregation=aggregation,
|
||||
group_by=group_by,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Dashboard Endpoints
|
||||
@app.get("/api/dashboard/configs")
|
||||
async def get_dashboard_configs():
|
||||
"""Get available dashboard configurations"""
|
||||
configs = await data_aggregator.get_dashboard_configs()
|
||||
return {"configs": configs}
|
||||
|
||||
@app.get("/api/dashboard/{dashboard_id}")
|
||||
async def get_dashboard_data(dashboard_id: str):
|
||||
"""Get data for a specific dashboard"""
|
||||
try:
|
||||
data = await data_aggregator.get_dashboard_data(dashboard_id)
|
||||
return data
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# WebSocket Endpoint for Real-time Updates
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
@app.websocket("/ws/metrics")
|
||||
async def websocket_metrics(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time metrics streaming"""
|
||||
await ws_manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
# Send metrics updates every second
|
||||
metrics = await metrics_collector.get_latest_metrics()
|
||||
await websocket.send_json({
|
||||
"type": "metrics_update",
|
||||
"data": metrics,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(websocket)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
ws_manager.disconnect(websocket)
|
||||
|
||||
# Alert Management Endpoints
|
||||
@app.post("/api/alerts/rules")
|
||||
async def create_alert_rule(rule: Dict[str, Any]):
|
||||
"""Create a new alert rule"""
|
||||
try:
|
||||
rule_id = await data_aggregator.create_alert_rule(rule)
|
||||
return {"rule_id": rule_id, "status": "created"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/alerts/active")
|
||||
async def get_active_alerts():
|
||||
"""Get currently active alerts"""
|
||||
try:
|
||||
alerts = await data_aggregator.get_active_alerts()
|
||||
return {"alerts": alerts, "count": len(alerts)}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Export Endpoints
|
||||
@app.get("/api/export/csv")
|
||||
async def export_metrics_csv(
|
||||
metric_type: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
):
|
||||
"""Export metrics as CSV"""
|
||||
try:
|
||||
csv_data = await data_aggregator.export_to_csv(
|
||||
metric_type=metric_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
csv_data,
|
||||
media_type="text/csv",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=metrics_{metric_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True
|
||||
)
|
||||
242
services/statistics/backend/metrics_collector.py
Normal file
242
services/statistics/backend/metrics_collector.py
Normal file
@ -0,0 +1,242 @@
|
||||
"""
|
||||
Metrics Collector - Collects metrics from Kafka and other sources
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from aiokafka import AIOKafkaConsumer
|
||||
from models import Metric, MetricType
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetricsCollector:
|
||||
"""Collects and processes metrics from various sources"""
|
||||
|
||||
def __init__(self, kafka_bootstrap_servers: str, ts_db, cache):
|
||||
self.kafka_servers = kafka_bootstrap_servers
|
||||
self.ts_db = ts_db
|
||||
self.cache = cache
|
||||
self.consumer = None
|
||||
self.is_running = False
|
||||
self.latest_metrics = {}
|
||||
self.metrics_buffer = []
|
||||
self.buffer_size = 100
|
||||
self.flush_interval = 5 # seconds
|
||||
|
||||
async def start(self):
|
||||
"""Start the metrics collector"""
|
||||
try:
|
||||
# Start Kafka consumer for event metrics
|
||||
self.consumer = AIOKafkaConsumer(
|
||||
'metrics-events',
|
||||
'user-events',
|
||||
'system-metrics',
|
||||
bootstrap_servers=self.kafka_servers,
|
||||
group_id='statistics-consumer-group',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
await self.consumer.start()
|
||||
self.is_running = True
|
||||
|
||||
# Start background tasks
|
||||
asyncio.create_task(self._consume_metrics())
|
||||
asyncio.create_task(self._flush_metrics_periodically())
|
||||
|
||||
logger.info("Metrics collector started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start metrics collector: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the metrics collector"""
|
||||
self.is_running = False
|
||||
if self.consumer:
|
||||
await self.consumer.stop()
|
||||
|
||||
# Flush remaining metrics
|
||||
if self.metrics_buffer:
|
||||
await self._flush_metrics()
|
||||
|
||||
logger.info("Metrics collector stopped")
|
||||
|
||||
async def _consume_metrics(self):
|
||||
"""Consume metrics from Kafka"""
|
||||
while self.is_running:
|
||||
try:
|
||||
async for msg in self.consumer:
|
||||
if not self.is_running:
|
||||
break
|
||||
|
||||
metric = self._parse_kafka_message(msg)
|
||||
if metric:
|
||||
await self.record_metric(metric)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error consuming metrics: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def _parse_kafka_message(self, msg) -> Optional[Metric]:
|
||||
"""Parse Kafka message into Metric"""
|
||||
try:
|
||||
data = msg.value
|
||||
topic = msg.topic
|
||||
|
||||
# Create metric based on topic
|
||||
if topic == 'user-events':
|
||||
return self._create_user_metric(data)
|
||||
elif topic == 'system-metrics':
|
||||
return self._create_system_metric(data)
|
||||
elif topic == 'metrics-events':
|
||||
return Metric(**data)
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse Kafka message: {e}")
|
||||
return None
|
||||
|
||||
def _create_user_metric(self, data: Dict) -> Metric:
|
||||
"""Create metric from user event"""
|
||||
event_type = data.get('event_type', 'unknown')
|
||||
|
||||
return Metric(
|
||||
id=str(uuid.uuid4()),
|
||||
name=f"user.event.{event_type.lower()}",
|
||||
type=MetricType.COUNTER,
|
||||
value=1,
|
||||
tags={
|
||||
"event_type": event_type,
|
||||
"user_id": data.get('data', {}).get('user_id', 'unknown'),
|
||||
"service": data.get('service', 'unknown')
|
||||
},
|
||||
timestamp=datetime.fromisoformat(data.get('timestamp', datetime.now().isoformat())),
|
||||
service=data.get('service', 'users')
|
||||
)
|
||||
|
||||
def _create_system_metric(self, data: Dict) -> Metric:
|
||||
"""Create metric from system event"""
|
||||
return Metric(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.get('metric_name', 'system.unknown'),
|
||||
type=MetricType.GAUGE,
|
||||
value=float(data.get('value', 0)),
|
||||
tags=data.get('tags', {}),
|
||||
timestamp=datetime.fromisoformat(data.get('timestamp', datetime.now().isoformat())),
|
||||
service=data.get('service', 'system')
|
||||
)
|
||||
|
||||
async def record_metric(self, metric: Metric):
|
||||
"""Record a single metric"""
|
||||
try:
|
||||
# Add to buffer
|
||||
self.metrics_buffer.append(metric)
|
||||
|
||||
# Update latest metrics cache
|
||||
self.latest_metrics[metric.name] = {
|
||||
"value": metric.value,
|
||||
"timestamp": metric.timestamp.isoformat(),
|
||||
"tags": metric.tags
|
||||
}
|
||||
|
||||
# Flush if buffer is full
|
||||
if len(self.metrics_buffer) >= self.buffer_size:
|
||||
await self._flush_metrics()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record metric: {e}")
|
||||
raise
|
||||
|
||||
async def record_metrics_batch(self, metrics: List[Metric]):
|
||||
"""Record multiple metrics"""
|
||||
for metric in metrics:
|
||||
await self.record_metric(metric)
|
||||
|
||||
async def _flush_metrics(self):
|
||||
"""Flush metrics buffer to time series database"""
|
||||
if not self.metrics_buffer:
|
||||
return
|
||||
|
||||
try:
|
||||
# Write to time series database
|
||||
await self.ts_db.write_metrics(self.metrics_buffer)
|
||||
|
||||
# Clear buffer
|
||||
self.metrics_buffer.clear()
|
||||
|
||||
logger.debug(f"Flushed {len(self.metrics_buffer)} metrics to database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to flush metrics: {e}")
|
||||
|
||||
async def _flush_metrics_periodically(self):
|
||||
"""Periodically flush metrics buffer"""
|
||||
while self.is_running:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self._flush_metrics()
|
||||
|
||||
async def get_latest_metrics(self) -> Dict[str, Any]:
|
||||
"""Get latest metrics for real-time display"""
|
||||
return self.latest_metrics
|
||||
|
||||
async def collect_system_metrics(self):
|
||||
"""Collect system-level metrics"""
|
||||
import psutil
|
||||
|
||||
try:
|
||||
# CPU metrics
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
await self.record_metric(Metric(
|
||||
name="system.cpu.usage",
|
||||
type=MetricType.GAUGE,
|
||||
value=cpu_percent,
|
||||
tags={"host": "localhost"},
|
||||
service="statistics"
|
||||
))
|
||||
|
||||
# Memory metrics
|
||||
memory = psutil.virtual_memory()
|
||||
await self.record_metric(Metric(
|
||||
name="system.memory.usage",
|
||||
type=MetricType.GAUGE,
|
||||
value=memory.percent,
|
||||
tags={"host": "localhost"},
|
||||
service="statistics"
|
||||
))
|
||||
|
||||
# Disk metrics
|
||||
disk = psutil.disk_usage('/')
|
||||
await self.record_metric(Metric(
|
||||
name="system.disk.usage",
|
||||
type=MetricType.GAUGE,
|
||||
value=disk.percent,
|
||||
tags={"host": "localhost", "mount": "/"},
|
||||
service="statistics"
|
||||
))
|
||||
|
||||
# Network metrics
|
||||
net_io = psutil.net_io_counters()
|
||||
await self.record_metric(Metric(
|
||||
name="system.network.bytes_sent",
|
||||
type=MetricType.COUNTER,
|
||||
value=net_io.bytes_sent,
|
||||
tags={"host": "localhost"},
|
||||
service="statistics"
|
||||
))
|
||||
await self.record_metric(Metric(
|
||||
name="system.network.bytes_recv",
|
||||
type=MetricType.COUNTER,
|
||||
value=net_io.bytes_recv,
|
||||
tags={"host": "localhost"},
|
||||
service="statistics"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to collect system metrics: {e}")
|
||||
|
||||
async def collect_application_metrics(self):
|
||||
"""Collect application-level metrics"""
|
||||
# This would be called by other services to report their metrics
|
||||
pass
|
||||
159
services/statistics/backend/models.py
Normal file
159
services/statistics/backend/models.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""
|
||||
Data models for Statistics Service
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from enum import Enum
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""Types of metrics"""
|
||||
COUNTER = "counter"
|
||||
GAUGE = "gauge"
|
||||
HISTOGRAM = "histogram"
|
||||
SUMMARY = "summary"
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Types of aggregation"""
|
||||
AVG = "avg"
|
||||
SUM = "sum"
|
||||
MIN = "min"
|
||||
MAX = "max"
|
||||
COUNT = "count"
|
||||
PERCENTILE = "percentile"
|
||||
|
||||
class Granularity(str, Enum):
|
||||
"""Time granularity for aggregation"""
|
||||
MINUTE = "minute"
|
||||
HOUR = "hour"
|
||||
DAY = "day"
|
||||
WEEK = "week"
|
||||
MONTH = "month"
|
||||
|
||||
class Metric(BaseModel):
|
||||
"""Single metric data point"""
|
||||
id: Optional[str] = Field(None, description="Unique metric ID")
|
||||
name: str = Field(..., description="Metric name")
|
||||
type: MetricType = Field(..., description="Metric type")
|
||||
value: float = Field(..., description="Metric value")
|
||||
tags: Dict[str, str] = Field(default_factory=dict, description="Metric tags")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="Metric timestamp")
|
||||
service: str = Field(..., description="Source service")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class AggregatedMetric(BaseModel):
|
||||
"""Aggregated metric result"""
|
||||
metric_name: str
|
||||
aggregation_type: AggregationType
|
||||
value: float
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
granularity: Optional[Granularity] = None
|
||||
group_by: Optional[str] = None
|
||||
count: int = Field(..., description="Number of data points aggregated")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class TimeSeriesData(BaseModel):
|
||||
"""Time series data response"""
|
||||
metric_name: str
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
interval: str
|
||||
data: List[Dict[str, Any]]
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class DashboardConfig(BaseModel):
|
||||
"""Dashboard configuration"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
widgets: List[Dict[str, Any]]
|
||||
refresh_interval: int = Field(60, description="Refresh interval in seconds")
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class AlertRule(BaseModel):
|
||||
"""Alert rule configuration"""
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
metric_name: str
|
||||
condition: Literal["gt", "lt", "gte", "lte", "eq", "neq"]
|
||||
threshold: float
|
||||
duration: int = Field(..., description="Duration in seconds")
|
||||
severity: Literal["low", "medium", "high", "critical"]
|
||||
enabled: bool = True
|
||||
notification_channels: List[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class Alert(BaseModel):
|
||||
"""Active alert"""
|
||||
id: str
|
||||
rule_id: str
|
||||
rule_name: str
|
||||
metric_name: str
|
||||
current_value: float
|
||||
threshold: float
|
||||
severity: str
|
||||
triggered_at: datetime
|
||||
resolved_at: Optional[datetime] = None
|
||||
status: Literal["active", "resolved", "acknowledged"]
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
class UserAnalytics(BaseModel):
|
||||
"""User analytics data"""
|
||||
total_users: int
|
||||
active_users: int
|
||||
new_users: int
|
||||
user_growth_rate: float
|
||||
average_session_duration: float
|
||||
top_actions: List[Dict[str, Any]]
|
||||
user_distribution: Dict[str, int]
|
||||
period: str
|
||||
|
||||
class SystemAnalytics(BaseModel):
|
||||
"""System performance analytics"""
|
||||
uptime_percentage: float
|
||||
average_response_time: float
|
||||
error_rate: float
|
||||
throughput: float
|
||||
cpu_usage: float
|
||||
memory_usage: float
|
||||
disk_usage: float
|
||||
active_connections: int
|
||||
services_health: Dict[str, str]
|
||||
|
||||
class EventAnalytics(BaseModel):
|
||||
"""Event analytics data"""
|
||||
total_events: int
|
||||
events_per_second: float
|
||||
event_types: Dict[str, int]
|
||||
top_events: List[Dict[str, Any]]
|
||||
error_events: int
|
||||
success_rate: float
|
||||
processing_time: Dict[str, float]
|
||||
9
services/statistics/backend/requirements.txt
Normal file
9
services/statistics/backend/requirements.txt
Normal file
@ -0,0 +1,9 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
pydantic==2.5.3
|
||||
python-dotenv==1.0.0
|
||||
aiokafka==0.10.0
|
||||
redis==5.0.1
|
||||
psutil==5.9.8
|
||||
httpx==0.26.0
|
||||
websockets==12.0
|
||||
165
services/statistics/backend/time_series_db.py
Normal file
165
services/statistics/backend/time_series_db.py
Normal file
@ -0,0 +1,165 @@
|
||||
"""
|
||||
Time Series Database Interface (Simplified for InfluxDB)
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from models import Metric, AggregatedMetric
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TimeSeriesDB:
|
||||
"""Time series database interface"""
|
||||
|
||||
def __init__(self, host: str, port: int, database: str):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.database = database
|
||||
self.is_connected = False
|
||||
# In production, would use actual InfluxDB client
|
||||
self.data_store = [] # Simplified in-memory storage
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to database"""
|
||||
# Simplified connection
|
||||
self.is_connected = True
|
||||
logger.info(f"Connected to time series database at {self.host}:{self.port}")
|
||||
|
||||
async def close(self):
|
||||
"""Close database connection"""
|
||||
self.is_connected = False
|
||||
logger.info("Disconnected from time series database")
|
||||
|
||||
async def write_metrics(self, metrics: List[Metric]):
|
||||
"""Write metrics to database"""
|
||||
for metric in metrics:
|
||||
self.data_store.append({
|
||||
"name": metric.name,
|
||||
"value": metric.value,
|
||||
"timestamp": metric.timestamp,
|
||||
"tags": metric.tags,
|
||||
"service": metric.service
|
||||
})
|
||||
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_type: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query metrics from database"""
|
||||
results = []
|
||||
for data in self.data_store:
|
||||
if (data["name"].startswith(metric_type) and
|
||||
start_time <= data["timestamp"] <= end_time):
|
||||
results.append(data)
|
||||
return results
|
||||
|
||||
async def get_time_series(
|
||||
self,
|
||||
metric_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
interval: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get time series data"""
|
||||
return await self.query_metrics(metric_name, start_time, end_time)
|
||||
|
||||
async def store_aggregated_metric(self, metric: AggregatedMetric):
|
||||
"""Store aggregated metric"""
|
||||
self.data_store.append({
|
||||
"name": f"agg.{metric.metric_name}",
|
||||
"value": metric.value,
|
||||
"timestamp": metric.end_time,
|
||||
"tags": {"aggregation": metric.aggregation_type},
|
||||
"service": "statistics"
|
||||
})
|
||||
|
||||
async def count_metrics(
|
||||
self,
|
||||
metric_type: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> int:
|
||||
"""Count metrics"""
|
||||
metrics = await self.query_metrics(metric_type, start_time, end_time)
|
||||
return len(metrics)
|
||||
|
||||
async def get_average(
|
||||
self,
|
||||
metric_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> Optional[float]:
|
||||
"""Get average value"""
|
||||
metrics = await self.query_metrics(metric_name, start_time, end_time)
|
||||
if not metrics:
|
||||
return None
|
||||
values = [m["value"] for m in metrics]
|
||||
return sum(values) / len(values)
|
||||
|
||||
async def count_distinct_tags(
|
||||
self,
|
||||
metric_type: str,
|
||||
tag_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> int:
|
||||
"""Count distinct tag values"""
|
||||
metrics = await self.query_metrics(metric_type, start_time, end_time)
|
||||
unique_values = set()
|
||||
for metric in metrics:
|
||||
if tag_name in metric.get("tags", {}):
|
||||
unique_values.add(metric["tags"][tag_name])
|
||||
return len(unique_values)
|
||||
|
||||
async def get_top_metrics(
|
||||
self,
|
||||
metric_type: str,
|
||||
group_by: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get top metrics grouped by tag"""
|
||||
metrics = await self.query_metrics(metric_type, start_time, end_time)
|
||||
grouped = {}
|
||||
for metric in metrics:
|
||||
key = metric.get("tags", {}).get(group_by, "unknown")
|
||||
grouped[key] = grouped.get(key, 0) + 1
|
||||
|
||||
sorted_items = sorted(grouped.items(), key=lambda x: x[1], reverse=True)
|
||||
return [{"name": k, "count": v} for k, v in sorted_items[:limit]]
|
||||
|
||||
async def count_metrics_with_value(
|
||||
self,
|
||||
metric_name: str,
|
||||
value: float,
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> int:
|
||||
"""Count metrics with specific value"""
|
||||
metrics = await self.query_metrics(metric_name, start_time, end_time)
|
||||
return sum(1 for m in metrics if m["value"] == value)
|
||||
|
||||
async def get_metric_distribution(
|
||||
self,
|
||||
metric_type: str,
|
||||
tag_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> Dict[str, int]:
|
||||
"""Get metric distribution by tag"""
|
||||
metrics = await self.query_metrics(metric_type, start_time, end_time)
|
||||
distribution = {}
|
||||
for metric in metrics:
|
||||
key = metric.get("tags", {}).get(tag_name, "unknown")
|
||||
distribution[key] = distribution.get(key, 0) + 1
|
||||
return distribution
|
||||
|
||||
async def delete_old_data(self, cutoff_date: datetime):
|
||||
"""Delete old data"""
|
||||
self.data_store = [
|
||||
d for d in self.data_store
|
||||
if d["timestamp"] >= cutoff_date
|
||||
]
|
||||
33
services/statistics/backend/websocket_manager.py
Normal file
33
services/statistics/backend/websocket_manager.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""WebSocket Manager for real-time updates"""
|
||||
from typing import List
|
||||
from fastapi import WebSocket
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WebSocketManager:
|
||||
"""Manages WebSocket connections"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
"""Accept WebSocket connection"""
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
"""Remove WebSocket connection"""
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""Broadcast message to all connected clients"""
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting to WebSocket: {e}")
|
||||
self.disconnect(connection)
|
||||
21
services/users/backend/Dockerfile
Normal file
21
services/users/backend/Dockerfile
Normal file
@ -0,0 +1,21 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the application
|
||||
CMD ["python", "main.py"]
|
||||
22
services/users/backend/database.py
Normal file
22
services/users/backend/database.py
Normal file
@ -0,0 +1,22 @@
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from beanie import init_beanie
|
||||
import os
|
||||
from models import User
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database connection"""
|
||||
# Get MongoDB URL from environment
|
||||
mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
|
||||
db_name = os.getenv("DB_NAME", "users_db")
|
||||
|
||||
# Create Motor client
|
||||
client = AsyncIOMotorClient(mongodb_url)
|
||||
|
||||
# Initialize beanie with the User model
|
||||
await init_beanie(
|
||||
database=client[db_name],
|
||||
document_models=[User]
|
||||
)
|
||||
|
||||
print(f"Connected to MongoDB: {mongodb_url}/{db_name}")
|
||||
334
services/users/backend/main.py
Normal file
334
services/users/backend/main.py
Normal file
@ -0,0 +1,334 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import uvicorn
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from database import init_db
|
||||
from models import User
|
||||
from beanie import PydanticObjectId
|
||||
|
||||
sys.path.append('/app')
|
||||
from shared.kafka import KafkaProducer, Event, EventType
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Pydantic models for requests
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
email: str
|
||||
full_name: Optional[str] = None
|
||||
profile_picture: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
profile_picture: Optional[str] = None
|
||||
profile_picture_thumbnail: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
is_email_verified: Optional[bool] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
full_name: Optional[str] = None
|
||||
profile_picture: Optional[str] = None
|
||||
profile_picture_thumbnail: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
is_email_verified: bool
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class UserPublicResponse(BaseModel):
|
||||
"""공개 프로필용 응답 (민감한 정보 제외)"""
|
||||
id: str
|
||||
username: str
|
||||
full_name: Optional[str] = None
|
||||
profile_picture: Optional[str] = None
|
||||
profile_picture_thumbnail: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# Global Kafka producer
|
||||
kafka_producer: Optional[KafkaProducer] = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
global kafka_producer
|
||||
|
||||
await init_db()
|
||||
|
||||
# Initialize Kafka producer
|
||||
try:
|
||||
kafka_producer = KafkaProducer(
|
||||
bootstrap_servers=os.getenv('KAFKA_BOOTSTRAP_SERVERS', 'kafka:9092')
|
||||
)
|
||||
await kafka_producer.start()
|
||||
logger.info("Kafka producer initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize Kafka producer: {e}")
|
||||
kafka_producer = None
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if kafka_producer:
|
||||
await kafka_producer.stop()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Users Service",
|
||||
description="User management microservice with MongoDB",
|
||||
version="0.2.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Health check
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "users",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# CRUD Operations
|
||||
@app.get("/users", response_model=List[UserResponse])
|
||||
async def get_users():
|
||||
users = await User.find_all().to_list()
|
||||
return [UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
profile_picture=user.profile_picture,
|
||||
profile_picture_thumbnail=user.profile_picture_thumbnail,
|
||||
bio=user.bio,
|
||||
location=user.location,
|
||||
website=user.website,
|
||||
is_email_verified=user.is_email_verified,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
) for user in users]
|
||||
|
||||
@app.get("/users/{user_id}", response_model=UserResponse)
|
||||
async def get_user(user_id: str):
|
||||
try:
|
||||
user = await User.get(PydanticObjectId(user_id))
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
profile_picture=user.profile_picture,
|
||||
profile_picture_thumbnail=user.profile_picture_thumbnail,
|
||||
bio=user.bio,
|
||||
location=user.location,
|
||||
website=user.website,
|
||||
is_email_verified=user.is_email_verified,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
@app.post("/users", response_model=UserResponse, status_code=201)
|
||||
async def create_user(user_data: UserCreate):
|
||||
# Check if username already exists
|
||||
existing_user = await User.find_one(User.username == user_data.username)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="Username already exists")
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
full_name=user_data.full_name,
|
||||
profile_picture=user_data.profile_picture,
|
||||
bio=user_data.bio,
|
||||
location=user_data.location,
|
||||
website=user_data.website
|
||||
)
|
||||
|
||||
await user.create()
|
||||
|
||||
# Publish event
|
||||
if kafka_producer:
|
||||
event = Event(
|
||||
event_type=EventType.USER_CREATED,
|
||||
service="users",
|
||||
data={
|
||||
"user_id": str(user.id),
|
||||
"username": user.username,
|
||||
"email": user.email
|
||||
},
|
||||
user_id=str(user.id)
|
||||
)
|
||||
await kafka_producer.send_event("user-events", event)
|
||||
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
profile_picture=user.profile_picture,
|
||||
profile_picture_thumbnail=user.profile_picture_thumbnail,
|
||||
bio=user.bio,
|
||||
location=user.location,
|
||||
website=user.website,
|
||||
is_email_verified=user.is_email_verified,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
|
||||
@app.put("/users/{user_id}", response_model=UserResponse)
|
||||
async def update_user(user_id: str, user_update: UserUpdate):
|
||||
try:
|
||||
user = await User.get(PydanticObjectId(user_id))
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user_update.username is not None:
|
||||
# Check if new username already exists
|
||||
existing_user = await User.find_one(
|
||||
User.username == user_update.username,
|
||||
User.id != user.id
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="Username already exists")
|
||||
user.username = user_update.username
|
||||
|
||||
if user_update.email is not None:
|
||||
user.email = user_update.email
|
||||
|
||||
if user_update.full_name is not None:
|
||||
user.full_name = user_update.full_name
|
||||
|
||||
if user_update.profile_picture is not None:
|
||||
user.profile_picture = user_update.profile_picture
|
||||
|
||||
if user_update.profile_picture_thumbnail is not None:
|
||||
user.profile_picture_thumbnail = user_update.profile_picture_thumbnail
|
||||
|
||||
if user_update.bio is not None:
|
||||
user.bio = user_update.bio
|
||||
|
||||
if user_update.location is not None:
|
||||
user.location = user_update.location
|
||||
|
||||
if user_update.website is not None:
|
||||
user.website = user_update.website
|
||||
|
||||
if user_update.is_email_verified is not None:
|
||||
user.is_email_verified = user_update.is_email_verified
|
||||
|
||||
if user_update.is_active is not None:
|
||||
user.is_active = user_update.is_active
|
||||
|
||||
user.updated_at = datetime.now()
|
||||
await user.save()
|
||||
|
||||
# Publish event
|
||||
if kafka_producer:
|
||||
event = Event(
|
||||
event_type=EventType.USER_UPDATED,
|
||||
service="users",
|
||||
data={
|
||||
"user_id": str(user.id),
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"updated_fields": list(user_update.dict(exclude_unset=True).keys())
|
||||
},
|
||||
user_id=str(user.id)
|
||||
)
|
||||
await kafka_producer.send_event("user-events", event)
|
||||
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
profile_picture=user.profile_picture,
|
||||
profile_picture_thumbnail=user.profile_picture_thumbnail,
|
||||
bio=user.bio,
|
||||
location=user.location,
|
||||
website=user.website,
|
||||
is_email_verified=user.is_email_verified,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
|
||||
@app.delete("/users/{user_id}")
|
||||
async def delete_user(user_id: str):
|
||||
try:
|
||||
user = await User.get(PydanticObjectId(user_id))
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user_id_str = str(user.id)
|
||||
username = user.username
|
||||
|
||||
await user.delete()
|
||||
|
||||
# Publish event
|
||||
if kafka_producer:
|
||||
event = Event(
|
||||
event_type=EventType.USER_DELETED,
|
||||
service="users",
|
||||
data={
|
||||
"user_id": user_id_str,
|
||||
"username": username
|
||||
},
|
||||
user_id=user_id_str
|
||||
)
|
||||
await kafka_producer.send_event("user-events", event)
|
||||
|
||||
return {"message": "User deleted successfully"}
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True
|
||||
)
|
||||
31
services/users/backend/models.py
Normal file
31
services/users/backend/models.py
Normal file
@ -0,0 +1,31 @@
|
||||
from beanie import Document
|
||||
from pydantic import EmailStr, Field, HttpUrl
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class User(Document):
|
||||
username: str = Field(..., unique=True)
|
||||
email: EmailStr
|
||||
full_name: Optional[str] = None
|
||||
profile_picture: Optional[str] = Field(None, description="프로필 사진 URL")
|
||||
profile_picture_thumbnail: Optional[str] = Field(None, description="프로필 사진 썸네일 URL")
|
||||
bio: Optional[str] = Field(None, max_length=500, description="자기소개")
|
||||
location: Optional[str] = Field(None, description="위치")
|
||||
website: Optional[str] = Field(None, description="개인 웹사이트")
|
||||
is_email_verified: bool = Field(default=False, description="이메일 인증 여부")
|
||||
is_active: bool = Field(default=True, description="계정 활성화 상태")
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class Settings:
|
||||
collection = "users"
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"username": "john_doe",
|
||||
"email": "john@example.com",
|
||||
"full_name": "John Doe"
|
||||
}
|
||||
}
|
||||
7
services/users/backend/requirements.txt
Normal file
7
services/users/backend/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
pydantic[email]==2.5.3
|
||||
pymongo==4.6.1
|
||||
motor==3.3.2
|
||||
beanie==1.23.6
|
||||
aiokafka==0.10.0
|
||||
Reference in New Issue
Block a user