import { Request, Response, NextFunction } from "express"; import { z } from "zod"; import { db } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { idp, idpOidcConfig, idpOrg, Role, roles, userOrgs, users } from "@server/db/schemas"; import { and, eq, inArray } from "drizzle-orm"; import * as arctic from "arctic"; import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl"; import jmespath from "jmespath"; import jsonwebtoken from "jsonwebtoken"; import config from "@server/lib/config"; import { UserType } from "@server/types/UserTypes"; import { createSession, generateId, generateSessionToken, serializeSessionCookie } from "@server/auth/sessions/app"; const paramsSchema = z .object({ idpId: z.coerce.number() }) .strict(); const bodySchema = z.object({ code: z.string().nonempty(), state: z.string().nonempty(), storedState: z.string().nonempty() }); export type ValidateOidcUrlCallbackResponse = { redirectUrl: string; }; export async function validateOidcCallback( req: Request, res: Response, next: NextFunction ): Promise { try { const parsedParams = paramsSchema.safeParse(req.params); if (!parsedParams.success) { return next( createHttpError( HttpCode.BAD_REQUEST, fromError(parsedParams.error).toString() ) ); } const { idpId } = parsedParams.data; const parsedBody = bodySchema.safeParse(req.body); if (!parsedBody.success) { return next( createHttpError( HttpCode.BAD_REQUEST, fromError(parsedBody.error).toString() ) ); } const { storedState, code, state: expectedState } = parsedBody.data; const [existingIdp] = await db .select() .from(idp) .innerJoin(idpOidcConfig, eq(idpOidcConfig.idpId, idp.idpId)) .where(and(eq(idp.type, "oidc"), eq(idp.idpId, idpId))); if (!existingIdp) { return next( createHttpError( HttpCode.BAD_REQUEST, "IdP not found for the organization" ) ); } const redirectUrl = generateOidcRedirectUrl(existingIdp.idp.idpId); const client = new arctic.OAuth2Client( existingIdp.idpOidcConfig.clientId, existingIdp.idpOidcConfig.clientSecret, redirectUrl ); const statePayload = jsonwebtoken.verify( storedState, config.getRawConfig().server.secret, function (err, decoded) { if (err) { logger.error("Error verifying state JWT", { err }); return next( createHttpError( HttpCode.BAD_REQUEST, "Invalid state JWT" ) ); } return decoded; } ); const stateObj = z .object({ redirectUrl: z.string(), state: z.string(), codeVerifier: z.string() }) .safeParse(statePayload); if (!stateObj.success) { logger.error("Error parsing state JWT"); return next( createHttpError( HttpCode.BAD_REQUEST, fromError(stateObj.error).toString() ) ); } const { codeVerifier, state, redirectUrl: postAuthRedirectUrl } = stateObj.data; if (state !== expectedState) { logger.error("State mismatch", { expectedState, state }); return next( createHttpError(HttpCode.BAD_REQUEST, "State mismatch") ); } const tokens = await client.validateAuthorizationCode( existingIdp.idpOidcConfig.tokenUrl, code, codeVerifier ); const idToken = tokens.idToken(); const claims = arctic.decodeIdToken(idToken); const userIdentifier = jmespath.search( claims, existingIdp.idpOidcConfig.identifierPath ); if (!userIdentifier) { return next( createHttpError( HttpCode.BAD_REQUEST, "User identifier not found in the ID token" ) ); } logger.debug("User identifier", { userIdentifier }); let email = null; let name = null; try { if (existingIdp.idpOidcConfig.emailPath) { email = jmespath.search( claims, existingIdp.idpOidcConfig.emailPath ); } if (existingIdp.idpOidcConfig.namePath) { name = jmespath.search( claims, existingIdp.idpOidcConfig.namePath || "" ); } } catch (error) {} logger.debug("User email", { email }); logger.debug("User name", { name }); const [existingUser] = await db .select() .from(users) .where( and( eq(users.username, userIdentifier), eq(users.idpId, existingIdp.idp.idpId) ) ); const idpOrgs = await db .select() .from(idpOrg) .where(eq(idpOrg.idpId, existingIdp.idp.idpId)); let userOrgInfo: { orgId: string; roleId: number }[] = []; for (const idpOrg of idpOrgs) { let roleId: number | undefined = undefined; if (idpOrg.orgMapping) { const orgId = jmespath.search(claims, idpOrg.orgMapping); if (!orgId) { continue; } } if (idpOrg.roleMapping) { const roleName = jmespath.search(claims, idpOrg.roleMapping); if (!roleName) { logger.error("Role name not found in the ID token", { roleName }); continue; } const [roleRes] = await db .select() .from(roles) .where( and( eq(roles.orgId, idpOrg.orgId), eq(roles.name, roleName) ) ); if (!roleRes) { logger.error("Role not found", { orgId: idpOrg.orgId, roleName }); continue; } roleId = roleRes.roleId; userOrgInfo.push({ orgId: idpOrg.orgId, roleId }); } } logger.debug("User org info", { userOrgInfo }); let existingUserId = existingUser?.userId; // sync the user with the orgs and roles await db.transaction(async (trx) => { let userId = existingUser?.userId; // create user if not exists if (!existingUser) { userId = generateId(15); await trx.insert(users).values({ userId, username: userIdentifier, email: email || null, name: name || null, type: UserType.OIDC, idpId: existingIdp.idp.idpId, emailVerified: true, // OIDC users are always verified dateCreated: new Date().toISOString() }); } else { // set the name and email await trx .update(users) .set({ username: userIdentifier, email: email || null, name: name || null }) .where(eq(users.userId, userId)); } existingUserId = userId; // get all current user orgs const currentUserOrgs = await trx .select() .from(userOrgs) .where(eq(userOrgs.userId, userId)); // Delete orgs that are no longer valid const orgsToDelete = currentUserOrgs.filter( (currentOrg) => !userOrgInfo.some( (newOrg) => newOrg.orgId === currentOrg.orgId ) ); if (orgsToDelete.length > 0) { await trx.delete(userOrgs).where( and( eq(userOrgs.userId, userId), inArray( userOrgs.orgId, orgsToDelete.map((org) => org.orgId) ) ) ); } // Update roles for existing orgs where the role has changed const orgsToUpdate = currentUserOrgs.filter((currentOrg) => { const newOrg = userOrgInfo.find( (newOrg) => newOrg.orgId === currentOrg.orgId ); return newOrg && newOrg.roleId !== currentOrg.roleId; }); if (orgsToUpdate.length > 0) { for (const org of orgsToUpdate) { const newRole = userOrgInfo.find( (newOrg) => newOrg.orgId === org.orgId ); if (newRole) { await trx .update(userOrgs) .set({ roleId: newRole.roleId }) .where( and( eq(userOrgs.userId, userId), eq(userOrgs.orgId, org.orgId) ) ); } } } // Add new orgs that don't exist yet const orgsToAdd = userOrgInfo.filter( (newOrg) => !currentUserOrgs.some( (currentOrg) => currentOrg.orgId === newOrg.orgId ) ); if (orgsToAdd.length > 0) { await trx.insert(userOrgs).values( orgsToAdd.map((org) => ({ userId, orgId: org.orgId, roleId: org.roleId, dateCreated: new Date().toISOString() })) ); } }); const token = generateSessionToken(); const sess = await createSession(token, existingUserId); const isSecure = req.protocol === "https"; const cookie = serializeSessionCookie( token, isSecure, new Date(sess.expiresAt) ); res.appendHeader("Set-Cookie", cookie); return response(res, { data: { redirectUrl: postAuthRedirectUrl }, success: true, error: false, message: "OIDC callback validated successfully", status: HttpCode.CREATED }); } catch (error) { logger.error(error); return next( createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") ); } }