diff --git a/server/auth/resourceOtp.ts b/server/auth/resourceOtp.ts index a9de7499..43d25776 100644 --- a/server/auth/resourceOtp.ts +++ b/server/auth/resourceOtp.ts @@ -3,11 +3,10 @@ import { resourceOtp } from "@server/db/schema"; import { and, eq } from "drizzle-orm"; import { createDate, isWithinExpirationDate, TimeSpan } from "oslo"; import { alphabet, generateRandomString, sha256 } from "oslo/crypto"; -import { encodeHex } from "oslo/encoding"; import { sendEmail } from "@server/emails"; import ResourceOTPCode from "@server/emails/templates/ResourceOTPCode"; import config from "@server/config"; -import { hash, verify } from "@node-rs/argon2"; +import { verifyPassword } from "./password"; import { hashPassword } from "./password"; export async function sendResourceOtpEmail( @@ -37,24 +36,25 @@ export async function generateResourceOtpCode( resourceId: number, email: string ): Promise { - await db - .delete(resourceOtp) - .where( - and( - eq(resourceOtp.email, email), - eq(resourceOtp.resourceId, resourceId) - ) - ); - const otp = generateRandomString(8, alphabet("0-9", "A-Z", "a-z")); + await db.transaction(async (trx) => { + await trx + .delete(resourceOtp) + .where( + and( + eq(resourceOtp.email, email), + eq(resourceOtp.resourceId, resourceId) + ) + ); - const otpHash = await hashPassword(otp); + const otpHash = await hashPassword(otp); - await db.insert(resourceOtp).values({ - resourceId, - email, - otpHash, - expiresAt: createDate(new TimeSpan(15, "m")).getTime() + await trx.insert(resourceOtp).values({ + resourceId, + email, + otpHash, + expiresAt: createDate(new TimeSpan(15, "m")).getTime() + }); }); return otp; diff --git a/server/auth/sendEmailVerificationCode.ts b/server/auth/sendEmailVerificationCode.ts index 9d9dc08d..4d961d25 100644 --- a/server/auth/sendEmailVerificationCode.ts +++ b/server/auth/sendEmailVerificationCode.ts @@ -31,18 +31,18 @@ async function generateEmailVerificationCode( userId: string, email: string ): Promise { - await db - .delete(emailVerificationCodes) - .where(eq(emailVerificationCodes.userId, userId)); - const code = generateRandomString(8, alphabet("0-9")); + await db.transaction(async (trx) => { + await trx + .delete(emailVerificationCodes) + .where(eq(emailVerificationCodes.userId, userId)); - await db.insert(emailVerificationCodes).values({ - userId, - email, - code, - expiresAt: createDate(new TimeSpan(15, "m")).getTime() + await trx.insert(emailVerificationCodes).values({ + userId, + email, + code, + expiresAt: createDate(new TimeSpan(15, "m")).getTime() + }); }); - return code; } diff --git a/server/routers/auth/disable2fa.ts b/server/routers/auth/disable2fa.ts index 05ed6338..85e6894e 100644 --- a/server/routers/auth/disable2fa.ts +++ b/server/routers/auth/disable2fa.ts @@ -12,10 +12,12 @@ import { verifyPassword } from "@server/auth/password"; import { verifyTotpCode } from "@server/auth/2fa"; import logger from "@server/logger"; -export const disable2faBody = z.object({ - password: z.string(), - code: z.string().optional(), -}).strict(); +export const disable2faBody = z + .object({ + password: z.string(), + code: z.string().optional() + }) + .strict(); export type Disable2faBody = z.infer; @@ -26,7 +28,7 @@ export type Disable2faResponse = { export async function disable2fa( req: Request, res: Response, - next: NextFunction, + next: NextFunction ): Promise { const parsedBody = disable2faBody.safeParse(req.body); @@ -34,8 +36,8 @@ export async function disable2fa( return next( createHttpError( HttpCode.BAD_REQUEST, - fromError(parsedBody.error).toString(), - ), + fromError(parsedBody.error).toString() + ) ); } @@ -52,8 +54,8 @@ export async function disable2fa( return next( createHttpError( HttpCode.BAD_REQUEST, - "Two-factor authentication is already disabled", - ), + "Two-factor authentication is already disabled" + ) ); } else { if (!code) { @@ -62,7 +64,7 @@ export async function disable2fa( success: true, error: false, message: "Two-factor authentication required", - status: HttpCode.ACCEPTED, + status: HttpCode.ACCEPTED }); } } @@ -70,27 +72,28 @@ export async function disable2fa( const validOTP = await verifyTotpCode( code, user.twoFactorSecret!, - user.userId, + user.userId ); if (!validOTP) { return next( createHttpError( HttpCode.BAD_REQUEST, - "The two-factor code you entered is incorrect", - ), + "The two-factor code you entered is incorrect" + ) ); } - await db - .update(users) - .set({ twoFactorEnabled: false }) - .where(eq(users.userId, user.userId)); - - await db - .delete(twoFactorBackupCodes) - .where(eq(twoFactorBackupCodes.userId, user.userId)); + await db.transaction(async (trx) => { + await trx + .update(users) + .set({ twoFactorEnabled: false }) + .where(eq(users.userId, user.userId)); + await trx + .delete(twoFactorBackupCodes) + .where(eq(twoFactorBackupCodes.userId, user.userId)); + }); // TODO: send email to user confirming two-factor authentication is disabled return response(res, { @@ -98,15 +101,15 @@ export async function disable2fa( success: true, error: false, message: "Two-factor authentication disabled", - status: HttpCode.OK, + status: HttpCode.OK }); } catch (error) { logger.error(error); return next( createHttpError( HttpCode.INTERNAL_SERVER_ERROR, - "Failed to disable two-factor authentication", - ), + "Failed to disable two-factor authentication" + ) ); } } diff --git a/server/routers/auth/requestPasswordReset.ts b/server/routers/auth/requestPasswordReset.ts index 62902c0a..df57ec4b 100644 --- a/server/routers/auth/requestPasswordReset.ts +++ b/server/routers/auth/requestPasswordReset.ts @@ -63,18 +63,23 @@ export async function requestPasswordReset( ); } - await db - .delete(passwordResetTokens) - .where(eq(passwordResetTokens.userId, existingUser[0].userId)); + const token = generateRandomString( + 8, + alphabet("0-9", "A-Z", "a-z") + ); + await db.transaction(async (trx) => { + await trx + .delete(passwordResetTokens) + .where(eq(passwordResetTokens.userId, existingUser[0].userId)); - const token = generateRandomString(8, alphabet("0-9", "A-Z", "a-z")); - const tokenHash = await hashPassword(token); + const tokenHash = await hashPassword(token); - await db.insert(passwordResetTokens).values({ - userId: existingUser[0].userId, - email: existingUser[0].email, - tokenHash, - expiresAt: createDate(new TimeSpan(2, "h")).getTime() + await trx.insert(passwordResetTokens).values({ + userId: existingUser[0].userId, + email: existingUser[0].email, + tokenHash, + expiresAt: createDate(new TimeSpan(2, "h")).getTime() + }); }); const url = `${config.app.base_url}/auth/reset-password?email=${email}&token=${token}`; diff --git a/server/routers/auth/resetPassword.ts b/server/routers/auth/resetPassword.ts index 1c358e39..259e4e28 100644 --- a/server/routers/auth/resetPassword.ts +++ b/server/routers/auth/resetPassword.ts @@ -135,20 +135,22 @@ export async function resetPassword( await invalidateAllSessions(resetRequest[0].userId); - await db - .update(users) - .set({ passwordHash }) - .where(eq(users.userId, resetRequest[0].userId)); + await db.transaction(async (trx) => { + await trx + .update(users) + .set({ passwordHash }) + .where(eq(users.userId, resetRequest[0].userId)); - await db - .delete(passwordResetTokens) - .where(eq(passwordResetTokens.email, email)); + await trx + .delete(passwordResetTokens) + .where(eq(passwordResetTokens.email, email)); + }); await sendEmail(ConfirmPasswordReset({ email }), { from: config.email?.no_reply, to: email, subject: "Password Reset Confirmation" - }) + }); return response(res, { data: null, diff --git a/server/routers/auth/verifyEmail.ts b/server/routers/auth/verifyEmail.ts index 59525a0b..a73983ce 100644 --- a/server/routers/auth/verifyEmail.ts +++ b/server/routers/auth/verifyEmail.ts @@ -62,16 +62,18 @@ export async function verifyEmail( const valid = await isValidCode(user, code); if (valid) { - await db - .delete(emailVerificationCodes) - .where(eq(emailVerificationCodes.userId, user.userId)); + await db.transaction(async (trx) => { + await trx + .delete(emailVerificationCodes) + .where(eq(emailVerificationCodes.userId, user.userId)); - await db - .update(users) - .set({ - emailVerified: true - }) - .where(eq(users.userId, user.userId)); + await trx + .update(users) + .set({ + emailVerified: true + }) + .where(eq(users.userId, user.userId)); + }); } else { return next( createHttpError( diff --git a/server/routers/auth/verifyTotp.ts b/server/routers/auth/verifyTotp.ts index 185f3d1a..448a6256 100644 --- a/server/routers/auth/verifyTotp.ts +++ b/server/routers/auth/verifyTotp.ts @@ -73,21 +73,23 @@ export async function verifyTotp( let codes; if (valid) { // if valid, enable two-factor authentication; the totp secret is no longer temporary - await db - .update(users) - .set({ twoFactorEnabled: true }) - .where(eq(users.userId, user.userId)); + await db.transaction(async (trx) => { + await trx + .update(users) + .set({ twoFactorEnabled: true }) + .where(eq(users.userId, user.userId)); - const backupCodes = await generateBackupCodes(); - codes = backupCodes; - for (const code of backupCodes) { - const hash = await hashPassword(code); + const backupCodes = await generateBackupCodes(); + codes = backupCodes; + for (const code of backupCodes) { + const hash = await hashPassword(code); - await db.insert(twoFactorBackupCodes).values({ - userId: user.userId, - codeHash: hash - }); - } + await trx.insert(twoFactorBackupCodes).values({ + userId: user.userId, + codeHash: hash + }); + } + }); } // TODO: send email to user confirming two-factor authentication is enabled diff --git a/server/routers/gerbil/receiveBandwidth.ts b/server/routers/gerbil/receiveBandwidth.ts index 266373a1..0c8d8dd7 100644 --- a/server/routers/gerbil/receiveBandwidth.ts +++ b/server/routers/gerbil/receiveBandwidth.ts @@ -1,10 +1,10 @@ -import { Request, Response, NextFunction } from 'express'; -import { DrizzleError, eq } from 'drizzle-orm'; -import { sites, resources, targets, exitNodes } from '@server/db/schema'; -import db from '@server/db'; -import logger from '@server/logger'; -import createHttpError from 'http-errors'; -import HttpCode from '@server/types/HttpCode'; +import { Request, Response, NextFunction } from "express"; +import { DrizzleError, eq } from "drizzle-orm"; +import { sites, resources, targets, exitNodes } from "@server/db/schema"; +import db from "@server/db"; +import logger from "@server/logger"; +import createHttpError from "http-errors"; +import HttpCode from "@server/types/HttpCode"; import response from "@server/utils/response"; interface PeerBandwidth { @@ -13,62 +13,76 @@ interface PeerBandwidth { bytesOut: number; } -export const receiveBandwidth = async (req: Request, res: Response, next: NextFunction): Promise => { +export const receiveBandwidth = async ( + req: Request, + res: Response, + next: NextFunction +): Promise => { try { const bandwidthData: PeerBandwidth[] = req.body; if (!Array.isArray(bandwidthData)) { - throw new Error('Invalid bandwidth data'); + throw new Error("Invalid bandwidth data"); } - for (const peer of bandwidthData) { - const { publicKey, bytesIn, bytesOut } = peer; + await db.transaction(async (trx) => { + for (const peer of bandwidthData) { + const { publicKey, bytesIn, bytesOut } = peer; - // Find the site by public key - const site = await db.query.sites.findFirst({ - where: eq(sites.pubKey, publicKey), - }); + // Find the site by public key + const site = await trx.query.sites.findFirst({ + where: eq(sites.pubKey, publicKey) + }); - if (!site) { - logger.warn(`Site not found for public key: ${publicKey}`); - continue; - } - let online = site.online; - - // if the bandwidth for the site is > 0 then set it to online. if it has been less than 0 (no update) for 5 minutes then set it to offline - if (bytesIn > 0 || bytesOut > 0) { - online = true; - } else if (site.lastBandwidthUpdate) { - const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate); - const currentTime = new Date(); - const diff = currentTime.getTime() - lastBandwidthUpdate.getTime(); - if (diff < 300000) { - online = false; + if (!site) { + logger.warn(`Site not found for public key: ${publicKey}`); + continue; } + let online = site.online; + + // if the bandwidth for the site is > 0 then set it to online. if it has been less than 0 (no update) for 5 minutes then set it to offline + if (bytesIn > 0 || bytesOut > 0) { + online = true; + } else if (site.lastBandwidthUpdate) { + const lastBandwidthUpdate = new Date( + site.lastBandwidthUpdate + ); + const currentTime = new Date(); + const diff = + currentTime.getTime() - lastBandwidthUpdate.getTime(); + if (diff < 300000) { + online = false; + } + } + + // Update the site's bandwidth usage + await trx + .update(sites) + .set({ + megabytesIn: (site.megabytesIn || 0) + bytesIn, + megabytesOut: (site.megabytesOut || 0) + bytesOut, + lastBandwidthUpdate: new Date().toISOString(), + online + }) + .where(eq(sites.siteId, site.siteId)); } - - // Update the site's bandwidth usage - await db.update(sites) - .set({ - megabytesIn: (site.megabytesIn || 0) + bytesIn, - megabytesOut: (site.megabytesOut || 0) + bytesOut, - lastBandwidthUpdate: new Date().toISOString(), - online, - }) - .where(eq(sites.siteId, site.siteId)); - - } + }); return response(res, { data: {}, success: true, error: false, message: "Organization retrieved successfully", - status: HttpCode.OK, + status: HttpCode.OK }); } catch (error) { - logger.error('Error updating bandwidth data:', error); - return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred...")); + logger.error("Error updating bandwidth data:", error); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "An error occurred..." + ) + ); } }; diff --git a/server/routers/newt/handleRegisterMessage.ts b/server/routers/newt/handleRegisterMessage.ts index 868372f2..2721ec86 100644 --- a/server/routers/newt/handleRegisterMessage.ts +++ b/server/routers/newt/handleRegisterMessage.ts @@ -39,15 +39,14 @@ export const handleRegisterMessage: MessageHandler = async (context) => { return; } - const [updatedSite] = await db - .update(sites) + await db + .update(sites) .set({ pubKey: publicKey }) .where(eq(sites.siteId, siteId)) .returning(); - const [exitNode] = await db .select() .from(exitNodes) @@ -67,35 +66,41 @@ export const handleRegisterMessage: MessageHandler = async (context) => { // add the peer to the exit node await addPeer(site.exitNodeId, { publicKey: publicKey, - allowedIps: [site.subnet], + allowedIps: [site.subnet] }); - const siteResources = await db.select().from(resources).where(eq(resources.siteId, siteId)); + const siteResources = await db + .select() + .from(resources) + .where(eq(resources.siteId, siteId)); // get the targets from the resourceIds const siteTargets = await db - .select() - .from(targets) - .where( - inArray( - targets.resourceId, - siteResources.map(resource => resource.resourceId) - ) - ); + .select() + .from(targets) + .where( + inArray( + targets.resourceId, + siteResources.map((resource) => resource.resourceId) + ) + ); const udpTargets = siteTargets .filter((target) => target.protocol === "udp") .map((target) => { - return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; + return `${target.internalPort ? target.internalPort + ":" : ""}${ + target.ip + }:${target.port}`; }); const tcpTargets = siteTargets .filter((target) => target.protocol === "tcp") .map((target) => { - return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; + return `${target.internalPort ? target.internalPort + ":" : ""}${ + target.ip + }:${target.port}`; }); - return { message: { type: "newt/wg/connect", @@ -106,11 +111,11 @@ export const handleRegisterMessage: MessageHandler = async (context) => { tunnelIP: site.subnet.split("/")[0], targets: { udp: udpTargets, - tcp: tcpTargets, + tcp: tcpTargets } - }, + } }, broadcast: false, // Send to all clients - excludeSender: false, // Include sender in broadcast + excludeSender: false // Include sender in broadcast }; }; diff --git a/server/routers/org/deleteOrg.ts b/server/routers/org/deleteOrg.ts index 3ebea645..4bf1eaaa 100644 --- a/server/routers/org/deleteOrg.ts +++ b/server/routers/org/deleteOrg.ts @@ -24,9 +24,7 @@ const deleteOrgSchema = z }) .strict(); -export type DeleteOrgResponse = { - -} +export type DeleteOrgResponse = {}; export async function deleteOrg( req: Request, @@ -79,39 +77,47 @@ export async function deleteOrg( .where(eq(sites.orgId, orgId)) .limit(1); - if (sites) { - for (const site of orgSites) { - if (site.pubKey) { - if (site.type == "wireguard") { - await deletePeer(site.exitNodeId!, site.pubKey); - } else if (site.type == "newt") { - // get the newt on the site by querying the newt table for siteId - const [deletedNewt] = await db - .delete(newts) - .where(eq(newts.siteId, site.siteId)) - .returning(); - if (deletedNewt) { - const payload = { - type: `newt/terminate`, - data: {} - }; - sendToClient(deletedNewt.newtId, payload); + await db.transaction(async (trx) => { + if (sites) { + for (const site of orgSites) { + if (site.pubKey) { + if (site.type == "wireguard") { + await deletePeer(site.exitNodeId!, site.pubKey); + } else if (site.type == "newt") { + // get the newt on the site by querying the newt table for siteId + const [deletedNewt] = await trx + .delete(newts) + .where(eq(newts.siteId, site.siteId)) + .returning(); + if (deletedNewt) { + const payload = { + type: `newt/terminate`, + data: {} + }; + sendToClient(deletedNewt.newtId, payload); - // delete all of the sessions for the newt - await db.delete(newtSessions) - .where( - eq(newtSessions.newtId, deletedNewt.newtId) - ); + // delete all of the sessions for the newt + await trx + .delete(newtSessions) + .where( + eq( + newtSessions.newtId, + deletedNewt.newtId + ) + ); + } } } + + logger.info(`Deleting site ${site.siteId}`); + await trx + .delete(sites) + .where(eq(sites.siteId, site.siteId)); } - - logger.info(`Deleting site ${site.siteId}`); - await db.delete(sites).where(eq(sites.siteId, site.siteId)) } - } - await db.delete(orgs).where(eq(orgs.orgId, orgId)); + await trx.delete(orgs).where(eq(orgs.orgId, orgId)); + }); return response(res, { data: null, diff --git a/server/routers/resource/createResource.ts b/server/routers/resource/createResource.ts index 018faf2c..d6ae47f2 100644 --- a/server/routers/resource/createResource.ts +++ b/server/routers/resource/createResource.ts @@ -89,50 +89,50 @@ export async function createResource( } const fullDomain = `${subdomain}.${org[0].domain}`; + await db.transaction(async (trx) => { + const newResource = await trx + .insert(resources) + .values({ + siteId, + fullDomain, + orgId, + name, + subdomain, + ssl: true + }) + .returning(); - const newResource = await db - .insert(resources) - .values({ - siteId, - fullDomain, - orgId, - name, - subdomain, - ssl: true - }) - .returning(); + const adminRole = await db + .select() + .from(roles) + .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) + .limit(1); - const adminRole = await db - .select() - .from(roles) - .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) - .limit(1); + if (adminRole.length === 0) { + return next( + createHttpError(HttpCode.NOT_FOUND, `Admin role not found`) + ); + } - if (adminRole.length === 0) { - return next( - createHttpError(HttpCode.NOT_FOUND, `Admin role not found`) - ); - } - - await db.insert(roleResources).values({ - roleId: adminRole[0].roleId, - resourceId: newResource[0].resourceId - }); - - if (req.userOrgRoleId != adminRole[0].roleId) { - // make sure the user can access the resource - await db.insert(userResources).values({ - userId: req.user?.userId!, + await trx.insert(roleResources).values({ + roleId: adminRole[0].roleId, resourceId: newResource[0].resourceId }); - } - response(res, { - data: newResource[0], - success: true, - error: false, - message: "Resource created successfully", - status: HttpCode.CREATED + if (req.userOrgRoleId != adminRole[0].roleId) { + // make sure the user can access the resource + await trx.insert(userResources).values({ + userId: req.user?.userId!, + resourceId: newResource[0].resourceId + }); + } + response(res, { + data: newResource[0], + success: true, + error: false, + message: "Resource created successfully", + status: HttpCode.CREATED + }); }); } catch (error) { if ( diff --git a/server/routers/role/addRoleSite.ts b/server/routers/role/addRoleSite.ts index 5204cae5..c702614c 100644 --- a/server/routers/role/addRoleSite.ts +++ b/server/routers/role/addRoleSite.ts @@ -51,32 +51,34 @@ export async function addRoleSite( const { roleId } = parsedParams.data; - const newRoleSite = await db - .insert(roleSites) - .values({ - roleId, - siteId - }) - .returning(); + await db.transaction(async (trx) => { + const newRoleSite = await trx + .insert(roleSites) + .values({ + roleId, + siteId + }) + .returning(); - const siteResources = await db - .select() - .from(resources) - .where(eq(resources.siteId, siteId)); + const siteResources = await db + .select() + .from(resources) + .where(eq(resources.siteId, siteId)); - for (const resource of siteResources) { - await db.insert(roleResources).values({ - roleId, - resourceId: resource.resourceId + for (const resource of siteResources) { + await trx.insert(roleResources).values({ + roleId, + resourceId: resource.resourceId + }); + } + + return response(res, { + data: newRoleSite[0], + success: true, + error: false, + message: "Site added to role successfully", + status: HttpCode.CREATED }); - } - - return response(res, { - data: newRoleSite[0], - success: true, - error: false, - message: "Site added to role successfully", - status: HttpCode.CREATED }); } catch (error) { logger.error(error); diff --git a/server/routers/role/createRole.ts b/server/routers/role/createRole.ts index dd3656fe..a6a82dc9 100644 --- a/server/routers/role/createRole.ts +++ b/server/routers/role/createRole.ts @@ -82,31 +82,33 @@ export async function createRole( ); } - const newRole = await db - .insert(roles) - .values({ - ...roleData, - orgId - }) - .returning(); - - await db - .insert(roleActions) - .values( - defaultRoleAllowedActions.map((action) => ({ - roleId: newRole[0].roleId, - actionId: action, + await db.transaction(async (trx) => { + const newRole = await trx + .insert(roles) + .values({ + ...roleData, orgId - })) - ) - .execute(); + }) + .returning(); - return response(res, { - data: newRole[0], - success: true, - error: false, - message: "Role created successfully", - status: HttpCode.CREATED + await trx + .insert(roleActions) + .values( + defaultRoleAllowedActions.map((action) => ({ + roleId: newRole[0].roleId, + actionId: action, + orgId + })) + ) + .execute(); + + return response(res, { + data: newRole[0], + success: true, + error: false, + message: "Role created successfully", + status: HttpCode.CREATED + }); }); } catch (error) { logger.error(error); diff --git a/server/routers/role/deleteRole.ts b/server/routers/role/deleteRole.ts index 1cc44b2f..708b0968 100644 --- a/server/routers/role/deleteRole.ts +++ b/server/routers/role/deleteRole.ts @@ -98,15 +98,17 @@ export async function deleteRole( ); } - // move all users from the userOrgs table with roleId to newRoleId - await db - .update(userOrgs) - .set({ roleId: newRoleId }) - .where(eq(userOrgs.roleId, roleId)); - - // delete the old role - await db.delete(roles).where(eq(roles.roleId, roleId)); + await db.transaction(async (trx) => { + // move all users from the userOrgs table with roleId to newRoleId + await trx + .update(userOrgs) + .set({ roleId: newRoleId }) + .where(eq(userOrgs.roleId, roleId)); + // delete the old role + await trx.delete(roles).where(eq(roles.roleId, roleId)); + }); + return response(res, { data: null, success: true, diff --git a/server/routers/role/removeRoleSite.ts b/server/routers/role/removeRoleSite.ts index 43efca52..e04b9dda 100644 --- a/server/routers/role/removeRoleSite.ts +++ b/server/routers/role/removeRoleSite.ts @@ -51,38 +51,43 @@ export async function removeRoleSite( const { roleId } = parsedBody.data; - const deletedRoleSite = await db - .delete(roleSites) - .where( - and(eq(roleSites.roleId, roleId), eq(roleSites.siteId, siteId)) - ) - .returning(); - - if (deletedRoleSite.length === 0) { - return next( - createHttpError( - HttpCode.NOT_FOUND, - `Site with ID ${siteId} not found for role with ID ${roleId}` - ) - ); - } - - const siteResources = await db - .select() - .from(resources) - .where(eq(resources.siteId, siteId)); - - for (const resource of siteResources) { - await db - .delete(roleResources) + await db.transaction(async (trx) => { + const deletedRoleSite = await trx + .delete(roleSites) .where( and( - eq(roleResources.roleId, roleId), - eq(roleResources.resourceId, resource.resourceId) + eq(roleSites.roleId, roleId), + eq(roleSites.siteId, siteId) ) ) .returning(); - } + + if (deletedRoleSite.length === 0) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Site with ID ${siteId} not found for role with ID ${roleId}` + ) + ); + } + + const siteResources = await db + .select() + .from(resources) + .where(eq(resources.siteId, siteId)); + + for (const resource of siteResources) { + await trx + .delete(roleResources) + .where( + and( + eq(roleResources.roleId, roleId), + eq(roleResources.resourceId, resource.resourceId) + ) + ) + .returning(); + } + }); return response(res, { data: null, diff --git a/server/routers/site/createSite.ts b/server/routers/site/createSite.ts index 458e0f63..73ecb490 100644 --- a/server/routers/site/createSite.ts +++ b/server/routers/site/createSite.ts @@ -94,64 +94,69 @@ export async function createSite( }; } - const [newSite] = await db.insert(sites).values(payload).returning(); + await db.transaction(async (trx) => { + const [newSite] = await trx + .insert(sites) + .values(payload) + .returning(); - const adminRole = await db - .select() - .from(roles) - .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) - .limit(1); + const adminRole = await trx + .select() + .from(roles) + .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) + .limit(1); - if (adminRole.length === 0) { - return next( - createHttpError(HttpCode.NOT_FOUND, `Admin role not found`) - ); - } - - await db.insert(roleSites).values({ - roleId: adminRole[0].roleId, - siteId: newSite.siteId - }); - - if (req.userOrgRoleId != adminRole[0].roleId) { - // make sure the user can access the site - db.insert(userSites).values({ - userId: req.user?.userId!, - siteId: newSite.siteId - }); - } - - // add the peer to the exit node - if (type == "newt") { - const secretHash = await hashPassword(secret!); - - await db.insert(newts).values({ - newtId: newtId!, - secretHash, - siteId: newSite.siteId, - dateCreated: moment().toISOString() - }); - } else if (type == "wireguard") { - if (!pubKey) { + if (adminRole.length === 0) { return next( - createHttpError( - HttpCode.BAD_REQUEST, - "Public key is required for wireguard sites" - ) + createHttpError(HttpCode.NOT_FOUND, `Admin role not found`) ); } - await addPeer(exitNodeId, { - publicKey: pubKey, - allowedIps: [] - }); - } - return response(res, { - data: newSite, - success: true, - error: false, - message: "Site created successfully", - status: HttpCode.CREATED + await trx.insert(roleSites).values({ + roleId: adminRole[0].roleId, + siteId: newSite.siteId + }); + + if (req.userOrgRoleId != adminRole[0].roleId) { + // make sure the user can access the site + trx.insert(userSites).values({ + userId: req.user?.userId!, + siteId: newSite.siteId + }); + } + + // add the peer to the exit node + if (type == "newt") { + const secretHash = await hashPassword(secret!); + + await trx.insert(newts).values({ + newtId: newtId!, + secretHash, + siteId: newSite.siteId, + dateCreated: moment().toISOString() + }); + } else if (type == "wireguard") { + if (!pubKey) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Public key is required for wireguard sites" + ) + ); + } + await addPeer(exitNodeId, { + publicKey: pubKey, + allowedIps: [] + }); + } + + return response(res, { + data: newSite, + success: true, + error: false, + message: "Site created successfully", + status: HttpCode.CREATED + }); }); } catch (error) { logger.error(error); diff --git a/server/routers/site/deleteSite.ts b/server/routers/site/deleteSite.ts index 8d6421fe..d739810b 100644 --- a/server/routers/site/deleteSite.ts +++ b/server/routers/site/deleteSite.ts @@ -50,32 +50,34 @@ export async function deleteSite( ); } - if (site.pubKey) { - if (site.type == "wireguard") { - await deletePeer(site.exitNodeId!, site.pubKey); - } else if (site.type == "newt") { - // get the newt on the site by querying the newt table for siteId - const [deletedNewt] = await db - .delete(newts) - .where(eq(newts.siteId, siteId)) - .returning(); - if (deletedNewt) { - const payload = { - type: `newt/terminate`, - data: {} - }; - sendToClient(deletedNewt.newtId, payload); + await db.transaction(async (trx) => { + if (site.pubKey) { + if (site.type == "wireguard") { + await deletePeer(site.exitNodeId!, site.pubKey); + } else if (site.type == "newt") { + // get the newt on the site by querying the newt table for siteId + const [deletedNewt] = await trx + .delete(newts) + .where(eq(newts.siteId, siteId)) + .returning(); + if (deletedNewt) { + const payload = { + type: `newt/terminate`, + data: {} + }; + sendToClient(deletedNewt.newtId, payload); - // delete all of the sessions for the newt - db.delete(newtSessions) - .where(eq(newtSessions.newtId, deletedNewt.newtId)) - .run(); + // delete all of the sessions for the newt + await trx + .delete(newtSessions) + .where(eq(newtSessions.newtId, deletedNewt.newtId)); + } } } - } - - db.delete(sites).where(eq(sites.siteId, siteId)).run(); + await trx.delete(sites).where(eq(sites.siteId, siteId)); + }); + return response(res, { data: null, success: true, diff --git a/server/routers/user/acceptInvite.ts b/server/routers/user/acceptInvite.ts index 3c3b720b..c097e5ff 100644 --- a/server/routers/user/acceptInvite.ts +++ b/server/routers/user/acceptInvite.ts @@ -118,16 +118,20 @@ export async function acceptInvite( ); } - // add the user to the org - await db.insert(userOrgs).values({ - userId: existingUser[0].userId, - orgId: existingInvite[0].orgId, - roleId: existingInvite[0].roleId + await db.transaction(async (trx) => { + // add the user to the org + await trx.insert(userOrgs).values({ + userId: existingUser[0].userId, + orgId: existingInvite[0].orgId, + roleId: existingInvite[0].roleId + }); + + // delete the invite + await trx + .delete(userInvites) + .where(eq(userInvites.inviteId, inviteId)); }); - - // delete the invite - await db.delete(userInvites).where(eq(userInvites.inviteId, inviteId)); - + return response(res, { data: { accepted: true, orgId: existingInvite[0].orgId }, success: true, diff --git a/server/routers/user/addUserSite.ts b/server/routers/user/addUserSite.ts index 22a08e09..4e6a2ef7 100644 --- a/server/routers/user/addUserSite.ts +++ b/server/routers/user/addUserSite.ts @@ -34,33 +34,36 @@ export async function addUserSite( const { userId, siteId } = parsedBody.data; - const newUserSite = await db - .insert(userSites) - .values({ - userId, - siteId - }) - .returning(); + await db.transaction(async (trx) => { + const newUserSite = await trx + .insert(userSites) + .values({ + userId, + siteId + }) + .returning(); - const siteResources = await db - .select() - .from(resources) - .where(eq(resources.siteId, siteId)); + const siteResources = await trx + .select() + .from(resources) + .where(eq(resources.siteId, siteId)); - for (const resource of siteResources) { - await db.insert(userResources).values({ - userId, - resourceId: resource.resourceId + for (const resource of siteResources) { + await trx.insert(userResources).values({ + userId, + resourceId: resource.resourceId + }); + } + + return response(res, { + data: newUserSite[0], + success: true, + error: false, + message: "Site added to user successfully", + status: HttpCode.CREATED }); - } - - return response(res, { - data: newUserSite[0], - success: true, - error: false, - message: "Site added to user successfully", - status: HttpCode.CREATED }); + } catch (error) { logger.error(error); return next( diff --git a/server/routers/user/inviteUser.ts b/server/routers/user/inviteUser.ts index 318a08b3..2073142f 100644 --- a/server/routers/user/inviteUser.ts +++ b/server/routers/user/inviteUser.ts @@ -130,21 +130,26 @@ export async function inviteUser( const tokenHash = await hashPassword(token); - // delete any existing invites for this email - await db - .delete(userInvites) - .where( - and(eq(userInvites.email, email), eq(userInvites.orgId, orgId)) - ) - .execute(); + await db.transaction(async (trx) => { + // delete any existing invites for this email + await trx + .delete(userInvites) + .where( + and( + eq(userInvites.email, email), + eq(userInvites.orgId, orgId) + ) + ) + .execute(); - await db.insert(userInvites).values({ - inviteId, - orgId, - email, - expiresAt, - tokenHash, - roleId + await trx.insert(userInvites).values({ + inviteId, + orgId, + email, + expiresAt, + tokenHash, + roleId + }); }); const inviteLink = `${config.app.base_url}/invite?token=${inviteId}-${token}`; diff --git a/server/routers/user/removeUserSite.ts b/server/routers/user/removeUserSite.ts index 153f76b7..6d10bce7 100644 --- a/server/routers/user/removeUserSite.ts +++ b/server/routers/user/removeUserSite.ts @@ -51,38 +51,43 @@ export async function removeUserSite( const { siteId } = parsedBody.data; - const deletedUserSite = await db - .delete(userSites) - .where( - and(eq(userSites.userId, userId), eq(userSites.siteId, siteId)) - ) - .returning(); - - if (deletedUserSite.length === 0) { - return next( - createHttpError( - HttpCode.NOT_FOUND, - `Site with ID ${siteId} not found for user with ID ${userId}` - ) - ); - } - - const siteResources = await db - .select() - .from(resources) - .where(eq(resources.siteId, siteId)); - - for (const resource of siteResources) { - await db - .delete(userResources) + await db.transaction(async (trx) => { + const deletedUserSite = await trx + .delete(userSites) .where( and( - eq(userResources.userId, userId), - eq(userResources.resourceId, resource.resourceId) + eq(userSites.userId, userId), + eq(userSites.siteId, siteId) ) ) .returning(); - } + + if (deletedUserSite.length === 0) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Site with ID ${siteId} not found for user with ID ${userId}` + ) + ); + } + + const siteResources = await trx + .select() + .from(resources) + .where(eq(resources.siteId, siteId)); + + for (const resource of siteResources) { + await trx + .delete(userResources) + .where( + and( + eq(userResources.userId, userId), + eq(userResources.resourceId, resource.resourceId) + ) + ) + .returning(); + } + }); return response(res, { data: null, diff --git a/server/setup/ensureActions.ts b/server/setup/ensureActions.ts index c83a0a6b..aa7b40f3 100644 --- a/server/setup/ensureActions.ts +++ b/server/setup/ensureActions.ts @@ -22,13 +22,15 @@ export async function ensureActions() { .where(eq(roles.isAdmin, true)) .execute(); + await db.transaction(async (trx) => { + // Add new actions for (const actionId of actionsToAdd) { logger.debug(`Adding action: ${actionId}`); - await db.insert(actions).values({ actionId }).execute(); + await trx.insert(actions).values({ actionId }).execute(); // Add new actions to the Default role if (defaultRoles.length != 0) { - await db + await trx .insert(roleActions) .values( defaultRoles.map((role) => ({ @@ -44,19 +46,23 @@ export async function ensureActions() { // Remove deprecated actions if (actionsToRemove.length > 0) { logger.debug(`Removing actions: ${actionsToRemove.join(", ")}`); - await db + await trx .delete(actions) .where(inArray(actions.actionId, actionsToRemove)) .execute(); - await db + await trx .delete(roleActions) .where(inArray(roleActions.actionId, actionsToRemove)) .execute(); } +}); } export async function createAdminRole(orgId: string) { - const [insertedRole] = await db + let roleId: any; + await db.transaction(async (trx) => { + + const [insertedRole] = await trx .insert(roles) .values({ orgId, @@ -67,16 +73,20 @@ export async function createAdminRole(orgId: string) { .returning({ roleId: roles.roleId }) .execute(); - const roleId = insertedRole.roleId; + if (!insertedRole || !insertedRole.roleId) { + throw new Error("Failed to create Admin role"); + } - const actionIds = await db.select().from(actions).execute(); + roleId = insertedRole.roleId; + + const actionIds = await trx.select().from(actions).execute(); if (actionIds.length === 0) { logger.info("No actions to assign to the Admin role"); return; } - await db + await trx .insert(roleActions) .values( actionIds.map((action) => ({ @@ -86,6 +96,11 @@ export async function createAdminRole(orgId: string) { })) ) .execute(); + }); + + if (!roleId) { + throw new Error("Failed to create Admin role"); + } return roleId; }