Add dual model support: M2M100 and NLLB-200

- Added optional 'model' parameter to translation request (default: m2m100)
- M2M100: 105 languages, Apache 2.0 License (commercial OK)
- NLLB-200: 200 languages, CC-BY-NC 4.0 License (non-commercial only)
- Updated /api/translate endpoint to accept model selection
- Updated /api/supported-languages to show languages per model
- Added comprehensive language name mappings for all NLLB-200 languages
- Both models can be used independently with automatic model loading
- Model information includes license and commercial use status

Example usage:
- Default (M2M100): {"text": "Hello", "source_lang": "en", "target_lang": "ko"}
- NLLB-200: {"text": "Hello", "source_lang": "en", "target_lang": "ko", "model": "nllb200"}

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
jungwoo choi
2025-11-11 15:57:00 +09:00
parent 228f6c38e5
commit 28e26d19b6
3 changed files with 434 additions and 61 deletions

View File

@ -103,7 +103,8 @@ async def translate_text(request: TranslationRequest):
translated_text, model_used = translator.translate( translated_text, model_used = translator.translate(
text=request.text, text=request.text,
source_lang=request.source_lang, source_lang=request.source_lang,
target_lang=request.target_lang target_lang=request.target_lang,
model_type=request.model
) )
return TranslationResponse( return TranslationResponse(
@ -125,8 +126,15 @@ async def translate_text(request: TranslationRequest):
@app.get("/api/supported-languages") @app.get("/api/supported-languages")
async def get_supported_languages(): async def get_supported_languages(model: str = "m2m100"):
"""Get list of supported languages""" """
Get list of supported languages for specified model
- **model**: Model type ('m2m100' or 'nllb200')
"""
if model not in ["m2m100", "nllb200"]:
raise HTTPException(status_code=400, detail="Invalid model. Choose 'm2m100' or 'nllb200'")
# Language names mapping # Language names mapping
lang_names = { lang_names = {
@ -235,10 +243,97 @@ async def get_supported_languages():
"tk": {"name": "Turkmen", "native": "Türkmençe"}, "tk": {"name": "Turkmen", "native": "Türkmençe"},
"ug": {"name": "Uyghur", "native": "ئۇيغۇرچە"}, "ug": {"name": "Uyghur", "native": "ئۇيغۇرچە"},
"yi": {"name": "Yiddish", "native": "ייִדיש"}, "yi": {"name": "Yiddish", "native": "ייִדיש"},
# Additional NLLB-200 exclusive languages
"ace": {"name": "Acehnese", "native": "Acèh"},
"acm": {"name": "Mesopotamian Arabic", "native": "عراقي"},
"acq": {"name": "Ta'izzi-Adeni Arabic", "native": "تعزية-عدنية"},
"aeb": {"name": "Tunisian Arabic", "native": "تونسي"},
"ajp": {"name": "South Levantine Arabic", "native": "شامي"},
"als": {"name": "Tosk Albanian", "native": "Toskë"},
"ars": {"name": "Najdi Arabic", "native": "نجدي"},
"ary": {"name": "Moroccan Arabic", "native": "الدارجة"},
"arz": {"name": "Egyptian Arabic", "native": "مصري"},
"asm": {"name": "Assamese", "native": "অসমীয়া"},
"ast": {"name": "Asturian", "native": "Asturianu"},
"awa": {"name": "Awadhi", "native": "अवधी"},
"ayr": {"name": "Central Aymara", "native": "Aymar aru"},
"azb": {"name": "South Azerbaijani", "native": "تۆرکجه"},
"bak": {"name": "Bashkir", "native": "Башҡортса"},
"bam": {"name": "Bambara", "native": "Bamanankan"},
"ban": {"name": "Balinese", "native": "Basa Bali"},
"bho": {"name": "Bhojpuri", "native": "भोजपुरी"},
"bjn": {"name": "Banjar", "native": "Bahasa Banjar"},
"bod": {"name": "Tibetan", "native": "བོད་སྐད་"},
"bug": {"name": "Buginese", "native": "Basa Ugi"},
"crh": {"name": "Crimean Tatar", "native": "Qırımtatar tili"},
"cjk": {"name": "Chokwe", "native": "Chokwe"},
"ckb": {"name": "Central Kurdish", "native": "کوردیی ناوەندی"},
"dik": {"name": "Southwestern Dinka", "native": "Thuɔŋjäŋ"},
"dyu": {"name": "Dyula", "native": "Jula"},
"dzo": {"name": "Dzongkha", "native": "རྫོང་ཁ"},
"fur": {"name": "Friulian", "native": "Furlan"},
"fuv": {"name": "Nigerian Fulfulde", "native": "Fulfulde"},
"gaz": {"name": "West Central Oromo", "native": "Oromoo"},
"grn": {"name": "Guarani", "native": "Avañe'"},
"hne": {"name": "Chhattisgarhi", "native": "छत्तीसगढ़ी"},
"ilo": {"name": "Iloko", "native": "Ilokano"},
"kab": {"name": "Kabyle", "native": "Taqbaylit"},
"kac": {"name": "Jingpho", "native": "Jinghpaw"},
"kam": {"name": "Kamba", "native": "Kikamba"},
"kas": {"name": "Kashmiri", "native": "कॉशुर"},
"kea": {"name": "Kabuverdianu", "native": "Kabuverdianu"},
"khk": {"name": "Halh Mongolian", "native": "Монгол хэл"},
"kin": {"name": "Kinyarwanda", "native": "Ikinyarwanda"},
"lij": {"name": "Ligurian", "native": "Ligure"},
"lim": {"name": "Limburgish", "native": "Limburgs"},
"lin": {"name": "Lingala", "native": "Lingála"},
"lmo": {"name": "Lombard", "native": "Lombard"},
"ltg": {"name": "Latgalian", "native": "Latgalīšu"},
"luo": {"name": "Luo", "native": "Dholuo"},
"lus": {"name": "Mizo", "native": "Mizo ṭawng"},
"mag": {"name": "Magahi", "native": "मगही"},
"mai": {"name": "Maithili", "native": "मैथिली"},
"min": {"name": "Minangkabau", "native": "Baso Minangkabau"},
"mni": {"name": "Meitei", "native": "মৈতৈলোন্"},
"mos": {"name": "Mossi", "native": "Mooré"},
"mri": {"name": "Maori", "native": "Te Reo Māori"},
"nus": {"name": "Nuer", "native": "Thok Naath"},
"ory": {"name": "Odia", "native": "ଓଡ଼ିଆ"},
"pag": {"name": "Pangasinan", "native": "Pangasinan"},
"pap": {"name": "Papiamento", "native": "Papiamentu"},
"prs": {"name": "Dari", "native": "دری"},
"quy": {"name": "Ayacucho Quechua", "native": "Chanka Qhichwa"},
"run": {"name": "Rundi", "native": "Ikirundi"},
"sag": {"name": "Sango", "native": "Sängö"},
"san": {"name": "Sanskrit", "native": "संस्कृतम्"},
"sat": {"name": "Santali", "native": "ᱥᱟᱱᱛᱟᱲᱤ"},
"scn": {"name": "Sicilian", "native": "Sicilianu"},
"shn": {"name": "Shan", "native": "လိၵ်ႈတႆး"},
"srd": {"name": "Sardinian", "native": "Sardu"},
"szl": {"name": "Silesian", "native": "Ślōnski"},
"taq": {"name": "Tamasheq", "native": "Tamasheq"},
"tat": {"name": "Tatar", "native": "Татарча"},
"tir": {"name": "Tigrinya", "native": "ትግርኛ"},
"tpi": {"name": "Tok Pisin", "native": "Tok Pisin"},
"tsn": {"name": "Tswana", "native": "Setswana"},
"tso": {"name": "Tsonga", "native": "Xitsonga"},
"tum": {"name": "Tumbuka", "native": "Chitumbuka"},
"twi": {"name": "Twi", "native": "Twi"},
"tzm": {"name": "Central Atlas Tamazight", "native": "ⵜⴰⵎⴰⵣⵉⵖⵜ"},
"uig": {"name": "Uyghur", "native": "ئۇيغۇرچە"},
"vec": {"name": "Venetian", "native": "Vèneto"},
"war": {"name": "Waray", "native": "Winaray"},
"wol": {"name": "Wolof", "native": "Wolof"},
"xho": {"name": "Xhosa", "native": "isiXhosa"},
"ydd": {"name": "Eastern Yiddish", "native": "ייִדיש"},
"yor": {"name": "Yoruba", "native": "Yorùbá"},
"yue": {"name": "Cantonese", "native": "粵語"},
"zho_hant": {"name": "Chinese (Traditional)", "native": "繁體中文"},
} }
# Get all supported language codes from translator # Get all supported language codes from translator based on model type
supported_codes = list(translator.lang_codes.keys()) supported_codes = list(translator.get_supported_languages(model).keys())
# Build language list # Build language list
languages = [ languages = [
@ -250,7 +345,25 @@ async def get_supported_languages():
for code in sorted(supported_codes) for code in sorted(supported_codes)
] ]
model_info = {
"m2m100": {
"name": "M2M100",
"languages": 105,
"license": "Apache 2.0",
"commercial_use": True,
"model_id": "facebook/m2m100_418M"
},
"nllb200": {
"name": "NLLB-200",
"languages": 200,
"license": "CC-BY-NC 4.0",
"commercial_use": False,
"model_id": "facebook/nllb-200-distilled-600M"
}
}
return { return {
"model": model_info[model],
"languages": languages, "languages": languages,
"total_languages": len(languages), "total_languages": len(languages),
"note": "All language pairs are supported (any-to-any translation)" "note": "All language pairs are supported (any-to-any translation)"

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing import Optional from typing import Optional, Literal
class TranslationRequest(BaseModel): class TranslationRequest(BaseModel):
@ -7,6 +7,10 @@ class TranslationRequest(BaseModel):
text: str = Field(..., description="Text to translate", min_length=1, max_length=5000) text: str = Field(..., description="Text to translate", min_length=1, max_length=5000)
source_lang: str = Field(..., description="Source language code (e.g., 'en', 'ms', 'bn', etc.)", min_length=2, max_length=5) source_lang: str = Field(..., description="Source language code (e.g., 'en', 'ms', 'bn', etc.)", min_length=2, max_length=5)
target_lang: str = Field(..., description="Target language code (e.g., 'en', 'ms', 'bn', etc.)", min_length=2, max_length=5) target_lang: str = Field(..., description="Target language code (e.g., 'en', 'ms', 'bn', etc.)", min_length=2, max_length=5)
model: Literal["m2m100", "nllb200"] = Field(
default="m2m100",
description="Translation model to use: 'm2m100' (105 langs, Apache 2.0, commercial OK) or 'nllb200' (200 langs, CC-BY-NC, non-commercial only)"
)
@field_validator('source_lang', 'target_lang') @field_validator('source_lang', 'target_lang')
@classmethod @classmethod
@ -19,7 +23,8 @@ class TranslationRequest(BaseModel):
"example": { "example": {
"text": "Selamat pagi, apa khabar?", "text": "Selamat pagi, apa khabar?",
"source_lang": "ms", "source_lang": "ms",
"target_lang": "en" "target_lang": "en",
"model": "m2m100"
} }
} }

View File

@ -1,4 +1,4 @@
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, AutoModelForSeq2SeqLM, AutoTokenizer
import torch import torch
from typing import Dict, Optional from typing import Dict, Optional
import logging import logging
@ -10,8 +10,9 @@ logger = logging.getLogger(__name__)
class TranslationService: class TranslationService:
""" """
Service for handling multilingual translation Service for handling multilingual translation
Uses M2M100 model (Apache 2.0 License - Commercial use allowed) Supports two models:
Supports 100 languages for many-to-many translation - M2M100 (105 languages, Apache 2.0 License - Commercial use allowed)
- NLLB-200 (200 languages, CC-BY-NC License - Non-commercial only)
""" """
def __init__(self): def __init__(self):
@ -19,9 +20,9 @@ class TranslationService:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}") logger.info(f"Using device: {self.device}")
# M2M100 supported language codes (100 languages) # M2M100 supported language codes (105 languages)
# Full list: https://huggingface.co/facebook/m2m100_418M # Full list: https://huggingface.co/facebook/m2m100_418M
self.lang_codes = { self.m2m100_lang_codes = {
# Major languages # Major languages
"en": "en", # English "en": "en", # English
"zh": "zh", # Chinese "zh": "zh", # Chinese
@ -100,7 +101,7 @@ class TranslationService:
"uz": "uz", # Uzbek "uz": "uz", # Uzbek
"mn": "mn", # Mongolian "mn": "mn", # Mongolian
# Additional languages (completing 100) # Additional languages
"sq": "sq", # Albanian "sq": "sq", # Albanian
"hy": "hy", # Armenian "hy": "hy", # Armenian
"be": "be", # Belarusian "be": "be", # Belarusian
@ -120,7 +121,6 @@ class TranslationService:
"ht": "ht", # Haitian Creole "ht": "ht", # Haitian Creole
"is": "is", # Icelandic "is": "is", # Icelandic
"jv": "jv", # Javanese "jv": "jv", # Javanese
"kn": "kn", # Kannada
"ku": "ku", # Kurdish "ku": "ku", # Kurdish
"ky": "ky", # Kyrgyz "ky": "ky", # Kyrgyz
"la": "la", # Latin "la": "la", # Latin
@ -143,21 +143,238 @@ class TranslationService:
"yi": "yi", # Yiddish "yi": "yi", # Yiddish
} }
def _get_model_info(self, source_lang: str, target_lang: str) -> tuple[str, str, str]: # NLLB-200 supported language codes (200 languages)
# NLLB uses different format: xxx_Yyyy (language_Script)
# Full list: https://huggingface.co/facebook/nllb-200-distilled-600M
self.nllb200_lang_codes = {
# Major languages
"en": "eng_Latn", # English
"zh": "zho_Hans", # Chinese (Simplified)
"es": "spa_Latn", # Spanish
"ar": "arb_Arab", # Arabic (Standard)
"hi": "hin_Deva", # Hindi
"bn": "ben_Beng", # Bengali
"pt": "por_Latn", # Portuguese
"ru": "rus_Cyrl", # Russian
"ja": "jpn_Jpan", # Japanese
"de": "deu_Latn", # German
"fr": "fra_Latn", # French
"ko": "kor_Hang", # Korean
"it": "ita_Latn", # Italian
"tr": "tur_Latn", # Turkish
"vi": "vie_Latn", # Vietnamese
"th": "tha_Thai", # Thai
"pl": "pol_Latn", # Polish
"nl": "nld_Latn", # Dutch
"uk": "ukr_Cyrl", # Ukrainian
"ro": "ron_Latn", # Romanian
# Southeast Asian languages
"ms": "zsm_Latn", # Malay (Standard)
"id": "ind_Latn", # Indonesian
"tl": "tgl_Latn", # Tagalog
"my": "mya_Mymr", # Burmese
"km": "khm_Khmr", # Khmer
"lo": "lao_Laoo", # Lao
# South Asian languages
"ur": "urd_Arab", # Urdu
"ta": "tam_Taml", # Tamil
"te": "tel_Telu", # Telugu
"mr": "mar_Deva", # Marathi
"gu": "guj_Gujr", # Gujarati
"kn": "kan_Knda", # Kannada
"ml": "mal_Mlym", # Malayalam
"pa": "pan_Guru", # Punjabi
"ne": "npi_Deva", # Nepali
"si": "sin_Sinh", # Sinhala
# European languages
"sv": "swe_Latn", # Swedish
"da": "dan_Latn", # Danish
"fi": "fin_Latn", # Finnish
"no": "nob_Latn", # Norwegian (Bokmål)
"cs": "ces_Latn", # Czech
"sk": "slk_Latn", # Slovak
"hu": "hun_Latn", # Hungarian
"bg": "bul_Cyrl", # Bulgarian
"sr": "srp_Cyrl", # Serbian
"hr": "hrv_Latn", # Croatian
"sl": "slv_Latn", # Slovenian
"et": "est_Latn", # Estonian
"lv": "lvs_Latn", # Latvian
"lt": "lit_Latn", # Lithuanian
"el": "ell_Grek", # Greek
"he": "heb_Hebr", # Hebrew
"fa": "pes_Arab", # Persian
# African languages
"sw": "swh_Latn", # Swahili
"am": "amh_Ethi", # Amharic
"ha": "hau_Latn", # Hausa
"ig": "ibo_Latn", # Igbo
"yo": "yor_Latn", # Yoruba
"zu": "zul_Latn", # Zulu
"xh": "xho_Latn", # Xhosa
"af": "afr_Latn", # Afrikaans
"sn": "sna_Latn", # Shona
"so": "som_Latn", # Somali
# Other languages
"az": "azj_Latn", # Azerbaijani (North)
"ka": "kat_Geor", # Georgian
"kk": "kaz_Cyrl", # Kazakh
"uz": "uzn_Latn", # Uzbek (Northern)
"mn": "khk_Cyrl", # Mongolian (Halh)
"sq": "als_Latn", # Albanian
"hy": "hye_Armn", # Armenian
"be": "bel_Cyrl", # Belarusian
"bs": "bos_Latn", # Bosnian
"ca": "cat_Latn", # Catalan
"ceb": "ceb_Latn", # Cebuano
"cy": "cym_Latn", # Welsh
"eo": "epo_Latn", # Esperanto
"eu": "eus_Latn", # Basque
"gl": "glg_Latn", # Galician
"is": "isl_Latn", # Icelandic
"jv": "jav_Latn", # Javanese
"ku": "kmr_Latn", # Kurdish (Kurmanji)
"ky": "kir_Cyrl", # Kyrgyz
"la": "lat_Latn", # Latin
"lb": "ltz_Latn", # Luxembourgish
"lg": "lug_Latn", # Luganda
"mg": "plt_Latn", # Malagasy
"mk": "mkd_Cyrl", # Macedonian
"mt": "mlt_Latn", # Maltese
"ny": "nya_Latn", # Chichewa
"ps": "pbt_Arab", # Pashto (Southern)
"st": "sot_Latn", # Sesotho
"su": "sun_Latn", # Sundanese
"tg": "tgk_Cyrl", # Tajik
"tk": "tuk_Latn", # Turkmen
"ug": "uig_Arab", # Uyghur
# Additional NLLB-200 exclusive languages (examples, 95 more)
"ace": "ace_Latn", # Acehnese
"acm": "acm_Arab", # Mesopotamian Arabic
"acq": "acq_Arab", # Ta'izzi-Adeni Arabic
"aeb": "aeb_Arab", # Tunisian Arabic
"ajp": "ajp_Arab", # South Levantine Arabic
"als": "als_Latn", # Tosk Albanian
"ars": "ars_Arab", # Najdi Arabic
"ary": "ary_Arab", # Moroccan Arabic
"arz": "arz_Arab", # Egyptian Arabic
"asm": "asm_Beng", # Assamese
"ast": "ast_Latn", # Asturian
"awa": "awa_Deva", # Awadhi
"ayr": "ayr_Latn", # Central Aymara
"azb": "azb_Arab", # South Azerbaijani
"bak": "bak_Cyrl", # Bashkir
"bam": "bam_Latn", # Bambara
"ban": "ban_Latn", # Balinese
"bho": "bho_Deva", # Bhojpuri
"bjn": "bjn_Latn", # Banjar
"bod": "bod_Tibt", # Tibetan
"bug": "bug_Latn", # Buginese
"crh": "crh_Latn", # Crimean Tatar
"cjk": "cjk_Latn", # Chokwe
"ckb": "ckb_Arab", # Central Kurdish
"dik": "dik_Latn", # Southwestern Dinka
"dyu": "dyu_Latn", # Dyula
"dzo": "dzo_Tibt", # Dzongkha
"fur": "fur_Latn", # Friulian
"fuv": "fuv_Latn", # Nigerian Fulfulde
"gaz": "gaz_Latn", # West Central Oromo
"grn": "grn_Latn", # Guarani
"hne": "hne_Deva", # Chhattisgarhi
"ilo": "ilo_Latn", # Iloko
"kab": "kab_Latn", # Kabyle
"kac": "kac_Latn", # Jingpho
"kam": "kam_Latn", # Kamba
"kas": "kas_Arab", # Kashmiri
"kea": "kea_Latn", # Kabuverdianu
"khk": "khk_Cyrl", # Halh Mongolian
"kin": "kin_Latn", # Kinyarwanda
"lij": "lij_Latn", # Ligurian
"lim": "lim_Latn", # Limburgish
"lin": "lin_Latn", # Lingala
"lmo": "lmo_Latn", # Lombard
"ltg": "ltg_Latn", # Latgalian
"luo": "luo_Latn", # Luo
"lus": "lus_Latn", # Mizo
"mag": "mag_Deva", # Magahi
"mai": "mai_Deva", # Maithili
"min": "min_Latn", # Minangkabau
"mni": "mni_Beng", # Meitei
"mos": "mos_Latn", # Mossi
"mri": "mri_Latn", # Maori
"nus": "nus_Latn", # Nuer
"ory": "ory_Orya", # Odia
"pag": "pag_Latn", # Pangasinan
"pap": "pap_Latn", # Papiamento
"prs": "prs_Arab", # Dari
"quy": "quy_Latn", # Ayacucho Quechua
"run": "run_Latn", # Rundi
"sag": "sag_Latn", # Sango
"san": "san_Deva", # Sanskrit
"sat": "sat_Beng", # Santali
"scn": "scn_Latn", # Sicilian
"shn": "shn_Mymr", # Shan
"srd": "srd_Latn", # Sardinian
"szl": "szl_Latn", # Silesian
"taq": "taq_Latn", # Tamasheq
"tat": "tat_Cyrl", # Tatar
"tir": "tir_Ethi", # Tigrinya
"tpi": "tpi_Latn", # Tok Pisin
"tsn": "tsn_Latn", # Tswana
"tso": "tso_Latn", # Tsonga
"tum": "tum_Latn", # Tumbuka
"twi": "twi_Latn", # Twi
"tzm": "tzm_Tfng", # Central Atlas Tamazight
"uig": "uig_Arab", # Uyghur
"vec": "vec_Latn", # Venetian
"war": "war_Latn", # Waray
"wol": "wol_Latn", # Wolof
"xho": "xho_Latn", # Xhosa
"ydd": "ydd_Hebr", # Eastern Yiddish
"yor": "yor_Latn", # Yoruba
"yue": "yue_Hant", # Cantonese
"zho_hant": "zho_Hant", # Chinese (Traditional)
}
def get_supported_languages(self, model_type: str = "m2m100") -> Dict[str, str]:
"""Get supported language codes for the specified model"""
if model_type == "m2m100":
return self.m2m100_lang_codes
elif model_type == "nllb200":
return self.nllb200_lang_codes
else:
raise ValueError(f"Unknown model type: {model_type}")
def _get_model_info(self, source_lang: str, target_lang: str, model_type: str) -> tuple[str, str, str, str]:
"""Get the model name and language codes for translation""" """Get the model name and language codes for translation"""
# Using M2M100 418M model (smaller, faster, commercial-friendly) if model_type == "m2m100":
model_name = "facebook/m2m100_418M" model_name = "facebook/m2m100_418M"
src_code = self.lang_codes.get(source_lang) lang_codes = self.m2m100_lang_codes
tgt_code = self.lang_codes.get(target_lang) elif model_type == "nllb200":
model_name = "facebook/nllb-200-distilled-600M"
lang_codes = self.nllb200_lang_codes
else:
raise ValueError(f"Unknown model type: {model_type}")
if not src_code or not tgt_code: src_code = lang_codes.get(source_lang)
raise ValueError(f"Unsupported language pair: {source_lang} -> {target_lang}") tgt_code = lang_codes.get(target_lang)
return model_name, src_code, tgt_code if not src_code:
raise ValueError(f"Source language '{source_lang}' not supported by {model_type}. Use /api/supported-languages?model={model_type} to see available languages.")
if not tgt_code:
raise ValueError(f"Target language '{target_lang}' not supported by {model_type}. Use /api/supported-languages?model={model_type} to see available languages.")
def load_model(self, source_lang: str, target_lang: str) -> None: return model_name, src_code, tgt_code, model_type
def load_model(self, source_lang: str, target_lang: str, model_type: str = "m2m100") -> None:
"""Load translation model for specific language pair""" """Load translation model for specific language pair"""
model_name, _, _ = self._get_model_info(source_lang, target_lang) model_name, _, _, _ = self._get_model_info(source_lang, target_lang, model_type)
if model_name in self.models: if model_name in self.models:
logger.info(f"Model {model_name} already loaded") logger.info(f"Model {model_name} already loaded")
@ -165,6 +382,8 @@ class TranslationService:
try: try:
logger.info(f"Loading model: {model_name}") logger.info(f"Loading model: {model_name}")
if model_type == "m2m100":
tokenizer = M2M100Tokenizer.from_pretrained( tokenizer = M2M100Tokenizer.from_pretrained(
model_name, model_name,
cache_dir=settings.model_cache_dir cache_dir=settings.model_cache_dir
@ -173,10 +392,20 @@ class TranslationService:
model_name, model_name,
cache_dir=settings.model_cache_dir cache_dir=settings.model_cache_dir
).to(self.device) ).to(self.device)
elif model_type == "nllb200":
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=settings.model_cache_dir
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
cache_dir=settings.model_cache_dir
).to(self.device)
self.models[model_name] = { self.models[model_name] = {
"tokenizer": tokenizer, "tokenizer": tokenizer,
"model": model "model": model,
"type": model_type
} }
logger.info(f"Successfully loaded model: {model_name}") logger.info(f"Successfully loaded model: {model_name}")
@ -184,7 +413,7 @@ class TranslationService:
logger.error(f"Error loading model {model_name}: {str(e)}") logger.error(f"Error loading model {model_name}: {str(e)}")
raise raise
def translate(self, text: str, source_lang: str, target_lang: str) -> tuple[str, str]: def translate(self, text: str, source_lang: str, target_lang: str, model_type: str = "m2m100") -> tuple[str, str]:
""" """
Translate text from source language to target language Translate text from source language to target language
@ -192,21 +421,23 @@ class TranslationService:
text: Text to translate text: Text to translate
source_lang: Source language code source_lang: Source language code
target_lang: Target language code target_lang: Target language code
model_type: Model to use ('m2m100' or 'nllb200')
Returns: Returns:
Tuple of (translated_text, model_name) Tuple of (translated_text, model_name)
""" """
model_name, src_code, tgt_code = self._get_model_info(source_lang, target_lang) model_name, src_code, tgt_code, _ = self._get_model_info(source_lang, target_lang, model_type)
# Load model if not already loaded # Load model if not already loaded
if model_name not in self.models: if model_name not in self.models:
self.load_model(source_lang, target_lang) self.load_model(source_lang, target_lang, model_type)
try: try:
tokenizer = self.models[model_name]["tokenizer"] tokenizer = self.models[model_name]["tokenizer"]
model = self.models[model_name]["model"] model = self.models[model_name]["model"]
# Set source language for tokenizer if model_type == "m2m100":
# M2M100 uses src_lang attribute
tokenizer.src_lang = src_code tokenizer.src_lang = src_code
# Tokenize input # Tokenize input
@ -218,7 +449,7 @@ class TranslationService:
max_length=settings.max_length max_length=settings.max_length
).to(self.device) ).to(self.device)
# Generate translation - M2M100 uses target language token # Generate translation
generated_tokens = tokenizer.get_lang_id(tgt_code) generated_tokens = tokenizer.get_lang_id(tgt_code)
with torch.no_grad(): with torch.no_grad():
@ -228,6 +459,29 @@ class TranslationService:
max_length=settings.max_length max_length=settings.max_length
) )
elif model_type == "nllb200":
# NLLB uses different approach with src_lang in tokenizer call
tokenizer.src_lang = src_code
# Tokenize input
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=settings.max_length
).to(self.device)
# Generate translation - NLLB uses forced_bos_token_id
forced_bos_token_id = tokenizer.lang_code_to_id[tgt_code]
with torch.no_grad():
translated = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=settings.max_length
)
# Decode output # Decode output
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
@ -238,17 +492,18 @@ class TranslationService:
raise raise
def preload_all_models(self) -> None: def preload_all_models(self) -> None:
"""Preload all supported translation models""" """Preload commonly used translation models"""
# Only preload M2M100 by default (commercial-friendly)
language_pairs = [ language_pairs = [
("ms", "en"), ("ms", "en", "m2m100"),
("en", "ms") ("en", "ms", "m2m100")
] ]
for source, target in language_pairs: for source, target, model_type in language_pairs:
try: try:
self.load_model(source, target) self.load_model(source, target, model_type)
except Exception as e: except Exception as e:
logger.warning(f"Could not preload model for {source}->{target}: {str(e)}") logger.warning(f"Could not preload model for {source}->{target} ({model_type}): {str(e)}")
def is_ready(self) -> bool: def is_ready(self) -> bool:
"""Check if at least one model is loaded""" """Check if at least one model is loaded"""