Files
site11/services/pipeline/image-generator/image_generator.py
2025-09-28 20:41:57 +09:00

256 lines
9.7 KiB
Python

"""
Image Generation Service
Replicate API를 사용한 이미지 생성 서비스
"""
import asyncio
import logging
import os
import sys
import base64
from typing import List, Dict, Any
import httpx
from io import BytesIO
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId
# Import from shared module
from shared.models import PipelineJob
from shared.queue_manager import QueueManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageGeneratorWorker:
def __init__(self):
self.queue_manager = QueueManager(
redis_url=os.getenv("REDIS_URL", "redis://redis:6379")
)
self.replicate_api_key = os.getenv("REPLICATE_API_TOKEN")
self.replicate_api_url = "https://api.replicate.com/v1/predictions"
# Stable Diffusion 모델 사용
self.model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017")
self.db_name = os.getenv("DB_NAME", "ai_writer_db")
self.db = None
async def start(self):
"""워커 시작"""
logger.info("Starting Image Generator Worker")
# Redis 연결
await self.queue_manager.connect()
# MongoDB 연결
client = AsyncIOMotorClient(self.mongodb_url)
self.db = client[self.db_name]
# API 키 확인
if not self.replicate_api_key:
logger.warning("Replicate API key not configured - using placeholder images")
# 메인 처리 루프
while True:
try:
# 큐에서 작업 가져오기
job = await self.queue_manager.dequeue('image_generation', timeout=5)
if job:
await self.process_job(job)
except Exception as e:
logger.error(f"Error in worker loop: {e}")
await asyncio.sleep(1)
async def process_job(self, job: PipelineJob):
"""이미지 생성 및 MongoDB 업데이트"""
try:
logger.info(f"Processing job {job.job_id} for image generation")
# MongoDB에서 기사 정보 가져오기
news_id = job.data.get('news_id')
mongodb_id = job.data.get('mongodb_id')
if not news_id:
logger.error(f"No news_id in job {job.job_id}")
await self.queue_manager.mark_failed('image_generation', job, "No news_id")
return
# MongoDB에서 한국어 기사 조회 (articles_ko)
article = await self.db.articles_ko.find_one({"news_id": news_id})
if not article:
logger.error(f"Article {news_id} not found in MongoDB")
await self.queue_manager.mark_failed('image_generation', job, "Article not found")
return
# 이미지 생성을 위한 프롬프트 생성 (한국어 기사 기반)
prompt = self._create_image_prompt_from_article(article)
# 이미지 생성 (최대 3개)
image_urls = []
for i in range(min(3, 1)): # 테스트를 위해 1개만 생성
image_url = await self._generate_image(prompt)
image_urls.append(image_url)
# API 속도 제한
if self.replicate_api_key and i < 2:
await asyncio.sleep(2)
# MongoDB 업데이트 (이미지 추가 - articles_ko)
await self.db.articles_ko.update_one(
{"news_id": news_id},
{
"$set": {
"images": image_urls,
"image_prompt": prompt
},
"$addToSet": {
"pipeline_stages": "image_generation"
}
}
)
logger.info(f"Updated article {news_id} with {len(image_urls)} images")
# 다음 단계로 전달 (번역)
job.stages_completed.append('image_generation')
job.stage = 'translation'
await self.queue_manager.enqueue('translation', job)
await self.queue_manager.mark_completed('image_generation', job.job_id)
except Exception as e:
logger.error(f"Error processing job {job.job_id}: {e}")
await self.queue_manager.mark_failed('image_generation', job, str(e))
def _create_image_prompt_from_article(self, article: Dict) -> str:
"""기사로부터 이미지 프롬프트 생성"""
# 키워드와 제목을 기반으로 프롬프트 생성
keyword = article.get('keyword', '')
title = article.get('title', '')
categories = article.get('categories', [])
# 카테고리 맵핑 (한글 -> 영어)
category_map = {
'기술': 'technology',
'경제': 'business',
'정치': 'politics',
'교육': 'education',
'사회': 'society',
'문화': 'culture',
'과학': 'science'
}
eng_categories = [category_map.get(cat, cat) for cat in categories]
category_str = ', '.join(eng_categories[:2]) if eng_categories else 'news'
# 뉴스 관련 이미지를 위한 프롬프트
prompt = f"News illustration for {keyword} {category_str}, professional, modern, clean design, high quality, 4k, no text"
return prompt
async def _generate_image(self, prompt: str) -> str:
"""Replicate API를 사용한 이미지 생성"""
try:
if not self.replicate_api_key:
# API 키가 없으면 플레이스홀더 이미지 URL 반환
return "https://via.placeholder.com/800x600.png?text=News+Image"
async with httpx.AsyncClient() as client:
# 예측 생성 요청
response = await client.post(
self.replicate_api_url,
headers={
"Authorization": f"Token {self.replicate_api_key}",
"Content-Type": "application/json"
},
json={
"version": self.model_version,
"input": {
"prompt": prompt,
"width": 768,
"height": 768,
"num_outputs": 1,
"scheduler": "K_EULER",
"num_inference_steps": 25,
"guidance_scale": 7.5,
"prompt_strength": 0.8,
"refine": "expert_ensemble_refiner",
"high_noise_frac": 0.8
}
},
timeout=60
)
if response.status_code in [200, 201]:
result = response.json()
prediction_id = result.get('id')
# 예측 결과 폴링
image_url = await self._poll_prediction(prediction_id)
return image_url
else:
logger.error(f"Replicate API error: {response.status_code}")
return "https://via.placeholder.com/800x600.png?text=Generation+Failed"
except Exception as e:
logger.error(f"Error generating image: {e}")
return "https://via.placeholder.com/800x600.png?text=Error"
async def _poll_prediction(self, prediction_id: str, max_attempts: int = 30) -> str:
"""예측 결과 폴링"""
try:
async with httpx.AsyncClient() as client:
for attempt in range(max_attempts):
response = await client.get(
f"{self.replicate_api_url}/{prediction_id}",
headers={
"Authorization": f"Token {self.replicate_api_key}"
},
timeout=30
)
if response.status_code == 200:
result = response.json()
status = result.get('status')
if status == 'succeeded':
output = result.get('output')
if output and isinstance(output, list) and len(output) > 0:
return output[0]
else:
return "https://via.placeholder.com/800x600.png?text=No+Output"
elif status == 'failed':
logger.error(f"Prediction failed: {result.get('error')}")
return "https://via.placeholder.com/800x600.png?text=Failed"
# 아직 처리중이면 대기
await asyncio.sleep(2)
else:
logger.error(f"Error polling prediction: {response.status_code}")
return "https://via.placeholder.com/800x600.png?text=Poll+Error"
# 최대 시도 횟수 초과
return "https://via.placeholder.com/800x600.png?text=Timeout"
except Exception as e:
logger.error(f"Error polling prediction: {e}")
return "https://via.placeholder.com/800x600.png?text=Poll+Exception"
async def stop(self):
"""워커 중지"""
await self.queue_manager.disconnect()
logger.info("Image Generator Worker stopped")
async def main():
"""메인 함수"""
worker = ImageGeneratorWorker()
try:
await worker.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
finally:
await worker.stop()
if __name__ == "__main__":
asyncio.run(main())