""" WebSocket Server for real-time notifications """ import logging import json from typing import Dict, List from fastapi import WebSocket from datetime import datetime logger = logging.getLogger(__name__) class WebSocketNotificationServer: """Manages WebSocket connections for real-time notifications""" def __init__(self): # Store connections by user_id self.active_connections: Dict[str, List[WebSocket]] = {} self.connection_metadata: Dict[WebSocket, Dict] = {} async def connect(self, websocket: WebSocket, user_id: str): """Accept a new WebSocket connection""" await websocket.accept() # Add to active connections if user_id not in self.active_connections: self.active_connections[user_id] = [] self.active_connections[user_id].append(websocket) # Store metadata self.connection_metadata[websocket] = { "user_id": user_id, "connected_at": datetime.now(), "last_activity": datetime.now() } logger.info(f"WebSocket connected for user {user_id}. Total connections: {len(self.active_connections[user_id])}") # Send welcome message await self.send_welcome_message(websocket, user_id) def disconnect(self, user_id: str): """Remove a WebSocket connection""" if user_id in self.active_connections: # Remove all connections for this user for websocket in self.active_connections[user_id]: if websocket in self.connection_metadata: del self.connection_metadata[websocket] del self.active_connections[user_id] logger.info(f"WebSocket disconnected for user {user_id}") async def send_to_user(self, user_id: str, message: Dict): """Send a message to all connections for a specific user""" if user_id not in self.active_connections: logger.debug(f"No active connections for user {user_id}") return False disconnected = [] for websocket in self.active_connections[user_id]: try: await websocket.send_json(message) # Update last activity if websocket in self.connection_metadata: self.connection_metadata[websocket]["last_activity"] = datetime.now() except Exception as e: logger.error(f"Error sending to WebSocket for user {user_id}: {e}") disconnected.append(websocket) # Remove disconnected websockets for ws in disconnected: self.active_connections[user_id].remove(ws) if ws in self.connection_metadata: del self.connection_metadata[ws] # Clean up if no more connections if not self.active_connections[user_id]: del self.active_connections[user_id] return True async def broadcast(self, message: Dict): """Broadcast a message to all connected users""" for user_id in list(self.active_connections.keys()): await self.send_to_user(user_id, message) async def send_notification(self, user_id: str, notification: Dict): """Send a notification to a specific user""" message = { "type": "notification", "timestamp": datetime.now().isoformat(), "data": notification } return await self.send_to_user(user_id, message) async def send_welcome_message(self, websocket: WebSocket, user_id: str): """Send a welcome message to newly connected user""" welcome_message = { "type": "connection", "status": "connected", "user_id": user_id, "timestamp": datetime.now().isoformat(), "message": "Connected to notification service" } try: await websocket.send_json(welcome_message) except Exception as e: logger.error(f"Error sending welcome message: {e}") def get_connection_count(self, user_id: str = None) -> int: """Get the number of active connections""" if user_id: return len(self.active_connections.get(user_id, [])) total = 0 for connections in self.active_connections.values(): total += len(connections) return total def get_connected_users(self) -> List[str]: """Get list of connected user IDs""" return list(self.active_connections.keys()) async def send_system_message(self, user_id: str, message: str, severity: str = "info"): """Send a system message to a user""" system_message = { "type": "system", "severity": severity, "message": message, "timestamp": datetime.now().isoformat() } return await self.send_to_user(user_id, system_message) async def send_presence_update(self, user_id: str, status: str): """Send presence update to user's connections""" presence_message = { "type": "presence", "user_id": user_id, "status": status, "timestamp": datetime.now().isoformat() } # Could send to friends/contacts if implemented return await self.send_to_user(user_id, presence_message) async def handle_ping(self, websocket: WebSocket): """Handle ping message from client""" try: await websocket.send_json({ "type": "pong", "timestamp": datetime.now().isoformat() }) # Update last activity if websocket in self.connection_metadata: self.connection_metadata[websocket]["last_activity"] = datetime.now() except Exception as e: logger.error(f"Error handling ping: {e}") async def cleanup_stale_connections(self, timeout_minutes: int = 30): """Clean up stale connections that haven't been active""" now = datetime.now() stale_connections = [] for websocket, metadata in self.connection_metadata.items(): last_activity = metadata.get("last_activity") if last_activity: time_diff = (now - last_activity).total_seconds() / 60 if time_diff > timeout_minutes: stale_connections.append({ "websocket": websocket, "user_id": metadata.get("user_id") }) # Remove stale connections for conn in stale_connections: user_id = conn["user_id"] websocket = conn["websocket"] if user_id in self.active_connections: if websocket in self.active_connections[user_id]: self.active_connections[user_id].remove(websocket) # Clean up if no more connections if not self.active_connections[user_id]: del self.active_connections[user_id] if websocket in self.connection_metadata: del self.connection_metadata[websocket] logger.info(f"Cleaned up stale connection for user {user_id}") return len(stale_connections)