diff --git a/server/db/schema.ts b/server/db/schema.ts index 9d3ac0e2..2c462bb5 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -23,6 +23,7 @@ export const sites = sqliteTable("sites", { subnet: text("subnet").notNull(), megabytesIn: integer("bytesIn"), megabytesOut: integer("bytesOut"), + type: text("type").notNull(), // "newt" or "wireguard" }); export const resources = sqliteTable("resources", { @@ -60,6 +61,7 @@ export const targets = sqliteTable("targets", { ip: text("ip").notNull(), method: text("method").notNull(), port: integer("port").notNull(), + internalPort: integer("internalPort"), protocol: text("protocol"), enabled: integer("enabled", { mode: "boolean" }).notNull().default(true), }); @@ -92,9 +94,7 @@ export const newts = sqliteTable("newt", { newtId: text("id").primaryKey(), secretHash: text("secretHash").notNull(), dateCreated: text("dateCreated").notNull(), - siteId: integer("siteId").references(() => sites.siteId, { - onDelete: "cascade", - }), + siteId: integer("siteId").references(() => sites.siteId), }); export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", { diff --git a/server/routers/newt/handleRegisterMessage.ts b/server/routers/newt/handleRegisterMessage.ts index 3524fd08..cf38b2af 100644 --- a/server/routers/newt/handleRegisterMessage.ts +++ b/server/routers/newt/handleRegisterMessage.ts @@ -1,42 +1,116 @@ import db from "@server/db"; import { MessageHandler } from "../ws"; -import { sites } from "@server/db/schema"; -import { eq } from "drizzle-orm"; +import { exitNodes, resources, sites, targets } from "@server/db/schema"; +import { eq, inArray } from "drizzle-orm"; +import { addPeer, deletePeer } from "../gerbil/peers"; +import logger from "@server/logger"; +import { findNextAvailableCidr } from "@server/utils/ip"; +import { exit } from "process"; export const handleRegisterMessage: MessageHandler = async (context) => { const { message, newt, sendToClient } = context; - + if (!newt) { - console.log("Newt not found"); + logger.warn("Newt not found"); return; } if (!newt.siteId) { - console.log("Newt has no site!"); // TODO: Maybe we create the site here? + logger.warn("Newt has no site!"); // TODO: Maybe we create the site here? return; } - + const siteId = newt.siteId; - - // get the site - const site = await db + + const { publicKey } = message.data; + if (!publicKey) { + logger.warn("Public key not provided"); + return; + } + + // const [site] = await db + // .select() + // .from(sites) + // .where(eq(sites.siteId, siteId)) + // .limit(1); + + const [site] = await db + .update(sites) + .set({ + pubKey: publicKey + }) + .where(eq(sites.siteId, siteId)) + .returning(); + + + if (!site || !site.exitNodeId) { + logger.warn("Site not found or does not have exit node"); + return; + } + + const [exitNode] = await db + .select() + .from(exitNodes) + .where(eq(exitNodes.exitNodeId, site.exitNodeId)) + .limit(1); + + if (site.pubKey && site.pubKey !== publicKey) { + logger.info("Public key mismatch. Deleting old peer..."); + await deletePeer(site.exitNodeId, site.pubKey); + } + + if (!site.subnet) { + logger.warn("Site has no subnet"); + return; + } + + // add the peer to the exit node + await addPeer(site.exitNodeId, { + publicKey: publicKey, + allowedIps: [site.subnet], + }); + + const siteResources = await db.select().from(resources).where(eq(resources.siteId, siteId)); + + // get the targets from the resourceIds + const siteTargets = await db .select() - .from(sites) - .where(eq(sites.siteId, siteId)) - .limit(1); + .from(targets) + .where( + inArray( + targets.resourceId, + siteResources.map(resource => resource.resourceId) + ) + ); + const udpTargets = siteTargets + .filter((target) => target.protocol === "udp") + .map((target) => { + return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; + }); + + const tcpTargets = siteTargets + .filter((target) => target.protocol === "tcp") + .map((target) => { + return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; + }); - const { publicKey } = message.data; return { message: { - type: 'newt/wg/connect', + type: "newt/wg/connect", data: { - publicKey: 'publicKey', - - } + endpoint: exitNode.endpoint, + publicKey: exitNode.publicKey, + serverIP: exitNode.address, + tunnelIP: site.subnet, + targets: { + udp: udpTargets, + tcp: tcpTargets, + } + }, }, - broadcast: false, // Send to all clients - excludeSender: false // Include sender in broadcast + broadcast: false, // Send to all clients + excludeSender: false, // Include sender in broadcast }; -}; \ No newline at end of file +}; diff --git a/server/routers/newt/targets.ts b/server/routers/newt/targets.ts new file mode 100644 index 00000000..1d456329 --- /dev/null +++ b/server/routers/newt/targets.ts @@ -0,0 +1,73 @@ +import { Target } from "@server/db/schema"; +import { sendToClient } from "../ws"; + +export async function addTargets(newtId: string, targets: Target[]): Promise { + //create a list of udp and tcp targets + const udpTargets = targets + .filter((target) => target.protocol === "udp") + .map((target) => { + return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; + }); + + const tcpTargets = targets + .filter((target) => target.protocol === "tcp") + .map((target) => { + return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; + }); + + if (udpTargets.length > 0) { + const payload = { + type: `newt/udp/add`, + data: { + targets: udpTargets, + }, + }; + sendToClient(newtId, payload); + } + + if (tcpTargets.length > 0) { + const payload = { + type: `newt/tcp/add`, + data: { + targets: tcpTargets, + }, + }; + sendToClient(newtId, payload); + } +} + + +export async function removeTargets(newtId: string, targets: Target[]): Promise { + //create a list of udp and tcp targets + const udpTargets = targets + .filter((target) => target.protocol === "udp") + .map((target) => { + return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; + }); + + const tcpTargets = targets + .filter((target) => target.protocol === "tcp") + .map((target) => { + return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; + }); + + if (udpTargets.length > 0) { + const payload = { + type: `newt/udp/remove`, + data: { + targets: udpTargets, + }, + }; + sendToClient(newtId, payload); + } + + if (tcpTargets.length > 0) { + const payload = { + type: `newt/tcp/remove`, + data: { + targets: tcpTargets, + }, + }; + sendToClient(newtId, payload); + } +} diff --git a/server/routers/site/createSite.ts b/server/routers/site/createSite.ts index 24527fae..b624d6ea 100644 --- a/server/routers/site/createSite.ts +++ b/server/routers/site/createSite.ts @@ -22,6 +22,7 @@ const createSiteSchema = z subdomain: z.string().min(1).max(255).optional(), pubKey: z.string().optional(), subnet: z.string(), + type: z.string(), }) .strict(); @@ -48,7 +49,7 @@ export async function createSite( ); } - const { name, subdomain, exitNodeId, pubKey, subnet } = parsedBody.data; + const { name, type, exitNodeId, pubKey, subnet } = parsedBody.data; const parsedParams = createSiteParamsSchema.safeParse(req.params); if (!parsedParams.success) { @@ -76,6 +77,7 @@ export async function createSite( name, niceId, subnet, + type, }; if (pubKey) { @@ -114,10 +116,17 @@ export async function createSite( if (pubKey) { // add the peer to the exit node - await addPeer(exitNodeId, { - publicKey: pubKey, - allowedIps: [], - }); + if (type == "newt") { + await addPeer(exitNodeId, { + publicKey: pubKey, + allowedIps: [subnet], + }); + } else if (type == "wireguard") { + await addPeer(exitNodeId, { + publicKey: pubKey, + allowedIps: [], + }); + } } return response(res, { diff --git a/server/routers/site/updateSite.ts b/server/routers/site/updateSite.ts index 71b36144..70c60fd0 100644 --- a/server/routers/site/updateSite.ts +++ b/server/routers/site/updateSite.ts @@ -17,11 +17,11 @@ const updateSiteBodySchema = z .object({ name: z.string().min(1).max(255).optional(), subdomain: z.string().min(1).max(255).optional(), - pubKey: z.string().optional(), - subnet: z.string().optional(), - exitNode: z.number().int().positive().optional(), - megabytesIn: z.number().int().nonnegative().optional(), - megabytesOut: z.number().int().nonnegative().optional(), + // pubKey: z.string().optional(), + // subnet: z.string().optional(), + // exitNode: z.number().int().positive().optional(), + // megabytesIn: z.number().int().nonnegative().optional(), + // megabytesOut: z.number().int().nonnegative().optional(), }) .strict() .refine((data) => Object.keys(data).length > 0, { diff --git a/server/routers/target/createTarget.ts b/server/routers/target/createTarget.ts index d98437f1..cbbd5a23 100644 --- a/server/routers/target/createTarget.ts +++ b/server/routers/target/createTarget.ts @@ -10,6 +10,7 @@ import { addPeer } from "../gerbil/peers"; import { eq, and } from "drizzle-orm"; import { isIpInCidr } from "@server/utils/ip"; import { fromError } from "zod-validation-error"; +import { addTargets } from "../newt/targets"; const createTargetParamsSchema = z.object({ resourceId: z.string().transform(Number).pipe(z.number().int().positive()), @@ -111,25 +112,31 @@ export async function createTarget( }) .returning(); - // Fetch resources for this site - const resourcesRes = await db.query.resources.findMany({ - where: eq(resources.siteId, site.siteId), - }); + if (site.pubKey) { + if ( site.type == "wireguard") { + // Fetch resources for this site + const resourcesRes = await db.query.resources.findMany({ + where: eq(resources.siteId, site.siteId), + }); - // Fetch targets for all resources of this site - const targetIps = await Promise.all( - resourcesRes.map(async (resource) => { - const targetsRes = await db.query.targets.findMany({ - where: eq(targets.resourceId, resource.resourceId), - }); - return targetsRes.map((target) => `${target.ip}/32`); - }) - ); + // Fetch targets for all resources of this site + const targetIps = await Promise.all( + resourcesRes.map(async (resource) => { + const targetsRes = await db.query.targets.findMany({ + where: eq(targets.resourceId, resource.resourceId), + }); + return targetsRes.map((target) => `${target.ip}/32`); + }) + ); - await addPeer(site.exitNodeId!, { - publicKey: site.pubKey, - allowedIps: targetIps.flat(), - }); + await addPeer(site.exitNodeId!, { + publicKey: site.pubKey, + allowedIps: targetIps.flat(), + }); + } else if (site.type == "newt") { + addTargets("",newTarget); // TODO: we need to generate and save the internal port somewhere and also come up with the newtId + } + } return response(res, { data: newTarget[0], diff --git a/server/routers/target/deleteTarget.ts b/server/routers/target/deleteTarget.ts index b79b2847..efe69181 100644 --- a/server/routers/target/deleteTarget.ts +++ b/server/routers/target/deleteTarget.ts @@ -9,6 +9,7 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { addPeer } from "../gerbil/peers"; import { fromError } from "zod-validation-error"; +import { removeTargets } from "../newt/targets"; const deleteTargetSchema = z.object({ targetId: z.string().transform(Number).pipe(z.number().int().positive()), @@ -80,25 +81,31 @@ export async function deleteTarget( ); } - // Fetch resources for this site - const resourcesRes = await db.query.resources.findMany({ - where: eq(resources.siteId, site.siteId), - }); - - // Fetch targets for all resources of this site - const targetIps = await Promise.all( - resourcesRes.map(async (resource) => { - const targetsRes = await db.query.targets.findMany({ - where: eq(targets.resourceId, resource.resourceId), + if (site.pubKey) { + if (site.type == "wireguard") { + // Fetch resources for this site + const resourcesRes = await db.query.resources.findMany({ + where: eq(resources.siteId, site.siteId), }); - return targetsRes.map((target) => `${target.ip}/32`); - }) - ); - await addPeer(site.exitNodeId!, { - publicKey: site.pubKey, - allowedIps: targetIps.flat(), - }); + // Fetch targets for all resources of this site + const targetIps = await Promise.all( + resourcesRes.map(async (resource) => { + const targetsRes = await db.query.targets.findMany({ + where: eq(targets.resourceId, resource.resourceId), + }); + return targetsRes.map((target) => `${target.ip}/32`); + }) + ); + + await addPeer(site.exitNodeId!, { + publicKey: site.pubKey, + allowedIps: targetIps.flat(), + }); + } else if (site.type == "newt") { + removeTargets("", [deletedTarget]); // TODO: we need to generate and save the internal port somewhere and also come up with the newtId + } + } return response(res, { data: null, diff --git a/server/routers/target/updateTarget.ts b/server/routers/target/updateTarget.ts index 58fa4914..7adf4c7a 100644 --- a/server/routers/target/updateTarget.ts +++ b/server/routers/target/updateTarget.ts @@ -1,13 +1,14 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; import { db } from "@server/db"; -import { targets } from "@server/db/schema"; +import { resources, sites, targets } from "@server/db/schema"; import { eq } from "drizzle-orm"; import response from "@server/utils/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; +import { addPeer } from "../gerbil/peers"; const updateTargetParamsSchema = z.object({ targetId: z.string().transform(Number).pipe(z.number().int().positive()), @@ -53,14 +54,14 @@ export async function updateTarget( const { targetId } = parsedParams.data; const updateData = parsedBody.data; - - const updatedTarget = await db + + const [updatedTarget] = await db .update(targets) .set(updateData) .where(eq(targets.targetId, targetId)) .returning(); - if (updatedTarget.length === 0) { + if (!updatedTarget) { return next( createHttpError( HttpCode.NOT_FOUND, @@ -69,8 +70,65 @@ export async function updateTarget( ); } + // get the resource + const [resource] = await db + .select({ + siteId: resources.siteId, + }) + .from(resources) + .where(eq(resources.resourceId, updatedTarget.resourceId!)); + + if (!resource) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Resource with ID ${updatedTarget.resourceId} not found` + ) + ); + } + + // TODO: is this all inefficient? + + // get the site + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, resource.siteId!)) + .limit(1); + + if (!site) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Site with ID ${resource.siteId} not found` + ) + ); + } + + if (site.pubKey && site.type == "wireguard") { + // Fetch resources for this site + const resourcesRes = await db.query.resources.findMany({ + where: eq(resources.siteId, site.siteId), + }); + + // Fetch targets for all resources of this site + const targetIps = await Promise.all( + resourcesRes.map(async (resource) => { + const targetsRes = await db.query.targets.findMany({ + where: eq(targets.resourceId, resource.resourceId), + }); + return targetsRes.map((target) => `${target.ip}/32`); + }) + ); + + await addPeer(site.exitNodeId!, { + publicKey: site.pubKey, + allowedIps: targetIps.flat(), + }); + } + return response(res, { - data: updatedTarget[0], + data: updatedTarget, success: true, error: false, message: "Target updated successfully",