diff --git a/server/lib/readConfigFile.ts b/server/lib/readConfigFile.ts index b9f90d46..5d8500b4 100644 --- a/server/lib/readConfigFile.ts +++ b/server/lib/readConfigFile.ts @@ -162,6 +162,7 @@ export const configSchema = z gerbil: z .object({ exit_node_name: z.string().optional(), + max_connections: z.number().positive().gt(0).optional(), start_port: portSchema .optional() .default(51820) diff --git a/server/routers/newt/handleNewtPingRequestMessage.ts b/server/routers/newt/handleNewtPingRequestMessage.ts index 70bb52c3..946152a8 100644 --- a/server/routers/newt/handleNewtPingRequestMessage.ts +++ b/server/routers/newt/handleNewtPingRequestMessage.ts @@ -1,7 +1,9 @@ -import { db } from "@server/db"; +import { db, sites } from "@server/db"; import { MessageHandler } from "../ws"; import { exitNodes, Newt } from "@server/db"; import logger from "@server/logger"; +import config from "@server/lib/config"; +import { eq, and, count } from "drizzle-orm"; export const handleNewtPingRequestMessage: MessageHandler = async (context) => { const { message, client, sendToClient } = context; @@ -15,18 +17,41 @@ export const handleNewtPingRequestMessage: MessageHandler = async (context) => { } // TODO: pick which nodes to send and ping better than just all of them - const exitNodesList = await db - .select() - .from(exitNodes); + const exitNodesList = await db.select().from(exitNodes); - const exitNodesPayload = exitNodesList.map((node) => ({ - exitNodeId: node.exitNodeId, - exitNodeName: node.name, - endpoint: node.endpoint, - weight: 1 // TODO: Implement weight calculation if needed depending on load - // (MAX_CONNECTIONS - current_connections) / MAX_CONNECTIONS) - // higher = more desirable - })); + const exitNodesPayload = await Promise.all( + exitNodesList.map(async (node) => { + // (MAX_CONNECTIONS - current_connections) / MAX_CONNECTIONS) + // higher = more desirable + + let weight = 1; + const maxConnections = config.getRawConfig().gerbil.max_connections; + if (maxConnections !== undefined) { + const [currentConnections] = await db + .select({ + count: count() + }) + .from(sites) + .where( + and( + eq(sites.exitNodeId, node.exitNodeId), + eq(sites.online, true) + ) + ); + + weight = + (maxConnections - currentConnections.count) / + maxConnections; + } + + return { + exitNodeId: node.exitNodeId, + exitNodeName: node.name, + endpoint: node.endpoint, + weight + }; + }) + ); return { message: {