add wg site get config and pick client defaults

This commit is contained in:
miloschwartz 2025-02-20 22:34:51 -05:00
parent fb49fb8ddd
commit 41983ce356
No known key found for this signature in database
9 changed files with 305 additions and 5 deletions

View file

@ -62,6 +62,7 @@ export enum ActionsEnum {
deleteResourceRule = "deleteResourceRule", deleteResourceRule = "deleteResourceRule",
listResourceRules = "listResourceRules", listResourceRules = "listResourceRules",
updateResourceRule = "updateResourceRule", updateResourceRule = "updateResourceRule",
createClient = "createClient"
} }
export async function checkUserActionPermission( export async function checkUserActionPermission(

View file

@ -31,8 +31,7 @@ export const sites = sqliteTable("sites", {
address: text("address"), // this is the address of the wireguard interface in gerbil address: text("address"), // this is the address of the wireguard interface in gerbil
endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config
publicKey: text("pubicKey"), publicKey: text("pubicKey"),
listenPort: integer("listenPort"), listenPort: integer("listenPort")
reachableAt: text("reachableAt") // this is the internal address of the gerbil http server for command control
}); });
export const resources = sqliteTable("resources", { export const resources = sqliteTable("resources", {
@ -121,7 +120,16 @@ export const clients = sqliteTable("clients", {
dateCreated: text("dateCreated").notNull(), dateCreated: text("dateCreated").notNull(),
siteId: integer("siteId").references(() => sites.siteId, { siteId: integer("siteId").references(() => sites.siteId, {
onDelete: "cascade" onDelete: "cascade"
}) }),
// wgstuff
pubKey: text("pubKey"),
subnet: text("subnet").notNull(),
megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"),
lastBandwidthUpdate: text("lastBandwidthUpdate"),
type: text("type").notNull(), // "newt" or "wireguard"
online: integer("online", { mode: "boolean" }).notNull().default(false),
}); });
export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", { export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {

View file

@ -109,6 +109,10 @@ const configSchema = z.object({
block_size: z.number().positive().gt(0), block_size: z.number().positive().gt(0),
site_block_size: z.number().positive().gt(0) site_block_size: z.number().positive().gt(0)
}), }),
wg_site: z.object({
block_size: z.number().positive().gt(0),
subnet_group: z.string(),
}),
rate_limits: z.object({ rate_limits: z.object({
global: z.object({ global: z.object({
window_minutes: z.number().positive().gt(0), window_minutes: z.number().positive().gt(0),

View file

@ -0,0 +1 @@
export * from "./pickClientDefaults";

View file

@ -0,0 +1,128 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { clients, sites } from "@server/db/schema";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { findNextAvailableCidr } from "@server/lib/ip";
import { generateId } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import { z } from "zod";
import { fromError } from "zod-validation-error";
const getSiteSchema = z
.object({
siteId: z.number().int().positive()
})
.strict();
export type PickClientDefaultsResponse = {
siteId: number;
address: string;
publicKey: string;
name: string;
listenPort: number;
endpoint: string;
subnet: string;
clientId: string;
clientSecret: string;
};
export async function pickClientDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = getSiteSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { siteId } = parsedParams.data;
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// make sure all the required fields are present
if (
!site.address ||
!site.publicKey ||
!site.listenPort ||
!site.endpoint
) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Site has no address")
);
}
const clientsQuery = await db
.select({
subnet: clients.subnet
})
.from(clients)
.where(eq(clients.siteId, site.siteId));
let subnets = clientsQuery.map((client) => client.subnet);
// exclude the exit node address by replacing after the / with a site block size
subnets.push(
site.address.replace(
/\/\d+$/,
`/${config.getRawConfig().wg_site.block_size}`
)
);
const newSubnet = findNextAvailableCidr(
subnets,
config.getRawConfig().wg_site.block_size,
site.address
);
if (!newSubnet) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"No available subnets"
)
);
}
const clientId = generateId(15);
const secret = generateId(48);
return response<PickClientDefaultsResponse>(res, {
data: {
siteId: site.siteId,
address: site.address,
publicKey: site.publicKey,
name: site.name,
listenPort: site.listenPort,
endpoint: site.endpoint,
subnet: newSubnet,
clientId,
clientSecret: secret
},
success: true,
error: false,
message: "Organization retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -7,6 +7,7 @@ import * as target from "./target";
import * as user from "./user"; import * as user from "./user";
import * as auth from "./auth"; import * as auth from "./auth";
import * as role from "./role"; import * as role from "./role";
import * as client from "./client";
import * as accessToken from "./accessToken"; import * as accessToken from "./accessToken";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { import {
@ -94,6 +95,14 @@ authenticated.get(
verifyUserHasAction(ActionsEnum.getSite), verifyUserHasAction(ActionsEnum.getSite),
site.getSite site.getSite
); );
authenticated.get(
"/site/:siteId/pick-client-defaults",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
client.pickClientDefaults
);
// authenticated.get( // authenticated.get(
// "/site/:siteId/roles", // "/site/:siteId/roles",
// verifySiteAccess, // verifySiteAccess,

View file

@ -86,7 +86,7 @@ export async function getConfig(req: Request, res: Response, next: NextFunction)
const peers = await Promise.all(sitesRes.map(async (site) => { const peers = await Promise.all(sitesRes.map(async (site) => {
return { return {
publicKey: site.pubKey, publicKey: site.pubKey,
allowedIps: await getAllowedIps(site.siteId) allowedIps: await getAllowedIps(site.siteId) // put 0.0.0.0/0 for now
}; };
})); }));

View file

@ -1,6 +1,8 @@
import { handleRegisterMessage } from "./newt"; import { handleRegisterMessage } from "./newt";
import { handleGetConfigMessage } from "./newt/handleGetConfigMessage";
import { MessageHandler } from "./ws"; import { MessageHandler } from "./ws";
export const messageHandlers: Record<string, MessageHandler> = { export const messageHandlers: Record<string, MessageHandler> = {
"newt/wg/register": handleRegisterMessage, "newt/wg/register": handleRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage,
}; };

View file

@ -0,0 +1,147 @@
import { z } from "zod";
import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import db from "@server/db";
import { clients, Site, sites } from "@server/db/schema";
import { eq, isNotNull } from "drizzle-orm";
import { findNextAvailableCidr } from "@server/lib/ip";
import config from "@server/lib/config";
const inputSchema = z.object({
publicKey: z.string(),
endpoint: z.string(),
listenPort: z.number()
});
type Input = z.infer<typeof inputSchema>;
export const handleGetConfigMessage: MessageHandler = async (context) => {
const { message, newt, sendToClient } = context;
logger.debug("Handling Newt get config message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const parsed = inputSchema.safeParse(message.data);
if (!parsed.success) {
logger.error(
"handleGetConfigMessage: Invalid input: " +
fromError(parsed.error).toString()
);
return;
}
const { publicKey, endpoint, listenPort } = message.data as Input;
const siteId = newt.siteId;
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (!siteRes) {
logger.warn("handleGetConfigMessage: Site not found");
return;
}
let site: Site | undefined;
if (!site) {
const address = await getNextAvailableSubnet();
// create a new exit node
const [updateRes] = await db
.update(sites)
.set({
publicKey,
endpoint,
address,
listenPort
})
.where(eq(sites.siteId, siteId))
.returning();
site = updateRes;
logger.info(`Updated site ${siteId} with new WG Newt info`);
} else {
site = siteRes;
}
if (!site) {
logger.error("handleGetConfigMessage: Failed to update site");
return;
}
const clientsRes = await db
.select()
.from(clients)
.where(eq(clients.siteId, siteId));
const peers = await Promise.all(
clientsRes.map(async (client) => {
return {
publicKey: client.pubKey,
allowedIps: "0.0.0.0/0"
};
})
);
const configResponse = {
listenPort: site.listenPort, // ?????
// ipAddress: exitNode[0].address,
peers
};
logger.debug("Sending config: ", configResponse);
return {
message: {
type: "newt/wg/connect", // what to make the response type?
data: {
config: configResponse
}
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};
};
async function getNextAvailableSubnet(): Promise<string> {
const existingAddresses = await db
.select({
address: sites.address
})
.from(sites)
.where(isNotNull(sites.address));
const addresses = existingAddresses
.map((a) => a.address)
.filter((a) => a) as string[];
let subnet = findNextAvailableCidr(
addresses,
config.getRawConfig().wg_site.block_size,
config.getRawConfig().wg_site.subnet_group
);
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
// replace the last octet with 1
subnet =
subnet.split(".").slice(0, 3).join(".") +
".1" +
"/" +
subnet.split("/")[1];
return subnet;
}