diff --git a/server/auth/resource.ts b/server/auth/resource.ts new file mode 100644 index 00000000..0d87b105 --- /dev/null +++ b/server/auth/resource.ts @@ -0,0 +1,125 @@ +import { encodeHexLowerCase } from "@oslojs/encoding"; +import { sha256 } from "@oslojs/crypto/sha2"; +import { + resourceSessions, + ResourceSession, + User, + users, +} from "@server/db/schema"; +import db from "@server/db"; +import { eq, and } from "drizzle-orm"; + +export const SESSION_COOKIE_NAME = "resource_session"; +export const SESSION_COOKIE_EXPIRES = 1000 * 60 * 60 * 24 * 30; + +export type ResourceAuthMethod = "password" | "pincode"; + +export async function createResourceSession( + token: string, + userId: string, + resourceId: number, + method: ResourceAuthMethod +): Promise { + const sessionId = encodeHexLowerCase( + sha256(new TextEncoder().encode(token)) + ); + const session: ResourceSession = { + sessionId: sessionId, + userId, + expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(), + resourceId, + method, + }; + await db.insert(resourceSessions).values(session); + return session; +} + +export async function validateResourceSessionToken( + token: string +): Promise { + const sessionId = encodeHexLowerCase( + sha256(new TextEncoder().encode(token)) + ); + const result = await db + .select({ user: users, resourceSession: resourceSessions }) + .from(resourceSessions) + .innerJoin(users, eq(resourceSessions.userId, users.userId)) + .where(eq(resourceSessions.sessionId, sessionId)); + if (result.length < 1) { + return { session: null, user: null }; + } + const { user, resourceSession } = result[0]; + if (Date.now() >= resourceSession.expiresAt) { + await db + .delete(resourceSessions) + .where(eq(resourceSessions.sessionId, resourceSession.sessionId)); + return { session: null, user: null }; + } + if (Date.now() >= resourceSession.expiresAt - SESSION_COOKIE_EXPIRES / 2) { + resourceSession.expiresAt = new Date( + Date.now() + SESSION_COOKIE_EXPIRES + ).getTime(); + await db + .update(resourceSessions) + .set({ + expiresAt: resourceSession.expiresAt, + }) + .where(eq(resourceSessions.sessionId, resourceSession.sessionId)); + } + return { session: resourceSession, user }; +} + +export async function invalidateResourceSession( + sessionId: string +): Promise { + await db + .delete(resourceSessions) + .where(eq(resourceSessions.sessionId, sessionId)); +} + +export async function invalidateAllSessions( + userId: string, + method?: ResourceAuthMethod +): Promise { + if (!method) { + await db + .delete(resourceSessions) + .where(eq(resourceSessions.userId, userId)); + } else { + await db + .delete(resourceSessions) + .where( + and( + eq(resourceSessions.userId, userId), + eq(resourceSessions.method, method) + ) + ); + } +} + +export function serializeSessionCookie( + token: string, + fqdn: string, + secure: boolean +): string { + if (secure) { + return `${SESSION_COOKIE_NAME}=${token}; HttpOnly; SameSite=Lax; Max-Age=${SESSION_COOKIE_EXPIRES}; Path=/; Secure; Domain=${fqdn}`; + } else { + return `${SESSION_COOKIE_NAME}=${token}; HttpOnly; SameSite=Lax; Max-Age=${SESSION_COOKIE_EXPIRES}; Path=/; Domain=${fqdn}`; + } +} + +export function createBlankSessionTokenCookie( + fqdn: string, + secure: boolean +): string { + if (secure) { + return `${SESSION_COOKIE_NAME}=; HttpOnly; SameSite=Lax; Max-Age=0; Path=/; Secure; Domain=${fqdn}`; + } else { + return `${SESSION_COOKIE_NAME}=; HttpOnly; SameSite=Lax; Max-Age=0; Path=/; Domain=${fqdn}`; + } +} + +export type ResourceSessionValidationResult = + | { session: ResourceSession; user: User } + | { session: null; user: null }; diff --git a/server/db/schema.ts b/server/db/schema.ts index 2bad27b2..93168d1c 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -9,9 +9,11 @@ export const orgs = sqliteTable("orgs", { export const sites = sqliteTable("sites", { siteId: integer("siteId").primaryKey({ autoIncrement: true }), - orgId: text("orgId").references(() => orgs.orgId, { - onDelete: "cascade", - }).notNull(), + orgId: text("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade", + }) + .notNull(), niceId: text("niceId").notNull(), exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { onDelete: "set null", @@ -25,22 +27,32 @@ export const sites = sqliteTable("sites", { export const resources = sqliteTable("resources", { resourceId: integer("resourceId").primaryKey({ autoIncrement: true }), - siteId: integer("siteId").references(() => sites.siteId, { - onDelete: "cascade", - }).notNull(), - orgId: text("orgId").references(() => orgs.orgId, { - onDelete: "cascade", - }).notNull(), + siteId: integer("siteId") + .references(() => sites.siteId, { + onDelete: "cascade", + }) + .notNull(), + orgId: text("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade", + }) + .notNull(), name: text("name").notNull(), subdomain: text("subdomain").notNull(), + fullDomain: text("fullDomain").notNull().unique(), ssl: integer("ssl", { mode: "boolean" }).notNull().default(false), + appSSOEnabled: integer("appSSOEnabled", { mode: "boolean" }) + .notNull() + .default(false), }); export const targets = sqliteTable("targets", { targetId: integer("targetId").primaryKey({ autoIncrement: true }), - resourceId: integer("resourceId").references(() => resources.resourceId, { - onDelete: "cascade", - }).notNull(), + resourceId: integer("resourceId") + .references(() => resources.resourceId, { + onDelete: "cascade", + }) + .notNull(), ip: text("ip").notNull(), method: text("method").notNull(), port: integer("port").notNull(), @@ -145,9 +157,11 @@ export const actions = sqliteTable("actions", { export const roles = sqliteTable("roles", { roleId: integer("roleId").primaryKey({ autoIncrement: true }), - orgId: text("orgId").references(() => orgs.orgId, { - onDelete: "cascade", - }).notNull(), + orgId: text("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade", + }) + .notNull(), isAdmin: integer("isAdmin", { mode: "boolean" }), name: text("name").notNull(), description: text("description"), @@ -215,9 +229,11 @@ export const userResources = sqliteTable("userResources", { export const limitsTable = sqliteTable("limits", { limitId: integer("limitId").primaryKey({ autoIncrement: true }), - orgId: text("orgId").references(() => orgs.orgId, { - onDelete: "cascade", - }).notNull(), + orgId: text("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade", + }) + .notNull(), name: text("name").notNull(), value: integer("value").notNull(), description: text("description"), @@ -236,6 +252,39 @@ export const userInvites = sqliteTable("userInvites", { .references(() => roles.roleId, { onDelete: "cascade" }), }); +export const resourcePincode = sqliteTable("resourcePincode", { + resourcePincodeId: integer("resourcePincodeId").primaryKey({ + autoIncrement: true, + }), + resourceId: integer("resourceId") + .notNull() + .references(() => resources.resourceId, { onDelete: "cascade" }), + pincodeHash: text("pincodeHash").notNull(), + digitLength: integer("digitLength").notNull(), +}); + +export const resourcePassword = sqliteTable("resourcePassword", { + resourcePasswordId: integer("resourcePasswordId").primaryKey({ + autoIncrement: true, + }), + resourceId: integer("resourceId") + .notNull() + .references(() => resources.resourceId, { onDelete: "cascade" }), + passwordHash: text("passwordHash").notNull(), +}); + +export const resourceSessions = sqliteTable("resourceSessions", { + sessionId: text("id").primaryKey(), + resourceId: integer("resourceId") + .notNull() + .references(() => resources.resourceId, { onDelete: "cascade" }), + userId: text("userId") + .notNull() + .references(() => users.userId, { onDelete: "cascade" }), + expiresAt: integer("expiresAt").notNull(), + method: text("method").notNull(), +}); + export type Org = InferSelectModel; export type User = InferSelectModel; export type Site = InferSelectModel; @@ -261,3 +310,4 @@ export type UserResource = InferSelectModel; export type Limit = InferSelectModel; export type UserInvite = InferSelectModel; export type UserOrg = InferSelectModel; +export type ResourceSession = InferSelectModel; diff --git a/server/routers/badger/index.ts b/server/routers/badger/index.ts index f7622030..7af4684a 100644 --- a/server/routers/badger/index.ts +++ b/server/routers/badger/index.ts @@ -1 +1 @@ -export * from "./verifyUser"; +export * from "./verifySession"; diff --git a/server/routers/badger/verifySession.ts b/server/routers/badger/verifySession.ts new file mode 100644 index 00000000..93b58ccc --- /dev/null +++ b/server/routers/badger/verifySession.ts @@ -0,0 +1,153 @@ +import HttpCode from "@server/types/HttpCode"; +import { NextFunction, Request, Response } from "express"; +import createHttpError from "http-errors"; +import { z } from "zod"; +import { fromError } from "zod-validation-error"; +import { response } from "@server/utils/response"; +import { validateSessionToken } from "@server/auth"; +import db from "@server/db"; +import { + resourcePassword, + resourcePincode, + resources, +} from "@server/db/schema"; +import { eq } from "drizzle-orm"; +import config from "@server/config"; +import { validateResourceSessionToken } from "@server/auth/resource"; + +const verifyResourceSessionSchema = z.object({ + cookies: z.object({ + session: z.string().nullable(), + resource_session: z.string().nullable(), + }), + originalRequestURL: z.string().url(), + scheme: z.string(), + host: z.string(), + path: z.string(), + method: z.string(), + tls: z.boolean(), +}); + +export type VerifyResourceSessionSchema = z.infer< + typeof verifyResourceSessionSchema +>; + +export type VerifyUserResponse = { + valid: boolean; + redirectUrl?: string; +}; + +export async function verifyResourceSession( + req: Request, + res: Response, + next: NextFunction +): Promise { + const parsedBody = verifyResourceSessionSchema.safeParse(req.query); + + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + try { + const { cookies, host, originalRequestURL } = parsedBody.data; + + const [result] = await db + .select() + .from(resources) + .leftJoin( + resourcePincode, + eq(resourcePincode.resourceId, resources.resourceId) + ) + .leftJoin( + resourcePassword, + eq(resourcePassword.resourceId, resources.resourceId) + ) + .where(eq(resources.fullDomain, host)) + .limit(1); + + const resource = result?.resources; + const pincode = result?.resourcePincode; + const password = result?.resourcePassword; + + // resource doesn't exist for some reason + if (!resource) { + return notAllowed(res); // no resource to redirect to + } + + // no auth is configured; auth check is disabled + if (!resource.appSSOEnabled && !pincode && !password) { + return allowed(res); + } + + const redirectUrl = `${config.app.base_url}/auth/resource/${resource.resourceId}/login?redirect=${originalRequestURL}`; + + // we need to check all session to find at least one valid session + // if we find one, we allow access + // if we don't find any, we deny access and redirect to the login page + + // we found a session token, and app sso is enabled, so we need to check if it's a valid session + if (cookies.session && resource.appSSOEnabled) { + const { user, session } = await validateSessionToken( + cookies.session + ); + if (user && session) { + return allowed(res); + } + } + + // we found a resource session token, and either pincode or password is enabled for the resource + // so we need to check if it's a valid session + if (cookies.resource_session && (pincode || password)) { + const { session, user } = await validateResourceSessionToken( + cookies.resource_session + ); + + if (session && user) { + if (pincode && session.method === "pincode") { + return allowed(res); + } + + if (password && session.method === "password") { + return allowed(res); + } + } + } + + // a valid session was not found for an enabled auth method so we deny access + // the user is redirected to the login page + // the login page with render which auth methods are enabled and show the user the correct login form + return notAllowed(res, redirectUrl); + } catch (e) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to verify session" + ) + ); + } +} + +function notAllowed(res: Response, redirectUrl?: string) { + return response(res, { + data: { valid: false, redirectUrl }, + success: true, + error: false, + message: "Access denied", + status: HttpCode.OK, + }); +} + +function allowed(res: Response) { + return response(res, { + data: { valid: true }, + success: true, + error: false, + message: "Access allowed", + status: HttpCode.OK, + }); +} diff --git a/server/routers/badger/verifyUser.ts b/server/routers/badger/verifyUser.ts deleted file mode 100644 index 952dfd7d..00000000 --- a/server/routers/badger/verifyUser.ts +++ /dev/null @@ -1,61 +0,0 @@ -import HttpCode from "@server/types/HttpCode"; -import { NextFunction, Request, Response } from "express"; -import createHttpError from "http-errors"; -import { z } from "zod"; -import { fromError } from "zod-validation-error"; -import { response } from "@server/utils/response"; -import { validateSessionToken } from "@server/auth"; - -export const verifyUserBody = z.object({ - sessionId: z.string(), -}); - -export type VerifyUserBody = z.infer; - -export type VerifyUserResponse = { - valid: boolean; -}; - -export async function verifyUser( - req: Request, - res: Response, - next: NextFunction, -): Promise { - const parsedBody = verifyUserBody.safeParse(req.query); - - if (!parsedBody.success) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - fromError(parsedBody.error).toString(), - ), - ); - } - - const { sessionId } = parsedBody.data; - - try { - const { session, user } = await validateSessionToken(sessionId); - - if (!session || !user) { - return next( - createHttpError(HttpCode.UNAUTHORIZED, "Invalid session"), - ); - } - - return response(res, { - data: { valid: true }, - success: true, - error: false, - message: "Access allowed", - status: HttpCode.OK, - }); - } catch (e) { - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - "Failed to check user", - ), - ); - } -} diff --git a/server/routers/internal.ts b/server/routers/internal.ts index d477e696..64cb91bc 100644 --- a/server/routers/internal.ts +++ b/server/routers/internal.ts @@ -24,6 +24,6 @@ gerbilRouter.post("/receive-bandwidth", gerbil.receiveBandwidth); const badgerRouter = Router(); internalRouter.use("/badger", badgerRouter); -badgerRouter.get("/verify-user", badger.verifyUser) +badgerRouter.get("/verify-session", badger.verifyResourceSession); export default internalRouter; diff --git a/server/routers/resource/createResource.ts b/server/routers/resource/createResource.ts index 63f1ef3c..50473fad 100644 --- a/server/routers/resource/createResource.ts +++ b/server/routers/resource/createResource.ts @@ -18,11 +18,7 @@ import { fromError } from "zod-validation-error"; import { subdomainSchema } from "@server/schemas/subdomainSchema"; const createResourceParamsSchema = z.object({ - siteId: z - .string() - .optional() - .transform(stoi) - .pipe(z.number().int().positive().optional()), + siteId: z.string().transform(stoi).pipe(z.number().int().positive()), orgId: z.string(), }); @@ -88,10 +84,13 @@ export async function createResource( ); } + const fullDomain = `${subdomain}.${org[0].domain}`; + const newResource = await db .insert(resources) .values({ siteId, + fullDomain, orgId, name, subdomain, diff --git a/server/routers/resource/updateResource.ts b/server/routers/resource/updateResource.ts index 664f8d84..cf909ce1 100644 --- a/server/routers/resource/updateResource.ts +++ b/server/routers/resource/updateResource.ts @@ -1,8 +1,8 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; import { db } from "@server/db"; -import { resources, sites } from "@server/db/schema"; -import { eq } from "drizzle-orm"; +import { orgs, resources, sites } from "@server/db/schema"; +import { eq, or } from "drizzle-orm"; import response from "@server/utils/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -55,9 +55,35 @@ export async function updateResource( const { resourceId } = parsedParams.data; const updateData = parsedBody.data; + const resource = await db + .select() + .from(resources) + .where(eq(resources.resourceId, resourceId)) + .leftJoin(orgs, eq(resources.orgId, orgs.orgId)); + + if (resource.length === 0) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Resource with ID ${resourceId} not found` + ) + ); + } + + if (!resource[0].orgs?.domain) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Resource does not have a domain" + ) + ); + } + + const fullDomain = `${updateData.subdomain}.${resource[0].orgs.domain}`; + const updatedResource = await db .update(resources) - .set(updateData) + .set({ ...updateData, fullDomain }) .where(eq(resources.resourceId, resourceId)) .returning();