diff --git a/package-lock.json b/package-lock.json index 57750db7..240e8cc7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -58,6 +58,7 @@ "http-errors": "2.0.0", "i": "^0.3.7", "input-otp": "1.4.2", + "ioredis": "^5.6.1", "jmespath": "^0.16.0", "js-yaml": "4.1.0", "jsonwebtoken": "^9.0.2", @@ -854,6 +855,12 @@ "integrity": "sha512-Sx1pU8EM64o2BrqNpEO1CNLtKQwyhuXuqyfH7oGKCk+1a33d2r5saW8zNwm3j6BTExtjrv2BxTgzzkMwts6vGg==", "license": "MIT" }, + "node_modules/@ioredis/commands": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@ioredis/commands/-/commands-1.2.0.tgz", + "integrity": "sha512-Sx1pU8EM64o2BrqNpEO1CNLtKQwyhuXuqyfH7oGKCk+1a33d2r5saW8zNwm3j6BTExtjrv2BxTgzzkMwts6vGg==", + "license": "MIT" + }, "node_modules/@isaacs/balanced-match": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz", diff --git a/package.json b/package.json index 040a5453..02d26197 100644 --- a/package.json +++ b/package.json @@ -75,6 +75,7 @@ "http-errors": "2.0.0", "i": "^0.3.7", "input-otp": "1.4.2", + "ioredis": "^5.6.1", "jmespath": "^0.16.0", "js-yaml": "4.1.0", "jsonwebtoken": "^9.0.2", diff --git a/server/db/redis.ts b/server/db/redis.ts new file mode 100644 index 00000000..80f1e690 --- /dev/null +++ b/server/db/redis.ts @@ -0,0 +1,333 @@ +import Redis, { RedisOptions } from "ioredis"; +import logger from "@server/logger"; +import config from "@server/lib/config"; + +class RedisManager { + private static instance: RedisManager; + private client: Redis | null = null; + private subscriber: Redis | null = null; + private publisher: Redis | null = null; + private isEnabled: boolean = false; + private subscribers: Map< + string, + Set<(channel: string, message: string) => void> + > = new Map(); + + private constructor() { + this.isEnabled = config.getRawConfig().redis?.enabled || false; + if (this.isEnabled) { + this.initializeClients(); + } + } + + public static getInstance(): RedisManager { + if (!RedisManager.instance) { + RedisManager.instance = new RedisManager(); + } + return RedisManager.instance; + } + + private getRedisConfig(): RedisOptions { + const redisConfig = config.getRawConfig().redis!; + const opts: RedisOptions = { + host: redisConfig.host!, + port: redisConfig.port!, + password: redisConfig.password, + db: redisConfig.db, + tls: { + rejectUnauthorized: false + }, + }; + return opts; + } + + private initializeClients(): void { + const config = this.getRedisConfig(); + + try { + // Main client for general operations + this.client = new Redis(config); + + // Dedicated publisher client + this.publisher = new Redis(config); + + // Dedicated subscriber client + this.subscriber = new Redis(config); + + // Set up error handlers + this.client.on("error", (err) => { + logger.error("Redis client error:", err); + }); + + this.publisher.on("error", (err) => { + logger.error("Redis publisher error:", err); + }); + + this.subscriber.on("error", (err) => { + logger.error("Redis subscriber error:", err); + }); + + // Set up connection handlers + this.client.on("connect", () => { + logger.info("Redis client connected"); + }); + + this.publisher.on("connect", () => { + logger.info("Redis publisher connected"); + }); + + this.subscriber.on("connect", () => { + logger.info("Redis subscriber connected"); + }); + + // Set up message handler for subscriber + this.subscriber.on( + "message", + (channel: string, message: string) => { + const channelSubscribers = this.subscribers.get(channel); + if (channelSubscribers) { + channelSubscribers.forEach((callback) => { + try { + callback(channel, message); + } catch (error) { + logger.error( + `Error in subscriber callback for channel ${channel}:`, + error + ); + } + }); + } + } + ); + + logger.info("Redis clients initialized successfully"); + } catch (error) { + logger.error("Failed to initialize Redis clients:", error); + this.isEnabled = false; + } + } + + public isRedisEnabled(): boolean { + return this.isEnabled && this.client !== null; + } + + public getClient(): Redis | null { + return this.client; + } + + public async set( + key: string, + value: string, + ttl?: number + ): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + if (ttl) { + await this.client.setex(key, ttl, value); + } else { + await this.client.set(key, value); + } + return true; + } catch (error) { + logger.error("Redis SET error:", error); + return false; + } + } + + public async get(key: string): Promise { + if (!this.isRedisEnabled() || !this.client) return null; + + try { + return await this.client.get(key); + } catch (error) { + logger.error("Redis GET error:", error); + return null; + } + } + + public async del(key: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.del(key); + return true; + } catch (error) { + logger.error("Redis DEL error:", error); + return false; + } + } + + public async sadd(key: string, member: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.sadd(key, member); + return true; + } catch (error) { + logger.error("Redis SADD error:", error); + return false; + } + } + + public async srem(key: string, member: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.srem(key, member); + return true; + } catch (error) { + logger.error("Redis SREM error:", error); + return false; + } + } + + public async smembers(key: string): Promise { + if (!this.isRedisEnabled() || !this.client) return []; + + try { + return await this.client.smembers(key); + } catch (error) { + logger.error("Redis SMEMBERS error:", error); + return []; + } + } + + public async hset( + key: string, + field: string, + value: string + ): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.hset(key, field, value); + return true; + } catch (error) { + logger.error("Redis HSET error:", error); + return false; + } + } + + public async hget(key: string, field: string): Promise { + if (!this.isRedisEnabled() || !this.client) return null; + + try { + return await this.client.hget(key, field); + } catch (error) { + logger.error("Redis HGET error:", error); + return null; + } + } + + public async hdel(key: string, field: string): Promise { + if (!this.isRedisEnabled() || !this.client) return false; + + try { + await this.client.hdel(key, field); + return true; + } catch (error) { + logger.error("Redis HDEL error:", error); + return false; + } + } + + public async hgetall(key: string): Promise> { + if (!this.isRedisEnabled() || !this.client) return {}; + + try { + return await this.client.hgetall(key); + } catch (error) { + logger.error("Redis HGETALL error:", error); + return {}; + } + } + + public async publish(channel: string, message: string): Promise { + if (!this.isRedisEnabled() || !this.publisher) return false; + + try { + await this.publisher.publish(channel, message); + return true; + } catch (error) { + logger.error("Redis PUBLISH error:", error); + return false; + } + } + + public async subscribe( + channel: string, + callback: (channel: string, message: string) => void + ): Promise { + if (!this.isRedisEnabled() || !this.subscriber) return false; + + try { + // Add callback to subscribers map + if (!this.subscribers.has(channel)) { + this.subscribers.set(channel, new Set()); + // Only subscribe to the channel if it's the first subscriber + await this.subscriber.subscribe(channel); + } + + this.subscribers.get(channel)!.add(callback); + return true; + } catch (error) { + logger.error("Redis SUBSCRIBE error:", error); + return false; + } + } + + public async unsubscribe( + channel: string, + callback?: (channel: string, message: string) => void + ): Promise { + if (!this.isRedisEnabled() || !this.subscriber) return false; + + try { + const channelSubscribers = this.subscribers.get(channel); + if (!channelSubscribers) return true; + + if (callback) { + // Remove specific callback + channelSubscribers.delete(callback); + if (channelSubscribers.size === 0) { + this.subscribers.delete(channel); + await this.subscriber.unsubscribe(channel); + } + } else { + // Remove all callbacks for this channel + this.subscribers.delete(channel); + await this.subscriber.unsubscribe(channel); + } + + return true; + } catch (error) { + logger.error("Redis UNSUBSCRIBE error:", error); + return false; + } + } + + public async disconnect(): Promise { + try { + if (this.client) { + await this.client.quit(); + this.client = null; + } + if (this.publisher) { + await this.publisher.quit(); + this.publisher = null; + } + if (this.subscriber) { + await this.subscriber.quit(); + this.subscriber = null; + } + this.subscribers.clear(); + logger.info("Redis clients disconnected"); + } catch (error) { + logger.error("Error disconnecting Redis clients:", error); + } + } +} + +export const redisManager = RedisManager.getInstance(); +export default redisManager; diff --git a/server/lib/readConfigFile.ts b/server/lib/readConfigFile.ts index 49a7f3b0..d1a98a3c 100644 --- a/server/lib/readConfigFile.ts +++ b/server/lib/readConfigFile.ts @@ -131,6 +131,32 @@ export const configSchema = z.object({ .optional() }) .optional(), + redis: z + .object({ + enabled: z.boolean(), + host: z.string().optional(), + port: portSchema.optional(), + password: z.string().optional(), + db: z.number().int().nonnegative().optional().default(0), + tls: z + .object({ + rejectUnauthorized: z.boolean().optional().default(true) + }) + .optional() + }) + .refine( + (redis) => { + if (!redis.enabled) { + return true; + } + return redis.host !== undefined && redis.port !== undefined; + }, + { + message: + "If Redis is enabled, connection details must be provided" + } + ) + .optional(), traefik: z .object({ http_entrypoint: z.string().optional().default("web"), diff --git a/server/routers/ws.ts b/server/routers/ws.ts index 1459e79c..709da400 100644 --- a/server/routers/ws.ts +++ b/server/routers/ws.ts @@ -10,6 +10,8 @@ import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { validateOlmSessionToken } from "@server/auth/sessions/olm"; import { messageHandlers } from "./messageHandlers"; import logger from "@server/logger"; +import redisManager from "@server/db/redis"; +import { v4 as uuidv4 } from "uuid"; // Custom interfaces interface WebSocketRequest extends IncomingMessage { @@ -21,6 +23,7 @@ type ClientType = 'newt' | 'olm'; interface AuthenticatedWebSocket extends WebSocket { client?: Newt | Olm; clientType?: ClientType; + connectionId?: string; } interface TokenPayload { @@ -46,44 +49,113 @@ interface HandlerContext { senderWs: WebSocket; client: Newt | Olm | undefined; clientType: ClientType; - sendToClient: (clientId: string, message: WSMessage) => boolean; - broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => void; + sendToClient: (clientType: ClientType, clientId: string, message: WSMessage) => Promise; + broadcastToAllExcept: (message: WSMessage, excludeClientType?: ClientType, excludeClientId?: string) => Promise; connectedClients: Map; } +interface RedisMessage { + type: 'direct' | 'broadcast'; + targetClientType?: ClientType; + targetClientId?: string; + excludeClientType?: ClientType; + excludeClientId?: string; + message: WSMessage; + fromNodeId: string; +} + export type MessageHandler = (context: HandlerContext) => Promise; const router: Router = Router(); const wss: WebSocketServer = new WebSocketServer({ noServer: true }); -// Client tracking map -let connectedClients: Map = new Map(); +// Generate unique node ID for this instance +const NODE_ID = uuidv4(); +const REDIS_CHANNEL = 'websocket_messages'; -// Helper functions for client management -const addClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => { - const existingClients = connectedClients.get(clientId) || []; - existingClients.push(ws); - connectedClients.set(clientId, existingClients); - logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Total connections: ${existingClients.length}`); +// Client tracking map (local to this node) +let connectedClients: Map = new Map(); +// Helper to get map key +const getClientMapKey = (clientType: ClientType, clientId: string) => `${clientType}:${clientId}`; + +// Redis keys (generalized) +const getConnectionsKey = (clientType: ClientType, clientId: string) => `ws:connections:${clientType}:${clientId}`; +const getNodeConnectionsKey = (nodeId: string, clientType: ClientType, clientId: string) => `ws:node:${nodeId}:${clientType}:${clientId}`; + +// Initialize Redis subscription for cross-node messaging +const initializeRedisSubscription = async (): Promise => { + if (!redisManager.isRedisEnabled()) return; + + await redisManager.subscribe(REDIS_CHANNEL, async (channel: string, message: string) => { + try { + const redisMessage: RedisMessage = JSON.parse(message); + + // Ignore messages from this node + if (redisMessage.fromNodeId === NODE_ID) return; + + if (redisMessage.type === 'direct' && redisMessage.targetClientType && redisMessage.targetClientId) { + // Send to specific client on this node + await sendToClientLocal(redisMessage.targetClientType, redisMessage.targetClientId, redisMessage.message); + } else if (redisMessage.type === 'broadcast') { + // Broadcast to all clients on this node except excluded + await broadcastToAllExceptLocal(redisMessage.message, redisMessage.excludeClientType, redisMessage.excludeClientId); + } + } catch (error) { + logger.error('Error processing Redis message:', error); + } + }); }; -const removeClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => { - const existingClients = connectedClients.get(clientId) || []; +// Helper functions for client management +const addClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise => { + // Generate unique connection ID + const connectionId = uuidv4(); + ws.connectionId = connectionId; + + // Add to local tracking + const mapKey = getClientMapKey(clientType, clientId); + const existingClients = connectedClients.get(mapKey) || []; + existingClients.push(ws); + connectedClients.set(mapKey, existingClients); + + // Add to Redis tracking if enabled + if (redisManager.isRedisEnabled()) { + await redisManager.sadd(getConnectionsKey(clientType, clientId), NODE_ID); + await redisManager.hset(getNodeConnectionsKey(NODE_ID, clientType, clientId), connectionId, Date.now().toString()); + } + + logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`); +}; + +const removeClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise => { + const mapKey = getClientMapKey(clientType, clientId); + const existingClients = connectedClients.get(mapKey) || []; const updatedClients = existingClients.filter(client => client !== ws); if (updatedClients.length === 0) { - connectedClients.delete(clientId); + connectedClients.delete(mapKey); + + if (redisManager.isRedisEnabled()) { + await redisManager.srem(getConnectionsKey(clientType, clientId), NODE_ID); + await redisManager.del(getNodeConnectionsKey(NODE_ID, clientType, clientId)); + } + logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`); } else { - connectedClients.set(clientId, updatedClients); + connectedClients.set(mapKey, updatedClients); + + if (redisManager.isRedisEnabled() && ws.connectionId) { + await redisManager.hdel(getNodeConnectionsKey(NODE_ID, clientType, clientId), ws.connectionId); + } + logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`); } }; -// Helper functions for sending messages -const sendToClient = (clientId: string, message: WSMessage): boolean => { - const clients = connectedClients.get(clientId); +// Local message sending (within this node) +const sendToClientLocal = async (clientType: ClientType, clientId: string, message: WSMessage): Promise => { + const mapKey = getClientMapKey(clientType, clientId); + const clients = connectedClients.get(mapKey); if (!clients || clients.length === 0) { - logger.info(`No active connections found for Client ID: ${clientId}`); return false; } const messageString = JSON.stringify(message); @@ -95,9 +167,10 @@ const sendToClient = (clientId: string, message: WSMessage): boolean => { return true; }; -const broadcastToAllExcept = (message: WSMessage, excludeClientId?: string): void => { - connectedClients.forEach((clients, clientId) => { - if (clientId !== excludeClientId) { +const broadcastToAllExceptLocal = async (message: WSMessage, excludeClientType?: ClientType, excludeClientId?: string): Promise => { + connectedClients.forEach((clients, mapKey) => { + const [type, id] = mapKey.split(":"); + if (!(excludeClientType && excludeClientId && type === excludeClientType && id === excludeClientId)) { clients.forEach(client => { if (client.readyState === WebSocket.OPEN) { client.send(JSON.stringify(message)); @@ -107,9 +180,72 @@ const broadcastToAllExcept = (message: WSMessage, excludeClientId?: string): voi }); }; +// Cross-node message sending (via Redis) +const sendToClient = async (clientType: ClientType, clientId: string, message: WSMessage): Promise => { + // Try to send locally first + const localSent = await sendToClientLocal(clientType, clientId, message); + + // If Redis is enabled, also send via Redis pub/sub to other nodes + if (redisManager.isRedisEnabled()) { + const redisMessage: RedisMessage = { + type: 'direct', + targetClientType: clientType, + targetClientId: clientId, + message, + fromNodeId: NODE_ID + }; + + await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage)); + } + + return localSent; +}; + +const broadcastToAllExcept = async (message: WSMessage, excludeClientType?: ClientType, excludeClientId?: string): Promise => { + // Broadcast locally + await broadcastToAllExceptLocal(message, excludeClientType, excludeClientId); + + // If Redis is enabled, also broadcast via Redis pub/sub to other nodes + if (redisManager.isRedisEnabled()) { + const redisMessage: RedisMessage = { + type: 'broadcast', + excludeClientType, + excludeClientId, + message, + fromNodeId: NODE_ID + }; + + await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage)); + } +}; + +// Check if a client has active connections across all nodes +const hasActiveConnections = async (clientType: ClientType, clientId: string): Promise => { + if (!redisManager.isRedisEnabled()) { + const mapKey = getClientMapKey(clientType, clientId); + const clients = connectedClients.get(mapKey); + return !!(clients && clients.length > 0); + } + + const activeNodes = await redisManager.smembers(getConnectionsKey(clientType, clientId)); + return activeNodes.length > 0; +}; + +// Get all active nodes for a client +const getActiveNodes = async (clientType: ClientType, clientId: string): Promise => { + if (!redisManager.isRedisEnabled()) { + const mapKey = getClientMapKey(clientType, clientId); + const clients = connectedClients.get(mapKey); + return (clients && clients.length > 0) ? [NODE_ID] : []; + } + + return await redisManager.smembers(getConnectionsKey(clientType, clientId)); +}; + // Token verification middleware const verifyToken = async (token: string, clientType: ClientType): Promise => { - try { + +try { if (clientType === 'newt') { const { session, newt } = await validateNewtSessionToken(token); if (!session || !newt) { @@ -143,7 +279,7 @@ const verifyToken = async (token: string, clientType: ClientType): Promise { +const setupConnection = async (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): Promise => { logger.info("Establishing websocket connection"); if (!client) { logger.error("Connection attempt without client"); @@ -155,7 +291,7 @@ const setupConnection = (ws: AuthenticatedWebSocket, client: Newt | Olm, clientT // Add client to tracking const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId; - addClient(clientId, ws, clientType); + await addClient(clientType, clientId, ws); ws.on("message", async (data) => { try { @@ -182,9 +318,13 @@ const setupConnection = (ws: AuthenticatedWebSocket, client: Newt | Olm, clientT if (response) { if (response.broadcast) { - broadcastToAllExcept(response.message, response.excludeSender ? clientId : undefined); + await broadcastToAllExcept( + response.message, + response.excludeSender ? clientType : undefined, + response.excludeSender ? clientId : undefined + ); } else if (response.targetClientId) { - sendToClient(response.targetClientId, response.message); + await sendToClient(clientType, response.targetClientId, response.message); } else { ws.send(JSON.stringify(response.message)); } @@ -202,7 +342,7 @@ const setupConnection = (ws: AuthenticatedWebSocket, client: Newt | Olm, clientT }); ws.on("close", () => { - removeClient(clientId, ws, clientType); + removeClient(clientType, clientId, ws); logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`); }); @@ -256,10 +396,54 @@ const handleWSUpgrade = (server: HttpServer): void => { }); }; +// Initialize Redis subscription when the module is loaded +if (redisManager.isRedisEnabled()) { + initializeRedisSubscription().catch(error => { + logger.error('Failed to initialize Redis subscription:', error); + }); + logger.info(`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`); +} else { + logger.info('WebSocket handler initialized in local mode (Redis disabled)'); +} + +// Cleanup function for graceful shutdown +const cleanup = async (): Promise => { + try { + // Close all WebSocket connections + connectedClients.forEach((clients) => { + clients.forEach(client => { + if (client.readyState === WebSocket.OPEN) { + client.terminate(); + } + }); + }); + + // Clean up Redis tracking for this node + if (redisManager.isRedisEnabled()) { + const keys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || []; + if (keys.length > 0) { + await Promise.all(keys.map(key => redisManager.del(key))); + } + } + + logger.info('WebSocket cleanup completed'); + } catch (error) { + logger.error('Error during WebSocket cleanup:', error); + } +}; + +// Handle process termination +process.on('SIGTERM', cleanup); +process.on('SIGINT', cleanup); + export { router, handleWSUpgrade, sendToClient, broadcastToAllExcept, - connectedClients + connectedClients, + hasActiveConnections, + getActiveNodes, + NODE_ID, + cleanup };