From c8802cfc654b40eba72814cd4a28906891ba7c18 Mon Sep 17 00:00:00 2001 From: jungwoo choi Date: Mon, 10 Nov 2025 09:57:19 +0900 Subject: [PATCH] Initial commit: mBART Translation API with Docker support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - FastAPI 기반 다국어 번역 REST API 서비스 - mBART-50 모델을 사용한 18개 언어 지원 - Docker 및 Docker Compose 설정 포함 - GPU/CPU 지원 - 헬스 체크 및 API 문서 자동 생성 - 외부 접속 지원 (172.30.1.2:8000) 주요 파일: - main.py: FastAPI 애플리케이션 - translator.py: mBART 번역 서비스 - models.py: Pydantic 데이터 모델 - config.py: 환경 설정 - Dockerfile: 최적화된 Docker 이미지 - docker-compose.yml: 간편한 배포 설정 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .dockerignore | 55 +++++++++ .env.example | 8 ++ .gitignore | 49 ++++++++ CLAUDE.md | 151 +++++++++++++++++++++++++ Dockerfile | 39 +++++++ README.md | 274 +++++++++++++++++++++++++++++++++++++++++++++ config.py | 38 +++++++ docker-compose.yml | 55 +++++++++ main.py | 143 +++++++++++++++++++++++ models.py | 52 +++++++++ requirements.txt | 8 ++ translator.py | 105 +++++++++++++++++ 12 files changed, 977 insertions(+) create mode 100644 .dockerignore create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 config.py create mode 100644 docker-compose.yml create mode 100644 main.py create mode 100644 models.py create mode 100644 requirements.txt create mode 100644 translator.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..96b780b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,55 @@ +# Python cache and compiled files +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +env/ +venv/ +ENV/ + +# Models (will be downloaded at runtime) +models/ +*.bin +*.pt +*.pth + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Git +.git/ +.gitignore + +# Environment +.env +.env.local + +# Logs +*.log +logs/ + +# Documentation (not needed in container) +README.md +CLAUDE.md + +# Docker files (not needed inside container) +Dockerfile +.dockerignore +docker-compose.yml + +# OS +.DS_Store +Thumbs.db + +# Cache +.cache/ diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..f93ede0 --- /dev/null +++ b/.env.example @@ -0,0 +1,8 @@ +# Server Configuration +HOST=0.0.0.0 +PORT=8000 + +# Model Configuration +MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt +MAX_LENGTH=512 +DEVICE=cuda # or cpu diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9fe26ac --- /dev/null +++ b/.gitignore @@ -0,0 +1,49 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +venv/ +ENV/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Models +models/ +*.bin +*.pt +*.pth + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Environment +.env +.env.local + +# Logs +*.log +logs/ + +# OS +.DS_Store +Thumbs.db diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..c8b771a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,151 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +mBART Translation API - FastAPI 기반의 다국어 번역 REST API 서비스. Facebook의 mBART-50 모델을 사용하여 18개 이상의 언어 간 번역을 제공합니다. + +## Commands + +### Docker (권장) + +```bash +# Docker Compose로 시작 (가장 쉬움) +docker-compose up -d + +# 로그 확인 +docker-compose logs -f + +# 서비스 중지 +docker-compose down + +# Docker 직접 사용 +docker build -t mbart-translation-api . +docker run -d -p 8000:8000 --name mbart-api mbart-translation-api + +# GPU 지원 +docker run -d --gpus all -p 8000:8000 -e DEVICE=cuda --name mbart-api-gpu mbart-translation-api +``` + +### Local Development + +```bash +# 개발 서버 실행 (자동 리로드) +python main.py + +# 프로덕션 서버 실행 +uvicorn main:app --host 0.0.0.0 --port 8000 + +# GPU 사용하여 실행 +DEVICE=cuda python main.py +``` + +### Installation + +```bash +# 의존성 설치 +pip install -r requirements.txt + +# 가상 환경 사용 권장 +python -m venv venv +source venv/bin/activate # Linux/Mac +# or +venv\Scripts\activate # Windows +pip install -r requirements.txt +``` + +### Testing API + +```bash +# 헬스 체크 +curl http://localhost:8000/health + +# 번역 테스트 (한국어 -> 영어) +curl -X POST "http://localhost:8000/translate" \ + -H "Content-Type: application/json" \ + -d '{"text": "안녕하세요", "source_lang": "ko", "target_lang": "en"}' + +# 지원 언어 목록 조회 +curl http://localhost:8000/languages +``` + +## Architecture + +### Core Components + +1. **main.py** - FastAPI 애플리케이션 + - `/translate` POST: 번역 요청 처리 + - `/health` GET: 서비스 상태 확인 + - `/languages` GET: 지원 언어 목록 + - `lifespan`: 애플리케이션 시작 시 모델 로드 + +2. **translator.py** - MBartTranslator 클래스 + - `load_model()`: mBART 모델과 토크나이저 초기화 + - `translate()`: 실제 번역 수행 + - GPU/CPU 자동 감지 및 할당 + - 모델: facebook/mbart-large-50-many-to-many-mmt + +3. **models.py** - Pydantic 데이터 모델 + - TranslationRequest: 번역 요청 스키마 + - TranslationResponse: 번역 응답 스키마 + - HealthResponse, LanguagesResponse + +4. **config.py** - 설정 관리 + - 환경 변수 기반 설정 + - SUPPORTED_LANGUAGES: 언어 코드 매핑 (예: "ko" -> "ko_KR") + +### Request Flow + +``` +Client Request → FastAPI Endpoint → Validation (Pydantic) + ↓ +Response ← Translated Text ← MBartTranslator.translate() +``` + +### Model Loading + +- 애플리케이션 시작 시 `lifespan` 이벤트에서 모델 로드 +- 첫 실행 시 HuggingFace에서 모델 다운로드 (약 2.4GB) +- 이후 실행에서는 캐시된 모델 사용 (~/.cache/huggingface/) + +### Language Code Mapping + +mBART는 특정 언어 코드 형식을 사용합니다: +- 일반 코드 (ko, en) → mBART 코드 (ko_KR, en_XX) +- config.py의 SUPPORTED_LANGUAGES에 정의 +- 새 언어 추가 시 이 딕셔너리에 추가 필요 + +## Key Implementation Details + +- **Device Selection**: CUDA 사용 가능 시 자동으로 GPU 사용, 아니면 CPU +- **Token Length**: MAX_LENGTH(기본 512)로 입력 길이 제한 +- **Error Handling**: 언어 코드 검증, 모델 로드 상태 확인 +- **CORS**: 모든 origin 허용 (프로덕션에서는 제한 필요) + +## Configuration + +환경 변수로 설정 변경: +- `HOST`, `PORT`: 서버 바인딩 설정 +- `MODEL_NAME`: 사용할 mBART 모델 (다른 변형 가능) +- `MAX_LENGTH`: 최대 토큰 길이 +- `DEVICE`: "cuda" 또는 "cpu" + +## Docker Details + +### Image Structure +- Base: python:3.11-slim +- Non-root user (appuser) for security +- Health check endpoint: /health +- Model cache volume: /home/appuser/.cache/huggingface + +### Volume Management +모델은 약 2.4GB로, 컨테이너 재시작 시 재다운로드 방지를 위해 볼륨 마운트 필수: +```bash +-v mbart-cache:/home/appuser/.cache/huggingface +``` + +### Resource Requirements +- CPU: 최소 8GB RAM (16GB 권장) +- GPU: 최소 8GB VRAM +- Disk: 모델 캐시용 약 5GB 여유 공간 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..4a9e3e8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,39 @@ +# Multi-stage build for smaller image size +FROM python:3.11-slim as base + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Set working directory +WORKDIR /app + +# Copy requirements first for better caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create non-root user for security +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app && \ + mkdir -p /home/appuser/.cache/huggingface && \ + chown -R appuser:appuser /home/appuser/.cache + +# Switch to non-root user +USER appuser + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the application +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..052b805 --- /dev/null +++ b/README.md @@ -0,0 +1,274 @@ +# mBART Translation API + +mBART 모델을 사용한 다국어 번역 REST API 서비스입니다. + +## 기능 + +- mBART-50 모델을 사용한 다국어 번역 +- RESTful API 인터페이스 +- 18개 이상의 언어 지원 (한국어, 영어, 일본어, 중국어 등) +- GPU/CPU 지원 +- 자동 API 문서화 (Swagger UI) + +## 지원 언어 + +- 한국어 (ko) +- 영어 (en) +- 일본어 (ja) +- 중국어 (zh) +- 스페인어 (es) +- 프랑스어 (fr) +- 독일어 (de) +- 러시아어 (ru) +- 아랍어 (ar) +- 힌디어 (hi) +- 베트남어 (vi) +- 태국어 (th) +- 인도네시아어 (id) +- 터키어 (tr) +- 포르투갈어 (pt) +- 이탈리아어 (it) +- 네덜란드어 (nl) +- 폴란드어 (pl) + +## 빠른 시작 (Docker 권장) + +### Docker Compose 사용 (가장 쉬운 방법) + +```bash +# 1. 서비스 시작 +docker-compose up -d + +# 2. 로그 확인 +docker-compose logs -f + +# 3. 서비스 중지 +docker-compose down +``` + +서비스가 시작되면 http://localhost:8000 에서 API를 사용할 수 있습니다. + +### 외부 접속 + +외부에서 접속하려면 서버의 IP 주소를 사용하세요: +- 로컬: http://localhost:8000 +- 외부: http://172.30.1.2:8000 (서버 IP 주소로 변경) + +### Docker 직접 사용 + +```bash +# 1. 이미지 빌드 +docker build -t mbart-translation-api . + +# 2. 컨테이너 실행 +docker run -d \ + --name mbart-api \ + -p 8000:8000 \ + -e DEVICE=cpu \ + -v mbart-cache:/home/appuser/.cache/huggingface \ + mbart-translation-api + +# 3. 로그 확인 +docker logs -f mbart-api + +# 4. 컨테이너 중지 +docker stop mbart-api +docker rm mbart-api +``` + +### GPU 지원 (NVIDIA GPU) + +GPU를 사용하려면 [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)이 설치되어 있어야 합니다. + +```bash +docker run -d \ + --name mbart-api-gpu \ + --gpus all \ + -p 8000:8000 \ + -e DEVICE=cuda \ + -v mbart-cache:/home/appuser/.cache/huggingface \ + mbart-translation-api +``` + +또는 docker-compose.yml에서 GPU 서비스 주석을 해제하세요. + +## 로컬 설치 및 실행 + +### 1. 의존성 설치 + +```bash +pip install -r requirements.txt +``` + +### 2. 환경 변수 설정 (선택사항) + +```bash +cp .env.example .env +``` + +`.env` 파일을 편집하여 설정을 변경할 수 있습니다: + +``` +HOST=0.0.0.0 +PORT=8000 +MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt +MAX_LENGTH=512 +DEVICE=cpu # GPU 사용 시 cuda로 변경 +``` + +### 3. 실행 + +#### 개발 모드 + +```bash +python main.py +``` + +#### 프로덕션 모드 + +```bash +uvicorn main:app --host 0.0.0.0 --port 8000 +``` + +#### GPU 사용 + +```bash +DEVICE=cuda python main.py +``` + +## API 사용법 + +### API 문서 + +서버 실행 후 다음 URL에서 자동 생성된 API 문서를 확인할 수 있습니다: +- Swagger UI: http://localhost:8000/docs (또는 http://172.30.1.2:8000/docs) +- ReDoc: http://localhost:8000/redoc (또는 http://172.30.1.2:8000/redoc) + +### 엔드포인트 + +#### 1. 번역 API + +**POST** `/translate` + +요청 예시: + +```bash +# 로컬에서 테스트 +curl -X POST "http://localhost:8000/translate" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "안녕하세요, 반갑습니다.", + "source_lang": "ko", + "target_lang": "en" + }' + +# 외부에서 접속 +curl -X POST "http://172.30.1.2:8000/translate" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "안녕하세요, 반갑습니다.", + "source_lang": "ko", + "target_lang": "en" + }' +``` + +응답 예시: + +```json +{ + "translated_text": "Hello, nice to meet you.", + "source_lang": "ko", + "target_lang": "en", + "original_text": "안녕하세요, 반갑습니다." +} +``` + +#### 2. 헬스 체크 + +**GET** `/health` + +```bash +# 로컬 +curl http://localhost:8000/health + +# 외부 +curl http://172.30.1.2:8000/health +``` + +응답: + +```json +{ + "status": "healthy", + "model_loaded": true, + "device": "cpu" +} +``` + +#### 3. 지원 언어 목록 + +**GET** `/languages` + +```bash +# 로컬 +curl http://localhost:8000/languages + +# 외부 +curl http://172.30.1.2:8000/languages +``` + +## 프로젝트 구조 + +``` +. +├── main.py # FastAPI 애플리케이션 메인 파일 +├── translator.py # mBART 번역 서비스 클래스 +├── models.py # Pydantic 데이터 모델 +├── config.py # 설정 파일 +├── requirements.txt # Python 의존성 +├── Dockerfile # Docker 이미지 빌드 설정 +├── docker-compose.yml # Docker Compose 설정 +├── .dockerignore # Docker 빌드 제외 파일 +├── .env.example # 환경 변수 예시 +├── .gitignore # Git 무시 파일 +├── README.md # 프로젝트 문서 +└── CLAUDE.md # 코드베이스 가이드 +``` + +## 성능 최적화 + +### GPU 사용 + +CUDA가 설치된 환경에서는 GPU를 사용하여 번역 속도를 크게 향상시킬 수 있습니다: + +```bash +DEVICE=cuda python main.py +``` + +### 모델 캐싱 + +첫 실행 시 모델이 다운로드되며, 이후 실행에서는 캐시된 모델을 사용합니다. + +## 문제 해결 + +### 메모리 부족 + +mBART 모델은 크기가 크므로 충분한 메모리가 필요합니다: +- CPU: 최소 8GB RAM 권장 +- GPU: 최소 8GB VRAM 권장 + +메모리가 부족한 경우 `MAX_LENGTH` 값을 줄이거나 더 작은 모델을 사용하세요. + +### CUDA 오류 + +GPU 사용 시 CUDA 관련 오류가 발생하면: + +```bash +DEVICE=cpu python main.py +``` + +CPU 모드로 전환하여 실행하세요. + +## 라이선스 + +이 프로젝트는 mBART 모델 (Facebook AI)을 사용합니다. diff --git a/config.py b/config.py new file mode 100644 index 0000000..4d3ad55 --- /dev/null +++ b/config.py @@ -0,0 +1,38 @@ +import os +from typing import Optional + +class Config: + """Application configuration""" + + # Server settings + HOST: str = os.getenv("HOST", "0.0.0.0") + PORT: int = int(os.getenv("PORT", "8000")) + + # Model settings + MODEL_NAME: str = os.getenv("MODEL_NAME", "facebook/mbart-large-50-many-to-many-mmt") + MAX_LENGTH: int = int(os.getenv("MAX_LENGTH", "512")) + DEVICE: str = os.getenv("DEVICE", "cpu") + + # Supported languages for mBART-50 + SUPPORTED_LANGUAGES = { + "ko": "ko_KR", # Korean + "en": "en_XX", # English + "ja": "ja_XX", # Japanese + "zh": "zh_CN", # Chinese (Simplified) + "es": "es_XX", # Spanish + "fr": "fr_XX", # French + "de": "de_DE", # German + "ru": "ru_RU", # Russian + "ar": "ar_AR", # Arabic + "hi": "hi_IN", # Hindi + "vi": "vi_VN", # Vietnamese + "th": "th_TH", # Thai + "id": "id_ID", # Indonesian + "tr": "tr_TR", # Turkish + "pt": "pt_XX", # Portuguese + "it": "it_IT", # Italian + "nl": "nl_XX", # Dutch + "pl": "pl_PL", # Polish + } + +config = Config() diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..4cc0919 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,55 @@ +version: '3.8' + +services: + mbart-api: + build: + context: . + dockerfile: Dockerfile + container_name: mbart-translation-api + ports: + - "8000:8000" + environment: + - HOST=0.0.0.0 + - PORT=8000 + - MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt + - MAX_LENGTH=512 + - DEVICE=cpu # Change to 'cuda' for GPU support + volumes: + # Cache HuggingFace models to avoid re-downloading + - huggingface-cache:/home/appuser/.cache/huggingface + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s + + # GPU support (uncomment if you have NVIDIA GPU) + # mbart-api-gpu: + # build: + # context: . + # dockerfile: Dockerfile + # container_name: mbart-translation-api-gpu + # ports: + # - "8000:8000" + # environment: + # - HOST=0.0.0.0 + # - PORT=8000 + # - MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt + # - MAX_LENGTH=512 + # - DEVICE=cuda + # volumes: + # - huggingface-cache:/home/appuser/.cache/huggingface + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] + # restart: unless-stopped + +volumes: + huggingface-cache: + driver: local diff --git a/main.py b/main.py new file mode 100644 index 0000000..c363516 --- /dev/null +++ b/main.py @@ -0,0 +1,143 @@ +from fastapi import FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +import logging + +from config import config +from translator import translator +from models import ( + TranslationRequest, + TranslationResponse, + HealthResponse, + LanguagesResponse, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """애플리케이션 시작 시 모델을 로드합니다.""" + logger.info("Starting up: Loading mBART model...") + try: + translator.load_model() + logger.info("Model loaded successfully") + except Exception as e: + logger.error(f"Failed to load model: {str(e)}") + raise + yield + logger.info("Shutting down...") + + +app = FastAPI( + title="mBART Translation API", + description="mBART 모델을 사용한 다국어 번역 API 서비스", + version="1.0.0", + lifespan=lifespan, +) + +# CORS 설정 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/", tags=["Root"]) +async def root(): + """API 루트 엔드포인트""" + return { + "message": "mBART Translation API", + "docs": "/docs", + "health": "/health", + } + + +@app.get("/health", response_model=HealthResponse, tags=["Health"]) +async def health_check(): + """서비스 헬스 체크""" + return HealthResponse( + status="healthy" if translator.is_ready() else "not ready", + model_loaded=translator.is_ready(), + device=translator.device, + ) + + +@app.get("/languages", response_model=LanguagesResponse, tags=["Languages"]) +async def get_supported_languages(): + """지원하는 언어 목록 조회""" + return LanguagesResponse(supported_languages=config.SUPPORTED_LANGUAGES) + + +@app.post( + "/translate", + response_model=TranslationResponse, + tags=["Translation"], + status_code=status.HTTP_200_OK, +) +async def translate_text(request: TranslationRequest): + """ + 텍스트 번역 API + + - **text**: 번역할 텍스트 + - **source_lang**: 소스 언어 코드 (예: ko, en, ja) + - **target_lang**: 타겟 언어 코드 (예: en, ko, ja) + """ + if not translator.is_ready(): + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Translation model is not ready", + ) + + # 언어 코드 검증 + if request.source_lang not in config.SUPPORTED_LANGUAGES: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported source language: {request.source_lang}", + ) + + if request.target_lang not in config.SUPPORTED_LANGUAGES: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported target language: {request.target_lang}", + ) + + try: + translated_text = translator.translate( + text=request.text, + source_lang=request.source_lang, + target_lang=request.target_lang, + ) + + return TranslationResponse( + translated_text=translated_text, + source_lang=request.source_lang, + target_lang=request.target_lang, + original_text=request.text, + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) + ) + except Exception as e: + logger.error(f"Translation failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Translation failed", + ) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "main:app", + host=config.HOST, + port=config.PORT, + reload=True, + ) diff --git a/models.py b/models.py new file mode 100644 index 0000000..89b8eeb --- /dev/null +++ b/models.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class TranslationRequest(BaseModel): + """번역 요청 모델""" + + text: str = Field(..., description="번역할 텍스트", min_length=1) + source_lang: str = Field(..., description="소스 언어 코드 (예: ko, en, ja)") + target_lang: str = Field(..., description="타겟 언어 코드 (예: en, ko, ja)") + + class Config: + json_schema_extra = { + "example": { + "text": "안녕하세요, 반갑습니다.", + "source_lang": "ko", + "target_lang": "en", + } + } + + +class TranslationResponse(BaseModel): + """번역 응답 모델""" + + translated_text: str = Field(..., description="번역된 텍스트") + source_lang: str = Field(..., description="소스 언어 코드") + target_lang: str = Field(..., description="타겟 언어 코드") + original_text: str = Field(..., description="원본 텍스트") + + class Config: + json_schema_extra = { + "example": { + "translated_text": "Hello, nice to meet you.", + "source_lang": "ko", + "target_lang": "en", + "original_text": "안녕하세요, 반갑습니다.", + } + } + + +class HealthResponse(BaseModel): + """헬스 체크 응답 모델""" + + status: str = Field(..., description="서비스 상태") + model_loaded: bool = Field(..., description="모델 로드 여부") + device: str = Field(..., description="사용 중인 디바이스 (cpu/cuda)") + + +class LanguagesResponse(BaseModel): + """지원 언어 목록 응답 모델""" + + supported_languages: dict = Field(..., description="지원하는 언어 코드 목록") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..510a87b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +transformers==4.35.2 +torch==2.1.1 +sentencepiece==0.1.99 +protobuf==4.25.1 +pydantic==2.5.0 +python-multipart==0.0.6 diff --git a/translator.py b/translator.py new file mode 100644 index 0000000..f4666fe --- /dev/null +++ b/translator.py @@ -0,0 +1,105 @@ +import torch +from transformers import MBartForConditionalGeneration, MBart50TokenizerFast +from typing import Optional +import logging + +from config import config + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class MBartTranslator: + """mBART 모델을 사용한 번역 서비스""" + + def __init__(self): + self.model: Optional[MBartForConditionalGeneration] = None + self.tokenizer: Optional[MBart50TokenizerFast] = None + self.device = config.DEVICE + self.model_name = config.MODEL_NAME + self.max_length = config.MAX_LENGTH + + def load_model(self): + """모델과 토크나이저를 로드합니다.""" + try: + logger.info(f"Loading mBART model: {self.model_name}") + self.tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name) + self.model = MBartForConditionalGeneration.from_pretrained(self.model_name) + + # GPU 사용 가능 시 모델을 GPU로 이동 + if self.device == "cuda" and torch.cuda.is_available(): + self.model = self.model.to("cuda") + logger.info("Model loaded on CUDA") + else: + self.device = "cpu" + logger.info("Model loaded on CPU") + + logger.info("Model loaded successfully") + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + raise + + def translate( + self, text: str, source_lang: str, target_lang: str + ) -> str: + """ + 텍스트를 번역합니다. + + Args: + text: 번역할 텍스트 + source_lang: 소스 언어 코드 (예: "ko", "en") + target_lang: 타겟 언어 코드 (예: "en", "ko") + + Returns: + 번역된 텍스트 + """ + if self.model is None or self.tokenizer is None: + raise RuntimeError("Model not loaded. Call load_model() first.") + + # 언어 코드를 mBART 형식으로 변환 + source_lang_code = config.SUPPORTED_LANGUAGES.get(source_lang) + target_lang_code = config.SUPPORTED_LANGUAGES.get(target_lang) + + if not source_lang_code or not target_lang_code: + raise ValueError( + f"Unsupported language. Source: {source_lang}, Target: {target_lang}" + ) + + try: + # 소스 언어 설정 + self.tokenizer.src_lang = source_lang_code + + # 입력 텍스트 토큰화 + encoded = self.tokenizer( + text, return_tensors="pt", max_length=self.max_length, truncation=True + ) + + # GPU 사용 시 입력도 GPU로 이동 + if self.device == "cuda": + encoded = {k: v.to("cuda") for k, v in encoded.items()} + + # 번역 생성 + generated_tokens = self.model.generate( + **encoded, + forced_bos_token_id=self.tokenizer.lang_code_to_id[target_lang_code], + max_length=self.max_length, + ) + + # 디코딩 + translated_text = self.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + )[0] + + return translated_text + + except Exception as e: + logger.error(f"Translation error: {str(e)}") + raise + + def is_ready(self) -> bool: + """모델이 로드되어 사용 가능한지 확인합니다.""" + return self.model is not None and self.tokenizer is not None + + +# 글로벌 translator 인스턴스 +translator = MBartTranslator()