This commit is contained in:
Owen Schwartz 2024-10-13 18:34:01 -04:00
commit a875b65e5b
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
39 changed files with 234 additions and 167 deletions

View file

@ -16,8 +16,9 @@
"dependencies": { "dependencies": {
"@esbuild-plugins/tsconfig-paths": "0.1.2", "@esbuild-plugins/tsconfig-paths": "0.1.2",
"@hookform/resolvers": "3.9.0", "@hookform/resolvers": "3.9.0",
"@lucia-auth/adapter-drizzle": "1.1.0",
"@node-rs/argon2": "1.8.3", "@node-rs/argon2": "1.8.3",
"@oslojs/crypto": "1.0.1",
"@oslojs/encoding": "1.1.0",
"@radix-ui/react-checkbox": "1.1.2", "@radix-ui/react-checkbox": "1.1.2",
"@radix-ui/react-dialog": "1.1.2", "@radix-ui/react-dialog": "1.1.2",
"@radix-ui/react-icons": "1.3.0", "@radix-ui/react-icons": "1.3.0",
@ -48,7 +49,6 @@
"http-errors": "2.0.0", "http-errors": "2.0.0",
"input-otp": "1.2.4", "input-otp": "1.2.4",
"js-yaml": "4.1.0", "js-yaml": "4.1.0",
"lucia": "3.2.0",
"lucide-react": "0.447.0", "lucide-react": "0.447.0",
"moment": "2.30.1", "moment": "2.30.1",
"next": "14.2.13", "next": "14.2.13",

View file

@ -1,6 +1,5 @@
import { import {
orgs, orgs,
users,
sites, sites,
resources, resources,
exitNodes, exitNodes,

View file

@ -45,14 +45,14 @@ export async function verifyBackUpCode(
parallelism: 1, parallelism: 1,
}); });
if (validCode) { if (validCode) {
validId = hashedCode.id; validId = hashedCode.codeId;
} }
} }
if (validId) { if (validId) {
await db await db
.delete(twoFactorBackupCodes) .delete(twoFactorBackupCodes)
.where(eq(twoFactorBackupCodes.id, validId)); .where(eq(twoFactorBackupCodes.codeId, validId));
} }
return validId ? true : false; return validId ? true : false;

View file

@ -55,8 +55,7 @@ export enum ActionsEnum {
} }
export async function checkUserActionPermission(actionId: string, req: Request): Promise<boolean> { export async function checkUserActionPermission(actionId: string, req: Request): Promise<boolean> {
const userId = req.user?.id; const userId = req.user?.userId;
if (!userId) { if (!userId) {
throw createHttpError(HttpCode.UNAUTHORIZED, 'User not authenticated'); throw createHttpError(HttpCode.UNAUTHORIZED, 'User not authenticated');
} }

View file

@ -1,51 +1,122 @@
export * from "./unauthorizedResponse";
export * from "./verifySession"; export * from "./verifySession";
export * from "./unauthorizedResponse";
import { Lucia, TimeSpan } from "lucia"; import {
import { DrizzleSQLiteAdapter } from "@lucia-auth/adapter-drizzle"; encodeBase32LowerCaseNoPadding,
encodeHexLowerCase,
} from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Session, sessions, User, users } from "@server/db/schema";
import db from "@server/db"; import db from "@server/db";
import { sessions, users } from "@server/db/schema"; import { eq } from "drizzle-orm";
import config from "@server/config"; import config from "@server/config";
import type { RandomReader } from "@oslojs/crypto/random";
import { generateRandomString } from "@oslojs/crypto/random";
const adapter = new DrizzleSQLiteAdapter(db, sessions, users); export const SESSION_COOKIE_NAME = "session";
export const SESSION_COOKIE_EXPIRES = 1000 * 60 * 60 * 24 * 30;
export const SECURE_COOKIES = config.server.secure_cookies;
export const COOKIE_DOMAIN =
"." + new URL(config.app.base_url).hostname.split(".").slice(-2).join(".");
export const lucia = new Lucia(adapter, { export function generateSessionToken(): string {
getUserAttributes: (attributes) => { const bytes = new Uint8Array(20);
return { crypto.getRandomValues(bytes);
email: attributes.email, const token = encodeBase32LowerCaseNoPadding(bytes);
twoFactorEnabled: attributes.twoFactorEnabled, return token;
twoFactorSecret: attributes.twoFactorSecret, }
emailVerified: attributes.emailVerified,
dateCreated: attributes.dateCreated, export async function createSession(
token: string,
userId: string,
): Promise<Session> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const session: Session = {
sessionId: sessionId,
userId,
expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(),
}; };
}, await db.insert(sessions).values(session);
sessionCookie: { return session;
name: "session", }
expires: false,
attributes: {
sameSite: "strict",
secure: config.server.secure_cookies || false,
domain:
"." + new URL(config.app.base_url).hostname.split(".").slice(-2).join("."),
},
},
sessionExpiresIn: new TimeSpan(2, "w"),
});
export default lucia; export async function validateSessionToken(
token: string,
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const result = await db
.select({ user: users, session: sessions })
.from(sessions)
.innerJoin(users, eq(sessions.userId, users.userId))
.where(eq(sessions.sessionId, sessionId));
if (result.length < 1) {
return { session: null, user: null };
}
const { user, session } = result[0];
if (Date.now() >= session.expiresAt) {
await db
.delete(sessions)
.where(eq(sessions.sessionId, session.sessionId));
return { session: null, user: null };
}
if (Date.now() >= session.expiresAt - (SESSION_COOKIE_EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + SESSION_COOKIE_EXPIRES,
).getTime();
await db
.update(sessions)
.set({
expiresAt: session.expiresAt,
})
.where(eq(sessions.sessionId, session.sessionId));
}
return { session, user };
}
declare module "lucia" { export async function invalidateSession(sessionId: string): Promise<void> {
interface Register { await db.delete(sessions).where(eq(sessions.sessionId, sessionId));
Lucia: typeof lucia; }
DatabaseUserAttributes: DatabaseUserAttributes;
export async function invalidateAllSessions(userId: string): Promise<void> {
await db.delete(sessions).where(eq(sessions.userId, userId));
}
export function serializeSessionCookie(token: string): string {
if (SECURE_COOKIES) {
return `${SESSION_COOKIE_NAME}=${token}; HttpOnly; SameSite=Lax; Max-Age=${SESSION_COOKIE_EXPIRES}; Path=/; Secure; Domain=${COOKIE_DOMAIN}`;
} else {
return `${SESSION_COOKIE_NAME}=${token}; HttpOnly; SameSite=Lax; Max-Age=${SESSION_COOKIE_EXPIRES}; Path=/; Domain=${COOKIE_DOMAIN}`;
} }
} }
interface DatabaseUserAttributes { export function createBlankSessionTokenCookie(): string {
email: string; if (SECURE_COOKIES) {
passwordHash: string; return `${SESSION_COOKIE_NAME}=; HttpOnly; SameSite=Lax; Max-Age=0; Path=/; Secure; Domain=${COOKIE_DOMAIN}`;
twoFactorEnabled: boolean; } else {
twoFactorSecret?: string; return `${SESSION_COOKIE_NAME}=; HttpOnly; SameSite=Lax; Max-Age=0; Path=/; Domain=${COOKIE_DOMAIN}`;
emailVerified: boolean;
dateCreated: string;
} }
}
const random: RandomReader = {
read(bytes: Uint8Array): void {
crypto.getRandomValues(bytes);
},
};
export function generateId(length: number): string {
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789";
return generateRandomString(random, alphabet, length);
}
export function generateIdFromEntropySize(size: number): string {
const buffer = crypto.getRandomValues(new Uint8Array(size));
return encodeBase32LowerCaseNoPadding(buffer);
}
export type SessionValidationResult =
| { session: Session; user: User }
| { session: null; user: null };

View file

@ -1,9 +1,9 @@
import { Request } from "express"; import { Request } from "express";
import { lucia } from "@server/auth"; import { validateSessionToken, SESSION_COOKIE_NAME } from "@server/auth";
export async function verifySession(req: Request) { export async function verifySession(req: Request) {
const res = await lucia.validateSession( const res = await validateSessionToken(
req.cookies[lucia.sessionCookieName], req.cookies[SESSION_COOKIE_NAME] ?? "",
); );
return res; return res;
} }

View file

@ -1,14 +1,12 @@
import { sqliteTable, text, integer } from "drizzle-orm/sqlite-core"; import { sqliteTable, text, integer } from "drizzle-orm/sqlite-core";
import { InferSelectModel } from "drizzle-orm"; import { InferSelectModel } from "drizzle-orm";
// Orgs table
export const orgs = sqliteTable("orgs", { export const orgs = sqliteTable("orgs", {
orgId: integer("orgId").primaryKey({ autoIncrement: true }), orgId: integer("orgId").primaryKey({ autoIncrement: true }),
name: text("name").notNull(), name: text("name").notNull(),
domain: text("domain").notNull(), domain: text("domain").notNull(),
}); });
// Sites table
export const sites = sqliteTable("sites", { export const sites = sqliteTable("sites", {
siteId: integer("siteId").primaryKey({ autoIncrement: true }), siteId: integer("siteId").primaryKey({ autoIncrement: true }),
orgId: integer("orgId").references(() => orgs.orgId, { orgId: integer("orgId").references(() => orgs.orgId, {
@ -25,7 +23,6 @@ export const sites = sqliteTable("sites", {
megabytesOut: integer("bytesOut"), megabytesOut: integer("bytesOut"),
}); });
// Resources table
export const resources = sqliteTable("resources", { export const resources = sqliteTable("resources", {
resourceId: text("resourceId", { length: 2048 }).primaryKey(), resourceId: text("resourceId", { length: 2048 }).primaryKey(),
siteId: integer("siteId").references(() => sites.siteId, { siteId: integer("siteId").references(() => sites.siteId, {
@ -38,7 +35,6 @@ export const resources = sqliteTable("resources", {
subdomain: text("subdomain"), subdomain: text("subdomain"),
}); });
// Targets table
export const targets = sqliteTable("targets", { export const targets = sqliteTable("targets", {
targetId: integer("targetId").primaryKey({ autoIncrement: true }), targetId: integer("targetId").primaryKey({ autoIncrement: true }),
resourceId: text("resourceId").references(() => resources.resourceId, { resourceId: text("resourceId").references(() => resources.resourceId, {
@ -51,7 +47,6 @@ export const targets = sqliteTable("targets", {
enabled: integer("enabled", { mode: "boolean" }).notNull().default(true), enabled: integer("enabled", { mode: "boolean" }).notNull().default(true),
}); });
// Exit Nodes table
export const exitNodes = sqliteTable("exitNodes", { export const exitNodes = sqliteTable("exitNodes", {
exitNodeId: integer("exitNodeId").primaryKey({ autoIncrement: true }), exitNodeId: integer("exitNodeId").primaryKey({ autoIncrement: true }),
name: text("name").notNull(), name: text("name").notNull(),
@ -60,7 +55,6 @@ export const exitNodes = sqliteTable("exitNodes", {
listenPort: integer("listenPort"), listenPort: integer("listenPort"),
}); });
// Routes table
export const routes = sqliteTable("routes", { export const routes = sqliteTable("routes", {
routeId: integer("routeId").primaryKey({ autoIncrement: true }), routeId: integer("routeId").primaryKey({ autoIncrement: true }),
exitNodeId: integer("exitNodeId").references(() => exitNodes.exitNodeId, { exitNodeId: integer("exitNodeId").references(() => exitNodes.exitNodeId, {
@ -69,9 +63,8 @@ export const routes = sqliteTable("routes", {
subnet: text("subnet").notNull(), subnet: text("subnet").notNull(),
}); });
// Users table
export const users = sqliteTable("user", { export const users = sqliteTable("user", {
id: text("id").primaryKey(), // has to be id not userId for lucia userId: text("id").primaryKey(),
email: text("email").notNull().unique(), email: text("email").notNull().unique(),
passwordHash: text("passwordHash").notNull(), passwordHash: text("passwordHash").notNull(),
twoFactorEnabled: integer("twoFactorEnabled", { mode: "boolean" }) twoFactorEnabled: integer("twoFactorEnabled", { mode: "boolean" })
@ -85,26 +78,25 @@ export const users = sqliteTable("user", {
}); });
export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", { export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {
id: integer("id").primaryKey({ autoIncrement: true }), codeId: integer("id").primaryKey({ autoIncrement: true }),
userId: text("userId") userId: text("userId")
.notNull() .notNull()
.references(() => users.id, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
codeHash: text("codeHash").notNull(), codeHash: text("codeHash").notNull(),
}); });
// Sessions table
export const sessions = sqliteTable("session", { export const sessions = sqliteTable("session", {
id: text("id").primaryKey(), // has to be id not sessionId for lucia sessionId: text("id").primaryKey(),
userId: text("userId") userId: text("userId")
.notNull() .notNull()
.references(() => users.id, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
expiresAt: integer("expiresAt").notNull(), expiresAt: integer("expiresAt").notNull(),
}); });
export const userOrgs = sqliteTable("userOrgs", { export const userOrgs = sqliteTable("userOrgs", {
userId: text("userId") userId: text("userId")
.notNull() .notNull()
.references(() => users.id), .references(() => users.userId),
orgId: integer("orgId") orgId: integer("orgId")
.notNull() .notNull()
.references(() => orgs.orgId), .references(() => orgs.orgId),
@ -114,20 +106,20 @@ export const userOrgs = sqliteTable("userOrgs", {
}); });
export const emailVerificationCodes = sqliteTable("emailVerificationCodes", { export const emailVerificationCodes = sqliteTable("emailVerificationCodes", {
id: integer("id").primaryKey({ autoIncrement: true }), codeId: integer("id").primaryKey({ autoIncrement: true }),
userId: text("userId") userId: text("userId")
.notNull() .notNull()
.references(() => users.id, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
email: text("email").notNull(), email: text("email").notNull(),
code: text("code").notNull(), code: text("code").notNull(),
expiresAt: integer("expiresAt").notNull(), expiresAt: integer("expiresAt").notNull(),
}); });
export const passwordResetTokens = sqliteTable("passwordResetTokens", { export const passwordResetTokens = sqliteTable("passwordResetTokens", {
id: integer("id").primaryKey({ autoIncrement: true }), tokenId: integer("id").primaryKey({ autoIncrement: true }),
userId: text("userId") userId: text("userId")
.notNull() .notNull()
.references(() => users.id, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
tokenHash: text("tokenHash").notNull(), tokenHash: text("tokenHash").notNull(),
expiresAt: integer("expiresAt").notNull(), expiresAt: integer("expiresAt").notNull(),
}); });
@ -140,7 +132,9 @@ export const actions = sqliteTable("actions", {
export const roles = sqliteTable("roles", { export const roles = sqliteTable("roles", {
roleId: integer("roleId").primaryKey({ autoIncrement: true }), roleId: integer("roleId").primaryKey({ autoIncrement: true }),
orgId: integer("orgId").references(() => orgs.orgId, { onDelete: "cascade" }), orgId: integer("orgId").references(() => orgs.orgId, {
onDelete: "cascade",
}),
isSuperuserRole: integer("isSuperuserRole", { mode: "boolean" }), isSuperuserRole: integer("isSuperuserRole", { mode: "boolean" }),
name: text("name").notNull(), name: text("name").notNull(),
description: text("description"), description: text("description"),
@ -161,7 +155,7 @@ export const roleActions = sqliteTable("roleActions", {
export const userActions = sqliteTable("userActions", { export const userActions = sqliteTable("userActions", {
userId: text("userId") userId: text("userId")
.notNull() .notNull()
.references(() => users.id, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
actionId: text("actionId") actionId: text("actionId")
.notNull() .notNull()
.references(() => actions.actionId, { onDelete: "cascade" }), .references(() => actions.actionId, { onDelete: "cascade" }),
@ -182,7 +176,7 @@ export const roleSites = sqliteTable("roleSites", {
export const userSites = sqliteTable("userSites", { export const userSites = sqliteTable("userSites", {
userId: text("userId") userId: text("userId")
.notNull() .notNull()
.references(() => users.id, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
siteId: integer("siteId") siteId: integer("siteId")
.notNull() .notNull()
.references(() => sites.siteId, { onDelete: "cascade" }), .references(() => sites.siteId, { onDelete: "cascade" }),
@ -200,7 +194,7 @@ export const roleResources = sqliteTable("roleResources", {
export const userResources = sqliteTable("userResources", { export const userResources = sqliteTable("userResources", {
userId: text("userId") userId: text("userId")
.notNull() .notNull()
.references(() => users.id, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
resourceId: text("resourceId") resourceId: text("resourceId")
.notNull() .notNull()
.references(() => resources.resourceId, { onDelete: "cascade" }), .references(() => resources.resourceId, { onDelete: "cascade" }),
@ -216,7 +210,6 @@ export const limitsTable = sqliteTable("limits", {
description: text("description"), description: text("description"),
}); });
// Define the model types for type inference
export type Org = InferSelectModel<typeof orgs>; export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>; export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>; export type Site = InferSelectModel<typeof sites>;

View file

@ -20,7 +20,7 @@ export const verifySessionMiddleware = async (
const existingUser = await db const existingUser = await db
.select() .select()
.from(users) .from(users)
.where(eq(users.id, user.id)); .where(eq(users.userId, user.userId));
if (!existingUser || !existingUser[0]) { if (!existingUser || !existingUser[0]) {
return next( return next(

View file

@ -20,7 +20,7 @@ export const verifySessionUserMiddleware = async (
const existingUser = await db const existingUser = await db
.select() .select()
.from(users) .from(users)
.where(eq(users.id, user.id)); .where(eq(users.userId, user.userId));
if (!existingUser || !existingUser[0]) { if (!existingUser || !existingUser[0]) {
return next( return next(

View file

@ -2,7 +2,7 @@ import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import lucia, { unauthorized } from "@server/auth"; import { unauthorized, invalidateAllSessions } from "@server/auth";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db } from "@server/db";
import { User, users } from "@server/db/schema"; import { User, users } from "@server/db/schema";
@ -74,7 +74,7 @@ export async function changePassword(
const validOTP = await verifyTotpCode( const validOTP = await verifyTotpCode(
code!, code!,
user.twoFactorSecret!, user.twoFactorSecret!,
user.id, user.userId,
); );
if (!validOTP) { if (!validOTP) {
@ -94,9 +94,9 @@ export async function changePassword(
.set({ .set({
passwordHash: hash, passwordHash: hash,
}) })
.where(eq(users.id, user.id)); .where(eq(users.userId, user.userId));
await lucia.invalidateUserSessions(user.id); await invalidateAllSessions(user.userId);
// TODO: send email to user confirming password change // TODO: send email to user confirming password change

View file

@ -69,7 +69,7 @@ export async function disable2fa(
const validOTP = await verifyTotpCode( const validOTP = await verifyTotpCode(
code, code,
user.twoFactorSecret!, user.twoFactorSecret!,
user.id, user.userId,
); );
if (!validOTP) { if (!validOTP) {
@ -84,11 +84,11 @@ export async function disable2fa(
await db await db
.update(users) .update(users)
.set({ twoFactorEnabled: false }) .set({ twoFactorEnabled: false })
.where(eq(users.id, user.id)); .where(eq(users.userId, user.userId));
await db await db
.delete(twoFactorBackupCodes) .delete(twoFactorBackupCodes)
.where(eq(twoFactorBackupCodes.userId, user.id)); .where(eq(twoFactorBackupCodes.userId, user.userId));
// TODO: send email to user confirming two-factor authentication is disabled // TODO: send email to user confirming two-factor authentication is disabled

View file

@ -6,7 +6,7 @@ import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode'; import HttpCode from '@server/types/HttpCode';
export async function getUserOrgs(req: Request, res: Response, next: NextFunction) { export async function getUserOrgs(req: Request, res: Response, next: NextFunction) {
const userId = req.user?.id; // Assuming you have user information in the request const userId = req.user?.userId; // Assuming you have user information in the request
if (!userId) { if (!userId) {
return next(createHttpError(HttpCode.UNAUTHORIZED, 'User not authenticated')); return next(createHttpError(HttpCode.UNAUTHORIZED, 'User not authenticated'));

View file

@ -1,5 +1,10 @@
import { verify } from "@node-rs/argon2"; import { verify } from "@node-rs/argon2";
import lucia, { verifySession } from "@server/auth"; import {
createSession,
generateSessionToken,
serializeSessionCookie,
verifySession,
} from "@server/auth";
import db from "@server/db"; import db from "@server/db";
import { users } from "@server/db/schema"; import { users } from "@server/db/schema";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@ -102,7 +107,7 @@ export async function login(
const validOTP = await verifyTotpCode( const validOTP = await verifyTotpCode(
code, code,
existingUser.twoFactorSecret!, existingUser.twoFactorSecret!,
existingUser.id, existingUser.userId,
); );
if (!validOTP) { if (!validOTP) {
@ -115,13 +120,11 @@ export async function login(
} }
} }
const session = await lucia.createSession(existingUser.id, {}); const token = generateSessionToken();
const cookie = lucia.createSessionCookie(session.id).serialize(); await createSession(token, existingUser.userId);
const cookie = serializeSessionCookie(token);
res.appendHeader( res.appendHeader("Set-Cookie", cookie);
"Set-Cookie",
cookie
);
if (!existingUser.emailVerified) { if (!existingUser.emailVerified) {
return response<LoginResponse>(res, { return response<LoginResponse>(res, {

View file

@ -1,16 +1,20 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { lucia } from "@server/auth";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import response from "@server/utils/response"; import response from "@server/utils/response";
import logger from "@server/logger"; import logger from "@server/logger";
import {
createBlankSessionTokenCookie,
invalidateSession,
SESSION_COOKIE_NAME,
} from "@server/auth";
export async function logout( export async function logout(
req: Request, req: Request,
res: Response, res: Response,
next: NextFunction, next: NextFunction,
): Promise<any> { ): Promise<any> {
const sessionId = req.cookies[lucia.sessionCookieName]; const sessionId = req.cookies[SESSION_COOKIE_NAME];
if (!sessionId) { if (!sessionId) {
return next( return next(
@ -22,11 +26,8 @@ export async function logout(
} }
try { try {
await lucia.invalidateSession(sessionId); await invalidateSession(sessionId);
res.setHeader( res.setHeader("Set-Cookie", createBlankSessionTokenCookie());
"Set-Cookie",
lucia.createBlankSessionCookie().serialize(),
);
return response<null>(res, { return response<null>(res, {
data: null, data: null,

View file

@ -26,7 +26,7 @@ export async function requestEmailVerificationCode(
); );
} }
await sendEmailVerificationCode(user.email, user.id); await sendEmailVerificationCode(user.email, user.userId);
return response<RequestEmailVerificationCodeResponse>(res, { return response<RequestEmailVerificationCodeResponse>(res, {
data: { data: {

View file

@ -8,10 +8,11 @@ import { db } from "@server/db";
import { passwordResetTokens, users } from "@server/db/schema"; import { passwordResetTokens, users } from "@server/db/schema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { sha256 } from "oslo/crypto"; import { sha256 } from "oslo/crypto";
import { generateIdFromEntropySize, TimeSpan } from "lucia";
import { encodeHex } from "oslo/encoding"; import { encodeHex } from "oslo/encoding";
import { createDate } from "oslo"; import { createDate } from "oslo";
import logger from "@server/logger"; import logger from "@server/logger";
import { generateIdFromEntropySize } from "@server/auth";
import { TimeSpan } from "oslo";
export const requestPasswordResetBody = z.object({ export const requestPasswordResetBody = z.object({
email: z.string().email(), email: z.string().email(),
@ -58,7 +59,7 @@ export async function requestPasswordReset(
await db await db
.delete(passwordResetTokens) .delete(passwordResetTokens)
.where(eq(passwordResetTokens.userId, existingUser[0].id)); .where(eq(passwordResetTokens.userId, existingUser[0].userId));
const token = generateIdFromEntropySize(25); const token = generateIdFromEntropySize(25);
const tokenHash = encodeHex( const tokenHash = encodeHex(
@ -66,7 +67,7 @@ export async function requestPasswordReset(
); );
await db.insert(passwordResetTokens).values({ await db.insert(passwordResetTokens).values({
userId: existingUser[0].id, userId: existingUser[0].userId,
tokenHash, tokenHash,
expiresAt: createDate(new TimeSpan(2, "h")).getTime(), expiresAt: createDate(new TimeSpan(2, "h")).getTime(),
}); });
@ -89,7 +90,7 @@ export async function requestPasswordReset(
return next( return next(
createHttpError( createHttpError(
HttpCode.INTERNAL_SERVER_ERROR, HttpCode.INTERNAL_SERVER_ERROR,
"Failed to process password reset request" "Failed to process password reset request",
), ),
); );
} }

View file

@ -72,7 +72,7 @@ export async function requestTotpSecret(
.set({ .set({
twoFactorSecret: secret, twoFactorSecret: secret,
}) })
.where(eq(users.id, user.id)); .where(eq(users.userId, user.userId));
return response<RequestTotpSecretResponse>(res, { return response<RequestTotpSecretResponse>(res, {
data: { data: {

View file

@ -13,7 +13,7 @@ import { verifyTotpCode } from "@server/auth/2fa";
import { passwordSchema } from "@server/auth/passwordSchema"; import { passwordSchema } from "@server/auth/passwordSchema";
import { encodeHex } from "oslo/encoding"; import { encodeHex } from "oslo/encoding";
import { isWithinExpirationDate } from "oslo"; import { isWithinExpirationDate } from "oslo";
import lucia from "@server/auth"; import { invalidateAllSessions } from "@server/auth";
export const resetPasswordBody = z.object({ export const resetPasswordBody = z.object({
token: z.string(), token: z.string(),
@ -71,7 +71,7 @@ export async function resetPassword(
const user = await db const user = await db
.select() .select()
.from(users) .from(users)
.where(eq(users.id, resetRequest[0].userId)); .where(eq(users.userId, resetRequest[0].userId));
if (!user || !user.length) { if (!user || !user.length) {
return next( return next(
@ -96,7 +96,7 @@ export async function resetPassword(
const validOTP = await verifyTotpCode( const validOTP = await verifyTotpCode(
code!, code!,
user[0].twoFactorSecret!, user[0].twoFactorSecret!,
user[0].id, user[0].userId,
); );
if (!validOTP) { if (!validOTP) {
@ -111,12 +111,12 @@ export async function resetPassword(
const passwordHash = await hashPassword(newPassword); const passwordHash = await hashPassword(newPassword);
await lucia.invalidateUserSessions(resetRequest[0].userId); await invalidateAllSessions(resetRequest[0].userId);
await db await db
.update(users) .update(users)
.set({ passwordHash }) .set({ passwordHash })
.where(eq(users.id, resetRequest[0].userId)); .where(eq(users.userId, resetRequest[0].userId));
await db await db
.delete(passwordResetTokens) .delete(passwordResetTokens)

View file

@ -3,9 +3,7 @@ import db from "@server/db";
import { hash } from "@node-rs/argon2"; import { hash } from "@node-rs/argon2";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { z } from "zod"; import { z } from "zod";
import { generateId } from "lucia";
import { users } from "@server/db/schema"; import { users } from "@server/db/schema";
import lucia from "@server/auth";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import response from "@server/utils/response"; import response from "@server/utils/response";
@ -14,6 +12,12 @@ import { sendEmailVerificationCode } from "./sendEmailVerificationCode";
import { passwordSchema } from "@server/auth/passwordSchema"; import { passwordSchema } from "@server/auth/passwordSchema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import moment from "moment"; import moment from "moment";
import {
createSession,
generateId,
generateSessionToken,
serializeSessionCookie,
} from "@server/auth";
export const signupBodySchema = z.object({ export const signupBodySchema = z.object({
email: z.string().email(), email: z.string().email(),
@ -85,22 +89,21 @@ export async function signup(
); );
} else { } else {
// If the user was created more than 2 hours ago, we want to delete the old user and create a new one // If the user was created more than 2 hours ago, we want to delete the old user and create a new one
await db.delete(users).where(eq(users.id, user.id)); await db.delete(users).where(eq(users.userId, user.userId));
} }
} }
await db.insert(users).values({ await db.insert(users).values({
id: userId, userId: userId,
email: email, email: email,
passwordHash, passwordHash,
dateCreated: moment().toISOString(), dateCreated: moment().toISOString(),
}); });
const session = await lucia.createSession(userId, {}); const token = generateSessionToken();
res.appendHeader( await createSession(token, userId);
"Set-Cookie", const cookie = serializeSessionCookie(token);
lucia.createSessionCookie(session.id).serialize(), res.appendHeader("Set-Cookie", cookie);
);
sendEmailVerificationCode(email, userId); sendEmailVerificationCode(email, userId);

View file

@ -51,14 +51,14 @@ export async function verifyEmail(
if (valid) { if (valid) {
await db await db
.delete(emailVerificationCodes) .delete(emailVerificationCodes)
.where(eq(emailVerificationCodes.userId, user.id)); .where(eq(emailVerificationCodes.userId, user.userId));
await db await db
.update(users) .update(users)
.set({ .set({
emailVerified: true, emailVerified: true,
}) })
.where(eq(users.id, user.id)); .where(eq(users.userId, user.userId));
} else { } else {
return next( return next(
createHttpError( createHttpError(
@ -93,7 +93,7 @@ async function isValidCode(user: User, code: string): Promise<boolean> {
const codeRecord = await db const codeRecord = await db
.select() .select()
.from(emailVerificationCodes) .from(emailVerificationCodes)
.where(eq(emailVerificationCodes.userId, user.id)) .where(eq(emailVerificationCodes.userId, user.userId))
.limit(1); .limit(1);
if (user.email !== codeRecord[0].email) { if (user.email !== codeRecord[0].email) {

View file

@ -7,7 +7,7 @@ import HttpCode from '@server/types/HttpCode';
import { AuthenticatedRequest } from '@server/types/Auth'; import { AuthenticatedRequest } from '@server/types/Auth';
export function verifyOrgAccess(req: Request, res: Response, next: NextFunction) { export function verifyOrgAccess(req: Request, res: Response, next: NextFunction) {
const userId = req.user!.id; // Assuming you have user information in the request const userId = req.user!.userId; // Assuming you have user information in the request
const orgId = parseInt(req.params.orgId); const orgId = parseInt(req.params.orgId);
if (!userId) { if (!userId) {

View file

@ -6,7 +6,7 @@ import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode'; import HttpCode from '@server/types/HttpCode';
export async function verifyResourceAccess(req: Request, res: Response, next: NextFunction) { export async function verifyResourceAccess(req: Request, res: Response, next: NextFunction) {
const userId = req.user!.id; // Assuming you have user information in the request const userId = req.user!.userId; // Assuming you have user information in the request
const resourceId = req.params.resourceId || req.body.resourceId || req.query.resourceId; const resourceId = req.params.resourceId || req.body.resourceId || req.query.resourceId;
if (!userId) { if (!userId) {

View file

@ -7,7 +7,7 @@ import HttpCode from '@server/types/HttpCode';
import logger from '@server/logger'; import logger from '@server/logger';
export async function verifyRoleAccess(req: Request, res: Response, next: NextFunction) { export async function verifyRoleAccess(req: Request, res: Response, next: NextFunction) {
const userId = req.user?.id; // Assuming you have user information in the request const userId = req.user?.userId; // Assuming you have user information in the request
const roleId = parseInt(req.params.roleId || req.body.roleId || req.query.roleId); const roleId = parseInt(req.params.roleId || req.body.roleId || req.query.roleId);
if (!userId) { if (!userId) {

View file

@ -6,7 +6,7 @@ import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode'; import HttpCode from '@server/types/HttpCode';
export async function verifySiteAccess(req: Request, res: Response, next: NextFunction) { export async function verifySiteAccess(req: Request, res: Response, next: NextFunction) {
const userId = req.user!.id; // Assuming you have user information in the request const userId = req.user!.userId; // Assuming you have user information in the request
const siteId = parseInt(req.params.siteId || req.body.siteId || req.query.siteId); const siteId = parseInt(req.params.siteId || req.body.siteId || req.query.siteId);
if (!userId) { if (!userId) {

View file

@ -7,7 +7,7 @@ import HttpCode from '@server/types/HttpCode';
import logger from '@server/logger'; import logger from '@server/logger';
export async function verifySuperuser(req: Request, res: Response, next: NextFunction) { export async function verifySuperuser(req: Request, res: Response, next: NextFunction) {
const userId = req.user?.id; // Assuming you have user information in the request const userId = req.user?.userId; // Assuming you have user information in the request
const orgId = req.userOrgId; const orgId = req.userOrgId;
if (!userId) { if (!userId) {

View file

@ -6,7 +6,7 @@ import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode'; import HttpCode from '@server/types/HttpCode';
export async function verifyTargetAccess(req: Request, res: Response, next: NextFunction) { export async function verifyTargetAccess(req: Request, res: Response, next: NextFunction) {
const userId = req.user!.id; // Assuming you have user information in the request const userId = req.user!.userId; // Assuming you have user information in the request
const targetId = parseInt(req.params.targetId); const targetId = parseInt(req.params.targetId);
if (!userId) { if (!userId) {

View file

@ -61,7 +61,7 @@ export async function verifyTotp(
} }
try { try {
const valid = await verifyTotpCode(code, user.twoFactorSecret, user.id); const valid = await verifyTotpCode(code, user.twoFactorSecret, user.userId);
let codes; let codes;
if (valid) { if (valid) {
@ -69,7 +69,7 @@ export async function verifyTotp(
await db await db
.update(users) .update(users)
.set({ twoFactorEnabled: true }) .set({ twoFactorEnabled: true })
.where(eq(users.id, user.id)); .where(eq(users.userId, user.userId));
const backupCodes = await generateBackupCodes(); const backupCodes = await generateBackupCodes();
codes = backupCodes; codes = backupCodes;
@ -77,7 +77,7 @@ export async function verifyTotp(
const hash = await hashPassword(code); const hash = await hashPassword(code);
await db.insert(twoFactorBackupCodes).values({ await db.insert(twoFactorBackupCodes).values({
userId: user.id, userId: user.userId,
codeHash: hash, codeHash: hash,
}); });
} }

View file

@ -6,7 +6,7 @@ import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode'; import HttpCode from '@server/types/HttpCode';
export async function verifyUserAccess(req: Request, res: Response, next: NextFunction) { export async function verifyUserAccess(req: Request, res: Response, next: NextFunction) {
const userId = req.user!.id; // Assuming you have user information in the request const userId = req.user!.userId; // Assuming you have user information in the request
const reqUserId = req.params.userId || req.body.userId || req.query.userId; const reqUserId = req.params.userId || req.body.userId || req.query.userId;
if (!userId) { if (!userId) {

View file

@ -1,11 +1,10 @@
import lucia from "@server/auth";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { response } from "@server/utils/response"; import { response } from "@server/utils/response";
import logger from "@server/logger"; import { validateSessionToken } from "@server/auth";
export const verifyUserBody = z.object({ export const verifyUserBody = z.object({
sessionId: z.string(), sessionId: z.string(),
@ -36,7 +35,7 @@ export async function verifyUser(
const { sessionId } = parsedBody.data; const { sessionId } = parsedBody.data;
try { try {
const { session, user } = await lucia.validateSession(sessionId); const { session, user } = await validateSessionToken(sessionId);
if (!session || !user) { if (!session || !user) {
return next( return next(

View file

@ -95,7 +95,7 @@ export async function createResource(req: Request, res: Response, next: NextFunc
if (req.userOrgRoleId != superuserRole[0].roleId) { if (req.userOrgRoleId != superuserRole[0].roleId) {
// make sure the user can access the resource // make sure the user can access the resource
await db.insert(userResources).values({ await db.insert(userResources).values({
userId: req.user?.id!, userId: req.user?.userId!,
resourceId: newResource[0].resourceId, resourceId: newResource[0].resourceId,
}); });
} }

View file

@ -57,7 +57,7 @@ export async function listResources(req: RequestWithOrgAndRole, res: Response, n
.fullJoin(roleResources, eq(userResources.resourceId, roleResources.resourceId)) .fullJoin(roleResources, eq(userResources.resourceId, roleResources.resourceId))
.where( .where(
or( or(
eq(userResources.userId, req.user!.id), eq(userResources.userId, req.user!.userId),
eq(roleResources.roleId, req.userOrgRoleId!) eq(roleResources.roleId, req.userOrgRoleId!)
) )
); );

View file

@ -93,7 +93,7 @@ export async function createSite(req: Request, res: Response, next: NextFunction
if (req.userOrgRoleId != superuserRole[0].roleId) { if (req.userOrgRoleId != superuserRole[0].roleId) {
// make sure the user can access the site // make sure the user can access the site
db.insert(userSites).values({ db.insert(userSites).values({
userId: req.user?.id!, userId: req.user?.userId!,
siteId: newSite[0].siteId, siteId: newSite[0].siteId,
}); });
} }

View file

@ -49,7 +49,7 @@ export async function listSites(req: Request, res: Response, next: NextFunction)
.fullJoin(roleSites, eq(userSites.siteId, roleSites.siteId)) .fullJoin(roleSites, eq(userSites.siteId, roleSites.siteId))
.where( .where(
or( or(
eq(userSites.userId, req.user!.id), eq(userSites.userId, req.user!.userId),
eq(roleSites.roleId, req.userOrgRoleId!) eq(roleSites.roleId, req.userOrgRoleId!)
) )
); );

View file

@ -36,7 +36,7 @@ export async function addUserAction(req: Request, res: Response, next: NextFunct
} }
// Check if the user exists // Check if the user exists
const user = await db.select().from(users).where(eq(users.id, userId)).limit(1); const user = await db.select().from(users).where(eq(users.userId, userId)).limit(1);
if (user.length === 0) { if (user.length === 0) {
return next(createHttpError(HttpCode.NOT_FOUND, `User with ID ${userId} not found`)); return next(createHttpError(HttpCode.NOT_FOUND, `User with ID ${userId} not found`));
} }

View file

@ -51,7 +51,7 @@ export async function addUserOrg(req: Request, res: Response, next: NextFunction
} }
// Check if the user exists // Check if the user exists
const user = await db.select().from(users).where(eq(users.id, userId)).limit(1); const user = await db.select().from(users).where(eq(users.userId, userId)).limit(1);
if (user.length === 0) { if (user.length === 0) {
return next(createHttpError(HttpCode.NOT_FOUND, `User with ID ${userId} not found`)); return next(createHttpError(HttpCode.NOT_FOUND, `User with ID ${userId} not found`));
} }

View file

@ -21,7 +21,7 @@ export async function getUser(
next: NextFunction, next: NextFunction,
): Promise<any> { ): Promise<any> {
try { try {
const userId = req.user?.id; const userId = req.user?.userId;
if (!userId) { if (!userId) {
return next( return next(
@ -32,7 +32,7 @@ export async function getUser(
const user = await db const user = await db
.select() .select()
.from(users) .from(users)
.where(eq(users.id, userId)) .where(eq(users.userId, userId))
.limit(1); .limit(1);
if (user.length === 0) { if (user.length === 0) {

View file

@ -52,7 +52,7 @@ export async function listUsers(req: Request, res: Response, next: NextFunction)
// Query to join users, userOrgs, and roles tables // Query to join users, userOrgs, and roles tables
const usersWithRoles = await db const usersWithRoles = await db
.select({ .select({
id: users.id, id: users.userId,
email: users.email, email: users.email,
emailVerified: users.emailVerified, emailVerified: users.emailVerified,
dateCreated: users.dateCreated, dateCreated: users.dateCreated,
@ -61,7 +61,7 @@ export async function listUsers(req: Request, res: Response, next: NextFunction)
roleName: roles.name, roleName: roles.name,
}) })
.from(users) .from(users)
.leftJoin(userOrgs, sql`${users.id} = ${userOrgs.userId}`) .leftJoin(userOrgs, sql`${users.userId} = ${userOrgs.userId}`)
.leftJoin(roles, sql`${userOrgs.roleId} = ${roles.roleId}`) .leftJoin(roles, sql`${userOrgs.roleId} = ${roles.roleId}`)
.where(sql`${userOrgs.orgId} = ${orgId}`) .where(sql`${userOrgs.orgId} = ${orgId}`)
.limit(limit) .limit(limit)

View file

@ -1,6 +1,6 @@
import { Request } from "express"; import { Request } from "express";
import { User } from "@server/db/schema"; import { User } from "@server/db/schema";
import { Session } from "lucia"; import { Session } from "@server/db/schema";
export interface AuthenticatedRequest extends Request { export interface AuthenticatedRequest extends Request {
user: User; user: User;

View file

@ -9,8 +9,6 @@ export default async function Page() {
redirect("/auth/login"); redirect("/auth/login");
} }
console.log(user);
return ( return (
<> <>
<LandingProvider user={user}> <LandingProvider user={user}>