""" Image Generation Service Replicate API를 사용한 이미지 생성 서비스 """ import asyncio import logging import os import sys import base64 from typing import List, Dict, Any import httpx from io import BytesIO from motor.motor_asyncio import AsyncIOMotorClient from bson import ObjectId # Import from shared module from shared.models import PipelineJob from shared.queue_manager import QueueManager logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ImageGeneratorWorker: def __init__(self): self.queue_manager = QueueManager( redis_url=os.getenv("REDIS_URL", "redis://redis:6379") ) self.replicate_api_key = os.getenv("REPLICATE_API_TOKEN") self.replicate_api_url = "https://api.replicate.com/v1/predictions" # Stable Diffusion 모델 사용 self.model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" self.mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongodb:27017") self.db_name = os.getenv("DB_NAME", "ai_writer_db") self.db = None async def start(self): """워커 시작""" logger.info("Starting Image Generator Worker") # Redis 연결 await self.queue_manager.connect() # MongoDB 연결 client = AsyncIOMotorClient(self.mongodb_url) self.db = client[self.db_name] # API 키 확인 if not self.replicate_api_key: logger.warning("Replicate API key not configured - using placeholder images") # 메인 처리 루프 while True: try: # 큐에서 작업 가져오기 job = await self.queue_manager.dequeue('image_generation', timeout=5) if job: await self.process_job(job) except Exception as e: logger.error(f"Error in worker loop: {e}") await asyncio.sleep(1) async def process_job(self, job: PipelineJob): """이미지 생성 및 MongoDB 업데이트""" try: logger.info(f"Processing job {job.job_id} for image generation") # MongoDB에서 기사 정보 가져오기 news_id = job.data.get('news_id') mongodb_id = job.data.get('mongodb_id') if not news_id: logger.error(f"No news_id in job {job.job_id}") await self.queue_manager.mark_failed('image_generation', job, "No news_id") return # MongoDB에서 한국어 기사 조회 (articles_ko) article = await self.db.articles_ko.find_one({"news_id": news_id}) if not article: logger.error(f"Article {news_id} not found in MongoDB") await self.queue_manager.mark_failed('image_generation', job, "Article not found") return # 이미지 생성을 위한 프롬프트 생성 (한국어 기사 기반) prompt = self._create_image_prompt_from_article(article) # 이미지 생성 (최대 3개) image_urls = [] for i in range(min(3, 1)): # 테스트를 위해 1개만 생성 image_url = await self._generate_image(prompt) image_urls.append(image_url) # API 속도 제한 if self.replicate_api_key and i < 2: await asyncio.sleep(2) # MongoDB 업데이트 (이미지 추가 - articles_ko) await self.db.articles_ko.update_one( {"news_id": news_id}, { "$set": { "images": image_urls, "image_prompt": prompt }, "$addToSet": { "pipeline_stages": "image_generation" } } ) logger.info(f"Updated article {news_id} with {len(image_urls)} images") # 다음 단계로 전달 (번역) job.stages_completed.append('image_generation') job.stage = 'translation' await self.queue_manager.enqueue('translation', job) await self.queue_manager.mark_completed('image_generation', job.job_id) except Exception as e: logger.error(f"Error processing job {job.job_id}: {e}") await self.queue_manager.mark_failed('image_generation', job, str(e)) def _create_image_prompt_from_article(self, article: Dict) -> str: """기사로부터 이미지 프롬프트 생성""" # 키워드와 제목을 기반으로 프롬프트 생성 keyword = article.get('keyword', '') title = article.get('title', '') categories = article.get('categories', []) # 카테고리 맵핑 (한글 -> 영어) category_map = { '기술': 'technology', '경제': 'business', '정치': 'politics', '교육': 'education', '사회': 'society', '문화': 'culture', '과학': 'science' } eng_categories = [category_map.get(cat, cat) for cat in categories] category_str = ', '.join(eng_categories[:2]) if eng_categories else 'news' # 뉴스 관련 이미지를 위한 프롬프트 prompt = f"News illustration for {keyword} {category_str}, professional, modern, clean design, high quality, 4k, no text" return prompt async def _generate_image(self, prompt: str) -> str: """Replicate API를 사용한 이미지 생성""" try: if not self.replicate_api_key: # API 키가 없으면 플레이스홀더 이미지 URL 반환 return "https://via.placeholder.com/800x600.png?text=News+Image" async with httpx.AsyncClient() as client: # 예측 생성 요청 response = await client.post( self.replicate_api_url, headers={ "Authorization": f"Token {self.replicate_api_key}", "Content-Type": "application/json" }, json={ "version": self.model_version, "input": { "prompt": prompt, "width": 768, "height": 768, "num_outputs": 1, "scheduler": "K_EULER", "num_inference_steps": 25, "guidance_scale": 7.5, "prompt_strength": 0.8, "refine": "expert_ensemble_refiner", "high_noise_frac": 0.8 } }, timeout=60 ) if response.status_code in [200, 201]: result = response.json() prediction_id = result.get('id') # 예측 결과 폴링 image_url = await self._poll_prediction(prediction_id) return image_url else: logger.error(f"Replicate API error: {response.status_code}") return "https://via.placeholder.com/800x600.png?text=Generation+Failed" except Exception as e: logger.error(f"Error generating image: {e}") return "https://via.placeholder.com/800x600.png?text=Error" async def _poll_prediction(self, prediction_id: str, max_attempts: int = 30) -> str: """예측 결과 폴링""" try: async with httpx.AsyncClient() as client: for attempt in range(max_attempts): response = await client.get( f"{self.replicate_api_url}/{prediction_id}", headers={ "Authorization": f"Token {self.replicate_api_key}" }, timeout=30 ) if response.status_code == 200: result = response.json() status = result.get('status') if status == 'succeeded': output = result.get('output') if output and isinstance(output, list) and len(output) > 0: return output[0] else: return "https://via.placeholder.com/800x600.png?text=No+Output" elif status == 'failed': logger.error(f"Prediction failed: {result.get('error')}") return "https://via.placeholder.com/800x600.png?text=Failed" # 아직 처리중이면 대기 await asyncio.sleep(2) else: logger.error(f"Error polling prediction: {response.status_code}") return "https://via.placeholder.com/800x600.png?text=Poll+Error" # 최대 시도 횟수 초과 return "https://via.placeholder.com/800x600.png?text=Timeout" except Exception as e: logger.error(f"Error polling prediction: {e}") return "https://via.placeholder.com/800x600.png?text=Poll+Exception" async def stop(self): """워커 중지""" await self.queue_manager.disconnect() logger.info("Image Generator Worker stopped") async def main(): """메인 함수""" worker = ImageGeneratorWorker() try: await worker.start() except KeyboardInterrupt: logger.info("Received interrupt signal") finally: await worker.stop() if __name__ == "__main__": asyncio.run(main())