From e112fcba29ab8b3b7cfc59ae07ed9d39da7b3f65 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 10:13:41 -0500 Subject: [PATCH] Move things around; rename to olm --- server/auth/sessions/olm.ts | 72 ++++++++ server/db/schema.ts | 12 +- server/routers/client/index.ts | 1 - server/routers/external.ts | 4 +- server/routers/messageHandlers.ts | 6 +- ...essage.ts => handleNewtRegisterMessage.ts} | 6 +- server/routers/olm/createOlm.ts | 106 ++++++++++++ server/routers/olm/getToken.ts | 115 +++++++++++++ server/routers/olm/handleGetConfigMessage.ts | 147 ++++++++++++++++ .../routers/olm/handleOlmRegisterMessage.ts | 93 ++++++++++ server/routers/olm/index.ts | 1 + .../pickOlmDefaults.ts} | 10 +- server/routers/ws.ts | 162 ++++++++++-------- 13 files changed, 642 insertions(+), 93 deletions(-) create mode 100644 server/auth/sessions/olm.ts delete mode 100644 server/routers/client/index.ts rename server/routers/newt/{handleRegisterMessage.ts => handleNewtRegisterMessage.ts} (96%) create mode 100644 server/routers/olm/createOlm.ts create mode 100644 server/routers/olm/getToken.ts create mode 100644 server/routers/olm/handleGetConfigMessage.ts create mode 100644 server/routers/olm/handleOlmRegisterMessage.ts create mode 100644 server/routers/olm/index.ts rename server/routers/{client/pickClientDefaults.ts => olm/pickOlmDefaults.ts} (94%) diff --git a/server/auth/sessions/olm.ts b/server/auth/sessions/olm.ts new file mode 100644 index 00000000..8d24c16f --- /dev/null +++ b/server/auth/sessions/olm.ts @@ -0,0 +1,72 @@ +import { + encodeHexLowerCase, +} from "@oslojs/encoding"; +import { sha256 } from "@oslojs/crypto/sha2"; +import { Olm, olms, olmSessions, OlmSession } from "@server/db/schema"; +import db from "@server/db"; +import { eq } from "drizzle-orm"; + +export const EXPIRES = 1000 * 60 * 60 * 24 * 30; + +export async function createOlmSession( + token: string, + olmId: string, +): Promise { + const sessionId = encodeHexLowerCase( + sha256(new TextEncoder().encode(token)), + ); + const session: OlmSession = { + sessionId: sessionId, + olmId, + expiresAt: new Date(Date.now() + EXPIRES).getTime(), + }; + await db.insert(olmSessions).values(session); + return session; +} + +export async function validateOlmSessionToken( + token: string, +): Promise { + const sessionId = encodeHexLowerCase( + sha256(new TextEncoder().encode(token)), + ); + const result = await db + .select({ olm: olms, session: olmSessions }) + .from(olmSessions) + .innerJoin(olms, eq(olmSessions.olmId, olms.olmId)) + .where(eq(olmSessions.sessionId, sessionId)); + if (result.length < 1) { + return { session: null, olm: null }; + } + const { olm, session } = result[0]; + if (Date.now() >= session.expiresAt) { + await db + .delete(olmSessions) + .where(eq(olmSessions.sessionId, session.sessionId)); + return { session: null, olm: null }; + } + if (Date.now() >= session.expiresAt - (EXPIRES / 2)) { + session.expiresAt = new Date( + Date.now() + EXPIRES, + ).getTime(); + await db + .update(olmSessions) + .set({ + expiresAt: session.expiresAt, + }) + .where(eq(olmSessions.sessionId, session.sessionId)); + } + return { session, olm }; +} + +export async function invalidateOlmSession(sessionId: string): Promise { + await db.delete(olmSessions).where(eq(olmSessions.sessionId, sessionId)); +} + +export async function invalidateAllOlmSessions(olmId: string): Promise { + await db.delete(olmSessions).where(eq(olmSessions.olmId, olmId)); +} + +export type SessionValidationResult = + | { session: OlmSession; olm: Olm } + | { session: null; olm: null }; diff --git a/server/db/schema.ts b/server/db/schema.ts index 70817573..9564fdff 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -114,8 +114,8 @@ export const newts = sqliteTable("newt", { }) }); -export const clients = sqliteTable("clients", { - clientId: text("id").primaryKey(), +export const olms = sqliteTable("olms", { + olmId: text("id").primaryKey(), secretHash: text("secretHash").notNull(), dateCreated: text("dateCreated").notNull(), siteId: integer("siteId").references(() => sites.siteId, { @@ -156,9 +156,9 @@ export const newtSessions = sqliteTable("newtSession", { expiresAt: integer("expiresAt").notNull() }); -export const clientSessions = sqliteTable("clientSession", { +export const olmSessions = sqliteTable("clientSession", { sessionId: text("id").primaryKey(), - clientId: text("clientId") + olmId: text("olmId") .notNull() .references(() => newts.newtId, { onDelete: "cascade" }), expiresAt: integer("expiresAt").notNull() @@ -425,8 +425,8 @@ export type Target = InferSelectModel; export type Session = InferSelectModel; export type Newt = InferSelectModel; export type NewtSession = InferSelectModel; -export type Client = InferSelectModel; -export type ClientSession = InferSelectModel; +export type Olm = InferSelectModel; +export type OlmSession = InferSelectModel; export type EmailVerificationCode = InferSelectModel< typeof emailVerificationCodes >; diff --git a/server/routers/client/index.ts b/server/routers/client/index.ts deleted file mode 100644 index 5b493724..00000000 --- a/server/routers/client/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./pickClientDefaults"; diff --git a/server/routers/external.ts b/server/routers/external.ts index 778bf288..cb5fbfd7 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -7,7 +7,7 @@ import * as target from "./target"; import * as user from "./user"; import * as auth from "./auth"; import * as role from "./role"; -import * as client from "./client"; +import * as olm from "./olm"; import * as accessToken from "./accessToken"; import HttpCode from "@server/types/HttpCode"; import { @@ -100,7 +100,7 @@ authenticated.get( "/site/:siteId/pick-client-defaults", verifyOrgAccess, verifyUserHasAction(ActionsEnum.createClient), - client.pickClientDefaults + olm.pickOlmDefaults ); // authenticated.get( diff --git a/server/routers/messageHandlers.ts b/server/routers/messageHandlers.ts index 262f9869..bf8f357c 100644 --- a/server/routers/messageHandlers.ts +++ b/server/routers/messageHandlers.ts @@ -1,8 +1,10 @@ -import { handleRegisterMessage } from "./newt"; +import { handleNewtRegisterMessage } from "./newt"; +import { handleOlmRegisterMessage } from "./olm"; import { handleGetConfigMessage } from "./newt/handleGetConfigMessage"; import { MessageHandler } from "./ws"; export const messageHandlers: Record = { - "newt/wg/register": handleRegisterMessage, + "newt/wg/register": handleNewtRegisterMessage, + "olm/wg/register": handleOlmRegisterMessage, "newt/wg/get-config": handleGetConfigMessage, }; diff --git a/server/routers/newt/handleRegisterMessage.ts b/server/routers/newt/handleNewtRegisterMessage.ts similarity index 96% rename from server/routers/newt/handleRegisterMessage.ts rename to server/routers/newt/handleNewtRegisterMessage.ts index 0f086698..8e263034 100644 --- a/server/routers/newt/handleRegisterMessage.ts +++ b/server/routers/newt/handleNewtRegisterMessage.ts @@ -11,8 +11,10 @@ import { eq, and, sql } from "drizzle-orm"; import { addPeer, deletePeer } from "../gerbil/peers"; import logger from "@server/logger"; -export const handleRegisterMessage: MessageHandler = async (context) => { - const { message, newt, sendToClient } = context; +export const handleNewtRegisterMessage: MessageHandler = async (context) => { + const { message, client, sendToClient } = context; + + const newt = client; logger.info("Handling register message!"); diff --git a/server/routers/olm/createOlm.ts b/server/routers/olm/createOlm.ts new file mode 100644 index 00000000..d43c4cc6 --- /dev/null +++ b/server/routers/olm/createOlm.ts @@ -0,0 +1,106 @@ +import { NextFunction, Request, Response } from "express"; +import db from "@server/db"; +import { hash } from "@node-rs/argon2"; +import HttpCode from "@server/types/HttpCode"; +import { z } from "zod"; +import { newts } from "@server/db/schema"; +import createHttpError from "http-errors"; +import response from "@server/lib/response"; +import { SqliteError } from "better-sqlite3"; +import moment from "moment"; +import { generateSessionToken } from "@server/auth/sessions/app"; +import { createNewtSession } from "@server/auth/sessions/newt"; +import { fromError } from "zod-validation-error"; +import { hashPassword } from "@server/auth/password"; + +export const createNewtBodySchema = z.object({}); + +export type CreateNewtBody = z.infer; + +export type CreateNewtResponse = { + token: string; + newtId: string; + secret: string; +}; + +const createNewtSchema = z + .object({ + newtId: z.string(), + secret: z.string() + }) + .strict(); + +export async function createNewt( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + + const parsedBody = createNewtSchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { newtId, secret } = parsedBody.data; + + if (!req.userOrgRoleId) { + return next( + createHttpError(HttpCode.FORBIDDEN, "User does not have a role") + ); + } + + const secretHash = await hashPassword(secret); + + await db.insert(newts).values({ + newtId: newtId, + secretHash, + dateCreated: moment().toISOString(), + }); + + // give the newt their default permissions: + // await db.insert(newtActions).values({ + // newtId: newtId, + // actionId: ActionsEnum.createOrg, + // orgId: null, + // }); + + const token = generateSessionToken(); + await createNewtSession(token, newtId); + + return response(res, { + data: { + newtId, + secret, + token, + }, + success: true, + error: false, + message: "Newt created successfully", + status: HttpCode.OK, + }); + } catch (e) { + if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "A newt with that email address already exists" + ) + ); + } else { + console.error(e); + + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to create newt" + ) + ); + } + } +} diff --git a/server/routers/olm/getToken.ts b/server/routers/olm/getToken.ts new file mode 100644 index 00000000..e6ae0cd6 --- /dev/null +++ b/server/routers/olm/getToken.ts @@ -0,0 +1,115 @@ +import { generateSessionToken } from "@server/auth/sessions/app"; +import db from "@server/db"; +import { newts } from "@server/db/schema"; +import HttpCode from "@server/types/HttpCode"; +import response from "@server/lib/response"; +import { eq } from "drizzle-orm"; +import { NextFunction, Request, Response } from "express"; +import createHttpError from "http-errors"; +import { z } from "zod"; +import { fromError } from "zod-validation-error"; +import { + createNewtSession, + validateNewtSessionToken +} from "@server/auth/sessions/newt"; +import { verifyPassword } from "@server/auth/password"; +import logger from "@server/logger"; +import config from "@server/lib/config"; + +export const newtGetTokenBodySchema = z.object({ + newtId: z.string(), + secret: z.string(), + token: z.string().optional() +}); + +export type NewtGetTokenBody = z.infer; + +export async function getToken( + req: Request, + res: Response, + next: NextFunction +): Promise { + const parsedBody = newtGetTokenBodySchema.safeParse(req.body); + + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { newtId, secret, token } = parsedBody.data; + + try { + if (token) { + const { session, newt } = await validateNewtSessionToken(token); + if (session) { + if (config.getRawConfig().app.log_failed_attempts) { + logger.info( + `Newt session already valid. Newt ID: ${newtId}. IP: ${req.ip}.` + ); + } + return response(res, { + data: null, + success: true, + error: false, + message: "Token session already valid", + status: HttpCode.OK + }); + } + } + + const existingNewtRes = await db + .select() + .from(newts) + .where(eq(newts.newtId, newtId)); + if (!existingNewtRes || !existingNewtRes.length) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "No newt found with that newtId" + ) + ); + } + + const existingNewt = existingNewtRes[0]; + + const validSecret = await verifyPassword( + secret, + existingNewt.secretHash + ); + if (!validSecret) { + if (config.getRawConfig().app.log_failed_attempts) { + logger.info( + `Newt id or secret is incorrect. Newt: ID ${newtId}. IP: ${req.ip}.` + ); + } + return next( + createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect") + ); + } + + const resToken = generateSessionToken(); + await createNewtSession(resToken, existingNewt.newtId); + + return response<{ token: string }>(res, { + data: { + token: resToken + }, + success: true, + error: false, + message: "Token created successfully", + status: HttpCode.OK + }); + } catch (e) { + console.error(e); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to authenticate newt" + ) + ); + } +} diff --git a/server/routers/olm/handleGetConfigMessage.ts b/server/routers/olm/handleGetConfigMessage.ts new file mode 100644 index 00000000..6e4f7ebf --- /dev/null +++ b/server/routers/olm/handleGetConfigMessage.ts @@ -0,0 +1,147 @@ +import { z } from "zod"; +import { MessageHandler } from "../ws"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; +import db from "@server/db"; +import { olms, Site, sites } from "@server/db/schema"; +import { eq, isNotNull } from "drizzle-orm"; +import { findNextAvailableCidr } from "@server/lib/ip"; +import config from "@server/lib/config"; + +const inputSchema = z.object({ + publicKey: z.string(), + endpoint: z.string(), + listenPort: z.number() +}); + +type Input = z.infer; + +export const handleGetConfigMessage: MessageHandler = async (context) => { + const { message, newt, sendToClient } = context; + + logger.debug("Handling Newt get config message!"); + + if (!newt) { + logger.warn("Newt not found"); + return; + } + + if (!newt.siteId) { + logger.warn("Newt has no site!"); // TODO: Maybe we create the site here? + return; + } + + const parsed = inputSchema.safeParse(message.data); + if (!parsed.success) { + logger.error( + "handleGetConfigMessage: Invalid input: " + + fromError(parsed.error).toString() + ); + return; + } + + const { publicKey, endpoint, listenPort } = message.data as Input; + + const siteId = newt.siteId; + + const [siteRes] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)); + + if (!siteRes) { + logger.warn("handleGetConfigMessage: Site not found"); + return; + } + + let site: Site | undefined; + if (!site) { + const address = await getNextAvailableSubnet(); + + // create a new exit node + const [updateRes] = await db + .update(sites) + .set({ + publicKey, + endpoint, + address, + listenPort + }) + .where(eq(sites.siteId, siteId)) + .returning(); + + site = updateRes; + + logger.info(`Updated site ${siteId} with new WG Newt info`); + } else { + site = siteRes; + } + + if (!site) { + logger.error("handleGetConfigMessage: Failed to update site"); + return; + } + + const clientsRes = await db + .select() + .from(olms) + .where(eq(olms.siteId, siteId)); + + const peers = await Promise.all( + clientsRes.map(async (client) => { + return { + publicKey: client.pubKey, + allowedIps: "0.0.0.0/0" + }; + }) + ); + + const configResponse = { + listenPort: site.listenPort, // ????? + // ipAddress: exitNode[0].address, + peers + }; + + logger.debug("Sending config: ", configResponse); + + return { + message: { + type: "olm/wg/connect", // what to make the response type? + data: { + config: configResponse + } + }, + broadcast: false, // Send to all clients + excludeSender: false // Include sender in broadcast + }; +}; + +async function getNextAvailableSubnet(): Promise { + const existingAddresses = await db + .select({ + address: sites.address + }) + .from(sites) + .where(isNotNull(sites.address)); + + const addresses = existingAddresses + .map((a) => a.address) + .filter((a) => a) as string[]; + + let subnet = findNextAvailableCidr( + addresses, + config.getRawConfig().wg_site.block_size, + config.getRawConfig().wg_site.subnet_group + ); + if (!subnet) { + throw new Error("No available subnets remaining in space"); + } + + // replace the last octet with 1 + subnet = + subnet.split(".").slice(0, 3).join(".") + + ".1" + + "/" + + subnet.split("/")[1]; + return subnet; +} diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts new file mode 100644 index 00000000..33786f2d --- /dev/null +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -0,0 +1,93 @@ +import db from "@server/db"; +import { MessageHandler } from "../ws"; +import { + exitNodes, + resources, + sites, + Target, + targets +} from "@server/db/schema"; +import { eq, and, sql } from "drizzle-orm"; +import { addPeer, deletePeer } from "../gerbil/peers"; +import logger from "@server/logger"; + +export const handleOlmRegisterMessage: MessageHandler = async (context) => { + const { message, client, sendToClient } = context; + + const olm = client; + + logger.info("Handling register message!"); + + if (!olm) { + logger.warn("Olm not found"); + return; + } + + if (!olm.siteId) { + logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? + return; + } + + const siteId = olm.siteId; + + const { publicKey } = message.data; + if (!publicKey) { + logger.warn("Public key not provided"); + return; + } + + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site || !site.exitNodeId) { + logger.warn("Site not found or does not have exit node"); + return; + } + + await db + .update(sites) + .set({ + pubKey: publicKey + }) + .where(eq(sites.siteId, siteId)) + .returning(); + + const [exitNode] = await db + .select() + .from(exitNodes) + .where(eq(exitNodes.exitNodeId, site.exitNodeId)) + .limit(1); + + if (site.pubKey && site.pubKey !== publicKey) { + logger.info("Public key mismatch. Deleting old peer..."); + await deletePeer(site.exitNodeId, site.pubKey); + } + + if (!site.subnet) { + logger.warn("Site has no subnet"); + return; + } + + // add the peer to the exit node + await addPeer(site.exitNodeId, { + publicKey: publicKey, + allowedIps: [site.subnet] + }); + + return { + message: { + type: "olm/wg/connect", + data: { + endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`, + publicKey: exitNode.publicKey, + serverIP: exitNode.address.split("/")[0], + tunnelIP: site.subnet.split("/")[0] + } + }, + broadcast: false, // Send to all olms + excludeSender: false // Include sender in broadcast + }; +}; diff --git a/server/routers/olm/index.ts b/server/routers/olm/index.ts new file mode 100644 index 00000000..7265331b --- /dev/null +++ b/server/routers/olm/index.ts @@ -0,0 +1 @@ +export * from "./pickOlmDefaults"; \ No newline at end of file diff --git a/server/routers/client/pickClientDefaults.ts b/server/routers/olm/pickOlmDefaults.ts similarity index 94% rename from server/routers/client/pickClientDefaults.ts rename to server/routers/olm/pickOlmDefaults.ts index eb765fc2..24ddcace 100644 --- a/server/routers/client/pickClientDefaults.ts +++ b/server/routers/olm/pickOlmDefaults.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { db } from "@server/db"; -import { clients, sites } from "@server/db/schema"; +import { olms, sites } from "@server/db/schema"; import { eq } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; @@ -30,7 +30,7 @@ export type PickClientDefaultsResponse = { clientSecret: string; }; -export async function pickClientDefaults( +export async function pickOlmDefaults( req: Request, res: Response, next: NextFunction @@ -71,10 +71,10 @@ export async function pickClientDefaults( const clientsQuery = await db .select({ - subnet: clients.subnet + subnet: olms.subnet }) - .from(clients) - .where(eq(clients.siteId, site.siteId)); + .from(olms) + .where(eq(olms.siteId, site.siteId)); let subnets = clientsQuery.map((client) => client.subnet); diff --git a/server/routers/ws.ts b/server/routers/ws.ts index afe422d0..1c24f48e 100644 --- a/server/routers/ws.ts +++ b/server/routers/ws.ts @@ -3,10 +3,11 @@ import { Server as HttpServer } from "http"; import { WebSocket, WebSocketServer } from "ws"; import { IncomingMessage } from "http"; import { Socket } from "net"; -import { Newt, newts, NewtSession } from "@server/db/schema"; +import { Newt, newts, NewtSession, Olm, olms, OlmSession } from "@server/db/schema"; import { eq } from "drizzle-orm"; import db from "@server/db"; import { validateNewtSessionToken } from "@server/auth/sessions/newt"; +import { validateOlmSessionToken } from "@server/auth/sessions/olm"; import { messageHandlers } from "./messageHandlers"; import logger from "@server/logger"; @@ -15,13 +16,17 @@ interface WebSocketRequest extends IncomingMessage { token?: string; } +type ClientType = 'newt' | 'olm'; + interface AuthenticatedWebSocket extends WebSocket { - newt?: Newt; + client?: Newt | Olm; + clientType?: ClientType; } interface TokenPayload { - newt: Newt; - session: NewtSession; + client: Newt | Olm; + session: NewtSession | OlmSession; + clientType: ClientType; } interface WSMessage { @@ -33,15 +38,16 @@ interface HandlerResponse { message: WSMessage; broadcast?: boolean; excludeSender?: boolean; - targetNewtId?: string; + targetClientId?: string; } interface HandlerContext { message: WSMessage; senderWs: WebSocket; - newt: Newt | undefined; - sendToClient: (newtId: string, message: WSMessage) => boolean; - broadcastToAllExcept: (message: WSMessage, excludeNewtId?: string) => void; + client: Newt | Olm | undefined; + clientType: ClientType; + sendToClient: (clientId: string, message: WSMessage) => boolean; + broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => void; connectedClients: Map; } @@ -54,34 +60,32 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true }); let connectedClients: Map = new Map(); // Helper functions for client management -const addClient = (newtId: string, ws: AuthenticatedWebSocket): void => { - const existingClients = connectedClients.get(newtId) || []; +const addClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => { + const existingClients = connectedClients.get(clientId) || []; existingClients.push(ws); - connectedClients.set(newtId, existingClients); - logger.info(`Client added to tracking - Newt ID: ${newtId}, Total connections: ${existingClients.length}`); + connectedClients.set(clientId, existingClients); + logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Total connections: ${existingClients.length}`); }; -const removeClient = (newtId: string, ws: AuthenticatedWebSocket): void => { - const existingClients = connectedClients.get(newtId) || []; +const removeClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => { + const existingClients = connectedClients.get(clientId) || []; const updatedClients = existingClients.filter(client => client !== ws); - if (updatedClients.length === 0) { - connectedClients.delete(newtId); - logger.info(`All connections removed for Newt ID: ${newtId}`); + connectedClients.delete(clientId); + logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`); } else { - connectedClients.set(newtId, updatedClients); - logger.info(`Connection removed - Newt ID: ${newtId}, Remaining connections: ${updatedClients.length}`); + connectedClients.set(clientId, updatedClients); + logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`); } }; // Helper functions for sending messages -const sendToClient = (newtId: string, message: WSMessage): boolean => { - const clients = connectedClients.get(newtId); +const sendToClient = (clientId: string, message: WSMessage): boolean => { + const clients = connectedClients.get(clientId); if (!clients || clients.length === 0) { - logger.info(`No active connections found for Newt ID: ${newtId}`); + logger.info(`No active connections found for Client ID: ${clientId}`); return false; } - const messageString = JSON.stringify(message); clients.forEach(client => { if (client.readyState === WebSocket.OPEN) { @@ -91,9 +95,9 @@ const sendToClient = (newtId: string, message: WSMessage): boolean => { return true; }; -const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void => { - connectedClients.forEach((clients, newtId) => { - if (newtId !== excludeNewtId) { +const broadcastToAllExcept = (message: WSMessage, excludeClientId?: string): void => { + connectedClients.forEach((clients, clientId) => { + if (clientId !== excludeClientId) { clients.forEach(client => { if (client.readyState === WebSocket.OPEN) { client.send(JSON.stringify(message)); @@ -103,84 +107,88 @@ const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void }); }; -// Token verification middleware (unchanged) -const verifyToken = async (token: string): Promise => { +// Token verification middleware +const verifyToken = async (token: string, clientType: ClientType): Promise => { try { - const { session, newt } = await validateNewtSessionToken(token); - - if (!session || !newt) { - return null; + if (clientType === 'newt') { + const { session, newt } = await validateNewtSessionToken(token); + if (!session || !newt) { + return null; + } + const existingNewt = await db + .select() + .from(newts) + .where(eq(newts.newtId, newt.newtId)); + if (!existingNewt || !existingNewt[0]) { + return null; + } + return { client: existingNewt[0], session, clientType }; + } else { + const { session, olm } = await validateOlmSessionToken(token); + if (!session || !olm) { + return null; + } + const existingOlm = await db + .select() + .from(olms) + .where(eq(olms.olmId, olm.olmId)); + if (!existingOlm || !existingOlm[0]) { + return null; + } + return { client: existingOlm[0], session, clientType }; } - - const existingNewt = await db - .select() - .from(newts) - .where(eq(newts.newtId, newt.newtId)); - - if (!existingNewt || !existingNewt[0]) { - return null; - } - - return { newt: existingNewt[0], session }; } catch (error) { logger.error("Token verification failed:", error); return null; } }; -const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => { +const setupConnection = (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): void => { logger.info("Establishing websocket connection"); - - if (!newt) { - logger.error("Connection attempt without newt"); + if (!client) { + logger.error("Connection attempt without client"); return ws.terminate(); } - ws.newt = newt; + ws.client = client; + ws.clientType = clientType; // Add client to tracking - addClient(newt.newtId, ws); + const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId; + addClient(clientId, ws, clientType); ws.on("message", async (data) => { try { const message: WSMessage = JSON.parse(data.toString()); - // logger.info(`Message received from Newt ID ${newtId}:`, message); - // Validate message format if (!message.type || typeof message.type !== "string") { throw new Error("Invalid message format: missing or invalid type"); } - // Get the appropriate handler for the message type const handler = messageHandlers[message.type]; if (!handler) { throw new Error(`Unsupported message type: ${message.type}`); } - // Process the message and get response const response = await handler({ message, senderWs: ws, - newt: ws.newt, + client: ws.client, + clientType: ws.clientType!, sendToClient, broadcastToAllExcept, connectedClients }); - // Send response if one was returned if (response) { if (response.broadcast) { - // Broadcast to all clients except sender if specified - broadcastToAllExcept(response.message, response.excludeSender ? newt.newtId : undefined); - } else if (response.targetNewtId) { - // Send to specific client if targetNewtId is provided - sendToClient(response.targetNewtId, response.message); + broadcastToAllExcept(response.message, response.excludeSender ? clientId : undefined); + } else if (response.targetClientId) { + sendToClient(response.targetClientId, response.message); } else { - // Send back to sender ws.send(JSON.stringify(response.message)); } } - } catch (error) { logger.error("Message handling error:", error); ws.send(JSON.stringify({ @@ -194,18 +202,18 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => { }); ws.on("close", () => { - removeClient(newt.newtId, ws); - logger.info(`Client disconnected - Newt ID: ${newt.newtId}`); + removeClient(clientId, ws, clientType); + logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`); }); ws.on("error", (error: Error) => { - logger.error(`WebSocket error for Newt ID ${newt.newtId}:`, error); + logger.error(`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, error); }); - logger.info(`WebSocket connection established - Newt ID: ${newt.newtId}`); + logger.info(`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`); }; -// Router endpoint (unchanged) +// Router endpoint router.get("/ws", (req: Request, res: Response) => { res.status(200).send("WebSocket endpoint"); }); @@ -214,18 +222,22 @@ router.get("/ws", (req: Request, res: Response) => { const handleWSUpgrade = (server: HttpServer): void => { server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => { try { - const token = request.url?.includes("?") - ? new URLSearchParams(request.url.split("?")[1]).get("token") || "" - : request.headers["sec-websocket-protocol"]; + const url = new URL(request.url || '', `http://${request.headers.host}`); + const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || ''; + let clientType = url.searchParams.get('clientType') as ClientType; - if (!token) { - logger.warn("Unauthorized connection attempt: no token..."); + if (!clientType) { + clientType = "newt"; + } + + if (!token || !clientType || !['newt', 'olm'].includes(clientType)) { + logger.warn("Unauthorized connection attempt: invalid token or client type..."); socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n"); socket.destroy(); return; } - const tokenPayload = await verifyToken(token); + const tokenPayload = await verifyToken(token, clientType); if (!tokenPayload) { logger.warn("Unauthorized connection attempt: invalid token..."); socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n"); @@ -234,7 +246,7 @@ const handleWSUpgrade = (server: HttpServer): void => { } wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => { - setupConnection(ws, tokenPayload.newt); + setupConnection(ws, tokenPayload.client, tokenPayload.clientType); }); } catch (error) { logger.error("WebSocket upgrade error:", error); @@ -250,4 +262,4 @@ export { sendToClient, broadcastToAllExcept, connectedClients -}; +}; \ No newline at end of file