diff --git a/server/routers/newt/handleSocketMessages.ts b/server/routers/newt/handleSocketMessages.ts index 0a217c52..01b7be60 100644 --- a/server/routers/newt/handleSocketMessages.ts +++ b/server/routers/newt/handleSocketMessages.ts @@ -1,9 +1,11 @@ import { MessageHandler } from "../ws"; import logger from "@server/logger"; import { dockerSocketCache } from "./dockerSocket"; +import { Newt } from "@server/db"; export const handleDockerStatusMessage: MessageHandler = async (context) => { - const { message, newt } = context; + const { message, client, sendToClient } = context; + const newt = client as Newt; logger.info("Handling Docker socket check response"); @@ -33,7 +35,8 @@ export const handleDockerStatusMessage: MessageHandler = async (context) => { export const handleDockerContainersMessage: MessageHandler = async ( context ) => { - const { message, newt } = context; + const { message, client, sendToClient } = context; + const newt = client as Newt; logger.info("Handling Docker containers response"); diff --git a/server/routers/newt/index.ts b/server/routers/newt/index.ts index e51ef93d..08f047e3 100644 --- a/server/routers/newt/index.ts +++ b/server/routers/newt/index.ts @@ -4,3 +4,4 @@ export * from "./handleNewtRegisterMessage"; export * from "./handleReceiveBandwidthMessage"; export * from "./handleGetConfigMessage"; export * from "./handleSocketMessages"; +export * from "./handleNewtPingRequestMessage"; \ No newline at end of file diff --git a/server/routers/ws.ts b/server/routers/ws.ts index 709da400..f6b5b99d 100644 --- a/server/routers/ws.ts +++ b/server/routers/ws.ts @@ -49,16 +49,14 @@ interface HandlerContext { senderWs: WebSocket; client: Newt | Olm | undefined; clientType: ClientType; - sendToClient: (clientType: ClientType, clientId: string, message: WSMessage) => Promise; - broadcastToAllExcept: (message: WSMessage, excludeClientType?: ClientType, excludeClientId?: string) => Promise; + sendToClient: (clientId: string, message: WSMessage) => Promise; + broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => Promise; connectedClients: Map; } interface RedisMessage { type: 'direct' | 'broadcast'; - targetClientType?: ClientType; targetClientId?: string; - excludeClientType?: ClientType; excludeClientId?: string; message: WSMessage; fromNodeId: string; @@ -76,11 +74,11 @@ const REDIS_CHANNEL = 'websocket_messages'; // Client tracking map (local to this node) let connectedClients: Map = new Map(); // Helper to get map key -const getClientMapKey = (clientType: ClientType, clientId: string) => `${clientType}:${clientId}`; +const getClientMapKey = (clientId: string) => clientId; // Redis keys (generalized) -const getConnectionsKey = (clientType: ClientType, clientId: string) => `ws:connections:${clientType}:${clientId}`; -const getNodeConnectionsKey = (nodeId: string, clientType: ClientType, clientId: string) => `ws:node:${nodeId}:${clientType}:${clientId}`; +const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`; +const getNodeConnectionsKey = (nodeId: string, clientId: string) => `ws:node:${nodeId}:${clientId}`; // Initialize Redis subscription for cross-node messaging const initializeRedisSubscription = async (): Promise => { @@ -93,12 +91,12 @@ const initializeRedisSubscription = async (): Promise => { // Ignore messages from this node if (redisMessage.fromNodeId === NODE_ID) return; - if (redisMessage.type === 'direct' && redisMessage.targetClientType && redisMessage.targetClientId) { + if (redisMessage.type === 'direct' && redisMessage.targetClientId) { // Send to specific client on this node - await sendToClientLocal(redisMessage.targetClientType, redisMessage.targetClientId, redisMessage.message); + await sendToClientLocal(redisMessage.targetClientId, redisMessage.message); } else if (redisMessage.type === 'broadcast') { // Broadcast to all clients on this node except excluded - await broadcastToAllExceptLocal(redisMessage.message, redisMessage.excludeClientType, redisMessage.excludeClientId); + await broadcastToAllExceptLocal(redisMessage.message, redisMessage.excludeClientId); } } catch (error) { logger.error('Error processing Redis message:', error); @@ -113,30 +111,30 @@ const addClient = async (clientType: ClientType, clientId: string, ws: Authentic ws.connectionId = connectionId; // Add to local tracking - const mapKey = getClientMapKey(clientType, clientId); + const mapKey = getClientMapKey(clientId); const existingClients = connectedClients.get(mapKey) || []; existingClients.push(ws); connectedClients.set(mapKey, existingClients); // Add to Redis tracking if enabled if (redisManager.isRedisEnabled()) { - await redisManager.sadd(getConnectionsKey(clientType, clientId), NODE_ID); - await redisManager.hset(getNodeConnectionsKey(NODE_ID, clientType, clientId), connectionId, Date.now().toString()); + await redisManager.sadd(getConnectionsKey(clientId), NODE_ID); + await redisManager.hset(getNodeConnectionsKey(NODE_ID, clientId), connectionId, Date.now().toString()); } logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`); }; const removeClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise => { - const mapKey = getClientMapKey(clientType, clientId); + const mapKey = getClientMapKey(clientId); const existingClients = connectedClients.get(mapKey) || []; const updatedClients = existingClients.filter(client => client !== ws); if (updatedClients.length === 0) { connectedClients.delete(mapKey); if (redisManager.isRedisEnabled()) { - await redisManager.srem(getConnectionsKey(clientType, clientId), NODE_ID); - await redisManager.del(getNodeConnectionsKey(NODE_ID, clientType, clientId)); + await redisManager.srem(getConnectionsKey(clientId), NODE_ID); + await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId)); } logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`); @@ -144,7 +142,7 @@ const removeClient = async (clientType: ClientType, clientId: string, ws: Authen connectedClients.set(mapKey, updatedClients); if (redisManager.isRedisEnabled() && ws.connectionId) { - await redisManager.hdel(getNodeConnectionsKey(NODE_ID, clientType, clientId), ws.connectionId); + await redisManager.hdel(getNodeConnectionsKey(NODE_ID, clientId), ws.connectionId); } logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`); @@ -152,8 +150,8 @@ const removeClient = async (clientType: ClientType, clientId: string, ws: Authen }; // Local message sending (within this node) -const sendToClientLocal = async (clientType: ClientType, clientId: string, message: WSMessage): Promise => { - const mapKey = getClientMapKey(clientType, clientId); +const sendToClientLocal = async (clientId: string, message: WSMessage): Promise => { + const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); if (!clients || clients.length === 0) { return false; @@ -167,10 +165,10 @@ const sendToClientLocal = async (clientType: ClientType, clientId: string, messa return true; }; -const broadcastToAllExceptLocal = async (message: WSMessage, excludeClientType?: ClientType, excludeClientId?: string): Promise => { +const broadcastToAllExceptLocal = async (message: WSMessage, excludeClientId?: string): Promise => { connectedClients.forEach((clients, mapKey) => { const [type, id] = mapKey.split(":"); - if (!(excludeClientType && excludeClientId && type === excludeClientType && id === excludeClientId)) { + if (!(excludeClientId && id === excludeClientId)) { clients.forEach(client => { if (client.readyState === WebSocket.OPEN) { client.send(JSON.stringify(message)); @@ -181,15 +179,14 @@ const broadcastToAllExceptLocal = async (message: WSMessage, excludeClientType?: }; // Cross-node message sending (via Redis) -const sendToClient = async (clientType: ClientType, clientId: string, message: WSMessage): Promise => { +const sendToClient = async (clientId: string, message: WSMessage): Promise => { // Try to send locally first - const localSent = await sendToClientLocal(clientType, clientId, message); + const localSent = await sendToClientLocal(clientId, message); // If Redis is enabled, also send via Redis pub/sub to other nodes if (redisManager.isRedisEnabled()) { const redisMessage: RedisMessage = { type: 'direct', - targetClientType: clientType, targetClientId: clientId, message, fromNodeId: NODE_ID @@ -201,15 +198,14 @@ const sendToClient = async (clientType: ClientType, clientId: string, message: W return localSent; }; -const broadcastToAllExcept = async (message: WSMessage, excludeClientType?: ClientType, excludeClientId?: string): Promise => { +const broadcastToAllExcept = async (message: WSMessage, excludeClientId?: string): Promise => { // Broadcast locally - await broadcastToAllExceptLocal(message, excludeClientType, excludeClientId); + await broadcastToAllExceptLocal(message, excludeClientId); // If Redis is enabled, also broadcast via Redis pub/sub to other nodes if (redisManager.isRedisEnabled()) { const redisMessage: RedisMessage = { type: 'broadcast', - excludeClientType, excludeClientId, message, fromNodeId: NODE_ID @@ -220,26 +216,26 @@ const broadcastToAllExcept = async (message: WSMessage, excludeClientType?: Clie }; // Check if a client has active connections across all nodes -const hasActiveConnections = async (clientType: ClientType, clientId: string): Promise => { +const hasActiveConnections = async (clientId: string): Promise => { if (!redisManager.isRedisEnabled()) { - const mapKey = getClientMapKey(clientType, clientId); + const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); return !!(clients && clients.length > 0); } - const activeNodes = await redisManager.smembers(getConnectionsKey(clientType, clientId)); + const activeNodes = await redisManager.smembers(getConnectionsKey(clientId)); return activeNodes.length > 0; }; // Get all active nodes for a client const getActiveNodes = async (clientType: ClientType, clientId: string): Promise => { if (!redisManager.isRedisEnabled()) { - const mapKey = getClientMapKey(clientType, clientId); + const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); return (clients && clients.length > 0) ? [NODE_ID] : []; } - return await redisManager.smembers(getConnectionsKey(clientType, clientId)); + return await redisManager.smembers(getConnectionsKey(clientId)); }; // Token verification middleware @@ -320,11 +316,10 @@ const setupConnection = async (ws: AuthenticatedWebSocket, client: Newt | Olm, c if (response.broadcast) { await broadcastToAllExcept( response.message, - response.excludeSender ? clientType : undefined, response.excludeSender ? clientId : undefined ); } else if (response.targetClientId) { - await sendToClient(clientType, response.targetClientId, response.message); + await sendToClient(response.targetClientId, response.message); } else { ws.send(JSON.stringify(response.message)); }