Move things around; rename to olm

This commit is contained in:
Owen 2025-02-21 10:13:41 -05:00
parent 41983ce356
commit e112fcba29
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
13 changed files with 642 additions and 93 deletions

View file

@ -0,0 +1,72 @@
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Olm, olms, olmSessions, OlmSession } from "@server/db/schema";
import db from "@server/db";
import { eq } from "drizzle-orm";
export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createOlmSession(
token: string,
olmId: string,
): Promise<OlmSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const session: OlmSession = {
sessionId: sessionId,
olmId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
};
await db.insert(olmSessions).values(session);
return session;
}
export async function validateOlmSessionToken(
token: string,
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const result = await db
.select({ olm: olms, session: olmSessions })
.from(olmSessions)
.innerJoin(olms, eq(olmSessions.olmId, olms.olmId))
.where(eq(olmSessions.sessionId, sessionId));
if (result.length < 1) {
return { session: null, olm: null };
}
const { olm, session } = result[0];
if (Date.now() >= session.expiresAt) {
await db
.delete(olmSessions)
.where(eq(olmSessions.sessionId, session.sessionId));
return { session: null, olm: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
await db
.update(olmSessions)
.set({
expiresAt: session.expiresAt,
})
.where(eq(olmSessions.sessionId, session.sessionId));
}
return { session, olm };
}
export async function invalidateOlmSession(sessionId: string): Promise<void> {
await db.delete(olmSessions).where(eq(olmSessions.sessionId, sessionId));
}
export async function invalidateAllOlmSessions(olmId: string): Promise<void> {
await db.delete(olmSessions).where(eq(olmSessions.olmId, olmId));
}
export type SessionValidationResult =
| { session: OlmSession; olm: Olm }
| { session: null; olm: null };

View file

@ -114,8 +114,8 @@ export const newts = sqliteTable("newt", {
}) })
}); });
export const clients = sqliteTable("clients", { export const olms = sqliteTable("olms", {
clientId: text("id").primaryKey(), olmId: text("id").primaryKey(),
secretHash: text("secretHash").notNull(), secretHash: text("secretHash").notNull(),
dateCreated: text("dateCreated").notNull(), dateCreated: text("dateCreated").notNull(),
siteId: integer("siteId").references(() => sites.siteId, { siteId: integer("siteId").references(() => sites.siteId, {
@ -156,9 +156,9 @@ export const newtSessions = sqliteTable("newtSession", {
expiresAt: integer("expiresAt").notNull() expiresAt: integer("expiresAt").notNull()
}); });
export const clientSessions = sqliteTable("clientSession", { export const olmSessions = sqliteTable("clientSession", {
sessionId: text("id").primaryKey(), sessionId: text("id").primaryKey(),
clientId: text("clientId") olmId: text("olmId")
.notNull() .notNull()
.references(() => newts.newtId, { onDelete: "cascade" }), .references(() => newts.newtId, { onDelete: "cascade" }),
expiresAt: integer("expiresAt").notNull() expiresAt: integer("expiresAt").notNull()
@ -425,8 +425,8 @@ export type Target = InferSelectModel<typeof targets>;
export type Session = InferSelectModel<typeof sessions>; export type Session = InferSelectModel<typeof sessions>;
export type Newt = InferSelectModel<typeof newts>; export type Newt = InferSelectModel<typeof newts>;
export type NewtSession = InferSelectModel<typeof newtSessions>; export type NewtSession = InferSelectModel<typeof newtSessions>;
export type Client = InferSelectModel<typeof clients>; export type Olm = InferSelectModel<typeof olms>;
export type ClientSession = InferSelectModel<typeof clientSessions>; export type OlmSession = InferSelectModel<typeof olmSessions>;
export type EmailVerificationCode = InferSelectModel< export type EmailVerificationCode = InferSelectModel<
typeof emailVerificationCodes typeof emailVerificationCodes
>; >;

View file

@ -1 +0,0 @@
export * from "./pickClientDefaults";

View file

@ -7,7 +7,7 @@ import * as target from "./target";
import * as user from "./user"; import * as user from "./user";
import * as auth from "./auth"; import * as auth from "./auth";
import * as role from "./role"; import * as role from "./role";
import * as client from "./client"; import * as olm from "./olm";
import * as accessToken from "./accessToken"; import * as accessToken from "./accessToken";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { import {
@ -100,7 +100,7 @@ authenticated.get(
"/site/:siteId/pick-client-defaults", "/site/:siteId/pick-client-defaults",
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient), verifyUserHasAction(ActionsEnum.createClient),
client.pickClientDefaults olm.pickOlmDefaults
); );
// authenticated.get( // authenticated.get(

View file

@ -1,8 +1,10 @@
import { handleRegisterMessage } from "./newt"; import { handleNewtRegisterMessage } from "./newt";
import { handleOlmRegisterMessage } from "./olm";
import { handleGetConfigMessage } from "./newt/handleGetConfigMessage"; import { handleGetConfigMessage } from "./newt/handleGetConfigMessage";
import { MessageHandler } from "./ws"; import { MessageHandler } from "./ws";
export const messageHandlers: Record<string, MessageHandler> = { export const messageHandlers: Record<string, MessageHandler> = {
"newt/wg/register": handleRegisterMessage, "newt/wg/register": handleNewtRegisterMessage,
"olm/wg/register": handleOlmRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage, "newt/wg/get-config": handleGetConfigMessage,
}; };

View file

@ -11,8 +11,10 @@ import { eq, and, sql } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers"; import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger"; import logger from "@server/logger";
export const handleRegisterMessage: MessageHandler = async (context) => { export const handleNewtRegisterMessage: MessageHandler = async (context) => {
const { message, newt, sendToClient } = context; const { message, client, sendToClient } = context;
const newt = client;
logger.info("Handling register message!"); logger.info("Handling register message!");

View file

@ -0,0 +1,106 @@
import { NextFunction, Request, Response } from "express";
import db from "@server/db";
import { hash } from "@node-rs/argon2";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { newts } from "@server/db/schema";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { SqliteError } from "better-sqlite3";
import moment from "moment";
import { generateSessionToken } from "@server/auth/sessions/app";
import { createNewtSession } from "@server/auth/sessions/newt";
import { fromError } from "zod-validation-error";
import { hashPassword } from "@server/auth/password";
export const createNewtBodySchema = z.object({});
export type CreateNewtBody = z.infer<typeof createNewtBodySchema>;
export type CreateNewtResponse = {
token: string;
newtId: string;
secret: string;
};
const createNewtSchema = z
.object({
newtId: z.string(),
secret: z.string()
})
.strict();
export async function createNewt(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = createNewtSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { newtId, secret } = parsedBody.data;
if (!req.userOrgRoleId) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
}
const secretHash = await hashPassword(secret);
await db.insert(newts).values({
newtId: newtId,
secretHash,
dateCreated: moment().toISOString(),
});
// give the newt their default permissions:
// await db.insert(newtActions).values({
// newtId: newtId,
// actionId: ActionsEnum.createOrg,
// orgId: null,
// });
const token = generateSessionToken();
await createNewtSession(token, newtId);
return response<CreateNewtResponse>(res, {
data: {
newtId,
secret,
token,
},
success: true,
error: false,
message: "Newt created successfully",
status: HttpCode.OK,
});
} catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A newt with that email address already exists"
)
);
} else {
console.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create newt"
)
);
}
}
}

View file

@ -0,0 +1,115 @@
import { generateSessionToken } from "@server/auth/sessions/app";
import db from "@server/db";
import { newts } from "@server/db/schema";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { eq } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import {
createNewtSession,
validateNewtSessionToken
} from "@server/auth/sessions/newt";
import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import config from "@server/lib/config";
export const newtGetTokenBodySchema = z.object({
newtId: z.string(),
secret: z.string(),
token: z.string().optional()
});
export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
export async function getToken(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedBody = newtGetTokenBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { newtId, secret, token } = parsedBody.data;
try {
if (token) {
const { session, newt } = await validateNewtSessionToken(token);
if (session) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Newt session already valid. Newt ID: ${newtId}. IP: ${req.ip}.`
);
}
return response<null>(res, {
data: null,
success: true,
error: false,
message: "Token session already valid",
status: HttpCode.OK
});
}
}
const existingNewtRes = await db
.select()
.from(newts)
.where(eq(newts.newtId, newtId));
if (!existingNewtRes || !existingNewtRes.length) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No newt found with that newtId"
)
);
}
const existingNewt = existingNewtRes[0];
const validSecret = await verifyPassword(
secret,
existingNewt.secretHash
);
if (!validSecret) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Newt id or secret is incorrect. Newt: ID ${newtId}. IP: ${req.ip}.`
);
}
return next(
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
);
}
const resToken = generateSessionToken();
await createNewtSession(resToken, existingNewt.newtId);
return response<{ token: string }>(res, {
data: {
token: resToken
},
success: true,
error: false,
message: "Token created successfully",
status: HttpCode.OK
});
} catch (e) {
console.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to authenticate newt"
)
);
}
}

View file

@ -0,0 +1,147 @@
import { z } from "zod";
import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import db from "@server/db";
import { olms, Site, sites } from "@server/db/schema";
import { eq, isNotNull } from "drizzle-orm";
import { findNextAvailableCidr } from "@server/lib/ip";
import config from "@server/lib/config";
const inputSchema = z.object({
publicKey: z.string(),
endpoint: z.string(),
listenPort: z.number()
});
type Input = z.infer<typeof inputSchema>;
export const handleGetConfigMessage: MessageHandler = async (context) => {
const { message, newt, sendToClient } = context;
logger.debug("Handling Newt get config message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const parsed = inputSchema.safeParse(message.data);
if (!parsed.success) {
logger.error(
"handleGetConfigMessage: Invalid input: " +
fromError(parsed.error).toString()
);
return;
}
const { publicKey, endpoint, listenPort } = message.data as Input;
const siteId = newt.siteId;
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (!siteRes) {
logger.warn("handleGetConfigMessage: Site not found");
return;
}
let site: Site | undefined;
if (!site) {
const address = await getNextAvailableSubnet();
// create a new exit node
const [updateRes] = await db
.update(sites)
.set({
publicKey,
endpoint,
address,
listenPort
})
.where(eq(sites.siteId, siteId))
.returning();
site = updateRes;
logger.info(`Updated site ${siteId} with new WG Newt info`);
} else {
site = siteRes;
}
if (!site) {
logger.error("handleGetConfigMessage: Failed to update site");
return;
}
const clientsRes = await db
.select()
.from(olms)
.where(eq(olms.siteId, siteId));
const peers = await Promise.all(
clientsRes.map(async (client) => {
return {
publicKey: client.pubKey,
allowedIps: "0.0.0.0/0"
};
})
);
const configResponse = {
listenPort: site.listenPort, // ?????
// ipAddress: exitNode[0].address,
peers
};
logger.debug("Sending config: ", configResponse);
return {
message: {
type: "olm/wg/connect", // what to make the response type?
data: {
config: configResponse
}
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};
};
async function getNextAvailableSubnet(): Promise<string> {
const existingAddresses = await db
.select({
address: sites.address
})
.from(sites)
.where(isNotNull(sites.address));
const addresses = existingAddresses
.map((a) => a.address)
.filter((a) => a) as string[];
let subnet = findNextAvailableCidr(
addresses,
config.getRawConfig().wg_site.block_size,
config.getRawConfig().wg_site.subnet_group
);
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
// replace the last octet with 1
subnet =
subnet.split(".").slice(0, 3).join(".") +
".1" +
"/" +
subnet.split("/")[1];
return subnet;
}

View file

@ -0,0 +1,93 @@
import db from "@server/db";
import { MessageHandler } from "../ws";
import {
exitNodes,
resources,
sites,
Target,
targets
} from "@server/db/schema";
import { eq, and, sql } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const olm = client;
logger.info("Handling register message!");
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.siteId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
return;
}
const siteId = olm.siteId;
const { publicKey } = message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site || !site.exitNodeId) {
logger.warn("Site not found or does not have exit node");
return;
}
await db
.update(sites)
.set({
pubKey: publicKey
})
.where(eq(sites.siteId, siteId))
.returning();
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (site.pubKey && site.pubKey !== publicKey) {
logger.info("Public key mismatch. Deleting old peer...");
await deletePeer(site.exitNodeId, site.pubKey);
}
if (!site.subnet) {
logger.warn("Site has no subnet");
return;
}
// add the peer to the exit node
await addPeer(site.exitNodeId, {
publicKey: publicKey,
allowedIps: [site.subnet]
});
return {
message: {
type: "olm/wg/connect",
data: {
endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`,
publicKey: exitNode.publicKey,
serverIP: exitNode.address.split("/")[0],
tunnelIP: site.subnet.split("/")[0]
}
},
broadcast: false, // Send to all olms
excludeSender: false // Include sender in broadcast
};
};

View file

@ -0,0 +1 @@
export * from "./pickOlmDefaults";

View file

@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { db } from "@server/db"; import { db } from "@server/db";
import { clients, sites } from "@server/db/schema"; import { olms, sites } from "@server/db/schema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@ -30,7 +30,7 @@ export type PickClientDefaultsResponse = {
clientSecret: string; clientSecret: string;
}; };
export async function pickClientDefaults( export async function pickOlmDefaults(
req: Request, req: Request,
res: Response, res: Response,
next: NextFunction next: NextFunction
@ -71,10 +71,10 @@ export async function pickClientDefaults(
const clientsQuery = await db const clientsQuery = await db
.select({ .select({
subnet: clients.subnet subnet: olms.subnet
}) })
.from(clients) .from(olms)
.where(eq(clients.siteId, site.siteId)); .where(eq(olms.siteId, site.siteId));
let subnets = clientsQuery.map((client) => client.subnet); let subnets = clientsQuery.map((client) => client.subnet);

View file

@ -3,10 +3,11 @@ import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws"; import { WebSocket, WebSocketServer } from "ws";
import { IncomingMessage } from "http"; import { IncomingMessage } from "http";
import { Socket } from "net"; import { Socket } from "net";
import { Newt, newts, NewtSession } from "@server/db/schema"; import { Newt, newts, NewtSession, Olm, olms, OlmSession } from "@server/db/schema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import db from "@server/db"; import db from "@server/db";
import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
import { messageHandlers } from "./messageHandlers"; import { messageHandlers } from "./messageHandlers";
import logger from "@server/logger"; import logger from "@server/logger";
@ -15,13 +16,17 @@ interface WebSocketRequest extends IncomingMessage {
token?: string; token?: string;
} }
type ClientType = 'newt' | 'olm';
interface AuthenticatedWebSocket extends WebSocket { interface AuthenticatedWebSocket extends WebSocket {
newt?: Newt; client?: Newt | Olm;
clientType?: ClientType;
} }
interface TokenPayload { interface TokenPayload {
newt: Newt; client: Newt | Olm;
session: NewtSession; session: NewtSession | OlmSession;
clientType: ClientType;
} }
interface WSMessage { interface WSMessage {
@ -33,15 +38,16 @@ interface HandlerResponse {
message: WSMessage; message: WSMessage;
broadcast?: boolean; broadcast?: boolean;
excludeSender?: boolean; excludeSender?: boolean;
targetNewtId?: string; targetClientId?: string;
} }
interface HandlerContext { interface HandlerContext {
message: WSMessage; message: WSMessage;
senderWs: WebSocket; senderWs: WebSocket;
newt: Newt | undefined; client: Newt | Olm | undefined;
sendToClient: (newtId: string, message: WSMessage) => boolean; clientType: ClientType;
broadcastToAllExcept: (message: WSMessage, excludeNewtId?: string) => void; sendToClient: (clientId: string, message: WSMessage) => boolean;
broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => void;
connectedClients: Map<string, WebSocket[]>; connectedClients: Map<string, WebSocket[]>;
} }
@ -54,34 +60,32 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true });
let connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map(); let connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Helper functions for client management // Helper functions for client management
const addClient = (newtId: string, ws: AuthenticatedWebSocket): void => { const addClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => {
const existingClients = connectedClients.get(newtId) || []; const existingClients = connectedClients.get(clientId) || [];
existingClients.push(ws); existingClients.push(ws);
connectedClients.set(newtId, existingClients); connectedClients.set(clientId, existingClients);
logger.info(`Client added to tracking - Newt ID: ${newtId}, Total connections: ${existingClients.length}`); logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Total connections: ${existingClients.length}`);
}; };
const removeClient = (newtId: string, ws: AuthenticatedWebSocket): void => { const removeClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => {
const existingClients = connectedClients.get(newtId) || []; const existingClients = connectedClients.get(clientId) || [];
const updatedClients = existingClients.filter(client => client !== ws); const updatedClients = existingClients.filter(client => client !== ws);
if (updatedClients.length === 0) { if (updatedClients.length === 0) {
connectedClients.delete(newtId); connectedClients.delete(clientId);
logger.info(`All connections removed for Newt ID: ${newtId}`); logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`);
} else { } else {
connectedClients.set(newtId, updatedClients); connectedClients.set(clientId, updatedClients);
logger.info(`Connection removed - Newt ID: ${newtId}, Remaining connections: ${updatedClients.length}`); logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`);
} }
}; };
// Helper functions for sending messages // Helper functions for sending messages
const sendToClient = (newtId: string, message: WSMessage): boolean => { const sendToClient = (clientId: string, message: WSMessage): boolean => {
const clients = connectedClients.get(newtId); const clients = connectedClients.get(clientId);
if (!clients || clients.length === 0) { if (!clients || clients.length === 0) {
logger.info(`No active connections found for Newt ID: ${newtId}`); logger.info(`No active connections found for Client ID: ${clientId}`);
return false; return false;
} }
const messageString = JSON.stringify(message); const messageString = JSON.stringify(message);
clients.forEach(client => { clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
@ -91,9 +95,9 @@ const sendToClient = (newtId: string, message: WSMessage): boolean => {
return true; return true;
}; };
const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void => { const broadcastToAllExcept = (message: WSMessage, excludeClientId?: string): void => {
connectedClients.forEach((clients, newtId) => { connectedClients.forEach((clients, clientId) => {
if (newtId !== excludeNewtId) { if (clientId !== excludeClientId) {
clients.forEach(client => { clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message)); client.send(JSON.stringify(message));
@ -103,84 +107,88 @@ const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void
}); });
}; };
// Token verification middleware (unchanged) // Token verification middleware
const verifyToken = async (token: string): Promise<TokenPayload | null> => { const verifyToken = async (token: string, clientType: ClientType): Promise<TokenPayload | null> => {
try { try {
const { session, newt } = await validateNewtSessionToken(token); if (clientType === 'newt') {
const { session, newt } = await validateNewtSessionToken(token);
if (!session || !newt) { if (!session || !newt) {
return null; return null;
}
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { client: existingNewt[0], session, clientType };
} else {
const { session, olm } = await validateOlmSessionToken(token);
if (!session || !olm) {
return null;
}
const existingOlm = await db
.select()
.from(olms)
.where(eq(olms.olmId, olm.olmId));
if (!existingOlm || !existingOlm[0]) {
return null;
}
return { client: existingOlm[0], session, clientType };
} }
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { newt: existingNewt[0], session };
} catch (error) { } catch (error) {
logger.error("Token verification failed:", error); logger.error("Token verification failed:", error);
return null; return null;
} }
}; };
const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => { const setupConnection = (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): void => {
logger.info("Establishing websocket connection"); logger.info("Establishing websocket connection");
if (!client) {
if (!newt) { logger.error("Connection attempt without client");
logger.error("Connection attempt without newt");
return ws.terminate(); return ws.terminate();
} }
ws.newt = newt; ws.client = client;
ws.clientType = clientType;
// Add client to tracking // Add client to tracking
addClient(newt.newtId, ws); const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId;
addClient(clientId, ws, clientType);
ws.on("message", async (data) => { ws.on("message", async (data) => {
try { try {
const message: WSMessage = JSON.parse(data.toString()); const message: WSMessage = JSON.parse(data.toString());
// logger.info(`Message received from Newt ID ${newtId}:`, message);
// Validate message format
if (!message.type || typeof message.type !== "string") { if (!message.type || typeof message.type !== "string") {
throw new Error("Invalid message format: missing or invalid type"); throw new Error("Invalid message format: missing or invalid type");
} }
// Get the appropriate handler for the message type
const handler = messageHandlers[message.type]; const handler = messageHandlers[message.type];
if (!handler) { if (!handler) {
throw new Error(`Unsupported message type: ${message.type}`); throw new Error(`Unsupported message type: ${message.type}`);
} }
// Process the message and get response
const response = await handler({ const response = await handler({
message, message,
senderWs: ws, senderWs: ws,
newt: ws.newt, client: ws.client,
clientType: ws.clientType!,
sendToClient, sendToClient,
broadcastToAllExcept, broadcastToAllExcept,
connectedClients connectedClients
}); });
// Send response if one was returned
if (response) { if (response) {
if (response.broadcast) { if (response.broadcast) {
// Broadcast to all clients except sender if specified broadcastToAllExcept(response.message, response.excludeSender ? clientId : undefined);
broadcastToAllExcept(response.message, response.excludeSender ? newt.newtId : undefined); } else if (response.targetClientId) {
} else if (response.targetNewtId) { sendToClient(response.targetClientId, response.message);
// Send to specific client if targetNewtId is provided
sendToClient(response.targetNewtId, response.message);
} else { } else {
// Send back to sender
ws.send(JSON.stringify(response.message)); ws.send(JSON.stringify(response.message));
} }
} }
} catch (error) { } catch (error) {
logger.error("Message handling error:", error); logger.error("Message handling error:", error);
ws.send(JSON.stringify({ ws.send(JSON.stringify({
@ -194,18 +202,18 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
}); });
ws.on("close", () => { ws.on("close", () => {
removeClient(newt.newtId, ws); removeClient(clientId, ws, clientType);
logger.info(`Client disconnected - Newt ID: ${newt.newtId}`); logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`);
}); });
ws.on("error", (error: Error) => { ws.on("error", (error: Error) => {
logger.error(`WebSocket error for Newt ID ${newt.newtId}:`, error); logger.error(`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, error);
}); });
logger.info(`WebSocket connection established - Newt ID: ${newt.newtId}`); logger.info(`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`);
}; };
// Router endpoint (unchanged) // Router endpoint
router.get("/ws", (req: Request, res: Response) => { router.get("/ws", (req: Request, res: Response) => {
res.status(200).send("WebSocket endpoint"); res.status(200).send("WebSocket endpoint");
}); });
@ -214,18 +222,22 @@ router.get("/ws", (req: Request, res: Response) => {
const handleWSUpgrade = (server: HttpServer): void => { const handleWSUpgrade = (server: HttpServer): void => {
server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => { server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
try { try {
const token = request.url?.includes("?") const url = new URL(request.url || '', `http://${request.headers.host}`);
? new URLSearchParams(request.url.split("?")[1]).get("token") || "" const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || '';
: request.headers["sec-websocket-protocol"]; let clientType = url.searchParams.get('clientType') as ClientType;
if (!token) { if (!clientType) {
logger.warn("Unauthorized connection attempt: no token..."); clientType = "newt";
}
if (!token || !clientType || !['newt', 'olm'].includes(clientType)) {
logger.warn("Unauthorized connection attempt: invalid token or client type...");
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n"); socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
socket.destroy(); socket.destroy();
return; return;
} }
const tokenPayload = await verifyToken(token); const tokenPayload = await verifyToken(token, clientType);
if (!tokenPayload) { if (!tokenPayload) {
logger.warn("Unauthorized connection attempt: invalid token..."); logger.warn("Unauthorized connection attempt: invalid token...");
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n"); socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
@ -234,7 +246,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
} }
wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => { wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
setupConnection(ws, tokenPayload.newt); setupConnection(ws, tokenPayload.client, tokenPayload.clientType);
}); });
} catch (error) { } catch (error) {
logger.error("WebSocket upgrade error:", error); logger.error("WebSocket upgrade error:", error);
@ -250,4 +262,4 @@ export {
sendToClient, sendToClient,
broadcastToAllExcept, broadcastToAllExcept,
connectedClients connectedClients
}; };