import torch from transformers import MBartForConditionalGeneration, MBart50TokenizerFast from typing import Optional import logging from config import config logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MBartTranslator: """mBART 모델을 사용한 번역 서비스""" def __init__(self): self.model: Optional[MBartForConditionalGeneration] = None self.tokenizer: Optional[MBart50TokenizerFast] = None self.device = config.DEVICE self.model_name = config.MODEL_NAME self.max_length = config.MAX_LENGTH def load_model(self): """모델과 토크나이저를 로드합니다.""" try: logger.info(f"Loading mBART model: {self.model_name}") self.tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name) self.model = MBartForConditionalGeneration.from_pretrained(self.model_name) # GPU 사용 가능 시 모델을 GPU로 이동 if self.device == "cuda" and torch.cuda.is_available(): self.model = self.model.to("cuda") logger.info("Model loaded on CUDA") else: self.device = "cpu" logger.info("Model loaded on CPU") logger.info("Model loaded successfully") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise def translate( self, text: str, source_lang: str, target_lang: str ) -> str: """ 텍스트를 번역합니다. Args: text: 번역할 텍스트 source_lang: 소스 언어 코드 (예: "ko", "en") target_lang: 타겟 언어 코드 (예: "en", "ko") Returns: 번역된 텍스트 """ if self.model is None or self.tokenizer is None: raise RuntimeError("Model not loaded. Call load_model() first.") # 언어 코드를 mBART 형식으로 변환 source_lang_code = config.SUPPORTED_LANGUAGES.get(source_lang) target_lang_code = config.SUPPORTED_LANGUAGES.get(target_lang) if not source_lang_code or not target_lang_code: raise ValueError( f"Unsupported language. Source: {source_lang}, Target: {target_lang}" ) try: # 소스 언어 설정 self.tokenizer.src_lang = source_lang_code # 입력 텍스트 토큰화 encoded = self.tokenizer( text, return_tensors="pt", max_length=self.max_length, truncation=True ) # GPU 사용 시 입력도 GPU로 이동 if self.device == "cuda": encoded = {k: v.to("cuda") for k, v in encoded.items()} # 번역 생성 generated_tokens = self.model.generate( **encoded, forced_bos_token_id=self.tokenizer.lang_code_to_id[target_lang_code], max_length=self.max_length, ) # 디코딩 translated_text = self.tokenizer.batch_decode( generated_tokens, skip_special_tokens=True )[0] return translated_text except Exception as e: logger.error(f"Translation error: {str(e)}") raise def is_ready(self) -> bool: """모델이 로드되어 사용 가능한지 확인합니다.""" return self.model is not None and self.tokenizer is not None # 글로벌 translator 인스턴스 translator = MBartTranslator()