import { Router, Request, Response } from "express"; import { Server as HttpServer } from "http"; import { WebSocket, WebSocketServer } from "ws"; import { IncomingMessage } from "http"; import { Socket } from "net"; import { Newt, newts, NewtSession, olms, Olm, OlmSession } from "@server/db"; import { eq } from "drizzle-orm"; import { db } from "@server/db"; import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { validateOlmSessionToken } from "@server/auth/sessions/olm"; import { messageHandlers } from "./messageHandlers"; import logger from "@server/logger"; import redisManager from "@server/db/redis"; import { v4 as uuidv4 } from "uuid"; // Custom interfaces interface WebSocketRequest extends IncomingMessage { token?: string; } type ClientType = 'newt' | 'olm'; interface AuthenticatedWebSocket extends WebSocket { client?: Newt | Olm; clientType?: ClientType; connectionId?: string; } interface TokenPayload { client: Newt | Olm; session: NewtSession | OlmSession; clientType: ClientType; } interface WSMessage { type: string; data: any; } interface HandlerResponse { message: WSMessage; broadcast?: boolean; excludeSender?: boolean; targetClientId?: string; } interface HandlerContext { message: WSMessage; senderWs: WebSocket; client: Newt | Olm | undefined; clientType: ClientType; sendToClient: (clientId: string, message: WSMessage) => Promise; broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => Promise; connectedClients: Map; } interface RedisMessage { type: 'direct' | 'broadcast'; targetClientId?: string; excludeClientId?: string; message: WSMessage; fromNodeId: string; } export type MessageHandler = (context: HandlerContext) => Promise; const router: Router = Router(); const wss: WebSocketServer = new WebSocketServer({ noServer: true }); // Generate unique node ID for this instance const NODE_ID = uuidv4(); const REDIS_CHANNEL = 'websocket_messages'; // Client tracking map (local to this node) let connectedClients: Map = new Map(); // Helper to get map key const getClientMapKey = (clientId: string) => clientId; // Redis keys (generalized) const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`; const getNodeConnectionsKey = (nodeId: string, clientId: string) => `ws:node:${nodeId}:${clientId}`; // Initialize Redis subscription for cross-node messaging const initializeRedisSubscription = async (): Promise => { if (!redisManager.isRedisEnabled()) return; await redisManager.subscribe(REDIS_CHANNEL, async (channel: string, message: string) => { try { const redisMessage: RedisMessage = JSON.parse(message); // Ignore messages from this node if (redisMessage.fromNodeId === NODE_ID) return; if (redisMessage.type === 'direct' && redisMessage.targetClientId) { // Send to specific client on this node await sendToClientLocal(redisMessage.targetClientId, redisMessage.message); } else if (redisMessage.type === 'broadcast') { // Broadcast to all clients on this node except excluded await broadcastToAllExceptLocal(redisMessage.message, redisMessage.excludeClientId); } } catch (error) { logger.error('Error processing Redis message:', error); } }); }; // Helper functions for client management const addClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise => { // Generate unique connection ID const connectionId = uuidv4(); ws.connectionId = connectionId; // Add to local tracking const mapKey = getClientMapKey(clientId); const existingClients = connectedClients.get(mapKey) || []; existingClients.push(ws); connectedClients.set(mapKey, existingClients); // Add to Redis tracking if enabled if (redisManager.isRedisEnabled()) { await redisManager.sadd(getConnectionsKey(clientId), NODE_ID); await redisManager.hset(getNodeConnectionsKey(NODE_ID, clientId), connectionId, Date.now().toString()); } logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`); }; const removeClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise => { const mapKey = getClientMapKey(clientId); const existingClients = connectedClients.get(mapKey) || []; const updatedClients = existingClients.filter(client => client !== ws); if (updatedClients.length === 0) { connectedClients.delete(mapKey); if (redisManager.isRedisEnabled()) { await redisManager.srem(getConnectionsKey(clientId), NODE_ID); await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId)); } logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`); } else { connectedClients.set(mapKey, updatedClients); if (redisManager.isRedisEnabled() && ws.connectionId) { await redisManager.hdel(getNodeConnectionsKey(NODE_ID, clientId), ws.connectionId); } logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`); } }; // Local message sending (within this node) const sendToClientLocal = async (clientId: string, message: WSMessage): Promise => { const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); if (!clients || clients.length === 0) { return false; } const messageString = JSON.stringify(message); clients.forEach(client => { if (client.readyState === WebSocket.OPEN) { client.send(messageString); } }); return true; }; const broadcastToAllExceptLocal = async (message: WSMessage, excludeClientId?: string): Promise => { connectedClients.forEach((clients, mapKey) => { const [type, id] = mapKey.split(":"); if (!(excludeClientId && id === excludeClientId)) { clients.forEach(client => { if (client.readyState === WebSocket.OPEN) { client.send(JSON.stringify(message)); } }); } }); }; // Cross-node message sending (via Redis) const sendToClient = async (clientId: string, message: WSMessage): Promise => { // Try to send locally first const localSent = await sendToClientLocal(clientId, message); // If Redis is enabled, also send via Redis pub/sub to other nodes if (redisManager.isRedisEnabled()) { const redisMessage: RedisMessage = { type: 'direct', targetClientId: clientId, message, fromNodeId: NODE_ID }; await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage)); } return localSent; }; const broadcastToAllExcept = async (message: WSMessage, excludeClientId?: string): Promise => { // Broadcast locally await broadcastToAllExceptLocal(message, excludeClientId); // If Redis is enabled, also broadcast via Redis pub/sub to other nodes if (redisManager.isRedisEnabled()) { const redisMessage: RedisMessage = { type: 'broadcast', excludeClientId, message, fromNodeId: NODE_ID }; await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage)); } }; // Check if a client has active connections across all nodes const hasActiveConnections = async (clientId: string): Promise => { if (!redisManager.isRedisEnabled()) { const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); return !!(clients && clients.length > 0); } const activeNodes = await redisManager.smembers(getConnectionsKey(clientId)); return activeNodes.length > 0; }; // Get all active nodes for a client const getActiveNodes = async (clientType: ClientType, clientId: string): Promise => { if (!redisManager.isRedisEnabled()) { const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); return (clients && clients.length > 0) ? [NODE_ID] : []; } return await redisManager.smembers(getConnectionsKey(clientId)); }; // Token verification middleware const verifyToken = async (token: string, clientType: ClientType): Promise => { try { if (clientType === 'newt') { const { session, newt } = await validateNewtSessionToken(token); if (!session || !newt) { return null; } const existingNewt = await db .select() .from(newts) .where(eq(newts.newtId, newt.newtId)); if (!existingNewt || !existingNewt[0]) { return null; } return { client: existingNewt[0], session, clientType }; } else { const { session, olm } = await validateOlmSessionToken(token); if (!session || !olm) { return null; } const existingOlm = await db .select() .from(olms) .where(eq(olms.olmId, olm.olmId)); if (!existingOlm || !existingOlm[0]) { return null; } return { client: existingOlm[0], session, clientType }; } } catch (error) { logger.error("Token verification failed:", error); return null; } }; const setupConnection = async (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): Promise => { logger.info("Establishing websocket connection"); if (!client) { logger.error("Connection attempt without client"); return ws.terminate(); } ws.client = client; ws.clientType = clientType; // Add client to tracking const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId; await addClient(clientType, clientId, ws); ws.on("message", async (data) => { try { const message: WSMessage = JSON.parse(data.toString()); if (!message.type || typeof message.type !== "string") { throw new Error("Invalid message format: missing or invalid type"); } const handler = messageHandlers[message.type]; if (!handler) { throw new Error(`Unsupported message type: ${message.type}`); } const response = await handler({ message, senderWs: ws, client: ws.client, clientType: ws.clientType!, sendToClient, broadcastToAllExcept, connectedClients }); if (response) { if (response.broadcast) { await broadcastToAllExcept( response.message, response.excludeSender ? clientId : undefined ); } else if (response.targetClientId) { await sendToClient(response.targetClientId, response.message); } else { ws.send(JSON.stringify(response.message)); } } } catch (error) { logger.error("Message handling error:", error); ws.send(JSON.stringify({ type: "error", data: { message: error instanceof Error ? error.message : "Unknown error occurred", originalMessage: data.toString() } })); } }); ws.on("close", () => { removeClient(clientType, clientId, ws); logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`); }); ws.on("error", (error: Error) => { logger.error(`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, error); }); logger.info(`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`); }; // Router endpoint router.get("/ws", (req: Request, res: Response) => { res.status(200).send("WebSocket endpoint"); }); // WebSocket upgrade handler const handleWSUpgrade = (server: HttpServer): void => { server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => { try { const url = new URL(request.url || '', `http://${request.headers.host}`); const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || ''; let clientType = url.searchParams.get('clientType') as ClientType; if (!clientType) { clientType = "newt"; } if (!token || !clientType || !['newt', 'olm'].includes(clientType)) { logger.warn("Unauthorized connection attempt: invalid token or client type..."); socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n"); socket.destroy(); return; } const tokenPayload = await verifyToken(token, clientType); if (!tokenPayload) { logger.warn("Unauthorized connection attempt: invalid token..."); socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n"); socket.destroy(); return; } wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => { setupConnection(ws, tokenPayload.client, tokenPayload.clientType); }); } catch (error) { logger.error("WebSocket upgrade error:", error); socket.write("HTTP/1.1 500 Internal Server Error\r\n\r\n"); socket.destroy(); } }); }; // Initialize Redis subscription when the module is loaded if (redisManager.isRedisEnabled()) { initializeRedisSubscription().catch(error => { logger.error('Failed to initialize Redis subscription:', error); }); logger.info(`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`); } else { logger.debug('WebSocket handler initialized in local mode (Redis disabled)'); } // Cleanup function for graceful shutdown const cleanup = async (): Promise => { try { // Close all WebSocket connections connectedClients.forEach((clients) => { clients.forEach(client => { if (client.readyState === WebSocket.OPEN) { client.terminate(); } }); }); // Clean up Redis tracking for this node if (redisManager.isRedisEnabled()) { const keys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || []; if (keys.length > 0) { await Promise.all(keys.map(key => redisManager.del(key))); } } logger.info('WebSocket cleanup completed'); } catch (error) { logger.error('Error during WebSocket cleanup:', error); } }; // Handle process termination process.on('SIGTERM', cleanup); process.on('SIGINT', cleanup); export { router, handleWSUpgrade, sendToClient, broadcastToAllExcept, connectedClients, hasActiveConnections, getActiveNodes, NODE_ID, cleanup };