Updated dual model system to fully support both M2M100 and NLLB-200: **NLLB-200 Model (204 languages)** - Added all 204 FLORES-200 language codes to nllb200_lang_codes dictionary - Updated language code mappings with FLORES-200 format (xxx_Yyyy) - Added 24+ NLLB-exclusive languages including: - Southeast Asian: Acehnese, Balinese, Banjar, Buginese, Minangkabau - South Asian: Assamese, Awadhi, Bhojpuri, Chhattisgarhi, Magahi, Maithili, Meitei, Odia, Santali - African: Akan, Bambara, Bemba, Chokwe, Dyula, Fon, Kikuyu, Kimbundu, Kongo, Luba-Kasai, Luo, Mossi, Nuer - Arabic dialects: Mesopotamian, Najdi, Moroccan, Egyptian, Tunisian, South/North Levantine - European regional: Asturian, Friulian, Latgalian, Ligurian, Limburgish, Lombard, Norwegian Nynorsk/Bokmål, Occitan, Sardinian, Sicilian, Silesian, Venetian - Other: Dzongkha, Fijian, Guarani, Kabyle, Kabuverdianu, Papiamento, Quechua, Samoan, Sango, Shan, Tamasheq, Tibetan, Tok Pisin **Updated Files** - app/translator.py: Complete NLLB-200 language mappings (204 languages) - app/main.py: Added display names for all 204+ language codes - README.md: Updated with dual model system, NLLB-200 details, license info - CLAUDE.md: Updated developer documentation with model architecture **Testing** - Verified M2M100: 105 languages working ✅ - Verified NLLB-200: 204 languages working ✅ - Tested NLLB-exclusive languages (Bemba, Fon, etc.) ✅ **License Information** - M2M100: Apache 2.0 - Commercial use allowed - NLLB-200: CC-BY-NC 4.0 - Non-commercial only 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
527 lines
20 KiB
Python
527 lines
20 KiB
Python
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, AutoModelForSeq2SeqLM, AutoTokenizer
|
|
import torch
|
|
from typing import Dict, Optional
|
|
import logging
|
|
from .config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TranslationService:
|
|
"""
|
|
Service for handling multilingual translation
|
|
Supports two models:
|
|
- M2M100 (105 languages, Apache 2.0 License - Commercial use allowed)
|
|
- NLLB-200 (200 languages, CC-BY-NC License - Non-commercial only)
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.models: Dict[str, Dict] = {}
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
logger.info(f"Using device: {self.device}")
|
|
|
|
# M2M100 supported language codes (105 languages)
|
|
# Full list: https://huggingface.co/facebook/m2m100_418M
|
|
self.m2m100_lang_codes = {
|
|
# Major languages
|
|
"en": "en", # English
|
|
"zh": "zh", # Chinese
|
|
"es": "es", # Spanish
|
|
"ar": "ar", # Arabic
|
|
"hi": "hi", # Hindi
|
|
"bn": "bn", # Bengali
|
|
"pt": "pt", # Portuguese
|
|
"ru": "ru", # Russian
|
|
"ja": "ja", # Japanese
|
|
"de": "de", # German
|
|
"fr": "fr", # French
|
|
"ko": "ko", # Korean
|
|
"it": "it", # Italian
|
|
"tr": "tr", # Turkish
|
|
"vi": "vi", # Vietnamese
|
|
"th": "th", # Thai
|
|
"pl": "pl", # Polish
|
|
"nl": "nl", # Dutch
|
|
"uk": "uk", # Ukrainian
|
|
"ro": "ro", # Romanian
|
|
|
|
# Southeast Asian languages
|
|
"ms": "ms", # Malay
|
|
"id": "id", # Indonesian
|
|
"tl": "tl", # Tagalog
|
|
"my": "my", # Burmese
|
|
"km": "km", # Khmer
|
|
"lo": "lo", # Lao
|
|
|
|
# South Asian languages
|
|
"ur": "ur", # Urdu
|
|
"ta": "ta", # Tamil
|
|
"te": "te", # Telugu
|
|
"mr": "mr", # Marathi
|
|
"gu": "gu", # Gujarati
|
|
"kn": "kn", # Kannada
|
|
"ml": "ml", # Malayalam
|
|
"pa": "pa", # Punjabi
|
|
"ne": "ne", # Nepali
|
|
"si": "si", # Sinhala
|
|
|
|
# European languages
|
|
"sv": "sv", # Swedish
|
|
"da": "da", # Danish
|
|
"fi": "fi", # Finnish
|
|
"no": "no", # Norwegian
|
|
"cs": "cs", # Czech
|
|
"sk": "sk", # Slovak
|
|
"hu": "hu", # Hungarian
|
|
"bg": "bg", # Bulgarian
|
|
"sr": "sr", # Serbian
|
|
"hr": "hr", # Croatian
|
|
"sl": "sl", # Slovenian
|
|
"et": "et", # Estonian
|
|
"lv": "lv", # Latvian
|
|
"lt": "lt", # Lithuanian
|
|
"el": "el", # Greek
|
|
"he": "he", # Hebrew
|
|
"fa": "fa", # Persian
|
|
|
|
# African languages
|
|
"sw": "sw", # Swahili
|
|
"am": "am", # Amharic
|
|
"ha": "ha", # Hausa
|
|
"ig": "ig", # Igbo
|
|
"yo": "yo", # Yoruba
|
|
"zu": "zu", # Zulu
|
|
"xh": "xh", # Xhosa
|
|
"af": "af", # Afrikaans
|
|
|
|
# Other major languages
|
|
"az": "az", # Azerbaijani
|
|
"ka": "ka", # Georgian
|
|
"kk": "kk", # Kazakh
|
|
"uz": "uz", # Uzbek
|
|
"mn": "mn", # Mongolian
|
|
|
|
# Additional languages
|
|
"sq": "sq", # Albanian
|
|
"hy": "hy", # Armenian
|
|
"be": "be", # Belarusian
|
|
"bs": "bs", # Bosnian
|
|
"ca": "ca", # Catalan
|
|
"ceb": "ceb", # Cebuano
|
|
"cy": "cy", # Welsh
|
|
"eo": "eo", # Esperanto
|
|
"eu": "eu", # Basque
|
|
"fil": "fil", # Filipino
|
|
"fy": "fy", # Frisian
|
|
"ga": "ga", # Irish
|
|
"gd": "gd", # Scottish Gaelic
|
|
"gl": "gl", # Galician
|
|
"haw": "haw", # Hawaiian
|
|
"hmn": "hmn", # Hmong
|
|
"ht": "ht", # Haitian Creole
|
|
"is": "is", # Icelandic
|
|
"jv": "jv", # Javanese
|
|
"ku": "ku", # Kurdish
|
|
"ky": "ky", # Kyrgyz
|
|
"la": "la", # Latin
|
|
"lb": "lb", # Luxembourgish
|
|
"lg": "lg", # Luganda
|
|
"ln": "ln", # Lingala
|
|
"mg": "mg", # Malagasy
|
|
"mi": "mi", # Maori
|
|
"mk": "mk", # Macedonian
|
|
"mt": "mt", # Maltese
|
|
"ny": "ny", # Chichewa
|
|
"ps": "ps", # Pashto
|
|
"sn": "sn", # Shona
|
|
"so": "so", # Somali
|
|
"st": "st", # Sesotho
|
|
"su": "su", # Sundanese
|
|
"tg": "tg", # Tajik
|
|
"tk": "tk", # Turkmen
|
|
"ug": "ug", # Uyghur
|
|
"yi": "yi", # Yiddish
|
|
}
|
|
|
|
# NLLB-200 supported language codes (204 languages from FLORES-200)
|
|
# NLLB uses format: xxx_Yyyy (language_Script)
|
|
# Full list: https://github.com/facebookresearch/flores/blob/main/flores200/README.md
|
|
self.nllb200_lang_codes = {
|
|
"ace_arab": "ace_Arab", # Acehnese (Arabic script)
|
|
"ace": "ace_Latn", # Acehnese
|
|
"acm": "acm_Arab", # Mesopotamian Arabic
|
|
"acq": "acq_Arab", # Ta'izzi-Adeni Arabic
|
|
"aeb": "aeb_Arab", # Tunisian Arabic
|
|
"af": "afr_Latn", # Afrikaans
|
|
"ajp": "ajp_Arab", # South Levantine Arabic
|
|
"aka": "aka_Latn", # Akan
|
|
"am": "amh_Ethi", # Amharic
|
|
"apc": "apc_Arab", # North Levantine Arabic
|
|
"ar": "arb_Arab", # Arabic (Standard)
|
|
"ar_latn": "arb_Latn", # Arabic (Latin script)
|
|
"ars": "ars_Arab", # Najdi Arabic
|
|
"ary": "ary_Arab", # Moroccan Arabic
|
|
"arz": "arz_Arab", # Egyptian Arabic
|
|
"as": "asm_Beng", # Assamese
|
|
"ast": "ast_Latn", # Asturian
|
|
"awa": "awa_Deva", # Awadhi
|
|
"ayr": "ayr_Latn", # Central Aymara
|
|
"azb": "azb_Arab", # South Azerbaijani
|
|
"az": "azj_Latn", # Azerbaijani
|
|
"ba": "bak_Cyrl", # Bashkir
|
|
"bam": "bam_Latn", # Bambara
|
|
"ban": "ban_Latn", # Balinese
|
|
"be": "bel_Cyrl", # Belarusian
|
|
"bem": "bem_Latn", # Bemba
|
|
"bn": "ben_Beng", # Bengali
|
|
"bho": "bho_Deva", # Bhojpuri
|
|
"bjn_arab": "bjn_Arab", # Banjar (Arabic script)
|
|
"bjn": "bjn_Latn", # Banjar
|
|
"bo": "bod_Tibt", # Tibetan
|
|
"bs": "bos_Latn", # Bosnian
|
|
"bug": "bug_Latn", # Buginese
|
|
"bg": "bul_Cyrl", # Bulgarian
|
|
"ca": "cat_Latn", # Catalan
|
|
"ceb": "ceb_Latn", # Cebuano
|
|
"cs": "ces_Latn", # Czech
|
|
"cjk": "cjk_Latn", # Chokwe
|
|
"ckb": "ckb_Arab", # Central Kurdish
|
|
"crh": "crh_Latn", # Crimean Tatar
|
|
"cy": "cym_Latn", # Welsh
|
|
"da": "dan_Latn", # Danish
|
|
"de": "deu_Latn", # German
|
|
"dik": "dik_Latn", # Southwestern Dinka
|
|
"dyu": "dyu_Latn", # Dyula
|
|
"dz": "dzo_Tibt", # Dzongkha
|
|
"el": "ell_Grek", # Greek
|
|
"en": "eng_Latn", # English
|
|
"eo": "epo_Latn", # Esperanto
|
|
"et": "est_Latn", # Estonian
|
|
"eu": "eus_Latn", # Basque
|
|
"ee": "ewe_Latn", # Ewe
|
|
"fo": "fao_Latn", # Faroese
|
|
"fj": "fij_Latn", # Fijian
|
|
"fi": "fin_Latn", # Finnish
|
|
"fon": "fon_Latn", # Fon
|
|
"fr": "fra_Latn", # French
|
|
"fur": "fur_Latn", # Friulian
|
|
"fuv": "fuv_Latn", # Nigerian Fulfulde
|
|
"om": "gaz_Latn", # West Central Oromo
|
|
"gd": "gla_Latn", # Scottish Gaelic
|
|
"ga": "gle_Latn", # Irish
|
|
"gl": "glg_Latn", # Galician
|
|
"gn": "grn_Latn", # Guarani
|
|
"gu": "guj_Gujr", # Gujarati
|
|
"ht": "hat_Latn", # Haitian Creole
|
|
"ha": "hau_Latn", # Hausa
|
|
"he": "heb_Hebr", # Hebrew
|
|
"hi": "hin_Deva", # Hindi
|
|
"hne": "hne_Deva", # Chhattisgarhi
|
|
"hr": "hrv_Latn", # Croatian
|
|
"hu": "hun_Latn", # Hungarian
|
|
"hy": "hye_Armn", # Armenian
|
|
"ig": "ibo_Latn", # Igbo
|
|
"ilo": "ilo_Latn", # Iloko
|
|
"id": "ind_Latn", # Indonesian
|
|
"is": "isl_Latn", # Icelandic
|
|
"it": "ita_Latn", # Italian
|
|
"jv": "jav_Latn", # Javanese
|
|
"ja": "jpn_Jpan", # Japanese
|
|
"kab": "kab_Latn", # Kabyle
|
|
"kac": "kac_Latn", # Jingpho
|
|
"kam": "kam_Latn", # Kamba
|
|
"kn": "kan_Knda", # Kannada
|
|
"ks": "kas_Arab", # Kashmiri (Arabic)
|
|
"ks_deva": "kas_Deva", # Kashmiri (Devanagari)
|
|
"ka": "kat_Geor", # Georgian
|
|
"kk": "kaz_Cyrl", # Kazakh
|
|
"kbp": "kbp_Latn", # Kabiyè
|
|
"kea": "kea_Latn", # Kabuverdianu
|
|
"mn": "khk_Cyrl", # Mongolian (Halh)
|
|
"km": "khm_Khmr", # Khmer
|
|
"ki": "kik_Latn", # Kikuyu
|
|
"rw": "kin_Latn", # Kinyarwanda
|
|
"ky": "kir_Cyrl", # Kyrgyz
|
|
"kmb": "kmb_Latn", # Kimbundu
|
|
"ku": "kmr_Latn", # Kurdish (Kurmanji)
|
|
"knc_arab": "knc_Arab", # Kanuri (Arabic script)
|
|
"knc": "knc_Latn", # Kanuri
|
|
"kg": "kon_Latn", # Kongo
|
|
"ko": "kor_Hang", # Korean
|
|
"lo": "lao_Laoo", # Lao
|
|
"lij": "lij_Latn", # Ligurian
|
|
"li": "lim_Latn", # Limburgish
|
|
"ln": "lin_Latn", # Lingala
|
|
"lt": "lit_Latn", # Lithuanian
|
|
"lmo": "lmo_Latn", # Lombard
|
|
"ltg": "ltg_Latn", # Latgalian
|
|
"lb": "ltz_Latn", # Luxembourgish
|
|
"lua": "lua_Latn", # Luba-Kasai
|
|
"lg": "lug_Latn", # Luganda
|
|
"luo": "luo_Latn", # Luo
|
|
"lus": "lus_Latn", # Mizo
|
|
"lv": "lvs_Latn", # Latvian
|
|
"mag": "mag_Deva", # Magahi
|
|
"mai": "mai_Deva", # Maithili
|
|
"ml": "mal_Mlym", # Malayalam
|
|
"mr": "mar_Deva", # Marathi
|
|
"min_arab": "min_Arab", # Minangkabau (Arabic)
|
|
"min": "min_Latn", # Minangkabau
|
|
"mk": "mkd_Cyrl", # Macedonian
|
|
"mt": "mlt_Latn", # Maltese
|
|
"mni": "mni_Beng", # Meitei
|
|
"mos": "mos_Latn", # Mossi
|
|
"mi": "mri_Latn", # Maori
|
|
"my": "mya_Mymr", # Burmese
|
|
"nl": "nld_Latn", # Dutch
|
|
"nn": "nno_Latn", # Norwegian Nynorsk
|
|
"nb": "nob_Latn", # Norwegian Bokmål
|
|
"ne": "npi_Deva", # Nepali
|
|
"nso": "nso_Latn", # Northern Sotho
|
|
"nus": "nus_Latn", # Nuer
|
|
"ny": "nya_Latn", # Chichewa
|
|
"oc": "oci_Latn", # Occitan
|
|
"or": "ory_Orya", # Odia
|
|
"pag": "pag_Latn", # Pangasinan
|
|
"pa": "pan_Guru", # Punjabi
|
|
"pap": "pap_Latn", # Papiamento
|
|
"ps": "pbt_Arab", # Pashto (Southern)
|
|
"fa": "pes_Arab", # Persian
|
|
"mg": "plt_Latn", # Malagasy
|
|
"pl": "pol_Latn", # Polish
|
|
"pt": "por_Latn", # Portuguese
|
|
"prs": "prs_Arab", # Dari
|
|
"qu": "quy_Latn", # Ayacucho Quechua
|
|
"ro": "ron_Latn", # Romanian
|
|
"rn": "run_Latn", # Rundi
|
|
"ru": "rus_Cyrl", # Russian
|
|
"sg": "sag_Latn", # Sango
|
|
"sa": "san_Deva", # Sanskrit
|
|
"sat": "sat_Olck", # Santali
|
|
"scn": "scn_Latn", # Sicilian
|
|
"shn": "shn_Mymr", # Shan
|
|
"si": "sin_Sinh", # Sinhala
|
|
"sk": "slk_Latn", # Slovak
|
|
"sl": "slv_Latn", # Slovenian
|
|
"sm": "smo_Latn", # Samoan
|
|
"sn": "sna_Latn", # Shona
|
|
"sd": "snd_Arab", # Sindhi
|
|
"so": "som_Latn", # Somali
|
|
"st": "sot_Latn", # Sesotho
|
|
"es": "spa_Latn", # Spanish
|
|
"sq": "als_Latn", # Albanian (Tosk)
|
|
"sc": "srd_Latn", # Sardinian
|
|
"sr": "srp_Cyrl", # Serbian
|
|
"ss": "ssw_Latn", # Swazi
|
|
"su": "sun_Latn", # Sundanese
|
|
"sv": "swe_Latn", # Swedish
|
|
"sw": "swh_Latn", # Swahili
|
|
"szl": "szl_Latn", # Silesian
|
|
"ta": "tam_Taml", # Tamil
|
|
"taq": "taq_Latn", # Tamasheq (Latin)
|
|
"taq_tfng": "taq_Tfng", # Tamasheq (Tifinagh)
|
|
"tt": "tat_Cyrl", # Tatar
|
|
"te": "tel_Telu", # Telugu
|
|
"tg": "tgk_Cyrl", # Tajik
|
|
"tl": "tgl_Latn", # Tagalog
|
|
"th": "tha_Thai", # Thai
|
|
"ti": "tir_Ethi", # Tigrinya
|
|
"tpi": "tpi_Latn", # Tok Pisin
|
|
"tn": "tsn_Latn", # Tswana
|
|
"ts": "tso_Latn", # Tsonga
|
|
"tk": "tuk_Latn", # Turkmen
|
|
"tum": "tum_Latn", # Tumbuka
|
|
"tr": "tur_Latn", # Turkish
|
|
"tw": "twi_Latn", # Twi
|
|
"tzm": "tzm_Tfng", # Central Atlas Tamazight
|
|
"ug": "uig_Arab", # Uyghur
|
|
"uk": "ukr_Cyrl", # Ukrainian
|
|
"umb": "umb_Latn", # Umbundu
|
|
"ur": "urd_Arab", # Urdu
|
|
"uz": "uzn_Latn", # Uzbek (Northern)
|
|
"vec": "vec_Latn", # Venetian
|
|
"vi": "vie_Latn", # Vietnamese
|
|
"war": "war_Latn", # Waray
|
|
"wo": "wol_Latn", # Wolof
|
|
"xh": "xho_Latn", # Xhosa
|
|
"yi": "ydd_Hebr", # Eastern Yiddish
|
|
"yo": "yor_Latn", # Yoruba
|
|
"yue": "yue_Hant", # Cantonese
|
|
"zh": "zho_Hans", # Chinese (Simplified)
|
|
"zh_hant": "zho_Hant", # Chinese (Traditional)
|
|
"ms": "zsm_Latn", # Malay (Standard)
|
|
"zu": "zul_Latn", # Zulu
|
|
}
|
|
|
|
def get_supported_languages(self, model_type: str = "m2m100") -> Dict[str, str]:
|
|
"""Get supported language codes for the specified model"""
|
|
if model_type == "m2m100":
|
|
return self.m2m100_lang_codes
|
|
elif model_type == "nllb200":
|
|
return self.nllb200_lang_codes
|
|
else:
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
def _get_model_info(self, source_lang: str, target_lang: str, model_type: str) -> tuple[str, str, str, str]:
|
|
"""Get the model name and language codes for translation"""
|
|
if model_type == "m2m100":
|
|
model_name = "facebook/m2m100_418M"
|
|
lang_codes = self.m2m100_lang_codes
|
|
elif model_type == "nllb200":
|
|
model_name = "facebook/nllb-200-distilled-600M"
|
|
lang_codes = self.nllb200_lang_codes
|
|
else:
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
src_code = lang_codes.get(source_lang)
|
|
tgt_code = lang_codes.get(target_lang)
|
|
|
|
if not src_code:
|
|
raise ValueError(f"Source language '{source_lang}' not supported by {model_type}. Use /api/supported-languages?model={model_type} to see available languages.")
|
|
if not tgt_code:
|
|
raise ValueError(f"Target language '{target_lang}' not supported by {model_type}. Use /api/supported-languages?model={model_type} to see available languages.")
|
|
|
|
return model_name, src_code, tgt_code, model_type
|
|
|
|
def load_model(self, source_lang: str, target_lang: str, model_type: str = "m2m100") -> None:
|
|
"""Load translation model for specific language pair"""
|
|
model_name, _, _, _ = self._get_model_info(source_lang, target_lang, model_type)
|
|
|
|
if model_name in self.models:
|
|
logger.info(f"Model {model_name} already loaded")
|
|
return
|
|
|
|
try:
|
|
logger.info(f"Loading model: {model_name}")
|
|
|
|
if model_type == "m2m100":
|
|
tokenizer = M2M100Tokenizer.from_pretrained(
|
|
model_name,
|
|
cache_dir=settings.model_cache_dir
|
|
)
|
|
model = M2M100ForConditionalGeneration.from_pretrained(
|
|
model_name,
|
|
cache_dir=settings.model_cache_dir
|
|
).to(self.device)
|
|
elif model_type == "nllb200":
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name,
|
|
cache_dir=settings.model_cache_dir
|
|
)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
model_name,
|
|
cache_dir=settings.model_cache_dir
|
|
).to(self.device)
|
|
|
|
self.models[model_name] = {
|
|
"tokenizer": tokenizer,
|
|
"model": model,
|
|
"type": model_type
|
|
}
|
|
logger.info(f"Successfully loaded model: {model_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading model {model_name}: {str(e)}")
|
|
raise
|
|
|
|
def translate(self, text: str, source_lang: str, target_lang: str, model_type: str = "m2m100") -> tuple[str, str]:
|
|
"""
|
|
Translate text from source language to target language
|
|
|
|
Args:
|
|
text: Text to translate
|
|
source_lang: Source language code
|
|
target_lang: Target language code
|
|
model_type: Model to use ('m2m100' or 'nllb200')
|
|
|
|
Returns:
|
|
Tuple of (translated_text, model_name)
|
|
"""
|
|
model_name, src_code, tgt_code, _ = self._get_model_info(source_lang, target_lang, model_type)
|
|
|
|
# Load model if not already loaded
|
|
if model_name not in self.models:
|
|
self.load_model(source_lang, target_lang, model_type)
|
|
|
|
try:
|
|
tokenizer = self.models[model_name]["tokenizer"]
|
|
model = self.models[model_name]["model"]
|
|
|
|
if model_type == "m2m100":
|
|
# M2M100 uses src_lang attribute
|
|
tokenizer.src_lang = src_code
|
|
|
|
# Tokenize input
|
|
inputs = tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=settings.max_length
|
|
).to(self.device)
|
|
|
|
# Generate translation
|
|
generated_tokens = tokenizer.get_lang_id(tgt_code)
|
|
|
|
with torch.no_grad():
|
|
translated = model.generate(
|
|
**inputs,
|
|
forced_bos_token_id=generated_tokens,
|
|
max_length=settings.max_length
|
|
)
|
|
|
|
elif model_type == "nllb200":
|
|
# NLLB uses different approach with src_lang in tokenizer call
|
|
tokenizer.src_lang = src_code
|
|
|
|
# Tokenize input
|
|
inputs = tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=settings.max_length
|
|
).to(self.device)
|
|
|
|
# Generate translation - NLLB uses forced_bos_token_id
|
|
# Convert language code to token ID
|
|
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
|
|
|
|
with torch.no_grad():
|
|
translated = model.generate(
|
|
**inputs,
|
|
forced_bos_token_id=forced_bos_token_id,
|
|
max_length=settings.max_length
|
|
)
|
|
|
|
# Decode output
|
|
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
|
|
|
|
return translated_text, model_name
|
|
|
|
except Exception as e:
|
|
logger.error(f"Translation error: {str(e)}")
|
|
raise
|
|
|
|
def preload_all_models(self) -> None:
|
|
"""Preload commonly used translation models"""
|
|
# Only preload M2M100 by default (commercial-friendly)
|
|
language_pairs = [
|
|
("ms", "en", "m2m100"),
|
|
("en", "ms", "m2m100")
|
|
]
|
|
|
|
for source, target, model_type in language_pairs:
|
|
try:
|
|
self.load_model(source, target, model_type)
|
|
except Exception as e:
|
|
logger.warning(f"Could not preload model for {source}->{target} ({model_type}): {str(e)}")
|
|
|
|
def is_ready(self) -> bool:
|
|
"""Check if at least one model is loaded"""
|
|
return len(self.models) > 0
|
|
|
|
|
|
# Global translator instance
|
|
translator = TranslationService()
|