diff --git a/server/db/schema.ts b/server/db/schema.ts index 875754f5..002fc442 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -143,6 +143,9 @@ export const clients = sqliteTable("clients", { onDelete: "cascade" }) .notNull(), + exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { + onDelete: "set null" + }), name: text("name").notNull(), pubKey: text("pubKey"), subnet: text("subnet").notNull(), diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index ac658b78..88c5ba4c 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -1,6 +1,13 @@ import db from "@server/db"; import { MessageHandler } from "../ws"; -import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db/schema"; +import { + clients, + clientSites, + exitNodes, + Olm, + olms, + sites +} from "@server/db/schema"; import { eq, inArray } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; @@ -23,30 +30,36 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.warn("Public key not provided"); return; } - + // Get the client const [client] = await db .select() .from(clients) .where(eq(clients.clientId, clientId)) .limit(1); - + if (!client) { logger.warn("Client not found"); return; } + + if (client.exitNodeId) { + // Get the exit node for this site + const [exitNode] = await db + .select() + .from(exitNodes) + .where(eq(exitNodes.exitNodeId, client.exitNodeId)) + .limit(1); - // Get all site associations for this client - const clientSiteAssociations = await db - .select() - .from(clientSites) - .where(eq(clientSites.clientId, clientId)); - - if (clientSiteAssociations.length === 0) { - logger.warn("Client is not associated with any sites"); - return; + // Send holepunch message for each site + sendToClient(olm.olmId, { + type: "olm/wg/holepunch", + data: { + serverPubKey: exitNode.publicKey + } + }); } - + // Update the client's public key await db .update(clients) @@ -55,103 +68,97 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { }) .where(eq(clients.clientId, olm.clientId)) .returning(); - + // Check if public key changed and handle old peer deletion later const pubKeyChanged = client.pubKey && client.pubKey !== publicKey; - + // Get all sites data - const siteIds = clientSiteAssociations.map(cs => cs.siteId); const sitesData = await db .select() .from(sites) - .where(inArray(sites.siteId, siteIds)); - + .innerJoin(clientSites, eq(sites.siteId, clientSites.siteId)) + .where(eq(clientSites.clientId, client.clientId)); + // Prepare an array to store site configurations const siteConfigurations = []; const now = new Date().getTime() / 1000; - + // Process each site - for (const site of sitesData) { + for (const { sites: site } of sitesData) { if (!site.exitNodeId) { - logger.warn(`Site ${site.siteId} does not have exit node, skipping`); + logger.warn( + `Site ${site.siteId} does not have exit node, skipping` + ); continue; } - - // Get the exit node for this site - const [exitNode] = await db - .select() - .from(exitNodes) - .where(eq(exitNodes.exitNodeId, site.exitNodeId)) - .limit(1); - + // Validate endpoint and hole punch status if (!site.endpoint) { logger.warn(`Site ${site.siteId} has no endpoint, skipping`); continue; } - + if (site.lastHolePunch && now - site.lastHolePunch > 6) { - logger.warn(`Site ${site.siteId} last hole punch is too old, skipping`); + logger.warn( + `Site ${site.siteId} last hole punch is too old, skipping` + ); continue; } - + if (client.lastHolePunch && now - client.lastHolePunch > 6) { - logger.warn("Client last hole punch is too old, skipping all sites"); + logger.warn( + "Client last hole punch is too old, skipping all sites" + ); break; } - + // If public key changed, delete old peer from this site if (pubKeyChanged) { - logger.info(`Public key mismatch. Deleting old peer from site ${site.siteId}...`); - await deletePeer(site.siteId, client.pubKey); + logger.info( + `Public key mismatch. Deleting old peer from site ${site.siteId}...` + ); + await deletePeer(site.siteId, client.pubKey!); } - + if (!site.subnet) { logger.warn(`Site ${site.siteId} has no subnet, skipping`); continue; } - + // Add the peer to the exit node for this site - await addPeer(site.siteId, { - publicKey: publicKey, - allowedIps: [client.subnet], - endpoint: client.endpoint - }); - + if (client.endpoint) { + await addPeer(site.siteId, { + publicKey: publicKey, + allowedIps: [client.subnet], + endpoint: client.endpoint + }); + } + // Add site configuration to the array siteConfigurations.push({ siteId: site.siteId, endpoint: site.endpoint, publicKey: site.publicKey, - serverIP: site.address, - }); - - // Send holepunch message for each site - sendToClient(olm.olmId, { - type: "olm/wg/holepunch", - data: { - serverPubKey: exitNode.publicKey, - siteId: site.siteId - } + serverIP: site.address }); } - + // If we have no valid site configurations, don't send a connect message if (siteConfigurations.length === 0) { logger.warn("No valid site configurations found"); return; } - + // Return connect message with all site configurations return { message: { type: "olm/wg/connect", data: { sites: siteConfigurations, - tunnelIP: client.subnet, + tunnelIP: client.subnet } }, broadcast: false, excludeSender: false }; -}; \ No newline at end of file +};