diff --git a/server/db/schema.ts b/server/db/schema.ts index 92710278..05825c3f 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -157,7 +157,7 @@ export const clients = sqliteTable("clients", { type: text("type").notNull(), // "olm" online: integer("online", { mode: "boolean" }).notNull().default(false), endpoint: text("endpoint"), - lastHolePunch: integer("lastHolePunch"), + lastHolePunch: integer("lastHolePunch") }); export const clientSites = sqliteTable("clientSites", { @@ -167,6 +167,7 @@ export const clientSites = sqliteTable("clientSites", { siteId: integer("siteId") .notNull() .references(() => sites.siteId, { onDelete: "cascade" }), + isRelayed: integer("isRelayed", { mode: "boolean" }).notNull().default(false) }); export const olms = sqliteTable("olms", { diff --git a/server/routers/messageHandlers.ts b/server/routers/messageHandlers.ts index 8fb240c1..074cd4e2 100644 --- a/server/routers/messageHandlers.ts +++ b/server/routers/messageHandlers.ts @@ -1,5 +1,5 @@ import { handleNewtRegisterMessage, handleReceiveBandwidthMessage, handleGetConfigMessage } from "./newt"; -import { handleOlmRegisterMessage, handleOlmRelayMessage, handleOlmPingMessage } from "./olm"; +import { handleOlmRegisterMessage, handleOlmRelayMessage, handleOlmPingMessage, startOfflineChecker } from "./olm"; import { MessageHandler } from "./ws"; export const messageHandlers: Record = { @@ -10,3 +10,5 @@ export const messageHandlers: Record = { "olm/wg/relay": handleOlmRelayMessage, "olm/ping": handleOlmPingMessage }; + +startOfflineChecker(); // this is to handle the offline check for olms \ No newline at end of file diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index 19e91f8a..6c4ace06 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -3,14 +3,22 @@ import { MessageHandler } from "../ws"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import db from "@server/db"; -import { clients, clientSites, Newt, Site, sites } from "@server/db/schema"; +import { + clients, + clientSites, + Newt, + Site, + sites, + olms +} from "@server/db/schema"; import { eq } from "drizzle-orm"; import { getNextAvailableClientSubnet } from "@server/lib/ip"; import config from "@server/lib/config"; +import { addPeer } from "../olm/peers"; const inputSchema = z.object({ publicKey: z.string(), - port: z.number().int().positive(), + port: z.number().int().positive() }); type Input = z.infer; @@ -43,42 +51,42 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { } const { publicKey, port } = message.data as Input; - const siteId = newt.siteId; - const [siteRes] = await db + // Get the current site data + const [existingSite] = await db .select() .from(sites) .where(eq(sites.siteId, siteId)); - if (!siteRes) { + if (!existingSite) { logger.warn("handleGetConfigMessage: Site not found"); return; } let site: Site | undefined; - if (!siteRes.address) { - let address = await getNextAvailableClientSubnet(siteRes.orgId); + if (!existingSite.address) { + // This is a new site configuration + let address = await getNextAvailableClientSubnet(existingSite.orgId); if (!address) { logger.error("handleGetConfigMessage: No available address"); return; } - address = `${address.split("/")[0]}/${config.getRawConfig().orgs.block_size}` // we want the block size of the whole org + address = `${address.split("/")[0]}/${config.getRawConfig().orgs.block_size}`; // we want the block size of the whole org - // create a new exit node + // Update the site with new WireGuard info const [updateRes] = await db .update(sites) .set({ publicKey, address, - listenPort: port, + listenPort: port }) .where(eq(sites.siteId, siteId)) .returning(); site = updateRes; - logger.info(`Updated site ${siteId} with new WG Newt info`); } else { // update the endpoint and the public key @@ -86,7 +94,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { .update(sites) .set({ publicKey, - listenPort: port, + listenPort: port }) .where(eq(sites.siteId, siteId)) .returning(); @@ -99,12 +107,14 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { return; } + // Get all clients connected to this site const clientsRes = await db .select() .from(clients) .innerJoin(clientSites, eq(clients.clientId, clientSites.clientId)) .where(eq(clientSites.siteId, siteId)); + // Prepare peers data for the response const peers = await Promise.all( clientsRes .filter((client) => { @@ -124,29 +134,49 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { return true; }) .map(async (client) => { - return { - publicKey: client.clients.pubKey, - allowedIps: [client.clients.subnet], - endpoint: client.clients.endpoint + const peerData = { + publicKey: client.clients.pubKey!, + allowedIps: [client.clients.subnet!], + endpoint: client.clientSites.isRelayed ? "" : client.clients.endpoint! // if its relayed it should be localhost }; + + // Add or update this peer on the olm if it is connected + try { + await addPeer(client.clients.clientId, { + ...peerData, + siteId: siteId, + serverIP: site.address, + serverPort: site.listenPort + }); + } catch (error) { + logger.error( + `Failed to add/update peer ${client.clients.pubKey} to newt ${newt.newtId}: ${error}` + ); + } + + return peerData; }) ); + // Filter out any null values from peers that didn't have an olm + const validPeers = peers.filter((peer) => peer !== null); + + // Build the configuration response const configResponse = { ipAddress: site.address, - peers + peers: validPeers }; logger.debug("Sending config: ", configResponse); return { message: { - type: "newt/wg/receive-config", // what to make the response type? + type: "newt/wg/receive-config", data: { ...configResponse } }, - broadcast: false, // Send to all clients - excludeSender: false // Include sender in broadcast + broadcast: false, + excludeSender: false }; -}; \ No newline at end of file +}; diff --git a/server/routers/newt/peers.ts b/server/routers/newt/peers.ts index 99aacf0d..f5e6c518 100644 --- a/server/routers/newt/peers.ts +++ b/server/routers/newt/peers.ts @@ -18,7 +18,7 @@ export async function addPeer(siteId: number, peer: { // get the newt on the site const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1); if (!newt) { - throw new Error(`Newt not found for site ${siteId}`); + throw new Error(`Site found for site ${siteId}`); } sendToClient(newt.newtId, { @@ -32,7 +32,7 @@ export async function addPeer(siteId: number, peer: { export async function deletePeer(siteId: number, publicKey: string) { const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1); if (!site) { - throw new Error(`Exit node with ID ${siteId} not found`); + throw new Error(`Site with ID ${siteId} not found`); } // get the newt on the site @@ -57,7 +57,7 @@ export async function updatePeer(siteId: number, publicKey: string, peer: { }) { const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1); if (!site) { - throw new Error(`Exit node with ID ${siteId} not found`); + throw new Error(`Site with ID ${siteId} not found`); } // get the newt on the site diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index 10067832..c958c38d 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -3,7 +3,6 @@ import { MessageHandler } from "../ws"; import { clients, Olm } from "@server/db/schema"; import { eq, lt, isNull } from "drizzle-orm"; import logger from "@server/logger"; -import { time } from "console"; // Track if the offline checker interval is running let offlineCheckerInterval: NodeJS.Timeout | null = null; diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 1ada2eba..a398d5e4 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -68,13 +68,26 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - // Update the client's public key - await db - .update(clients) - .set({ - pubKey: publicKey - }) - .where(eq(clients.clientId, olm.clientId)); + if (client.pubKey !== publicKey) { + logger.info( + "Public key mismatch. Updating public key and clearing session info..." + ); + // Update the client's public key + await db + .update(clients) + .set({ + pubKey: publicKey + }) + .where(eq(clients.clientId, olm.clientId)); + + // set isRelay to false for all of the client's sites to reset the connection metadata + await db + .update(clientSites) + .set({ + isRelayed: false + }) + .where(eq(clientSites.clientId, olm.clientId)); + } // Get all sites data const sitesData = await db @@ -143,7 +156,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { endpoint: site.endpoint, publicKey: site.publicKey, serverIP: site.address, - serverPort: site.listenPort, + serverPort: site.listenPort }); } diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts new file mode 100644 index 00000000..90c91c07 --- /dev/null +++ b/server/routers/olm/peers.ts @@ -0,0 +1,70 @@ +import db from '@server/db'; +import { clients, olms, newts } from '@server/db/schema'; +import { eq } from 'drizzle-orm'; +import { sendToClient } from '../ws'; +import logger from '@server/logger'; + +export async function addPeer(clientId: number, peer: { + siteId: number, + publicKey: string; + allowedIps: string[]; + endpoint: string; + serverIP: string | null; + serverPort: number | null; +}) { + const [olm] = await db.select().from(olms).where(eq(olms.clientId, clientId)).limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + + sendToClient(olm.olmId, { + type: 'olm/wg/peer/add', + data: { + publicKey: peer.publicKey, + allowedIps: peer.allowedIps, + endpoint: peer.endpoint, + serverIP: peer.serverIP, + serverPort: peer.serverPort + } + }); + + logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`); +} + +export async function deletePeer(clientId: number, publicKey: string) { + const [olm] = await db.select().from(olms).where(eq(olms.clientId, clientId)).limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + + sendToClient(olm.olmId, { + type: 'olm/wg/peer/remove', + data: { + publicKey + } + }); + + logger.info(`Deleted peer ${publicKey} from olm ${olm.olmId}`); +} + +export async function updatePeer(clientId: number, publicKey: string, peer: { + allowedIps?: string[]; + endpoint?: string; + serverIP?: string; + serverPort?: number; +}) { + const [olm] = await db.select().from(olms).where(eq(olms.clientId, clientId)).limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + + sendToClient(olm.olmId, { + type: 'olm/wg/peer/update', + data: { + publicKey, + ...peer + } + }); + + logger.info(`Updated peer ${publicKey} on olm ${olm.olmId}`); +} \ No newline at end of file