fosrl.pangolin/server/routers/ws.ts

445 lines
15 KiB
TypeScript
Raw Normal View History

2024-11-10 17:34:07 -05:00
import { Router, Request, Response } from "express";
import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws";
import { IncomingMessage } from "http";
import { Socket } from "net";
2025-06-10 13:00:20 -04:00
import { Newt, newts, NewtSession, olms, Olm, OlmSession } from "@server/db";
2024-11-10 17:34:07 -05:00
import { eq } from "drizzle-orm";
2025-06-04 12:02:07 -04:00
import { db } from "@server/db";
2025-01-01 21:41:31 -05:00
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
2025-02-21 10:13:41 -05:00
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
2024-11-10 21:06:36 -05:00
import { messageHandlers } from "./messageHandlers";
2024-11-15 21:53:58 -05:00
import logger from "@server/logger";
2025-05-28 20:59:06 -04:00
import redisManager from "@server/db/redis";
import { v4 as uuidv4 } from "uuid";
2024-11-04 00:29:25 -05:00
// Custom interfaces
interface WebSocketRequest extends IncomingMessage {
2024-11-10 17:08:11 -05:00
token?: string;
2024-11-04 00:29:25 -05:00
}
2025-02-21 10:13:41 -05:00
type ClientType = 'newt' | 'olm';
2024-11-04 00:29:25 -05:00
interface AuthenticatedWebSocket extends WebSocket {
2025-02-21 10:13:41 -05:00
client?: Newt | Olm;
clientType?: ClientType;
2025-05-28 20:59:06 -04:00
connectionId?: string;
2024-11-04 00:29:25 -05:00
}
interface TokenPayload {
2025-02-21 10:13:41 -05:00
client: Newt | Olm;
session: NewtSession | OlmSession;
clientType: ClientType;
2024-11-04 00:29:25 -05:00
}
2024-11-10 17:34:07 -05:00
interface WSMessage {
type: string;
data: any;
}
interface HandlerResponse {
message: WSMessage;
broadcast?: boolean;
excludeSender?: boolean;
2025-02-21 10:13:41 -05:00
targetClientId?: string;
2024-11-10 17:34:07 -05:00
}
interface HandlerContext {
message: WSMessage;
senderWs: WebSocket;
2025-02-21 10:13:41 -05:00
client: Newt | Olm | undefined;
clientType: ClientType;
2025-06-15 18:18:28 -04:00
sendToClient: (clientId: string, message: WSMessage) => Promise<boolean>;
broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => Promise<void>;
2024-11-10 17:34:07 -05:00
connectedClients: Map<string, WebSocket[]>;
}
2025-05-28 20:59:06 -04:00
interface RedisMessage {
type: 'direct' | 'broadcast';
2025-06-15 11:54:55 -04:00
targetClientId?: string;
excludeClientId?: string;
2025-05-28 20:59:06 -04:00
message: WSMessage;
fromNodeId: string;
}
2024-11-10 17:34:07 -05:00
export type MessageHandler = (context: HandlerContext) => Promise<HandlerResponse | void>;
2024-11-04 00:29:25 -05:00
const router: Router = Router();
const wss: WebSocketServer = new WebSocketServer({ noServer: true });
2025-05-28 20:59:06 -04:00
// Generate unique node ID for this instance
const NODE_ID = uuidv4();
const REDIS_CHANNEL = 'websocket_messages';
// Client tracking map (local to this node)
2024-11-10 17:34:07 -05:00
let connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
2025-06-15 11:54:55 -04:00
// Helper to get map key
2025-06-15 18:18:28 -04:00
const getClientMapKey = (clientId: string) => clientId;
2024-11-10 17:34:07 -05:00
2025-06-15 11:54:55 -04:00
// Redis keys (generalized)
2025-06-15 18:18:28 -04:00
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
const getNodeConnectionsKey = (nodeId: string, clientId: string) => `ws:node:${nodeId}:${clientId}`;
2025-05-28 20:59:06 -04:00
// Initialize Redis subscription for cross-node messaging
const initializeRedisSubscription = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
await redisManager.subscribe(REDIS_CHANNEL, async (channel: string, message: string) => {
try {
const redisMessage: RedisMessage = JSON.parse(message);
2025-06-15 11:54:55 -04:00
2025-05-28 20:59:06 -04:00
// Ignore messages from this node
if (redisMessage.fromNodeId === NODE_ID) return;
2025-06-15 11:54:55 -04:00
2025-06-15 18:18:28 -04:00
if (redisMessage.type === 'direct' && redisMessage.targetClientId) {
2025-05-28 20:59:06 -04:00
// Send to specific client on this node
2025-06-15 18:18:28 -04:00
await sendToClientLocal(redisMessage.targetClientId, redisMessage.message);
2025-05-28 20:59:06 -04:00
} else if (redisMessage.type === 'broadcast') {
// Broadcast to all clients on this node except excluded
2025-06-15 18:18:28 -04:00
await broadcastToAllExceptLocal(redisMessage.message, redisMessage.excludeClientId);
2025-05-28 20:59:06 -04:00
}
} catch (error) {
logger.error('Error processing Redis message:', error);
}
});
};
2024-11-10 17:34:07 -05:00
// Helper functions for client management
2025-06-15 11:54:55 -04:00
const addClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise<void> => {
2025-05-28 20:59:06 -04:00
// Generate unique connection ID
const connectionId = uuidv4();
ws.connectionId = connectionId;
// Add to local tracking
2025-06-15 18:18:28 -04:00
const mapKey = getClientMapKey(clientId);
2025-06-15 11:54:55 -04:00
const existingClients = connectedClients.get(mapKey) || [];
2024-11-10 17:34:07 -05:00
existingClients.push(ws);
2025-06-15 11:54:55 -04:00
connectedClients.set(mapKey, existingClients);
2025-05-28 20:59:06 -04:00
// Add to Redis tracking if enabled
if (redisManager.isRedisEnabled()) {
2025-06-15 18:18:28 -04:00
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
await redisManager.hset(getNodeConnectionsKey(NODE_ID, clientId), connectionId, Date.now().toString());
2025-05-28 20:59:06 -04:00
}
2025-06-15 11:54:55 -04:00
logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`);
2024-11-10 17:34:07 -05:00
};
2025-06-15 11:54:55 -04:00
const removeClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise<void> => {
2025-06-15 18:18:28 -04:00
const mapKey = getClientMapKey(clientId);
2025-06-15 11:54:55 -04:00
const existingClients = connectedClients.get(mapKey) || [];
2024-11-10 17:34:07 -05:00
const updatedClients = existingClients.filter(client => client !== ws);
if (updatedClients.length === 0) {
2025-06-15 11:54:55 -04:00
connectedClients.delete(mapKey);
2025-05-28 20:59:06 -04:00
if (redisManager.isRedisEnabled()) {
2025-06-15 18:18:28 -04:00
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId));
2025-05-28 20:59:06 -04:00
}
2025-06-15 11:54:55 -04:00
2025-02-21 10:13:41 -05:00
logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`);
2024-11-10 17:34:07 -05:00
} else {
2025-06-15 11:54:55 -04:00
connectedClients.set(mapKey, updatedClients);
2025-05-28 20:59:06 -04:00
if (redisManager.isRedisEnabled() && ws.connectionId) {
2025-06-15 18:18:28 -04:00
await redisManager.hdel(getNodeConnectionsKey(NODE_ID, clientId), ws.connectionId);
2025-05-28 20:59:06 -04:00
}
2025-06-15 11:54:55 -04:00
2025-02-21 10:13:41 -05:00
logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`);
2024-11-10 17:34:07 -05:00
}
};
2025-05-28 20:59:06 -04:00
// Local message sending (within this node)
2025-06-15 18:18:28 -04:00
const sendToClientLocal = async (clientId: string, message: WSMessage): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
2025-06-15 11:54:55 -04:00
const clients = connectedClients.get(mapKey);
2024-11-10 17:34:07 -05:00
if (!clients || clients.length === 0) {
return false;
}
const messageString = JSON.stringify(message);
clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
return true;
};
2025-06-15 18:18:28 -04:00
const broadcastToAllExceptLocal = async (message: WSMessage, excludeClientId?: string): Promise<void> => {
2025-06-15 11:54:55 -04:00
connectedClients.forEach((clients, mapKey) => {
const [type, id] = mapKey.split(":");
2025-06-15 18:18:28 -04:00
if (!(excludeClientId && id === excludeClientId)) {
2024-11-10 17:34:07 -05:00
clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message));
}
});
}
});
};
2025-05-28 20:59:06 -04:00
// Cross-node message sending (via Redis)
2025-06-15 18:18:28 -04:00
const sendToClient = async (clientId: string, message: WSMessage): Promise<boolean> => {
2025-05-28 20:59:06 -04:00
// Try to send locally first
2025-06-15 18:18:28 -04:00
const localSent = await sendToClientLocal(clientId, message);
2025-06-15 11:54:55 -04:00
2025-05-28 20:59:06 -04:00
// If Redis is enabled, also send via Redis pub/sub to other nodes
if (redisManager.isRedisEnabled()) {
const redisMessage: RedisMessage = {
type: 'direct',
2025-06-15 11:54:55 -04:00
targetClientId: clientId,
2025-05-28 20:59:06 -04:00
message,
fromNodeId: NODE_ID
};
2025-06-15 11:54:55 -04:00
2025-05-28 20:59:06 -04:00
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
}
2025-06-15 11:54:55 -04:00
2025-05-28 20:59:06 -04:00
return localSent;
};
2025-06-15 18:18:28 -04:00
const broadcastToAllExcept = async (message: WSMessage, excludeClientId?: string): Promise<void> => {
2025-05-28 20:59:06 -04:00
// Broadcast locally
2025-06-15 18:18:28 -04:00
await broadcastToAllExceptLocal(message, excludeClientId);
2025-06-15 11:54:55 -04:00
2025-05-28 20:59:06 -04:00
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes
if (redisManager.isRedisEnabled()) {
const redisMessage: RedisMessage = {
type: 'broadcast',
2025-06-15 11:54:55 -04:00
excludeClientId,
2025-05-28 20:59:06 -04:00
message,
fromNodeId: NODE_ID
};
2025-06-15 11:54:55 -04:00
2025-05-28 20:59:06 -04:00
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
}
};
2025-06-15 11:54:55 -04:00
// Check if a client has active connections across all nodes
2025-06-15 18:18:28 -04:00
const hasActiveConnections = async (clientId: string): Promise<boolean> => {
2025-05-28 20:59:06 -04:00
if (!redisManager.isRedisEnabled()) {
2025-06-15 18:18:28 -04:00
const mapKey = getClientMapKey(clientId);
2025-06-15 11:54:55 -04:00
const clients = connectedClients.get(mapKey);
2025-05-28 20:59:06 -04:00
return !!(clients && clients.length > 0);
}
2025-06-15 11:54:55 -04:00
2025-06-15 18:18:28 -04:00
const activeNodes = await redisManager.smembers(getConnectionsKey(clientId));
2025-05-28 20:59:06 -04:00
return activeNodes.length > 0;
};
2025-06-15 11:54:55 -04:00
// Get all active nodes for a client
const getActiveNodes = async (clientType: ClientType, clientId: string): Promise<string[]> => {
2025-05-28 20:59:06 -04:00
if (!redisManager.isRedisEnabled()) {
2025-06-15 18:18:28 -04:00
const mapKey = getClientMapKey(clientId);
2025-06-15 11:54:55 -04:00
const clients = connectedClients.get(mapKey);
2025-05-28 20:59:06 -04:00
return (clients && clients.length > 0) ? [NODE_ID] : [];
}
2024-11-10 17:08:11 -05:00
2025-06-15 18:18:28 -04:00
return await redisManager.smembers(getConnectionsKey(clientId));
2025-06-15 11:54:55 -04:00
};
2024-11-10 17:08:11 -05:00
2025-02-21 10:13:41 -05:00
// Token verification middleware
const verifyToken = async (token: string, clientType: ClientType): Promise<TokenPayload | null> => {
2024-11-10 17:08:11 -05:00
2025-06-15 11:54:55 -04:00
try {
2025-02-21 10:13:41 -05:00
if (clientType === 'newt') {
const { session, newt } = await validateNewtSessionToken(token);
if (!session || !newt) {
return null;
}
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { client: existingNewt[0], session, clientType };
} else {
const { session, olm } = await validateOlmSessionToken(token);
if (!session || !olm) {
return null;
}
const existingOlm = await db
.select()
.from(olms)
.where(eq(olms.olmId, olm.olmId));
if (!existingOlm || !existingOlm[0]) {
return null;
}
return { client: existingOlm[0], session, clientType };
2024-11-10 17:08:11 -05:00
}
} catch (error) {
2024-12-07 22:07:13 -05:00
logger.error("Token verification failed:", error);
2024-11-10 17:08:11 -05:00
return null;
}
2024-11-04 00:29:25 -05:00
};
2025-06-15 11:54:55 -04:00
const setupConnection = async (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): Promise<void> => {
2024-12-07 22:07:13 -05:00
logger.info("Establishing websocket connection");
2025-02-21 10:13:41 -05:00
if (!client) {
logger.error("Connection attempt without client");
2024-11-10 17:34:07 -05:00
return ws.terminate();
}
2025-02-21 10:13:41 -05:00
ws.client = client;
ws.clientType = clientType;
2024-11-10 17:08:11 -05:00
2024-12-07 22:07:13 -05:00
// Add client to tracking
2025-02-21 10:13:41 -05:00
const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId;
2025-06-15 11:54:55 -04:00
await addClient(clientType, clientId, ws);
2024-11-10 17:08:11 -05:00
2024-11-10 17:34:07 -05:00
ws.on("message", async (data) => {
2024-11-10 17:08:11 -05:00
try {
const message: WSMessage = JSON.parse(data.toString());
2025-01-01 21:41:31 -05:00
2024-11-10 17:34:07 -05:00
if (!message.type || typeof message.type !== "string") {
throw new Error("Invalid message format: missing or invalid type");
}
2025-01-01 21:41:31 -05:00
2024-11-10 17:34:07 -05:00
const handler = messageHandlers[message.type];
if (!handler) {
throw new Error(`Unsupported message type: ${message.type}`);
}
2025-01-01 21:41:31 -05:00
2024-11-10 17:34:07 -05:00
const response = await handler({
message,
senderWs: ws,
2025-02-21 10:13:41 -05:00
client: ws.client,
clientType: ws.clientType!,
2024-11-10 17:34:07 -05:00
sendToClient,
broadcastToAllExcept,
connectedClients
});
2025-01-01 21:41:31 -05:00
2024-11-10 17:34:07 -05:00
if (response) {
if (response.broadcast) {
2025-06-15 11:54:55 -04:00
await broadcastToAllExcept(
response.message,
response.excludeSender ? clientId : undefined
);
2025-02-21 10:13:41 -05:00
} else if (response.targetClientId) {
2025-06-15 18:18:28 -04:00
await sendToClient(response.targetClientId, response.message);
2024-11-10 17:34:07 -05:00
} else {
ws.send(JSON.stringify(response.message));
}
}
2024-11-10 17:08:11 -05:00
} catch (error) {
2024-12-07 22:07:13 -05:00
logger.error("Message handling error:", error);
2024-11-10 17:08:11 -05:00
ws.send(JSON.stringify({
2024-11-10 17:34:07 -05:00
type: "error",
data: {
message: error instanceof Error ? error.message : "Unknown error occurred",
originalMessage: data.toString()
}
2024-11-10 17:08:11 -05:00
}));
}
2025-01-01 21:41:31 -05:00
});
2024-11-10 17:34:07 -05:00
ws.on("close", () => {
2025-06-15 11:54:55 -04:00
removeClient(clientType, clientId, ws);
2025-02-21 10:13:41 -05:00
logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`);
2024-11-10 17:08:11 -05:00
});
2025-01-01 21:41:31 -05:00
2024-11-10 17:34:07 -05:00
ws.on("error", (error: Error) => {
2025-02-21 10:13:41 -05:00
logger.error(`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, error);
2024-11-10 17:08:11 -05:00
});
2024-12-07 22:07:13 -05:00
2025-02-21 10:13:41 -05:00
logger.info(`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`);
2024-12-07 22:07:13 -05:00
};
2025-02-21 10:13:41 -05:00
// Router endpoint
2024-12-07 22:07:13 -05:00
router.get("/ws", (req: Request, res: Response) => {
res.status(200).send("WebSocket endpoint");
2024-11-04 00:29:25 -05:00
});
2024-12-07 22:07:13 -05:00
// WebSocket upgrade handler
const handleWSUpgrade = (server: HttpServer): void => {
server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
try {
2025-02-21 10:13:41 -05:00
const url = new URL(request.url || '', `http://${request.headers.host}`);
const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || '';
let clientType = url.searchParams.get('clientType') as ClientType;
if (!clientType) {
clientType = "newt";
}
2024-12-07 22:07:13 -05:00
2025-02-21 10:13:41 -05:00
if (!token || !clientType || !['newt', 'olm'].includes(clientType)) {
logger.warn("Unauthorized connection attempt: invalid token or client type...");
2024-12-07 22:07:13 -05:00
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
socket.destroy();
return;
}
2025-02-21 10:13:41 -05:00
const tokenPayload = await verifyToken(token, clientType);
2024-12-07 22:07:13 -05:00
if (!tokenPayload) {
logger.warn("Unauthorized connection attempt: invalid token...");
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
socket.destroy();
return;
}
wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
2025-02-21 10:13:41 -05:00
setupConnection(ws, tokenPayload.client, tokenPayload.clientType);
2024-12-07 22:07:13 -05:00
});
} catch (error) {
logger.error("WebSocket upgrade error:", error);
socket.write("HTTP/1.1 500 Internal Server Error\r\n\r\n");
socket.destroy();
}
});
};
2025-05-28 20:59:06 -04:00
// Initialize Redis subscription when the module is loaded
if (redisManager.isRedisEnabled()) {
initializeRedisSubscription().catch(error => {
logger.error('Failed to initialize Redis subscription:', error);
});
logger.info(`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`);
} else {
2025-06-19 16:57:54 -04:00
logger.debug('WebSocket handler initialized in local mode (Redis disabled)');
2025-05-28 20:59:06 -04:00
}
// Cleanup function for graceful shutdown
const cleanup = async (): Promise<void> => {
try {
// Close all WebSocket connections
connectedClients.forEach((clients) => {
clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) {
client.terminate();
}
});
});
// Clean up Redis tracking for this node
if (redisManager.isRedisEnabled()) {
const keys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || [];
if (keys.length > 0) {
await Promise.all(keys.map(key => redisManager.del(key)));
}
}
logger.info('WebSocket cleanup completed');
} catch (error) {
logger.error('Error during WebSocket cleanup:', error);
}
};
// Handle process termination
process.on('SIGTERM', cleanup);
process.on('SIGINT', cleanup);
2024-11-04 00:29:25 -05:00
export {
2024-11-10 17:08:11 -05:00
router,
2024-11-10 17:34:07 -05:00
handleWSUpgrade,
sendToClient,
broadcastToAllExcept,
2025-05-28 20:59:06 -04:00
connectedClients,
hasActiveConnections,
getActiveNodes,
NODE_ID,
cleanup
2025-01-01 21:41:31 -05:00
};