This commit is contained in:
Milo Schwartz 2024-11-19 00:05:26 -05:00
commit 96888876e5
No known key found for this signature in database
8 changed files with 300 additions and 72 deletions

View file

@ -23,6 +23,7 @@ export const sites = sqliteTable("sites", {
subnet: text("subnet").notNull(), subnet: text("subnet").notNull(),
megabytesIn: integer("bytesIn"), megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"), megabytesOut: integer("bytesOut"),
type: text("type").notNull(), // "newt" or "wireguard"
}); });
export const resources = sqliteTable("resources", { export const resources = sqliteTable("resources", {
@ -60,6 +61,7 @@ export const targets = sqliteTable("targets", {
ip: text("ip").notNull(), ip: text("ip").notNull(),
method: text("method").notNull(), method: text("method").notNull(),
port: integer("port").notNull(), port: integer("port").notNull(),
internalPort: integer("internalPort"),
protocol: text("protocol"), protocol: text("protocol"),
enabled: integer("enabled", { mode: "boolean" }).notNull().default(true), enabled: integer("enabled", { mode: "boolean" }).notNull().default(true),
}); });
@ -92,9 +94,7 @@ export const newts = sqliteTable("newt", {
newtId: text("id").primaryKey(), newtId: text("id").primaryKey(),
secretHash: text("secretHash").notNull(), secretHash: text("secretHash").notNull(),
dateCreated: text("dateCreated").notNull(), dateCreated: text("dateCreated").notNull(),
siteId: integer("siteId").references(() => sites.siteId, { siteId: integer("siteId").references(() => sites.siteId),
onDelete: "cascade",
}),
}); });
export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", { export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {

View file

@ -1,42 +1,116 @@
import db from "@server/db"; import db from "@server/db";
import { MessageHandler } from "../ws"; import { MessageHandler } from "../ws";
import { sites } from "@server/db/schema"; import { exitNodes, resources, sites, targets } from "@server/db/schema";
import { eq } from "drizzle-orm"; import { eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
import { findNextAvailableCidr } from "@server/utils/ip";
import { exit } from "process";
export const handleRegisterMessage: MessageHandler = async (context) => { export const handleRegisterMessage: MessageHandler = async (context) => {
const { message, newt, sendToClient } = context; const { message, newt, sendToClient } = context;
if (!newt) { if (!newt) {
console.log("Newt not found"); logger.warn("Newt not found");
return; return;
} }
if (!newt.siteId) { if (!newt.siteId) {
console.log("Newt has no site!"); // TODO: Maybe we create the site here? logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return; return;
} }
const siteId = newt.siteId; const siteId = newt.siteId;
// get the site const { publicKey } = message.data;
const site = await db if (!publicKey) {
.select() logger.warn("Public key not provided");
.from(sites) return;
}
// const [site] = await db
// .select()
// .from(sites)
// .where(eq(sites.siteId, siteId))
// .limit(1);
const [site] = await db
.update(sites)
.set({
pubKey: publicKey
})
.where(eq(sites.siteId, siteId)) .where(eq(sites.siteId, siteId))
.returning();
if (!site || !site.exitNodeId) {
logger.warn("Site not found or does not have exit node");
return;
}
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1); .limit(1);
if (site.pubKey && site.pubKey !== publicKey) {
logger.info("Public key mismatch. Deleting old peer...");
await deletePeer(site.exitNodeId, site.pubKey);
}
if (!site.subnet) {
logger.warn("Site has no subnet");
return;
}
// add the peer to the exit node
await addPeer(site.exitNodeId, {
publicKey: publicKey,
allowedIps: [site.subnet],
});
const siteResources = await db.select().from(resources).where(eq(resources.siteId, siteId));
// get the targets from the resourceIds
const siteTargets = await db
.select()
.from(targets)
.where(
inArray(
targets.resourceId,
siteResources.map(resource => resource.resourceId)
)
);
const udpTargets = siteTargets
.filter((target) => target.protocol === "udp")
.map((target) => {
return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`;
});
const tcpTargets = siteTargets
.filter((target) => target.protocol === "tcp")
.map((target) => {
return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`;
});
const { publicKey } = message.data;
return { return {
message: { message: {
type: 'newt/wg/connect', type: "newt/wg/connect",
data: { data: {
publicKey: 'publicKey', endpoint: exitNode.endpoint,
publicKey: exitNode.publicKey,
serverIP: exitNode.address,
tunnelIP: site.subnet,
targets: {
udp: udpTargets,
tcp: tcpTargets,
} }
}, },
},
broadcast: false, // Send to all clients broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast excludeSender: false, // Include sender in broadcast
}; };
}; };

View file

@ -0,0 +1,73 @@
import { Target } from "@server/db/schema";
import { sendToClient } from "../ws";
export async function addTargets(newtId: string, targets: Target[]): Promise<void> {
//create a list of udp and tcp targets
const udpTargets = targets
.filter((target) => target.protocol === "udp")
.map((target) => {
return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`;
});
const tcpTargets = targets
.filter((target) => target.protocol === "tcp")
.map((target) => {
return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`;
});
if (udpTargets.length > 0) {
const payload = {
type: `newt/udp/add`,
data: {
targets: udpTargets,
},
};
sendToClient(newtId, payload);
}
if (tcpTargets.length > 0) {
const payload = {
type: `newt/tcp/add`,
data: {
targets: tcpTargets,
},
};
sendToClient(newtId, payload);
}
}
export async function removeTargets(newtId: string, targets: Target[]): Promise<void> {
//create a list of udp and tcp targets
const udpTargets = targets
.filter((target) => target.protocol === "udp")
.map((target) => {
return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`;
});
const tcpTargets = targets
.filter((target) => target.protocol === "tcp")
.map((target) => {
return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`;
});
if (udpTargets.length > 0) {
const payload = {
type: `newt/udp/remove`,
data: {
targets: udpTargets,
},
};
sendToClient(newtId, payload);
}
if (tcpTargets.length > 0) {
const payload = {
type: `newt/tcp/remove`,
data: {
targets: tcpTargets,
},
};
sendToClient(newtId, payload);
}
}

View file

@ -22,6 +22,7 @@ const createSiteSchema = z
subdomain: z.string().min(1).max(255).optional(), subdomain: z.string().min(1).max(255).optional(),
pubKey: z.string().optional(), pubKey: z.string().optional(),
subnet: z.string(), subnet: z.string(),
type: z.string(),
}) })
.strict(); .strict();
@ -48,7 +49,7 @@ export async function createSite(
); );
} }
const { name, subdomain, exitNodeId, pubKey, subnet } = parsedBody.data; const { name, type, exitNodeId, pubKey, subnet } = parsedBody.data;
const parsedParams = createSiteParamsSchema.safeParse(req.params); const parsedParams = createSiteParamsSchema.safeParse(req.params);
if (!parsedParams.success) { if (!parsedParams.success) {
@ -76,6 +77,7 @@ export async function createSite(
name, name,
niceId, niceId,
subnet, subnet,
type,
}; };
if (pubKey) { if (pubKey) {
@ -114,11 +116,18 @@ export async function createSite(
if (pubKey) { if (pubKey) {
// add the peer to the exit node // add the peer to the exit node
if (type == "newt") {
await addPeer(exitNodeId, {
publicKey: pubKey,
allowedIps: [subnet],
});
} else if (type == "wireguard") {
await addPeer(exitNodeId, { await addPeer(exitNodeId, {
publicKey: pubKey, publicKey: pubKey,
allowedIps: [], allowedIps: [],
}); });
} }
}
return response(res, { return response(res, {
data: { data: {

View file

@ -17,11 +17,11 @@ const updateSiteBodySchema = z
.object({ .object({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional(),
subdomain: z.string().min(1).max(255).optional(), subdomain: z.string().min(1).max(255).optional(),
pubKey: z.string().optional(), // pubKey: z.string().optional(),
subnet: z.string().optional(), // subnet: z.string().optional(),
exitNode: z.number().int().positive().optional(), // exitNode: z.number().int().positive().optional(),
megabytesIn: z.number().int().nonnegative().optional(), // megabytesIn: z.number().int().nonnegative().optional(),
megabytesOut: z.number().int().nonnegative().optional(), // megabytesOut: z.number().int().nonnegative().optional(),
}) })
.strict() .strict()
.refine((data) => Object.keys(data).length > 0, { .refine((data) => Object.keys(data).length > 0, {

View file

@ -10,6 +10,7 @@ import { addPeer } from "../gerbil/peers";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import { isIpInCidr } from "@server/utils/ip"; import { isIpInCidr } from "@server/utils/ip";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { addTargets } from "../newt/targets";
const createTargetParamsSchema = z.object({ const createTargetParamsSchema = z.object({
resourceId: z.string().transform(Number).pipe(z.number().int().positive()), resourceId: z.string().transform(Number).pipe(z.number().int().positive()),
@ -111,6 +112,8 @@ export async function createTarget(
}) })
.returning(); .returning();
if (site.pubKey) {
if ( site.type == "wireguard") {
// Fetch resources for this site // Fetch resources for this site
const resourcesRes = await db.query.resources.findMany({ const resourcesRes = await db.query.resources.findMany({
where: eq(resources.siteId, site.siteId), where: eq(resources.siteId, site.siteId),
@ -130,6 +133,10 @@ export async function createTarget(
publicKey: site.pubKey, publicKey: site.pubKey,
allowedIps: targetIps.flat(), allowedIps: targetIps.flat(),
}); });
} else if (site.type == "newt") {
addTargets("",newTarget); // TODO: we need to generate and save the internal port somewhere and also come up with the newtId
}
}
return response<CreateTargetResponse>(res, { return response<CreateTargetResponse>(res, {
data: newTarget[0], data: newTarget[0],

View file

@ -9,6 +9,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { addPeer } from "../gerbil/peers"; import { addPeer } from "../gerbil/peers";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { removeTargets } from "../newt/targets";
const deleteTargetSchema = z.object({ const deleteTargetSchema = z.object({
targetId: z.string().transform(Number).pipe(z.number().int().positive()), targetId: z.string().transform(Number).pipe(z.number().int().positive()),
@ -80,6 +81,8 @@ export async function deleteTarget(
); );
} }
if (site.pubKey) {
if (site.type == "wireguard") {
// Fetch resources for this site // Fetch resources for this site
const resourcesRes = await db.query.resources.findMany({ const resourcesRes = await db.query.resources.findMany({
where: eq(resources.siteId, site.siteId), where: eq(resources.siteId, site.siteId),
@ -99,6 +102,10 @@ export async function deleteTarget(
publicKey: site.pubKey, publicKey: site.pubKey,
allowedIps: targetIps.flat(), allowedIps: targetIps.flat(),
}); });
} else if (site.type == "newt") {
removeTargets("", [deletedTarget]); // TODO: we need to generate and save the internal port somewhere and also come up with the newtId
}
}
return response(res, { return response(res, {
data: null, data: null,

View file

@ -1,13 +1,14 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db } from "@server/db";
import { targets } from "@server/db/schema"; import { resources, sites, targets } from "@server/db/schema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/utils/response"; import response from "@server/utils/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { addPeer } from "../gerbil/peers";
const updateTargetParamsSchema = z.object({ const updateTargetParamsSchema = z.object({
targetId: z.string().transform(Number).pipe(z.number().int().positive()), targetId: z.string().transform(Number).pipe(z.number().int().positive()),
@ -54,13 +55,13 @@ export async function updateTarget(
const { targetId } = parsedParams.data; const { targetId } = parsedParams.data;
const updateData = parsedBody.data; const updateData = parsedBody.data;
const updatedTarget = await db const [updatedTarget] = await db
.update(targets) .update(targets)
.set(updateData) .set(updateData)
.where(eq(targets.targetId, targetId)) .where(eq(targets.targetId, targetId))
.returning(); .returning();
if (updatedTarget.length === 0) { if (!updatedTarget) {
return next( return next(
createHttpError( createHttpError(
HttpCode.NOT_FOUND, HttpCode.NOT_FOUND,
@ -69,8 +70,65 @@ export async function updateTarget(
); );
} }
// get the resource
const [resource] = await db
.select({
siteId: resources.siteId,
})
.from(resources)
.where(eq(resources.resourceId, updatedTarget.resourceId!));
if (!resource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Resource with ID ${updatedTarget.resourceId} not found`
)
);
}
// TODO: is this all inefficient?
// get the site
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, resource.siteId!))
.limit(1);
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${resource.siteId} not found`
)
);
}
if (site.pubKey && site.type == "wireguard") {
// Fetch resources for this site
const resourcesRes = await db.query.resources.findMany({
where: eq(resources.siteId, site.siteId),
});
// Fetch targets for all resources of this site
const targetIps = await Promise.all(
resourcesRes.map(async (resource) => {
const targetsRes = await db.query.targets.findMany({
where: eq(targets.resourceId, resource.resourceId),
});
return targetsRes.map((target) => `${target.ip}/32`);
})
);
await addPeer(site.exitNodeId!, {
publicKey: site.pubKey,
allowedIps: targetIps.flat(),
});
}
return response(res, { return response(res, {
data: updatedTarget[0], data: updatedTarget,
success: true, success: true,
error: false, error: false,
message: "Target updated successfully", message: "Target updated successfully",