Finish conversion of olm reg to multi site

This commit is contained in:
Owen 2025-03-26 21:23:26 -04:00
parent 87012c47ea
commit 926ec831e2
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
2 changed files with 68 additions and 58 deletions

View file

@ -143,6 +143,9 @@ export const clients = sqliteTable("clients", {
onDelete: "cascade" onDelete: "cascade"
}) })
.notNull(), .notNull(),
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null"
}),
name: text("name").notNull(), name: text("name").notNull(),
pubKey: text("pubKey"), pubKey: text("pubKey"),
subnet: text("subnet").notNull(), subnet: text("subnet").notNull(),

View file

@ -1,6 +1,13 @@
import db from "@server/db"; import db from "@server/db";
import { MessageHandler } from "../ws"; import { MessageHandler } from "../ws";
import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db/schema"; import {
clients,
clientSites,
exitNodes,
Olm,
olms,
sites
} from "@server/db/schema";
import { eq, inArray } from "drizzle-orm"; import { eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers"; import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger"; import logger from "@server/logger";
@ -36,15 +43,21 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} }
// Get all site associations for this client if (client.exitNodeId) {
const clientSiteAssociations = await db // Get the exit node for this site
.select() const [exitNode] = await db
.from(clientSites) .select()
.where(eq(clientSites.clientId, clientId)); .from(exitNodes)
.where(eq(exitNodes.exitNodeId, client.exitNodeId))
.limit(1);
if (clientSiteAssociations.length === 0) { // Send holepunch message for each site
logger.warn("Client is not associated with any sites"); sendToClient(olm.olmId, {
return; type: "olm/wg/holepunch",
data: {
serverPubKey: exitNode.publicKey
}
});
} }
// Update the client's public key // Update the client's public key
@ -60,30 +73,25 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const pubKeyChanged = client.pubKey && client.pubKey !== publicKey; const pubKeyChanged = client.pubKey && client.pubKey !== publicKey;
// Get all sites data // Get all sites data
const siteIds = clientSiteAssociations.map(cs => cs.siteId);
const sitesData = await db const sitesData = await db
.select() .select()
.from(sites) .from(sites)
.where(inArray(sites.siteId, siteIds)); .innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.where(eq(clientSites.clientId, client.clientId));
// Prepare an array to store site configurations // Prepare an array to store site configurations
const siteConfigurations = []; const siteConfigurations = [];
const now = new Date().getTime() / 1000; const now = new Date().getTime() / 1000;
// Process each site // Process each site
for (const site of sitesData) { for (const { sites: site } of sitesData) {
if (!site.exitNodeId) { if (!site.exitNodeId) {
logger.warn(`Site ${site.siteId} does not have exit node, skipping`); logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
);
continue; continue;
} }
// Get the exit node for this site
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
// Validate endpoint and hole punch status // Validate endpoint and hole punch status
if (!site.endpoint) { if (!site.endpoint) {
logger.warn(`Site ${site.siteId} has no endpoint, skipping`); logger.warn(`Site ${site.siteId} has no endpoint, skipping`);
@ -91,19 +99,25 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} }
if (site.lastHolePunch && now - site.lastHolePunch > 6) { if (site.lastHolePunch && now - site.lastHolePunch > 6) {
logger.warn(`Site ${site.siteId} last hole punch is too old, skipping`); logger.warn(
`Site ${site.siteId} last hole punch is too old, skipping`
);
continue; continue;
} }
if (client.lastHolePunch && now - client.lastHolePunch > 6) { if (client.lastHolePunch && now - client.lastHolePunch > 6) {
logger.warn("Client last hole punch is too old, skipping all sites"); logger.warn(
"Client last hole punch is too old, skipping all sites"
);
break; break;
} }
// If public key changed, delete old peer from this site // If public key changed, delete old peer from this site
if (pubKeyChanged) { if (pubKeyChanged) {
logger.info(`Public key mismatch. Deleting old peer from site ${site.siteId}...`); logger.info(
await deletePeer(site.siteId, client.pubKey); `Public key mismatch. Deleting old peer from site ${site.siteId}...`
);
await deletePeer(site.siteId, client.pubKey!);
} }
if (!site.subnet) { if (!site.subnet) {
@ -112,27 +126,20 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} }
// Add the peer to the exit node for this site // Add the peer to the exit node for this site
await addPeer(site.siteId, { if (client.endpoint) {
publicKey: publicKey, await addPeer(site.siteId, {
allowedIps: [client.subnet], publicKey: publicKey,
endpoint: client.endpoint allowedIps: [client.subnet],
}); endpoint: client.endpoint
});
}
// Add site configuration to the array // Add site configuration to the array
siteConfigurations.push({ siteConfigurations.push({
siteId: site.siteId, siteId: site.siteId,
endpoint: site.endpoint, endpoint: site.endpoint,
publicKey: site.publicKey, publicKey: site.publicKey,
serverIP: site.address, serverIP: site.address
});
// Send holepunch message for each site
sendToClient(olm.olmId, {
type: "olm/wg/holepunch",
data: {
serverPubKey: exitNode.publicKey,
siteId: site.siteId
}
}); });
} }
@ -148,7 +155,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
type: "olm/wg/connect", type: "olm/wg/connect",
data: { data: {
sites: siteConfigurations, sites: siteConfigurations,
tunnelIP: client.subnet, tunnelIP: client.subnet
} }
}, },
broadcast: false, broadcast: false,