diff --git a/server/routers/gerbil/peers.ts b/server/routers/gerbil/peers.ts index 525dced5..6aaeae01 100644 --- a/server/routers/gerbil/peers.ts +++ b/server/routers/gerbil/peers.ts @@ -43,9 +43,7 @@ export async function deletePeer(exitNodeId: number, publicKey: string) { throw new Error(`Exit node with ID ${exitNodeId} is not reachable`); } try { - const response = await axios.delete(`${exitNode.reachableAt}/peer`, { - data: { publicKey } // Send public key in request body - }); + const response = await axios.delete(`${exitNode.reachableAt}/peer?public_key=${encodeURIComponent(publicKey)}`); logger.info('Peer deleted successfully:', response.data.status); return response.data; } catch (error) { diff --git a/server/routers/newt/createNewt.ts b/server/routers/newt/createNewt.ts index 4a177e01..565d3a8b 100644 --- a/server/routers/newt/createNewt.ts +++ b/server/routers/newt/createNewt.ts @@ -13,6 +13,7 @@ import { generateSessionToken, } from "@server/auth"; import { createNewtSession } from "@server/auth/newt"; +import { fromError } from "zod-validation-error"; export const createNewtBodySchema = z.object({}); @@ -24,6 +25,13 @@ export type CreateNewtResponse = { secret: string; }; +const createNewtSchema = z + .object({ + newtId: z.string(), + secret: z.string() + }) + .strict(); + export async function createNewt( req: Request, res: Response, @@ -31,8 +39,25 @@ export async function createNewt( ): Promise { try { + const parsedBody = createNewtSchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { newtId, secret } = parsedBody.data; + + if (!req.userOrgRoleId) { + return next( + createHttpError(HttpCode.FORBIDDEN, "User does not have a role") + ); + } + // generate a newtId and secret - const secret = generateId(48); const secretHash = await hash(secret, { memoryCost: 19456, timeCost: 2, @@ -40,8 +65,6 @@ export async function createNewt( parallelism: 1, }); - const newtId = generateId(15); - await db.insert(newts).values({ newtId: newtId, secretHash, diff --git a/server/routers/newt/handleRegisterMessage.ts b/server/routers/newt/handleRegisterMessage.ts index cf38b2af..22b0f487 100644 --- a/server/routers/newt/handleRegisterMessage.ts +++ b/server/routers/newt/handleRegisterMessage.ts @@ -4,8 +4,6 @@ import { exitNodes, resources, sites, targets } from "@server/db/schema"; import { eq, inArray } from "drizzle-orm"; import { addPeer, deletePeer } from "../gerbil/peers"; import logger from "@server/logger"; -import { findNextAvailableCidr } from "@server/utils/ip"; -import { exit } from "process"; export const handleRegisterMessage: MessageHandler = async (context) => { const { message, newt, sendToClient } = context; @@ -28,13 +26,18 @@ export const handleRegisterMessage: MessageHandler = async (context) => { return; } - // const [site] = await db - // .select() - // .from(sites) - // .where(eq(sites.siteId, siteId)) - // .limit(1); - const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site || !site.exitNodeId) { + logger.warn("Site not found or does not have exit node"); + return; + } + + const [updatedSite] = await db .update(sites) .set({ pubKey: publicKey @@ -43,11 +46,6 @@ export const handleRegisterMessage: MessageHandler = async (context) => { .returning(); - if (!site || !site.exitNodeId) { - logger.warn("Site not found or does not have exit node"); - return; - } - const [exitNode] = await db .select() .from(exitNodes) @@ -100,10 +98,10 @@ export const handleRegisterMessage: MessageHandler = async (context) => { message: { type: "newt/wg/connect", data: { - endpoint: exitNode.endpoint, + endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`, publicKey: exitNode.publicKey, - serverIP: exitNode.address, - tunnelIP: site.subnet, + serverIP: exitNode.address.split("/")[0], + tunnelIP: site.subnet.split("/")[0], targets: { udp: udpTargets, tcp: tcpTargets, diff --git a/server/routers/site/createSite.ts b/server/routers/site/createSite.ts index b624d6ea..45be6b20 100644 --- a/server/routers/site/createSite.ts +++ b/server/routers/site/createSite.ts @@ -10,6 +10,9 @@ import { eq, and } from "drizzle-orm"; import { getUniqueSiteName } from "@server/db/names"; import { addPeer } from "../gerbil/peers"; import { fromError } from "zod-validation-error"; +import { hash } from "@node-rs/argon2"; +import { newts } from "@server/db/schema"; +import moment from "moment"; const createSiteParamsSchema = z.object({ orgId: z.string(), @@ -22,10 +25,14 @@ const createSiteSchema = z subdomain: z.string().min(1).max(255).optional(), pubKey: z.string().optional(), subnet: z.string(), + newtId: z.string().optional(), + secret: z.string().optional(), type: z.string(), }) .strict(); +export type CreateSiteBody = z.infer; + export type CreateSiteResponse = { name: string; siteId: number; @@ -49,7 +56,8 @@ export async function createSite( ); } - const { name, type, exitNodeId, pubKey, subnet } = parsedBody.data; + const { name, type, exitNodeId, pubKey, subnet, newtId, secret } = + parsedBody.data; const parsedParams = createSiteParamsSchema.safeParse(req.params); if (!parsedParams.success) { @@ -80,7 +88,8 @@ export async function createSite( type, }; - if (pubKey) { + if (pubKey && type == "wireguard") { + // we dont add the pubKey for newts because the newt will generate it payload = { ...payload, pubKey, @@ -114,19 +123,34 @@ export async function createSite( }); } - if (pubKey) { - // add the peer to the exit node - if (type == "newt") { - await addPeer(exitNodeId, { - publicKey: pubKey, - allowedIps: [subnet], - }); - } else if (type == "wireguard") { - await addPeer(exitNodeId, { - publicKey: pubKey, - allowedIps: [], - }); + // add the peer to the exit node + if (type == "newt") { + const secretHash = await hash(secret!, { + memoryCost: 19456, + timeCost: 2, + outputLen: 32, + parallelism: 1, + }); + + await db.insert(newts).values({ + newtId: newtId!, + secretHash, + siteId: newSite.siteId, + dateCreated: moment().toISOString(), + }); + } else if (type == "wireguard") { + if (!pubKey) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Public key is required for wireguard sites" + ) + ); } + await addPeer(exitNodeId, { + publicKey: pubKey, + allowedIps: [], + }); } return response(res, { @@ -142,7 +166,7 @@ export async function createSite( status: HttpCode.CREATED, }); } catch (error) { - logger.error(error); + throw error; return next( createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") ); diff --git a/server/routers/site/pickSiteDefaults.ts b/server/routers/site/pickSiteDefaults.ts index cddddcea..63642ba5 100644 --- a/server/routers/site/pickSiteDefaults.ts +++ b/server/routers/site/pickSiteDefaults.ts @@ -7,6 +7,7 @@ import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; import { findNextAvailableCidr } from "@server/utils/ip"; +import { generateId } from "@server/auth"; export type PickSiteDefaultsResponse = { exitNodeId: number; @@ -16,6 +17,8 @@ export type PickSiteDefaultsResponse = { listenPort: number; endpoint: string; subnet: string; + newtId: string; + newtSecret: string; }; export async function pickSiteDefaults( @@ -60,6 +63,9 @@ export async function pickSiteDefaults( ); } + const newtId = generateId(15); + const secret = generateId(48); + return response(res, { data: { exitNodeId: exitNode.exitNodeId, @@ -69,6 +75,8 @@ export async function pickSiteDefaults( listenPort: exitNode.listenPort, endpoint: exitNode.endpoint, subnet: newSubnet, + newtId, + newtSecret: secret, }, success: true, error: false, diff --git a/server/routers/target/createTarget.ts b/server/routers/target/createTarget.ts index 41df7fd6..11b73dca 100644 --- a/server/routers/target/createTarget.ts +++ b/server/routers/target/createTarget.ts @@ -91,7 +91,7 @@ export async function createTarget( } // make sure the target is within the site subnet - if (!isIpInCidr(targetData.ip, site.subnet!)) { + if (site.type == "wireguard" && !isIpInCidr(targetData.ip, site.subnet!)) { return next( createHttpError( HttpCode.BAD_REQUEST, diff --git a/src/app/[orgId]/settings/sites/components/CreateSiteForm.tsx b/src/app/[orgId]/settings/sites/components/CreateSiteForm.tsx index 08abcf5b..85e81b43 100644 --- a/src/app/[orgId]/settings/sites/components/CreateSiteForm.tsx +++ b/src/app/[orgId]/settings/sites/components/CreateSiteForm.tsx @@ -29,7 +29,7 @@ import { } from "@app/components/Credenza"; import { useOrgContext } from "@app/hooks/useOrgContext"; import { useParams, useRouter } from "next/navigation"; -import { PickSiteDefaultsResponse } from "@server/routers/site"; +import { CreateSiteBody, PickSiteDefaultsResponse } from "@server/routers/site"; import { generateKeypair } from "../[niceId]/components/wireguardConfig"; import CopyTextBox from "@app/components/CopyTextBox"; import { Checkbox } from "@app/components/ui/checkbox"; @@ -43,8 +43,8 @@ import { import { formatAxiosError } from "@app/lib/utils"; const method = [ - { label: "Wireguard", value: "wg" }, { label: "Newt", value: "newt" }, + { label: "Wireguard", value: "wireguard" }, ] as const; const accountFormSchema = z.object({ @@ -56,14 +56,14 @@ const accountFormSchema = z.object({ .max(30, { message: "Name must not be longer than 30 characters.", }), - method: z.enum(["wg", "newt"]), + method: z.enum(["wireguard", "newt"]), }); type AccountFormValues = z.infer; const defaultValues: Partial = { name: "", - method: "wg", + method: "newt", }; type CreateSiteFormProps = { @@ -124,13 +124,22 @@ export default function CreateSiteForm({ open, setOpen }: CreateSiteFormProps) { async function onSubmit(data: AccountFormValues) { setLoading(true); + if (!siteDefaults || !keypair) { + return; + } + let payload: CreateSiteBody = { + name: data.name, + subnet: siteDefaults.subnet, + exitNodeId: siteDefaults.exitNodeId, + pubKey: keypair.publicKey, + type: data.method, + }; + if (data.method === "newt") { + payload.secret = siteDefaults.newtSecret; + payload.newtId = siteDefaults.newtId; + } const res = await api - .put(`/org/${orgId}/site/`, { - name: data.name, - subnet: siteDefaults?.subnet, - exitNodeId: siteDefaults?.exitNodeId, - pubKey: keypair?.publicKey, - }) + .put(`/org/${orgId}/site/`, payload) .catch((e) => { toast({ variant: "destructive", @@ -165,8 +174,7 @@ Endpoint = ${siteDefaults.endpoint}:${siteDefaults.listenPort} PersistentKeepalive = 5` : ""; - const newtConfig = `curl -fsSL https://get.docker.com -o get-docker.sh -sh get-docker.sh`; + const newtConfig = `newt --id ${siteDefaults?.newtId} --secret ${siteDefaults?.newtSecret}`; return ( <> @@ -236,7 +244,7 @@ sh get-docker.sh`; - + WireGuard @@ -255,10 +263,10 @@ sh get-docker.sh`; />
- {form.watch("method") === "wg" && + {form.watch("method") === "wireguard" && !isLoading ? ( - ) : form.watch("method") === "wg" && + ) : form.watch("method") === "wireguard" && isLoading ? (

Loading WireGuard