diff --git a/server/db/schema.ts b/server/db/schema.ts index 24c705a9..7355d0ca 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -31,7 +31,8 @@ export const sites = sqliteTable("sites", { address: text("address"), // this is the address of the wireguard interface in gerbil endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config publicKey: text("pubicKey"), - listenPort: integer("listenPort") + listenPort: integer("listenPort"), + lastHolePunch: integer("lastHolePunch"), }); export const resources = sqliteTable("resources", { @@ -135,7 +136,9 @@ export const clients = sqliteTable("clients", { megabytesOut: integer("bytesOut"), lastBandwidthUpdate: text("lastBandwidthUpdate"), type: text("type").notNull(), // "olm" - online: integer("online", { mode: "boolean" }).notNull().default(false) + online: integer("online", { mode: "boolean" }).notNull().default(false), + endpoint: text("endpoint"), + lastHolePunch: integer("lastHolePunch"), }); export const olms = sqliteTable("olms", { diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts new file mode 100644 index 00000000..50648f13 --- /dev/null +++ b/server/routers/gerbil/updateHolePunch.ts @@ -0,0 +1,91 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { clients, newts, olms, sites } from "@server/db/schema"; +import { db } from "@server/db"; +import { eq } from "drizzle-orm"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; + +// Define Zod schema for request validation +const updateHolePunchSchema = z.object({ + olmId: z.string().optional(), + newtId: z.string().optional(), + ip: z.string(), + port: z.number(), + timestamp: z.number() +}); + +export async function updateHolePunch( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + // Validate request parameters + const parsedParams = updateHolePunchSchema.safeParse(req.body); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { olmId, newtId, ip, port, timestamp } = parsedParams.data; + + if (olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.olmId, olmId)); + + if (!olm || !olm.clientId) { + logger.warn(`Olm not found: ${olmId}`); + return next( + createHttpError(HttpCode.NOT_FOUND, "Olm not found") + ); + } + + await db + .update(clients) + .set({ + endpoint: `${ip}:${port}`, + lastHolePunch: timestamp + }) + .where(eq(clients.clientId, olm.clientId)); + } else if (newtId) { + const [newt] = await db + .select() + .from(newts) + .where(eq(newts.newtId, newtId)); + + if (!newt || !newt.siteId) { + logger.warn(`Newt not found: ${newtId}`); + return next( + createHttpError(HttpCode.NOT_FOUND, "New not found") + ); + } + + await db + .update(sites) + .set({ + endpoint: `${ip}:${port}`, + lastHolePunch: timestamp + }) + .where(eq(sites.siteId, newt.siteId)); + } + + return res.status(HttpCode.OK).send({}); + } catch (error) { + logger.error(error); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "An error occurred..." + ) + ); + } +}