From 5c94887949a5e6e2be68643b13e5ece0cc51a6d1 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 15 Aug 2025 15:45:45 -0700 Subject: [PATCH] Use new exit node functions --- server/lib/exitNodes/exitNodes.ts | 2 +- server/routers/client/createClient.ts | 20 +++------- .../newt/handleNewtPingRequestMessage.ts | 24 +++++++---- .../routers/newt/handleNewtRegisterMessage.ts | 25 ++++++++---- server/routers/site/createSite.ts | 29 +++++++++++++- server/routers/site/pickSiteDefaults.ts | 40 +++++++++---------- 6 files changed, 87 insertions(+), 53 deletions(-) diff --git a/server/lib/exitNodes/exitNodes.ts b/server/lib/exitNodes/exitNodes.ts index 7b25873e..f5854e27 100644 --- a/server/lib/exitNodes/exitNodes.ts +++ b/server/lib/exitNodes/exitNodes.ts @@ -2,7 +2,7 @@ import { db, exitNodes } from "@server/db"; import logger from "@server/logger"; import { eq, and, or } from "drizzle-orm"; -export async function privateVerifyExitNodeOrgAccess( +export async function verifyExitNodeOrgAccess( exitNodeId: number, orgId: string ) { diff --git a/server/routers/client/createClient.ts b/server/routers/client/createClient.ts index 4e9dcdce..e7762223 100644 --- a/server/routers/client/createClient.ts +++ b/server/routers/client/createClient.ts @@ -24,6 +24,7 @@ import { hashPassword } from "@server/auth/password"; import { isValidCIDR, isValidIP } from "@server/lib/validators"; import { isIpInCidr } from "@server/lib/ip"; import { OpenAPITags, registry } from "@server/openApi"; +import { listExitNodes } from "@server/lib/exitNodes"; const createClientParamsSchema = z .object({ @@ -177,20 +178,9 @@ export async function createClient( await db.transaction(async (trx) => { // TODO: more intelligent way to pick the exit node - - // make sure there is an exit node by counting the exit nodes table - const nodes = await db.select().from(exitNodes); - if (nodes.length === 0) { - return next( - createHttpError( - HttpCode.NOT_FOUND, - "No exit nodes available" - ) - ); - } - - // get the first exit node - const exitNode = nodes[0]; + const exitNodesList = await listExitNodes(orgId); + const randomExitNode = + exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; const adminRole = await trx .select() @@ -208,7 +198,7 @@ export async function createClient( const [newClient] = await trx .insert(clients) .values({ - exitNodeId: exitNode.exitNodeId, + exitNodeId: randomExitNode.exitNodeId, orgId, name, subnet: updatedSubnet, diff --git a/server/routers/newt/handleNewtPingRequestMessage.ts b/server/routers/newt/handleNewtPingRequestMessage.ts index 65edea61..f93862f6 100644 --- a/server/routers/newt/handleNewtPingRequestMessage.ts +++ b/server/routers/newt/handleNewtPingRequestMessage.ts @@ -4,6 +4,7 @@ import { exitNodes, Newt } from "@server/db"; import logger from "@server/logger"; import config from "@server/lib/config"; import { ne, eq, or, and, count } from "drizzle-orm"; +import { listExitNodes } from "@server/lib/exitNodes"; export const handleNewtPingRequestMessage: MessageHandler = async (context) => { const { message, client, sendToClient } = context; @@ -16,12 +17,19 @@ export const handleNewtPingRequestMessage: MessageHandler = async (context) => { return; } - // TODO: pick which nodes to send and ping better than just all of them - let exitNodesList = await db - .select() - .from(exitNodes); + // Get the newt's orgId through the site relationship + if (!newt.siteId) { + logger.warn("Newt siteId not found"); + return; + } - exitNodesList = exitNodesList.filter((node) => node.maxConnections !== 0); + const [site] = await db + .select({ orgId: sites.orgId }) + .from(sites) + .where(eq(sites.siteId, newt.siteId)) + .limit(1); + + const exitNodesList = await listExitNodes(site.orgId, true); // filter for only the online ones let lastExitNodeId = null; if (newt.siteId) { @@ -54,9 +62,9 @@ export const handleNewtPingRequestMessage: MessageHandler = async (context) => { ) ); - if (currentConnections.count >= maxConnections) { - return null; - } + if (currentConnections.count >= maxConnections) { + return null; + } weight = (maxConnections - currentConnections.count) / diff --git a/server/routers/newt/handleNewtRegisterMessage.ts b/server/routers/newt/handleNewtRegisterMessage.ts index bb982c24..26aa3477 100644 --- a/server/routers/newt/handleNewtRegisterMessage.ts +++ b/server/routers/newt/handleNewtRegisterMessage.ts @@ -9,6 +9,7 @@ import { findNextAvailableCidr, getNextAvailableClientSubnet } from "@server/lib/ip"; +import { verifyExitNodeOrgAccess } from "@server/lib/exitNodes"; export type ExitNodePingResult = { exitNodeId: number; @@ -24,7 +25,7 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => { const { message, client, sendToClient } = context; const newt = client as Newt; - logger.info("Handling register newt message!"); + logger.debug("Handling register newt message!"); if (!newt) { logger.warn("Newt not found"); @@ -81,6 +82,18 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => { // This effectively moves the exit node to the new one exitNodeIdToQuery = exitNodeId; // Use the provided exitNodeId if it differs from the site's exitNodeId + const { exitNode, hasAccess } = await verifyExitNodeOrgAccess(exitNodeIdToQuery, oldSite.orgId); + + if (!exitNode) { + logger.warn("Exit node not found"); + return; + } + + if (!hasAccess) { + logger.warn("Not authorized to use this exit node"); + return; + } + const sitesQuery = await db .select({ subnet: sites.subnet @@ -88,14 +101,10 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => { .from(sites) .where(eq(sites.exitNodeId, exitNodeId)); - const [exitNode] = await db - .select() - .from(exitNodes) - .where(eq(exitNodes.exitNodeId, exitNodeIdToQuery)) - .limit(1); - const blockSize = config.getRawConfig().gerbil.site_block_size; - const subnets = sitesQuery.map((site) => site.subnet).filter((subnet) => subnet !== null); + const subnets = sitesQuery + .map((site) => site.subnet) + .filter((subnet) => subnet !== null); subnets.push(exitNode.address.replace(/\/\d+$/, `/${blockSize}`)); const newSubnet = findNextAvailableCidr( subnets, diff --git a/server/routers/site/createSite.ts b/server/routers/site/createSite.ts index fb1170cd..af8e4073 100644 --- a/server/routers/site/createSite.ts +++ b/server/routers/site/createSite.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { clients, db } from "@server/db"; +import { clients, db, exitNodes } from "@server/db"; import { roles, userSites, sites, roleSites, Site, orgs } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; @@ -17,6 +17,7 @@ import { hashPassword } from "@server/auth/password"; import { isValidIP } from "@server/lib/validators"; import { isIpInCidr } from "@server/lib/ip"; import config from "@server/lib/config"; +import { verifyExitNodeOrgAccess } from "@server/lib/exitNodes"; const createSiteParamsSchema = z .object({ @@ -217,6 +218,32 @@ export async function createSite( ); } + const { exitNode, hasAccess } = + await verifyExitNodeOrgAccess( + exitNodeId, + orgId + ); + + if (!exitNode) { + logger.warn("Exit node not found"); + return next( + createHttpError( + HttpCode.NOT_FOUND, + "Exit node not found" + ) + ); + } + + if (!hasAccess) { + logger.warn("Not authorized to use this exit node"); + return next( + createHttpError( + HttpCode.FORBIDDEN, + "Not authorized to use this exit node" + ) + ); + } + [newSite] = await trx .insert(sites) .values({ diff --git a/server/routers/site/pickSiteDefaults.ts b/server/routers/site/pickSiteDefaults.ts index d6309d0c..2e705c56 100644 --- a/server/routers/site/pickSiteDefaults.ts +++ b/server/routers/site/pickSiteDefaults.ts @@ -6,12 +6,16 @@ import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; -import { findNextAvailableCidr, getNextAvailableClientSubnet } from "@server/lib/ip"; +import { + findNextAvailableCidr, + getNextAvailableClientSubnet +} from "@server/lib/ip"; import { generateId } from "@server/auth/sessions/app"; import config from "@server/lib/config"; import { OpenAPITags, registry } from "@server/openApi"; import { fromError } from "zod-validation-error"; import { z } from "zod"; +import { listExitNodes } from "@server/lib/exitNodes"; export type PickSiteDefaultsResponse = { exitNodeId: number; @@ -65,16 +69,10 @@ export async function pickSiteDefaults( const { orgId } = parsedParams.data; // TODO: more intelligent way to pick the exit node - // make sure there is an exit node by counting the exit nodes table - const nodes = await db.select().from(exitNodes); - if (nodes.length === 0) { - return next( - createHttpError(HttpCode.NOT_FOUND, "No exit nodes available") - ); - } + const exitNodesList = await listExitNodes(orgId); - // get the first exit node - const exitNode = nodes[0]; + const randomExitNode = + exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; // TODO: this probably can be optimized... // list all of the sites on that exit node @@ -83,13 +81,15 @@ export async function pickSiteDefaults( subnet: sites.subnet }) .from(sites) - .where(eq(sites.exitNodeId, exitNode.exitNodeId)); + .where(eq(sites.exitNodeId, randomExitNode.exitNodeId)); // TODO: we need to lock this subnet for some time so someone else does not take it - const subnets = sitesQuery.map((site) => site.subnet).filter((subnet) => subnet !== null); + const subnets = sitesQuery + .map((site) => site.subnet) + .filter((subnet) => subnet !== null); // exclude the exit node address by replacing after the / with a site block size subnets.push( - exitNode.address.replace( + randomExitNode.address.replace( /\/\d+$/, `/${config.getRawConfig().gerbil.site_block_size}` ) @@ -97,7 +97,7 @@ export async function pickSiteDefaults( const newSubnet = findNextAvailableCidr( subnets, config.getRawConfig().gerbil.site_block_size, - exitNode.address + randomExitNode.address ); if (!newSubnet) { return next( @@ -125,12 +125,12 @@ export async function pickSiteDefaults( return response(res, { data: { - exitNodeId: exitNode.exitNodeId, - address: exitNode.address, - publicKey: exitNode.publicKey, - name: exitNode.name, - listenPort: exitNode.listenPort, - endpoint: exitNode.endpoint, + exitNodeId: randomExitNode.exitNodeId, + address: randomExitNode.address, + publicKey: randomExitNode.publicKey, + name: randomExitNode.name, + listenPort: randomExitNode.listenPort, + endpoint: randomExitNode.endpoint, // subnet: `${newSubnet.split("/")[0]}/${config.getRawConfig().gerbil.block_size}`, // we want the block size of the whole subnet subnet: newSubnet, clientAddress: clientAddress,