commit c8802cfc654b40eba72814cd4a28906891ba7c18 Author: jungwoo choi Date: Mon Nov 10 09:57:19 2025 +0900 Initial commit: mBART Translation API with Docker support - FastAPI 기반 다국어 번역 REST API 서비스 - mBART-50 모델을 사용한 18개 언어 지원 - Docker 및 Docker Compose 설정 포함 - GPU/CPU 지원 - 헬스 체크 및 API 문서 자동 생성 - 외부 접속 지원 (172.30.1.2:8000) 주요 파일: - main.py: FastAPI 애플리케이션 - translator.py: mBART 번역 서비스 - models.py: Pydantic 데이터 모델 - config.py: 환경 설정 - Dockerfile: 최적화된 Docker 이미지 - docker-compose.yml: 간편한 배포 설정 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..96b780b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,55 @@ +# Python cache and compiled files +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +env/ +venv/ +ENV/ + +# Models (will be downloaded at runtime) +models/ +*.bin +*.pt +*.pth + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Git +.git/ +.gitignore + +# Environment +.env +.env.local + +# Logs +*.log +logs/ + +# Documentation (not needed in container) +README.md +CLAUDE.md + +# Docker files (not needed inside container) +Dockerfile +.dockerignore +docker-compose.yml + +# OS +.DS_Store +Thumbs.db + +# Cache +.cache/ diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..f93ede0 --- /dev/null +++ b/.env.example @@ -0,0 +1,8 @@ +# Server Configuration +HOST=0.0.0.0 +PORT=8000 + +# Model Configuration +MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt +MAX_LENGTH=512 +DEVICE=cuda # or cpu diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9fe26ac --- /dev/null +++ b/.gitignore @@ -0,0 +1,49 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +venv/ +ENV/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Models +models/ +*.bin +*.pt +*.pth + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Environment +.env +.env.local + +# Logs +*.log +logs/ + +# OS +.DS_Store +Thumbs.db diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..c8b771a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,151 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +mBART Translation API - FastAPI 기반의 다국어 번역 REST API 서비스. Facebook의 mBART-50 모델을 사용하여 18개 이상의 언어 간 번역을 제공합니다. + +## Commands + +### Docker (권장) + +```bash +# Docker Compose로 시작 (가장 쉬움) +docker-compose up -d + +# 로그 확인 +docker-compose logs -f + +# 서비스 중지 +docker-compose down + +# Docker 직접 사용 +docker build -t mbart-translation-api . +docker run -d -p 8000:8000 --name mbart-api mbart-translation-api + +# GPU 지원 +docker run -d --gpus all -p 8000:8000 -e DEVICE=cuda --name mbart-api-gpu mbart-translation-api +``` + +### Local Development + +```bash +# 개발 서버 실행 (자동 리로드) +python main.py + +# 프로덕션 서버 실행 +uvicorn main:app --host 0.0.0.0 --port 8000 + +# GPU 사용하여 실행 +DEVICE=cuda python main.py +``` + +### Installation + +```bash +# 의존성 설치 +pip install -r requirements.txt + +# 가상 환경 사용 권장 +python -m venv venv +source venv/bin/activate # Linux/Mac +# or +venv\Scripts\activate # Windows +pip install -r requirements.txt +``` + +### Testing API + +```bash +# 헬스 체크 +curl http://localhost:8000/health + +# 번역 테스트 (한국어 -> 영어) +curl -X POST "http://localhost:8000/translate" \ + -H "Content-Type: application/json" \ + -d '{"text": "안녕하세요", "source_lang": "ko", "target_lang": "en"}' + +# 지원 언어 목록 조회 +curl http://localhost:8000/languages +``` + +## Architecture + +### Core Components + +1. **main.py** - FastAPI 애플리케이션 + - `/translate` POST: 번역 요청 처리 + - `/health` GET: 서비스 상태 확인 + - `/languages` GET: 지원 언어 목록 + - `lifespan`: 애플리케이션 시작 시 모델 로드 + +2. **translator.py** - MBartTranslator 클래스 + - `load_model()`: mBART 모델과 토크나이저 초기화 + - `translate()`: 실제 번역 수행 + - GPU/CPU 자동 감지 및 할당 + - 모델: facebook/mbart-large-50-many-to-many-mmt + +3. **models.py** - Pydantic 데이터 모델 + - TranslationRequest: 번역 요청 스키마 + - TranslationResponse: 번역 응답 스키마 + - HealthResponse, LanguagesResponse + +4. **config.py** - 설정 관리 + - 환경 변수 기반 설정 + - SUPPORTED_LANGUAGES: 언어 코드 매핑 (예: "ko" -> "ko_KR") + +### Request Flow + +``` +Client Request → FastAPI Endpoint → Validation (Pydantic) + ↓ +Response ← Translated Text ← MBartTranslator.translate() +``` + +### Model Loading + +- 애플리케이션 시작 시 `lifespan` 이벤트에서 모델 로드 +- 첫 실행 시 HuggingFace에서 모델 다운로드 (약 2.4GB) +- 이후 실행에서는 캐시된 모델 사용 (~/.cache/huggingface/) + +### Language Code Mapping + +mBART는 특정 언어 코드 형식을 사용합니다: +- 일반 코드 (ko, en) → mBART 코드 (ko_KR, en_XX) +- config.py의 SUPPORTED_LANGUAGES에 정의 +- 새 언어 추가 시 이 딕셔너리에 추가 필요 + +## Key Implementation Details + +- **Device Selection**: CUDA 사용 가능 시 자동으로 GPU 사용, 아니면 CPU +- **Token Length**: MAX_LENGTH(기본 512)로 입력 길이 제한 +- **Error Handling**: 언어 코드 검증, 모델 로드 상태 확인 +- **CORS**: 모든 origin 허용 (프로덕션에서는 제한 필요) + +## Configuration + +환경 변수로 설정 변경: +- `HOST`, `PORT`: 서버 바인딩 설정 +- `MODEL_NAME`: 사용할 mBART 모델 (다른 변형 가능) +- `MAX_LENGTH`: 최대 토큰 길이 +- `DEVICE`: "cuda" 또는 "cpu" + +## Docker Details + +### Image Structure +- Base: python:3.11-slim +- Non-root user (appuser) for security +- Health check endpoint: /health +- Model cache volume: /home/appuser/.cache/huggingface + +### Volume Management +모델은 약 2.4GB로, 컨테이너 재시작 시 재다운로드 방지를 위해 볼륨 마운트 필수: +```bash +-v mbart-cache:/home/appuser/.cache/huggingface +``` + +### Resource Requirements +- CPU: 최소 8GB RAM (16GB 권장) +- GPU: 최소 8GB VRAM +- Disk: 모델 캐시용 약 5GB 여유 공간 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..4a9e3e8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,39 @@ +# Multi-stage build for smaller image size +FROM python:3.11-slim as base + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Set working directory +WORKDIR /app + +# Copy requirements first for better caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create non-root user for security +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app && \ + mkdir -p /home/appuser/.cache/huggingface && \ + chown -R appuser:appuser /home/appuser/.cache + +# Switch to non-root user +USER appuser + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the application +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..052b805 --- /dev/null +++ b/README.md @@ -0,0 +1,274 @@ +# mBART Translation API + +mBART 모델을 사용한 다국어 번역 REST API 서비스입니다. + +## 기능 + +- mBART-50 모델을 사용한 다국어 번역 +- RESTful API 인터페이스 +- 18개 이상의 언어 지원 (한국어, 영어, 일본어, 중국어 등) +- GPU/CPU 지원 +- 자동 API 문서화 (Swagger UI) + +## 지원 언어 + +- 한국어 (ko) +- 영어 (en) +- 일본어 (ja) +- 중국어 (zh) +- 스페인어 (es) +- 프랑스어 (fr) +- 독일어 (de) +- 러시아어 (ru) +- 아랍어 (ar) +- 힌디어 (hi) +- 베트남어 (vi) +- 태국어 (th) +- 인도네시아어 (id) +- 터키어 (tr) +- 포르투갈어 (pt) +- 이탈리아어 (it) +- 네덜란드어 (nl) +- 폴란드어 (pl) + +## 빠른 시작 (Docker 권장) + +### Docker Compose 사용 (가장 쉬운 방법) + +```bash +# 1. 서비스 시작 +docker-compose up -d + +# 2. 로그 확인 +docker-compose logs -f + +# 3. 서비스 중지 +docker-compose down +``` + +서비스가 시작되면 http://localhost:8000 에서 API를 사용할 수 있습니다. + +### 외부 접속 + +외부에서 접속하려면 서버의 IP 주소를 사용하세요: +- 로컬: http://localhost:8000 +- 외부: http://172.30.1.2:8000 (서버 IP 주소로 변경) + +### Docker 직접 사용 + +```bash +# 1. 이미지 빌드 +docker build -t mbart-translation-api . + +# 2. 컨테이너 실행 +docker run -d \ + --name mbart-api \ + -p 8000:8000 \ + -e DEVICE=cpu \ + -v mbart-cache:/home/appuser/.cache/huggingface \ + mbart-translation-api + +# 3. 로그 확인 +docker logs -f mbart-api + +# 4. 컨테이너 중지 +docker stop mbart-api +docker rm mbart-api +``` + +### GPU 지원 (NVIDIA GPU) + +GPU를 사용하려면 [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)이 설치되어 있어야 합니다. + +```bash +docker run -d \ + --name mbart-api-gpu \ + --gpus all \ + -p 8000:8000 \ + -e DEVICE=cuda \ + -v mbart-cache:/home/appuser/.cache/huggingface \ + mbart-translation-api +``` + +또는 docker-compose.yml에서 GPU 서비스 주석을 해제하세요. + +## 로컬 설치 및 실행 + +### 1. 의존성 설치 + +```bash +pip install -r requirements.txt +``` + +### 2. 환경 변수 설정 (선택사항) + +```bash +cp .env.example .env +``` + +`.env` 파일을 편집하여 설정을 변경할 수 있습니다: + +``` +HOST=0.0.0.0 +PORT=8000 +MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt +MAX_LENGTH=512 +DEVICE=cpu # GPU 사용 시 cuda로 변경 +``` + +### 3. 실행 + +#### 개발 모드 + +```bash +python main.py +``` + +#### 프로덕션 모드 + +```bash +uvicorn main:app --host 0.0.0.0 --port 8000 +``` + +#### GPU 사용 + +```bash +DEVICE=cuda python main.py +``` + +## API 사용법 + +### API 문서 + +서버 실행 후 다음 URL에서 자동 생성된 API 문서를 확인할 수 있습니다: +- Swagger UI: http://localhost:8000/docs (또는 http://172.30.1.2:8000/docs) +- ReDoc: http://localhost:8000/redoc (또는 http://172.30.1.2:8000/redoc) + +### 엔드포인트 + +#### 1. 번역 API + +**POST** `/translate` + +요청 예시: + +```bash +# 로컬에서 테스트 +curl -X POST "http://localhost:8000/translate" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "안녕하세요, 반갑습니다.", + "source_lang": "ko", + "target_lang": "en" + }' + +# 외부에서 접속 +curl -X POST "http://172.30.1.2:8000/translate" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "안녕하세요, 반갑습니다.", + "source_lang": "ko", + "target_lang": "en" + }' +``` + +응답 예시: + +```json +{ + "translated_text": "Hello, nice to meet you.", + "source_lang": "ko", + "target_lang": "en", + "original_text": "안녕하세요, 반갑습니다." +} +``` + +#### 2. 헬스 체크 + +**GET** `/health` + +```bash +# 로컬 +curl http://localhost:8000/health + +# 외부 +curl http://172.30.1.2:8000/health +``` + +응답: + +```json +{ + "status": "healthy", + "model_loaded": true, + "device": "cpu" +} +``` + +#### 3. 지원 언어 목록 + +**GET** `/languages` + +```bash +# 로컬 +curl http://localhost:8000/languages + +# 외부 +curl http://172.30.1.2:8000/languages +``` + +## 프로젝트 구조 + +``` +. +├── main.py # FastAPI 애플리케이션 메인 파일 +├── translator.py # mBART 번역 서비스 클래스 +├── models.py # Pydantic 데이터 모델 +├── config.py # 설정 파일 +├── requirements.txt # Python 의존성 +├── Dockerfile # Docker 이미지 빌드 설정 +├── docker-compose.yml # Docker Compose 설정 +├── .dockerignore # Docker 빌드 제외 파일 +├── .env.example # 환경 변수 예시 +├── .gitignore # Git 무시 파일 +├── README.md # 프로젝트 문서 +└── CLAUDE.md # 코드베이스 가이드 +``` + +## 성능 최적화 + +### GPU 사용 + +CUDA가 설치된 환경에서는 GPU를 사용하여 번역 속도를 크게 향상시킬 수 있습니다: + +```bash +DEVICE=cuda python main.py +``` + +### 모델 캐싱 + +첫 실행 시 모델이 다운로드되며, 이후 실행에서는 캐시된 모델을 사용합니다. + +## 문제 해결 + +### 메모리 부족 + +mBART 모델은 크기가 크므로 충분한 메모리가 필요합니다: +- CPU: 최소 8GB RAM 권장 +- GPU: 최소 8GB VRAM 권장 + +메모리가 부족한 경우 `MAX_LENGTH` 값을 줄이거나 더 작은 모델을 사용하세요. + +### CUDA 오류 + +GPU 사용 시 CUDA 관련 오류가 발생하면: + +```bash +DEVICE=cpu python main.py +``` + +CPU 모드로 전환하여 실행하세요. + +## 라이선스 + +이 프로젝트는 mBART 모델 (Facebook AI)을 사용합니다. diff --git a/config.py b/config.py new file mode 100644 index 0000000..4d3ad55 --- /dev/null +++ b/config.py @@ -0,0 +1,38 @@ +import os +from typing import Optional + +class Config: + """Application configuration""" + + # Server settings + HOST: str = os.getenv("HOST", "0.0.0.0") + PORT: int = int(os.getenv("PORT", "8000")) + + # Model settings + MODEL_NAME: str = os.getenv("MODEL_NAME", "facebook/mbart-large-50-many-to-many-mmt") + MAX_LENGTH: int = int(os.getenv("MAX_LENGTH", "512")) + DEVICE: str = os.getenv("DEVICE", "cpu") + + # Supported languages for mBART-50 + SUPPORTED_LANGUAGES = { + "ko": "ko_KR", # Korean + "en": "en_XX", # English + "ja": "ja_XX", # Japanese + "zh": "zh_CN", # Chinese (Simplified) + "es": "es_XX", # Spanish + "fr": "fr_XX", # French + "de": "de_DE", # German + "ru": "ru_RU", # Russian + "ar": "ar_AR", # Arabic + "hi": "hi_IN", # Hindi + "vi": "vi_VN", # Vietnamese + "th": "th_TH", # Thai + "id": "id_ID", # Indonesian + "tr": "tr_TR", # Turkish + "pt": "pt_XX", # Portuguese + "it": "it_IT", # Italian + "nl": "nl_XX", # Dutch + "pl": "pl_PL", # Polish + } + +config = Config() diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..4cc0919 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,55 @@ +version: '3.8' + +services: + mbart-api: + build: + context: . + dockerfile: Dockerfile + container_name: mbart-translation-api + ports: + - "8000:8000" + environment: + - HOST=0.0.0.0 + - PORT=8000 + - MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt + - MAX_LENGTH=512 + - DEVICE=cpu # Change to 'cuda' for GPU support + volumes: + # Cache HuggingFace models to avoid re-downloading + - huggingface-cache:/home/appuser/.cache/huggingface + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s + + # GPU support (uncomment if you have NVIDIA GPU) + # mbart-api-gpu: + # build: + # context: . + # dockerfile: Dockerfile + # container_name: mbart-translation-api-gpu + # ports: + # - "8000:8000" + # environment: + # - HOST=0.0.0.0 + # - PORT=8000 + # - MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt + # - MAX_LENGTH=512 + # - DEVICE=cuda + # volumes: + # - huggingface-cache:/home/appuser/.cache/huggingface + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] + # restart: unless-stopped + +volumes: + huggingface-cache: + driver: local diff --git a/main.py b/main.py new file mode 100644 index 0000000..c363516 --- /dev/null +++ b/main.py @@ -0,0 +1,143 @@ +from fastapi import FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +import logging + +from config import config +from translator import translator +from models import ( + TranslationRequest, + TranslationResponse, + HealthResponse, + LanguagesResponse, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """애플리케이션 시작 시 모델을 로드합니다.""" + logger.info("Starting up: Loading mBART model...") + try: + translator.load_model() + logger.info("Model loaded successfully") + except Exception as e: + logger.error(f"Failed to load model: {str(e)}") + raise + yield + logger.info("Shutting down...") + + +app = FastAPI( + title="mBART Translation API", + description="mBART 모델을 사용한 다국어 번역 API 서비스", + version="1.0.0", + lifespan=lifespan, +) + +# CORS 설정 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/", tags=["Root"]) +async def root(): + """API 루트 엔드포인트""" + return { + "message": "mBART Translation API", + "docs": "/docs", + "health": "/health", + } + + +@app.get("/health", response_model=HealthResponse, tags=["Health"]) +async def health_check(): + """서비스 헬스 체크""" + return HealthResponse( + status="healthy" if translator.is_ready() else "not ready", + model_loaded=translator.is_ready(), + device=translator.device, + ) + + +@app.get("/languages", response_model=LanguagesResponse, tags=["Languages"]) +async def get_supported_languages(): + """지원하는 언어 목록 조회""" + return LanguagesResponse(supported_languages=config.SUPPORTED_LANGUAGES) + + +@app.post( + "/translate", + response_model=TranslationResponse, + tags=["Translation"], + status_code=status.HTTP_200_OK, +) +async def translate_text(request: TranslationRequest): + """ + 텍스트 번역 API + + - **text**: 번역할 텍스트 + - **source_lang**: 소스 언어 코드 (예: ko, en, ja) + - **target_lang**: 타겟 언어 코드 (예: en, ko, ja) + """ + if not translator.is_ready(): + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Translation model is not ready", + ) + + # 언어 코드 검증 + if request.source_lang not in config.SUPPORTED_LANGUAGES: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported source language: {request.source_lang}", + ) + + if request.target_lang not in config.SUPPORTED_LANGUAGES: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported target language: {request.target_lang}", + ) + + try: + translated_text = translator.translate( + text=request.text, + source_lang=request.source_lang, + target_lang=request.target_lang, + ) + + return TranslationResponse( + translated_text=translated_text, + source_lang=request.source_lang, + target_lang=request.target_lang, + original_text=request.text, + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) + ) + except Exception as e: + logger.error(f"Translation failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Translation failed", + ) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "main:app", + host=config.HOST, + port=config.PORT, + reload=True, + ) diff --git a/models.py b/models.py new file mode 100644 index 0000000..89b8eeb --- /dev/null +++ b/models.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class TranslationRequest(BaseModel): + """번역 요청 모델""" + + text: str = Field(..., description="번역할 텍스트", min_length=1) + source_lang: str = Field(..., description="소스 언어 코드 (예: ko, en, ja)") + target_lang: str = Field(..., description="타겟 언어 코드 (예: en, ko, ja)") + + class Config: + json_schema_extra = { + "example": { + "text": "안녕하세요, 반갑습니다.", + "source_lang": "ko", + "target_lang": "en", + } + } + + +class TranslationResponse(BaseModel): + """번역 응답 모델""" + + translated_text: str = Field(..., description="번역된 텍스트") + source_lang: str = Field(..., description="소스 언어 코드") + target_lang: str = Field(..., description="타겟 언어 코드") + original_text: str = Field(..., description="원본 텍스트") + + class Config: + json_schema_extra = { + "example": { + "translated_text": "Hello, nice to meet you.", + "source_lang": "ko", + "target_lang": "en", + "original_text": "안녕하세요, 반갑습니다.", + } + } + + +class HealthResponse(BaseModel): + """헬스 체크 응답 모델""" + + status: str = Field(..., description="서비스 상태") + model_loaded: bool = Field(..., description="모델 로드 여부") + device: str = Field(..., description="사용 중인 디바이스 (cpu/cuda)") + + +class LanguagesResponse(BaseModel): + """지원 언어 목록 응답 모델""" + + supported_languages: dict = Field(..., description="지원하는 언어 코드 목록") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..510a87b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +transformers==4.35.2 +torch==2.1.1 +sentencepiece==0.1.99 +protobuf==4.25.1 +pydantic==2.5.0 +python-multipart==0.0.6 diff --git a/translator.py b/translator.py new file mode 100644 index 0000000..f4666fe --- /dev/null +++ b/translator.py @@ -0,0 +1,105 @@ +import torch +from transformers import MBartForConditionalGeneration, MBart50TokenizerFast +from typing import Optional +import logging + +from config import config + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class MBartTranslator: + """mBART 모델을 사용한 번역 서비스""" + + def __init__(self): + self.model: Optional[MBartForConditionalGeneration] = None + self.tokenizer: Optional[MBart50TokenizerFast] = None + self.device = config.DEVICE + self.model_name = config.MODEL_NAME + self.max_length = config.MAX_LENGTH + + def load_model(self): + """모델과 토크나이저를 로드합니다.""" + try: + logger.info(f"Loading mBART model: {self.model_name}") + self.tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name) + self.model = MBartForConditionalGeneration.from_pretrained(self.model_name) + + # GPU 사용 가능 시 모델을 GPU로 이동 + if self.device == "cuda" and torch.cuda.is_available(): + self.model = self.model.to("cuda") + logger.info("Model loaded on CUDA") + else: + self.device = "cpu" + logger.info("Model loaded on CPU") + + logger.info("Model loaded successfully") + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + raise + + def translate( + self, text: str, source_lang: str, target_lang: str + ) -> str: + """ + 텍스트를 번역합니다. + + Args: + text: 번역할 텍스트 + source_lang: 소스 언어 코드 (예: "ko", "en") + target_lang: 타겟 언어 코드 (예: "en", "ko") + + Returns: + 번역된 텍스트 + """ + if self.model is None or self.tokenizer is None: + raise RuntimeError("Model not loaded. Call load_model() first.") + + # 언어 코드를 mBART 형식으로 변환 + source_lang_code = config.SUPPORTED_LANGUAGES.get(source_lang) + target_lang_code = config.SUPPORTED_LANGUAGES.get(target_lang) + + if not source_lang_code or not target_lang_code: + raise ValueError( + f"Unsupported language. Source: {source_lang}, Target: {target_lang}" + ) + + try: + # 소스 언어 설정 + self.tokenizer.src_lang = source_lang_code + + # 입력 텍스트 토큰화 + encoded = self.tokenizer( + text, return_tensors="pt", max_length=self.max_length, truncation=True + ) + + # GPU 사용 시 입력도 GPU로 이동 + if self.device == "cuda": + encoded = {k: v.to("cuda") for k, v in encoded.items()} + + # 번역 생성 + generated_tokens = self.model.generate( + **encoded, + forced_bos_token_id=self.tokenizer.lang_code_to_id[target_lang_code], + max_length=self.max_length, + ) + + # 디코딩 + translated_text = self.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + )[0] + + return translated_text + + except Exception as e: + logger.error(f"Translation error: {str(e)}") + raise + + def is_ready(self) -> bool: + """모델이 로드되어 사용 가능한지 확인합니다.""" + return self.model is not None and self.tokenizer is not None + + +# 글로벌 translator 인스턴스 +translator = MBartTranslator()