diff --git a/server/db/redis.ts b/server/db/redis.ts new file mode 100644 index 00000000..bae80099 --- /dev/null +++ b/server/db/redis.ts @@ -0,0 +1,316 @@ +import Redis from 'ioredis'; +import logger from '@server/logger'; + +interface RedisConfig { + host: string; + port: number; + password?: string; + db?: number; + retryDelayOnFailover?: number; + maxRetriesPerRequest?: number; +} + +class RedisManager { + private static instance: RedisManager; + private client: Redis | null = null; + private subscriber: Redis | null = null; + private publisher: Redis | null = null; + private isEnabled: boolean = false; + private subscribers: Map void>> = new Map(); + + private constructor() { + this.isEnabled = !!process.env.REDIS; + if (this.isEnabled) { + this.initializeClients(); + } + } + + public static getInstance(): RedisManager { + if (!RedisManager.instance) { + RedisManager.instance = new RedisManager(); + } + return RedisManager.instance; + } + + private getRedisConfig(): RedisConfig { + return { + host: process.env.REDIS_HOST || 'localhost', + port: parseInt(process.env.REDIS_PORT || '6379'), + password: process.env.REDIS_PASSWORD, + db: parseInt(process.env.REDIS_DB || '0'), + retryDelayOnFailover: 100, + maxRetriesPerRequest: 3, + }; + } + + private initializeClients(): void { + const config = this.getRedisConfig(); + + try { + // Main client for general operations + this.client = new Redis(config); + + // Dedicated publisher client + this.publisher = new Redis(config); + + // Dedicated subscriber client + this.subscriber = new Redis(config); + + // Set up error handlers + this.client.on('error', (err) => { + logger.error('Redis client error:', err); + }); + + this.publisher.on('error', (err) => { + logger.error('Redis publisher error:', err); + }); + + this.subscriber.on('error', (err) => { + logger.error('Redis subscriber error:', err); + }); + + // Set up connection handlers + this.client.on('connect', () => { + logger.info('Redis client connected'); + }); + + this.publisher.on('connect', () => { + logger.info('Redis publisher connected'); + }); + + this.subscriber.on('connect', () => { + logger.info('Redis subscriber connected'); + }); + + // Set up message handler for subscriber + this.subscriber.on('message', (channel: string, message: string) => { + const channelSubscribers = this.subscribers.get(channel); + if (channelSubscribers) { + channelSubscribers.forEach(callback => { + try { + callback(channel, message); + } catch (error) { + logger.error(`Error in subscriber callback for channel ${channel}:`, error); + } + }); + } + }); + + logger.info('Redis clients initialized successfully'); + } catch (error) { + logger.error('Failed to initialize Redis clients:', error); + this.isEnabled = false; + } + } + + public isRedisEnabled(): boolean { + return this.isEnabled && this.client !== null; + } + + public getClient(): Redis | null { + return this.client; + } + + public async set(key: string, value: string, ttl?: number): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + if (ttl) { + await this.client.setex(key, ttl, value); + } else { + await this.client.set(key, value); + } + return true; + } catch (error) { + logger.error('Redis SET error:', error); + return false; + } + } + + public async get(key: string): Promise { + if (!this.isRedisEnabled() || !this.client) return null; + + try { + return await this.client.get(key); + } catch (error) { + logger.error('Redis GET error:', error); + return null; + } + } + + public async del(key: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.del(key); + return true; + } catch (error) { + logger.error('Redis DEL error:', error); + return false; + } + } + + public async sadd(key: string, member: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.sadd(key, member); + return true; + } catch (error) { + logger.error('Redis SADD error:', error); + return false; + } + } + + public async srem(key: string, member: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.srem(key, member); + return true; + } catch (error) { + logger.error('Redis SREM error:', error); + return false; + } + } + + public async smembers(key: string): Promise { + if (!this.isRedisEnabled() || !this.client) return []; + + try { + return await this.client.smembers(key); + } catch (error) { + logger.error('Redis SMEMBERS error:', error); + return []; + } + } + + public async hset(key: string, field: string, value: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.hset(key, field, value); + return true; + } catch (error) { + logger.error('Redis HSET error:', error); + return false; + } + } + + public async hget(key: string, field: string): Promise { + if (!this.isRedisEnabled() || !this.client) return null; + + try { + return await this.client.hget(key, field); + } catch (error) { + logger.error('Redis HGET error:', error); + return null; + } + } + + public async hdel(key: string, field: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.hdel(key, field); + return true; + } catch (error) { + logger.error('Redis HDEL error:', error); + return false; + } + } + + public async hgetall(key: string): Promise> { + if (!this.isRedisEnabled() || !this.client) return {}; + + try { + return await this.client.hgetall(key); + } catch (error) { + logger.error('Redis HGETALL error:', error); + return {}; + } + } + + public async publish(channel: string, message: string): Promise { + if (!this.isRedisEnabled() || !this.publisher) return false; + + try { + await this.publisher.publish(channel, message); + return true; + } catch (error) { + logger.error('Redis PUBLISH error:', error); + return false; + } + } + + public async subscribe(channel: string, callback: (channel: string, message: string) => void): Promise { + if (!this.isRedisEnabled() || !this.subscriber) return false; + + try { + // Add callback to subscribers map + if (!this.subscribers.has(channel)) { + this.subscribers.set(channel, new Set()); + // Only subscribe to the channel if it's the first subscriber + await this.subscriber.subscribe(channel); + } + + this.subscribers.get(channel)!.add(callback); + return true; + } catch (error) { + logger.error('Redis SUBSCRIBE error:', error); + return false; + } + } + + public async unsubscribe(channel: string, callback?: (channel: string, message: string) => void): Promise { + if (!this.isRedisEnabled() || !this.subscriber) return false; + + try { + const channelSubscribers = this.subscribers.get(channel); + if (!channelSubscribers) return true; + + if (callback) { + // Remove specific callback + channelSubscribers.delete(callback); + if (channelSubscribers.size === 0) { + this.subscribers.delete(channel); + await this.subscriber.unsubscribe(channel); + } + } else { + // Remove all callbacks for this channel + this.subscribers.delete(channel); + await this.subscriber.unsubscribe(channel); + } + + return true; + } catch (error) { + logger.error('Redis UNSUBSCRIBE error:', error); + return false; + } + } + + public async disconnect(): Promise { + try { + if (this.client) { + await this.client.quit(); + this.client = null; + } + if (this.publisher) { + await this.publisher.quit(); + this.publisher = null; + } + if (this.subscriber) { + await this.subscriber.quit(); + this.subscriber = null; + } + this.subscribers.clear(); + logger.info('Redis clients disconnected'); + } catch (error) { + logger.error('Error disconnecting Redis clients:', error); + } + } +} + +// Export singleton instance +export const redisManager = RedisManager.getInstance(); +export default redisManager; \ No newline at end of file diff --git a/server/routers/ws.ts b/server/routers/ws.ts index c4ee8874..c953a60c 100644 --- a/server/routers/ws.ts +++ b/server/routers/ws.ts @@ -9,6 +9,8 @@ import db from "@server/db"; import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { messageHandlers } from "./messageHandlers"; import logger from "@server/logger"; +import redisManager from "@server/db/redis"; +import { v4 as uuidv4 } from "uuid"; // Custom interfaces interface WebSocketRequest extends IncomingMessage { @@ -17,6 +19,7 @@ interface WebSocketRequest extends IncomingMessage { interface AuthenticatedWebSocket extends WebSocket { newt?: Newt; + connectionId?: string; } interface TokenPayload { @@ -40,45 +43,113 @@ interface HandlerContext { message: WSMessage; senderWs: WebSocket; newt: Newt | undefined; - sendToClient: (newtId: string, message: WSMessage) => boolean; - broadcastToAllExcept: (message: WSMessage, excludeNewtId?: string) => void; + sendToClient: (newtId: string, message: WSMessage) => Promise; + broadcastToAllExcept: (message: WSMessage, excludeNewtId?: string) => Promise; connectedClients: Map; } +interface RedisMessage { + type: 'direct' | 'broadcast'; + targetNewtId?: string; + excludeNewtId?: string; + message: WSMessage; + fromNodeId: string; +} + export type MessageHandler = (context: HandlerContext) => Promise; const router: Router = Router(); const wss: WebSocketServer = new WebSocketServer({ noServer: true }); -// Client tracking map +// Generate unique node ID for this instance +const NODE_ID = uuidv4(); +const REDIS_CHANNEL = 'websocket_messages'; + +// Client tracking map (local to this node) let connectedClients: Map = new Map(); +// Redis keys +const getConnectionsKey = (newtId: string) => `ws:connections:${newtId}`; +const getNodeConnectionsKey = (nodeId: string, newtId: string) => `ws:node:${nodeId}:${newtId}`; + +// Initialize Redis subscription for cross-node messaging +const initializeRedisSubscription = async (): Promise => { + if (!redisManager.isRedisEnabled()) return; + + await redisManager.subscribe(REDIS_CHANNEL, async (channel: string, message: string) => { + try { + const redisMessage: RedisMessage = JSON.parse(message); + + // Ignore messages from this node + if (redisMessage.fromNodeId === NODE_ID) return; + + if (redisMessage.type === 'direct' && redisMessage.targetNewtId) { + // Send to specific client on this node + await sendToClientLocal(redisMessage.targetNewtId, redisMessage.message); + } else if (redisMessage.type === 'broadcast') { + // Broadcast to all clients on this node except excluded + await broadcastToAllExceptLocal(redisMessage.message, redisMessage.excludeNewtId); + } + } catch (error) { + logger.error('Error processing Redis message:', error); + } + }); +}; + // Helper functions for client management -const addClient = (newtId: string, ws: AuthenticatedWebSocket): void => { +const addClient = async (newtId: string, ws: AuthenticatedWebSocket): Promise => { + // Generate unique connection ID + const connectionId = uuidv4(); + ws.connectionId = connectionId; + + // Add to local tracking const existingClients = connectedClients.get(newtId) || []; existingClients.push(ws); connectedClients.set(newtId, existingClients); - logger.info(`Client added to tracking - Newt ID: ${newtId}, Total connections: ${existingClients.length}`); + + // Add to Redis tracking if enabled + if (redisManager.isRedisEnabled()) { + // Add this node to the set of nodes handling this newt + await redisManager.sadd(getConnectionsKey(newtId), NODE_ID); + + // Track specific connection on this node + await redisManager.hset(getNodeConnectionsKey(NODE_ID, newtId), connectionId, Date.now().toString()); + } + + logger.info(`Client added to tracking - Newt ID: ${newtId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`); }; -const removeClient = (newtId: string, ws: AuthenticatedWebSocket): void => { +const removeClient = async (newtId: string, ws: AuthenticatedWebSocket): Promise => { + // Remove from local tracking const existingClients = connectedClients.get(newtId) || []; const updatedClients = existingClients.filter(client => client !== ws); if (updatedClients.length === 0) { connectedClients.delete(newtId); + + // Remove from Redis tracking if enabled + if (redisManager.isRedisEnabled()) { + await redisManager.srem(getConnectionsKey(newtId), NODE_ID); + await redisManager.del(getNodeConnectionsKey(NODE_ID, newtId)); + } + logger.info(`All connections removed for Newt ID: ${newtId}`); } else { connectedClients.set(newtId, updatedClients); + + // Update Redis tracking if enabled + if (redisManager.isRedisEnabled() && ws.connectionId) { + await redisManager.hdel(getNodeConnectionsKey(NODE_ID, newtId), ws.connectionId); + } + logger.info(`Connection removed - Newt ID: ${newtId}, Remaining connections: ${updatedClients.length}`); } }; -// Helper functions for sending messages -const sendToClient = (newtId: string, message: WSMessage): boolean => { +// Local message sending (within this node) +const sendToClientLocal = async (newtId: string, message: WSMessage): Promise => { const clients = connectedClients.get(newtId); if (!clients || clients.length === 0) { - logger.info(`No active connections found for Newt ID: ${newtId}`); return false; } @@ -91,7 +162,7 @@ const sendToClient = (newtId: string, message: WSMessage): boolean => { return true; }; -const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void => { +const broadcastToAllExceptLocal = async (message: WSMessage, excludeNewtId?: string): Promise => { connectedClients.forEach((clients, newtId) => { if (newtId !== excludeNewtId) { clients.forEach(client => { @@ -103,6 +174,65 @@ const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void }); }; +// Cross-node message sending (via Redis) +const sendToClient = async (newtId: string, message: WSMessage): Promise => { + // Try to send locally first + const localSent = await sendToClientLocal(newtId, message); + + // If Redis is enabled, also send via Redis pub/sub to other nodes + if (redisManager.isRedisEnabled()) { + const redisMessage: RedisMessage = { + type: 'direct', + targetNewtId: newtId, + message, + fromNodeId: NODE_ID + }; + + await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage)); + } + + return localSent; +}; + +const broadcastToAllExcept = async (message: WSMessage, excludeNewtId?: string): Promise => { + // Broadcast locally + await broadcastToAllExceptLocal(message, excludeNewtId); + + // If Redis is enabled, also broadcast via Redis pub/sub to other nodes + if (redisManager.isRedisEnabled()) { + const redisMessage: RedisMessage = { + type: 'broadcast', + excludeNewtId, + message, + fromNodeId: NODE_ID + }; + + await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage)); + } +}; + +// Check if a newt has active connections across all nodes +const hasActiveConnections = async (newtId: string): Promise => { + if (!redisManager.isRedisEnabled()) { + // Fallback to local check + const clients = connectedClients.get(newtId); + return !!(clients && clients.length > 0); + } + + const activeNodes = await redisManager.smembers(getConnectionsKey(newtId)); + return activeNodes.length > 0; +}; + +// Get all active nodes for a newt +const getActiveNodes = async (newtId: string): Promise => { + if (!redisManager.isRedisEnabled()) { + const clients = connectedClients.get(newtId); + return (clients && clients.length > 0) ? [NODE_ID] : []; + } + + return await redisManager.smembers(getConnectionsKey(newtId)); +}; + // Token verification middleware (unchanged) const verifyToken = async (token: string): Promise => { try { @@ -128,7 +258,7 @@ const verifyToken = async (token: string): Promise => { } }; -const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => { +const setupConnection = async (ws: AuthenticatedWebSocket, newt: Newt): Promise => { logger.info("Establishing websocket connection"); if (!newt) { @@ -139,12 +269,11 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => { ws.newt = newt; // Add client to tracking - addClient(newt.newtId, ws); + await addClient(newt.newtId, ws); ws.on("message", async (data) => { try { const message: WSMessage = JSON.parse(data.toString()); - // logger.info(`Message received from Newt ID ${newtId}:`, message); // Validate message format if (!message.type || typeof message.type !== "string") { @@ -171,10 +300,10 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => { if (response) { if (response.broadcast) { // Broadcast to all clients except sender if specified - broadcastToAllExcept(response.message, response.excludeSender ? newt.newtId : undefined); + await broadcastToAllExcept(response.message, response.excludeSender ? newt.newtId : undefined); } else if (response.targetNewtId) { // Send to specific client if targetNewtId is provided - sendToClient(response.targetNewtId, response.message); + await sendToClient(response.targetNewtId, response.message); } else { // Send back to sender ws.send(JSON.stringify(response.message)); @@ -193,8 +322,8 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => { } }); - ws.on("close", () => { - removeClient(newt.newtId, ws); + ws.on("close", async () => { + await removeClient(newt.newtId, ws); logger.info(`Client disconnected - Newt ID: ${newt.newtId}`); }); @@ -202,7 +331,7 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => { logger.error(`WebSocket error for Newt ID ${newt.newtId}:`, error); }); - logger.info(`WebSocket connection established - Newt ID: ${newt.newtId}`); + logger.info(`WebSocket connection established - Newt ID: ${newt.newtId}, Node ID: ${NODE_ID}`); }; // Router endpoint (unchanged) @@ -233,8 +362,8 @@ const handleWSUpgrade = (server: HttpServer): void => { return; } - wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => { - setupConnection(ws, tokenPayload.newt); + wss.handleUpgrade(request, socket, head, async (ws: AuthenticatedWebSocket) => { + await setupConnection(ws, tokenPayload.newt); }); } catch (error) { logger.error("WebSocket upgrade error:", error); @@ -244,10 +373,54 @@ const handleWSUpgrade = (server: HttpServer): void => { }); }; +// Initialize Redis subscription when the module is loaded +if (redisManager.isRedisEnabled()) { + initializeRedisSubscription().catch(error => { + logger.error('Failed to initialize Redis subscription:', error); + }); + logger.info(`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`); +} else { + logger.info('WebSocket handler initialized in local mode (Redis disabled)'); +} + +// Cleanup function for graceful shutdown +const cleanup = async (): Promise => { + try { + // Close all WebSocket connections + connectedClients.forEach((clients) => { + clients.forEach(client => { + if (client.readyState === WebSocket.OPEN) { + client.terminate(); + } + }); + }); + + // Clean up Redis tracking for this node + if (redisManager.isRedisEnabled()) { + const keys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || []; + if (keys.length > 0) { + await Promise.all(keys.map(key => redisManager.del(key))); + } + } + + logger.info('WebSocket cleanup completed'); + } catch (error) { + logger.error('Error during WebSocket cleanup:', error); + } +}; + +// Handle process termination +process.on('SIGTERM', cleanup); +process.on('SIGINT', cleanup); + export { router, handleWSUpgrade, sendToClient, broadcastToAllExcept, - connectedClients -}; + connectedClients, + hasActiveConnections, + getActiveNodes, + NODE_ID, + cleanup +}; \ No newline at end of file