From 20db6d450c3d833d23b910d63c39f2238082ef42 Mon Sep 17 00:00:00 2001 From: Owen Schwartz Date: Sun, 6 Oct 2024 16:19:04 -0400 Subject: [PATCH] Update to verify middleware & lists agenst new permissions tables --- server/auth/actions.ts | 53 ++++++++++++ server/db/schema.ts | 28 +++--- server/index.ts | 4 +- server/routers/auth/getUserOrgs.ts | 2 +- server/routers/auth/verifyOrgAccess.ts | 3 +- server/routers/auth/verifyResourceAccess.ts | 94 ++++++++++++-------- server/routers/auth/verifySiteAccess.ts | 95 +++++++++++++-------- server/routers/auth/verifyTargetAccess.ts | 3 +- server/routers/org/listOrgs.ts | 2 +- server/routers/resource/listResources.ts | 65 +++++++++----- server/routers/site/listSites.ts | 52 +++++++---- server/types/Auth.ts | 2 +- 12 files changed, 275 insertions(+), 128 deletions(-) create mode 100644 server/auth/actions.ts diff --git a/server/auth/actions.ts b/server/auth/actions.ts new file mode 100644 index 00000000..8fa8f441 --- /dev/null +++ b/server/auth/actions.ts @@ -0,0 +1,53 @@ +import { Request } from 'express'; +import { db } from '@server/db'; +import { userActions, roleActions, userOrgs } from '@server/db/schema'; +import { and, eq, or } from 'drizzle-orm'; +import createHttpError from 'http-errors'; +import HttpCode from '@server/types/HttpCode'; + +export async function checkUserActionPermission(actionId: number, req: Request): Promise { + const userId = req.user?.id; + + if (!userId) { + throw createHttpError(HttpCode.UNAUTHORIZED, 'User not authenticated'); + } + + try { + // Check if the user has direct permission for the action + const userActionPermission = await db.select() + .from(userActions) + .where(and(eq(userActions.userId, userId), eq(userActions.actionId, actionId))) + .limit(1); + + if (userActionPermission.length > 0) { + return true; + } + + // If no direct permission, check role-based permission + const userOrgRoles = await db.select() + .from(userOrgs) + .where(eq(userOrgs.userId, userId)); + + if (userOrgRoles.length === 0) { + return false; // User doesn't belong to any organization + } + + const roleIds = userOrgRoles.map(role => role.roleId); + + const roleActionPermission = await db.select() + .from(roleActions) + .where( + and( + eq(roleActions.actionId, actionId), + or(...roleIds.map(roleId => eq(roleActions.roleId, roleId))) + ) + ) + .limit(1); + + return roleActionPermission.length > 0; + + } catch (error) { + console.error('Error checking user action permission:', error); + throw createHttpError(HttpCode.INTERNAL_SERVER_ERROR, 'Error checking action permission'); + } +} \ No newline at end of file diff --git a/server/db/schema.ts b/server/db/schema.ts index 896b6d84..46b264b1 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -107,7 +107,7 @@ export const userOrgs = sqliteTable("userOrgs", { orgId: integer("orgId") .notNull() .references(() => orgs.orgId), - role: text("role").notNull(), // e.g., 'admin', 'member', etc. + roleId: integer("roleId").notNull().references(() => roles.roleId), }); export const emailVerificationCodes = sqliteTable("emailVerificationCodes", { @@ -149,6 +149,9 @@ export const roleActions = sqliteTable("roleActions", { actionId: integer("actionId") .notNull() .references(() => actions.actionId, { onDelete: "cascade" }), + orgId: integer("orgId") + .notNull() + .references(() => orgs.orgId, { onDelete: "cascade" }), }); export const userActions = sqliteTable("userActions", { @@ -158,10 +161,13 @@ export const userActions = sqliteTable("userActions", { actionId: integer("actionId") .notNull() .references(() => actions.actionId, { onDelete: "cascade" }), + orgId: integer("orgId") + .notNull() + .references(() => orgs.orgId, { onDelete: "cascade" }), }); -export const roleSites = sqliteTable("roleActions", { - roleId: integer("role]Id") +export const roleSites = sqliteTable("roleSites", { + roleId: integer("roleId") .notNull() .references(() => roles.roleId, { onDelete: "cascade" }), siteId: integer("siteId") @@ -169,8 +175,8 @@ export const roleSites = sqliteTable("roleActions", { .references(() => sites.siteId, { onDelete: "cascade" }), }); -export const userSites = sqliteTable("userActions", { - userId: text("user]Id") +export const userSites = sqliteTable("userSites", { + userId: text("userId") .notNull() .references(() => users.id, { onDelete: "cascade" }), siteId: integer("siteId") @@ -178,20 +184,20 @@ export const userSites = sqliteTable("userActions", { .references(() => sites.siteId, { onDelete: "cascade" }), }); -export const roleResources = sqliteTable("roleActions", { - roleId: integer("role]Id") +export const roleResources = sqliteTable("roleResources", { + roleId: integer("roleId") .notNull() .references(() => roles.roleId, { onDelete: "cascade" }), - resourceId: integer("resourceId") + resourceId: text("resourceId") .notNull() .references(() => resources.resourceId, { onDelete: "cascade" }), }); -export const userResources = sqliteTable("userActions", { - userId: text("user]Id") +export const userResources = sqliteTable("userResources", { + userId: text("userId") .notNull() .references(() => users.id, { onDelete: "cascade" }), - resourceId: integer("resourceId") + resourceId: text("resourceId") .notNull() .references(() => resources.resourceId, { onDelete: "cascade" }), }); diff --git a/server/index.ts b/server/index.ts index 120e1400..d5169bda 100644 --- a/server/index.ts +++ b/server/index.ts @@ -82,8 +82,8 @@ declare global { namespace Express { interface Request { user?: User; - userOrgRole?: string; - userOrgs?: number[]; + userOrgRoleId?: number; + userOrgId?: number; } } } diff --git a/server/routers/auth/getUserOrgs.ts b/server/routers/auth/getUserOrgs.ts index f73bebca..1753d0b7 100644 --- a/server/routers/auth/getUserOrgs.ts +++ b/server/routers/auth/getUserOrgs.ts @@ -21,7 +21,7 @@ export async function getUserOrgs(req: Request, res: Response, next: NextFunctio .where(eq(userOrgs.userId, userId)); req.userOrgs = userOrganizations.map(org => org.orgId); - // req.userOrgRoles = userOrganizations.reduce((acc, org) => { + // req.userOrgRoleIds = userOrganizations.reduce((acc, org) => { // acc[org.orgId] = org.role; // return acc; // }, {} as Record); diff --git a/server/routers/auth/verifyOrgAccess.ts b/server/routers/auth/verifyOrgAccess.ts index 2e79d674..e24e765e 100644 --- a/server/routers/auth/verifyOrgAccess.ts +++ b/server/routers/auth/verifyOrgAccess.ts @@ -26,7 +26,8 @@ export function verifyOrgAccess(req: Request, res: Response, next: NextFunction) next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this organization')); } else { // User has access, attach the user's role to the request for potential future use - req.userOrgRole = result[0].role; + req.userOrgRoleId = result[0].roleId; + req.userOrgId = orgId; next(); } }) diff --git a/server/routers/auth/verifyResourceAccess.ts b/server/routers/auth/verifyResourceAccess.ts index 2efec9fd..f5518323 100644 --- a/server/routers/auth/verifyResourceAccess.ts +++ b/server/routers/auth/verifyResourceAccess.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from 'express'; import { db } from '@server/db'; -import { resources, userOrgs } from '@server/db/schema'; +import { resources, userOrgs, userResources, roleResources } from '@server/db/schema'; import { and, eq } from 'drizzle-orm'; import createHttpError from 'http-errors'; import HttpCode from '@server/types/HttpCode'; @@ -13,42 +13,66 @@ export async function verifyResourceAccess(req: Request, res: Response, next: Ne return next(createHttpError(HttpCode.UNAUTHORIZED, 'User not authenticated')); } - const resource = await db.select() - .from(resources) - .where(eq(resources.resourceId, resourceId)) - .limit(1); + try { + // Get the resource + const resource = await db.select() + .from(resources) + .where(eq(resources.resourceId, resourceId)) + .limit(1); - if (resource.length === 0) { - return next( - createHttpError( - HttpCode.NOT_FOUND, - `resource with ID ${resourceId} not found` + if (resource.length === 0) { + return next(createHttpError(HttpCode.NOT_FOUND, `Resource with ID ${resourceId} not found`)); + } + + if (!resource[0].orgId) { + return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, `Resource with ID ${resourceId} does not have an organization ID`)); + } + + // Get user's role ID in the organization + const userOrgRole = await db.select() + .from(userOrgs) + .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, resource[0].orgId))) + .limit(1); + + if (userOrgRole.length === 0) { + return next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this organization')); + } + + const userOrgRoleId = userOrgRole[0].roleId; + req.userOrgRoleId = userOrgRoleId; + req.userOrgId = resource[0].orgId; + + // Check role-based resource access first + const roleResourceAccess = await db.select() + .from(roleResources) + .where( + and( + eq(roleResources.resourceId, resourceId), + eq(roleResources.roleId, userOrgRoleId) + ) ) - ); - } + .limit(1); - if (!resource[0].orgId) { - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - `resource with ID ${resourceId} does not have an organization ID` - ) - ); - } + if (roleResourceAccess.length > 0) { + // User's role has access to the resource + return next(); + } - db.select() - .from(userOrgs) - .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, resource[0].orgId))) - .then((result) => { - if (result.length === 0) { - next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this organization')); - } else { - // User has access, attach the user's role to the request for potential future use - req.userOrgRole = result[0].role; - next(); - } - }) - .catch((error) => { - next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, 'Error verifying organization access')); - }); + // If role doesn't have access, check user-specific resource access + const userResourceAccess = await db.select() + .from(userResources) + .where(and(eq(userResources.userId, userId), eq(userResources.resourceId, resourceId))) + .limit(1); + + if (userResourceAccess.length > 0) { + // User has direct access to the resource + return next(); + } + + // If we reach here, the user doesn't have access to the resource + return next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this resource')); + + } catch (error) { + return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, 'Error verifying resource access')); + } } \ No newline at end of file diff --git a/server/routers/auth/verifySiteAccess.ts b/server/routers/auth/verifySiteAccess.ts index 397b1491..6b8e16cd 100644 --- a/server/routers/auth/verifySiteAccess.ts +++ b/server/routers/auth/verifySiteAccess.ts @@ -1,7 +1,7 @@ import { Request, Response, NextFunction } from 'express'; import { db } from '@server/db'; -import { sites, userOrgs } from '@server/db/schema'; -import { and, eq } from 'drizzle-orm'; +import { sites, userOrgs, userSites, roleSites, roles } from '@server/db/schema'; +import { and, eq, or } from 'drizzle-orm'; import createHttpError from 'http-errors'; import HttpCode from '@server/types/HttpCode'; @@ -14,45 +14,66 @@ export async function verifySiteAccess(req: Request, res: Response, next: NextFu } if (isNaN(siteId)) { - return next(createHttpError(HttpCode.BAD_REQUEST, 'Invalid organization ID')); + return next(createHttpError(HttpCode.BAD_REQUEST, 'Invalid site ID')); } - const site = await db.select() - .from(sites) - .where(eq(sites.siteId, siteId)) - .limit(1); + try { + // Get the site + const site = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1); - if (site.length === 0) { - return next( - createHttpError( - HttpCode.NOT_FOUND, - `Site with ID ${siteId} not found` + if (site.length === 0) { + return next(createHttpError(HttpCode.NOT_FOUND, `Site with ID ${siteId} not found`)); + } + + if (!site[0].orgId) { + return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, `Site with ID ${siteId} does not have an organization ID`)); + } + + // Get user's role ID in the organization + const userOrgRole = await db.select() + .from(userOrgs) + .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, site[0].orgId))) + .limit(1); + + if (userOrgRole.length === 0) { + return next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this organization')); + } + + const userOrgRoleId = userOrgRole[0].roleId; + req.userOrgRoleId = userOrgRoleId; + req.userOrgId = site[0].orgId; + + // Check role-based site access first + const roleSiteAccess = await db.select() + .from(roleSites) + .where( + and( + eq(roleSites.siteId, siteId), + eq(roleSites.roleId, userOrgRoleId) + ) ) - ); - } + .limit(1); - if (!site[0].orgId) { - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - `Site with ID ${siteId} does not have an organization ID` - ) - ); - } + if (roleSiteAccess.length > 0) { + // User's role has access to the site + return next(); + } - db.select() - .from(userOrgs) - .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, site[0].orgId))) - .then((result) => { - if (result.length === 0) { - next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this organization')); - } else { - // User has access, attach the user's role to the request for potential future use - req.userOrgRole = result[0].role; - next(); - } - }) - .catch((error) => { - next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, 'Error verifying organization access')); - }); + // If role doesn't have access, check user-specific site access + const userSiteAccess = await db.select() + .from(userSites) + .where(and(eq(userSites.userId, userId), eq(userSites.siteId, siteId))) + .limit(1); + + if (userSiteAccess.length > 0) { + // User has direct access to the site + return next(); + } + + // If we reach here, the user doesn't have access to the site + return next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this site')); + + } catch (error) { + return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, 'Error verifying site access')); + } } \ No newline at end of file diff --git a/server/routers/auth/verifyTargetAccess.ts b/server/routers/auth/verifyTargetAccess.ts index 6958db3e..fa070e22 100644 --- a/server/routers/auth/verifyTargetAccess.ts +++ b/server/routers/auth/verifyTargetAccess.ts @@ -73,7 +73,8 @@ export async function verifyTargetAccess(req: Request, res: Response, next: Next next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this organization')); } else { // User has access, attach the user's role to the request for potential future use - req.userOrgRole = result[0].role; + req.userOrgRoleId = result[0].roleId; + req.userOrgId = resource[0].orgId!; next(); } }) diff --git a/server/routers/org/listOrgs.ts b/server/routers/org/listOrgs.ts index fda2bcde..37617a7d 100644 --- a/server/routers/org/listOrgs.ts +++ b/server/routers/org/listOrgs.ts @@ -62,7 +62,7 @@ export async function listOrgs(req: Request, res: Response, next: NextFunction): // // Add the user's role for each organization // const organizationsWithRoles = organizations.map(org => ({ // ...org, - // userRole: req.userOrgRoles[org.orgId], + // userRole: req.userOrgRoleIds[org.orgId], // })); return res.status(HttpCode.OK).send( diff --git a/server/routers/resource/listResources.ts b/server/routers/resource/listResources.ts index 0f20650b..6441f681 100644 --- a/server/routers/resource/listResources.ts +++ b/server/routers/resource/listResources.ts @@ -1,11 +1,11 @@ import { Request, Response, NextFunction } from 'express'; import { z } from 'zod'; import { db } from '@server/db'; -import { resources, sites } from '@server/db/schema'; +import { resources, sites, userResources, roleResources } from '@server/db/schema'; import response from "@server/utils/response"; import HttpCode from '@server/types/HttpCode'; import createHttpError from 'http-errors'; -import { sql, eq } from 'drizzle-orm'; +import { sql, eq, and, or, inArray } from 'drizzle-orm'; const listResourcesParamsSchema = z.object({ siteId: z.coerce.number().int().positive().optional(), @@ -19,30 +19,50 @@ const listResourcesSchema = z.object({ offset: z.coerce.number().int().nonnegative().default(0), }); -export async function listResources(req: Request, res: Response, next: NextFunction): Promise { +interface RequestWithOrgAndRole extends Request { + userOrgRoleId?: number; + orgId?: number; +} + +export async function listResources(req: RequestWithOrgAndRole, res: Response, next: NextFunction): Promise { try { + // Check if the user has permission to list resources + // const LIST_RESOURCES_ACTION_ID = 3; // Assume 3 is the action ID for listing resources + // const hasPermission = await checkUserActionPermission(LIST_RESOURCES_ACTION_ID, req); + // if (!hasPermission) { + // return next(createHttpError(HttpCode.FORBIDDEN, 'User does not have permission to list resources')); + // } + const parsedQuery = listResourcesSchema.safeParse(req.query); if (!parsedQuery.success) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - parsedQuery.error.errors.map(e => e.message).join(', ') - ) - ); + return next(createHttpError(HttpCode.BAD_REQUEST, parsedQuery.error.errors.map(e => e.message).join(', '))); } const { limit, offset } = parsedQuery.data; const parsedParams = listResourcesParamsSchema.safeParse(req.params); if (!parsedParams.success) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - parsedParams.error.errors.map(e => e.message).join(', ') - ) - ); + return next(createHttpError(HttpCode.BAD_REQUEST, parsedParams.error.errors.map(e => e.message).join(', '))); } const { siteId, orgId } = parsedParams.data; + if (orgId && orgId !== req.orgId) { + return next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this organization')); + } + + // Get the list of resources the user has access to + const accessibleResources = await db + .select({ resourceId: sql`COALESCE(${userResources.resourceId}, ${roleResources.resourceId})` }) + .from(userResources) + .fullJoin(roleResources, eq(userResources.resourceId, roleResources.resourceId)) + .where( + or( + eq(userResources.userId, req.user!.id), + eq(roleResources.roleId, req.userOrgRoleId!) + ) + ); + + const accessibleResourceIds = accessibleResources.map(resource => resource.resourceId); + let baseQuery: any = db .select({ resourceId: resources.resourceId, @@ -51,16 +71,21 @@ export async function listResources(req: Request, res: Response, next: NextFunct siteName: sites.name, }) .from(resources) - .leftJoin(sites, eq(resources.siteId, sites.siteId)); + .leftJoin(sites, eq(resources.siteId, sites.siteId)) + .where(inArray(resources.resourceId, accessibleResourceIds)); - let countQuery: any = db.select({ count: sql`cast(count(*) as integer)` }).from(resources); + let countQuery: any = db + .select({ count: sql`cast(count(*) as integer)` }) + .from(resources) + .where(inArray(resources.resourceId, accessibleResourceIds)); if (siteId) { baseQuery = baseQuery.where(eq(resources.siteId, siteId)); countQuery = countQuery.where(eq(resources.siteId, siteId)); - } else if (orgId) { - baseQuery = baseQuery.where(eq(resources.orgId, orgId)); - countQuery = countQuery.where(eq(resources.orgId, orgId)); + } else { + // If orgId is provided, it's already checked to match req.orgId + baseQuery = baseQuery.where(eq(resources.orgId, req.orgId!)); + countQuery = countQuery.where(eq(resources.orgId, req.orgId!)); } const resourcesList = await baseQuery.limit(limit).offset(offset); diff --git a/server/routers/site/listSites.ts b/server/routers/site/listSites.ts index 6b7cba79..1d72bbb1 100644 --- a/server/routers/site/listSites.ts +++ b/server/routers/site/listSites.ts @@ -1,11 +1,12 @@ import { Request, Response, NextFunction } from 'express'; import { z } from 'zod'; import { db } from '@server/db'; -import { sites, orgs, exitNodes } from '@server/db/schema'; +import { sites, orgs, exitNodes, userSites, roleSites } from '@server/db/schema'; import response from "@server/utils/response"; import HttpCode from '@server/types/HttpCode'; import createHttpError from 'http-errors'; -import { sql, eq } from 'drizzle-orm'; +import { sql, eq, and, or, inArray } from 'drizzle-orm'; +// import { checkUserActionPermission } from './checkUserActionPermission'; // Import the function we created earlier const listSitesParamsSchema = z.object({ orgId: z.string().optional().transform(Number).pipe(z.number().int().positive()), @@ -18,29 +19,41 @@ const listSitesSchema = z.object({ export async function listSites(req: Request, res: Response, next: NextFunction): Promise { try { + // Check if the user has permission to list sites + // const LIST_SITES_ACTION_ID = 1; // Assume 1 is the action ID for listing sites + // const hasPermission = await checkUserActionPermission(LIST_SITES_ACTION_ID, req); + // if (!hasPermission) { + // return next(createHttpError(HttpCode.FORBIDDEN, 'User does not have permission to list sites')); + // } + const parsedQuery = listSitesSchema.safeParse(req.query); if (!parsedQuery.success) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - parsedQuery.error.errors.map(e => e.message).join(', ') - ) - ); + return next(createHttpError(HttpCode.BAD_REQUEST, parsedQuery.error.errors.map(e => e.message).join(', '))); } - const { limit, offset } = parsedQuery.data; const parsedParams = listSitesParamsSchema.safeParse(req.params); if (!parsedParams.success) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - parsedParams.error.errors.map(e => e.message).join(', ') - ) - ); + return next(createHttpError(HttpCode.BAD_REQUEST, parsedParams.error.errors.map(e => e.message).join(', '))); + } + const { orgId } = parsedParams.data; + + if (orgId && orgId !== req.userOrgId) { + return next(createHttpError(HttpCode.FORBIDDEN, 'User does not have access to this organization')); } - const { orgId } = parsedParams.data; + const accessibleSites = await db + .select({ siteId: sql`COALESCE(${userSites.siteId}, ${roleSites.siteId})` }) + .from(userSites) + .fullJoin(roleSites, eq(userSites.siteId, roleSites.siteId)) + .where( + or( + eq(userSites.userId, req.user!.id), + eq(roleSites.roleId, req.userOrgRoleId!) + ) + ); + + const accessibleSiteIds = accessibleSites.map(site => site.siteId); let baseQuery: any = db .select({ @@ -56,9 +69,12 @@ export async function listSites(req: Request, res: Response, next: NextFunction) }) .from(sites) .leftJoin(orgs, eq(sites.orgId, orgs.orgId)) - .leftJoin(exitNodes, eq(sites.exitNode, exitNodes.exitNodeId)); + .where(inArray(sites.siteId, accessibleSiteIds)); - let countQuery: any = db.select({ count: sql`cast(count(*) as integer)` }).from(sites); + let countQuery: any = db + .select({ count: sql`cast(count(*) as integer)` }) + .from(sites) + .where(inArray(sites.siteId, accessibleSiteIds)); if (orgId) { baseQuery = baseQuery.where(eq(sites.orgId, orgId)); diff --git a/server/types/Auth.ts b/server/types/Auth.ts index 9a228ee0..69f6cefc 100644 --- a/server/types/Auth.ts +++ b/server/types/Auth.ts @@ -5,5 +5,5 @@ import { Session } from "lucia"; export interface AuthenticatedRequest extends Request { user: User; session: Session; - userOrgRole?: string; + userOrgRoleId?: number; }