diff --git a/app/main.py b/app/main.py index e68fe92..3aa77ca 100644 --- a/app/main.py +++ b/app/main.py @@ -103,7 +103,8 @@ async def translate_text(request: TranslationRequest): translated_text, model_used = translator.translate( text=request.text, source_lang=request.source_lang, - target_lang=request.target_lang + target_lang=request.target_lang, + model_type=request.model ) return TranslationResponse( @@ -125,8 +126,15 @@ async def translate_text(request: TranslationRequest): @app.get("/api/supported-languages") -async def get_supported_languages(): - """Get list of supported languages""" +async def get_supported_languages(model: str = "m2m100"): + """ + Get list of supported languages for specified model + + - **model**: Model type ('m2m100' or 'nllb200') + """ + + if model not in ["m2m100", "nllb200"]: + raise HTTPException(status_code=400, detail="Invalid model. Choose 'm2m100' or 'nllb200'") # Language names mapping lang_names = { @@ -235,10 +243,97 @@ async def get_supported_languages(): "tk": {"name": "Turkmen", "native": "Türkmençe"}, "ug": {"name": "Uyghur", "native": "ئۇيغۇرچە"}, "yi": {"name": "Yiddish", "native": "ייִדיש"}, + + # Additional NLLB-200 exclusive languages + "ace": {"name": "Acehnese", "native": "Acèh"}, + "acm": {"name": "Mesopotamian Arabic", "native": "عراقي"}, + "acq": {"name": "Ta'izzi-Adeni Arabic", "native": "تعزية-عدنية"}, + "aeb": {"name": "Tunisian Arabic", "native": "تونسي"}, + "ajp": {"name": "South Levantine Arabic", "native": "شامي"}, + "als": {"name": "Tosk Albanian", "native": "Toskë"}, + "ars": {"name": "Najdi Arabic", "native": "نجدي"}, + "ary": {"name": "Moroccan Arabic", "native": "الدارجة"}, + "arz": {"name": "Egyptian Arabic", "native": "مصري"}, + "asm": {"name": "Assamese", "native": "অসমীয়া"}, + "ast": {"name": "Asturian", "native": "Asturianu"}, + "awa": {"name": "Awadhi", "native": "अवधी"}, + "ayr": {"name": "Central Aymara", "native": "Aymar aru"}, + "azb": {"name": "South Azerbaijani", "native": "تۆرکجه"}, + "bak": {"name": "Bashkir", "native": "Башҡортса"}, + "bam": {"name": "Bambara", "native": "Bamanankan"}, + "ban": {"name": "Balinese", "native": "Basa Bali"}, + "bho": {"name": "Bhojpuri", "native": "भोजपुरी"}, + "bjn": {"name": "Banjar", "native": "Bahasa Banjar"}, + "bod": {"name": "Tibetan", "native": "བོད་སྐད་"}, + "bug": {"name": "Buginese", "native": "Basa Ugi"}, + "crh": {"name": "Crimean Tatar", "native": "Qırımtatar tili"}, + "cjk": {"name": "Chokwe", "native": "Chokwe"}, + "ckb": {"name": "Central Kurdish", "native": "کوردیی ناوەندی"}, + "dik": {"name": "Southwestern Dinka", "native": "Thuɔŋjäŋ"}, + "dyu": {"name": "Dyula", "native": "Jula"}, + "dzo": {"name": "Dzongkha", "native": "རྫོང་ཁ"}, + "fur": {"name": "Friulian", "native": "Furlan"}, + "fuv": {"name": "Nigerian Fulfulde", "native": "Fulfulde"}, + "gaz": {"name": "West Central Oromo", "native": "Oromoo"}, + "grn": {"name": "Guarani", "native": "Avañe'ẽ"}, + "hne": {"name": "Chhattisgarhi", "native": "छत्तीसगढ़ी"}, + "ilo": {"name": "Iloko", "native": "Ilokano"}, + "kab": {"name": "Kabyle", "native": "Taqbaylit"}, + "kac": {"name": "Jingpho", "native": "Jinghpaw"}, + "kam": {"name": "Kamba", "native": "Kikamba"}, + "kas": {"name": "Kashmiri", "native": "कॉशुर"}, + "kea": {"name": "Kabuverdianu", "native": "Kabuverdianu"}, + "khk": {"name": "Halh Mongolian", "native": "Монгол хэл"}, + "kin": {"name": "Kinyarwanda", "native": "Ikinyarwanda"}, + "lij": {"name": "Ligurian", "native": "Ligure"}, + "lim": {"name": "Limburgish", "native": "Limburgs"}, + "lin": {"name": "Lingala", "native": "Lingála"}, + "lmo": {"name": "Lombard", "native": "Lombard"}, + "ltg": {"name": "Latgalian", "native": "Latgalīšu"}, + "luo": {"name": "Luo", "native": "Dholuo"}, + "lus": {"name": "Mizo", "native": "Mizo ṭawng"}, + "mag": {"name": "Magahi", "native": "मगही"}, + "mai": {"name": "Maithili", "native": "मैथिली"}, + "min": {"name": "Minangkabau", "native": "Baso Minangkabau"}, + "mni": {"name": "Meitei", "native": "মৈতৈলোন্"}, + "mos": {"name": "Mossi", "native": "Mooré"}, + "mri": {"name": "Maori", "native": "Te Reo Māori"}, + "nus": {"name": "Nuer", "native": "Thok Naath"}, + "ory": {"name": "Odia", "native": "ଓଡ଼ିଆ"}, + "pag": {"name": "Pangasinan", "native": "Pangasinan"}, + "pap": {"name": "Papiamento", "native": "Papiamentu"}, + "prs": {"name": "Dari", "native": "دری"}, + "quy": {"name": "Ayacucho Quechua", "native": "Chanka Qhichwa"}, + "run": {"name": "Rundi", "native": "Ikirundi"}, + "sag": {"name": "Sango", "native": "Sängö"}, + "san": {"name": "Sanskrit", "native": "संस्कृतम्"}, + "sat": {"name": "Santali", "native": "ᱥᱟᱱᱛᱟᱲᱤ"}, + "scn": {"name": "Sicilian", "native": "Sicilianu"}, + "shn": {"name": "Shan", "native": "လိၵ်ႈတႆး"}, + "srd": {"name": "Sardinian", "native": "Sardu"}, + "szl": {"name": "Silesian", "native": "Ślōnski"}, + "taq": {"name": "Tamasheq", "native": "Tamasheq"}, + "tat": {"name": "Tatar", "native": "Татарча"}, + "tir": {"name": "Tigrinya", "native": "ትግርኛ"}, + "tpi": {"name": "Tok Pisin", "native": "Tok Pisin"}, + "tsn": {"name": "Tswana", "native": "Setswana"}, + "tso": {"name": "Tsonga", "native": "Xitsonga"}, + "tum": {"name": "Tumbuka", "native": "Chitumbuka"}, + "twi": {"name": "Twi", "native": "Twi"}, + "tzm": {"name": "Central Atlas Tamazight", "native": "ⵜⴰⵎⴰⵣⵉⵖⵜ"}, + "uig": {"name": "Uyghur", "native": "ئۇيغۇرچە"}, + "vec": {"name": "Venetian", "native": "Vèneto"}, + "war": {"name": "Waray", "native": "Winaray"}, + "wol": {"name": "Wolof", "native": "Wolof"}, + "xho": {"name": "Xhosa", "native": "isiXhosa"}, + "ydd": {"name": "Eastern Yiddish", "native": "ייִדיש"}, + "yor": {"name": "Yoruba", "native": "Yorùbá"}, + "yue": {"name": "Cantonese", "native": "粵語"}, + "zho_hant": {"name": "Chinese (Traditional)", "native": "繁體中文"}, } - # Get all supported language codes from translator - supported_codes = list(translator.lang_codes.keys()) + # Get all supported language codes from translator based on model type + supported_codes = list(translator.get_supported_languages(model).keys()) # Build language list languages = [ @@ -250,7 +345,25 @@ async def get_supported_languages(): for code in sorted(supported_codes) ] + model_info = { + "m2m100": { + "name": "M2M100", + "languages": 105, + "license": "Apache 2.0", + "commercial_use": True, + "model_id": "facebook/m2m100_418M" + }, + "nllb200": { + "name": "NLLB-200", + "languages": 200, + "license": "CC-BY-NC 4.0", + "commercial_use": False, + "model_id": "facebook/nllb-200-distilled-600M" + } + } + return { + "model": model_info[model], "languages": languages, "total_languages": len(languages), "note": "All language pairs are supported (any-to-any translation)" diff --git a/app/models.py b/app/models.py index 2c07e18..a4b0a5c 100644 --- a/app/models.py +++ b/app/models.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field, field_validator -from typing import Optional +from typing import Optional, Literal class TranslationRequest(BaseModel): @@ -7,6 +7,10 @@ class TranslationRequest(BaseModel): text: str = Field(..., description="Text to translate", min_length=1, max_length=5000) source_lang: str = Field(..., description="Source language code (e.g., 'en', 'ms', 'bn', etc.)", min_length=2, max_length=5) target_lang: str = Field(..., description="Target language code (e.g., 'en', 'ms', 'bn', etc.)", min_length=2, max_length=5) + model: Literal["m2m100", "nllb200"] = Field( + default="m2m100", + description="Translation model to use: 'm2m100' (105 langs, Apache 2.0, commercial OK) or 'nllb200' (200 langs, CC-BY-NC, non-commercial only)" + ) @field_validator('source_lang', 'target_lang') @classmethod @@ -19,7 +23,8 @@ class TranslationRequest(BaseModel): "example": { "text": "Selamat pagi, apa khabar?", "source_lang": "ms", - "target_lang": "en" + "target_lang": "en", + "model": "m2m100" } } diff --git a/app/translator.py b/app/translator.py index 9fc1755..2c3e515 100644 --- a/app/translator.py +++ b/app/translator.py @@ -1,4 +1,4 @@ -from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer +from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, AutoModelForSeq2SeqLM, AutoTokenizer import torch from typing import Dict, Optional import logging @@ -10,8 +10,9 @@ logger = logging.getLogger(__name__) class TranslationService: """ Service for handling multilingual translation - Uses M2M100 model (Apache 2.0 License - Commercial use allowed) - Supports 100 languages for many-to-many translation + Supports two models: + - M2M100 (105 languages, Apache 2.0 License - Commercial use allowed) + - NLLB-200 (200 languages, CC-BY-NC License - Non-commercial only) """ def __init__(self): @@ -19,9 +20,9 @@ class TranslationService: self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") - # M2M100 supported language codes (100 languages) + # M2M100 supported language codes (105 languages) # Full list: https://huggingface.co/facebook/m2m100_418M - self.lang_codes = { + self.m2m100_lang_codes = { # Major languages "en": "en", # English "zh": "zh", # Chinese @@ -100,7 +101,7 @@ class TranslationService: "uz": "uz", # Uzbek "mn": "mn", # Mongolian - # Additional languages (completing 100) + # Additional languages "sq": "sq", # Albanian "hy": "hy", # Armenian "be": "be", # Belarusian @@ -120,7 +121,6 @@ class TranslationService: "ht": "ht", # Haitian Creole "is": "is", # Icelandic "jv": "jv", # Javanese - "kn": "kn", # Kannada "ku": "ku", # Kurdish "ky": "ky", # Kyrgyz "la": "la", # Latin @@ -143,21 +143,238 @@ class TranslationService: "yi": "yi", # Yiddish } - def _get_model_info(self, source_lang: str, target_lang: str) -> tuple[str, str, str]: + # NLLB-200 supported language codes (200 languages) + # NLLB uses different format: xxx_Yyyy (language_Script) + # Full list: https://huggingface.co/facebook/nllb-200-distilled-600M + self.nllb200_lang_codes = { + # Major languages + "en": "eng_Latn", # English + "zh": "zho_Hans", # Chinese (Simplified) + "es": "spa_Latn", # Spanish + "ar": "arb_Arab", # Arabic (Standard) + "hi": "hin_Deva", # Hindi + "bn": "ben_Beng", # Bengali + "pt": "por_Latn", # Portuguese + "ru": "rus_Cyrl", # Russian + "ja": "jpn_Jpan", # Japanese + "de": "deu_Latn", # German + "fr": "fra_Latn", # French + "ko": "kor_Hang", # Korean + "it": "ita_Latn", # Italian + "tr": "tur_Latn", # Turkish + "vi": "vie_Latn", # Vietnamese + "th": "tha_Thai", # Thai + "pl": "pol_Latn", # Polish + "nl": "nld_Latn", # Dutch + "uk": "ukr_Cyrl", # Ukrainian + "ro": "ron_Latn", # Romanian + + # Southeast Asian languages + "ms": "zsm_Latn", # Malay (Standard) + "id": "ind_Latn", # Indonesian + "tl": "tgl_Latn", # Tagalog + "my": "mya_Mymr", # Burmese + "km": "khm_Khmr", # Khmer + "lo": "lao_Laoo", # Lao + + # South Asian languages + "ur": "urd_Arab", # Urdu + "ta": "tam_Taml", # Tamil + "te": "tel_Telu", # Telugu + "mr": "mar_Deva", # Marathi + "gu": "guj_Gujr", # Gujarati + "kn": "kan_Knda", # Kannada + "ml": "mal_Mlym", # Malayalam + "pa": "pan_Guru", # Punjabi + "ne": "npi_Deva", # Nepali + "si": "sin_Sinh", # Sinhala + + # European languages + "sv": "swe_Latn", # Swedish + "da": "dan_Latn", # Danish + "fi": "fin_Latn", # Finnish + "no": "nob_Latn", # Norwegian (Bokmål) + "cs": "ces_Latn", # Czech + "sk": "slk_Latn", # Slovak + "hu": "hun_Latn", # Hungarian + "bg": "bul_Cyrl", # Bulgarian + "sr": "srp_Cyrl", # Serbian + "hr": "hrv_Latn", # Croatian + "sl": "slv_Latn", # Slovenian + "et": "est_Latn", # Estonian + "lv": "lvs_Latn", # Latvian + "lt": "lit_Latn", # Lithuanian + "el": "ell_Grek", # Greek + "he": "heb_Hebr", # Hebrew + "fa": "pes_Arab", # Persian + + # African languages + "sw": "swh_Latn", # Swahili + "am": "amh_Ethi", # Amharic + "ha": "hau_Latn", # Hausa + "ig": "ibo_Latn", # Igbo + "yo": "yor_Latn", # Yoruba + "zu": "zul_Latn", # Zulu + "xh": "xho_Latn", # Xhosa + "af": "afr_Latn", # Afrikaans + "sn": "sna_Latn", # Shona + "so": "som_Latn", # Somali + + # Other languages + "az": "azj_Latn", # Azerbaijani (North) + "ka": "kat_Geor", # Georgian + "kk": "kaz_Cyrl", # Kazakh + "uz": "uzn_Latn", # Uzbek (Northern) + "mn": "khk_Cyrl", # Mongolian (Halh) + "sq": "als_Latn", # Albanian + "hy": "hye_Armn", # Armenian + "be": "bel_Cyrl", # Belarusian + "bs": "bos_Latn", # Bosnian + "ca": "cat_Latn", # Catalan + "ceb": "ceb_Latn", # Cebuano + "cy": "cym_Latn", # Welsh + "eo": "epo_Latn", # Esperanto + "eu": "eus_Latn", # Basque + "gl": "glg_Latn", # Galician + "is": "isl_Latn", # Icelandic + "jv": "jav_Latn", # Javanese + "ku": "kmr_Latn", # Kurdish (Kurmanji) + "ky": "kir_Cyrl", # Kyrgyz + "la": "lat_Latn", # Latin + "lb": "ltz_Latn", # Luxembourgish + "lg": "lug_Latn", # Luganda + "mg": "plt_Latn", # Malagasy + "mk": "mkd_Cyrl", # Macedonian + "mt": "mlt_Latn", # Maltese + "ny": "nya_Latn", # Chichewa + "ps": "pbt_Arab", # Pashto (Southern) + "st": "sot_Latn", # Sesotho + "su": "sun_Latn", # Sundanese + "tg": "tgk_Cyrl", # Tajik + "tk": "tuk_Latn", # Turkmen + "ug": "uig_Arab", # Uyghur + + # Additional NLLB-200 exclusive languages (examples, 95 more) + "ace": "ace_Latn", # Acehnese + "acm": "acm_Arab", # Mesopotamian Arabic + "acq": "acq_Arab", # Ta'izzi-Adeni Arabic + "aeb": "aeb_Arab", # Tunisian Arabic + "ajp": "ajp_Arab", # South Levantine Arabic + "als": "als_Latn", # Tosk Albanian + "ars": "ars_Arab", # Najdi Arabic + "ary": "ary_Arab", # Moroccan Arabic + "arz": "arz_Arab", # Egyptian Arabic + "asm": "asm_Beng", # Assamese + "ast": "ast_Latn", # Asturian + "awa": "awa_Deva", # Awadhi + "ayr": "ayr_Latn", # Central Aymara + "azb": "azb_Arab", # South Azerbaijani + "bak": "bak_Cyrl", # Bashkir + "bam": "bam_Latn", # Bambara + "ban": "ban_Latn", # Balinese + "bho": "bho_Deva", # Bhojpuri + "bjn": "bjn_Latn", # Banjar + "bod": "bod_Tibt", # Tibetan + "bug": "bug_Latn", # Buginese + "crh": "crh_Latn", # Crimean Tatar + "cjk": "cjk_Latn", # Chokwe + "ckb": "ckb_Arab", # Central Kurdish + "dik": "dik_Latn", # Southwestern Dinka + "dyu": "dyu_Latn", # Dyula + "dzo": "dzo_Tibt", # Dzongkha + "fur": "fur_Latn", # Friulian + "fuv": "fuv_Latn", # Nigerian Fulfulde + "gaz": "gaz_Latn", # West Central Oromo + "grn": "grn_Latn", # Guarani + "hne": "hne_Deva", # Chhattisgarhi + "ilo": "ilo_Latn", # Iloko + "kab": "kab_Latn", # Kabyle + "kac": "kac_Latn", # Jingpho + "kam": "kam_Latn", # Kamba + "kas": "kas_Arab", # Kashmiri + "kea": "kea_Latn", # Kabuverdianu + "khk": "khk_Cyrl", # Halh Mongolian + "kin": "kin_Latn", # Kinyarwanda + "lij": "lij_Latn", # Ligurian + "lim": "lim_Latn", # Limburgish + "lin": "lin_Latn", # Lingala + "lmo": "lmo_Latn", # Lombard + "ltg": "ltg_Latn", # Latgalian + "luo": "luo_Latn", # Luo + "lus": "lus_Latn", # Mizo + "mag": "mag_Deva", # Magahi + "mai": "mai_Deva", # Maithili + "min": "min_Latn", # Minangkabau + "mni": "mni_Beng", # Meitei + "mos": "mos_Latn", # Mossi + "mri": "mri_Latn", # Maori + "nus": "nus_Latn", # Nuer + "ory": "ory_Orya", # Odia + "pag": "pag_Latn", # Pangasinan + "pap": "pap_Latn", # Papiamento + "prs": "prs_Arab", # Dari + "quy": "quy_Latn", # Ayacucho Quechua + "run": "run_Latn", # Rundi + "sag": "sag_Latn", # Sango + "san": "san_Deva", # Sanskrit + "sat": "sat_Beng", # Santali + "scn": "scn_Latn", # Sicilian + "shn": "shn_Mymr", # Shan + "srd": "srd_Latn", # Sardinian + "szl": "szl_Latn", # Silesian + "taq": "taq_Latn", # Tamasheq + "tat": "tat_Cyrl", # Tatar + "tir": "tir_Ethi", # Tigrinya + "tpi": "tpi_Latn", # Tok Pisin + "tsn": "tsn_Latn", # Tswana + "tso": "tso_Latn", # Tsonga + "tum": "tum_Latn", # Tumbuka + "twi": "twi_Latn", # Twi + "tzm": "tzm_Tfng", # Central Atlas Tamazight + "uig": "uig_Arab", # Uyghur + "vec": "vec_Latn", # Venetian + "war": "war_Latn", # Waray + "wol": "wol_Latn", # Wolof + "xho": "xho_Latn", # Xhosa + "ydd": "ydd_Hebr", # Eastern Yiddish + "yor": "yor_Latn", # Yoruba + "yue": "yue_Hant", # Cantonese + "zho_hant": "zho_Hant", # Chinese (Traditional) + } + + def get_supported_languages(self, model_type: str = "m2m100") -> Dict[str, str]: + """Get supported language codes for the specified model""" + if model_type == "m2m100": + return self.m2m100_lang_codes + elif model_type == "nllb200": + return self.nllb200_lang_codes + else: + raise ValueError(f"Unknown model type: {model_type}") + + def _get_model_info(self, source_lang: str, target_lang: str, model_type: str) -> tuple[str, str, str, str]: """Get the model name and language codes for translation""" - # Using M2M100 418M model (smaller, faster, commercial-friendly) - model_name = "facebook/m2m100_418M" - src_code = self.lang_codes.get(source_lang) - tgt_code = self.lang_codes.get(target_lang) + if model_type == "m2m100": + model_name = "facebook/m2m100_418M" + lang_codes = self.m2m100_lang_codes + elif model_type == "nllb200": + model_name = "facebook/nllb-200-distilled-600M" + lang_codes = self.nllb200_lang_codes + else: + raise ValueError(f"Unknown model type: {model_type}") - if not src_code or not tgt_code: - raise ValueError(f"Unsupported language pair: {source_lang} -> {target_lang}") + src_code = lang_codes.get(source_lang) + tgt_code = lang_codes.get(target_lang) - return model_name, src_code, tgt_code + if not src_code: + raise ValueError(f"Source language '{source_lang}' not supported by {model_type}. Use /api/supported-languages?model={model_type} to see available languages.") + if not tgt_code: + raise ValueError(f"Target language '{target_lang}' not supported by {model_type}. Use /api/supported-languages?model={model_type} to see available languages.") - def load_model(self, source_lang: str, target_lang: str) -> None: + return model_name, src_code, tgt_code, model_type + + def load_model(self, source_lang: str, target_lang: str, model_type: str = "m2m100") -> None: """Load translation model for specific language pair""" - model_name, _, _ = self._get_model_info(source_lang, target_lang) + model_name, _, _, _ = self._get_model_info(source_lang, target_lang, model_type) if model_name in self.models: logger.info(f"Model {model_name} already loaded") @@ -165,18 +382,30 @@ class TranslationService: try: logger.info(f"Loading model: {model_name}") - tokenizer = M2M100Tokenizer.from_pretrained( - model_name, - cache_dir=settings.model_cache_dir - ) - model = M2M100ForConditionalGeneration.from_pretrained( - model_name, - cache_dir=settings.model_cache_dir - ).to(self.device) + + if model_type == "m2m100": + tokenizer = M2M100Tokenizer.from_pretrained( + model_name, + cache_dir=settings.model_cache_dir + ) + model = M2M100ForConditionalGeneration.from_pretrained( + model_name, + cache_dir=settings.model_cache_dir + ).to(self.device) + elif model_type == "nllb200": + tokenizer = AutoTokenizer.from_pretrained( + model_name, + cache_dir=settings.model_cache_dir + ) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_name, + cache_dir=settings.model_cache_dir + ).to(self.device) self.models[model_name] = { "tokenizer": tokenizer, - "model": model + "model": model, + "type": model_type } logger.info(f"Successfully loaded model: {model_name}") @@ -184,7 +413,7 @@ class TranslationService: logger.error(f"Error loading model {model_name}: {str(e)}") raise - def translate(self, text: str, source_lang: str, target_lang: str) -> tuple[str, str]: + def translate(self, text: str, source_lang: str, target_lang: str, model_type: str = "m2m100") -> tuple[str, str]: """ Translate text from source language to target language @@ -192,41 +421,66 @@ class TranslationService: text: Text to translate source_lang: Source language code target_lang: Target language code + model_type: Model to use ('m2m100' or 'nllb200') Returns: Tuple of (translated_text, model_name) """ - model_name, src_code, tgt_code = self._get_model_info(source_lang, target_lang) + model_name, src_code, tgt_code, _ = self._get_model_info(source_lang, target_lang, model_type) # Load model if not already loaded if model_name not in self.models: - self.load_model(source_lang, target_lang) + self.load_model(source_lang, target_lang, model_type) try: tokenizer = self.models[model_name]["tokenizer"] model = self.models[model_name]["model"] - # Set source language for tokenizer - tokenizer.src_lang = src_code + if model_type == "m2m100": + # M2M100 uses src_lang attribute + tokenizer.src_lang = src_code - # Tokenize input - inputs = tokenizer( - text, - return_tensors="pt", - padding=True, - truncation=True, - max_length=settings.max_length - ).to(self.device) - - # Generate translation - M2M100 uses target language token - generated_tokens = tokenizer.get_lang_id(tgt_code) - - with torch.no_grad(): - translated = model.generate( - **inputs, - forced_bos_token_id=generated_tokens, + # Tokenize input + inputs = tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, max_length=settings.max_length - ) + ).to(self.device) + + # Generate translation + generated_tokens = tokenizer.get_lang_id(tgt_code) + + with torch.no_grad(): + translated = model.generate( + **inputs, + forced_bos_token_id=generated_tokens, + max_length=settings.max_length + ) + + elif model_type == "nllb200": + # NLLB uses different approach with src_lang in tokenizer call + tokenizer.src_lang = src_code + + # Tokenize input + inputs = tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=settings.max_length + ).to(self.device) + + # Generate translation - NLLB uses forced_bos_token_id + forced_bos_token_id = tokenizer.lang_code_to_id[tgt_code] + + with torch.no_grad(): + translated = model.generate( + **inputs, + forced_bos_token_id=forced_bos_token_id, + max_length=settings.max_length + ) # Decode output translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] @@ -238,17 +492,18 @@ class TranslationService: raise def preload_all_models(self) -> None: - """Preload all supported translation models""" + """Preload commonly used translation models""" + # Only preload M2M100 by default (commercial-friendly) language_pairs = [ - ("ms", "en"), - ("en", "ms") + ("ms", "en", "m2m100"), + ("en", "ms", "m2m100") ] - for source, target in language_pairs: + for source, target, model_type in language_pairs: try: - self.load_model(source, target) + self.load_model(source, target, model_type) except Exception as e: - logger.warning(f"Could not preload model for {source}->{target}: {str(e)}") + logger.warning(f"Could not preload model for {source}->{target} ({model_type}): {str(e)}") def is_ready(self) -> bool: """Check if at least one model is loaded"""