Initial commit: mBART Translation API with Docker support

- FastAPI 기반 다국어 번역 REST API 서비스
- mBART-50 모델을 사용한 18개 언어 지원
- Docker 및 Docker Compose 설정 포함
- GPU/CPU 지원
- 헬스 체크 및 API 문서 자동 생성
- 외부 접속 지원 (172.30.1.2:8000)

주요 파일:
- main.py: FastAPI 애플리케이션
- translator.py: mBART 번역 서비스
- models.py: Pydantic 데이터 모델
- config.py: 환경 설정
- Dockerfile: 최적화된 Docker 이미지
- docker-compose.yml: 간편한 배포 설정

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
jungwoo choi
2025-11-10 09:57:19 +09:00
commit c8802cfc65
12 changed files with 977 additions and 0 deletions

55
.dockerignore Normal file
View File

@ -0,0 +1,55 @@
# Python cache and compiled files
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
*.egg-info/
.installed.cfg
*.egg
# Virtual environments
env/
venv/
ENV/
# Models (will be downloaded at runtime)
models/
*.bin
*.pt
*.pth
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# Git
.git/
.gitignore
# Environment
.env
.env.local
# Logs
*.log
logs/
# Documentation (not needed in container)
README.md
CLAUDE.md
# Docker files (not needed inside container)
Dockerfile
.dockerignore
docker-compose.yml
# OS
.DS_Store
Thumbs.db
# Cache
.cache/

8
.env.example Normal file
View File

@ -0,0 +1,8 @@
# Server Configuration
HOST=0.0.0.0
PORT=8000
# Model Configuration
MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt
MAX_LENGTH=512
DEVICE=cuda # or cpu

49
.gitignore vendored Normal file
View File

@ -0,0 +1,49 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
env/
venv/
ENV/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Models
models/
*.bin
*.pt
*.pth
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# Environment
.env
.env.local
# Logs
*.log
logs/
# OS
.DS_Store
Thumbs.db

151
CLAUDE.md Normal file
View File

@ -0,0 +1,151 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
mBART Translation API - FastAPI 기반의 다국어 번역 REST API 서비스. Facebook의 mBART-50 모델을 사용하여 18개 이상의 언어 간 번역을 제공합니다.
## Commands
### Docker (권장)
```bash
# Docker Compose로 시작 (가장 쉬움)
docker-compose up -d
# 로그 확인
docker-compose logs -f
# 서비스 중지
docker-compose down
# Docker 직접 사용
docker build -t mbart-translation-api .
docker run -d -p 8000:8000 --name mbart-api mbart-translation-api
# GPU 지원
docker run -d --gpus all -p 8000:8000 -e DEVICE=cuda --name mbart-api-gpu mbart-translation-api
```
### Local Development
```bash
# 개발 서버 실행 (자동 리로드)
python main.py
# 프로덕션 서버 실행
uvicorn main:app --host 0.0.0.0 --port 8000
# GPU 사용하여 실행
DEVICE=cuda python main.py
```
### Installation
```bash
# 의존성 설치
pip install -r requirements.txt
# 가상 환경 사용 권장
python -m venv venv
source venv/bin/activate # Linux/Mac
# or
venv\Scripts\activate # Windows
pip install -r requirements.txt
```
### Testing API
```bash
# 헬스 체크
curl http://localhost:8000/health
# 번역 테스트 (한국어 -> 영어)
curl -X POST "http://localhost:8000/translate" \
-H "Content-Type: application/json" \
-d '{"text": "안녕하세요", "source_lang": "ko", "target_lang": "en"}'
# 지원 언어 목록 조회
curl http://localhost:8000/languages
```
## Architecture
### Core Components
1. **main.py** - FastAPI 애플리케이션
- `/translate` POST: 번역 요청 처리
- `/health` GET: 서비스 상태 확인
- `/languages` GET: 지원 언어 목록
- `lifespan`: 애플리케이션 시작 시 모델 로드
2. **translator.py** - MBartTranslator 클래스
- `load_model()`: mBART 모델과 토크나이저 초기화
- `translate()`: 실제 번역 수행
- GPU/CPU 자동 감지 및 할당
- 모델: facebook/mbart-large-50-many-to-many-mmt
3. **models.py** - Pydantic 데이터 모델
- TranslationRequest: 번역 요청 스키마
- TranslationResponse: 번역 응답 스키마
- HealthResponse, LanguagesResponse
4. **config.py** - 설정 관리
- 환경 변수 기반 설정
- SUPPORTED_LANGUAGES: 언어 코드 매핑 (예: "ko" -> "ko_KR")
### Request Flow
```
Client Request → FastAPI Endpoint → Validation (Pydantic)
Response ← Translated Text ← MBartTranslator.translate()
```
### Model Loading
- 애플리케이션 시작 시 `lifespan` 이벤트에서 모델 로드
- 첫 실행 시 HuggingFace에서 모델 다운로드 (약 2.4GB)
- 이후 실행에서는 캐시된 모델 사용 (~/.cache/huggingface/)
### Language Code Mapping
mBART는 특정 언어 코드 형식을 사용합니다:
- 일반 코드 (ko, en) → mBART 코드 (ko_KR, en_XX)
- config.py의 SUPPORTED_LANGUAGES에 정의
- 새 언어 추가 시 이 딕셔너리에 추가 필요
## Key Implementation Details
- **Device Selection**: CUDA 사용 가능 시 자동으로 GPU 사용, 아니면 CPU
- **Token Length**: MAX_LENGTH(기본 512)로 입력 길이 제한
- **Error Handling**: 언어 코드 검증, 모델 로드 상태 확인
- **CORS**: 모든 origin 허용 (프로덕션에서는 제한 필요)
## Configuration
환경 변수로 설정 변경:
- `HOST`, `PORT`: 서버 바인딩 설정
- `MODEL_NAME`: 사용할 mBART 모델 (다른 변형 가능)
- `MAX_LENGTH`: 최대 토큰 길이
- `DEVICE`: "cuda" 또는 "cpu"
## Docker Details
### Image Structure
- Base: python:3.11-slim
- Non-root user (appuser) for security
- Health check endpoint: /health
- Model cache volume: /home/appuser/.cache/huggingface
### Volume Management
모델은 약 2.4GB로, 컨테이너 재시작 시 재다운로드 방지를 위해 볼륨 마운트 필수:
```bash
-v mbart-cache:/home/appuser/.cache/huggingface
```
### Resource Requirements
- CPU: 최소 8GB RAM (16GB 권장)
- GPU: 최소 8GB VRAM
- Disk: 모델 캐시용 약 5GB 여유 공간

39
Dockerfile Normal file
View File

@ -0,0 +1,39 @@
# Multi-stage build for smaller image size
FROM python:3.11-slim as base
# Install system dependencies
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /app
# Copy requirements first for better caching
COPY requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Create non-root user for security
RUN useradd -m -u 1000 appuser && \
chown -R appuser:appuser /app && \
mkdir -p /home/appuser/.cache/huggingface && \
chown -R appuser:appuser /home/appuser/.cache
# Switch to non-root user
USER appuser
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run the application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

274
README.md Normal file
View File

@ -0,0 +1,274 @@
# mBART Translation API
mBART 모델을 사용한 다국어 번역 REST API 서비스입니다.
## 기능
- mBART-50 모델을 사용한 다국어 번역
- RESTful API 인터페이스
- 18개 이상의 언어 지원 (한국어, 영어, 일본어, 중국어 등)
- GPU/CPU 지원
- 자동 API 문서화 (Swagger UI)
## 지원 언어
- 한국어 (ko)
- 영어 (en)
- 일본어 (ja)
- 중국어 (zh)
- 스페인어 (es)
- 프랑스어 (fr)
- 독일어 (de)
- 러시아어 (ru)
- 아랍어 (ar)
- 힌디어 (hi)
- 베트남어 (vi)
- 태국어 (th)
- 인도네시아어 (id)
- 터키어 (tr)
- 포르투갈어 (pt)
- 이탈리아어 (it)
- 네덜란드어 (nl)
- 폴란드어 (pl)
## 빠른 시작 (Docker 권장)
### Docker Compose 사용 (가장 쉬운 방법)
```bash
# 1. 서비스 시작
docker-compose up -d
# 2. 로그 확인
docker-compose logs -f
# 3. 서비스 중지
docker-compose down
```
서비스가 시작되면 http://localhost:8000 에서 API를 사용할 수 있습니다.
### 외부 접속
외부에서 접속하려면 서버의 IP 주소를 사용하세요:
- 로컬: http://localhost:8000
- 외부: http://172.30.1.2:8000 (서버 IP 주소로 변경)
### Docker 직접 사용
```bash
# 1. 이미지 빌드
docker build -t mbart-translation-api .
# 2. 컨테이너 실행
docker run -d \
--name mbart-api \
-p 8000:8000 \
-e DEVICE=cpu \
-v mbart-cache:/home/appuser/.cache/huggingface \
mbart-translation-api
# 3. 로그 확인
docker logs -f mbart-api
# 4. 컨테이너 중지
docker stop mbart-api
docker rm mbart-api
```
### GPU 지원 (NVIDIA GPU)
GPU를 사용하려면 [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)이 설치되어 있어야 합니다.
```bash
docker run -d \
--name mbart-api-gpu \
--gpus all \
-p 8000:8000 \
-e DEVICE=cuda \
-v mbart-cache:/home/appuser/.cache/huggingface \
mbart-translation-api
```
또는 docker-compose.yml에서 GPU 서비스 주석을 해제하세요.
## 로컬 설치 및 실행
### 1. 의존성 설치
```bash
pip install -r requirements.txt
```
### 2. 환경 변수 설정 (선택사항)
```bash
cp .env.example .env
```
`.env` 파일을 편집하여 설정을 변경할 수 있습니다:
```
HOST=0.0.0.0
PORT=8000
MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt
MAX_LENGTH=512
DEVICE=cpu # GPU 사용 시 cuda로 변경
```
### 3. 실행
#### 개발 모드
```bash
python main.py
```
#### 프로덕션 모드
```bash
uvicorn main:app --host 0.0.0.0 --port 8000
```
#### GPU 사용
```bash
DEVICE=cuda python main.py
```
## API 사용법
### API 문서
서버 실행 후 다음 URL에서 자동 생성된 API 문서를 확인할 수 있습니다:
- Swagger UI: http://localhost:8000/docs (또는 http://172.30.1.2:8000/docs)
- ReDoc: http://localhost:8000/redoc (또는 http://172.30.1.2:8000/redoc)
### 엔드포인트
#### 1. 번역 API
**POST** `/translate`
요청 예시:
```bash
# 로컬에서 테스트
curl -X POST "http://localhost:8000/translate" \
-H "Content-Type: application/json" \
-d '{
"text": "안녕하세요, 반갑습니다.",
"source_lang": "ko",
"target_lang": "en"
}'
# 외부에서 접속
curl -X POST "http://172.30.1.2:8000/translate" \
-H "Content-Type: application/json" \
-d '{
"text": "안녕하세요, 반갑습니다.",
"source_lang": "ko",
"target_lang": "en"
}'
```
응답 예시:
```json
{
"translated_text": "Hello, nice to meet you.",
"source_lang": "ko",
"target_lang": "en",
"original_text": "안녕하세요, 반갑습니다."
}
```
#### 2. 헬스 체크
**GET** `/health`
```bash
# 로컬
curl http://localhost:8000/health
# 외부
curl http://172.30.1.2:8000/health
```
응답:
```json
{
"status": "healthy",
"model_loaded": true,
"device": "cpu"
}
```
#### 3. 지원 언어 목록
**GET** `/languages`
```bash
# 로컬
curl http://localhost:8000/languages
# 외부
curl http://172.30.1.2:8000/languages
```
## 프로젝트 구조
```
.
├── main.py # FastAPI 애플리케이션 메인 파일
├── translator.py # mBART 번역 서비스 클래스
├── models.py # Pydantic 데이터 모델
├── config.py # 설정 파일
├── requirements.txt # Python 의존성
├── Dockerfile # Docker 이미지 빌드 설정
├── docker-compose.yml # Docker Compose 설정
├── .dockerignore # Docker 빌드 제외 파일
├── .env.example # 환경 변수 예시
├── .gitignore # Git 무시 파일
├── README.md # 프로젝트 문서
└── CLAUDE.md # 코드베이스 가이드
```
## 성능 최적화
### GPU 사용
CUDA가 설치된 환경에서는 GPU를 사용하여 번역 속도를 크게 향상시킬 수 있습니다:
```bash
DEVICE=cuda python main.py
```
### 모델 캐싱
첫 실행 시 모델이 다운로드되며, 이후 실행에서는 캐시된 모델을 사용합니다.
## 문제 해결
### 메모리 부족
mBART 모델은 크기가 크므로 충분한 메모리가 필요합니다:
- CPU: 최소 8GB RAM 권장
- GPU: 최소 8GB VRAM 권장
메모리가 부족한 경우 `MAX_LENGTH` 값을 줄이거나 더 작은 모델을 사용하세요.
### CUDA 오류
GPU 사용 시 CUDA 관련 오류가 발생하면:
```bash
DEVICE=cpu python main.py
```
CPU 모드로 전환하여 실행하세요.
## 라이선스
이 프로젝트는 mBART 모델 (Facebook AI)을 사용합니다.

38
config.py Normal file
View File

@ -0,0 +1,38 @@
import os
from typing import Optional
class Config:
"""Application configuration"""
# Server settings
HOST: str = os.getenv("HOST", "0.0.0.0")
PORT: int = int(os.getenv("PORT", "8000"))
# Model settings
MODEL_NAME: str = os.getenv("MODEL_NAME", "facebook/mbart-large-50-many-to-many-mmt")
MAX_LENGTH: int = int(os.getenv("MAX_LENGTH", "512"))
DEVICE: str = os.getenv("DEVICE", "cpu")
# Supported languages for mBART-50
SUPPORTED_LANGUAGES = {
"ko": "ko_KR", # Korean
"en": "en_XX", # English
"ja": "ja_XX", # Japanese
"zh": "zh_CN", # Chinese (Simplified)
"es": "es_XX", # Spanish
"fr": "fr_XX", # French
"de": "de_DE", # German
"ru": "ru_RU", # Russian
"ar": "ar_AR", # Arabic
"hi": "hi_IN", # Hindi
"vi": "vi_VN", # Vietnamese
"th": "th_TH", # Thai
"id": "id_ID", # Indonesian
"tr": "tr_TR", # Turkish
"pt": "pt_XX", # Portuguese
"it": "it_IT", # Italian
"nl": "nl_XX", # Dutch
"pl": "pl_PL", # Polish
}
config = Config()

55
docker-compose.yml Normal file
View File

@ -0,0 +1,55 @@
version: '3.8'
services:
mbart-api:
build:
context: .
dockerfile: Dockerfile
container_name: mbart-translation-api
ports:
- "8000:8000"
environment:
- HOST=0.0.0.0
- PORT=8000
- MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt
- MAX_LENGTH=512
- DEVICE=cpu # Change to 'cuda' for GPU support
volumes:
# Cache HuggingFace models to avoid re-downloading
- huggingface-cache:/home/appuser/.cache/huggingface
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
# GPU support (uncomment if you have NVIDIA GPU)
# mbart-api-gpu:
# build:
# context: .
# dockerfile: Dockerfile
# container_name: mbart-translation-api-gpu
# ports:
# - "8000:8000"
# environment:
# - HOST=0.0.0.0
# - PORT=8000
# - MODEL_NAME=facebook/mbart-large-50-many-to-many-mmt
# - MAX_LENGTH=512
# - DEVICE=cuda
# volumes:
# - huggingface-cache:/home/appuser/.cache/huggingface
# deploy:
# resources:
# reservations:
# devices:
# - driver: nvidia
# count: 1
# capabilities: [gpu]
# restart: unless-stopped
volumes:
huggingface-cache:
driver: local

143
main.py Normal file
View File

@ -0,0 +1,143 @@
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import logging
from config import config
from translator import translator
from models import (
TranslationRequest,
TranslationResponse,
HealthResponse,
LanguagesResponse,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""애플리케이션 시작 시 모델을 로드합니다."""
logger.info("Starting up: Loading mBART model...")
try:
translator.load_model()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
yield
logger.info("Shutting down...")
app = FastAPI(
title="mBART Translation API",
description="mBART 모델을 사용한 다국어 번역 API 서비스",
version="1.0.0",
lifespan=lifespan,
)
# CORS 설정
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", tags=["Root"])
async def root():
"""API 루트 엔드포인트"""
return {
"message": "mBART Translation API",
"docs": "/docs",
"health": "/health",
}
@app.get("/health", response_model=HealthResponse, tags=["Health"])
async def health_check():
"""서비스 헬스 체크"""
return HealthResponse(
status="healthy" if translator.is_ready() else "not ready",
model_loaded=translator.is_ready(),
device=translator.device,
)
@app.get("/languages", response_model=LanguagesResponse, tags=["Languages"])
async def get_supported_languages():
"""지원하는 언어 목록 조회"""
return LanguagesResponse(supported_languages=config.SUPPORTED_LANGUAGES)
@app.post(
"/translate",
response_model=TranslationResponse,
tags=["Translation"],
status_code=status.HTTP_200_OK,
)
async def translate_text(request: TranslationRequest):
"""
텍스트 번역 API
- **text**: 번역할 텍스트
- **source_lang**: 소스 언어 코드 (예: ko, en, ja)
- **target_lang**: 타겟 언어 코드 (예: en, ko, ja)
"""
if not translator.is_ready():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Translation model is not ready",
)
# 언어 코드 검증
if request.source_lang not in config.SUPPORTED_LANGUAGES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported source language: {request.source_lang}",
)
if request.target_lang not in config.SUPPORTED_LANGUAGES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported target language: {request.target_lang}",
)
try:
translated_text = translator.translate(
text=request.text,
source_lang=request.source_lang,
target_lang=request.target_lang,
)
return TranslationResponse(
translated_text=translated_text,
source_lang=request.source_lang,
target_lang=request.target_lang,
original_text=request.text,
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)
)
except Exception as e:
logger.error(f"Translation failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Translation failed",
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=config.HOST,
port=config.PORT,
reload=True,
)

52
models.py Normal file
View File

@ -0,0 +1,52 @@
from pydantic import BaseModel, Field
from typing import Optional
class TranslationRequest(BaseModel):
"""번역 요청 모델"""
text: str = Field(..., description="번역할 텍스트", min_length=1)
source_lang: str = Field(..., description="소스 언어 코드 (예: ko, en, ja)")
target_lang: str = Field(..., description="타겟 언어 코드 (예: en, ko, ja)")
class Config:
json_schema_extra = {
"example": {
"text": "안녕하세요, 반갑습니다.",
"source_lang": "ko",
"target_lang": "en",
}
}
class TranslationResponse(BaseModel):
"""번역 응답 모델"""
translated_text: str = Field(..., description="번역된 텍스트")
source_lang: str = Field(..., description="소스 언어 코드")
target_lang: str = Field(..., description="타겟 언어 코드")
original_text: str = Field(..., description="원본 텍스트")
class Config:
json_schema_extra = {
"example": {
"translated_text": "Hello, nice to meet you.",
"source_lang": "ko",
"target_lang": "en",
"original_text": "안녕하세요, 반갑습니다.",
}
}
class HealthResponse(BaseModel):
"""헬스 체크 응답 모델"""
status: str = Field(..., description="서비스 상태")
model_loaded: bool = Field(..., description="모델 로드 여부")
device: str = Field(..., description="사용 중인 디바이스 (cpu/cuda)")
class LanguagesResponse(BaseModel):
"""지원 언어 목록 응답 모델"""
supported_languages: dict = Field(..., description="지원하는 언어 코드 목록")

8
requirements.txt Normal file
View File

@ -0,0 +1,8 @@
fastapi==0.104.1
uvicorn[standard]==0.24.0
transformers==4.35.2
torch==2.1.1
sentencepiece==0.1.99
protobuf==4.25.1
pydantic==2.5.0
python-multipart==0.0.6

105
translator.py Normal file
View File

@ -0,0 +1,105 @@
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from typing import Optional
import logging
from config import config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MBartTranslator:
"""mBART 모델을 사용한 번역 서비스"""
def __init__(self):
self.model: Optional[MBartForConditionalGeneration] = None
self.tokenizer: Optional[MBart50TokenizerFast] = None
self.device = config.DEVICE
self.model_name = config.MODEL_NAME
self.max_length = config.MAX_LENGTH
def load_model(self):
"""모델과 토크나이저를 로드합니다."""
try:
logger.info(f"Loading mBART model: {self.model_name}")
self.tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name)
self.model = MBartForConditionalGeneration.from_pretrained(self.model_name)
# GPU 사용 가능 시 모델을 GPU로 이동
if self.device == "cuda" and torch.cuda.is_available():
self.model = self.model.to("cuda")
logger.info("Model loaded on CUDA")
else:
self.device = "cpu"
logger.info("Model loaded on CPU")
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def translate(
self, text: str, source_lang: str, target_lang: str
) -> str:
"""
텍스트를 번역합니다.
Args:
text: 번역할 텍스트
source_lang: 소스 언어 코드 (예: "ko", "en")
target_lang: 타겟 언어 코드 (예: "en", "ko")
Returns:
번역된 텍스트
"""
if self.model is None or self.tokenizer is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
# 언어 코드를 mBART 형식으로 변환
source_lang_code = config.SUPPORTED_LANGUAGES.get(source_lang)
target_lang_code = config.SUPPORTED_LANGUAGES.get(target_lang)
if not source_lang_code or not target_lang_code:
raise ValueError(
f"Unsupported language. Source: {source_lang}, Target: {target_lang}"
)
try:
# 소스 언어 설정
self.tokenizer.src_lang = source_lang_code
# 입력 텍스트 토큰화
encoded = self.tokenizer(
text, return_tensors="pt", max_length=self.max_length, truncation=True
)
# GPU 사용 시 입력도 GPU로 이동
if self.device == "cuda":
encoded = {k: v.to("cuda") for k, v in encoded.items()}
# 번역 생성
generated_tokens = self.model.generate(
**encoded,
forced_bos_token_id=self.tokenizer.lang_code_to_id[target_lang_code],
max_length=self.max_length,
)
# 디코딩
translated_text = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
return translated_text
except Exception as e:
logger.error(f"Translation error: {str(e)}")
raise
def is_ready(self) -> bool:
"""모델이 로드되어 사용 가능한지 확인합니다."""
return self.model is not None and self.tokenizer is not None
# 글로벌 translator 인스턴스
translator = MBartTranslator()