diff --git a/server/auth/actions.ts b/server/auth/actions.ts index 001b9a6c..972ff1a7 100644 --- a/server/auth/actions.ts +++ b/server/auth/actions.ts @@ -62,6 +62,7 @@ export enum ActionsEnum { deleteResourceRule = "deleteResourceRule", listResourceRules = "listResourceRules", updateResourceRule = "updateResourceRule", + createClient = "createClient" } export async function checkUserActionPermission( diff --git a/server/db/schema.ts b/server/db/schema.ts index 6671152a..70817573 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -31,8 +31,7 @@ export const sites = sqliteTable("sites", { address: text("address"), // this is the address of the wireguard interface in gerbil endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config publicKey: text("pubicKey"), - listenPort: integer("listenPort"), - reachableAt: text("reachableAt") // this is the internal address of the gerbil http server for command control + listenPort: integer("listenPort") }); export const resources = sqliteTable("resources", { @@ -121,7 +120,16 @@ export const clients = sqliteTable("clients", { dateCreated: text("dateCreated").notNull(), siteId: integer("siteId").references(() => sites.siteId, { onDelete: "cascade" - }) + }), + + // wgstuff + pubKey: text("pubKey"), + subnet: text("subnet").notNull(), + megabytesIn: integer("bytesIn"), + megabytesOut: integer("bytesOut"), + lastBandwidthUpdate: text("lastBandwidthUpdate"), + type: text("type").notNull(), // "newt" or "wireguard" + online: integer("online", { mode: "boolean" }).notNull().default(false), }); export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", { diff --git a/server/lib/config.ts b/server/lib/config.ts index 7c5ad227..fc1c0531 100644 --- a/server/lib/config.ts +++ b/server/lib/config.ts @@ -109,6 +109,10 @@ const configSchema = z.object({ block_size: z.number().positive().gt(0), site_block_size: z.number().positive().gt(0) }), + wg_site: z.object({ + block_size: z.number().positive().gt(0), + subnet_group: z.string(), + }), rate_limits: z.object({ global: z.object({ window_minutes: z.number().positive().gt(0), diff --git a/server/routers/client/index.ts b/server/routers/client/index.ts new file mode 100644 index 00000000..5b493724 --- /dev/null +++ b/server/routers/client/index.ts @@ -0,0 +1 @@ +export * from "./pickClientDefaults"; diff --git a/server/routers/client/pickClientDefaults.ts b/server/routers/client/pickClientDefaults.ts new file mode 100644 index 00000000..eb765fc2 --- /dev/null +++ b/server/routers/client/pickClientDefaults.ts @@ -0,0 +1,128 @@ +import { Request, Response, NextFunction } from "express"; +import { db } from "@server/db"; +import { clients, sites } from "@server/db/schema"; +import { eq } from "drizzle-orm"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { findNextAvailableCidr } from "@server/lib/ip"; +import { generateId } from "@server/auth/sessions/app"; +import config from "@server/lib/config"; +import { z } from "zod"; +import { fromError } from "zod-validation-error"; + +const getSiteSchema = z + .object({ + siteId: z.number().int().positive() + }) + .strict(); + +export type PickClientDefaultsResponse = { + siteId: number; + address: string; + publicKey: string; + name: string; + listenPort: number; + endpoint: string; + subnet: string; + clientId: string; + clientSecret: string; +}; + +export async function pickClientDefaults( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = getSiteSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { siteId } = parsedParams.data; + + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)); + + if (!site) { + return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); + } + + // make sure all the required fields are present + if ( + !site.address || + !site.publicKey || + !site.listenPort || + !site.endpoint + ) { + return next( + createHttpError(HttpCode.BAD_REQUEST, "Site has no address") + ); + } + + const clientsQuery = await db + .select({ + subnet: clients.subnet + }) + .from(clients) + .where(eq(clients.siteId, site.siteId)); + + let subnets = clientsQuery.map((client) => client.subnet); + + // exclude the exit node address by replacing after the / with a site block size + subnets.push( + site.address.replace( + /\/\d+$/, + `/${config.getRawConfig().wg_site.block_size}` + ) + ); + const newSubnet = findNextAvailableCidr( + subnets, + config.getRawConfig().wg_site.block_size, + site.address + ); + if (!newSubnet) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "No available subnets" + ) + ); + } + + const clientId = generateId(15); + const secret = generateId(48); + + return response(res, { + data: { + siteId: site.siteId, + address: site.address, + publicKey: site.publicKey, + name: site.name, + listenPort: site.listenPort, + endpoint: site.endpoint, + subnet: newSubnet, + clientId, + clientSecret: secret + }, + success: true, + error: false, + message: "Organization retrieved successfully", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} diff --git a/server/routers/external.ts b/server/routers/external.ts index 19c57008..778bf288 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -7,6 +7,7 @@ import * as target from "./target"; import * as user from "./user"; import * as auth from "./auth"; import * as role from "./role"; +import * as client from "./client"; import * as accessToken from "./accessToken"; import HttpCode from "@server/types/HttpCode"; import { @@ -94,6 +95,14 @@ authenticated.get( verifyUserHasAction(ActionsEnum.getSite), site.getSite ); + +authenticated.get( + "/site/:siteId/pick-client-defaults", + verifyOrgAccess, + verifyUserHasAction(ActionsEnum.createClient), + client.pickClientDefaults +); + // authenticated.get( // "/site/:siteId/roles", // verifySiteAccess, diff --git a/server/routers/gerbil/getConfig.ts b/server/routers/gerbil/getConfig.ts index 28b576d8..95e0df6b 100644 --- a/server/routers/gerbil/getConfig.ts +++ b/server/routers/gerbil/getConfig.ts @@ -86,7 +86,7 @@ export async function getConfig(req: Request, res: Response, next: NextFunction) const peers = await Promise.all(sitesRes.map(async (site) => { return { publicKey: site.pubKey, - allowedIps: await getAllowedIps(site.siteId) + allowedIps: await getAllowedIps(site.siteId) // put 0.0.0.0/0 for now }; })); diff --git a/server/routers/messageHandlers.ts b/server/routers/messageHandlers.ts index 9dd7756f..262f9869 100644 --- a/server/routers/messageHandlers.ts +++ b/server/routers/messageHandlers.ts @@ -1,6 +1,8 @@ import { handleRegisterMessage } from "./newt"; +import { handleGetConfigMessage } from "./newt/handleGetConfigMessage"; import { MessageHandler } from "./ws"; export const messageHandlers: Record = { "newt/wg/register": handleRegisterMessage, -}; \ No newline at end of file + "newt/wg/get-config": handleGetConfigMessage, +}; diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts new file mode 100644 index 00000000..17ac63dd --- /dev/null +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -0,0 +1,147 @@ +import { z } from "zod"; +import { MessageHandler } from "../ws"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; +import db from "@server/db"; +import { clients, Site, sites } from "@server/db/schema"; +import { eq, isNotNull } from "drizzle-orm"; +import { findNextAvailableCidr } from "@server/lib/ip"; +import config from "@server/lib/config"; + +const inputSchema = z.object({ + publicKey: z.string(), + endpoint: z.string(), + listenPort: z.number() +}); + +type Input = z.infer; + +export const handleGetConfigMessage: MessageHandler = async (context) => { + const { message, newt, sendToClient } = context; + + logger.debug("Handling Newt get config message!"); + + if (!newt) { + logger.warn("Newt not found"); + return; + } + + if (!newt.siteId) { + logger.warn("Newt has no site!"); // TODO: Maybe we create the site here? + return; + } + + const parsed = inputSchema.safeParse(message.data); + if (!parsed.success) { + logger.error( + "handleGetConfigMessage: Invalid input: " + + fromError(parsed.error).toString() + ); + return; + } + + const { publicKey, endpoint, listenPort } = message.data as Input; + + const siteId = newt.siteId; + + const [siteRes] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)); + + if (!siteRes) { + logger.warn("handleGetConfigMessage: Site not found"); + return; + } + + let site: Site | undefined; + if (!site) { + const address = await getNextAvailableSubnet(); + + // create a new exit node + const [updateRes] = await db + .update(sites) + .set({ + publicKey, + endpoint, + address, + listenPort + }) + .where(eq(sites.siteId, siteId)) + .returning(); + + site = updateRes; + + logger.info(`Updated site ${siteId} with new WG Newt info`); + } else { + site = siteRes; + } + + if (!site) { + logger.error("handleGetConfigMessage: Failed to update site"); + return; + } + + const clientsRes = await db + .select() + .from(clients) + .where(eq(clients.siteId, siteId)); + + const peers = await Promise.all( + clientsRes.map(async (client) => { + return { + publicKey: client.pubKey, + allowedIps: "0.0.0.0/0" + }; + }) + ); + + const configResponse = { + listenPort: site.listenPort, // ????? + // ipAddress: exitNode[0].address, + peers + }; + + logger.debug("Sending config: ", configResponse); + + return { + message: { + type: "newt/wg/connect", // what to make the response type? + data: { + config: configResponse + } + }, + broadcast: false, // Send to all clients + excludeSender: false // Include sender in broadcast + }; +}; + +async function getNextAvailableSubnet(): Promise { + const existingAddresses = await db + .select({ + address: sites.address + }) + .from(sites) + .where(isNotNull(sites.address)); + + const addresses = existingAddresses + .map((a) => a.address) + .filter((a) => a) as string[]; + + let subnet = findNextAvailableCidr( + addresses, + config.getRawConfig().wg_site.block_size, + config.getRawConfig().wg_site.subnet_group + ); + if (!subnet) { + throw new Error("No available subnets remaining in space"); + } + + // replace the last octet with 1 + subnet = + subnet.split(".").slice(0, 3).join(".") + + ".1" + + "/" + + subnet.split("/")[1]; + return subnet; +}