from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer import torch from typing import Dict, Optional import logging from .config import settings logger = logging.getLogger(__name__) class TranslationService: """ Service for handling multilingual translation Uses M2M100 model (Apache 2.0 License - Commercial use allowed) Supports 100 languages for many-to-many translation """ def __init__(self): self.models: Dict[str, Dict] = {} self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # M2M100 supported language codes (100 languages) # Full list: https://huggingface.co/facebook/m2m100_418M self.lang_codes = { # Major languages "en": "en", # English "zh": "zh", # Chinese "es": "es", # Spanish "ar": "ar", # Arabic "hi": "hi", # Hindi "bn": "bn", # Bengali "pt": "pt", # Portuguese "ru": "ru", # Russian "ja": "ja", # Japanese "de": "de", # German "fr": "fr", # French "ko": "ko", # Korean "it": "it", # Italian "tr": "tr", # Turkish "vi": "vi", # Vietnamese "th": "th", # Thai "pl": "pl", # Polish "nl": "nl", # Dutch "uk": "uk", # Ukrainian "ro": "ro", # Romanian # Southeast Asian languages "ms": "ms", # Malay "id": "id", # Indonesian "tl": "tl", # Tagalog "my": "my", # Burmese "km": "km", # Khmer "lo": "lo", # Lao # South Asian languages "ur": "ur", # Urdu "ta": "ta", # Tamil "te": "te", # Telugu "mr": "mr", # Marathi "gu": "gu", # Gujarati "kn": "kn", # Kannada "ml": "ml", # Malayalam "pa": "pa", # Punjabi "ne": "ne", # Nepali "si": "si", # Sinhala # European languages "sv": "sv", # Swedish "da": "da", # Danish "fi": "fi", # Finnish "no": "no", # Norwegian "cs": "cs", # Czech "sk": "sk", # Slovak "hu": "hu", # Hungarian "bg": "bg", # Bulgarian "sr": "sr", # Serbian "hr": "hr", # Croatian "sl": "sl", # Slovenian "et": "et", # Estonian "lv": "lv", # Latvian "lt": "lt", # Lithuanian "el": "el", # Greek "he": "he", # Hebrew "fa": "fa", # Persian # African languages "sw": "sw", # Swahili "am": "am", # Amharic "ha": "ha", # Hausa "ig": "ig", # Igbo "yo": "yo", # Yoruba "zu": "zu", # Zulu "xh": "xh", # Xhosa "af": "af", # Afrikaans # Other major languages "az": "az", # Azerbaijani "ka": "ka", # Georgian "kk": "kk", # Kazakh "uz": "uz", # Uzbek "mn": "mn", # Mongolian # Additional languages (completing 100) "sq": "sq", # Albanian "hy": "hy", # Armenian "be": "be", # Belarusian "bs": "bs", # Bosnian "ca": "ca", # Catalan "ceb": "ceb", # Cebuano "cy": "cy", # Welsh "eo": "eo", # Esperanto "eu": "eu", # Basque "fil": "fil", # Filipino "fy": "fy", # Frisian "ga": "ga", # Irish "gd": "gd", # Scottish Gaelic "gl": "gl", # Galician "haw": "haw", # Hawaiian "hmn": "hmn", # Hmong "ht": "ht", # Haitian Creole "is": "is", # Icelandic "jv": "jv", # Javanese "kn": "kn", # Kannada "ku": "ku", # Kurdish "ky": "ky", # Kyrgyz "la": "la", # Latin "lb": "lb", # Luxembourgish "lg": "lg", # Luganda "ln": "ln", # Lingala "mg": "mg", # Malagasy "mi": "mi", # Maori "mk": "mk", # Macedonian "mt": "mt", # Maltese "ny": "ny", # Chichewa "ps": "ps", # Pashto "sn": "sn", # Shona "so": "so", # Somali "st": "st", # Sesotho "su": "su", # Sundanese "tg": "tg", # Tajik "tk": "tk", # Turkmen "ug": "ug", # Uyghur "yi": "yi", # Yiddish } def _get_model_info(self, source_lang: str, target_lang: str) -> tuple[str, str, str]: """Get the model name and language codes for translation""" # Using M2M100 418M model (smaller, faster, commercial-friendly) model_name = "facebook/m2m100_418M" src_code = self.lang_codes.get(source_lang) tgt_code = self.lang_codes.get(target_lang) if not src_code or not tgt_code: raise ValueError(f"Unsupported language pair: {source_lang} -> {target_lang}") return model_name, src_code, tgt_code def load_model(self, source_lang: str, target_lang: str) -> None: """Load translation model for specific language pair""" model_name, _, _ = self._get_model_info(source_lang, target_lang) if model_name in self.models: logger.info(f"Model {model_name} already loaded") return try: logger.info(f"Loading model: {model_name}") tokenizer = M2M100Tokenizer.from_pretrained( model_name, cache_dir=settings.model_cache_dir ) model = M2M100ForConditionalGeneration.from_pretrained( model_name, cache_dir=settings.model_cache_dir ).to(self.device) self.models[model_name] = { "tokenizer": tokenizer, "model": model } logger.info(f"Successfully loaded model: {model_name}") except Exception as e: logger.error(f"Error loading model {model_name}: {str(e)}") raise def translate(self, text: str, source_lang: str, target_lang: str) -> tuple[str, str]: """ Translate text from source language to target language Args: text: Text to translate source_lang: Source language code target_lang: Target language code Returns: Tuple of (translated_text, model_name) """ model_name, src_code, tgt_code = self._get_model_info(source_lang, target_lang) # Load model if not already loaded if model_name not in self.models: self.load_model(source_lang, target_lang) try: tokenizer = self.models[model_name]["tokenizer"] model = self.models[model_name]["model"] # Set source language for tokenizer tokenizer.src_lang = src_code # Tokenize input inputs = tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=settings.max_length ).to(self.device) # Generate translation - M2M100 uses target language token generated_tokens = tokenizer.get_lang_id(tgt_code) with torch.no_grad(): translated = model.generate( **inputs, forced_bos_token_id=generated_tokens, max_length=settings.max_length ) # Decode output translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] return translated_text, model_name except Exception as e: logger.error(f"Translation error: {str(e)}") raise def preload_all_models(self) -> None: """Preload all supported translation models""" language_pairs = [ ("ms", "en"), ("en", "ms") ] for source, target in language_pairs: try: self.load_model(source, target) except Exception as e: logger.warning(f"Could not preload model for {source}->{target}: {str(e)}") def is_ready(self) -> bool: """Check if at least one model is loaded""" return len(self.models) > 0 # Global translator instance translator = TranslationService()