diff --git a/server/routers/newt/handleNewtRegisterMessage.ts b/server/routers/newt/handleNewtRegisterMessage.ts index 8e263034..54a62735 100644 --- a/server/routers/newt/handleNewtRegisterMessage.ts +++ b/server/routers/newt/handleNewtRegisterMessage.ts @@ -16,7 +16,7 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => { const newt = client; - logger.info("Handling register message!"); + logger.info("Handling register newt message!"); if (!newt) { logger.warn("Newt not found"); diff --git a/server/routers/newt/peers.ts b/server/routers/newt/peers.ts new file mode 100644 index 00000000..ee22c052 --- /dev/null +++ b/server/routers/newt/peers.ts @@ -0,0 +1,46 @@ +import db from '@server/db'; +import { newts, sites } from '@server/db/schema'; +import { eq } from 'drizzle-orm'; +import { sendToClient } from '../ws'; + +export async function addPeer(siteId: number, peer: { + publicKey: string; + allowedIps: string[]; +}) { + + const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1); + if (!site) { + throw new Error(`Exit node with ID ${siteId} not found`); + } + + // get the newt on the site + const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1); + if (!newt) { + throw new Error(`Newt not found for site ${siteId}`); + } + + sendToClient(newt.newtId, { + type: 'add_peer', + data: peer + }); +} + +export async function deletePeer(siteId: number, publicKey: string) { + const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1); + if (!site) { + throw new Error(`Exit node with ID ${siteId} not found`); + } + + // get the newt on the site + const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1); + if (!newt) { + throw new Error(`Newt not found for site ${siteId}`); + } + + sendToClient(newt.newtId, { + type: 'delete_peer', + data: { + publicKey + } + }); +} \ No newline at end of file diff --git a/server/routers/olm/handleGetConfigMessage.ts b/server/routers/olm/handleGetConfigMessage.ts deleted file mode 100644 index 6e4f7ebf..00000000 --- a/server/routers/olm/handleGetConfigMessage.ts +++ /dev/null @@ -1,147 +0,0 @@ -import { z } from "zod"; -import { MessageHandler } from "../ws"; -import logger from "@server/logger"; -import { fromError } from "zod-validation-error"; -import db from "@server/db"; -import { olms, Site, sites } from "@server/db/schema"; -import { eq, isNotNull } from "drizzle-orm"; -import { findNextAvailableCidr } from "@server/lib/ip"; -import config from "@server/lib/config"; - -const inputSchema = z.object({ - publicKey: z.string(), - endpoint: z.string(), - listenPort: z.number() -}); - -type Input = z.infer; - -export const handleGetConfigMessage: MessageHandler = async (context) => { - const { message, newt, sendToClient } = context; - - logger.debug("Handling Newt get config message!"); - - if (!newt) { - logger.warn("Newt not found"); - return; - } - - if (!newt.siteId) { - logger.warn("Newt has no site!"); // TODO: Maybe we create the site here? - return; - } - - const parsed = inputSchema.safeParse(message.data); - if (!parsed.success) { - logger.error( - "handleGetConfigMessage: Invalid input: " + - fromError(parsed.error).toString() - ); - return; - } - - const { publicKey, endpoint, listenPort } = message.data as Input; - - const siteId = newt.siteId; - - const [siteRes] = await db - .select() - .from(sites) - .where(eq(sites.siteId, siteId)); - - if (!siteRes) { - logger.warn("handleGetConfigMessage: Site not found"); - return; - } - - let site: Site | undefined; - if (!site) { - const address = await getNextAvailableSubnet(); - - // create a new exit node - const [updateRes] = await db - .update(sites) - .set({ - publicKey, - endpoint, - address, - listenPort - }) - .where(eq(sites.siteId, siteId)) - .returning(); - - site = updateRes; - - logger.info(`Updated site ${siteId} with new WG Newt info`); - } else { - site = siteRes; - } - - if (!site) { - logger.error("handleGetConfigMessage: Failed to update site"); - return; - } - - const clientsRes = await db - .select() - .from(olms) - .where(eq(olms.siteId, siteId)); - - const peers = await Promise.all( - clientsRes.map(async (client) => { - return { - publicKey: client.pubKey, - allowedIps: "0.0.0.0/0" - }; - }) - ); - - const configResponse = { - listenPort: site.listenPort, // ????? - // ipAddress: exitNode[0].address, - peers - }; - - logger.debug("Sending config: ", configResponse); - - return { - message: { - type: "olm/wg/connect", // what to make the response type? - data: { - config: configResponse - } - }, - broadcast: false, // Send to all clients - excludeSender: false // Include sender in broadcast - }; -}; - -async function getNextAvailableSubnet(): Promise { - const existingAddresses = await db - .select({ - address: sites.address - }) - .from(sites) - .where(isNotNull(sites.address)); - - const addresses = existingAddresses - .map((a) => a.address) - .filter((a) => a) as string[]; - - let subnet = findNextAvailableCidr( - addresses, - config.getRawConfig().wg_site.block_size, - config.getRawConfig().wg_site.subnet_group - ); - if (!subnet) { - throw new Error("No available subnets remaining in space"); - } - - // replace the last octet with 1 - subnet = - subnet.split(".").slice(0, 3).join(".") + - ".1" + - "/" + - subnet.split("/")[1]; - return subnet; -} diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 33786f2d..859f756c 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -1,14 +1,11 @@ import db from "@server/db"; import { MessageHandler } from "../ws"; import { - exitNodes, - resources, + olms, sites, - Target, - targets } from "@server/db/schema"; -import { eq, and, sql } from "drizzle-orm"; -import { addPeer, deletePeer } from "../gerbil/peers"; +import { eq, } from "drizzle-orm"; +import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { @@ -16,7 +13,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { const olm = client; - logger.info("Handling register message!"); + logger.info("Handling register olm message!"); if (!olm) { logger.warn("Olm not found"); @@ -42,28 +39,22 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .where(eq(sites.siteId, siteId)) .limit(1); - if (!site || !site.exitNodeId) { + if (!site) { logger.warn("Site not found or does not have exit node"); return; } await db - .update(sites) + .update(olms) .set({ pubKey: publicKey }) - .where(eq(sites.siteId, siteId)) + .where(eq(olms.olmId, olm.olmId)) .returning(); - const [exitNode] = await db - .select() - .from(exitNodes) - .where(eq(exitNodes.exitNodeId, site.exitNodeId)) - .limit(1); - - if (site.pubKey && site.pubKey !== publicKey) { + if (olm.pubKey && olm.pubKey !== publicKey) { logger.info("Public key mismatch. Deleting old peer..."); - await deletePeer(site.exitNodeId, site.pubKey); + await deletePeer(site.siteId, site.pubKey); } if (!site.subnet) { @@ -72,7 +63,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { } // add the peer to the exit node - await addPeer(site.exitNodeId, { + await addPeer(site.siteId, { publicKey: publicKey, allowedIps: [site.subnet] }); @@ -81,9 +72,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { message: { type: "olm/wg/connect", data: { - endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`, - publicKey: exitNode.publicKey, - serverIP: exitNode.address.split("/")[0], + endpoint: `${site.endpoint}:${site.listenPort}`, + publicKey: site.publicKey, + serverIP: site.address!.split("/")[0], tunnelIP: site.subnet.split("/")[0] } }, diff --git a/server/routers/olm/index.ts b/server/routers/olm/index.ts index 7265331b..4c073152 100644 --- a/server/routers/olm/index.ts +++ b/server/routers/olm/index.ts @@ -1 +1,2 @@ -export * from "./pickOlmDefaults"; \ No newline at end of file +export * from "./pickOlmDefaults"; +export * from "./handleOlmRegisterMessage"; \ No newline at end of file