- FastAPI 백엔드 (audio-studio-api) - Next.js 프론트엔드 (audio-studio-ui) - Qwen3-TTS 엔진 (audio-studio-tts) - MusicGen 서비스 (audio-studio-musicgen) - Docker Compose 개발/운영 환경 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
200 lines
5.6 KiB
Python
200 lines
5.6 KiB
Python
"""MusicGen 서비스
|
|
|
|
Meta AudioCraft MusicGen을 사용한 AI 음악 생성
|
|
"""
|
|
|
|
import os
|
|
import io
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import soundfile as sf
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MusicGenService:
|
|
"""MusicGen 음악 생성 서비스"""
|
|
|
|
def __init__(self):
|
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
self.model_name = os.getenv("MODEL_NAME", "facebook/musicgen-medium")
|
|
self.model = None
|
|
self._initialized = False
|
|
|
|
async def initialize(self):
|
|
"""모델 초기화 (서버 시작 시 호출)"""
|
|
if self._initialized:
|
|
return
|
|
|
|
logger.info(f"MusicGen 모델 로딩 중: {self.model_name}")
|
|
|
|
try:
|
|
from audiocraft.models import MusicGen
|
|
|
|
self.model = MusicGen.get_pretrained(self.model_name)
|
|
self.model.set_generation_params(
|
|
use_sampling=True,
|
|
top_k=250,
|
|
duration=30, # 기본 30초
|
|
)
|
|
|
|
self._initialized = True
|
|
logger.info(f"MusicGen 모델 로드 완료 (device: {self.device})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"MusicGen 모델 로드 실패: {e}")
|
|
raise
|
|
|
|
async def generate(
|
|
self,
|
|
prompt: str,
|
|
duration: int = 30,
|
|
top_k: int = 250,
|
|
temperature: float = 1.0,
|
|
) -> bytes:
|
|
"""텍스트 프롬프트로 음악 생성
|
|
|
|
Args:
|
|
prompt: 음악 설명 (예: "upbeat electronic music for gaming")
|
|
duration: 생성 길이 (초, 최대 30초)
|
|
top_k: top-k 샘플링 파라미터
|
|
temperature: 생성 다양성 (높을수록 다양)
|
|
|
|
Returns:
|
|
WAV 바이트
|
|
"""
|
|
if not self._initialized:
|
|
await self.initialize()
|
|
|
|
# 파라미터 제한
|
|
duration = min(max(duration, 5), 30)
|
|
|
|
logger.info(f"음악 생성 시작: prompt='{prompt[:50]}...', duration={duration}s")
|
|
|
|
try:
|
|
# 생성 파라미터 설정
|
|
self.model.set_generation_params(
|
|
use_sampling=True,
|
|
top_k=top_k,
|
|
top_p=0.0,
|
|
temperature=temperature,
|
|
duration=duration,
|
|
)
|
|
|
|
# 생성
|
|
wav = self.model.generate([prompt])
|
|
|
|
# 결과 처리 (첫 번째 결과만)
|
|
audio_data = wav[0].cpu().numpy()
|
|
|
|
# 스테레오인 경우 모노로 변환
|
|
if len(audio_data.shape) > 1:
|
|
audio_data = audio_data.mean(axis=0)
|
|
|
|
# WAV 바이트로 변환
|
|
buffer = io.BytesIO()
|
|
sf.write(buffer, audio_data, 32000, format='WAV') # MusicGen은 32kHz
|
|
buffer.seek(0)
|
|
|
|
logger.info(f"음악 생성 완료: {duration}초")
|
|
return buffer.read()
|
|
|
|
except Exception as e:
|
|
logger.error(f"음악 생성 실패: {e}")
|
|
raise
|
|
|
|
async def generate_with_melody(
|
|
self,
|
|
prompt: str,
|
|
melody_audio: bytes,
|
|
duration: int = 30,
|
|
) -> bytes:
|
|
"""멜로디 조건부 음악 생성
|
|
|
|
Args:
|
|
prompt: 음악 설명
|
|
melody_audio: 참조 멜로디 오디오 (WAV)
|
|
duration: 생성 길이
|
|
|
|
Returns:
|
|
WAV 바이트
|
|
"""
|
|
if not self._initialized:
|
|
await self.initialize()
|
|
|
|
duration = min(max(duration, 5), 30)
|
|
|
|
logger.info(f"멜로디 기반 음악 생성: prompt='{prompt[:50]}...', duration={duration}s")
|
|
|
|
try:
|
|
# 멜로디 로드
|
|
import torchaudio
|
|
|
|
buffer = io.BytesIO(melody_audio)
|
|
melody, sr = torchaudio.load(buffer)
|
|
|
|
# 리샘플링 (32kHz로)
|
|
if sr != 32000:
|
|
melody = torchaudio.functional.resample(melody, sr, 32000)
|
|
|
|
# 모노로 변환
|
|
if melody.shape[0] > 1:
|
|
melody = melody.mean(dim=0, keepdim=True)
|
|
|
|
# 길이 제한 (30초)
|
|
max_samples = 32000 * 30
|
|
if melody.shape[1] > max_samples:
|
|
melody = melody[:, :max_samples]
|
|
|
|
# 생성 파라미터 설정
|
|
self.model.set_generation_params(
|
|
use_sampling=True,
|
|
top_k=250,
|
|
duration=duration,
|
|
)
|
|
|
|
# 멜로디 조건부 생성
|
|
wav = self.model.generate_with_chroma(
|
|
descriptions=[prompt],
|
|
melody_wavs=melody.unsqueeze(0).to(self.device),
|
|
melody_sample_rate=32000,
|
|
progress=True,
|
|
)
|
|
|
|
# 결과 처리
|
|
audio_data = wav[0].cpu().numpy()
|
|
if len(audio_data.shape) > 1:
|
|
audio_data = audio_data.mean(axis=0)
|
|
|
|
buffer = io.BytesIO()
|
|
sf.write(buffer, audio_data, 32000, format='WAV')
|
|
buffer.seek(0)
|
|
|
|
logger.info(f"멜로디 기반 음악 생성 완료")
|
|
return buffer.read()
|
|
|
|
except Exception as e:
|
|
logger.error(f"멜로디 기반 생성 실패: {e}")
|
|
raise
|
|
|
|
def is_initialized(self) -> bool:
|
|
"""초기화 상태 확인"""
|
|
return self._initialized
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""모델 정보 반환"""
|
|
return {
|
|
"model_name": self.model_name,
|
|
"device": self.device,
|
|
"initialized": self._initialized,
|
|
"max_duration": 30,
|
|
"sample_rate": 32000,
|
|
}
|
|
|
|
|
|
# 싱글톤 인스턴스
|
|
musicgen_service = MusicGenService()
|