diff --git a/server/db/schema.ts b/server/db/schema.ts index 002fc442..80baa52a 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -48,7 +48,8 @@ export const sites = sqliteTable("sites", { address: text("address"), // this is the address of the wireguard interface in gerbil endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config publicKey: text("pubicKey"), - lastHolePunch: integer("lastHolePunch"), + lastHolePunch: integer("lastHolePunch"), + listenPort: integer("listenPort") }); export const resources = sqliteTable("resources", { diff --git a/server/lib/ip.test.ts b/server/lib/ip.test.ts index 2c2dd057..67a2faaa 100644 --- a/server/lib/ip.test.ts +++ b/server/lib/ip.test.ts @@ -4,7 +4,14 @@ import { assertEquals } from "@test/assert"; // Test cases function testFindNextAvailableCidr() { console.log("Running findNextAvailableCidr tests..."); - + + // Test 0: Basic IPv4 allocation with a subnet in the wrong range + { + const existing = ["100.90.130.1/30", "100.90.128.4/30"]; + const result = findNextAvailableCidr(existing, 30, "100.90.130.1/24"); + assertEquals(result, "100.90.130.4/30", "Basic IPv4 allocation failed"); + } + // Test 1: Basic IPv4 allocation { const existing = ["10.0.0.0/16", "10.1.0.0/16"]; @@ -26,6 +33,12 @@ function testFindNextAvailableCidr() { assertEquals(result, null, "No available space test failed"); } + // Test 4: Empty existing + { + const existing: string[] = []; + const result = findNextAvailableCidr(existing, 30, "10.0.0.0/8"); + assertEquals(result, "10.0.0.0/30", "Empty existing test failed"); + } // // Test 4: IPv6 allocation // { // const existing = ["2001:db8::/32", "2001:db8:1::/32"]; diff --git a/server/lib/ip.ts b/server/lib/ip.ts index a3a78027..d06ce27f 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -137,7 +137,6 @@ export function findNextAvailableCidr( blockSize: number, startCidr?: string ): string | null { - if (!startCidr && existingCidrs.length === 0) { return null; } @@ -155,40 +154,47 @@ export function findNextAvailableCidr( existingCidrs.some(cidr => detectIpVersion(cidr.split('/')[0]) !== version)) { throw new Error('All CIDRs must be of the same IP version'); } - + + // Extract the network part from startCidr to ensure we stay in the right subnet + const startCidrRange = cidrToRange(startCidr); + // Convert existing CIDRs to ranges and sort them const existingRanges = existingCidrs .map(cidr => cidrToRange(cidr)) .sort((a, b) => (a.start < b.start ? -1 : 1)); - + // Calculate block size const maxPrefix = version === 4 ? 32 : 128; const blockSizeBigInt = BigInt(1) << BigInt(maxPrefix - blockSize); - + // Start from the beginning of the given CIDR - let current = cidrToRange(startCidr).start; - const maxIp = cidrToRange(startCidr).end; - + let current = startCidrRange.start; + const maxIp = startCidrRange.end; + // Iterate through existing ranges for (let i = 0; i <= existingRanges.length; i++) { const nextRange = existingRanges[i]; + // Align current to block size const alignedCurrent = current + ((blockSizeBigInt - (current % blockSizeBigInt)) % blockSizeBigInt); - + // Check if we've gone beyond the maximum allowed IP if (alignedCurrent + blockSizeBigInt - BigInt(1) > maxIp) { return null; } - + // If we're at the end of existing ranges or found a gap if (!nextRange || alignedCurrent + blockSizeBigInt - BigInt(1) < nextRange.start) { return `${bigIntToIp(alignedCurrent, version)}/${blockSize}`; } - - // Move current pointer to after the current range - current = nextRange.end + BigInt(1); + + // If next range overlaps with our search space, move past it + if (nextRange.end >= startCidrRange.start && nextRange.start <= maxIp) { + // Move current pointer to after the current range + current = nextRange.end + BigInt(1); + } } - + return null; } diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index 26608345..e2ba01c6 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -133,7 +133,7 @@ export async function updateHolePunch( return res.status(HttpCode.OK).send({ destinationIp: site.subnet.split("/")[0], - destinationPort: parseInt(site.endpoint.split(":")[1]) + destinationPort: site.listenPort }); } catch (error) { logger.error(error); diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index 00b5ee64..78ec32aa 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -8,7 +8,8 @@ import { eq } from "drizzle-orm"; import { getNextAvailableClientSubnet } from "@server/lib/ip"; const inputSchema = z.object({ - publicKey: z.string() + publicKey: z.string(), + port: z.number().int().positive(), }); type Input = z.infer; @@ -40,7 +41,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { return; } - const { publicKey } = message.data as Input; + const { publicKey, port } = message.data as Input; const siteId = newt.siteId; @@ -64,7 +65,8 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { .update(sites) .set({ publicKey, - address + address, + listenPort: port, }) .where(eq(sites.siteId, siteId)) .returning(); @@ -77,7 +79,8 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { const [siteRes] = await db .update(sites) .set({ - publicKey + publicKey, + listenPort: port, }) .where(eq(sites.siteId, siteId)) .returning(); diff --git a/server/routers/olm/handleOlmRelayMessage.ts b/server/routers/olm/handleOlmRelayMessage.ts new file mode 100644 index 00000000..ef42e05a --- /dev/null +++ b/server/routers/olm/handleOlmRelayMessage.ts @@ -0,0 +1,76 @@ +import db from "@server/db"; +import { MessageHandler } from "../ws"; +import { clients, Olm, olms, sites } from "@server/db/schema"; +import { eq } from "drizzle-orm"; +import { addPeer, deletePeer } from "../newt/peers"; +import logger from "@server/logger"; + +export const handleOlmRelayMessage: MessageHandler = async (context) => { + const { message, client: c, sendToClient } = context; + const olm = c as Olm; + + logger.info("Handling relay olm message!"); + + if (!olm) { + logger.warn("Olm not found"); + return; + } + + if (!olm.clientId) { + logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? + return; + } + + const clientId = olm.clientId; + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, clientId)) + .limit(1); + + if (!client || !client.siteId) { + logger.warn("Site not found or does not have exit node"); + return; + } + + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, client.siteId)) + .limit(1); + + if (!client) { + logger.warn("Site not found or does not have exit node"); + return; + } + + // make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old + if (!client.pubKey) { + logger.warn("Site or client has no endpoint or listen port"); + return; + } + + if (!site.subnet) { + logger.warn("Site has no subnet"); + return; + } + + await deletePeer(site.siteId, client.pubKey); + + // add the peer to the exit node + await addPeer(site.siteId, { + publicKey: client.pubKey, + allowedIps: [client.subnet], + endpoint: "" + }); + + return { + message: { + type: "olm/wg/relay-success", + data: {} + }, + broadcast: false, // Send to all olms + excludeSender: false // Include sender in broadcast + }; +}; diff --git a/server/routers/site/createSite.ts b/server/routers/site/createSite.ts index a1d1876f..f79149cc 100644 --- a/server/routers/site/createSite.ts +++ b/server/routers/site/createSite.ts @@ -10,7 +10,6 @@ import { eq, and } from "drizzle-orm"; import { getUniqueSiteName } from "@server/db/names"; import { addPeer } from "../gerbil/peers"; import { fromError } from "zod-validation-error"; -import { hash } from "@node-rs/argon2"; import { newts } from "@server/db/schema"; import moment from "moment"; import { hashPassword } from "@server/auth/password"; diff --git a/src/app/[orgId]/settings/sites/create/page.tsx b/src/app/[orgId]/settings/sites/create/page.tsx index a8002705..acafd5bb 100644 --- a/src/app/[orgId]/settings/sites/create/page.tsx +++ b/src/app/[orgId]/settings/sites/create/page.tsx @@ -324,7 +324,7 @@ PersistentKeepalive = 5`; let payload: CreateSiteBody = { name: data.name, - type: data.method + type: data.method as any, }; if (data.method == "wireguard") {