mirror of
https://github.com/fosrl/pangolin.git
synced 2025-07-23 20:24:34 +02:00
Move things around; rename to olm
This commit is contained in:
parent
41983ce356
commit
e112fcba29
13 changed files with 642 additions and 93 deletions
72
server/auth/sessions/olm.ts
Normal file
72
server/auth/sessions/olm.ts
Normal file
|
@ -0,0 +1,72 @@
|
|||
import {
|
||||
encodeHexLowerCase,
|
||||
} from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { Olm, olms, olmSessions, OlmSession } from "@server/db/schema";
|
||||
import db from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
|
||||
|
||||
export async function createOlmSession(
|
||||
token: string,
|
||||
olmId: string,
|
||||
): Promise<OlmSession> {
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token)),
|
||||
);
|
||||
const session: OlmSession = {
|
||||
sessionId: sessionId,
|
||||
olmId,
|
||||
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
|
||||
};
|
||||
await db.insert(olmSessions).values(session);
|
||||
return session;
|
||||
}
|
||||
|
||||
export async function validateOlmSessionToken(
|
||||
token: string,
|
||||
): Promise<SessionValidationResult> {
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token)),
|
||||
);
|
||||
const result = await db
|
||||
.select({ olm: olms, session: olmSessions })
|
||||
.from(olmSessions)
|
||||
.innerJoin(olms, eq(olmSessions.olmId, olms.olmId))
|
||||
.where(eq(olmSessions.sessionId, sessionId));
|
||||
if (result.length < 1) {
|
||||
return { session: null, olm: null };
|
||||
}
|
||||
const { olm, session } = result[0];
|
||||
if (Date.now() >= session.expiresAt) {
|
||||
await db
|
||||
.delete(olmSessions)
|
||||
.where(eq(olmSessions.sessionId, session.sessionId));
|
||||
return { session: null, olm: null };
|
||||
}
|
||||
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
|
||||
session.expiresAt = new Date(
|
||||
Date.now() + EXPIRES,
|
||||
).getTime();
|
||||
await db
|
||||
.update(olmSessions)
|
||||
.set({
|
||||
expiresAt: session.expiresAt,
|
||||
})
|
||||
.where(eq(olmSessions.sessionId, session.sessionId));
|
||||
}
|
||||
return { session, olm };
|
||||
}
|
||||
|
||||
export async function invalidateOlmSession(sessionId: string): Promise<void> {
|
||||
await db.delete(olmSessions).where(eq(olmSessions.sessionId, sessionId));
|
||||
}
|
||||
|
||||
export async function invalidateAllOlmSessions(olmId: string): Promise<void> {
|
||||
await db.delete(olmSessions).where(eq(olmSessions.olmId, olmId));
|
||||
}
|
||||
|
||||
export type SessionValidationResult =
|
||||
| { session: OlmSession; olm: Olm }
|
||||
| { session: null; olm: null };
|
|
@ -114,8 +114,8 @@ export const newts = sqliteTable("newt", {
|
|||
})
|
||||
});
|
||||
|
||||
export const clients = sqliteTable("clients", {
|
||||
clientId: text("id").primaryKey(),
|
||||
export const olms = sqliteTable("olms", {
|
||||
olmId: text("id").primaryKey(),
|
||||
secretHash: text("secretHash").notNull(),
|
||||
dateCreated: text("dateCreated").notNull(),
|
||||
siteId: integer("siteId").references(() => sites.siteId, {
|
||||
|
@ -156,9 +156,9 @@ export const newtSessions = sqliteTable("newtSession", {
|
|||
expiresAt: integer("expiresAt").notNull()
|
||||
});
|
||||
|
||||
export const clientSessions = sqliteTable("clientSession", {
|
||||
export const olmSessions = sqliteTable("clientSession", {
|
||||
sessionId: text("id").primaryKey(),
|
||||
clientId: text("clientId")
|
||||
olmId: text("olmId")
|
||||
.notNull()
|
||||
.references(() => newts.newtId, { onDelete: "cascade" }),
|
||||
expiresAt: integer("expiresAt").notNull()
|
||||
|
@ -425,8 +425,8 @@ export type Target = InferSelectModel<typeof targets>;
|
|||
export type Session = InferSelectModel<typeof sessions>;
|
||||
export type Newt = InferSelectModel<typeof newts>;
|
||||
export type NewtSession = InferSelectModel<typeof newtSessions>;
|
||||
export type Client = InferSelectModel<typeof clients>;
|
||||
export type ClientSession = InferSelectModel<typeof clientSessions>;
|
||||
export type Olm = InferSelectModel<typeof olms>;
|
||||
export type OlmSession = InferSelectModel<typeof olmSessions>;
|
||||
export type EmailVerificationCode = InferSelectModel<
|
||||
typeof emailVerificationCodes
|
||||
>;
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
export * from "./pickClientDefaults";
|
|
@ -7,7 +7,7 @@ import * as target from "./target";
|
|||
import * as user from "./user";
|
||||
import * as auth from "./auth";
|
||||
import * as role from "./role";
|
||||
import * as client from "./client";
|
||||
import * as olm from "./olm";
|
||||
import * as accessToken from "./accessToken";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import {
|
||||
|
@ -100,7 +100,7 @@ authenticated.get(
|
|||
"/site/:siteId/pick-client-defaults",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.createClient),
|
||||
client.pickClientDefaults
|
||||
olm.pickOlmDefaults
|
||||
);
|
||||
|
||||
// authenticated.get(
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import { handleRegisterMessage } from "./newt";
|
||||
import { handleNewtRegisterMessage } from "./newt";
|
||||
import { handleOlmRegisterMessage } from "./olm";
|
||||
import { handleGetConfigMessage } from "./newt/handleGetConfigMessage";
|
||||
import { MessageHandler } from "./ws";
|
||||
|
||||
export const messageHandlers: Record<string, MessageHandler> = {
|
||||
"newt/wg/register": handleRegisterMessage,
|
||||
"newt/wg/register": handleNewtRegisterMessage,
|
||||
"olm/wg/register": handleOlmRegisterMessage,
|
||||
"newt/wg/get-config": handleGetConfigMessage,
|
||||
};
|
||||
|
|
|
@ -11,8 +11,10 @@ import { eq, and, sql } from "drizzle-orm";
|
|||
import { addPeer, deletePeer } from "../gerbil/peers";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export const handleRegisterMessage: MessageHandler = async (context) => {
|
||||
const { message, newt, sendToClient } = context;
|
||||
export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
const { message, client, sendToClient } = context;
|
||||
|
||||
const newt = client;
|
||||
|
||||
logger.info("Handling register message!");
|
||||
|
106
server/routers/olm/createOlm.ts
Normal file
106
server/routers/olm/createOlm.ts
Normal file
|
@ -0,0 +1,106 @@
|
|||
import { NextFunction, Request, Response } from "express";
|
||||
import db from "@server/db";
|
||||
import { hash } from "@node-rs/argon2";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { newts } from "@server/db/schema";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { SqliteError } from "better-sqlite3";
|
||||
import moment from "moment";
|
||||
import { generateSessionToken } from "@server/auth/sessions/app";
|
||||
import { createNewtSession } from "@server/auth/sessions/newt";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
|
||||
export const createNewtBodySchema = z.object({});
|
||||
|
||||
export type CreateNewtBody = z.infer<typeof createNewtBodySchema>;
|
||||
|
||||
export type CreateNewtResponse = {
|
||||
token: string;
|
||||
newtId: string;
|
||||
secret: string;
|
||||
};
|
||||
|
||||
const createNewtSchema = z
|
||||
.object({
|
||||
newtId: z.string(),
|
||||
secret: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export async function createNewt(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
|
||||
const parsedBody = createNewtSchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { newtId, secret } = parsedBody.data;
|
||||
|
||||
if (!req.userOrgRoleId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
|
||||
);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(secret);
|
||||
|
||||
await db.insert(newts).values({
|
||||
newtId: newtId,
|
||||
secretHash,
|
||||
dateCreated: moment().toISOString(),
|
||||
});
|
||||
|
||||
// give the newt their default permissions:
|
||||
// await db.insert(newtActions).values({
|
||||
// newtId: newtId,
|
||||
// actionId: ActionsEnum.createOrg,
|
||||
// orgId: null,
|
||||
// });
|
||||
|
||||
const token = generateSessionToken();
|
||||
await createNewtSession(token, newtId);
|
||||
|
||||
return response<CreateNewtResponse>(res, {
|
||||
data: {
|
||||
newtId,
|
||||
secret,
|
||||
token,
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Newt created successfully",
|
||||
status: HttpCode.OK,
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A newt with that email address already exists"
|
||||
)
|
||||
);
|
||||
} else {
|
||||
console.error(e);
|
||||
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create newt"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
115
server/routers/olm/getToken.ts
Normal file
115
server/routers/olm/getToken.ts
Normal file
|
@ -0,0 +1,115 @@
|
|||
import { generateSessionToken } from "@server/auth/sessions/app";
|
||||
import db from "@server/db";
|
||||
import { newts } from "@server/db/schema";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import response from "@server/lib/response";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import {
|
||||
createNewtSession,
|
||||
validateNewtSessionToken
|
||||
} from "@server/auth/sessions/newt";
|
||||
import { verifyPassword } from "@server/auth/password";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export const newtGetTokenBodySchema = z.object({
|
||||
newtId: z.string(),
|
||||
secret: z.string(),
|
||||
token: z.string().optional()
|
||||
});
|
||||
|
||||
export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
|
||||
|
||||
export async function getToken(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedBody = newtGetTokenBodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { newtId, secret, token } = parsedBody.data;
|
||||
|
||||
try {
|
||||
if (token) {
|
||||
const { session, newt } = await validateNewtSessionToken(token);
|
||||
if (session) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Newt session already valid. Newt ID: ${newtId}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return response<null>(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Token session already valid",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const existingNewtRes = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.newtId, newtId));
|
||||
if (!existingNewtRes || !existingNewtRes.length) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No newt found with that newtId"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const existingNewt = existingNewtRes[0];
|
||||
|
||||
const validSecret = await verifyPassword(
|
||||
secret,
|
||||
existingNewt.secretHash
|
||||
);
|
||||
if (!validSecret) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Newt id or secret is incorrect. Newt: ID ${newtId}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
|
||||
);
|
||||
}
|
||||
|
||||
const resToken = generateSessionToken();
|
||||
await createNewtSession(resToken, existingNewt.newtId);
|
||||
|
||||
return response<{ token: string }>(res, {
|
||||
data: {
|
||||
token: resToken
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Token created successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to authenticate newt"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
147
server/routers/olm/handleGetConfigMessage.ts
Normal file
147
server/routers/olm/handleGetConfigMessage.ts
Normal file
|
@ -0,0 +1,147 @@
|
|||
import { z } from "zod";
|
||||
import { MessageHandler } from "../ws";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import db from "@server/db";
|
||||
import { olms, Site, sites } from "@server/db/schema";
|
||||
import { eq, isNotNull } from "drizzle-orm";
|
||||
import { findNextAvailableCidr } from "@server/lib/ip";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
const inputSchema = z.object({
|
||||
publicKey: z.string(),
|
||||
endpoint: z.string(),
|
||||
listenPort: z.number()
|
||||
});
|
||||
|
||||
type Input = z.infer<typeof inputSchema>;
|
||||
|
||||
export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
const { message, newt, sendToClient } = context;
|
||||
|
||||
logger.debug("Handling Newt get config message!");
|
||||
|
||||
if (!newt) {
|
||||
logger.warn("Newt not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!newt.siteId) {
|
||||
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = inputSchema.safeParse(message.data);
|
||||
if (!parsed.success) {
|
||||
logger.error(
|
||||
"handleGetConfigMessage: Invalid input: " +
|
||||
fromError(parsed.error).toString()
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const { publicKey, endpoint, listenPort } = message.data as Input;
|
||||
|
||||
const siteId = newt.siteId;
|
||||
|
||||
const [siteRes] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId));
|
||||
|
||||
if (!siteRes) {
|
||||
logger.warn("handleGetConfigMessage: Site not found");
|
||||
return;
|
||||
}
|
||||
|
||||
let site: Site | undefined;
|
||||
if (!site) {
|
||||
const address = await getNextAvailableSubnet();
|
||||
|
||||
// create a new exit node
|
||||
const [updateRes] = await db
|
||||
.update(sites)
|
||||
.set({
|
||||
publicKey,
|
||||
endpoint,
|
||||
address,
|
||||
listenPort
|
||||
})
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.returning();
|
||||
|
||||
site = updateRes;
|
||||
|
||||
logger.info(`Updated site ${siteId} with new WG Newt info`);
|
||||
} else {
|
||||
site = siteRes;
|
||||
}
|
||||
|
||||
if (!site) {
|
||||
logger.error("handleGetConfigMessage: Failed to update site");
|
||||
return;
|
||||
}
|
||||
|
||||
const clientsRes = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.siteId, siteId));
|
||||
|
||||
const peers = await Promise.all(
|
||||
clientsRes.map(async (client) => {
|
||||
return {
|
||||
publicKey: client.pubKey,
|
||||
allowedIps: "0.0.0.0/0"
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
const configResponse = {
|
||||
listenPort: site.listenPort, // ?????
|
||||
// ipAddress: exitNode[0].address,
|
||||
peers
|
||||
};
|
||||
|
||||
logger.debug("Sending config: ", configResponse);
|
||||
|
||||
return {
|
||||
message: {
|
||||
type: "olm/wg/connect", // what to make the response type?
|
||||
data: {
|
||||
config: configResponse
|
||||
}
|
||||
},
|
||||
broadcast: false, // Send to all clients
|
||||
excludeSender: false // Include sender in broadcast
|
||||
};
|
||||
};
|
||||
|
||||
async function getNextAvailableSubnet(): Promise<string> {
|
||||
const existingAddresses = await db
|
||||
.select({
|
||||
address: sites.address
|
||||
})
|
||||
.from(sites)
|
||||
.where(isNotNull(sites.address));
|
||||
|
||||
const addresses = existingAddresses
|
||||
.map((a) => a.address)
|
||||
.filter((a) => a) as string[];
|
||||
|
||||
let subnet = findNextAvailableCidr(
|
||||
addresses,
|
||||
config.getRawConfig().wg_site.block_size,
|
||||
config.getRawConfig().wg_site.subnet_group
|
||||
);
|
||||
if (!subnet) {
|
||||
throw new Error("No available subnets remaining in space");
|
||||
}
|
||||
|
||||
// replace the last octet with 1
|
||||
subnet =
|
||||
subnet.split(".").slice(0, 3).join(".") +
|
||||
".1" +
|
||||
"/" +
|
||||
subnet.split("/")[1];
|
||||
return subnet;
|
||||
}
|
93
server/routers/olm/handleOlmRegisterMessage.ts
Normal file
93
server/routers/olm/handleOlmRegisterMessage.ts
Normal file
|
@ -0,0 +1,93 @@
|
|||
import db from "@server/db";
|
||||
import { MessageHandler } from "../ws";
|
||||
import {
|
||||
exitNodes,
|
||||
resources,
|
||||
sites,
|
||||
Target,
|
||||
targets
|
||||
} from "@server/db/schema";
|
||||
import { eq, and, sql } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../gerbil/peers";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
const { message, client, sendToClient } = context;
|
||||
|
||||
const olm = client;
|
||||
|
||||
logger.info("Handling register message!");
|
||||
|
||||
if (!olm) {
|
||||
logger.warn("Olm not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!olm.siteId) {
|
||||
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
|
||||
return;
|
||||
}
|
||||
|
||||
const siteId = olm.siteId;
|
||||
|
||||
const { publicKey } = message.data;
|
||||
if (!publicKey) {
|
||||
logger.warn("Public key not provided");
|
||||
return;
|
||||
}
|
||||
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!site || !site.exitNodeId) {
|
||||
logger.warn("Site not found or does not have exit node");
|
||||
return;
|
||||
}
|
||||
|
||||
await db
|
||||
.update(sites)
|
||||
.set({
|
||||
pubKey: publicKey
|
||||
})
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.returning();
|
||||
|
||||
const [exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||
.limit(1);
|
||||
|
||||
if (site.pubKey && site.pubKey !== publicKey) {
|
||||
logger.info("Public key mismatch. Deleting old peer...");
|
||||
await deletePeer(site.exitNodeId, site.pubKey);
|
||||
}
|
||||
|
||||
if (!site.subnet) {
|
||||
logger.warn("Site has no subnet");
|
||||
return;
|
||||
}
|
||||
|
||||
// add the peer to the exit node
|
||||
await addPeer(site.exitNodeId, {
|
||||
publicKey: publicKey,
|
||||
allowedIps: [site.subnet]
|
||||
});
|
||||
|
||||
return {
|
||||
message: {
|
||||
type: "olm/wg/connect",
|
||||
data: {
|
||||
endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`,
|
||||
publicKey: exitNode.publicKey,
|
||||
serverIP: exitNode.address.split("/")[0],
|
||||
tunnelIP: site.subnet.split("/")[0]
|
||||
}
|
||||
},
|
||||
broadcast: false, // Send to all olms
|
||||
excludeSender: false // Include sender in broadcast
|
||||
};
|
||||
};
|
1
server/routers/olm/index.ts
Normal file
1
server/routers/olm/index.ts
Normal file
|
@ -0,0 +1 @@
|
|||
export * from "./pickOlmDefaults";
|
|
@ -1,6 +1,6 @@
|
|||
import { Request, Response, NextFunction } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { clients, sites } from "@server/db/schema";
|
||||
import { olms, sites } from "@server/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
|
@ -30,7 +30,7 @@ export type PickClientDefaultsResponse = {
|
|||
clientSecret: string;
|
||||
};
|
||||
|
||||
export async function pickClientDefaults(
|
||||
export async function pickOlmDefaults(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
|
@ -71,10 +71,10 @@ export async function pickClientDefaults(
|
|||
|
||||
const clientsQuery = await db
|
||||
.select({
|
||||
subnet: clients.subnet
|
||||
subnet: olms.subnet
|
||||
})
|
||||
.from(clients)
|
||||
.where(eq(clients.siteId, site.siteId));
|
||||
.from(olms)
|
||||
.where(eq(olms.siteId, site.siteId));
|
||||
|
||||
let subnets = clientsQuery.map((client) => client.subnet);
|
||||
|
|
@ -3,10 +3,11 @@ import { Server as HttpServer } from "http";
|
|||
import { WebSocket, WebSocketServer } from "ws";
|
||||
import { IncomingMessage } from "http";
|
||||
import { Socket } from "net";
|
||||
import { Newt, newts, NewtSession } from "@server/db/schema";
|
||||
import { Newt, newts, NewtSession, Olm, olms, OlmSession } from "@server/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import db from "@server/db";
|
||||
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
|
||||
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
|
||||
import { messageHandlers } from "./messageHandlers";
|
||||
import logger from "@server/logger";
|
||||
|
||||
|
@ -15,13 +16,17 @@ interface WebSocketRequest extends IncomingMessage {
|
|||
token?: string;
|
||||
}
|
||||
|
||||
type ClientType = 'newt' | 'olm';
|
||||
|
||||
interface AuthenticatedWebSocket extends WebSocket {
|
||||
newt?: Newt;
|
||||
client?: Newt | Olm;
|
||||
clientType?: ClientType;
|
||||
}
|
||||
|
||||
interface TokenPayload {
|
||||
newt: Newt;
|
||||
session: NewtSession;
|
||||
client: Newt | Olm;
|
||||
session: NewtSession | OlmSession;
|
||||
clientType: ClientType;
|
||||
}
|
||||
|
||||
interface WSMessage {
|
||||
|
@ -33,15 +38,16 @@ interface HandlerResponse {
|
|||
message: WSMessage;
|
||||
broadcast?: boolean;
|
||||
excludeSender?: boolean;
|
||||
targetNewtId?: string;
|
||||
targetClientId?: string;
|
||||
}
|
||||
|
||||
interface HandlerContext {
|
||||
message: WSMessage;
|
||||
senderWs: WebSocket;
|
||||
newt: Newt | undefined;
|
||||
sendToClient: (newtId: string, message: WSMessage) => boolean;
|
||||
broadcastToAllExcept: (message: WSMessage, excludeNewtId?: string) => void;
|
||||
client: Newt | Olm | undefined;
|
||||
clientType: ClientType;
|
||||
sendToClient: (clientId: string, message: WSMessage) => boolean;
|
||||
broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => void;
|
||||
connectedClients: Map<string, WebSocket[]>;
|
||||
}
|
||||
|
||||
|
@ -54,34 +60,32 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true });
|
|||
let connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
||||
|
||||
// Helper functions for client management
|
||||
const addClient = (newtId: string, ws: AuthenticatedWebSocket): void => {
|
||||
const existingClients = connectedClients.get(newtId) || [];
|
||||
const addClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => {
|
||||
const existingClients = connectedClients.get(clientId) || [];
|
||||
existingClients.push(ws);
|
||||
connectedClients.set(newtId, existingClients);
|
||||
logger.info(`Client added to tracking - Newt ID: ${newtId}, Total connections: ${existingClients.length}`);
|
||||
connectedClients.set(clientId, existingClients);
|
||||
logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Total connections: ${existingClients.length}`);
|
||||
};
|
||||
|
||||
const removeClient = (newtId: string, ws: AuthenticatedWebSocket): void => {
|
||||
const existingClients = connectedClients.get(newtId) || [];
|
||||
const removeClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => {
|
||||
const existingClients = connectedClients.get(clientId) || [];
|
||||
const updatedClients = existingClients.filter(client => client !== ws);
|
||||
|
||||
if (updatedClients.length === 0) {
|
||||
connectedClients.delete(newtId);
|
||||
logger.info(`All connections removed for Newt ID: ${newtId}`);
|
||||
connectedClients.delete(clientId);
|
||||
logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`);
|
||||
} else {
|
||||
connectedClients.set(newtId, updatedClients);
|
||||
logger.info(`Connection removed - Newt ID: ${newtId}, Remaining connections: ${updatedClients.length}`);
|
||||
connectedClients.set(clientId, updatedClients);
|
||||
logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Helper functions for sending messages
|
||||
const sendToClient = (newtId: string, message: WSMessage): boolean => {
|
||||
const clients = connectedClients.get(newtId);
|
||||
const sendToClient = (clientId: string, message: WSMessage): boolean => {
|
||||
const clients = connectedClients.get(clientId);
|
||||
if (!clients || clients.length === 0) {
|
||||
logger.info(`No active connections found for Newt ID: ${newtId}`);
|
||||
logger.info(`No active connections found for Client ID: ${clientId}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
const messageString = JSON.stringify(message);
|
||||
clients.forEach(client => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
|
@ -91,9 +95,9 @@ const sendToClient = (newtId: string, message: WSMessage): boolean => {
|
|||
return true;
|
||||
};
|
||||
|
||||
const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void => {
|
||||
connectedClients.forEach((clients, newtId) => {
|
||||
if (newtId !== excludeNewtId) {
|
||||
const broadcastToAllExcept = (message: WSMessage, excludeClientId?: string): void => {
|
||||
connectedClients.forEach((clients, clientId) => {
|
||||
if (clientId !== excludeClientId) {
|
||||
clients.forEach(client => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(JSON.stringify(message));
|
||||
|
@ -103,84 +107,88 @@ const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void
|
|||
});
|
||||
};
|
||||
|
||||
// Token verification middleware (unchanged)
|
||||
const verifyToken = async (token: string): Promise<TokenPayload | null> => {
|
||||
// Token verification middleware
|
||||
const verifyToken = async (token: string, clientType: ClientType): Promise<TokenPayload | null> => {
|
||||
try {
|
||||
const { session, newt } = await validateNewtSessionToken(token);
|
||||
|
||||
if (!session || !newt) {
|
||||
return null;
|
||||
if (clientType === 'newt') {
|
||||
const { session, newt } = await validateNewtSessionToken(token);
|
||||
if (!session || !newt) {
|
||||
return null;
|
||||
}
|
||||
const existingNewt = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.newtId, newt.newtId));
|
||||
if (!existingNewt || !existingNewt[0]) {
|
||||
return null;
|
||||
}
|
||||
return { client: existingNewt[0], session, clientType };
|
||||
} else {
|
||||
const { session, olm } = await validateOlmSessionToken(token);
|
||||
if (!session || !olm) {
|
||||
return null;
|
||||
}
|
||||
const existingOlm = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.olmId, olm.olmId));
|
||||
if (!existingOlm || !existingOlm[0]) {
|
||||
return null;
|
||||
}
|
||||
return { client: existingOlm[0], session, clientType };
|
||||
}
|
||||
|
||||
const existingNewt = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.newtId, newt.newtId));
|
||||
|
||||
if (!existingNewt || !existingNewt[0]) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return { newt: existingNewt[0], session };
|
||||
} catch (error) {
|
||||
logger.error("Token verification failed:", error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
|
||||
const setupConnection = (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): void => {
|
||||
logger.info("Establishing websocket connection");
|
||||
|
||||
if (!newt) {
|
||||
logger.error("Connection attempt without newt");
|
||||
if (!client) {
|
||||
logger.error("Connection attempt without client");
|
||||
return ws.terminate();
|
||||
}
|
||||
|
||||
ws.newt = newt;
|
||||
ws.client = client;
|
||||
ws.clientType = clientType;
|
||||
|
||||
// Add client to tracking
|
||||
addClient(newt.newtId, ws);
|
||||
const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId;
|
||||
addClient(clientId, ws, clientType);
|
||||
|
||||
ws.on("message", async (data) => {
|
||||
try {
|
||||
const message: WSMessage = JSON.parse(data.toString());
|
||||
// logger.info(`Message received from Newt ID ${newtId}:`, message);
|
||||
|
||||
// Validate message format
|
||||
if (!message.type || typeof message.type !== "string") {
|
||||
throw new Error("Invalid message format: missing or invalid type");
|
||||
}
|
||||
|
||||
// Get the appropriate handler for the message type
|
||||
const handler = messageHandlers[message.type];
|
||||
if (!handler) {
|
||||
throw new Error(`Unsupported message type: ${message.type}`);
|
||||
}
|
||||
|
||||
// Process the message and get response
|
||||
const response = await handler({
|
||||
message,
|
||||
senderWs: ws,
|
||||
newt: ws.newt,
|
||||
client: ws.client,
|
||||
clientType: ws.clientType!,
|
||||
sendToClient,
|
||||
broadcastToAllExcept,
|
||||
connectedClients
|
||||
});
|
||||
|
||||
// Send response if one was returned
|
||||
if (response) {
|
||||
if (response.broadcast) {
|
||||
// Broadcast to all clients except sender if specified
|
||||
broadcastToAllExcept(response.message, response.excludeSender ? newt.newtId : undefined);
|
||||
} else if (response.targetNewtId) {
|
||||
// Send to specific client if targetNewtId is provided
|
||||
sendToClient(response.targetNewtId, response.message);
|
||||
broadcastToAllExcept(response.message, response.excludeSender ? clientId : undefined);
|
||||
} else if (response.targetClientId) {
|
||||
sendToClient(response.targetClientId, response.message);
|
||||
} else {
|
||||
// Send back to sender
|
||||
ws.send(JSON.stringify(response.message));
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Message handling error:", error);
|
||||
ws.send(JSON.stringify({
|
||||
|
@ -194,18 +202,18 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
|
|||
});
|
||||
|
||||
ws.on("close", () => {
|
||||
removeClient(newt.newtId, ws);
|
||||
logger.info(`Client disconnected - Newt ID: ${newt.newtId}`);
|
||||
removeClient(clientId, ws, clientType);
|
||||
logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`);
|
||||
});
|
||||
|
||||
ws.on("error", (error: Error) => {
|
||||
logger.error(`WebSocket error for Newt ID ${newt.newtId}:`, error);
|
||||
logger.error(`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, error);
|
||||
});
|
||||
|
||||
logger.info(`WebSocket connection established - Newt ID: ${newt.newtId}`);
|
||||
logger.info(`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`);
|
||||
};
|
||||
|
||||
// Router endpoint (unchanged)
|
||||
// Router endpoint
|
||||
router.get("/ws", (req: Request, res: Response) => {
|
||||
res.status(200).send("WebSocket endpoint");
|
||||
});
|
||||
|
@ -214,18 +222,22 @@ router.get("/ws", (req: Request, res: Response) => {
|
|||
const handleWSUpgrade = (server: HttpServer): void => {
|
||||
server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
|
||||
try {
|
||||
const token = request.url?.includes("?")
|
||||
? new URLSearchParams(request.url.split("?")[1]).get("token") || ""
|
||||
: request.headers["sec-websocket-protocol"];
|
||||
const url = new URL(request.url || '', `http://${request.headers.host}`);
|
||||
const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || '';
|
||||
let clientType = url.searchParams.get('clientType') as ClientType;
|
||||
|
||||
if (!token) {
|
||||
logger.warn("Unauthorized connection attempt: no token...");
|
||||
if (!clientType) {
|
||||
clientType = "newt";
|
||||
}
|
||||
|
||||
if (!token || !clientType || !['newt', 'olm'].includes(clientType)) {
|
||||
logger.warn("Unauthorized connection attempt: invalid token or client type...");
|
||||
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
|
||||
socket.destroy();
|
||||
return;
|
||||
}
|
||||
|
||||
const tokenPayload = await verifyToken(token);
|
||||
const tokenPayload = await verifyToken(token, clientType);
|
||||
if (!tokenPayload) {
|
||||
logger.warn("Unauthorized connection attempt: invalid token...");
|
||||
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
|
||||
|
@ -234,7 +246,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
|||
}
|
||||
|
||||
wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
|
||||
setupConnection(ws, tokenPayload.newt);
|
||||
setupConnection(ws, tokenPayload.client, tokenPayload.clientType);
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error("WebSocket upgrade error:", error);
|
||||
|
@ -250,4 +262,4 @@ export {
|
|||
sendToClient,
|
||||
broadcastToAllExcept,
|
||||
connectedClients
|
||||
};
|
||||
};
|
Loading…
Add table
Add a link
Reference in a new issue