From 1a9de1e5c5c85f95e4770a4b6f15281decf8847e Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 4 Aug 2025 20:17:35 -0700 Subject: [PATCH] Move endpoint to per site --- server/db/pg/schema.ts | 5 +- server/db/sqlite/schema.ts | 5 +- server/routers/client/updateClient.ts | 37 ++- server/routers/gerbil/getAllRelays.ts | 16 +- server/routers/gerbil/updateHolePunch.ts | 231 ++++++++++++------ server/routers/newt/handleGetConfigMessage.ts | 5 +- .../routers/olm/handleOlmRegisterMessage.ts | 17 +- 7 files changed, 208 insertions(+), 108 deletions(-) diff --git a/server/db/pg/schema.ts b/server/db/pg/schema.ts index be4e58e2..d307f399 100644 --- a/server/db/pg/schema.ts +++ b/server/db/pg/schema.ts @@ -516,7 +516,7 @@ export const clients = pgTable("clients", { lastPing: varchar("lastPing"), type: varchar("type").notNull(), // "olm" online: boolean("online").notNull().default(false), - endpoint: varchar("endpoint"), + // endpoint: varchar("endpoint"), lastHolePunch: integer("lastHolePunch"), maxConnections: integer("maxConnections") }); @@ -528,7 +528,8 @@ export const clientSites = pgTable("clientSites", { siteId: integer("siteId") .notNull() .references(() => sites.siteId, { onDelete: "cascade" }), - isRelayed: boolean("isRelayed").notNull().default(false) + isRelayed: boolean("isRelayed").notNull().default(false), + endpoint: varchar("endpoint") }); export const olms = pgTable("olms", { diff --git a/server/db/sqlite/schema.ts b/server/db/sqlite/schema.ts index 5773a5f3..10f6686e 100644 --- a/server/db/sqlite/schema.ts +++ b/server/db/sqlite/schema.ts @@ -216,7 +216,7 @@ export const clients = sqliteTable("clients", { lastPing: text("lastPing"), type: text("type").notNull(), // "olm" online: integer("online", { mode: "boolean" }).notNull().default(false), - endpoint: text("endpoint"), + // endpoint: text("endpoint"), lastHolePunch: integer("lastHolePunch") }); @@ -227,7 +227,8 @@ export const clientSites = sqliteTable("clientSites", { siteId: integer("siteId") .notNull() .references(() => sites.siteId, { onDelete: "cascade" }), - isRelayed: integer("isRelayed", { mode: "boolean" }).notNull().default(false) + isRelayed: integer("isRelayed", { mode: "boolean" }).notNull().default(false), + endpoint: text("endpoint") }); export const olms = sqliteTable("olms", { diff --git a/server/routers/client/updateClient.ts b/server/routers/client/updateClient.ts index 60a48732..de4a7b5e 100644 --- a/server/routers/client/updateClient.ts +++ b/server/routers/client/updateClient.ts @@ -129,7 +129,7 @@ export async function updateClient( `Adding ${sitesAdded.length} new sites to client ${client.clientId}` ); for (const siteId of sitesAdded) { - if (!client.subnet || !client.pubKey || !client.endpoint) { + if (!client.subnet || !client.pubKey) { logger.debug( "Client subnet, pubKey or endpoint is not set" ); @@ -140,10 +140,25 @@ export async function updateClient( // BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS const isRelayed = true; + // get the clientsite + const [clientSite] = await db + .select() + .from(clientSites) + .where(and( + eq(clientSites.clientId, client.clientId), + eq(clientSites.siteId, siteId) + )) + .limit(1); + + if (!clientSite || !clientSite.endpoint) { + logger.debug("Client site is missing or has no endpoint"); + continue; + } + const site = await newtAddPeer(siteId, { publicKey: client.pubKey, allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client - endpoint: isRelayed ? "" : client.endpoint + endpoint: isRelayed ? "" : clientSite.endpoint }); if (!site) { @@ -255,7 +270,6 @@ export async function updateClient( } } - if (client.endpoint) { // get all sites for this client and join with exit nodes with site.exitNodeId const sitesData = await db .select() @@ -272,6 +286,8 @@ export async function updateClient( let exitNodeDestinations: { reachableAt: string; + sourceIp: string; + sourcePort: number; destinations: PeerDestination[]; }[] = []; @@ -282,6 +298,14 @@ export async function updateClient( ); continue; } + + if (!site.clientSites.endpoint) { + logger.warn( + `Site ${site.sites.siteId} has no endpoint, skipping` + ); + continue; + } + // find the destinations in the array let destinations = exitNodeDestinations.find( (d) => d.reachableAt === site.exitNodes?.reachableAt @@ -290,6 +314,8 @@ export async function updateClient( if (!destinations) { destinations = { reachableAt: site.exitNodes?.reachableAt || "", + sourceIp: site.clientSites.endpoint.split(":")[0] || "", + sourcePort: parseInt(site.clientSites.endpoint.split(":")[1]) || 0, destinations: [ { destinationIP: @@ -319,8 +345,8 @@ export async function updateClient( `Updating destinations for exit node at ${destination.reachableAt}` ); const payload = { - sourceIp: client.endpoint?.split(":")[0] || "", - sourcePort: parseInt(client.endpoint?.split(":")[1]) || 0, + sourceIp: destination.sourceIp, + sourcePort: destination.sourcePort, destinations: destination.destinations }; logger.info( @@ -351,7 +377,6 @@ export async function updateClient( } } } - } // Fetch the updated client const [updatedClient] = await trx diff --git a/server/routers/gerbil/getAllRelays.ts b/server/routers/gerbil/getAllRelays.ts index abe4d593..8d1c66b2 100644 --- a/server/routers/gerbil/getAllRelays.ts +++ b/server/routers/gerbil/getAllRelays.ts @@ -78,19 +78,13 @@ export async function getAllRelays( .where(eq(clientSites.siteId, site.siteId)); for (const clientSite of clientSitesRes) { - // Get client information - const [client] = await db - .select() - .from(clients) - .where(eq(clients.clientId, clientSite.clientId)); - - if (!client || !client.endpoint) { + if (!clientSite.endpoint) { continue; } // Add this site as a destination for the client - if (!mappings[client.endpoint]) { - mappings[client.endpoint] = { destinations: [] }; + if (!mappings[clientSite.endpoint]) { + mappings[clientSite.endpoint] = { destinations: [] }; } // Add site as a destination for this client @@ -100,13 +94,13 @@ export async function getAllRelays( }; // Check if this destination is already in the array to avoid duplicates - const isDuplicate = mappings[client.endpoint].destinations.some( + const isDuplicate = mappings[clientSite.endpoint].destinations.some( dest => dest.destinationIP === destination.destinationIP && dest.destinationPort === destination.destinationPort ); if (!isDuplicate) { - mappings[client.endpoint].destinations.push(destination); + mappings[clientSite.endpoint].destinations.push(destination); } } diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index 836061d6..39771454 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -1,8 +1,16 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { clients, newts, olms, Site, sites, clientSites, exitNodes } from "@server/db"; +import { + clients, + newts, + olms, + Site, + sites, + clientSites, + exitNodes +} from "@server/db"; import { db } from "@server/db"; -import { eq } from "drizzle-orm"; +import { eq, and } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; @@ -19,7 +27,8 @@ const updateHolePunchSchema = z.object({ ip: z.string(), port: z.number(), timestamp: z.number(), - reachableAt: z.string().optional() + reachableAt: z.string().optional(), + publicKey: z.string() }); // New response type with multi-peer destination support @@ -45,13 +54,24 @@ export async function updateHolePunch( ); } - const { olmId, newtId, ip, port, timestamp, token, reachableAt } = parsedParams.data; + const { + olmId, + newtId, + ip, + port, + timestamp, + token, + reachableAt, + publicKey + } = parsedParams.data; let currentSiteId: number | undefined; let destinations: PeerDestination[] = []; - + if (olmId) { - logger.debug(`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`); + logger.debug( + `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}${publicKey ? ` with exit node publicKey: ${publicKey}` : ""}` + ); const { session, olm: olmSession } = await validateOlmSessionToken(token); @@ -62,7 +82,9 @@ export async function updateHolePunch( } if (olmId !== olmSession.olmId) { - logger.warn(`Olm ID mismatch: ${olmId} !== ${olmSession.olmId}`); + logger.warn( + `Olm ID mismatch: ${olmId} !== ${olmSession.olmId}` + ); return next( createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized") ); @@ -83,12 +105,55 @@ export async function updateHolePunch( const [client] = await db .update(clients) .set({ - endpoint: `${ip}:${port}`, lastHolePunch: timestamp }) .where(eq(clients.clientId, olm.clientId)) .returning(); - + + // Get the exit node by public key + const [exitNode] = await db + .select() + .from(exitNodes) + .where(eq(exitNodes.publicKey, publicKey)); + + if (exitNode) { + // Get sites that are on this specific exit node and connected to this client + const sitesOnExitNode = await db + .select({ siteId: sites.siteId }) + .from(sites) + .innerJoin( + clientSites, + eq(sites.siteId, clientSites.siteId) + ) + .where( + and( + eq(sites.exitNodeId, exitNode.exitNodeId), + eq(clientSites.clientId, olm.clientId) + ) + ); + + // Update clientSites for each site on this exit node + for (const site of sitesOnExitNode) { + await db + .update(clientSites) + .set({ + endpoint: `${ip}:${port}` + }) + .where( + and( + eq(clientSites.clientId, olm.clientId), + eq(clientSites.siteId, site.siteId) + ) + ); + } + + logger.debug( + `Updated ${sitesOnExitNode.length} sites on exit node with publicKey: ${publicKey}` + ); + } else { + logger.warn(`Exit node not found for publicKey: ${publicKey}`); + } + if (!client) { logger.warn(`Client not found for olm: ${olmId}`); return next( @@ -101,23 +166,23 @@ export async function updateHolePunch( // .select() // .from(clientSites) // .where(eq(clientSites.clientId, client.clientId)); - + // if (clientSitePairs.length === 0) { // logger.warn(`No sites found for client: ${client.clientId}`); // return next( // createHttpError(HttpCode.NOT_FOUND, "No sites found for client") // ); // } - + // // Get all sites details // const siteIds = clientSitePairs.map(pair => pair.siteId); - + // for (const siteId of siteIds) { // const [site] = await db // .select() // .from(sites) // .where(eq(sites.siteId, siteId)); - + // if (site && site.subnet && site.listenPort) { // destinations.push({ // destinationIP: site.subnet.split("/")[0], @@ -141,7 +206,9 @@ export async function updateHolePunch( for (const site of sitesData) { if (!site.sites.subnet) { - logger.warn(`Site ${site.sites.siteId} has no subnet, skipping`); + logger.warn( + `Site ${site.sites.siteId} has no subnet, skipping` + ); continue; } // find the destinations in the array @@ -176,51 +243,55 @@ export async function updateHolePunch( logger.debug(JSON.stringify(exitNodeDestinations, null, 2)); - for (const destination of exitNodeDestinations) { - // if its the current exit node skip it because it is replying with the same data - if (reachableAt && destination.reachableAt == reachableAt) { - logger.debug(`Skipping update for reachableAt: ${reachableAt}`); - continue; - } + // BECAUSE OF HARD NAT YOU DONT WANT TO SEND THE OLM IP AND PORT TO THE ALL THE OTHER EXIT NODES + // BECAUSE THEY WILL GET A DIFFERENT IP AND PORT - try { - const response = await axios.post( - `${destination.reachableAt}/update-destinations`, - { - sourceIp: client.endpoint?.split(":")[0] || "", - sourcePort: parseInt(client.endpoint?.split(":")[1] || "0"), - destinations: destination.destinations - }, - { - headers: { - "Content-Type": "application/json" - } - } - ); + // for (const destination of exitNodeDestinations) { + // // if its the current exit node skip it because it is replying with the same data + // if (reachableAt && destination.reachableAt == reachableAt) { + // logger.debug(`Skipping update for reachableAt: ${reachableAt}`); + // continue; + // } - logger.info("Destinations updated:", { - peer: response.data.status - }); - } catch (error) { - if (axios.isAxiosError(error)) { - logger.error( - `Error updating destinations (can Pangolin see Gerbil HTTP API?) for exit node at ${destination.reachableAt} (status: ${error.response?.status}): ${JSON.stringify(error.response?.data, null, 2)}` - ); - } else { - logger.error( - `Error updating destinations for exit node at ${destination.reachableAt}: ${error}` - ); - } - } - } + // try { + // const response = await axios.post( + // `${destination.reachableAt}/update-destinations`, + // { + // sourceIp: client.endpoint?.split(":")[0] || "", + // sourcePort: parseInt(client.endpoint?.split(":")[1] || "0"), + // destinations: destination.destinations + // }, + // { + // headers: { + // "Content-Type": "application/json" + // } + // } + // ); + + // logger.info("Destinations updated:", { + // peer: response.data.status + // }); + // } catch (error) { + // if (axios.isAxiosError(error)) { + // logger.error( + // `Error updating destinations (can Pangolin see Gerbil HTTP API?) for exit node at ${destination.reachableAt} (status: ${error.response?.status}): ${JSON.stringify(error.response?.data, null, 2)}` + // ); + // } else { + // logger.error( + // `Error updating destinations for exit node at ${destination.reachableAt}: ${error}` + // ); + // } + // } + // } // Send the desinations back to the origin - destinations = exitNodeDestinations.find( - (d) => d.reachableAt === reachableAt - )?.destinations || []; - + destinations = + exitNodeDestinations.find((d) => d.reachableAt === reachableAt) + ?.destinations || []; } else if (newtId) { - logger.debug(`Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}`); + logger.debug( + `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}` + ); const { session, newt: newtSession } = await validateNewtSessionToken(token); @@ -232,7 +303,9 @@ export async function updateHolePunch( } if (newtId !== newtSession.newtId) { - logger.warn(`Newt ID mismatch: ${newtId} !== ${newtSession.newtId}`); + logger.warn( + `Newt ID mismatch: ${newtId} !== ${newtSession.newtId}` + ); return next( createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized") ); @@ -261,7 +334,7 @@ export async function updateHolePunch( }) .where(eq(sites.siteId, newt.siteId)) .returning(); - + if (!updatedSite || !updatedSite.subnet) { logger.warn(`Site not found: ${newt.siteId}`); return next( @@ -274,7 +347,7 @@ export async function updateHolePunch( // .select() // .from(clientSites) // .where(eq(clientSites.siteId, newt.siteId)); - + // THE NEWT IS NOT SENDING RAW WG TO THE GERBIL SO IDK IF WE REALLY NEED THIS - REMOVING // Get client details for each client // for (const pair of sitesClientPairs) { @@ -282,7 +355,7 @@ export async function updateHolePunch( // .select() // .from(clients) // .where(eq(clients.clientId, pair.clientId)); - + // if (client && client.endpoint) { // const [host, portStr] = client.endpoint.split(':'); // if (host && portStr) { @@ -293,27 +366,27 @@ export async function updateHolePunch( // } // } // } - + // If this is a newt/site, also add other sites in the same org - // if (updatedSite.orgId) { - // const orgSites = await db - // .select() - // .from(sites) - // .where(eq(sites.orgId, updatedSite.orgId)); - - // for (const site of orgSites) { - // // Don't add the current site to the destinations - // if (site.siteId !== currentSiteId && site.subnet && site.endpoint && site.listenPort) { - // const [host, portStr] = site.endpoint.split(':'); - // if (host && portStr) { - // destinations.push({ - // destinationIP: host, - // destinationPort: site.listenPort - // }); - // } - // } - // } - // } + // if (updatedSite.orgId) { + // const orgSites = await db + // .select() + // .from(sites) + // .where(eq(sites.orgId, updatedSite.orgId)); + + // for (const site of orgSites) { + // // Don't add the current site to the destinations + // if (site.siteId !== currentSiteId && site.subnet && site.endpoint && site.listenPort) { + // const [host, portStr] = site.endpoint.split(':'); + // if (host && portStr) { + // destinations.push({ + // destinationIP: host, + // destinationPort: site.listenPort + // }); + // } + // } + // } + // } } // if (destinations.length === 0) { @@ -336,4 +409,4 @@ export async function updateHolePunch( ) ); } -} \ No newline at end of file +} diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index 1059847c..9ab1c049 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -157,9 +157,6 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { if (!client.clients.subnet) { return false; } - if (!client.clients.endpoint) { - return false; - } return true; }) .map(async (client) => { @@ -215,7 +212,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { allowedIps: [`${client.clients.subnet.split("/")[0]}/32`], // we want to only allow from that client endpoint: client.clientSites.isRelayed ? "" - : client.clients.endpoint! // if its relayed it should be localhost + : client.clientSites.endpoint! // if its relayed it should be localhost }; }) ); diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 32e4fe51..7028a2f0 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -8,7 +8,7 @@ import { olms, sites } from "@server/db"; -import { eq, inArray } from "drizzle-orm"; +import { and, eq, inArray } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; @@ -147,15 +147,24 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { continue; } + const [clientSite] = await db + .select() + .from(clientSites) + .where(and( + eq(clientSites.clientId, client.clientId), + eq(clientSites.siteId, site.siteId) + )) + .limit(1); + // Add the peer to the exit node for this site - if (client.endpoint) { + if (clientSite.endpoint) { logger.info( - `Adding peer ${publicKey} to site ${site.siteId} with endpoint ${client.endpoint}` + `Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}` ); await addPeer(site.siteId, { publicKey: publicKey, allowedIps: [`${client.subnet.split('/')[0]}/32`], // we want to only allow from that client - endpoint: relay ? "" : client.endpoint + endpoint: relay ? "" : clientSite.endpoint }); } else { logger.warn(