Start changes for multi site clients

- Org subnet and assign sites and clients out of the same subnet group
  on each org
- Add join table for client on multiple sites
- Start to handle websocket endpoints for these multiple connections
This commit is contained in:
Owen 2025-03-25 22:01:08 -04:00
parent fbe7e0a427
commit 87012c47ea
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
8 changed files with 210 additions and 196 deletions

View file

@ -38,11 +38,9 @@ gerbil:
site_block_size: 30
subnet_group: 100.89.137.0/20
newt:
start_port: 51820
orgs:
block_size: 24
subnet_group: 100.89.138.0/20
site_block_size: 30
rate_limits:
global:

View file

@ -11,7 +11,8 @@ export const domains = sqliteTable("domains", {
export const orgs = sqliteTable("orgs", {
orgId: text("orgId").primaryKey(),
name: text("name").notNull()
name: text("name").notNull(),
subnet: text("subnet").notNull(),
});
export const orgDomains = sqliteTable("orgDomains", {
@ -47,7 +48,6 @@ export const sites = sqliteTable("sites", {
address: text("address"), // this is the address of the wireguard interface in gerbil
endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config
publicKey: text("pubicKey"),
listenPort: integer("listenPort"),
lastHolePunch: integer("lastHolePunch"),
});
@ -138,11 +138,6 @@ export const newts = sqliteTable("newt", {
export const clients = sqliteTable("clients", {
clientId: integer("id").primaryKey({ autoIncrement: true }),
siteId: integer("siteId")
.references(() => sites.siteId, {
onDelete: "cascade"
})
.notNull(),
orgId: text("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
@ -160,6 +155,15 @@ export const clients = sqliteTable("clients", {
lastHolePunch: integer("lastHolePunch"),
});
export const clientSites = sqliteTable("clientSites", {
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
});
export const olms = sqliteTable("olms", {
olmId: text("id").primaryKey(),
secretHash: text("secretHash").notNull(),
@ -516,6 +520,7 @@ export type ResourceWhitelist = InferSelectModel<typeof resourceWhitelist>;
export type VersionMigration = InferSelectModel<typeof versionMigrations>;
export type ResourceRule = InferSelectModel<typeof resourceRules>;
export type Client = InferSelectModel<typeof clients>;
export type ClientSite = InferSelectModel<typeof clientSites>;
export type RoleClient = InferSelectModel<typeof roleClients>;
export type UserClient = InferSelectModel<typeof userClients>;
export type Domain = InferSelectModel<typeof domains>;

View file

@ -105,11 +105,9 @@ const configSchema = z.object({
block_size: z.number().positive().gt(0),
site_block_size: z.number().positive().gt(0)
}),
newt: z.object({
orgs: z.object({
block_size: z.number().positive().gt(0),
subnet_group: z.string(),
start_port: portSchema,
site_block_size: z.number().positive().gt(0)
}),
rate_limits: z.object({
global: z.object({

View file

@ -1,3 +1,8 @@
import db from "@server/db";
import { clients, orgs, sites } from "@server/db/schema";
import { and, eq, isNotNull } from "drizzle-orm";
import config from "@server/lib/config";
interface IPRange {
start: bigint;
end: bigint;
@ -205,3 +210,57 @@ export function isIpInCidr(ip: string, cidr: string): boolean {
const range = cidrToRange(cidr);
return ipBigInt >= range.start && ipBigInt <= range.end;
}
export async function getNextAvailableClientSubnet(orgId: string): Promise<string> {
const existingAddressesSites = await db
.select({
address: sites.address
})
.from(sites)
.where(and(isNotNull(sites.address), eq(sites.orgId, orgId)));
const existingAddressesClients = await db
.select({
address: clients.subnet
})
.from(clients)
.where(and(isNotNull(clients.subnet), eq(clients.orgId, orgId)));
const addresses = [
...existingAddressesSites.map((site) => site.address),
...existingAddressesClients.map((client) => client.address)
].filter((address) => address !== null) as string[];
let subnet = findNextAvailableCidr(
addresses,
32,
config.getRawConfig().orgs.subnet_group
); // pick the sites address in the org
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
return subnet;
}
export async function getNextAvailableOrgSubnet(): Promise<string> {
const existingAddresses = await db
.select({
subnet: orgs.subnet
})
.from(orgs)
.where(isNotNull(orgs.subnet));
const addresses = existingAddresses.map((org) => org.subnet);
let subnet = findNextAvailableCidr(
addresses,
config.getRawConfig().orgs.block_size,
config.getRawConfig().orgs.subnet_group
);
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
return subnet;
}

View file

@ -6,7 +6,7 @@ import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { findNextAvailableCidr } from "@server/lib/ip";
import { findNextAvailableCidr, getNextAvailableClientSubnet } from "@server/lib/ip";
import { generateId } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import { z } from "zod";
@ -88,36 +88,7 @@ export async function pickClientDefaults(
const { address, publicKey, listenPort, endpoint } = parsedSite.data;
const clientsQuery = await db
.select({
subnet: clients.subnet
})
.from(clients)
.where(eq(clients.siteId, site.siteId));
let subnets = clientsQuery.map((client) => client.subnet);
// exclude the exit node address by replacing after the / with a site block size
subnets.push(
address.replace(
/\/\d+$/,
`/${config.getRawConfig().newt.site_block_size}`
)
);
const newSubnet = findNextAvailableCidr(
subnets,
config.getRawConfig().newt.site_block_size,
address
);
if (!newSubnet) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"No available subnets"
)
);
}
const newSubnet = await getNextAvailableClientSubnet(site.orgId);
const olmId = generateId(15);
const secret = generateId(48);
@ -130,8 +101,7 @@ export async function pickClientDefaults(
name: site.name,
listenPort: listenPort,
endpoint: endpoint,
// subnet: `${newSubnet.split("/")[0]}/${config.getRawConfig().newt.block_size}`, // we want the block size of the whole subnet
subnet: newSubnet,
subnet: `${newSubnet.split("/")[0]}/${config.getRawConfig().orgs.block_size}`, // we want the block size of the whole org
olmId: olmId,
olmSecret: secret
},

View file

@ -3,13 +3,12 @@ import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import db from "@server/db";
import { clients, Newt, Site, sites } from "@server/db/schema";
import { eq, isNotNull } from "drizzle-orm";
import { findNextAvailableCidr } from "@server/lib/ip";
import config from "@server/lib/config";
import { clients, clientSites, Newt, Site, sites } from "@server/db/schema";
import { eq } from "drizzle-orm";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
const inputSchema = z.object({
publicKey: z.string(),
publicKey: z.string()
});
type Input = z.infer<typeof inputSchema>;
@ -57,16 +56,15 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
let site: Site | undefined;
if (!siteRes.address) {
const address = await getNextAvailableSubnet();
const listenPort = await getNextAvailablePort();
let address = await getNextAvailableClientSubnet(siteRes.orgId);
address = address.split("/")[0]; // get the first part of the CIDR
// create a new exit node
const [updateRes] = await db
.update(sites)
.set({
publicKey,
address,
listenPort
address
})
.where(eq(sites.siteId, siteId))
.returning();
@ -95,28 +93,33 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
const clientsRes = await db
.select()
.from(clients)
.where(eq(clients.siteId, siteId));
.innerJoin(clientSites, eq(clients.clientId, clientSites.clientId))
.where(eq(clientSites.siteId, siteId));
const now = new Date().getTime() / 1000;
const peers = await Promise.all(
clientsRes
.filter((client) => {
if (client.lastHolePunch && now - client.lastHolePunch > 6) {
// This filter wasn't returning anything - fixed to properly filter clients
if (
!client.clients.lastHolePunch ||
now - client.clients.lastHolePunch > 6
) {
logger.warn("Client last hole punch is too old");
return;
return false;
}
return true;
})
.map(async (client) => {
return {
publicKey: client.pubKey,
allowedIps: [client.subnet],
endpoint: client.endpoint
publicKey: client.clients.pubKey,
allowedIps: [client.clients.subnet],
endpoint: client.clients.endpoint
};
})
);
const configResponse = {
listenPort: site.listenPort,
ipAddress: site.address,
peers
};
@ -134,56 +137,3 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
excludeSender: false // Include sender in broadcast
};
};
async function getNextAvailableSubnet(): Promise<string> {
const existingAddresses = await db
.select({
address: sites.address
})
.from(sites)
.where(isNotNull(sites.address));
const addresses = existingAddresses
.map((a) => a.address)
.filter((a) => a) as string[];
let subnet = findNextAvailableCidr(
addresses,
config.getRawConfig().newt.block_size,
config.getRawConfig().newt.subnet_group
);
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
// replace the last octet with 1
subnet =
subnet.split(".").slice(0, 3).join(".") +
".1" +
"/" +
subnet.split("/")[1];
return subnet;
}
async function getNextAvailablePort(): Promise<number> {
// Get all existing ports from exitNodes table
const existingPorts = await db
.select({
listenPort: sites.listenPort
})
.from(sites);
// Find the first available port between 1024 and 65535
let nextPort = config.getRawConfig().newt.start_port;
for (const port of existingPorts) {
if (port.listenPort && port.listenPort > nextPort) {
break;
}
nextPort++;
if (nextPort > 65535) {
throw new Error("No available ports remaining in space");
}
}
return nextPort;
}

View file

@ -1,91 +1,53 @@
import db from "@server/db";
import { MessageHandler } from "../ws";
import { clients, exitNodes, Olm, olms, sites } from "@server/db/schema";
import { eq } from "drizzle-orm";
import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db/schema";
import { eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
logger.info("Handling register olm message!");
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
logger.warn("Olm has no client ID!");
return;
}
const clientId = olm.clientId;
const { publicKey } = message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client || !client.siteId) {
logger.warn("Site not found or does not have exit node");
if (!client) {
logger.warn("Client not found");
return;
}
const [site] = await db
// Get all site associations for this client
const clientSiteAssociations = await db
.select()
.from(sites)
.where(eq(sites.siteId, client.siteId))
.limit(1);
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
if (!site) {
logger.warn("Site not found or does not have exit node");
return;
}
if (!site.exitNodeId) {
logger.warn("Site does not have exit node");
return;
}
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
sendToClient(olm.olmId, {
type: "olm/wg/holepunch",
data: {
serverPubKey: exitNode.publicKey,
}
});
// make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old
if (!site.endpoint || !client.endpoint) {
logger.warn("Site or client has no endpoint or listen port");
return;
}
const now = new Date().getTime() / 1000;
if (site.lastHolePunch && now - site.lastHolePunch > 6) {
logger.warn("Site last hole punch is too old");
return;
}
if (client.lastHolePunch && now - client.lastHolePunch > 6) {
logger.warn("Client last hole punch is too old");
if (clientSiteAssociations.length === 0) {
logger.warn("Client is not associated with any sites");
return;
}
// Update the client's public key
await db
.update(clients)
.set({
@ -94,34 +56,102 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.where(eq(clients.clientId, olm.clientId))
.returning();
if (client.pubKey && client.pubKey !== publicKey) {
logger.info("Public key mismatch. Deleting old peer...");
await deletePeer(site.siteId, client.pubKey);
// Check if public key changed and handle old peer deletion later
const pubKeyChanged = client.pubKey && client.pubKey !== publicKey;
// Get all sites data
const siteIds = clientSiteAssociations.map(cs => cs.siteId);
const sitesData = await db
.select()
.from(sites)
.where(inArray(sites.siteId, siteIds));
// Prepare an array to store site configurations
const siteConfigurations = [];
const now = new Date().getTime() / 1000;
// Process each site
for (const site of sitesData) {
if (!site.exitNodeId) {
logger.warn(`Site ${site.siteId} does not have exit node, skipping`);
continue;
}
// Get the exit node for this site
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(`Site ${site.siteId} has no endpoint, skipping`);
continue;
}
if (site.lastHolePunch && now - site.lastHolePunch > 6) {
logger.warn(`Site ${site.siteId} last hole punch is too old, skipping`);
continue;
}
if (client.lastHolePunch && now - client.lastHolePunch > 6) {
logger.warn("Client last hole punch is too old, skipping all sites");
break;
}
// If public key changed, delete old peer from this site
if (pubKeyChanged) {
logger.info(`Public key mismatch. Deleting old peer from site ${site.siteId}...`);
await deletePeer(site.siteId, client.pubKey);
}
if (!site.subnet) {
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
continue;
}
// Add the peer to the exit node for this site
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [client.subnet],
endpoint: client.endpoint
});
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
});
// Send holepunch message for each site
sendToClient(olm.olmId, {
type: "olm/wg/holepunch",
data: {
serverPubKey: exitNode.publicKey,
siteId: site.siteId
}
});
}
if (!site.subnet) {
logger.warn("Site has no subnet");
// If we have no valid site configurations, don't send a connect message
if (siteConfigurations.length === 0) {
logger.warn("No valid site configurations found");
return;
}
// add the peer to the exit node
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [client.subnet],
endpoint: client.endpoint
});
// Return connect message with all site configurations
return {
message: {
type: "olm/wg/connect",
data: {
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address!.split("/")[0],
tunnelIP: `${client.subnet.split("/")[0]}/${site.address!.split("/")[1]}` // put the client ip in the same subnet as the site. TODO: Is this right? Maybe we need th make .subnet work properly!
sites: siteConfigurations,
tunnelIP: client.subnet,
}
},
broadcast: false, // Send to all olms
excludeSender: false // Include sender in broadcast
broadcast: false,
excludeSender: false
};
};

View file

@ -19,6 +19,7 @@ import { createAdminRole } from "@server/setup/ensureActions";
import config from "@server/lib/config";
import { fromError } from "zod-validation-error";
import { defaultRoleAllowedActions } from "../role";
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
const createOrgSchema = z
.object({
@ -88,6 +89,8 @@ export async function createOrg(
let error = "";
let org: Org | null = null;
const subnet = await getNextAvailableOrgSubnet();
await db.transaction(async (trx) => {
const allDomains = await trx
.select()
@ -98,7 +101,8 @@ export async function createOrg(
.insert(orgs)
.values({
orgId,
name
name,
subnet,
})
.returning();