Initial commit: mBART Translation API with Docker support
- FastAPI 기반 다국어 번역 REST API 서비스 - mBART-50 모델을 사용한 18개 언어 지원 - Docker 및 Docker Compose 설정 포함 - GPU/CPU 지원 - 헬스 체크 및 API 문서 자동 생성 - 외부 접속 지원 (172.30.1.2:8000) 주요 파일: - main.py: FastAPI 애플리케이션 - translator.py: mBART 번역 서비스 - models.py: Pydantic 데이터 모델 - config.py: 환경 설정 - Dockerfile: 최적화된 Docker 이미지 - docker-compose.yml: 간편한 배포 설정 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
105
translator.py
Normal file
105
translator.py
Normal file
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from config import config
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MBartTranslator:
|
||||
"""mBART 모델을 사용한 번역 서비스"""
|
||||
|
||||
def __init__(self):
|
||||
self.model: Optional[MBartForConditionalGeneration] = None
|
||||
self.tokenizer: Optional[MBart50TokenizerFast] = None
|
||||
self.device = config.DEVICE
|
||||
self.model_name = config.MODEL_NAME
|
||||
self.max_length = config.MAX_LENGTH
|
||||
|
||||
def load_model(self):
|
||||
"""모델과 토크나이저를 로드합니다."""
|
||||
try:
|
||||
logger.info(f"Loading mBART model: {self.model_name}")
|
||||
self.tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name)
|
||||
self.model = MBartForConditionalGeneration.from_pretrained(self.model_name)
|
||||
|
||||
# GPU 사용 가능 시 모델을 GPU로 이동
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.model = self.model.to("cuda")
|
||||
logger.info("Model loaded on CUDA")
|
||||
else:
|
||||
self.device = "cpu"
|
||||
logger.info("Model loaded on CPU")
|
||||
|
||||
logger.info("Model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
raise
|
||||
|
||||
def translate(
|
||||
self, text: str, source_lang: str, target_lang: str
|
||||
) -> str:
|
||||
"""
|
||||
텍스트를 번역합니다.
|
||||
|
||||
Args:
|
||||
text: 번역할 텍스트
|
||||
source_lang: 소스 언어 코드 (예: "ko", "en")
|
||||
target_lang: 타겟 언어 코드 (예: "en", "ko")
|
||||
|
||||
Returns:
|
||||
번역된 텍스트
|
||||
"""
|
||||
if self.model is None or self.tokenizer is None:
|
||||
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||
|
||||
# 언어 코드를 mBART 형식으로 변환
|
||||
source_lang_code = config.SUPPORTED_LANGUAGES.get(source_lang)
|
||||
target_lang_code = config.SUPPORTED_LANGUAGES.get(target_lang)
|
||||
|
||||
if not source_lang_code or not target_lang_code:
|
||||
raise ValueError(
|
||||
f"Unsupported language. Source: {source_lang}, Target: {target_lang}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 소스 언어 설정
|
||||
self.tokenizer.src_lang = source_lang_code
|
||||
|
||||
# 입력 텍스트 토큰화
|
||||
encoded = self.tokenizer(
|
||||
text, return_tensors="pt", max_length=self.max_length, truncation=True
|
||||
)
|
||||
|
||||
# GPU 사용 시 입력도 GPU로 이동
|
||||
if self.device == "cuda":
|
||||
encoded = {k: v.to("cuda") for k, v in encoded.items()}
|
||||
|
||||
# 번역 생성
|
||||
generated_tokens = self.model.generate(
|
||||
**encoded,
|
||||
forced_bos_token_id=self.tokenizer.lang_code_to_id[target_lang_code],
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
# 디코딩
|
||||
translated_text = self.tokenizer.batch_decode(
|
||||
generated_tokens, skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
return translated_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Translation error: {str(e)}")
|
||||
raise
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""모델이 로드되어 사용 가능한지 확인합니다."""
|
||||
return self.model is not None and self.tokenizer is not None
|
||||
|
||||
|
||||
# 글로벌 translator 인스턴스
|
||||
translator = MBartTranslator()
|
||||
Reference in New Issue
Block a user