From 87012c47ea6bd12da7b7ac84e41ed4a670750a01 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 25 Mar 2025 22:01:08 -0400 Subject: [PATCH] Start changes for multi site clients - Org subnet and assign sites and clients out of the same subnet group on each org - Add join table for client on multiple sites - Start to handle websocket endpoints for these multiple connections --- install/config/config.yml | 4 +- server/db/schema.ts | 19 +- server/lib/config.ts | 4 +- server/lib/ip.ts | 59 ++++++ server/routers/client/pickClientDefaults.ts | 38 +--- server/routers/newt/handleGetConfigMessage.ts | 90 ++------- .../routers/olm/handleOlmRegisterMessage.ts | 186 ++++++++++-------- server/routers/org/createOrg.ts | 6 +- 8 files changed, 210 insertions(+), 196 deletions(-) diff --git a/install/config/config.yml b/install/config/config.yml index 043b1421..d972c637 100644 --- a/install/config/config.yml +++ b/install/config/config.yml @@ -38,11 +38,9 @@ gerbil: site_block_size: 30 subnet_group: 100.89.137.0/20 -newt: - start_port: 51820 +orgs: block_size: 24 subnet_group: 100.89.138.0/20 - site_block_size: 30 rate_limits: global: diff --git a/server/db/schema.ts b/server/db/schema.ts index 74e9ecf2..875754f5 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -11,7 +11,8 @@ export const domains = sqliteTable("domains", { export const orgs = sqliteTable("orgs", { orgId: text("orgId").primaryKey(), - name: text("name").notNull() + name: text("name").notNull(), + subnet: text("subnet").notNull(), }); export const orgDomains = sqliteTable("orgDomains", { @@ -47,7 +48,6 @@ export const sites = sqliteTable("sites", { address: text("address"), // this is the address of the wireguard interface in gerbil endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config publicKey: text("pubicKey"), - listenPort: integer("listenPort"), lastHolePunch: integer("lastHolePunch"), }); @@ -138,11 +138,6 @@ export const newts = sqliteTable("newt", { export const clients = sqliteTable("clients", { clientId: integer("id").primaryKey({ autoIncrement: true }), - siteId: integer("siteId") - .references(() => sites.siteId, { - onDelete: "cascade" - }) - .notNull(), orgId: text("orgId") .references(() => orgs.orgId, { onDelete: "cascade" @@ -160,6 +155,15 @@ export const clients = sqliteTable("clients", { lastHolePunch: integer("lastHolePunch"), }); +export const clientSites = sqliteTable("clientSites", { + clientId: integer("clientId") + .notNull() + .references(() => clients.clientId, { onDelete: "cascade" }), + siteId: integer("siteId") + .notNull() + .references(() => sites.siteId, { onDelete: "cascade" }), +}); + export const olms = sqliteTable("olms", { olmId: text("id").primaryKey(), secretHash: text("secretHash").notNull(), @@ -516,6 +520,7 @@ export type ResourceWhitelist = InferSelectModel; export type VersionMigration = InferSelectModel; export type ResourceRule = InferSelectModel; export type Client = InferSelectModel; +export type ClientSite = InferSelectModel; export type RoleClient = InferSelectModel; export type UserClient = InferSelectModel; export type Domain = InferSelectModel; diff --git a/server/lib/config.ts b/server/lib/config.ts index b41be6ec..9df6a55f 100644 --- a/server/lib/config.ts +++ b/server/lib/config.ts @@ -105,11 +105,9 @@ const configSchema = z.object({ block_size: z.number().positive().gt(0), site_block_size: z.number().positive().gt(0) }), - newt: z.object({ + orgs: z.object({ block_size: z.number().positive().gt(0), subnet_group: z.string(), - start_port: portSchema, - site_block_size: z.number().positive().gt(0) }), rate_limits: z.object({ global: z.object({ diff --git a/server/lib/ip.ts b/server/lib/ip.ts index 86fe1169..a3a78027 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -1,3 +1,8 @@ +import db from "@server/db"; +import { clients, orgs, sites } from "@server/db/schema"; +import { and, eq, isNotNull } from "drizzle-orm"; +import config from "@server/lib/config"; + interface IPRange { start: bigint; end: bigint; @@ -204,4 +209,58 @@ export function isIpInCidr(ip: string, cidr: string): boolean { const ipBigInt = ipToBigInt(ip); const range = cidrToRange(cidr); return ipBigInt >= range.start && ipBigInt <= range.end; +} + +export async function getNextAvailableClientSubnet(orgId: string): Promise { + const existingAddressesSites = await db + .select({ + address: sites.address + }) + .from(sites) + .where(and(isNotNull(sites.address), eq(sites.orgId, orgId))); + + const existingAddressesClients = await db + .select({ + address: clients.subnet + }) + .from(clients) + .where(and(isNotNull(clients.subnet), eq(clients.orgId, orgId))); + + const addresses = [ + ...existingAddressesSites.map((site) => site.address), + ...existingAddressesClients.map((client) => client.address) + ].filter((address) => address !== null) as string[]; + + let subnet = findNextAvailableCidr( + addresses, + 32, + config.getRawConfig().orgs.subnet_group + ); // pick the sites address in the org + if (!subnet) { + throw new Error("No available subnets remaining in space"); + } + + return subnet; +} + +export async function getNextAvailableOrgSubnet(): Promise { + const existingAddresses = await db + .select({ + subnet: orgs.subnet + }) + .from(orgs) + .where(isNotNull(orgs.subnet)); + + const addresses = existingAddresses.map((org) => org.subnet); + + let subnet = findNextAvailableCidr( + addresses, + config.getRawConfig().orgs.block_size, + config.getRawConfig().orgs.subnet_group + ); + if (!subnet) { + throw new Error("No available subnets remaining in space"); + } + + return subnet; } \ No newline at end of file diff --git a/server/routers/client/pickClientDefaults.ts b/server/routers/client/pickClientDefaults.ts index d77ae3bb..5e87759d 100644 --- a/server/routers/client/pickClientDefaults.ts +++ b/server/routers/client/pickClientDefaults.ts @@ -6,7 +6,7 @@ import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; -import { findNextAvailableCidr } from "@server/lib/ip"; +import { findNextAvailableCidr, getNextAvailableClientSubnet } from "@server/lib/ip"; import { generateId } from "@server/auth/sessions/app"; import config from "@server/lib/config"; import { z } from "zod"; @@ -88,37 +88,8 @@ export async function pickClientDefaults( const { address, publicKey, listenPort, endpoint } = parsedSite.data; - const clientsQuery = await db - .select({ - subnet: clients.subnet - }) - .from(clients) - .where(eq(clients.siteId, site.siteId)); - - let subnets = clientsQuery.map((client) => client.subnet); - - // exclude the exit node address by replacing after the / with a site block size - subnets.push( - address.replace( - /\/\d+$/, - `/${config.getRawConfig().newt.site_block_size}` - ) - ); - - const newSubnet = findNextAvailableCidr( - subnets, - config.getRawConfig().newt.site_block_size, - address - ); - if (!newSubnet) { - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - "No available subnets" - ) - ); - } - + const newSubnet = await getNextAvailableClientSubnet(site.orgId); + const olmId = generateId(15); const secret = generateId(48); @@ -130,8 +101,7 @@ export async function pickClientDefaults( name: site.name, listenPort: listenPort, endpoint: endpoint, - // subnet: `${newSubnet.split("/")[0]}/${config.getRawConfig().newt.block_size}`, // we want the block size of the whole subnet - subnet: newSubnet, + subnet: `${newSubnet.split("/")[0]}/${config.getRawConfig().orgs.block_size}`, // we want the block size of the whole org olmId: olmId, olmSecret: secret }, diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index 62934d2f..00b5ee64 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -3,13 +3,12 @@ import { MessageHandler } from "../ws"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import db from "@server/db"; -import { clients, Newt, Site, sites } from "@server/db/schema"; -import { eq, isNotNull } from "drizzle-orm"; -import { findNextAvailableCidr } from "@server/lib/ip"; -import config from "@server/lib/config"; +import { clients, clientSites, Newt, Site, sites } from "@server/db/schema"; +import { eq } from "drizzle-orm"; +import { getNextAvailableClientSubnet } from "@server/lib/ip"; const inputSchema = z.object({ - publicKey: z.string(), + publicKey: z.string() }); type Input = z.infer; @@ -57,16 +56,15 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { let site: Site | undefined; if (!siteRes.address) { - const address = await getNextAvailableSubnet(); - const listenPort = await getNextAvailablePort(); + let address = await getNextAvailableClientSubnet(siteRes.orgId); + address = address.split("/")[0]; // get the first part of the CIDR // create a new exit node const [updateRes] = await db .update(sites) .set({ publicKey, - address, - listenPort + address }) .where(eq(sites.siteId, siteId)) .returning(); @@ -95,28 +93,33 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { const clientsRes = await db .select() .from(clients) - .where(eq(clients.siteId, siteId)); + .innerJoin(clientSites, eq(clients.clientId, clientSites.clientId)) + .where(eq(clientSites.siteId, siteId)); const now = new Date().getTime() / 1000; const peers = await Promise.all( clientsRes .filter((client) => { - if (client.lastHolePunch && now - client.lastHolePunch > 6) { + // This filter wasn't returning anything - fixed to properly filter clients + if ( + !client.clients.lastHolePunch || + now - client.clients.lastHolePunch > 6 + ) { logger.warn("Client last hole punch is too old"); - return; + return false; } + return true; }) .map(async (client) => { return { - publicKey: client.pubKey, - allowedIps: [client.subnet], - endpoint: client.endpoint + publicKey: client.clients.pubKey, + allowedIps: [client.clients.subnet], + endpoint: client.clients.endpoint }; }) ); const configResponse = { - listenPort: site.listenPort, ipAddress: site.address, peers }; @@ -133,57 +136,4 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { broadcast: false, // Send to all clients excludeSender: false // Include sender in broadcast }; -}; - -async function getNextAvailableSubnet(): Promise { - const existingAddresses = await db - .select({ - address: sites.address - }) - .from(sites) - .where(isNotNull(sites.address)); - - const addresses = existingAddresses - .map((a) => a.address) - .filter((a) => a) as string[]; - - let subnet = findNextAvailableCidr( - addresses, - config.getRawConfig().newt.block_size, - config.getRawConfig().newt.subnet_group - ); - if (!subnet) { - throw new Error("No available subnets remaining in space"); - } - - // replace the last octet with 1 - subnet = - subnet.split(".").slice(0, 3).join(".") + - ".1" + - "/" + - subnet.split("/")[1]; - return subnet; -} - -async function getNextAvailablePort(): Promise { - // Get all existing ports from exitNodes table - const existingPorts = await db - .select({ - listenPort: sites.listenPort - }) - .from(sites); - - // Find the first available port between 1024 and 65535 - let nextPort = config.getRawConfig().newt.start_port; - for (const port of existingPorts) { - if (port.listenPort && port.listenPort > nextPort) { - break; - } - nextPort++; - if (nextPort > 65535) { - throw new Error("No available ports remaining in space"); - } - } - - return nextPort; -} +}; \ No newline at end of file diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 4bf46744..ac658b78 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -1,91 +1,53 @@ import db from "@server/db"; import { MessageHandler } from "../ws"; -import { clients, exitNodes, Olm, olms, sites } from "@server/db/schema"; -import { eq } from "drizzle-orm"; +import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db/schema"; +import { eq, inArray } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { const { message, client: c, sendToClient } = context; const olm = c as Olm; - logger.info("Handling register olm message!"); - if (!olm) { logger.warn("Olm not found"); return; } - if (!olm.clientId) { - logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? + logger.warn("Olm has no client ID!"); return; } - const clientId = olm.clientId; - const { publicKey } = message.data; if (!publicKey) { logger.warn("Public key not provided"); return; } - + + // Get the client const [client] = await db .select() .from(clients) .where(eq(clients.clientId, clientId)) .limit(1); - - if (!client || !client.siteId) { - logger.warn("Site not found or does not have exit node"); + + if (!client) { + logger.warn("Client not found"); return; } - - const [site] = await db + + // Get all site associations for this client + const clientSiteAssociations = await db .select() - .from(sites) - .where(eq(sites.siteId, client.siteId)) - .limit(1); - - if (!site) { - logger.warn("Site not found or does not have exit node"); + .from(clientSites) + .where(eq(clientSites.clientId, clientId)); + + if (clientSiteAssociations.length === 0) { + logger.warn("Client is not associated with any sites"); return; } - - if (!site.exitNodeId) { - logger.warn("Site does not have exit node"); - return; - } - - const [exitNode] = await db - .select() - .from(exitNodes) - .where(eq(exitNodes.exitNodeId, site.exitNodeId)) - .limit(1); - - sendToClient(olm.olmId, { - type: "olm/wg/holepunch", - data: { - serverPubKey: exitNode.publicKey, - } - }); - - // make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old - if (!site.endpoint || !client.endpoint) { - logger.warn("Site or client has no endpoint or listen port"); - return; - } - - const now = new Date().getTime() / 1000; - if (site.lastHolePunch && now - site.lastHolePunch > 6) { - logger.warn("Site last hole punch is too old"); - return; - } - - if (client.lastHolePunch && now - client.lastHolePunch > 6) { - logger.warn("Client last hole punch is too old"); - return; - } - + + // Update the client's public key await db .update(clients) .set({ @@ -93,35 +55,103 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { }) .where(eq(clients.clientId, olm.clientId)) .returning(); - - if (client.pubKey && client.pubKey !== publicKey) { - logger.info("Public key mismatch. Deleting old peer..."); - await deletePeer(site.siteId, client.pubKey); + + // Check if public key changed and handle old peer deletion later + const pubKeyChanged = client.pubKey && client.pubKey !== publicKey; + + // Get all sites data + const siteIds = clientSiteAssociations.map(cs => cs.siteId); + const sitesData = await db + .select() + .from(sites) + .where(inArray(sites.siteId, siteIds)); + + // Prepare an array to store site configurations + const siteConfigurations = []; + const now = new Date().getTime() / 1000; + + // Process each site + for (const site of sitesData) { + if (!site.exitNodeId) { + logger.warn(`Site ${site.siteId} does not have exit node, skipping`); + continue; + } + + // Get the exit node for this site + const [exitNode] = await db + .select() + .from(exitNodes) + .where(eq(exitNodes.exitNodeId, site.exitNodeId)) + .limit(1); + + // Validate endpoint and hole punch status + if (!site.endpoint) { + logger.warn(`Site ${site.siteId} has no endpoint, skipping`); + continue; + } + + if (site.lastHolePunch && now - site.lastHolePunch > 6) { + logger.warn(`Site ${site.siteId} last hole punch is too old, skipping`); + continue; + } + + if (client.lastHolePunch && now - client.lastHolePunch > 6) { + logger.warn("Client last hole punch is too old, skipping all sites"); + break; + } + + // If public key changed, delete old peer from this site + if (pubKeyChanged) { + logger.info(`Public key mismatch. Deleting old peer from site ${site.siteId}...`); + await deletePeer(site.siteId, client.pubKey); + } + + if (!site.subnet) { + logger.warn(`Site ${site.siteId} has no subnet, skipping`); + continue; + } + + // Add the peer to the exit node for this site + await addPeer(site.siteId, { + publicKey: publicKey, + allowedIps: [client.subnet], + endpoint: client.endpoint + }); + + // Add site configuration to the array + siteConfigurations.push({ + siteId: site.siteId, + endpoint: site.endpoint, + publicKey: site.publicKey, + serverIP: site.address, + }); + + // Send holepunch message for each site + sendToClient(olm.olmId, { + type: "olm/wg/holepunch", + data: { + serverPubKey: exitNode.publicKey, + siteId: site.siteId + } + }); } - - if (!site.subnet) { - logger.warn("Site has no subnet"); + + // If we have no valid site configurations, don't send a connect message + if (siteConfigurations.length === 0) { + logger.warn("No valid site configurations found"); return; } - - // add the peer to the exit node - await addPeer(site.siteId, { - publicKey: publicKey, - allowedIps: [client.subnet], - endpoint: client.endpoint - }); - + + // Return connect message with all site configurations return { message: { type: "olm/wg/connect", data: { - endpoint: site.endpoint, - publicKey: site.publicKey, - serverIP: site.address!.split("/")[0], - tunnelIP: `${client.subnet.split("/")[0]}/${site.address!.split("/")[1]}` // put the client ip in the same subnet as the site. TODO: Is this right? Maybe we need th make .subnet work properly! + sites: siteConfigurations, + tunnelIP: client.subnet, } }, - broadcast: false, // Send to all olms - excludeSender: false // Include sender in broadcast + broadcast: false, + excludeSender: false }; -}; +}; \ No newline at end of file diff --git a/server/routers/org/createOrg.ts b/server/routers/org/createOrg.ts index 381ce20e..fef5e2ac 100644 --- a/server/routers/org/createOrg.ts +++ b/server/routers/org/createOrg.ts @@ -19,6 +19,7 @@ import { createAdminRole } from "@server/setup/ensureActions"; import config from "@server/lib/config"; import { fromError } from "zod-validation-error"; import { defaultRoleAllowedActions } from "../role"; +import { getNextAvailableOrgSubnet } from "@server/lib/ip"; const createOrgSchema = z .object({ @@ -88,6 +89,8 @@ export async function createOrg( let error = ""; let org: Org | null = null; + const subnet = await getNextAvailableOrgSubnet(); + await db.transaction(async (trx) => { const allDomains = await trx .select() @@ -98,7 +101,8 @@ export async function createOrg( .insert(orgs) .values({ orgId, - name + name, + subnet, }) .returning();