diff --git a/server/routers/gerbil/getAllRelays.ts b/server/routers/gerbil/getAllRelays.ts new file mode 100644 index 00000000..2284d4eb --- /dev/null +++ b/server/routers/gerbil/getAllRelays.ts @@ -0,0 +1,89 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { clients, exitNodes, newts, olms, Site, sites } from "@server/db/schema"; +import { db } from "@server/db"; +import { eq } from "drizzle-orm"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; + +// Define Zod schema for request validation +const getAllRelaysSchema = z.object({ + publicKey: z.string().optional(), +}); + +export async function getAllRelays( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + // Validate request parameters + const parsedParams = getAllRelaysSchema.safeParse(req.body); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { publicKey } = parsedParams.data; + + if (!publicKey) { + return next(createHttpError(HttpCode.BAD_REQUEST, 'publicKey is required')); + } + + // Fetch exit node + let [exitNode] = await db.select().from(exitNodes).where(eq(exitNodes.publicKey, publicKey)); + if (!exitNode) { + return next(createHttpError(HttpCode.NOT_FOUND, "Exit node not found")); + } + + // Fetch sites for this exit node + const sitesRes = await db.select().from(sites).where(eq(sites.exitNodeId, exitNode.exitNodeId)); + + if (sitesRes.length === 0) { + return next(createHttpError(HttpCode.NOT_FOUND, "No sites found for this exit node")); + } + + // get the clients on each site and map them to the site + const sitesAndClients = await Promise.all(sitesRes.map(async (site) => { + const clientsRes = await db.select().from(clients).where(eq(clients.siteId, site.siteId)); + return { + site, + clients: clientsRes + }; + })); + + let mappings: { [key: string]: { + destinationIp: string; + destinationPort: number; + } } = {}; + + for (const siteAndClients of sitesAndClients) { + const { site, clients } = siteAndClients; + for (const client of clients) { + if (!client.endpoint || !site.endpoint || !site.subnet) { + continue; + } + mappings[client.endpoint] = { + destinationIp: site.subnet.split("/")[0], + destinationPort: parseInt(site.endpoint.split(":")[1]) + }; + } + } + + return res.status(HttpCode.OK).send({ mappings }); + } catch (error) { + logger.error(error); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "An error occurred..." + ) + ); + } +} diff --git a/server/routers/gerbil/getConfig.ts b/server/routers/gerbil/getConfig.ts index 95e0df6b..0c50944e 100644 --- a/server/routers/gerbil/getConfig.ts +++ b/server/routers/gerbil/getConfig.ts @@ -79,9 +79,7 @@ export async function getConfig(req: Request, res: Response, next: NextFunction) } // Fetch sites for this exit node - const sitesRes = await db.query.sites.findMany({ - where: eq(sites.exitNodeId, exitNode[0].exitNodeId), - }); + const sitesRes = await db.select().from(sites).where(eq(sites.exitNodeId, exitNode[0].exitNodeId)); const peers = await Promise.all(sitesRes.map(async (site) => { return { diff --git a/server/routers/gerbil/index.ts b/server/routers/gerbil/index.ts index bcf1eb24..4a4f3b60 100644 --- a/server/routers/gerbil/index.ts +++ b/server/routers/gerbil/index.ts @@ -1,3 +1,4 @@ export * from "./getConfig"; export * from "./receiveBandwidth"; -export * from "./updateHolePunch"; \ No newline at end of file +export * from "./updateHolePunch"; +export * from "./getAllRelays"; \ No newline at end of file diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index 36002f57..68a7282f 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { clients, newts, olms, sites } from "@server/db/schema"; +import { clients, newts, olms, Site, sites } from "@server/db/schema"; import { db } from "@server/db"; import { eq } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; @@ -36,7 +36,9 @@ export async function updateHolePunch( const { olmId, newtId, ip, port, timestamp } = parsedParams.data; - logger.debug(`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId} or newtId: ${newtId}`); + // logger.debug(`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId} or newtId: ${newtId}`); + + let site: Site | undefined; if (olmId) { const [olm] = await db @@ -51,13 +53,19 @@ export async function updateHolePunch( ); } - await db + const [client] = await db .update(clients) .set({ endpoint: `${ip}:${port}`, lastHolePunch: timestamp }) - .where(eq(clients.clientId, olm.clientId)); + .where(eq(clients.clientId, olm.clientId)) + .returning(); + + [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, client.siteId)); } else if (newtId) { const [newt] = await db .select() @@ -71,16 +79,27 @@ export async function updateHolePunch( ); } - await db + [site] = await db .update(sites) .set({ endpoint: `${ip}:${port}`, lastHolePunch: timestamp }) - .where(eq(sites.siteId, newt.siteId)); + .where(eq(sites.siteId, newt.siteId)) + .returning(); } - return res.status(HttpCode.OK).send({}); + if (!site || !site.endpoint || !site.subnet) { + logger.warn(`Site not found for olmId: ${olmId} or newtId: ${newtId}`); + return next( + createHttpError(HttpCode.NOT_FOUND, "Site not found") + ); + } + + return res.status(HttpCode.OK).send({ + destinationIp: site.subnet.split("/")[0], + destinationPort: parseInt(site.endpoint.split(":")[1]) + }); } catch (error) { logger.error(error); return next( diff --git a/server/routers/internal.ts b/server/routers/internal.ts index 8392cc6e..6b57e3f6 100644 --- a/server/routers/internal.ts +++ b/server/routers/internal.ts @@ -35,6 +35,7 @@ internalRouter.use("/gerbil", gerbilRouter); gerbilRouter.post("/get-config", gerbil.getConfig); gerbilRouter.post("/receive-bandwidth", gerbil.receiveBandwidth); gerbilRouter.post("/update-hole-punch", gerbil.updateHolePunch); +gerbilRouter.post("/get-all-relays", gerbil.getAllRelays); // Badger routes const badgerRouter = Router(); diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index aa994db5..62934d2f 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -10,7 +10,6 @@ import config from "@server/lib/config"; const inputSchema = z.object({ publicKey: z.string(), - endpoint: z.string() }); type Input = z.infer; @@ -42,7 +41,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { return; } - const { publicKey, endpoint } = message.data as Input; + const { publicKey } = message.data as Input; const siteId = newt.siteId; @@ -66,7 +65,6 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { .update(sites) .set({ publicKey, - // endpoint, address, listenPort }) @@ -82,7 +80,6 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { .update(sites) .set({ publicKey - // endpoint }) .where(eq(sites.siteId, siteId)) .returning(); diff --git a/server/routers/newt/peers.ts b/server/routers/newt/peers.ts index a4bf8ae7..5808484f 100644 --- a/server/routers/newt/peers.ts +++ b/server/routers/newt/peers.ts @@ -2,6 +2,7 @@ import db from '@server/db'; import { newts, sites } from '@server/db/schema'; import { eq } from 'drizzle-orm'; import { sendToClient } from '../ws'; +import logger from '@server/logger'; export async function addPeer(siteId: number, peer: { publicKey: string; @@ -24,6 +25,8 @@ export async function addPeer(siteId: number, peer: { type: 'newt/wg/peer/add', data: peer }); + + logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`); } export async function deletePeer(siteId: number, publicKey: string) { @@ -44,4 +47,6 @@ export async function deletePeer(siteId: number, publicKey: string) { publicKey } }); + + logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`); } \ No newline at end of file diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index ce40a35e..00ab9358 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -74,8 +74,6 @@ export async function getOlmToken( ); } - logger.debug("Existing olm: ", existingOlmRes); - const existingOlm = existingOlmRes[0]; const validSecret = await verifyPassword(