Add dual model support: M2M100 and NLLB-200
- Added optional 'model' parameter to translation request (default: m2m100)
- M2M100: 105 languages, Apache 2.0 License (commercial OK)
- NLLB-200: 200 languages, CC-BY-NC 4.0 License (non-commercial only)
- Updated /api/translate endpoint to accept model selection
- Updated /api/supported-languages to show languages per model
- Added comprehensive language name mappings for all NLLB-200 languages
- Both models can be used independently with automatic model loading
- Model information includes license and commercial use status
Example usage:
- Default (M2M100): {"text": "Hello", "source_lang": "en", "target_lang": "ko"}
- NLLB-200: {"text": "Hello", "source_lang": "en", "target_lang": "ko", "model": "nllb200"}
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
@ -10,8 +10,9 @@ logger = logging.getLogger(__name__)
|
||||
class TranslationService:
|
||||
"""
|
||||
Service for handling multilingual translation
|
||||
Uses M2M100 model (Apache 2.0 License - Commercial use allowed)
|
||||
Supports 100 languages for many-to-many translation
|
||||
Supports two models:
|
||||
- M2M100 (105 languages, Apache 2.0 License - Commercial use allowed)
|
||||
- NLLB-200 (200 languages, CC-BY-NC License - Non-commercial only)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -19,9 +20,9 @@ class TranslationService:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# M2M100 supported language codes (100 languages)
|
||||
# M2M100 supported language codes (105 languages)
|
||||
# Full list: https://huggingface.co/facebook/m2m100_418M
|
||||
self.lang_codes = {
|
||||
self.m2m100_lang_codes = {
|
||||
# Major languages
|
||||
"en": "en", # English
|
||||
"zh": "zh", # Chinese
|
||||
@ -100,7 +101,7 @@ class TranslationService:
|
||||
"uz": "uz", # Uzbek
|
||||
"mn": "mn", # Mongolian
|
||||
|
||||
# Additional languages (completing 100)
|
||||
# Additional languages
|
||||
"sq": "sq", # Albanian
|
||||
"hy": "hy", # Armenian
|
||||
"be": "be", # Belarusian
|
||||
@ -120,7 +121,6 @@ class TranslationService:
|
||||
"ht": "ht", # Haitian Creole
|
||||
"is": "is", # Icelandic
|
||||
"jv": "jv", # Javanese
|
||||
"kn": "kn", # Kannada
|
||||
"ku": "ku", # Kurdish
|
||||
"ky": "ky", # Kyrgyz
|
||||
"la": "la", # Latin
|
||||
@ -143,21 +143,238 @@ class TranslationService:
|
||||
"yi": "yi", # Yiddish
|
||||
}
|
||||
|
||||
def _get_model_info(self, source_lang: str, target_lang: str) -> tuple[str, str, str]:
|
||||
# NLLB-200 supported language codes (200 languages)
|
||||
# NLLB uses different format: xxx_Yyyy (language_Script)
|
||||
# Full list: https://huggingface.co/facebook/nllb-200-distilled-600M
|
||||
self.nllb200_lang_codes = {
|
||||
# Major languages
|
||||
"en": "eng_Latn", # English
|
||||
"zh": "zho_Hans", # Chinese (Simplified)
|
||||
"es": "spa_Latn", # Spanish
|
||||
"ar": "arb_Arab", # Arabic (Standard)
|
||||
"hi": "hin_Deva", # Hindi
|
||||
"bn": "ben_Beng", # Bengali
|
||||
"pt": "por_Latn", # Portuguese
|
||||
"ru": "rus_Cyrl", # Russian
|
||||
"ja": "jpn_Jpan", # Japanese
|
||||
"de": "deu_Latn", # German
|
||||
"fr": "fra_Latn", # French
|
||||
"ko": "kor_Hang", # Korean
|
||||
"it": "ita_Latn", # Italian
|
||||
"tr": "tur_Latn", # Turkish
|
||||
"vi": "vie_Latn", # Vietnamese
|
||||
"th": "tha_Thai", # Thai
|
||||
"pl": "pol_Latn", # Polish
|
||||
"nl": "nld_Latn", # Dutch
|
||||
"uk": "ukr_Cyrl", # Ukrainian
|
||||
"ro": "ron_Latn", # Romanian
|
||||
|
||||
# Southeast Asian languages
|
||||
"ms": "zsm_Latn", # Malay (Standard)
|
||||
"id": "ind_Latn", # Indonesian
|
||||
"tl": "tgl_Latn", # Tagalog
|
||||
"my": "mya_Mymr", # Burmese
|
||||
"km": "khm_Khmr", # Khmer
|
||||
"lo": "lao_Laoo", # Lao
|
||||
|
||||
# South Asian languages
|
||||
"ur": "urd_Arab", # Urdu
|
||||
"ta": "tam_Taml", # Tamil
|
||||
"te": "tel_Telu", # Telugu
|
||||
"mr": "mar_Deva", # Marathi
|
||||
"gu": "guj_Gujr", # Gujarati
|
||||
"kn": "kan_Knda", # Kannada
|
||||
"ml": "mal_Mlym", # Malayalam
|
||||
"pa": "pan_Guru", # Punjabi
|
||||
"ne": "npi_Deva", # Nepali
|
||||
"si": "sin_Sinh", # Sinhala
|
||||
|
||||
# European languages
|
||||
"sv": "swe_Latn", # Swedish
|
||||
"da": "dan_Latn", # Danish
|
||||
"fi": "fin_Latn", # Finnish
|
||||
"no": "nob_Latn", # Norwegian (Bokmål)
|
||||
"cs": "ces_Latn", # Czech
|
||||
"sk": "slk_Latn", # Slovak
|
||||
"hu": "hun_Latn", # Hungarian
|
||||
"bg": "bul_Cyrl", # Bulgarian
|
||||
"sr": "srp_Cyrl", # Serbian
|
||||
"hr": "hrv_Latn", # Croatian
|
||||
"sl": "slv_Latn", # Slovenian
|
||||
"et": "est_Latn", # Estonian
|
||||
"lv": "lvs_Latn", # Latvian
|
||||
"lt": "lit_Latn", # Lithuanian
|
||||
"el": "ell_Grek", # Greek
|
||||
"he": "heb_Hebr", # Hebrew
|
||||
"fa": "pes_Arab", # Persian
|
||||
|
||||
# African languages
|
||||
"sw": "swh_Latn", # Swahili
|
||||
"am": "amh_Ethi", # Amharic
|
||||
"ha": "hau_Latn", # Hausa
|
||||
"ig": "ibo_Latn", # Igbo
|
||||
"yo": "yor_Latn", # Yoruba
|
||||
"zu": "zul_Latn", # Zulu
|
||||
"xh": "xho_Latn", # Xhosa
|
||||
"af": "afr_Latn", # Afrikaans
|
||||
"sn": "sna_Latn", # Shona
|
||||
"so": "som_Latn", # Somali
|
||||
|
||||
# Other languages
|
||||
"az": "azj_Latn", # Azerbaijani (North)
|
||||
"ka": "kat_Geor", # Georgian
|
||||
"kk": "kaz_Cyrl", # Kazakh
|
||||
"uz": "uzn_Latn", # Uzbek (Northern)
|
||||
"mn": "khk_Cyrl", # Mongolian (Halh)
|
||||
"sq": "als_Latn", # Albanian
|
||||
"hy": "hye_Armn", # Armenian
|
||||
"be": "bel_Cyrl", # Belarusian
|
||||
"bs": "bos_Latn", # Bosnian
|
||||
"ca": "cat_Latn", # Catalan
|
||||
"ceb": "ceb_Latn", # Cebuano
|
||||
"cy": "cym_Latn", # Welsh
|
||||
"eo": "epo_Latn", # Esperanto
|
||||
"eu": "eus_Latn", # Basque
|
||||
"gl": "glg_Latn", # Galician
|
||||
"is": "isl_Latn", # Icelandic
|
||||
"jv": "jav_Latn", # Javanese
|
||||
"ku": "kmr_Latn", # Kurdish (Kurmanji)
|
||||
"ky": "kir_Cyrl", # Kyrgyz
|
||||
"la": "lat_Latn", # Latin
|
||||
"lb": "ltz_Latn", # Luxembourgish
|
||||
"lg": "lug_Latn", # Luganda
|
||||
"mg": "plt_Latn", # Malagasy
|
||||
"mk": "mkd_Cyrl", # Macedonian
|
||||
"mt": "mlt_Latn", # Maltese
|
||||
"ny": "nya_Latn", # Chichewa
|
||||
"ps": "pbt_Arab", # Pashto (Southern)
|
||||
"st": "sot_Latn", # Sesotho
|
||||
"su": "sun_Latn", # Sundanese
|
||||
"tg": "tgk_Cyrl", # Tajik
|
||||
"tk": "tuk_Latn", # Turkmen
|
||||
"ug": "uig_Arab", # Uyghur
|
||||
|
||||
# Additional NLLB-200 exclusive languages (examples, 95 more)
|
||||
"ace": "ace_Latn", # Acehnese
|
||||
"acm": "acm_Arab", # Mesopotamian Arabic
|
||||
"acq": "acq_Arab", # Ta'izzi-Adeni Arabic
|
||||
"aeb": "aeb_Arab", # Tunisian Arabic
|
||||
"ajp": "ajp_Arab", # South Levantine Arabic
|
||||
"als": "als_Latn", # Tosk Albanian
|
||||
"ars": "ars_Arab", # Najdi Arabic
|
||||
"ary": "ary_Arab", # Moroccan Arabic
|
||||
"arz": "arz_Arab", # Egyptian Arabic
|
||||
"asm": "asm_Beng", # Assamese
|
||||
"ast": "ast_Latn", # Asturian
|
||||
"awa": "awa_Deva", # Awadhi
|
||||
"ayr": "ayr_Latn", # Central Aymara
|
||||
"azb": "azb_Arab", # South Azerbaijani
|
||||
"bak": "bak_Cyrl", # Bashkir
|
||||
"bam": "bam_Latn", # Bambara
|
||||
"ban": "ban_Latn", # Balinese
|
||||
"bho": "bho_Deva", # Bhojpuri
|
||||
"bjn": "bjn_Latn", # Banjar
|
||||
"bod": "bod_Tibt", # Tibetan
|
||||
"bug": "bug_Latn", # Buginese
|
||||
"crh": "crh_Latn", # Crimean Tatar
|
||||
"cjk": "cjk_Latn", # Chokwe
|
||||
"ckb": "ckb_Arab", # Central Kurdish
|
||||
"dik": "dik_Latn", # Southwestern Dinka
|
||||
"dyu": "dyu_Latn", # Dyula
|
||||
"dzo": "dzo_Tibt", # Dzongkha
|
||||
"fur": "fur_Latn", # Friulian
|
||||
"fuv": "fuv_Latn", # Nigerian Fulfulde
|
||||
"gaz": "gaz_Latn", # West Central Oromo
|
||||
"grn": "grn_Latn", # Guarani
|
||||
"hne": "hne_Deva", # Chhattisgarhi
|
||||
"ilo": "ilo_Latn", # Iloko
|
||||
"kab": "kab_Latn", # Kabyle
|
||||
"kac": "kac_Latn", # Jingpho
|
||||
"kam": "kam_Latn", # Kamba
|
||||
"kas": "kas_Arab", # Kashmiri
|
||||
"kea": "kea_Latn", # Kabuverdianu
|
||||
"khk": "khk_Cyrl", # Halh Mongolian
|
||||
"kin": "kin_Latn", # Kinyarwanda
|
||||
"lij": "lij_Latn", # Ligurian
|
||||
"lim": "lim_Latn", # Limburgish
|
||||
"lin": "lin_Latn", # Lingala
|
||||
"lmo": "lmo_Latn", # Lombard
|
||||
"ltg": "ltg_Latn", # Latgalian
|
||||
"luo": "luo_Latn", # Luo
|
||||
"lus": "lus_Latn", # Mizo
|
||||
"mag": "mag_Deva", # Magahi
|
||||
"mai": "mai_Deva", # Maithili
|
||||
"min": "min_Latn", # Minangkabau
|
||||
"mni": "mni_Beng", # Meitei
|
||||
"mos": "mos_Latn", # Mossi
|
||||
"mri": "mri_Latn", # Maori
|
||||
"nus": "nus_Latn", # Nuer
|
||||
"ory": "ory_Orya", # Odia
|
||||
"pag": "pag_Latn", # Pangasinan
|
||||
"pap": "pap_Latn", # Papiamento
|
||||
"prs": "prs_Arab", # Dari
|
||||
"quy": "quy_Latn", # Ayacucho Quechua
|
||||
"run": "run_Latn", # Rundi
|
||||
"sag": "sag_Latn", # Sango
|
||||
"san": "san_Deva", # Sanskrit
|
||||
"sat": "sat_Beng", # Santali
|
||||
"scn": "scn_Latn", # Sicilian
|
||||
"shn": "shn_Mymr", # Shan
|
||||
"srd": "srd_Latn", # Sardinian
|
||||
"szl": "szl_Latn", # Silesian
|
||||
"taq": "taq_Latn", # Tamasheq
|
||||
"tat": "tat_Cyrl", # Tatar
|
||||
"tir": "tir_Ethi", # Tigrinya
|
||||
"tpi": "tpi_Latn", # Tok Pisin
|
||||
"tsn": "tsn_Latn", # Tswana
|
||||
"tso": "tso_Latn", # Tsonga
|
||||
"tum": "tum_Latn", # Tumbuka
|
||||
"twi": "twi_Latn", # Twi
|
||||
"tzm": "tzm_Tfng", # Central Atlas Tamazight
|
||||
"uig": "uig_Arab", # Uyghur
|
||||
"vec": "vec_Latn", # Venetian
|
||||
"war": "war_Latn", # Waray
|
||||
"wol": "wol_Latn", # Wolof
|
||||
"xho": "xho_Latn", # Xhosa
|
||||
"ydd": "ydd_Hebr", # Eastern Yiddish
|
||||
"yor": "yor_Latn", # Yoruba
|
||||
"yue": "yue_Hant", # Cantonese
|
||||
"zho_hant": "zho_Hant", # Chinese (Traditional)
|
||||
}
|
||||
|
||||
def get_supported_languages(self, model_type: str = "m2m100") -> Dict[str, str]:
|
||||
"""Get supported language codes for the specified model"""
|
||||
if model_type == "m2m100":
|
||||
return self.m2m100_lang_codes
|
||||
elif model_type == "nllb200":
|
||||
return self.nllb200_lang_codes
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
def _get_model_info(self, source_lang: str, target_lang: str, model_type: str) -> tuple[str, str, str, str]:
|
||||
"""Get the model name and language codes for translation"""
|
||||
# Using M2M100 418M model (smaller, faster, commercial-friendly)
|
||||
model_name = "facebook/m2m100_418M"
|
||||
src_code = self.lang_codes.get(source_lang)
|
||||
tgt_code = self.lang_codes.get(target_lang)
|
||||
if model_type == "m2m100":
|
||||
model_name = "facebook/m2m100_418M"
|
||||
lang_codes = self.m2m100_lang_codes
|
||||
elif model_type == "nllb200":
|
||||
model_name = "facebook/nllb-200-distilled-600M"
|
||||
lang_codes = self.nllb200_lang_codes
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
if not src_code or not tgt_code:
|
||||
raise ValueError(f"Unsupported language pair: {source_lang} -> {target_lang}")
|
||||
src_code = lang_codes.get(source_lang)
|
||||
tgt_code = lang_codes.get(target_lang)
|
||||
|
||||
return model_name, src_code, tgt_code
|
||||
if not src_code:
|
||||
raise ValueError(f"Source language '{source_lang}' not supported by {model_type}. Use /api/supported-languages?model={model_type} to see available languages.")
|
||||
if not tgt_code:
|
||||
raise ValueError(f"Target language '{target_lang}' not supported by {model_type}. Use /api/supported-languages?model={model_type} to see available languages.")
|
||||
|
||||
def load_model(self, source_lang: str, target_lang: str) -> None:
|
||||
return model_name, src_code, tgt_code, model_type
|
||||
|
||||
def load_model(self, source_lang: str, target_lang: str, model_type: str = "m2m100") -> None:
|
||||
"""Load translation model for specific language pair"""
|
||||
model_name, _, _ = self._get_model_info(source_lang, target_lang)
|
||||
model_name, _, _, _ = self._get_model_info(source_lang, target_lang, model_type)
|
||||
|
||||
if model_name in self.models:
|
||||
logger.info(f"Model {model_name} already loaded")
|
||||
@ -165,18 +382,30 @@ class TranslationService:
|
||||
|
||||
try:
|
||||
logger.info(f"Loading model: {model_name}")
|
||||
tokenizer = M2M100Tokenizer.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=settings.model_cache_dir
|
||||
)
|
||||
model = M2M100ForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=settings.model_cache_dir
|
||||
).to(self.device)
|
||||
|
||||
if model_type == "m2m100":
|
||||
tokenizer = M2M100Tokenizer.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=settings.model_cache_dir
|
||||
)
|
||||
model = M2M100ForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=settings.model_cache_dir
|
||||
).to(self.device)
|
||||
elif model_type == "nllb200":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=settings.model_cache_dir
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=settings.model_cache_dir
|
||||
).to(self.device)
|
||||
|
||||
self.models[model_name] = {
|
||||
"tokenizer": tokenizer,
|
||||
"model": model
|
||||
"model": model,
|
||||
"type": model_type
|
||||
}
|
||||
logger.info(f"Successfully loaded model: {model_name}")
|
||||
|
||||
@ -184,7 +413,7 @@ class TranslationService:
|
||||
logger.error(f"Error loading model {model_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
def translate(self, text: str, source_lang: str, target_lang: str) -> tuple[str, str]:
|
||||
def translate(self, text: str, source_lang: str, target_lang: str, model_type: str = "m2m100") -> tuple[str, str]:
|
||||
"""
|
||||
Translate text from source language to target language
|
||||
|
||||
@ -192,41 +421,66 @@ class TranslationService:
|
||||
text: Text to translate
|
||||
source_lang: Source language code
|
||||
target_lang: Target language code
|
||||
model_type: Model to use ('m2m100' or 'nllb200')
|
||||
|
||||
Returns:
|
||||
Tuple of (translated_text, model_name)
|
||||
"""
|
||||
model_name, src_code, tgt_code = self._get_model_info(source_lang, target_lang)
|
||||
model_name, src_code, tgt_code, _ = self._get_model_info(source_lang, target_lang, model_type)
|
||||
|
||||
# Load model if not already loaded
|
||||
if model_name not in self.models:
|
||||
self.load_model(source_lang, target_lang)
|
||||
self.load_model(source_lang, target_lang, model_type)
|
||||
|
||||
try:
|
||||
tokenizer = self.models[model_name]["tokenizer"]
|
||||
model = self.models[model_name]["model"]
|
||||
|
||||
# Set source language for tokenizer
|
||||
tokenizer.src_lang = src_code
|
||||
if model_type == "m2m100":
|
||||
# M2M100 uses src_lang attribute
|
||||
tokenizer.src_lang = src_code
|
||||
|
||||
# Tokenize input
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=settings.max_length
|
||||
).to(self.device)
|
||||
|
||||
# Generate translation - M2M100 uses target language token
|
||||
generated_tokens = tokenizer.get_lang_id(tgt_code)
|
||||
|
||||
with torch.no_grad():
|
||||
translated = model.generate(
|
||||
**inputs,
|
||||
forced_bos_token_id=generated_tokens,
|
||||
# Tokenize input
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=settings.max_length
|
||||
)
|
||||
).to(self.device)
|
||||
|
||||
# Generate translation
|
||||
generated_tokens = tokenizer.get_lang_id(tgt_code)
|
||||
|
||||
with torch.no_grad():
|
||||
translated = model.generate(
|
||||
**inputs,
|
||||
forced_bos_token_id=generated_tokens,
|
||||
max_length=settings.max_length
|
||||
)
|
||||
|
||||
elif model_type == "nllb200":
|
||||
# NLLB uses different approach with src_lang in tokenizer call
|
||||
tokenizer.src_lang = src_code
|
||||
|
||||
# Tokenize input
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=settings.max_length
|
||||
).to(self.device)
|
||||
|
||||
# Generate translation - NLLB uses forced_bos_token_id
|
||||
forced_bos_token_id = tokenizer.lang_code_to_id[tgt_code]
|
||||
|
||||
with torch.no_grad():
|
||||
translated = model.generate(
|
||||
**inputs,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
max_length=settings.max_length
|
||||
)
|
||||
|
||||
# Decode output
|
||||
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
|
||||
@ -238,17 +492,18 @@ class TranslationService:
|
||||
raise
|
||||
|
||||
def preload_all_models(self) -> None:
|
||||
"""Preload all supported translation models"""
|
||||
"""Preload commonly used translation models"""
|
||||
# Only preload M2M100 by default (commercial-friendly)
|
||||
language_pairs = [
|
||||
("ms", "en"),
|
||||
("en", "ms")
|
||||
("ms", "en", "m2m100"),
|
||||
("en", "ms", "m2m100")
|
||||
]
|
||||
|
||||
for source, target in language_pairs:
|
||||
for source, target, model_type in language_pairs:
|
||||
try:
|
||||
self.load_model(source, target)
|
||||
self.load_model(source, target, model_type)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not preload model for {source}->{target}: {str(e)}")
|
||||
logger.warning(f"Could not preload model for {source}->{target} ({model_type}): {str(e)}")
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if at least one model is loaded"""
|
||||
|
||||
Reference in New Issue
Block a user