diff --git a/server/apiServer.ts b/server/apiServer.ts index 8fc131a5..6f05aae5 100644 --- a/server/apiServer.ts +++ b/server/apiServer.ts @@ -6,11 +6,12 @@ import logger from "@server/logger"; import { errorHandlerMiddleware, notFoundMiddleware, - rateLimitMiddleware, + rateLimitMiddleware } from "@server/middlewares"; import { authenticated, unauthenticated } from "@server/routers/external"; import { router as wsRouter, handleWSUpgrade } from "@server/routers/ws"; import { logIncomingMiddleware } from "./middlewares/logIncoming"; +import { csrfProtectionMiddleware } from "./middlewares/csrfProtection"; import helmet from "helmet"; const dev = process.env.ENVIRONMENT !== "prod"; @@ -25,13 +26,22 @@ export function createApiServer() { apiServer.use( cors({ origin: `http://localhost:${config.server.next_port}`, - credentials: true, - }), + credentials: true + }) ); } else { - apiServer.use(cors()); + const corsOptions = { + origin: config.app.base_url, + methods: ["GET", "POST", "PUT", "DELETE", "PATCH"], + allowedHeaders: ["Content-Type", "X-CSRF-Token"], + credentials: true + }; + + apiServer.use(cors(corsOptions)); apiServer.use(helmet()); + apiServer.use(csrfProtectionMiddleware); } + apiServer.use(cookieParser()); apiServer.use(express.json()); @@ -40,8 +50,8 @@ export function createApiServer() { rateLimitMiddleware({ windowMin: config.rate_limits.global.window_minutes, max: config.rate_limits.global.max_requests, - type: "IP_AND_PATH", - }), + type: "IP_AND_PATH" + }) ); } @@ -62,7 +72,7 @@ export function createApiServer() { const httpServer = apiServer.listen(externalPort, (err?: any) => { if (err) throw err; logger.info( - `API server is running on http://localhost:${externalPort}`, + `API server is running on http://localhost:${externalPort}` ); }); diff --git a/server/middlewares/csrfProtection.ts b/server/middlewares/csrfProtection.ts new file mode 100644 index 00000000..33150d65 --- /dev/null +++ b/server/middlewares/csrfProtection.ts @@ -0,0 +1,24 @@ +import { NextFunction, Request, Response } from "express"; + +export function csrfProtectionMiddleware( + req: Request, + res: Response, + next: NextFunction +) { + const csrfToken = req.headers["x-csrf-token"]; + + // Skip CSRF check for GET requests as they should be idempotent + if (req.method === "GET") { + next(); + return; + } + + if (!csrfToken || csrfToken !== "x-csrf-protection") { + res.status(403).json({ + error: "CSRF token missing or invalid" + }); + return; + } + + next(); +} diff --git a/src/api/index.ts b/src/api/index.ts index 32d0df6e..b59445db 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -32,7 +32,8 @@ export function createApiClient({ env }: { env: env }): AxiosInstance { baseURL, timeout: 10000, headers: { - "Content-Type": "application/json" + "Content-Type": "application/json", + "X-CSRF-Token": "x-csrf-protection" } });