Add basic transactions

This commit is contained in:
Owen Schwartz 2024-12-24 16:00:02 -05:00
parent c8676ce06a
commit 2f328fc719
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
22 changed files with 548 additions and 459 deletions

View file

@ -3,11 +3,10 @@ import { resourceOtp } from "@server/db/schema";
import { and, eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { createDate, isWithinExpirationDate, TimeSpan } from "oslo"; import { createDate, isWithinExpirationDate, TimeSpan } from "oslo";
import { alphabet, generateRandomString, sha256 } from "oslo/crypto"; import { alphabet, generateRandomString, sha256 } from "oslo/crypto";
import { encodeHex } from "oslo/encoding";
import { sendEmail } from "@server/emails"; import { sendEmail } from "@server/emails";
import ResourceOTPCode from "@server/emails/templates/ResourceOTPCode"; import ResourceOTPCode from "@server/emails/templates/ResourceOTPCode";
import config from "@server/config"; import config from "@server/config";
import { hash, verify } from "@node-rs/argon2"; import { verifyPassword } from "./password";
import { hashPassword } from "./password"; import { hashPassword } from "./password";
export async function sendResourceOtpEmail( export async function sendResourceOtpEmail(
@ -37,24 +36,25 @@ export async function generateResourceOtpCode(
resourceId: number, resourceId: number,
email: string email: string
): Promise<string> { ): Promise<string> {
await db
.delete(resourceOtp)
.where(
and(
eq(resourceOtp.email, email),
eq(resourceOtp.resourceId, resourceId)
)
);
const otp = generateRandomString(8, alphabet("0-9", "A-Z", "a-z")); const otp = generateRandomString(8, alphabet("0-9", "A-Z", "a-z"));
await db.transaction(async (trx) => {
await trx
.delete(resourceOtp)
.where(
and(
eq(resourceOtp.email, email),
eq(resourceOtp.resourceId, resourceId)
)
);
const otpHash = await hashPassword(otp); const otpHash = await hashPassword(otp);
await db.insert(resourceOtp).values({ await trx.insert(resourceOtp).values({
resourceId, resourceId,
email, email,
otpHash, otpHash,
expiresAt: createDate(new TimeSpan(15, "m")).getTime() expiresAt: createDate(new TimeSpan(15, "m")).getTime()
});
}); });
return otp; return otp;

View file

@ -31,18 +31,18 @@ async function generateEmailVerificationCode(
userId: string, userId: string,
email: string email: string
): Promise<string> { ): Promise<string> {
await db
.delete(emailVerificationCodes)
.where(eq(emailVerificationCodes.userId, userId));
const code = generateRandomString(8, alphabet("0-9")); const code = generateRandomString(8, alphabet("0-9"));
await db.transaction(async (trx) => {
await trx
.delete(emailVerificationCodes)
.where(eq(emailVerificationCodes.userId, userId));
await db.insert(emailVerificationCodes).values({ await trx.insert(emailVerificationCodes).values({
userId, userId,
email, email,
code, code,
expiresAt: createDate(new TimeSpan(15, "m")).getTime() expiresAt: createDate(new TimeSpan(15, "m")).getTime()
});
}); });
return code; return code;
} }

View file

@ -12,10 +12,12 @@ import { verifyPassword } from "@server/auth/password";
import { verifyTotpCode } from "@server/auth/2fa"; import { verifyTotpCode } from "@server/auth/2fa";
import logger from "@server/logger"; import logger from "@server/logger";
export const disable2faBody = z.object({ export const disable2faBody = z
password: z.string(), .object({
code: z.string().optional(), password: z.string(),
}).strict(); code: z.string().optional()
})
.strict();
export type Disable2faBody = z.infer<typeof disable2faBody>; export type Disable2faBody = z.infer<typeof disable2faBody>;
@ -26,7 +28,7 @@ export type Disable2faResponse = {
export async function disable2fa( export async function disable2fa(
req: Request, req: Request,
res: Response, res: Response,
next: NextFunction, next: NextFunction
): Promise<any> { ): Promise<any> {
const parsedBody = disable2faBody.safeParse(req.body); const parsedBody = disable2faBody.safeParse(req.body);
@ -34,8 +36,8 @@ export async function disable2fa(
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString(), fromError(parsedBody.error).toString()
), )
); );
} }
@ -52,8 +54,8 @@ export async function disable2fa(
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
"Two-factor authentication is already disabled", "Two-factor authentication is already disabled"
), )
); );
} else { } else {
if (!code) { if (!code) {
@ -62,7 +64,7 @@ export async function disable2fa(
success: true, success: true,
error: false, error: false,
message: "Two-factor authentication required", message: "Two-factor authentication required",
status: HttpCode.ACCEPTED, status: HttpCode.ACCEPTED
}); });
} }
} }
@ -70,27 +72,28 @@ export async function disable2fa(
const validOTP = await verifyTotpCode( const validOTP = await verifyTotpCode(
code, code,
user.twoFactorSecret!, user.twoFactorSecret!,
user.userId, user.userId
); );
if (!validOTP) { if (!validOTP) {
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
"The two-factor code you entered is incorrect", "The two-factor code you entered is incorrect"
), )
); );
} }
await db await db.transaction(async (trx) => {
.update(users) await trx
.set({ twoFactorEnabled: false }) .update(users)
.where(eq(users.userId, user.userId)); .set({ twoFactorEnabled: false })
.where(eq(users.userId, user.userId));
await db
.delete(twoFactorBackupCodes)
.where(eq(twoFactorBackupCodes.userId, user.userId));
await trx
.delete(twoFactorBackupCodes)
.where(eq(twoFactorBackupCodes.userId, user.userId));
});
// TODO: send email to user confirming two-factor authentication is disabled // TODO: send email to user confirming two-factor authentication is disabled
return response<null>(res, { return response<null>(res, {
@ -98,15 +101,15 @@ export async function disable2fa(
success: true, success: true,
error: false, error: false,
message: "Two-factor authentication disabled", message: "Two-factor authentication disabled",
status: HttpCode.OK, status: HttpCode.OK
}); });
} catch (error) { } catch (error) {
logger.error(error); logger.error(error);
return next( return next(
createHttpError( createHttpError(
HttpCode.INTERNAL_SERVER_ERROR, HttpCode.INTERNAL_SERVER_ERROR,
"Failed to disable two-factor authentication", "Failed to disable two-factor authentication"
), )
); );
} }
} }

View file

@ -63,18 +63,23 @@ export async function requestPasswordReset(
); );
} }
await db const token = generateRandomString(
.delete(passwordResetTokens) 8,
.where(eq(passwordResetTokens.userId, existingUser[0].userId)); alphabet("0-9", "A-Z", "a-z")
);
await db.transaction(async (trx) => {
await trx
.delete(passwordResetTokens)
.where(eq(passwordResetTokens.userId, existingUser[0].userId));
const token = generateRandomString(8, alphabet("0-9", "A-Z", "a-z")); const tokenHash = await hashPassword(token);
const tokenHash = await hashPassword(token);
await db.insert(passwordResetTokens).values({ await trx.insert(passwordResetTokens).values({
userId: existingUser[0].userId, userId: existingUser[0].userId,
email: existingUser[0].email, email: existingUser[0].email,
tokenHash, tokenHash,
expiresAt: createDate(new TimeSpan(2, "h")).getTime() expiresAt: createDate(new TimeSpan(2, "h")).getTime()
});
}); });
const url = `${config.app.base_url}/auth/reset-password?email=${email}&token=${token}`; const url = `${config.app.base_url}/auth/reset-password?email=${email}&token=${token}`;

View file

@ -135,20 +135,22 @@ export async function resetPassword(
await invalidateAllSessions(resetRequest[0].userId); await invalidateAllSessions(resetRequest[0].userId);
await db await db.transaction(async (trx) => {
.update(users) await trx
.set({ passwordHash }) .update(users)
.where(eq(users.userId, resetRequest[0].userId)); .set({ passwordHash })
.where(eq(users.userId, resetRequest[0].userId));
await db await trx
.delete(passwordResetTokens) .delete(passwordResetTokens)
.where(eq(passwordResetTokens.email, email)); .where(eq(passwordResetTokens.email, email));
});
await sendEmail(ConfirmPasswordReset({ email }), { await sendEmail(ConfirmPasswordReset({ email }), {
from: config.email?.no_reply, from: config.email?.no_reply,
to: email, to: email,
subject: "Password Reset Confirmation" subject: "Password Reset Confirmation"
}) });
return response<ResetPasswordResponse>(res, { return response<ResetPasswordResponse>(res, {
data: null, data: null,

View file

@ -62,16 +62,18 @@ export async function verifyEmail(
const valid = await isValidCode(user, code); const valid = await isValidCode(user, code);
if (valid) { if (valid) {
await db await db.transaction(async (trx) => {
.delete(emailVerificationCodes) await trx
.where(eq(emailVerificationCodes.userId, user.userId)); .delete(emailVerificationCodes)
.where(eq(emailVerificationCodes.userId, user.userId));
await db await trx
.update(users) .update(users)
.set({ .set({
emailVerified: true emailVerified: true
}) })
.where(eq(users.userId, user.userId)); .where(eq(users.userId, user.userId));
});
} else { } else {
return next( return next(
createHttpError( createHttpError(

View file

@ -73,21 +73,23 @@ export async function verifyTotp(
let codes; let codes;
if (valid) { if (valid) {
// if valid, enable two-factor authentication; the totp secret is no longer temporary // if valid, enable two-factor authentication; the totp secret is no longer temporary
await db await db.transaction(async (trx) => {
.update(users) await trx
.set({ twoFactorEnabled: true }) .update(users)
.where(eq(users.userId, user.userId)); .set({ twoFactorEnabled: true })
.where(eq(users.userId, user.userId));
const backupCodes = await generateBackupCodes(); const backupCodes = await generateBackupCodes();
codes = backupCodes; codes = backupCodes;
for (const code of backupCodes) { for (const code of backupCodes) {
const hash = await hashPassword(code); const hash = await hashPassword(code);
await db.insert(twoFactorBackupCodes).values({ await trx.insert(twoFactorBackupCodes).values({
userId: user.userId, userId: user.userId,
codeHash: hash codeHash: hash
}); });
} }
});
} }
// TODO: send email to user confirming two-factor authentication is enabled // TODO: send email to user confirming two-factor authentication is enabled

View file

@ -1,10 +1,10 @@
import { Request, Response, NextFunction } from 'express'; import { Request, Response, NextFunction } from "express";
import { DrizzleError, eq } from 'drizzle-orm'; import { DrizzleError, eq } from "drizzle-orm";
import { sites, resources, targets, exitNodes } from '@server/db/schema'; import { sites, resources, targets, exitNodes } from "@server/db/schema";
import db from '@server/db'; import db from "@server/db";
import logger from '@server/logger'; import logger from "@server/logger";
import createHttpError from 'http-errors'; import createHttpError from "http-errors";
import HttpCode from '@server/types/HttpCode'; import HttpCode from "@server/types/HttpCode";
import response from "@server/utils/response"; import response from "@server/utils/response";
interface PeerBandwidth { interface PeerBandwidth {
@ -13,62 +13,76 @@ interface PeerBandwidth {
bytesOut: number; bytesOut: number;
} }
export const receiveBandwidth = async (req: Request, res: Response, next: NextFunction): Promise<any> => { export const receiveBandwidth = async (
req: Request,
res: Response,
next: NextFunction
): Promise<any> => {
try { try {
const bandwidthData: PeerBandwidth[] = req.body; const bandwidthData: PeerBandwidth[] = req.body;
if (!Array.isArray(bandwidthData)) { if (!Array.isArray(bandwidthData)) {
throw new Error('Invalid bandwidth data'); throw new Error("Invalid bandwidth data");
} }
for (const peer of bandwidthData) { await db.transaction(async (trx) => {
const { publicKey, bytesIn, bytesOut } = peer; for (const peer of bandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// Find the site by public key // Find the site by public key
const site = await db.query.sites.findFirst({ const site = await trx.query.sites.findFirst({
where: eq(sites.pubKey, publicKey), where: eq(sites.pubKey, publicKey)
}); });
if (!site) { if (!site) {
logger.warn(`Site not found for public key: ${publicKey}`); logger.warn(`Site not found for public key: ${publicKey}`);
continue; continue;
}
let online = site.online;
// if the bandwidth for the site is > 0 then set it to online. if it has been less than 0 (no update) for 5 minutes then set it to offline
if (bytesIn > 0 || bytesOut > 0) {
online = true;
} else if (site.lastBandwidthUpdate) {
const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
const currentTime = new Date();
const diff = currentTime.getTime() - lastBandwidthUpdate.getTime();
if (diff < 300000) {
online = false;
} }
let online = site.online;
// if the bandwidth for the site is > 0 then set it to online. if it has been less than 0 (no update) for 5 minutes then set it to offline
if (bytesIn > 0 || bytesOut > 0) {
online = true;
} else if (site.lastBandwidthUpdate) {
const lastBandwidthUpdate = new Date(
site.lastBandwidthUpdate
);
const currentTime = new Date();
const diff =
currentTime.getTime() - lastBandwidthUpdate.getTime();
if (diff < 300000) {
online = false;
}
}
// Update the site's bandwidth usage
await trx
.update(sites)
.set({
megabytesIn: (site.megabytesIn || 0) + bytesIn,
megabytesOut: (site.megabytesOut || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString(),
online
})
.where(eq(sites.siteId, site.siteId));
} }
});
// Update the site's bandwidth usage
await db.update(sites)
.set({
megabytesIn: (site.megabytesIn || 0) + bytesIn,
megabytesOut: (site.megabytesOut || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString(),
online,
})
.where(eq(sites.siteId, site.siteId));
}
return response(res, { return response(res, {
data: {}, data: {},
success: true, success: true,
error: false, error: false,
message: "Organization retrieved successfully", message: "Organization retrieved successfully",
status: HttpCode.OK, status: HttpCode.OK
}); });
} catch (error) { } catch (error) {
logger.error('Error updating bandwidth data:', error); logger.error("Error updating bandwidth data:", error);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred...")); return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred..."
)
);
} }
}; };

View file

@ -39,15 +39,14 @@ export const handleRegisterMessage: MessageHandler = async (context) => {
return; return;
} }
const [updatedSite] = await db await db
.update(sites) .update(sites)
.set({ .set({
pubKey: publicKey pubKey: publicKey
}) })
.where(eq(sites.siteId, siteId)) .where(eq(sites.siteId, siteId))
.returning(); .returning();
const [exitNode] = await db const [exitNode] = await db
.select() .select()
.from(exitNodes) .from(exitNodes)
@ -67,35 +66,41 @@ export const handleRegisterMessage: MessageHandler = async (context) => {
// add the peer to the exit node // add the peer to the exit node
await addPeer(site.exitNodeId, { await addPeer(site.exitNodeId, {
publicKey: publicKey, publicKey: publicKey,
allowedIps: [site.subnet], allowedIps: [site.subnet]
}); });
const siteResources = await db.select().from(resources).where(eq(resources.siteId, siteId)); const siteResources = await db
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
// get the targets from the resourceIds // get the targets from the resourceIds
const siteTargets = await db const siteTargets = await db
.select() .select()
.from(targets) .from(targets)
.where( .where(
inArray( inArray(
targets.resourceId, targets.resourceId,
siteResources.map(resource => resource.resourceId) siteResources.map((resource) => resource.resourceId)
) )
); );
const udpTargets = siteTargets const udpTargets = siteTargets
.filter((target) => target.protocol === "udp") .filter((target) => target.protocol === "udp")
.map((target) => { .map((target) => {
return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; return `${target.internalPort ? target.internalPort + ":" : ""}${
target.ip
}:${target.port}`;
}); });
const tcpTargets = siteTargets const tcpTargets = siteTargets
.filter((target) => target.protocol === "tcp") .filter((target) => target.protocol === "tcp")
.map((target) => { .map((target) => {
return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`; return `${target.internalPort ? target.internalPort + ":" : ""}${
target.ip
}:${target.port}`;
}); });
return { return {
message: { message: {
type: "newt/wg/connect", type: "newt/wg/connect",
@ -106,11 +111,11 @@ export const handleRegisterMessage: MessageHandler = async (context) => {
tunnelIP: site.subnet.split("/")[0], tunnelIP: site.subnet.split("/")[0],
targets: { targets: {
udp: udpTargets, udp: udpTargets,
tcp: tcpTargets, tcp: tcpTargets
} }
}, }
}, },
broadcast: false, // Send to all clients broadcast: false, // Send to all clients
excludeSender: false, // Include sender in broadcast excludeSender: false // Include sender in broadcast
}; };
}; };

View file

@ -24,9 +24,7 @@ const deleteOrgSchema = z
}) })
.strict(); .strict();
export type DeleteOrgResponse = { export type DeleteOrgResponse = {};
}
export async function deleteOrg( export async function deleteOrg(
req: Request, req: Request,
@ -79,39 +77,47 @@ export async function deleteOrg(
.where(eq(sites.orgId, orgId)) .where(eq(sites.orgId, orgId))
.limit(1); .limit(1);
if (sites) { await db.transaction(async (trx) => {
for (const site of orgSites) { if (sites) {
if (site.pubKey) { for (const site of orgSites) {
if (site.type == "wireguard") { if (site.pubKey) {
await deletePeer(site.exitNodeId!, site.pubKey); if (site.type == "wireguard") {
} else if (site.type == "newt") { await deletePeer(site.exitNodeId!, site.pubKey);
// get the newt on the site by querying the newt table for siteId } else if (site.type == "newt") {
const [deletedNewt] = await db // get the newt on the site by querying the newt table for siteId
.delete(newts) const [deletedNewt] = await trx
.where(eq(newts.siteId, site.siteId)) .delete(newts)
.returning(); .where(eq(newts.siteId, site.siteId))
if (deletedNewt) { .returning();
const payload = { if (deletedNewt) {
type: `newt/terminate`, const payload = {
data: {} type: `newt/terminate`,
}; data: {}
sendToClient(deletedNewt.newtId, payload); };
sendToClient(deletedNewt.newtId, payload);
// delete all of the sessions for the newt // delete all of the sessions for the newt
await db.delete(newtSessions) await trx
.where( .delete(newtSessions)
eq(newtSessions.newtId, deletedNewt.newtId) .where(
); eq(
newtSessions.newtId,
deletedNewt.newtId
)
);
}
} }
} }
logger.info(`Deleting site ${site.siteId}`);
await trx
.delete(sites)
.where(eq(sites.siteId, site.siteId));
} }
logger.info(`Deleting site ${site.siteId}`);
await db.delete(sites).where(eq(sites.siteId, site.siteId))
} }
}
await db.delete(orgs).where(eq(orgs.orgId, orgId)); await trx.delete(orgs).where(eq(orgs.orgId, orgId));
});
return response(res, { return response(res, {
data: null, data: null,

View file

@ -89,50 +89,50 @@ export async function createResource(
} }
const fullDomain = `${subdomain}.${org[0].domain}`; const fullDomain = `${subdomain}.${org[0].domain}`;
await db.transaction(async (trx) => {
const newResource = await trx
.insert(resources)
.values({
siteId,
fullDomain,
orgId,
name,
subdomain,
ssl: true
})
.returning();
const newResource = await db const adminRole = await db
.insert(resources) .select()
.values({ .from(roles)
siteId, .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
fullDomain, .limit(1);
orgId,
name,
subdomain,
ssl: true
})
.returning();
const adminRole = await db if (adminRole.length === 0) {
.select() return next(
.from(roles) createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) );
.limit(1); }
if (adminRole.length === 0) { await trx.insert(roleResources).values({
return next( roleId: adminRole[0].roleId,
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
}
await db.insert(roleResources).values({
roleId: adminRole[0].roleId,
resourceId: newResource[0].resourceId
});
if (req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the resource
await db.insert(userResources).values({
userId: req.user?.userId!,
resourceId: newResource[0].resourceId resourceId: newResource[0].resourceId
}); });
}
response<CreateResourceResponse>(res, { if (req.userOrgRoleId != adminRole[0].roleId) {
data: newResource[0], // make sure the user can access the resource
success: true, await trx.insert(userResources).values({
error: false, userId: req.user?.userId!,
message: "Resource created successfully", resourceId: newResource[0].resourceId
status: HttpCode.CREATED });
}
response<CreateResourceResponse>(res, {
data: newResource[0],
success: true,
error: false,
message: "Resource created successfully",
status: HttpCode.CREATED
});
}); });
} catch (error) { } catch (error) {
if ( if (

View file

@ -51,32 +51,34 @@ export async function addRoleSite(
const { roleId } = parsedParams.data; const { roleId } = parsedParams.data;
const newRoleSite = await db await db.transaction(async (trx) => {
.insert(roleSites) const newRoleSite = await trx
.values({ .insert(roleSites)
roleId, .values({
siteId roleId,
}) siteId
.returning(); })
.returning();
const siteResources = await db const siteResources = await db
.select() .select()
.from(resources) .from(resources)
.where(eq(resources.siteId, siteId)); .where(eq(resources.siteId, siteId));
for (const resource of siteResources) { for (const resource of siteResources) {
await db.insert(roleResources).values({ await trx.insert(roleResources).values({
roleId, roleId,
resourceId: resource.resourceId resourceId: resource.resourceId
});
}
return response(res, {
data: newRoleSite[0],
success: true,
error: false,
message: "Site added to role successfully",
status: HttpCode.CREATED
}); });
}
return response(res, {
data: newRoleSite[0],
success: true,
error: false,
message: "Site added to role successfully",
status: HttpCode.CREATED
}); });
} catch (error) { } catch (error) {
logger.error(error); logger.error(error);

View file

@ -82,31 +82,33 @@ export async function createRole(
); );
} }
const newRole = await db await db.transaction(async (trx) => {
.insert(roles) const newRole = await trx
.values({ .insert(roles)
...roleData, .values({
orgId ...roleData,
})
.returning();
await db
.insert(roleActions)
.values(
defaultRoleAllowedActions.map((action) => ({
roleId: newRole[0].roleId,
actionId: action,
orgId orgId
})) })
) .returning();
.execute();
return response<Role>(res, { await trx
data: newRole[0], .insert(roleActions)
success: true, .values(
error: false, defaultRoleAllowedActions.map((action) => ({
message: "Role created successfully", roleId: newRole[0].roleId,
status: HttpCode.CREATED actionId: action,
orgId
}))
)
.execute();
return response<Role>(res, {
data: newRole[0],
success: true,
error: false,
message: "Role created successfully",
status: HttpCode.CREATED
});
}); });
} catch (error) { } catch (error) {
logger.error(error); logger.error(error);

View file

@ -98,14 +98,16 @@ export async function deleteRole(
); );
} }
// move all users from the userOrgs table with roleId to newRoleId await db.transaction(async (trx) => {
await db // move all users from the userOrgs table with roleId to newRoleId
.update(userOrgs) await trx
.set({ roleId: newRoleId }) .update(userOrgs)
.where(eq(userOrgs.roleId, roleId)); .set({ roleId: newRoleId })
.where(eq(userOrgs.roleId, roleId));
// delete the old role // delete the old role
await db.delete(roles).where(eq(roles.roleId, roleId)); await trx.delete(roles).where(eq(roles.roleId, roleId));
});
return response(res, { return response(res, {
data: null, data: null,

View file

@ -51,38 +51,43 @@ export async function removeRoleSite(
const { roleId } = parsedBody.data; const { roleId } = parsedBody.data;
const deletedRoleSite = await db await db.transaction(async (trx) => {
.delete(roleSites) const deletedRoleSite = await trx
.where( .delete(roleSites)
and(eq(roleSites.roleId, roleId), eq(roleSites.siteId, siteId))
)
.returning();
if (deletedRoleSite.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${siteId} not found for role with ID ${roleId}`
)
);
}
const siteResources = await db
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
for (const resource of siteResources) {
await db
.delete(roleResources)
.where( .where(
and( and(
eq(roleResources.roleId, roleId), eq(roleSites.roleId, roleId),
eq(roleResources.resourceId, resource.resourceId) eq(roleSites.siteId, siteId)
) )
) )
.returning(); .returning();
}
if (deletedRoleSite.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${siteId} not found for role with ID ${roleId}`
)
);
}
const siteResources = await db
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
for (const resource of siteResources) {
await trx
.delete(roleResources)
.where(
and(
eq(roleResources.roleId, roleId),
eq(roleResources.resourceId, resource.resourceId)
)
)
.returning();
}
});
return response(res, { return response(res, {
data: null, data: null,

View file

@ -94,64 +94,69 @@ export async function createSite(
}; };
} }
const [newSite] = await db.insert(sites).values(payload).returning(); await db.transaction(async (trx) => {
const [newSite] = await trx
.insert(sites)
.values(payload)
.returning();
const adminRole = await db const adminRole = await trx
.select() .select()
.from(roles) .from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1); .limit(1);
if (adminRole.length === 0) { if (adminRole.length === 0) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
}
await db.insert(roleSites).values({
roleId: adminRole[0].roleId,
siteId: newSite.siteId
});
if (req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
db.insert(userSites).values({
userId: req.user?.userId!,
siteId: newSite.siteId
});
}
// add the peer to the exit node
if (type == "newt") {
const secretHash = await hashPassword(secret!);
await db.insert(newts).values({
newtId: newtId!,
secretHash,
siteId: newSite.siteId,
dateCreated: moment().toISOString()
});
} else if (type == "wireguard") {
if (!pubKey) {
return next( return next(
createHttpError( createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
HttpCode.BAD_REQUEST,
"Public key is required for wireguard sites"
)
); );
} }
await addPeer(exitNodeId, {
publicKey: pubKey,
allowedIps: []
});
}
return response<CreateSiteResponse>(res, { await trx.insert(roleSites).values({
data: newSite, roleId: adminRole[0].roleId,
success: true, siteId: newSite.siteId
error: false, });
message: "Site created successfully",
status: HttpCode.CREATED if (req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
trx.insert(userSites).values({
userId: req.user?.userId!,
siteId: newSite.siteId
});
}
// add the peer to the exit node
if (type == "newt") {
const secretHash = await hashPassword(secret!);
await trx.insert(newts).values({
newtId: newtId!,
secretHash,
siteId: newSite.siteId,
dateCreated: moment().toISOString()
});
} else if (type == "wireguard") {
if (!pubKey) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Public key is required for wireguard sites"
)
);
}
await addPeer(exitNodeId, {
publicKey: pubKey,
allowedIps: []
});
}
return response<CreateSiteResponse>(res, {
data: newSite,
success: true,
error: false,
message: "Site created successfully",
status: HttpCode.CREATED
});
}); });
} catch (error) { } catch (error) {
logger.error(error); logger.error(error);

View file

@ -50,31 +50,33 @@ export async function deleteSite(
); );
} }
if (site.pubKey) { await db.transaction(async (trx) => {
if (site.type == "wireguard") { if (site.pubKey) {
await deletePeer(site.exitNodeId!, site.pubKey); if (site.type == "wireguard") {
} else if (site.type == "newt") { await deletePeer(site.exitNodeId!, site.pubKey);
// get the newt on the site by querying the newt table for siteId } else if (site.type == "newt") {
const [deletedNewt] = await db // get the newt on the site by querying the newt table for siteId
.delete(newts) const [deletedNewt] = await trx
.where(eq(newts.siteId, siteId)) .delete(newts)
.returning(); .where(eq(newts.siteId, siteId))
if (deletedNewt) { .returning();
const payload = { if (deletedNewt) {
type: `newt/terminate`, const payload = {
data: {} type: `newt/terminate`,
}; data: {}
sendToClient(deletedNewt.newtId, payload); };
sendToClient(deletedNewt.newtId, payload);
// delete all of the sessions for the newt // delete all of the sessions for the newt
db.delete(newtSessions) await trx
.where(eq(newtSessions.newtId, deletedNewt.newtId)) .delete(newtSessions)
.run(); .where(eq(newtSessions.newtId, deletedNewt.newtId));
}
} }
} }
}
db.delete(sites).where(eq(sites.siteId, siteId)).run(); await trx.delete(sites).where(eq(sites.siteId, siteId));
});
return response(res, { return response(res, {
data: null, data: null,

View file

@ -118,15 +118,19 @@ export async function acceptInvite(
); );
} }
// add the user to the org await db.transaction(async (trx) => {
await db.insert(userOrgs).values({ // add the user to the org
userId: existingUser[0].userId, await trx.insert(userOrgs).values({
orgId: existingInvite[0].orgId, userId: existingUser[0].userId,
roleId: existingInvite[0].roleId orgId: existingInvite[0].orgId,
}); roleId: existingInvite[0].roleId
});
// delete the invite // delete the invite
await db.delete(userInvites).where(eq(userInvites.inviteId, inviteId)); await trx
.delete(userInvites)
.where(eq(userInvites.inviteId, inviteId));
});
return response<AcceptInviteResponse>(res, { return response<AcceptInviteResponse>(res, {
data: { accepted: true, orgId: existingInvite[0].orgId }, data: { accepted: true, orgId: existingInvite[0].orgId },

View file

@ -34,33 +34,36 @@ export async function addUserSite(
const { userId, siteId } = parsedBody.data; const { userId, siteId } = parsedBody.data;
const newUserSite = await db await db.transaction(async (trx) => {
.insert(userSites) const newUserSite = await trx
.values({ .insert(userSites)
userId, .values({
siteId userId,
}) siteId
.returning(); })
.returning();
const siteResources = await db const siteResources = await trx
.select() .select()
.from(resources) .from(resources)
.where(eq(resources.siteId, siteId)); .where(eq(resources.siteId, siteId));
for (const resource of siteResources) { for (const resource of siteResources) {
await db.insert(userResources).values({ await trx.insert(userResources).values({
userId, userId,
resourceId: resource.resourceId resourceId: resource.resourceId
});
}
return response(res, {
data: newUserSite[0],
success: true,
error: false,
message: "Site added to user successfully",
status: HttpCode.CREATED
}); });
}
return response(res, {
data: newUserSite[0],
success: true,
error: false,
message: "Site added to user successfully",
status: HttpCode.CREATED
}); });
} catch (error) { } catch (error) {
logger.error(error); logger.error(error);
return next( return next(

View file

@ -130,21 +130,26 @@ export async function inviteUser(
const tokenHash = await hashPassword(token); const tokenHash = await hashPassword(token);
// delete any existing invites for this email await db.transaction(async (trx) => {
await db // delete any existing invites for this email
.delete(userInvites) await trx
.where( .delete(userInvites)
and(eq(userInvites.email, email), eq(userInvites.orgId, orgId)) .where(
) and(
.execute(); eq(userInvites.email, email),
eq(userInvites.orgId, orgId)
)
)
.execute();
await db.insert(userInvites).values({ await trx.insert(userInvites).values({
inviteId, inviteId,
orgId, orgId,
email, email,
expiresAt, expiresAt,
tokenHash, tokenHash,
roleId roleId
});
}); });
const inviteLink = `${config.app.base_url}/invite?token=${inviteId}-${token}`; const inviteLink = `${config.app.base_url}/invite?token=${inviteId}-${token}`;

View file

@ -51,38 +51,43 @@ export async function removeUserSite(
const { siteId } = parsedBody.data; const { siteId } = parsedBody.data;
const deletedUserSite = await db await db.transaction(async (trx) => {
.delete(userSites) const deletedUserSite = await trx
.where( .delete(userSites)
and(eq(userSites.userId, userId), eq(userSites.siteId, siteId))
)
.returning();
if (deletedUserSite.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${siteId} not found for user with ID ${userId}`
)
);
}
const siteResources = await db
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
for (const resource of siteResources) {
await db
.delete(userResources)
.where( .where(
and( and(
eq(userResources.userId, userId), eq(userSites.userId, userId),
eq(userResources.resourceId, resource.resourceId) eq(userSites.siteId, siteId)
) )
) )
.returning(); .returning();
}
if (deletedUserSite.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${siteId} not found for user with ID ${userId}`
)
);
}
const siteResources = await trx
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
for (const resource of siteResources) {
await trx
.delete(userResources)
.where(
and(
eq(userResources.userId, userId),
eq(userResources.resourceId, resource.resourceId)
)
)
.returning();
}
});
return response(res, { return response(res, {
data: null, data: null,

View file

@ -22,13 +22,15 @@ export async function ensureActions() {
.where(eq(roles.isAdmin, true)) .where(eq(roles.isAdmin, true))
.execute(); .execute();
await db.transaction(async (trx) => {
// Add new actions // Add new actions
for (const actionId of actionsToAdd) { for (const actionId of actionsToAdd) {
logger.debug(`Adding action: ${actionId}`); logger.debug(`Adding action: ${actionId}`);
await db.insert(actions).values({ actionId }).execute(); await trx.insert(actions).values({ actionId }).execute();
// Add new actions to the Default role // Add new actions to the Default role
if (defaultRoles.length != 0) { if (defaultRoles.length != 0) {
await db await trx
.insert(roleActions) .insert(roleActions)
.values( .values(
defaultRoles.map((role) => ({ defaultRoles.map((role) => ({
@ -44,19 +46,23 @@ export async function ensureActions() {
// Remove deprecated actions // Remove deprecated actions
if (actionsToRemove.length > 0) { if (actionsToRemove.length > 0) {
logger.debug(`Removing actions: ${actionsToRemove.join(", ")}`); logger.debug(`Removing actions: ${actionsToRemove.join(", ")}`);
await db await trx
.delete(actions) .delete(actions)
.where(inArray(actions.actionId, actionsToRemove)) .where(inArray(actions.actionId, actionsToRemove))
.execute(); .execute();
await db await trx
.delete(roleActions) .delete(roleActions)
.where(inArray(roleActions.actionId, actionsToRemove)) .where(inArray(roleActions.actionId, actionsToRemove))
.execute(); .execute();
} }
});
} }
export async function createAdminRole(orgId: string) { export async function createAdminRole(orgId: string) {
const [insertedRole] = await db let roleId: any;
await db.transaction(async (trx) => {
const [insertedRole] = await trx
.insert(roles) .insert(roles)
.values({ .values({
orgId, orgId,
@ -67,16 +73,20 @@ export async function createAdminRole(orgId: string) {
.returning({ roleId: roles.roleId }) .returning({ roleId: roles.roleId })
.execute(); .execute();
const roleId = insertedRole.roleId; if (!insertedRole || !insertedRole.roleId) {
throw new Error("Failed to create Admin role");
}
const actionIds = await db.select().from(actions).execute(); roleId = insertedRole.roleId;
const actionIds = await trx.select().from(actions).execute();
if (actionIds.length === 0) { if (actionIds.length === 0) {
logger.info("No actions to assign to the Admin role"); logger.info("No actions to assign to the Admin role");
return; return;
} }
await db await trx
.insert(roleActions) .insert(roleActions)
.values( .values(
actionIds.map((action) => ({ actionIds.map((action) => ({
@ -86,6 +96,11 @@ export async function createAdminRole(orgId: string) {
})) }))
) )
.execute(); .execute();
});
if (!roleId) {
throw new Error("Failed to create Admin role");
}
return roleId; return roleId;
} }