Basic websocket and auth for newt

This commit is contained in:
Owen Schwartz 2024-11-10 17:08:11 -05:00
parent 231e1d2e2d
commit e5e78ff1bf
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
7 changed files with 328 additions and 95 deletions

80
server/auth/newt.ts Normal file
View file

@ -0,0 +1,80 @@
export * from "./verifySession";
export * from "./unauthorizedResponse";
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Newt, newts, newtSessions, NewtSession } from "@server/db/schema";
import db from "@server/db";
import { eq } from "drizzle-orm";
import config from "@server/config";
export const SESSION_COOKIE_NAME = "session";
export const SESSION_COOKIE_EXPIRES = 1000 * 60 * 60 * 24 * 30;
export const SECURE_COOKIES = config.server.secure_cookies;
export const COOKIE_DOMAIN =
"." + new URL(config.app.base_url).hostname.split(".").slice(-2).join(".");
export async function createNewtSession(
token: string,
newtId: string,
): Promise<NewtSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const session: NewtSession = {
sessionId: sessionId,
newtId,
expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(),
};
await db.insert(newtSessions).values(session);
return session;
}
export async function validateNewtSessionToken(
token: string,
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const result = await db
.select({ newt: newts, session: newtSessions })
.from(newtSessions)
.innerJoin(newts, eq(newtSessions.newtId, newts.newtId))
.where(eq(newtSessions.sessionId, sessionId));
if (result.length < 1) {
return { session: null, newt: null };
}
const { newt, session } = result[0];
if (Date.now() >= session.expiresAt) {
await db
.delete(newtSessions)
.where(eq(newtSessions.sessionId, session.sessionId));
return { session: null, newt: null };
}
if (Date.now() >= session.expiresAt - (SESSION_COOKIE_EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + SESSION_COOKIE_EXPIRES,
).getTime();
await db
.update(newtSessions)
.set({
expiresAt: session.expiresAt,
})
.where(eq(newtSessions.sessionId, session.sessionId));
}
return { session, newt };
}
export async function invalidateNewtSession(sessionId: string): Promise<void> {
await db.delete(newtSessions).where(eq(newtSessions.sessionId, sessionId));
}
export async function invalidateAllNewtSessions(newtId: string): Promise<void> {
await db.delete(newtSessions).where(eq(newtSessions.newtId, newtId));
}
export type SessionValidationResult =
| { session: NewtSession; newt: Newt }
| { session: null; newt: null };

View file

@ -73,6 +73,12 @@ export const users = sqliteTable("user", {
dateCreated: text("dateCreated").notNull(), dateCreated: text("dateCreated").notNull(),
}); });
export const newts = sqliteTable("newt", {
newtId: text("id").primaryKey(),
secretHash: text("secretHash").notNull(),
dateCreated: text("dateCreated").notNull(),
});
export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", { export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {
codeId: integer("id").primaryKey({ autoIncrement: true }), codeId: integer("id").primaryKey({ autoIncrement: true }),
userId: text("userId") userId: text("userId")
@ -89,6 +95,14 @@ export const sessions = sqliteTable("session", {
expiresAt: integer("expiresAt").notNull(), expiresAt: integer("expiresAt").notNull(),
}); });
export const newtSessions = sqliteTable("newtSession", {
sessionId: text("id").primaryKey(),
newtId: text("newtId")
.notNull()
.references(() => newts.newtId, { onDelete: "cascade" }),
expiresAt: integer("expiresAt").notNull(),
});
export const userOrgs = sqliteTable("userOrgs", { export const userOrgs = sqliteTable("userOrgs", {
userId: text("userId") userId: text("userId")
.notNull() .notNull()
@ -227,6 +241,8 @@ export type Resource = InferSelectModel<typeof resources>;
export type ExitNode = InferSelectModel<typeof exitNodes>; export type ExitNode = InferSelectModel<typeof exitNodes>;
export type Target = InferSelectModel<typeof targets>; export type Target = InferSelectModel<typeof targets>;
export type Session = InferSelectModel<typeof sessions>; export type Session = InferSelectModel<typeof sessions>;
export type Newt = InferSelectModel<typeof newts>;
export type NewtSession = InferSelectModel<typeof newtSessions>;
export type EmailVerificationCode = InferSelectModel< export type EmailVerificationCode = InferSelectModel<
typeof emailVerificationCodes typeof emailVerificationCodes
>; >;

View file

@ -12,6 +12,7 @@ export * from "./verifyTargetAccess";
export * from "./verifyRoleAccess"; export * from "./verifyRoleAccess";
export * from "./verifyUserAccess"; export * from "./verifyUserAccess";
export * from "./verifyAdmin"; export * from "./verifyAdmin";
// export * from "./verifySuperUser";
export * from "./verifyEmail"; export * from "./verifyEmail";
export * from "./requestEmailVerificationCode"; export * from "./requestEmailVerificationCode";
export * from "./changePassword"; export * from "./changePassword";

View file

@ -0,0 +1,115 @@
import { verify } from "@node-rs/argon2";
import {
createSession,
generateSessionToken,
verifySession,
} from "@server/auth";
import db from "@server/db";
import { newts } from "@server/db/schema";
import HttpCode from "@server/types/HttpCode";
import response from "@server/utils/response";
import { eq } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import config from "@server/config";
import { validateNewtSessionToken } from "@server/auth/newt";
export const newtGetTokenBodySchema = z.object({
newtId: z.string().email(),
secret: z.string(),
token: z.string().optional(),
});
export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
export async function newtGetToken(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedBody = newtGetTokenBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { newtId, secret, token } = parsedBody.data;
try {
if (token) {
const { session, newt } = await validateNewtSessionToken(
token
);
if (session) {
return response<null>(res, {
data: null,
success: true,
error: false,
message: "Token session already valid",
status: HttpCode.OK,
});
}
}
const existingNewtRes = await db
.select()
.from(newts)
.where(eq(newts.newtId, newtId));
if (!existingNewtRes || !existingNewtRes.length) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No newt found with that newtId"
)
);
}
const existingNewt = existingNewtRes[0];
const validSecret = await verify(
existingNewt.secretHash,
secret,
{
memoryCost: 19456,
timeCost: 2,
outputLen: 32,
parallelism: 1,
}
);
if (!validSecret) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Secret is incorrect"
)
);
}
const resToken = generateSessionToken();
await createSession(resToken, existingNewt.newtId);
return response<{ token: string }>(res, {
data: {
token: resToken
},
success: true,
error: false,
message: "Token created successfully",
status: HttpCode.OK,
});
} catch (e) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to authenticate newt"
)
);
}
}

View file

@ -4,8 +4,8 @@ import db from "@server/db";
import { users, emailVerificationCodes } from "@server/db/schema"; import { users, emailVerificationCodes } from "@server/db/schema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { sendEmail } from "@server/emails"; import { sendEmail } from "@server/emails";
import VerifyEmail from "@server/emails/templates/VerifyEmailCode";
import config from "@server/config"; import config from "@server/config";
import VerifyEmail from "@server/emails/templates/verifyEmailCode";
export async function sendEmailVerificationCode( export async function sendEmailVerificationCode(
email: string, email: string,

View file

@ -19,7 +19,6 @@ import {
verifyResourceAccess, verifyResourceAccess,
verifyTargetAccess, verifyTargetAccess,
verifyRoleAccess, verifyRoleAccess,
verifyUserInRole,
verifyUserAccess, verifyUserAccess,
} from "./auth"; } from "./auth";
import { verifyUserHasAction } from "./auth/verifyUserHasAction"; import { verifyUserHasAction } from "./auth/verifyUserHasAction";

View file

@ -3,20 +3,25 @@ import { Server as HttpServer } from 'http';
import { WebSocket, WebSocketServer } from 'ws'; import { WebSocket, WebSocketServer } from 'ws';
import { IncomingMessage } from 'http'; import { IncomingMessage } from 'http';
import { Socket } from 'net'; import { Socket } from 'net';
import { Newt, newts, NewtSession } from '@server/db/schema';
import { eq } from 'drizzle-orm';
import db from '@server/db';
import { newtGetToken } from './auth';
import { validateNewtSessionToken } from '@server/auth/newt';
// Custom interfaces // Custom interfaces
interface WebSocketRequest extends IncomingMessage { interface WebSocketRequest extends IncomingMessage {
token?: string; token?: string;
} }
interface AuthenticatedWebSocket extends WebSocket { interface AuthenticatedWebSocket extends WebSocket {
userId?: string; newt?: Newt;
isAlive?: boolean; isAlive?: boolean;
} }
interface TokenPayload { interface TokenPayload {
userId: string; newt: Newt;
// Add other token payload properties as needed session: NewtSession;
} }
const router: Router = Router(); const router: Router = Router();
@ -24,121 +29,138 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true });
// Token verification middleware // Token verification middleware
const verifyToken = async (token: string): Promise<TokenPayload | null> => { const verifyToken = async (token: string): Promise<TokenPayload | null> => {
try { try {
// This is where you'd implement your token verification logic
// For example, verify JWT, check against database, etc. const { session, newt } = await validateNewtSessionToken(
// Return the token payload if valid, null if invalid token
return { userId: 'dummy-user-id' }; // Placeholder return );
} catch (error) {
console.error('Token verification failed:', error); if (!session || !newt) {
return null; return null;
} }
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { newt: existingNewt[0], session };
} catch (error) {
console.error('Token verification failed:', error);
return null;
}
}; };
// Handle WebSocket upgrade requests // Handle WebSocket upgrade requests
router.get('/ws', (req: Request, res: Response) => { router.get('/ws', (req: Request, res: Response) => {
// WebSocket upgrade will be handled by the server // WebSocket upgrade will be handled by the server
res.status(200).send('WebSocket endpoint'); res.status(200).send('WebSocket endpoint');
}); });
router.get('/ws/auth/newtGetToken', newtGetToken);
// Set up WebSocket server handling // Set up WebSocket server handling
const handleWSUpgrade = (server: HttpServer): void => { const handleWSUpgrade = (server: HttpServer): void => {
server.on('upgrade', async (request: WebSocketRequest, socket: Socket, head: Buffer) => { server.on('upgrade', async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
try { try {
// Extract token from query parameters or headers // Extract token from query parameters or headers
const token = request.url?.includes('?') const token = request.url?.includes('?')
? new URLSearchParams(request.url.split('?')[1]).get('token') || '' ? new URLSearchParams(request.url.split('?')[1]).get('token') || ''
: request.headers['sec-websocket-protocol']; : request.headers['sec-websocket-protocol'];
if (!token) { if (!token) {
socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
socket.destroy(); socket.destroy();
return; return;
} }
// Verify the token // Verify the token
const tokenPayload = await verifyToken(token); const tokenPayload = await verifyToken(token);
if (!tokenPayload) { if (!tokenPayload) {
socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
socket.destroy(); socket.destroy();
return; return;
} }
// Store token payload data in the request for later use // Store token payload data in the request for later use
request.token = token; request.token = token;
wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => { wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
// Attach user data to the WebSocket instance // Attach newt data to the WebSocket instance
ws.userId = tokenPayload.userId; ws.newt = tokenPayload.newt;
ws.isAlive = true; ws.isAlive = true;
wss.emit('connection', ws, request); wss.emit('connection', ws, request);
}); });
} catch (error) { } catch (error) {
console.error('Upgrade error:', error); console.error('Upgrade error:', error);
socket.write('HTTP/1.1 500 Internal Server Error\r\n\r\n'); socket.write('HTTP/1.1 500 Internal Server Error\r\n\r\n');
socket.destroy(); socket.destroy();
} }
}); });
}; };
// WebSocket message interface // WebSocket message interface
interface WSMessage { interface WSMessage {
type: string; type: string;
data: any; data: any;
} }
// WebSocket connection handler // WebSocket connection handler
wss.on('connection', (ws: AuthenticatedWebSocket, request: WebSocketRequest) => { wss.on('connection', (ws: AuthenticatedWebSocket, request: WebSocketRequest) => {
console.log(`Client connected - User ID: ${ws.userId}`); console.log(`Client connected - Newt ID: ${ws.newt?.newtId}`);
// Set up ping-pong for connection health check // Set up ping-pong for connection health check
const pingInterval = setInterval(() => { const pingInterval = setInterval(() => {
if (ws.isAlive === false) { if (ws.isAlive === false) {
clearInterval(pingInterval); clearInterval(pingInterval);
return ws.terminate(); return ws.terminate();
} }
ws.isAlive = false; ws.isAlive = false;
ws.ping(); ws.ping();
}, 30000); }, 30000);
// Handle pong response // Handle pong response
ws.on('pong', () => { ws.on('pong', () => {
ws.isAlive = true; ws.isAlive = true;
}); });
// Set up message handler // Set up message handler
ws.on('message', (data) => { ws.on('message', (data) => {
try { try {
const message: WSMessage = JSON.parse(data.toString()); const message: WSMessage = JSON.parse(data.toString());
console.log('Received:', message); console.log('Received:', message);
// Echo the message back // Echo the message back
ws.send(JSON.stringify({ ws.send(JSON.stringify({
type: 'echo', type: 'echo',
data: message data: message
})); }));
} catch (error) { } catch (error) {
console.error('Message parsing error:', error); console.error('Message parsing error:', error);
ws.send(JSON.stringify({ ws.send(JSON.stringify({
type: 'error', type: 'error',
data: 'Invalid message format' data: 'Invalid message format'
})); }));
} }
}); });
// Handle client disconnect // Handle client disconnect
ws.on('close', () => { ws.on('close', () => {
clearInterval(pingInterval); clearInterval(pingInterval);
console.log(`Client disconnected - User ID: ${ws.userId}`); console.log(`Client disconnected - Newt ID: ${ws.newt?.newtId}`);
}); });
// Handle errors // Handle errors
ws.on('error', (error: Error) => { ws.on('error', (error: Error) => {
console.error('WebSocket error:', error); console.error('WebSocket error:', error);
}); });
}); });
export { export {
router, router,
handleWSUpgrade handleWSUpgrade
}; };