diff --git a/server/routers/gerbil/getAllRelays.ts b/server/routers/gerbil/getAllRelays.ts index e846326a..c975efec 100644 --- a/server/routers/gerbil/getAllRelays.ts +++ b/server/routers/gerbil/getAllRelays.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { clients, exitNodes, newts, olms, Site, sites } from "@server/db/schema"; +import { clients, exitNodes, newts, olms, Site, sites, clientSites } from "@server/db/schema"; import { db } from "@server/db"; import { eq } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; @@ -13,6 +13,17 @@ const getAllRelaysSchema = z.object({ publicKey: z.string().optional(), }); +// Type for peer destination +interface PeerDestination { + destinationIP: string; + destinationPort: number; +} + +// Updated mappings type to support multiple destinations per endpoint +interface ProxyMapping { + destinations: PeerDestination[]; +} + export async function getAllRelays( req: Request, res: Response, @@ -46,38 +57,96 @@ export async function getAllRelays( const sitesRes = await db.select().from(sites).where(eq(sites.exitNodeId, exitNode.exitNodeId)); if (sitesRes.length === 0) { - return { + return res.status(HttpCode.OK).send({ mappings: {} - } + }); } - // // get the clients on each site and map them to the site - // const sitesAndClients = await Promise.all(sitesRes.map(async (site) => { - // const clientsRes = await db.select().from(clients).where(eq(clients.siteId, site.siteId)); - // return { - // site, - // clients: clientsRes - // }; - // })); + // Initialize mappings object for multi-peer support + let mappings: { [key: string]: ProxyMapping } = {}; - let mappings: { [key: string]: { - destinationIp: string; - destinationPort: number; - } } = {}; + // Process each site + for (const site of sitesRes) { + if (!site.endpoint || !site.subnet || !site.listenPort) { + continue; + } - // for (const siteAndClients of sitesAndClients) { - // const { site, clients } = siteAndClients; - // for (const client of clients) { - // if (!client.endpoint || !site.endpoint || !site.subnet) { - // continue; - // } - // mappings[client.endpoint] = { - // destinationIp: site.subnet.split("/")[0], - // destinationPort: parseInt(site.endpoint.split(":")[1]) - // }; - // } - // } + // Find all clients associated with this site through clientSites + const clientSitesRes = await db + .select() + .from(clientSites) + .where(eq(clientSites.siteId, site.siteId)); + + for (const clientSite of clientSitesRes) { + // Get client information + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, clientSite.clientId)); + if (!client || !client.endpoint) { + continue; + } + + // Add this site as a destination for the client + if (!mappings[client.endpoint]) { + mappings[client.endpoint] = { destinations: [] }; + } + + // Add site as a destination for this client + const destination: PeerDestination = { + destinationIP: site.subnet.split("/")[0], + destinationPort: site.listenPort + }; + + // Check if this destination is already in the array to avoid duplicates + const isDuplicate = mappings[client.endpoint].destinations.some( + dest => dest.destinationIP === destination.destinationIP && + dest.destinationPort === destination.destinationPort + ); + + if (!isDuplicate) { + mappings[client.endpoint].destinations.push(destination); + } + } + + // Also handle site-to-site communication (all sites in the same org) + if (site.orgId) { + const orgSites = await db + .select() + .from(sites) + .where(eq(sites.orgId, site.orgId)); + + for (const peer of orgSites) { + // Skip self + if (peer.siteId === site.siteId || !peer.endpoint || !peer.subnet || !peer.listenPort) { + continue; + } + + // Add peer site as a destination for this site + if (!mappings[site.endpoint]) { + mappings[site.endpoint] = { destinations: [] }; + } + + const destination: PeerDestination = { + destinationIP: peer.subnet.split("/")[0], + destinationPort: peer.listenPort + }; + + // Check for duplicates + const isDuplicate = mappings[site.endpoint].destinations.some( + dest => dest.destinationIP === destination.destinationIP && + dest.destinationPort === destination.destinationPort + ); + + if (!isDuplicate) { + mappings[site.endpoint].destinations.push(destination); + } + } + } + } + + logger.debug(`Returning mappings for ${Object.keys(mappings).length} endpoints`); return res.status(HttpCode.OK).send({ mappings }); } catch (error) { logger.error(error); @@ -88,4 +157,4 @@ export async function getAllRelays( ) ); } -} +} \ No newline at end of file diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index 0149ec51..8d57d032 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { clients, newts, olms, Site, sites } from "@server/db/schema"; +import { clients, newts, olms, Site, sites, clientSites } from "@server/db/schema"; import { db } from "@server/db"; import { eq } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; @@ -20,6 +20,12 @@ const updateHolePunchSchema = z.object({ timestamp: z.number() }); +// New response type with multi-peer destination support +interface PeerDestination { + destinationIP: string; + destinationPort: number; +} + export async function updateHolePunch( req: Request, res: Response, @@ -41,7 +47,8 @@ export async function updateHolePunch( // logger.debug(`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId} or newtId: ${newtId}`); - let site: Site | undefined; + let currentSiteId: number | undefined; + let destinations: PeerDestination[] = []; if (olmId) { const { session, olm: olmSession } = @@ -79,11 +86,43 @@ export async function updateHolePunch( }) .where(eq(clients.clientId, olm.clientId)) .returning(); + + if (!client) { + logger.warn(`Client not found for olm: ${olmId}`); + return next( + createHttpError(HttpCode.NOT_FOUND, "Client not found") + ); + } - // [site] = await db - // .select() - // .from(sites) - // .where(eq(sites.siteId, client.siteId)); + // Get all sites that this client is connected to + const clientSitePairs = await db + .select() + .from(clientSites) + .where(eq(clientSites.clientId, client.clientId)); + + if (clientSitePairs.length === 0) { + logger.warn(`No sites found for client: ${client.clientId}`); + return next( + createHttpError(HttpCode.NOT_FOUND, "No sites found for client") + ); + } + + // Get all sites details + const siteIds = clientSitePairs.map(pair => pair.siteId); + + for (const siteId of siteIds) { + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)); + + if (site && site.subnet && site.listenPort) { + destinations.push({ + destinationIP: site.subnet.split("/")[0], + destinationPort: site.listenPort + }); + } + } } else if (newtId) { const { session, newt: newtSession } = @@ -114,7 +153,10 @@ export async function updateHolePunch( ); } - [site] = await db + currentSiteId = newt.siteId; + + // Update the current site with the new endpoint + const [updatedSite] = await db .update(sites) .set({ endpoint: `${ip}:${port}`, @@ -122,18 +164,70 @@ export async function updateHolePunch( }) .where(eq(sites.siteId, newt.siteId)) .returning(); + + if (!updatedSite || !updatedSite.subnet) { + logger.warn(`Site not found: ${newt.siteId}`); + return next( + createHttpError(HttpCode.NOT_FOUND, "Site not found") + ); + } + + // Find all clients that connect to this site + const sitesClientPairs = await db + .select() + .from(clientSites) + .where(eq(clientSites.siteId, newt.siteId)); + + // Get client details for each client + for (const pair of sitesClientPairs) { + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, pair.clientId)); + + if (client && client.endpoint) { + const [host, portStr] = client.endpoint.split(':'); + if (host && portStr) { + destinations.push({ + destinationIP: host, + destinationPort: parseInt(portStr, 10) + }); + } + } + } + + // If this is a newt/site, also add other sites in the same org + if (updatedSite.orgId) { + const orgSites = await db + .select() + .from(sites) + .where(eq(sites.orgId, updatedSite.orgId)); + + for (const site of orgSites) { + // Don't add the current site to the destinations + if (site.siteId !== currentSiteId && site.subnet && site.endpoint && site.listenPort) { + const [host, portStr] = site.endpoint.split(':'); + if (host && portStr) { + destinations.push({ + destinationIP: host, + destinationPort: site.listenPort + }); + } + } + } + } } - if (!site || !site.endpoint || !site.subnet) { + if (destinations.length === 0) { logger.warn( - `Site not found for olmId: ${olmId} or newtId: ${newtId}` + `No peer destinations found for olmId: ${olmId} or newtId: ${newtId}` ); - return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); + return next(createHttpError(HttpCode.NOT_FOUND, "No peer destinations found")); } + // Return the new multi-peer structure return res.status(HttpCode.OK).send({ - destinationIp: site.subnet.split("/")[0], - destinationPort: site.listenPort + destinations: destinations }); } catch (error) { logger.error(error); @@ -144,4 +238,4 @@ export async function updateHolePunch( ) ); } -} +} \ No newline at end of file