Merge branch 'dev' of https://github.com/fosrl/pangolin into dev

This commit is contained in:
miloschwartz 2025-07-14 22:21:04 -07:00
commit 379d31aac6
No known key found for this signature in database
234 changed files with 16088 additions and 7588 deletions

View file

@ -27,3 +27,4 @@ bruno/
LICENSE
CONTRIBUTING.md
dist
.git

1
.gitignore vendored
View file

@ -34,3 +34,4 @@ bin
test_event.json
.idea/
server/db/index.ts
build.ts

View file

@ -0,0 +1,22 @@
meta {
name: createClient
type: http
seq: 1
}
put {
url: http://localhost:3000/api/v1/site/1/client
body: json
auth: none
}
body:json {
{
"siteId": 1,
"name": "test",
"type": "olm",
"subnet": "100.90.129.4/30",
"olmId": "029yzunhx6nh3y5",
"secret": "l0ymp075y3d4rccb25l6sqpgar52k09etunui970qq5gj7x6"
}
}

View file

@ -0,0 +1,11 @@
meta {
name: pickClientDefaults
type: http
seq: 2
}
get {
url: http://localhost:3000/api/v1/site/1/pick-client-defaults
body: none
auth: none
}

19
docker-compose.dev.yml Normal file
View file

@ -0,0 +1,19 @@
services:
# PostgreSQL Service
db:
image: postgres:17 # Use the PostgreSQL 17 image
container_name: dev_postgres # Name your PostgreSQL container
environment:
POSTGRES_DB: postgres # Default database name
POSTGRES_USER: postgres # Default user
POSTGRES_PASSWORD: password # Default password (change for production!)
ports:
- "5432:5432" # Map host port 5432 to container port 5432
restart: no
redis:
image: redis:latest # Use the latest Redis image
container_name: dev_redis # Name your Redis container
ports:
- "6379:6379" # Map host port 6379 to container port 6379
restart: no

View file

@ -3,7 +3,7 @@ import path from "path";
export default defineConfig({
dialect: "postgresql",
schema: path.join("server", "db", "pg", "schema.ts"),
schema: [path.join("server", "db", "pg", "schema.ts")],
out: path.join("server", "migrations"),
verbose: true,
dbCredentials: {

View file

@ -22,6 +22,10 @@ gerbil:
start_port: 51820
base_endpoint: "{{.DashboardDomain}}"
orgs:
block_size: 24
subnet_group: 100.89.138.0/20
{{if .EnableEmail}}
email:
smtp_host: "{{.EmailSMTPHost}}"

View file

@ -1,4 +1,4 @@
name: captcha_remediation
iame: captcha_remediation
filters:
- Alert.Remediation == true && Alert.GetScope() == "Ip" && Alert.GetScenario() contains "http"
decisions:

View file

@ -10,7 +10,8 @@
"setupErrorIdentifier": "Organization ID is already taken. Please choose a different one.",
"componentsErrorNoMemberCreate": "You are not currently a member of any organizations. Create an organization to get started.",
"componentsErrorNoMember": "You are not currently a member of any organizations.",
"welcome": "Welcome to Pangolin",
"welcome": "Welcome!",
"welcomeTo": "Welcome to",
"componentsCreateOrg": "Create an Organization",
"componentsMember": "You're a member of {count, plural, =0 {no organization} one {one organization} other {# organizations}}.",
"componentsInvalidKey": "Invalid or expired license keys detected. Follow license terms to continue using all features.",
@ -206,6 +207,7 @@
"orgGeneralSettings": "Organization Settings",
"orgGeneralSettingsDescription": "Manage your organization details and configuration",
"saveGeneralSettings": "Save General Settings",
"saveSettings": "Save Settings",
"orgDangerZone": "Danger Zone",
"orgDangerZoneDescription": "Once you delete this org, there is no going back. Please be certain.",
"orgDelete": "Delete Organization",
@ -1092,6 +1094,8 @@
"sidebarAllUsers": "All Users",
"sidebarIdentityProviders": "Identity Providers",
"sidebarLicense": "License",
"sidebarClients": "Clients",
"sidebarDomains": "Domains",
"enableDockerSocket": "Enable Docker Socket",
"enableDockerSocketDescription": "Enable Docker Socket discovery for populating container information. Socket path must be provided to Newt.",
"enableDockerSocketLink": "Learn More",
@ -1131,10 +1135,88 @@
"dark": "dark",
"system": "system",
"theme": "Theme",
"subnetRequired": "Subnet is required",
"initialSetupTitle": "Initial Server Setup",
"initialSetupDescription": "Create the intial server admin account. Only one server admin can exist. You can always change these credentials later.",
"createAdminAccount": "Create Admin Account",
"setupErrorCreateAdmin": "An error occurred while creating the server admin account.",
"certificateStatus": "Certificate Status",
"loading": "Loading",
"restart": "Restart",
"domains": "Domains",
"domainsDescription": "Manage domains for your organization",
"domainsSearch": "Search domains...",
"domainAdd": "Add Domain",
"domainAddDescription": "Register a new domain with your organization",
"domainCreate": "Create Domain",
"domainCreatedDescription": "Domain created successfully",
"domainDeletedDescription": "Domain deleted successfully",
"domainQuestionRemove": "Are you sure you want to remove the domain {domain} from your account?",
"domainMessageRemove": "Once removed, the domain will no longer be associated with your account.",
"domainMessageConfirm": "To confirm, please type the domain name below.",
"domainConfirmDelete": "Confirm Delete Domain",
"domainDelete": "Delete Domain",
"domain": "Domain",
"selectDomainTypeNsName": "Domain Delegation (NS)",
"selectDomainTypeNsDescription": "This domain and all its subdomains. Use this when you want to control an entire domain zone.",
"selectDomainTypeCnameName": "Single Domain (CNAME)",
"selectDomainTypeCnameDescription": "Just this specific domain. Use this for individual subdomains or specific domain entries.",
"selectDomainTypeWildcardName": "Wildcard Domain (CNAME)",
"selectDomainTypeWildcardDescription": "This domain and its first level of subdomains.",
"domainDelegation": "Single Domain",
"selectType": "Select a type",
"actions": "Actions",
"refresh": "Refresh",
"refreshError": "Failed to refresh data",
"verified": "Verified",
"pending": "Pending",
"sidebarBilling": "Billing",
"billing": "Billing",
"orgBillingDescription": "Manage your billing information and subscriptions",
"github": "GitHub",
"pangolinHosted": "Pangolin Hosted",
"fossorial": "Fossorial",
"completeAccountSetup": "Complete Account Setup",
"completeAccountSetupDescription": "Set your password to get started",
"accountSetupSent": "We'll send an account setup code to this email address.",
"accountSetupCode": "Setup Code",
"accountSetupCodeDescription": "Check your email for the setup code.",
"passwordCreate": "Create Password",
"passwordCreateConfirm": "Confirm Password",
"accountSetupSubmit": "Send Setup Code",
"completeSetup": "Complete Setup",
"accountSetupSuccess": "Account setup completed! Welcome to Pangolin!",
"documentation": "Documentation",
"saveAllSettings": "Save All Settings",
"settingsUpdated": "Settings updated",
"settingsUpdatedDescription": "All settings have been updated successfully",
"settingsErrorUpdate": "Failed to update settings",
"settingsErrorUpdateDescription": "An error occurred while updating settings",
"sidebarCollapse": "Collapse",
"sidebarExpand": "Expand",
"newtUpdateAvailable": "Update Available",
"newtUpdateAvailableInfo": "A new version of Newt is available. Please update to the latest version for the best experience.",
"domainPickerEnterDomain": "Enter your domain",
"domainPickerPlaceholder": "myapp.example.com, api.v1.mydomain.com, or just myapp",
"domainPickerDescription": "Enter a full domain, subdomain, or just a name to see available options",
"domainPickerTabAll": "All",
"domainPickerTabOrganization": "Organization",
"domainPickerTabProvided": "Provided",
"domainPickerSortAsc": "A-Z",
"domainPickerSortDesc": "Z-A",
"domainPickerCheckingAvailability": "Checking availability...",
"domainPickerNoMatchingDomains": "No matching domains found for \"{userInput}\". Try a different domain or check your organization's domain settings.",
"domainPickerOrganizationDomains": "Organization Domains",
"domainPickerProvidedDomains": "Provided Domains",
"domainPickerSubdomain": "Subdomain: {subdomain}",
"domainPickerNamespace": "Namespace: {namespace}",
"domainPickerShowMore": "Show More",
"domainNotFound": "Domain Not Found",
"domainNotFoundDescription": "This resource is disabled because the domain no longer exists our system. Please set a new domain for this resource.",
"failed": "Failed",
"createNewOrgDescription": "Create a new organization",
"organization": "Organization",
"port": "Port",
"securityKeyManage": "Manage Security Keys",
"securityKeyDescription": "Add or remove security keys for passwordless authentication",
"securityKeyRegister": "Register New Security Key",

2151
package-lock.json generated

File diff suppressed because it is too large Load diff

View file

@ -49,6 +49,7 @@
"@radix-ui/react-switch": "1.2.5",
"@radix-ui/react-tabs": "1.1.12",
"@radix-ui/react-toast": "1.2.14",
"@radix-ui/react-tooltip": "^1.2.7",
"@react-email/components": "0.3.1",
"@react-email/render": "^1.1.2",
"@simplewebauthn/browser": "^13.1.0",
@ -78,6 +79,7 @@
"http-errors": "2.0.0",
"i": "^0.3.7",
"input-otp": "1.4.2",
"ioredis": "^5.6.1",
"jmespath": "^0.16.0",
"js-yaml": "4.1.0",
"jsonwebtoken": "^9.0.2",
@ -93,6 +95,7 @@
"oslo": "1.2.1",
"pg": "^8.16.2",
"qrcode.react": "4.2.0",
"rate-limit-redis": "^4.2.1",
"react": "19.1.0",
"react-dom": "19.1.0",
"react-easy-sort": "^1.6.0",
@ -127,6 +130,7 @@
"@types/jsonwebtoken": "^9.0.10",
"@types/node": "^24",
"@types/nodemailer": "6.4.17",
"@types/pg": "8.15.4",
"@types/react": "19.1.8",
"@types/react-dom": "19.1.6",
"@types/semver": "^7.7.0",

View file

@ -5,7 +5,7 @@ import config from "@server/lib/config";
import logger from "@server/logger";
import {
errorHandlerMiddleware,
notFoundMiddleware,
notFoundMiddleware
} from "@server/middlewares";
import { authenticated, unauthenticated } from "@server/routers/external";
import { router as wsRouter, handleWSUpgrade } from "@server/routers/ws";
@ -15,12 +15,14 @@ import helmet from "helmet";
import rateLimit from "express-rate-limit";
import createHttpError from "http-errors";
import HttpCode from "./types/HttpCode";
import requestTimeoutMiddleware from "./middlewares/requestTimeout";
const dev = config.isDev;
const externalPort = config.getRawConfig().server.external_port;
export function createApiServer() {
const apiServer = express();
const prefix = `/api/v1`;
const trustProxy = config.getRawConfig().server.trust_proxy;
if (trustProxy) {
@ -56,6 +58,9 @@ export function createApiServer() {
apiServer.use(cookieParser());
apiServer.use(express.json());
// Add request timeout middleware
apiServer.use(requestTimeoutMiddleware(60000)); // 60 second timeout
if (!dev) {
apiServer.use(
rateLimit({
@ -76,7 +81,6 @@ export function createApiServer() {
}
// API routes
const prefix = `/api/v1`;
apiServer.use(logIncomingMiddleware);
apiServer.use(prefix, unauthenticated);
apiServer.use(prefix, authenticated);

View file

@ -69,6 +69,11 @@ export enum ActionsEnum {
deleteResourceRule = "deleteResourceRule",
listResourceRules = "listResourceRules",
updateResourceRule = "updateResourceRule",
createClient = "createClient",
deleteClient = "deleteClient",
updateClient = "updateClient",
listClients = "listClients",
getClient = "getClient",
listOrgDomains = "listOrgDomains",
createNewt = "createNewt",
createIdp = "createIdp",
@ -87,7 +92,10 @@ export enum ActionsEnum {
setApiKeyOrgs = "setApiKeyOrgs",
listApiKeyActions = "listApiKeyActions",
listApiKeys = "listApiKeys",
getApiKey = "getApiKey"
getApiKey = "getApiKey",
createOrgDomain = "createOrgDomain",
deleteOrgDomain = "deleteOrgDomain",
restartOrgDomain = "restartOrgDomain"
}
export async function checkUserActionPermission(

View file

@ -1,40 +0,0 @@
import { db } from '@server/db';
import { limitsTable } from '@server/db';
import { and, eq } from 'drizzle-orm';
import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode';
interface CheckLimitOptions {
orgId: string;
limitName: string;
currentValue: number;
increment?: number;
}
export async function checkOrgLimit({ orgId, limitName, currentValue, increment = 0 }: CheckLimitOptions): Promise<boolean> {
try {
const limit = await db.select()
.from(limitsTable)
.where(
and(
eq(limitsTable.orgId, orgId),
eq(limitsTable.name, limitName)
)
)
.limit(1);
if (limit.length === 0) {
throw createHttpError(HttpCode.NOT_FOUND, `Limit "${limitName}" not found for organization`);
}
const limitValue = limit[0].value;
// Check if the current value plus the increment is within the limit
return (currentValue + increment) <= limitValue;
} catch (error) {
if (error instanceof Error) {
throw createHttpError(HttpCode.INTERNAL_SERVER_ERROR, `Error checking limit: ${error.message}`);
}
throw createHttpError(HttpCode.INTERNAL_SERVER_ERROR, 'Unknown error occurred while checking limit');
}
}

View file

@ -0,0 +1,72 @@
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Olm, olms, olmSessions, OlmSession } from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createOlmSession(
token: string,
olmId: string,
): Promise<OlmSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const session: OlmSession = {
sessionId: sessionId,
olmId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
};
await db.insert(olmSessions).values(session);
return session;
}
export async function validateOlmSessionToken(
token: string,
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const result = await db
.select({ olm: olms, session: olmSessions })
.from(olmSessions)
.innerJoin(olms, eq(olmSessions.olmId, olms.olmId))
.where(eq(olmSessions.sessionId, sessionId));
if (result.length < 1) {
return { session: null, olm: null };
}
const { olm, session } = result[0];
if (Date.now() >= session.expiresAt) {
await db
.delete(olmSessions)
.where(eq(olmSessions.sessionId, session.sessionId));
return { session: null, olm: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
await db
.update(olmSessions)
.set({
expiresAt: session.expiresAt,
})
.where(eq(olmSessions.sessionId, session.sessionId));
}
return { session, olm };
}
export async function invalidateOlmSession(sessionId: string): Promise<void> {
await db.delete(olmSessions).where(eq(olmSessions.sessionId, sessionId));
}
export async function invalidateAllOlmSessions(olmId: string): Promise<void> {
await db.delete(olmSessions).where(eq(olmSessions.olmId, olmId));
}
export type SessionValidationResult =
| { session: OlmSession; olm: Olm }
| { session: null; olm: null };

View file

@ -59,7 +59,7 @@ export async function getUniqueExitNodeEndpointName(): Promise<string> {
export function generateName(): string {
return (
const name = (
names.descriptors[
Math.floor(Math.random() * names.descriptors.length)
] +
@ -68,4 +68,7 @@ export function generateName(): string {
)
.toLowerCase()
.replace(/\s/g, "-");
// clean out any non-alphanumeric characters except for dashes
return name.replace(/[^a-z0-9-]/g, "");
}

View file

@ -1,4 +1,5 @@
import { drizzle as DrizzlePostgres } from "drizzle-orm/node-postgres";
import { Pool } from "pg";
import { readConfigFile } from "@server/lib/readConfigFile";
import { withReplicas } from "drizzle-orm/pg-core";
@ -20,19 +21,31 @@ function createDb() {
);
}
const primary = DrizzlePostgres(connectionString);
// Create connection pools instead of individual connections
const primaryPool = new Pool({
connectionString,
max: 20,
idleTimeoutMillis: 30000,
connectionTimeoutMillis: 2000,
});
const replicas = [];
if (!replicaConnections.length) {
replicas.push(primary);
replicas.push(DrizzlePostgres(primaryPool));
} else {
for (const conn of replicaConnections) {
const replica = DrizzlePostgres(conn.connection_string);
replicas.push(replica);
const replicaPool = new Pool({
connectionString: conn.connection_string,
max: 10,
idleTimeoutMillis: 30000,
connectionTimeoutMillis: 2000,
});
replicas.push(DrizzlePostgres(replicaPool));
}
}
return withReplicas(primary, replicas as any);
return withReplicas(DrizzlePostgres(primaryPool), replicas as any);
}
export const db = createDb();

View file

@ -1,2 +1,2 @@
export * from "./driver";
export * from "./schema";
export * from "./schema";

View file

@ -1,5 +1,5 @@
import { migrate } from "drizzle-orm/node-postgres/migrator";
import db from "./driver";
import { db } from "./driver";
import path from "path";
const migrationsFolder = path.join("server/migrations");

View file

@ -12,12 +12,17 @@ import { InferSelectModel } from "drizzle-orm";
export const domains = pgTable("domains", {
domainId: varchar("domainId").primaryKey(),
baseDomain: varchar("baseDomain").notNull(),
configManaged: boolean("configManaged").notNull().default(false)
configManaged: boolean("configManaged").notNull().default(false),
type: varchar("type"), // "ns", "cname", "wildcard"
verified: boolean("verified").notNull().default(false),
failed: boolean("failed").notNull().default(false),
tries: integer("tries").notNull().default(0)
});
export const orgs = pgTable("orgs", {
orgId: varchar("orgId").primaryKey(),
name: varchar("name").notNull()
name: varchar("name").notNull(),
subnet: varchar("subnet").notNull()
});
export const orgDomains = pgTable("orgDomains", {
@ -42,12 +47,17 @@ export const sites = pgTable("sites", {
}),
name: varchar("name").notNull(),
pubKey: varchar("pubKey"),
subnet: varchar("subnet").notNull(),
megabytesIn: real("bytesIn"),
megabytesOut: real("bytesOut"),
subnet: varchar("subnet"),
megabytesIn: real("bytesIn").default(0),
megabytesOut: real("bytesOut").default(0),
lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
type: varchar("type").notNull(), // "newt" or "wireguard"
online: boolean("online").notNull().default(false),
address: varchar("address"),
endpoint: varchar("endpoint"),
publicKey: varchar("publicKey"),
lastHolePunch: bigint("lastHolePunch", { mode: "number" }),
listenPort: integer("listenPort"),
dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true)
});
@ -107,7 +117,8 @@ export const exitNodes = pgTable("exitNodes", {
endpoint: varchar("endpoint").notNull(),
publicKey: varchar("publicKey").notNull(),
listenPort: integer("listenPort").notNull(),
reachableAt: varchar("reachableAt")
reachableAt: varchar("reachableAt"),
maxConnections: integer("maxConnections")
});
export const users = pgTable("user", {
@ -132,6 +143,7 @@ export const newts = pgTable("newt", {
newtId: varchar("id").primaryKey(),
secretHash: varchar("secretHash").notNull(),
dateCreated: varchar("dateCreated").notNull(),
version: varchar("version"),
siteId: integer("siteId").references(() => sites.siteId, {
onDelete: "cascade"
})
@ -274,18 +286,6 @@ export const userResources = pgTable("userResources", {
.references(() => resources.resourceId, { onDelete: "cascade" })
});
export const limitsTable = pgTable("limits", {
limitId: serial("limitId").primaryKey(),
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull(),
name: varchar("name").notNull(),
value: bigint("value", { mode: "number" }).notNull(),
description: varchar("description")
});
export const userInvites = pgTable("userInvites", {
inviteId: varchar("inviteId").primaryKey(),
orgId: varchar("orgId")
@ -492,6 +492,75 @@ export const idpOrg = pgTable("idpOrg", {
orgMapping: varchar("orgMapping")
});
export const clients = pgTable("clients", {
clientId: serial("id").primaryKey(),
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull(),
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null"
}),
name: varchar("name").notNull(),
pubKey: varchar("pubKey"),
subnet: varchar("subnet").notNull(),
megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"),
lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
lastPing: varchar("lastPing"),
type: varchar("type").notNull(), // "olm"
online: boolean("online").notNull().default(false),
endpoint: varchar("endpoint"),
lastHolePunch: integer("lastHolePunch"),
maxConnections: integer("maxConnections")
});
export const clientSites = pgTable("clientSites", {
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
isRelayed: boolean("isRelayed").notNull().default(false)
});
export const olms = pgTable("olms", {
olmId: varchar("id").primaryKey(),
secretHash: varchar("secretHash").notNull(),
dateCreated: varchar("dateCreated").notNull(),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
})
});
export const olmSessions = pgTable("clientSession", {
sessionId: varchar("id").primaryKey(),
olmId: varchar("olmId")
.notNull()
.references(() => olms.olmId, { onDelete: "cascade" }),
expiresAt: integer("expiresAt").notNull()
});
export const userClients = pgTable("userClients", {
userId: varchar("userId")
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" })
});
export const roleClients = pgTable("roleClients", {
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" }),
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" })
});
export const securityKeys = pgTable("webauthnCredentials", {
credentialId: varchar("credentialId").primaryKey(),
userId: varchar("userId").notNull().references(() => users.userId, {
@ -538,7 +607,6 @@ export type RoleSite = InferSelectModel<typeof roleSites>;
export type UserSite = InferSelectModel<typeof userSites>;
export type RoleResource = InferSelectModel<typeof roleResources>;
export type UserResource = InferSelectModel<typeof userResources>;
export type Limit = InferSelectModel<typeof limitsTable>;
export type UserInvite = InferSelectModel<typeof userInvites>;
export type UserOrg = InferSelectModel<typeof userOrgs>;
export type ResourceSession = InferSelectModel<typeof resourceSessions>;
@ -555,3 +623,10 @@ export type Idp = InferSelectModel<typeof idp>;
export type ApiKey = InferSelectModel<typeof apiKeys>;
export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>;
export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>;
export type Client = InferSelectModel<typeof clients>;
export type ClientSite = InferSelectModel<typeof clientSites>;
export type Olm = InferSelectModel<typeof olms>;
export type OlmSession = InferSelectModel<typeof olmSessions>;
export type UserClient = InferSelectModel<typeof userClients>;
export type RoleClient = InferSelectModel<typeof roleClients>;
export type OrgDomains = InferSelectModel<typeof orgDomains>;

442
server/db/redis.ts Normal file
View file

@ -0,0 +1,442 @@
import Redis, { RedisOptions } from "ioredis";
import logger from "@server/logger";
import config from "@server/lib/config";
import { build } from "@server/build";
class RedisManager {
public client: Redis | null = null;
private subscriber: Redis | null = null;
private publisher: Redis | null = null;
private isEnabled: boolean = false;
private isHealthy: boolean = true;
private lastHealthCheck: number = 0;
private healthCheckInterval: number = 30000; // 30 seconds
private subscribers: Map<
string,
Set<(channel: string, message: string) => void>
> = new Map();
constructor() {
if (build == "oss") {
this.isEnabled = false;
} else {
this.isEnabled = config.getRawConfig().flags?.enable_redis || false;
}
if (this.isEnabled) {
this.initializeClients();
}
}
private getRedisConfig(): RedisOptions {
const redisConfig = config.getRawConfig().redis!;
const opts: RedisOptions = {
host: redisConfig.host!,
port: redisConfig.port!,
password: redisConfig.password,
db: redisConfig.db,
// tls: {
// rejectUnauthorized:
// redisConfig.tls?.reject_unauthorized || false
// }
};
return opts;
}
// Add reconnection logic in initializeClients
private initializeClients(): void {
const config = this.getRedisConfig();
try {
this.client = new Redis({
...config,
enableReadyCheck: false,
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: 10000, // 10 seconds
commandTimeout: 5000, // 5 seconds
});
this.publisher = new Redis({
...config,
enableReadyCheck: false,
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: 10000, // 10 seconds
commandTimeout: 5000, // 5 seconds
});
this.subscriber = new Redis({
...config,
enableReadyCheck: false,
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: 10000, // 10 seconds
commandTimeout: 5000, // 5 seconds
});
// Add reconnection handlers
this.client.on("error", (err) => {
logger.error("Redis client error:", err);
this.isHealthy = false;
});
this.client.on("reconnecting", () => {
logger.info("Redis client reconnecting...");
this.isHealthy = false;
});
this.client.on("ready", () => {
logger.info("Redis client ready");
this.isHealthy = true;
});
this.publisher.on("error", (err) => {
logger.error("Redis publisher error:", err);
this.isHealthy = false;
});
this.publisher.on("ready", () => {
logger.info("Redis publisher ready");
});
this.subscriber.on("error", (err) => {
logger.error("Redis subscriber error:", err);
this.isHealthy = false;
});
this.subscriber.on("ready", () => {
logger.info("Redis subscriber ready");
});
// Set up connection handlers
this.client.on("connect", () => {
logger.info("Redis client connected");
});
this.publisher.on("connect", () => {
logger.info("Redis publisher connected");
});
this.subscriber.on("connect", () => {
logger.info("Redis subscriber connected");
});
// Set up message handler for subscriber
this.subscriber.on(
"message",
(channel: string, message: string) => {
const channelSubscribers = this.subscribers.get(channel);
if (channelSubscribers) {
channelSubscribers.forEach((callback) => {
try {
callback(channel, message);
} catch (error) {
logger.error(
`Error in subscriber callback for channel ${channel}:`,
error
);
}
});
}
}
);
logger.info("Redis clients initialized successfully");
// Start periodic health monitoring
this.startHealthMonitoring();
} catch (error) {
logger.error("Failed to initialize Redis clients:", error);
this.isEnabled = false;
}
}
private startHealthMonitoring(): void {
if (!this.isEnabled) return;
// Check health every 30 seconds
setInterval(async () => {
try {
await this.checkRedisHealth();
} catch (error) {
logger.error("Error during Redis health monitoring:", error);
}
}, this.healthCheckInterval);
}
public isRedisEnabled(): boolean {
return this.isEnabled && this.client !== null && this.isHealthy;
}
private async checkRedisHealth(): Promise<boolean> {
const now = Date.now();
// Only check health every 30 seconds
if (now - this.lastHealthCheck < this.healthCheckInterval) {
return this.isHealthy;
}
this.lastHealthCheck = now;
if (!this.client) {
this.isHealthy = false;
return false;
}
try {
await Promise.race([
this.client.ping(),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Health check timeout')), 2000)
)
]);
this.isHealthy = true;
return true;
} catch (error) {
logger.error("Redis health check failed:", error);
this.isHealthy = false;
return false;
}
}
public getClient(): Redis {
return this.client!;
}
public async set(
key: string,
value: string,
ttl?: number
): Promise<boolean> {
if (!this.isRedisEnabled() || !this.client) return false;
try {
if (ttl) {
await this.client.setex(key, ttl, value);
} else {
await this.client.set(key, value);
}
return true;
} catch (error) {
logger.error("Redis SET error:", error);
return false;
}
}
public async get(key: string): Promise<string | null> {
if (!this.isRedisEnabled() || !this.client) return null;
try {
return await this.client.get(key);
} catch (error) {
logger.error("Redis GET error:", error);
return null;
}
}
public async del(key: string): Promise<boolean> {
if (!this.isRedisEnabled() || !this.client) return false;
try {
await this.client.del(key);
return true;
} catch (error) {
logger.error("Redis DEL error:", error);
return false;
}
}
public async sadd(key: string, member: string): Promise<boolean> {
if (!this.isRedisEnabled() || !this.client) return false;
try {
await this.client.sadd(key, member);
return true;
} catch (error) {
logger.error("Redis SADD error:", error);
return false;
}
}
public async srem(key: string, member: string): Promise<boolean> {
if (!this.isRedisEnabled() || !this.client) return false;
try {
await this.client.srem(key, member);
return true;
} catch (error) {
logger.error("Redis SREM error:", error);
return false;
}
}
public async smembers(key: string): Promise<string[]> {
if (!this.isRedisEnabled() || !this.client) return [];
try {
return await this.client.smembers(key);
} catch (error) {
logger.error("Redis SMEMBERS error:", error);
return [];
}
}
public async hset(
key: string,
field: string,
value: string
): Promise<boolean> {
if (!this.isRedisEnabled() || !this.client) return false;
try {
await this.client.hset(key, field, value);
return true;
} catch (error) {
logger.error("Redis HSET error:", error);
return false;
}
}
public async hget(key: string, field: string): Promise<string | null> {
if (!this.isRedisEnabled() || !this.client) return null;
try {
return await this.client.hget(key, field);
} catch (error) {
logger.error("Redis HGET error:", error);
return null;
}
}
public async hdel(key: string, field: string): Promise<boolean> {
if (!this.isRedisEnabled() || !this.client) return false;
try {
await this.client.hdel(key, field);
return true;
} catch (error) {
logger.error("Redis HDEL error:", error);
return false;
}
}
public async hgetall(key: string): Promise<Record<string, string>> {
if (!this.isRedisEnabled() || !this.client) return {};
try {
return await this.client.hgetall(key);
} catch (error) {
logger.error("Redis HGETALL error:", error);
return {};
}
}
public async publish(channel: string, message: string): Promise<boolean> {
if (!this.isRedisEnabled() || !this.publisher) return false;
// Quick health check before attempting to publish
const isHealthy = await this.checkRedisHealth();
if (!isHealthy) {
logger.warn("Skipping Redis publish due to unhealthy connection");
return false;
}
try {
// Add timeout to prevent hanging
await Promise.race([
this.publisher.publish(channel, message),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Redis publish timeout')), 3000)
)
]);
return true;
} catch (error) {
logger.error("Redis PUBLISH error:", error);
this.isHealthy = false; // Mark as unhealthy on error
return false;
}
}
public async subscribe(
channel: string,
callback: (channel: string, message: string) => void
): Promise<boolean> {
if (!this.isRedisEnabled() || !this.subscriber) return false;
try {
// Add callback to subscribers map
if (!this.subscribers.has(channel)) {
this.subscribers.set(channel, new Set());
// Only subscribe to the channel if it's the first subscriber
await Promise.race([
this.subscriber.subscribe(channel),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Redis subscribe timeout')), 5000)
)
]);
}
this.subscribers.get(channel)!.add(callback);
return true;
} catch (error) {
logger.error("Redis SUBSCRIBE error:", error);
this.isHealthy = false;
return false;
}
}
public async unsubscribe(
channel: string,
callback?: (channel: string, message: string) => void
): Promise<boolean> {
if (!this.isRedisEnabled() || !this.subscriber) return false;
try {
const channelSubscribers = this.subscribers.get(channel);
if (!channelSubscribers) return true;
if (callback) {
// Remove specific callback
channelSubscribers.delete(callback);
if (channelSubscribers.size === 0) {
this.subscribers.delete(channel);
await this.subscriber.unsubscribe(channel);
}
} else {
// Remove all callbacks for this channel
this.subscribers.delete(channel);
await this.subscriber.unsubscribe(channel);
}
return true;
} catch (error) {
logger.error("Redis UNSUBSCRIBE error:", error);
return false;
}
}
public async disconnect(): Promise<void> {
try {
if (this.client) {
await this.client.quit();
this.client = null;
}
if (this.publisher) {
await this.publisher.quit();
this.publisher = null;
}
if (this.subscriber) {
await this.subscriber.quit();
this.subscriber = null;
}
this.subscribers.clear();
logger.info("Redis clients disconnected");
} catch (error) {
logger.error("Error disconnecting Redis clients:", error);
}
}
}
export const redisManager = new RedisManager();
export const redis = redisManager.getClient();
export default redisManager;

View file

@ -1,5 +1,5 @@
import { migrate } from "drizzle-orm/better-sqlite3/migrator";
import db from "./driver";
import { db } from "./driver";
import path from "path";
const migrationsFolder = path.join("server/migrations");

View file

@ -6,12 +6,26 @@ export const domains = sqliteTable("domains", {
baseDomain: text("baseDomain").notNull(),
configManaged: integer("configManaged", { mode: "boolean" })
.notNull()
.default(false)
.default(false),
type: text("type"), // "ns", "cname", "wildcard"
verified: integer("verified", { mode: "boolean" }).notNull().default(false),
failed: integer("failed", { mode: "boolean" }).notNull().default(false),
tries: integer("tries").notNull().default(0)
});
export const orgs = sqliteTable("orgs", {
orgId: text("orgId").primaryKey(),
name: text("name").notNull()
name: text("name").notNull(),
subnet: text("subnet").notNull(),
});
export const userDomains = sqliteTable("userDomains", {
userId: text("userId")
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
domainId: text("domainId")
.notNull()
.references(() => domains.domainId, { onDelete: "cascade" })
});
export const orgDomains = sqliteTable("orgDomains", {
@ -36,12 +50,19 @@ export const sites = sqliteTable("sites", {
}),
name: text("name").notNull(),
pubKey: text("pubKey"),
subnet: text("subnet").notNull(),
megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"),
subnet: text("subnet"),
megabytesIn: integer("bytesIn").default(0),
megabytesOut: integer("bytesOut").default(0),
lastBandwidthUpdate: text("lastBandwidthUpdate"),
type: text("type").notNull(), // "newt" or "wireguard"
online: integer("online", { mode: "boolean" }).notNull().default(false),
// exit node stuff that is how to connect to the site when it has a wg server
address: text("address"), // this is the address of the wireguard interface in newt
endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config
publicKey: text("publicKey"), // TODO: Fix typo in publicKey
lastHolePunch: integer("lastHolePunch"),
listenPort: integer("listenPort"),
dockerSocketEnabled: integer("dockerSocketEnabled", { mode: "boolean" })
.notNull()
.default(true)
@ -109,7 +130,8 @@ export const exitNodes = sqliteTable("exitNodes", {
endpoint: text("endpoint").notNull(), // this is how to reach gerbil externally - gets put into the wireguard config
publicKey: text("publicKey").notNull(),
listenPort: integer("listenPort").notNull(),
reachableAt: text("reachableAt") // this is the internal address of the gerbil http server for command control
reachableAt: text("reachableAt"), // this is the internal address of the gerbil http server for command control
maxConnections: integer("maxConnections")
});
export const users = sqliteTable("user", {
@ -165,11 +187,54 @@ export const newts = sqliteTable("newt", {
newtId: text("id").primaryKey(),
secretHash: text("secretHash").notNull(),
dateCreated: text("dateCreated").notNull(),
version: text("version"),
siteId: integer("siteId").references(() => sites.siteId, {
onDelete: "cascade"
})
});
export const clients = sqliteTable("clients", {
clientId: integer("id").primaryKey({ autoIncrement: true }),
orgId: text("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull(),
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null"
}),
name: text("name").notNull(),
pubKey: text("pubKey"),
subnet: text("subnet").notNull(),
megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"),
lastBandwidthUpdate: text("lastBandwidthUpdate"),
lastPing: text("lastPing"),
type: text("type").notNull(), // "olm"
online: integer("online", { mode: "boolean" }).notNull().default(false),
endpoint: text("endpoint"),
lastHolePunch: integer("lastHolePunch")
});
export const clientSites = sqliteTable("clientSites", {
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
isRelayed: integer("isRelayed", { mode: "boolean" }).notNull().default(false)
});
export const olms = sqliteTable("olms", {
olmId: text("id").primaryKey(),
secretHash: text("secretHash").notNull(),
dateCreated: text("dateCreated").notNull(),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
})
});
export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {
codeId: integer("id").primaryKey({ autoIncrement: true }),
userId: text("userId")
@ -194,6 +259,14 @@ export const newtSessions = sqliteTable("newtSession", {
expiresAt: integer("expiresAt").notNull()
});
export const olmSessions = sqliteTable("clientSession", {
sessionId: text("id").primaryKey(),
olmId: text("olmId")
.notNull()
.references(() => olms.olmId, { onDelete: "cascade" }),
expiresAt: integer("expiresAt").notNull()
});
export const userOrgs = sqliteTable("userOrgs", {
userId: text("userId")
.notNull()
@ -289,6 +362,24 @@ export const userSites = sqliteTable("userSites", {
.references(() => sites.siteId, { onDelete: "cascade" })
});
export const userClients = sqliteTable("userClients", {
userId: text("userId")
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" })
});
export const roleClients = sqliteTable("roleClients", {
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" }),
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" })
});
export const roleResources = sqliteTable("roleResources", {
roleId: integer("roleId")
.notNull()
@ -547,6 +638,8 @@ export type Target = InferSelectModel<typeof targets>;
export type Session = InferSelectModel<typeof sessions>;
export type Newt = InferSelectModel<typeof newts>;
export type NewtSession = InferSelectModel<typeof newtSessions>;
export type Olm = InferSelectModel<typeof olms>;
export type OlmSession = InferSelectModel<typeof olmSessions>;
export type EmailVerificationCode = InferSelectModel<
typeof emailVerificationCodes
>;
@ -572,8 +665,13 @@ export type ResourceWhitelist = InferSelectModel<typeof resourceWhitelist>;
export type VersionMigration = InferSelectModel<typeof versionMigrations>;
export type ResourceRule = InferSelectModel<typeof resourceRules>;
export type Domain = InferSelectModel<typeof domains>;
export type Client = InferSelectModel<typeof clients>;
export type ClientSite = InferSelectModel<typeof clientSites>;
export type RoleClient = InferSelectModel<typeof roleClients>;
export type UserClient = InferSelectModel<typeof userClients>;
export type SupporterKey = InferSelectModel<typeof supporterKey>;
export type Idp = InferSelectModel<typeof idp>;
export type ApiKey = InferSelectModel<typeof apiKeys>;
export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>;
export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>;
export type OrgDomains = InferSelectModel<typeof orgDomains>;

View file

@ -2,6 +2,7 @@ import { render } from "@react-email/render";
import { ReactElement } from "react";
import emailClient from "@server/emails";
import logger from "@server/logger";
import config from "@server/lib/config";
export async function sendEmail(
template: ReactElement,
@ -24,9 +25,11 @@ export async function sendEmail(
const emailHtml = await render(template);
const appName = "Fossorial - Pangolin";
await emailClient.sendMail({
from: {
name: opts.name || "Pangolin",
name: opts.name || appName,
address: opts.from,
},
to: opts.to,

View file

@ -1,11 +1,5 @@
import {
Body,
Head,
Html,
Preview,
Tailwind
} from "@react-email/components";
import * as React from "react";
import React from "react";
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
import { themeColors } from "./lib/theme";
import {
EmailContainer,
@ -22,29 +16,29 @@ interface Props {
}
export const ConfirmPasswordReset = ({ email }: Props) => {
const previewText = `Your password has been reset`;
const previewText = `Your password has been successfully reset.`;
return (
<Html>
<Head />
<Preview>{previewText}</Preview>
<Tailwind config={themeColors}>
<Body className="font-sans relative">
<Body className="font-sans bg-gray-50">
<EmailContainer>
<EmailLetterHead />
<EmailHeading>Password Reset Confirmation</EmailHeading>
{/* <EmailHeading>Password Successfully Reset</EmailHeading> */}
<EmailGreeting>Hi {email || "there"},</EmailGreeting>
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
This email confirms that your password has just been
reset. If you made this change, no further action is
required.
Your password has been successfully reset. You can
now sign in to your account using your new password.
</EmailText>
<EmailText>
Thank you for keeping your account secure.
If you didn't make this change, please contact our
support team immediately to secure your account.
</EmailText>
<EmailFooter>

View file

@ -1,11 +1,5 @@
import {
Body,
Head,
Html,
Preview,
Tailwind
} from "@react-email/components";
import * as React from "react";
import React from "react";
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
import { themeColors } from "./lib/theme";
import {
EmailContainer,
@ -18,6 +12,7 @@ import {
EmailText
} from "./components/Email";
import CopyCodeBox from "./components/CopyCodeBox";
import ButtonLink from "./components/ButtonLink";
interface Props {
email: string;
@ -26,37 +21,39 @@ interface Props {
}
export const ResetPasswordCode = ({ email, code, link }: Props) => {
const previewText = `Your password reset code is ${code}`;
const previewText = `Reset your password with code: ${code}`;
return (
<Html>
<Head />
<Preview>{previewText}</Preview>
<Tailwind config={themeColors}>
<Body className="font-sans">
<Body className="font-sans bg-gray-50">
<EmailContainer>
<EmailLetterHead />
<EmailHeading>Password Reset Request</EmailHeading>
{/* <EmailHeading>Reset Your Password</EmailHeading> */}
<EmailGreeting>Hi {email || "there"},</EmailGreeting>
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
Youve requested to reset your password. Please{" "}
<a href={link} className="text-primary">
click here
</a>{" "}
and follow the instructions to reset your password,
or manually enter the following code:
You've requested to reset your password. Click the
button below to reset your password, or use the
verification code provided if prompted.
</EmailText>
<EmailSection>
<ButtonLink href={link}>Reset Password</ButtonLink>
</EmailSection>
<EmailSection>
<CopyCodeBox text={code} />
</EmailSection>
<EmailText>
If you didnt request this, you can safely ignore
this email.
This reset code will expire in 2 hours. If you
didn't request a password reset, you can safely
ignore this email.
</EmailText>
<EmailFooter>

View file

@ -1,11 +1,5 @@
import {
Body,
Head,
Html,
Preview,
Tailwind
} from "@react-email/components";
import * as React from "react";
import React from "react";
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
import {
EmailContainer,
EmailLetterHead,
@ -32,34 +26,40 @@ export const ResourceOTPCode = ({
orgName: organizationName,
otp
}: ResourceOTPCodeProps) => {
const previewText = `Your one-time password for ${resourceName} is ${otp}`;
const previewText = `Your access code for ${resourceName}: ${otp}`;
return (
<Html>
<Head />
<Preview>{previewText}</Preview>
<Tailwind config={themeColors}>
<Body className="font-sans">
<Body className="font-sans bg-gray-50">
<EmailContainer>
<EmailLetterHead />
<EmailHeading>
Your One-Time Code for {resourceName}
</EmailHeading>
{/* <EmailHeading> */}
{/* Access Code for {resourceName} */}
{/* </EmailHeading> */}
<EmailGreeting>Hi {email || "there"},</EmailGreeting>
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
Youve requested a one-time password to access{" "}
You've requested access to{" "}
<strong>{resourceName}</strong> in{" "}
<strong>{organizationName}</strong>. Use the code
below to complete your authentication:
<strong>{organizationName}</strong>. Use the
verification code below to complete your
authentication.
</EmailText>
<EmailSection>
<CopyCodeBox text={otp} />
</EmailSection>
<EmailText>
This code will expire in 15 minutes. If you didn't
request this code, please ignore this email.
</EmailText>
<EmailFooter>
<EmailSignature />
</EmailFooter>

View file

@ -1,11 +1,5 @@
import {
Body,
Head,
Html,
Preview,
Tailwind,
} from "@react-email/components";
import * as React from "react";
import React from "react";
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
import { themeColors } from "./lib/theme";
import {
EmailContainer,
@ -41,35 +35,44 @@ export const SendInviteLink = ({
<Head />
<Preview>{previewText}</Preview>
<Tailwind config={themeColors}>
<Body className="font-sans">
<Body className="font-sans bg-gray-50">
<EmailContainer>
<EmailLetterHead />
<EmailHeading>Invited to Join {orgName}</EmailHeading>
{/* <EmailHeading> */}
{/* You're Invited to Join {orgName} */}
{/* </EmailHeading> */}
<EmailGreeting>Hi {email || "there"},</EmailGreeting>
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
Youve been invited to join the organization{" "}
You've been invited to join{" "}
<strong>{orgName}</strong>
{inviterName ? ` by ${inviterName}.` : "."} Please
access the link below to accept the invite.
</EmailText>
<EmailText>
This invite will expire in{" "}
<strong>
{expiresInDays}{" "}
{expiresInDays === "1" ? "day" : "days"}.
</strong>
{inviterName ? ` by ${inviterName}` : ""}. Click the
button below to accept your invitation and get
started.
</EmailText>
<EmailSection>
<ButtonLink href={inviteLink}>
Accept Invite to {orgName}
Accept Invitation
</ButtonLink>
</EmailSection>
{/* <EmailText> */}
{/* If you're having trouble clicking the button, copy */}
{/* and paste the URL below into your web browser: */}
{/* <br /> */}
{/* <span className="break-all">{inviteLink}</span> */}
{/* </EmailText> */}
<EmailText>
This invite expires in {expiresInDays}{" "}
{expiresInDays === "1" ? "day" : "days"}. If the
link has expired, please contact the owner of the
organization to request a new invitation.
</EmailText>
<EmailFooter>
<EmailSignature />
</EmailFooter>

View file

@ -1,11 +1,5 @@
import {
Body,
Head,
Html,
Preview,
Tailwind
} from "@react-email/components";
import * as React from "react";
import React from "react";
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
import { themeColors } from "./lib/theme";
import {
EmailContainer,
@ -23,44 +17,52 @@ interface Props {
}
export const TwoFactorAuthNotification = ({ email, enabled }: Props) => {
const previewText = `Two-Factor Authentication has been ${enabled ? "enabled" : "disabled"}`;
const previewText = `Two-Factor Authentication ${enabled ? "enabled" : "disabled"} for your account`;
return (
<Html>
<Head />
<Preview>{previewText}</Preview>
<Tailwind config={themeColors}>
<Body className="font-sans">
<Body className="font-sans bg-gray-50">
<EmailContainer>
<EmailLetterHead />
<EmailHeading>
Two-Factor Authentication{" "}
{enabled ? "Enabled" : "Disabled"}
</EmailHeading>
{/* <EmailHeading> */}
{/* Security Update: 2FA{" "} */}
{/* {enabled ? "Enabled" : "Disabled"} */}
{/* </EmailHeading> */}
<EmailGreeting>Hi {email || "there"},</EmailGreeting>
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
This email confirms that Two-Factor Authentication
has been successfully{" "}
{enabled ? "enabled" : "disabled"} on your account.
Two-factor authentication has been successfully{" "}
<strong>{enabled ? "enabled" : "disabled"}</strong>{" "}
on your account.
</EmailText>
{enabled ? (
<EmailText>
With Two-Factor Authentication enabled, your
account is now more secure. Please ensure you
keep your authentication method safe.
</EmailText>
<>
<EmailText>
Your account is now protected with an
additional layer of security. Keep your
authentication method safe and accessible.
</EmailText>
</>
) : (
<EmailText>
With Two-Factor Authentication disabled, your
account may be less secure. We recommend
enabling it to protect your account.
</EmailText>
<>
<EmailText>
We recommend re-enabling two-factor
authentication to keep your account secure.
</EmailText>
</>
)}
<EmailText>
If you didn't make this change, please contact our
support team immediately.
</EmailText>
<EmailFooter>
<EmailSignature />
</EmailFooter>

View file

@ -1,5 +1,5 @@
import React from "react";
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
import * as React from "react";
import { themeColors } from "./lib/theme";
import {
EmailContainer,
@ -24,25 +24,24 @@ export const VerifyEmail = ({
verificationCode,
verifyLink
}: VerifyEmailProps) => {
const previewText = `Your verification code is ${verificationCode}`;
const previewText = `Verify your email with code: ${verificationCode}`;
return (
<Html>
<Head />
<Preview>{previewText}</Preview>
<Tailwind config={themeColors}>
<Body className="font-sans">
<Body className="font-sans bg-gray-50">
<EmailContainer>
<EmailLetterHead />
<EmailHeading>Please Verify Your Email</EmailHeading>
{/* <EmailHeading>Verify Your Email Address</EmailHeading> */}
<EmailGreeting>Hi {username || "there"},</EmailGreeting>
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
Youve requested to verify your email. Please use
the code below to complete the verification process
upon logging in.
Welcome! To complete your account setup, please
verify your email address using the code below.
</EmailText>
<EmailSection>
@ -50,7 +49,8 @@ export const VerifyEmail = ({
</EmailSection>
<EmailText>
If you didnt request this, you can safely ignore
This verification code will expire in 15 minutes. If
you didn't create an account, you can safely ignore
this email.
</EmailText>

View file

@ -0,0 +1,131 @@
import React from "react";
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
import { themeColors } from "./lib/theme";
import {
EmailContainer,
EmailFooter,
EmailGreeting,
EmailHeading,
EmailLetterHead,
EmailSection,
EmailSignature,
EmailText,
EmailInfoSection
} from "./components/Email";
import ButtonLink from "./components/ButtonLink";
import CopyCodeBox from "./components/CopyCodeBox";
interface WelcomeQuickStartProps {
username?: string;
link: string;
fallbackLink: string;
resourceMethod: string;
resourceHostname: string;
resourcePort: string | number;
resourceUrl: string;
cliCommand: string;
}
export const WelcomeQuickStart = ({
username,
link,
fallbackLink,
resourceMethod,
resourceHostname,
resourcePort,
resourceUrl,
cliCommand
}: WelcomeQuickStartProps) => {
const previewText = "Welcome! Here's what to do next";
return (
<Html>
<Head />
<Preview>{previewText}</Preview>
<Tailwind config={themeColors}>
<Body className="font-sans bg-gray-50">
<EmailContainer>
<EmailLetterHead />
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
Thank you for trying out Pangolin! We're excited to
have you on board.
</EmailText>
<EmailText>
To continue to configure your site, resources, and
other features, complete your account setup to
access the full dashboard.
</EmailText>
<EmailSection>
<ButtonLink href={link}>
View Your Dashboard
</ButtonLink>
{/* <p className="text-sm text-gray-300 mt-2"> */}
{/* If the button above doesn't work, you can also */}
{/* use this{" "} */}
{/* <a href={fallbackLink} className="underline"> */}
{/* link */}
{/* </a> */}
{/* . */}
{/* </p> */}
</EmailSection>
<EmailSection>
<div className="mb-2 font-semibold text-gray-900 text-base text-left">
Connect your site using Newt
</div>
<div className="inline-block w-full">
<div className="bg-gray-50 border border-gray-200 rounded-lg px-6 py-4 mx-auto text-left">
<span className="text-sm font-mono text-gray-900 tracking-wider">
{cliCommand}
</span>
</div>
<p className="text-xs text-gray-500 mt-2">
To learn how to use Newt, including more
installation methods, visit the{" "}
<a
href="https://docs.fossorial.io"
className="underline"
>
docs
</a>
.
</p>
</div>
</EmailSection>
<EmailInfoSection
title="Your Demo Resource"
items={[
{ label: "Method", value: resourceMethod },
{ label: "Hostname", value: resourceHostname },
{ label: "Port", value: resourcePort },
{
label: "Resource URL",
value: (
<a
href={resourceUrl}
className="underline text-blue-600"
>
{resourceUrl}
</a>
)
}
]}
/>
<EmailFooter>
<EmailSignature />
</EmailFooter>
</EmailContainer>
</Body>
</Tailwind>
</Html>
);
};
export default WelcomeQuickStart;

View file

@ -12,7 +12,11 @@ export default function ButtonLink({
return (
<a
href={href}
className={`rounded-full bg-primary px-4 py-2 text-center font-semibold text-white text-xl no-underline inline-block ${className}`}
className={`inline-block bg-primary hover:bg-primary/90 text-white font-semibold px-8 py-3 rounded-lg text-center no-underline transition-colors ${className}`}
style={{
backgroundColor: "#F97316",
textDecoration: "none"
}}
>
{children}
</a>

View file

@ -2,10 +2,15 @@ import React from "react";
export default function CopyCodeBox({ text }: { text: string }) {
return (
<div className="text-center rounded-lg bg-neutral-100 p-2">
<span className="text-2xl font-mono text-neutral-600 tracking-wide">
{text}
</span>
<div className="inline-block">
<div className="bg-gray-50 border border-gray-200 rounded-lg px-6 py-4 mx-auto">
<span className="text-2xl font-mono text-gray-900 tracking-wider font-semibold">
{text}
</span>
</div>
<p className="text-xs text-gray-500 mt-2">
Copy and paste this code when prompted
</p>
</div>
);
}

View file

@ -1,47 +1,26 @@
import { Container } from "@react-email/components";
import React from "react";
import { Container, Img } from "@react-email/components";
// EmailContainer: Wraps the entire email layout
export function EmailContainer({ children }: { children: React.ReactNode }) {
return (
<Container className="bg-white border border-solid border-gray-200 p-6 max-w-lg mx-auto my-8 rounded-lg">
<Container className="bg-white border border-solid border-gray-200 max-w-lg mx-auto my-8 rounded-lg overflow-hidden shadow-sm">
{children}
</Container>
);
}
// EmailLetterHead: For branding or logo at the top
// EmailLetterHead: For branding with logo on dark background
export function EmailLetterHead() {
return (
<div className="mb-4">
<table
role="presentation"
width="100%"
style={{
marginBottom: "24px"
}}
>
<tr>
<td
style={{
fontSize: "14px",
fontWeight: "bold",
color: "#F97317"
}}
>
Pangolin
</td>
<td
style={{
fontSize: "14px",
textAlign: "right",
color: "#6B7280"
}}
>
{new Date().getFullYear()}
</td>
</tr>
</table>
<div className="px-6 pt-8 pb-2 text-center">
<Img
src="https://fossorial-public-assets.s3.us-east-1.amazonaws.com/word_mark_black.png"
alt="Fossorial"
width="120"
height="auto"
className="mx-auto"
/>
</div>
);
}
@ -49,14 +28,22 @@ export function EmailLetterHead() {
// EmailHeading: For the primary message or headline
export function EmailHeading({ children }: { children: React.ReactNode }) {
return (
<h1 className="text-2xl font-semibold text-gray-800 text-center">
{children}
</h1>
<div className="px-6 pt-4 pb-1">
<h1 className="text-2xl font-semibold text-gray-900 text-center leading-tight">
{children}
</h1>
</div>
);
}
export function EmailGreeting({ children }: { children: React.ReactNode }) {
return <p className="text-base text-gray-700 my-4">{children}</p>;
return (
<div className="px-6">
<p className="text-base text-gray-700 leading-relaxed">
{children}
</p>
</div>
);
}
// EmailText: For general text content
@ -68,9 +55,13 @@ export function EmailText({
className?: string;
}) {
return (
<p className={`my-2 text-base text-gray-700 ${className}`}>
{children}
</p>
<div className="px-6">
<p
className={`text-base text-gray-700 leading-relaxed ${className}`}
>
{children}
</p>
</div>
);
}
@ -82,20 +73,70 @@ export function EmailSection({
children: React.ReactNode;
className?: string;
}) {
return <div className={`text-center my-6 ${className}`}>{children}</div>;
return (
<div className={`px-6 py-6 text-center ${className}`}>{children}</div>
);
}
// EmailFooter: For closing or signature
export function EmailFooter({ children }: { children: React.ReactNode }) {
return <div className="text-sm text-gray-500 mt-6">{children}</div>;
return (
<div className="px-6 py-6 border-t border-gray-100 bg-gray-50">
{children}
<p className="text-xs text-gray-400 mt-4">
For any questions or support, please contact us at:
<br />
support@fossorial.io
</p>
<p className="text-xs text-gray-300 text-center mt-4">
&copy; {new Date().getFullYear()} Fossorial, Inc. All rights
reserved.
</p>
</div>
);
}
export function EmailSignature() {
return (
<p>
Best regards,
<br />
Fossorial
</p>
<div className="text-sm text-gray-600">
<p className="mb-2">
Best regards,
<br />
<strong>The Fossorial Team</strong>
</p>
</div>
);
}
// EmailInfoSection: For structured key-value info (like resource details)
export function EmailInfoSection({
title,
items
}: {
title?: string;
items: { label: string; value: React.ReactNode }[];
}) {
return (
<div className="px-6 py-4">
{title && (
<div className="mb-2 font-semibold text-gray-900 text-base">
{title}
</div>
)}
<table className="w-full text-sm text-left">
<tbody>
{items.map((item, idx) => (
<tr key={idx}>
<td className="pr-4 py-1 text-gray-600 align-top whitespace-nowrap">
{item.label}
</td>
<td className="py-1 text-gray-900 break-all">
{item.value}
</td>
</tr>
))}
</tbody>
</table>
</div>
);
}

View file

@ -1,3 +1,5 @@
import React from "react";
export const themeColors = {
theme: {
extend: {

View file

@ -9,6 +9,7 @@ import { createIntegrationApiServer } from "./integrationApiServer";
import config from "@server/lib/config";
async function startServers() {
await config.initServer();
await runSetupFunctions();
// Start all servers

View file

@ -20,8 +20,9 @@ const externalPort = config.getRawConfig().server.integration_port;
export function createIntegrationApiServer() {
const apiServer = express();
if (config.getRawConfig().server.trust_proxy) {
apiServer.set("trust proxy", 1);
const trustProxy = config.getRawConfig().server.trust_proxy;
if (trustProxy) {
apiServer.set("trust proxy", trustProxy);
}
apiServer.use(cors());

View file

@ -17,10 +17,6 @@ export class Config {
isDev: boolean = process.env.ENVIRONMENT !== "prod";
constructor() {
this.load();
}
public load() {
const environment = readConfigFile();
const {
@ -90,17 +86,36 @@ export class Config {
? "true"
: "false";
process.env.DASHBOARD_URL = parsedConfig.app.dashboard_url;
process.env.FLAGS_DISABLE_LOCAL_SITES = parsedConfig.flags
?.disable_local_sites
? "true"
: "false";
process.env.FLAGS_DISABLE_BASIC_WIREGUARD_SITES = parsedConfig.flags
?.disable_basic_wireguard_sites
? "true"
: "false";
license.setServerSecret(parsedConfig.server.secret);
this.checkKeyStatus();
process.env.FLAGS_ENABLE_CLIENTS = parsedConfig.flags?.enable_clients
? "true"
: "false";
this.rawConfig = parsedConfig;
}
public async initServer() {
if (!this.rawConfig) {
throw new Error("Config not loaded. Call load() first.");
}
license.setServerSecret(this.rawConfig.server.secret);
await this.checkKeyStatus();
}
private async checkKeyStatus() {
const licenseStatus = await license.check();
if (!licenseStatus.isHostLicensed) {
if (
!licenseStatus.isHostLicensed
) {
this.checkSupporterKey();
}
}
@ -116,6 +131,9 @@ export class Config {
}
public getDomain(domainId: string) {
if (!this.rawConfig.domains || !this.rawConfig.domains[domainId]) {
return null;
}
return this.rawConfig.domains[domainId];
}

View file

@ -4,7 +4,14 @@ import { assertEquals } from "@test/assert";
// Test cases
function testFindNextAvailableCidr() {
console.log("Running findNextAvailableCidr tests...");
// Test 0: Basic IPv4 allocation with a subnet in the wrong range
{
const existing = ["100.90.130.1/30", "100.90.128.4/30"];
const result = findNextAvailableCidr(existing, 30, "100.90.130.1/24");
assertEquals(result, "100.90.130.4/30", "Basic IPv4 allocation failed");
}
// Test 1: Basic IPv4 allocation
{
const existing = ["10.0.0.0/16", "10.1.0.0/16"];
@ -26,6 +33,12 @@ function testFindNextAvailableCidr() {
assertEquals(result, null, "No available space test failed");
}
// Test 4: Empty existing
{
const existing: string[] = [];
const result = findNextAvailableCidr(existing, 30, "10.0.0.0/8");
assertEquals(result, "10.0.0.0/30", "Empty existing test failed");
}
// // Test 4: IPv6 allocation
// {
// const existing = ["2001:db8::/32", "2001:db8:1::/32"];

View file

@ -1,3 +1,8 @@
import { db } from "@server/db";
import { clients, orgs, sites } from "@server/db";
import { and, eq, isNotNull } from "drizzle-orm";
import config from "@server/lib/config";
interface IPRange {
start: bigint;
end: bigint;
@ -9,7 +14,7 @@ type IPVersion = 4 | 6;
* Detects IP version from address string
*/
function detectIpVersion(ip: string): IPVersion {
return ip.includes(':') ? 6 : 4;
return ip.includes(":") ? 6 : 4;
}
/**
@ -19,34 +24,34 @@ function ipToBigInt(ip: string): bigint {
const version = detectIpVersion(ip);
if (version === 4) {
return ip.split('.')
.reduce((acc, octet) => {
const num = parseInt(octet);
if (isNaN(num) || num < 0 || num > 255) {
throw new Error(`Invalid IPv4 octet: ${octet}`);
}
return BigInt.asUintN(64, (acc << BigInt(8)) + BigInt(num));
}, BigInt(0));
return ip.split(".").reduce((acc, octet) => {
const num = parseInt(octet);
if (isNaN(num) || num < 0 || num > 255) {
throw new Error(`Invalid IPv4 octet: ${octet}`);
}
return BigInt.asUintN(64, (acc << BigInt(8)) + BigInt(num));
}, BigInt(0));
} else {
// Handle IPv6
// Expand :: notation
let fullAddress = ip;
if (ip.includes('::')) {
const parts = ip.split('::');
if (parts.length > 2) throw new Error('Invalid IPv6 address: multiple :: found');
const missing = 8 - (parts[0].split(':').length + parts[1].split(':').length);
const padding = Array(missing).fill('0').join(':');
if (ip.includes("::")) {
const parts = ip.split("::");
if (parts.length > 2)
throw new Error("Invalid IPv6 address: multiple :: found");
const missing =
8 - (parts[0].split(":").length + parts[1].split(":").length);
const padding = Array(missing).fill("0").join(":");
fullAddress = `${parts[0]}:${padding}:${parts[1]}`;
}
return fullAddress.split(':')
.reduce((acc, hextet) => {
const num = parseInt(hextet || '0', 16);
if (isNaN(num) || num < 0 || num > 65535) {
throw new Error(`Invalid IPv6 hextet: ${hextet}`);
}
return BigInt.asUintN(128, (acc << BigInt(16)) + BigInt(num));
}, BigInt(0));
return fullAddress.split(":").reduce((acc, hextet) => {
const num = parseInt(hextet || "0", 16);
if (isNaN(num) || num < 0 || num > 65535) {
throw new Error(`Invalid IPv6 hextet: ${hextet}`);
}
return BigInt.asUintN(128, (acc << BigInt(16)) + BigInt(num));
}, BigInt(0));
}
}
@ -60,11 +65,15 @@ function bigIntToIp(num: bigint, version: IPVersion): string {
octets.unshift(Number(num & BigInt(255)));
num = num >> BigInt(8);
}
return octets.join('.');
return octets.join(".");
} else {
const hextets: string[] = [];
for (let i = 0; i < 8; i++) {
hextets.unshift(Number(num & BigInt(65535)).toString(16).padStart(4, '0'));
hextets.unshift(
Number(num & BigInt(65535))
.toString(16)
.padStart(4, "0")
);
num = num >> BigInt(16);
}
// Compress zero sequences
@ -74,7 +83,7 @@ function bigIntToIp(num: bigint, version: IPVersion): string {
let currentZeroLength = 0;
for (let i = 0; i < hextets.length; i++) {
if (hextets[i] === '0000') {
if (hextets[i] === "0000") {
if (currentZeroStart === -1) currentZeroStart = i;
currentZeroLength++;
if (currentZeroLength > maxZeroLength) {
@ -88,12 +97,14 @@ function bigIntToIp(num: bigint, version: IPVersion): string {
}
if (maxZeroLength > 1) {
hextets.splice(maxZeroStart, maxZeroLength, '');
if (maxZeroStart === 0) hextets.unshift('');
if (maxZeroStart + maxZeroLength === 8) hextets.push('');
hextets.splice(maxZeroStart, maxZeroLength, "");
if (maxZeroStart === 0) hextets.unshift("");
if (maxZeroStart + maxZeroLength === 8) hextets.push("");
}
return hextets.map(h => h === '0000' ? '0' : h.replace(/^0+/, '')).join(':');
return hextets
.map((h) => (h === "0000" ? "0" : h.replace(/^0+/, "")))
.join(":");
}
}
@ -101,7 +112,7 @@ function bigIntToIp(num: bigint, version: IPVersion): string {
* Converts CIDR to IP range
*/
export function cidrToRange(cidr: string): IPRange {
const [ip, prefix] = cidr.split('/');
const [ip, prefix] = cidr.split("/");
const version = detectIpVersion(ip);
const prefixBits = parseInt(prefix);
const ipBigInt = ipToBigInt(ip);
@ -113,7 +124,10 @@ export function cidrToRange(cidr: string): IPRange {
}
const shiftBits = BigInt(maxPrefix - prefixBits);
const mask = BigInt.asUintN(version === 4 ? 64 : 128, (BigInt(1) << shiftBits) - BigInt(1));
const mask = BigInt.asUintN(
version === 4 ? 64 : 128,
(BigInt(1) << shiftBits) - BigInt(1)
);
const start = ipBigInt & ~mask;
const end = start | mask;
@ -132,28 +146,32 @@ export function findNextAvailableCidr(
blockSize: number,
startCidr?: string
): string | null {
if (!startCidr && existingCidrs.length === 0) {
return null;
}
// If no existing CIDRs, use the IP version from startCidr
const version = startCidr
? detectIpVersion(startCidr.split('/')[0])
: 4; // Default to IPv4 if no startCidr provided
const version = startCidr ? detectIpVersion(startCidr.split("/")[0]) : 4; // Default to IPv4 if no startCidr provided
// Use appropriate default startCidr if none provided
startCidr = startCidr || (version === 4 ? "0.0.0.0/0" : "::/0");
// If there are existing CIDRs, ensure all are same version
if (existingCidrs.length > 0 &&
existingCidrs.some(cidr => detectIpVersion(cidr.split('/')[0]) !== version)) {
throw new Error('All CIDRs must be of the same IP version');
if (
existingCidrs.length > 0 &&
existingCidrs.some(
(cidr) => detectIpVersion(cidr.split("/")[0]) !== version
)
) {
throw new Error("All CIDRs must be of the same IP version");
}
// Extract the network part from startCidr to ensure we stay in the right subnet
const startCidrRange = cidrToRange(startCidr);
// Convert existing CIDRs to ranges and sort them
const existingRanges = existingCidrs
.map(cidr => cidrToRange(cidr))
.map((cidr) => cidrToRange(cidr))
.sort((a, b) => (a.start < b.start ? -1 : 1));
// Calculate block size
@ -161,14 +179,17 @@ export function findNextAvailableCidr(
const blockSizeBigInt = BigInt(1) << BigInt(maxPrefix - blockSize);
// Start from the beginning of the given CIDR
let current = cidrToRange(startCidr).start;
const maxIp = cidrToRange(startCidr).end;
let current = startCidrRange.start;
const maxIp = startCidrRange.end;
// Iterate through existing ranges
for (let i = 0; i <= existingRanges.length; i++) {
const nextRange = existingRanges[i];
// Align current to block size
const alignedCurrent = current + ((blockSizeBigInt - (current % blockSizeBigInt)) % blockSizeBigInt);
const alignedCurrent =
current +
((blockSizeBigInt - (current % blockSizeBigInt)) % blockSizeBigInt);
// Check if we've gone beyond the maximum allowed IP
if (alignedCurrent + blockSizeBigInt - BigInt(1) > maxIp) {
@ -176,12 +197,18 @@ export function findNextAvailableCidr(
}
// If we're at the end of existing ranges or found a gap
if (!nextRange || alignedCurrent + blockSizeBigInt - BigInt(1) < nextRange.start) {
if (
!nextRange ||
alignedCurrent + blockSizeBigInt - BigInt(1) < nextRange.start
) {
return `${bigIntToIp(alignedCurrent, version)}/${blockSize}`;
}
// Move current pointer to after the current range
current = nextRange.end + BigInt(1);
// If next range overlaps with our search space, move past it
if (nextRange.end >= startCidrRange.start && nextRange.start <= maxIp) {
// Move current pointer to after the current range
current = nextRange.end + BigInt(1);
}
}
return null;
@ -195,7 +222,7 @@ export function findNextAvailableCidr(
*/
export function isIpInCidr(ip: string, cidr: string): boolean {
const ipVersion = detectIpVersion(ip);
const cidrVersion = detectIpVersion(cidr.split('/')[0]);
const cidrVersion = detectIpVersion(cidr.split("/")[0]);
// If IP versions don't match, the IP cannot be in the CIDR range
if (ipVersion !== cidrVersion) {
@ -207,3 +234,61 @@ export function isIpInCidr(ip: string, cidr: string): boolean {
const range = cidrToRange(cidr);
return ipBigInt >= range.start && ipBigInt <= range.end;
}
export async function getNextAvailableClientSubnet(
orgId: string
): Promise<string> {
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
const existingAddressesSites = await db
.select({
address: sites.address
})
.from(sites)
.where(and(isNotNull(sites.address), eq(sites.orgId, orgId)));
const existingAddressesClients = await db
.select({
address: clients.subnet
})
.from(clients)
.where(and(isNotNull(clients.subnet), eq(clients.orgId, orgId)));
const addresses = [
...existingAddressesSites.map(
(site) => `${site.address?.split("/")[0]}/32`
), // we are overriding the 32 so that we pick individual addresses in the subnet of the org for the site and the client even though they are stored with the /block_size of the org
...existingAddressesClients.map(
(client) => `${client.address.split("/")}/32`
)
].filter((address) => address !== null) as string[];
let subnet = findNextAvailableCidr(addresses, 32, org.subnet); // pick the sites address in the org
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
return subnet;
}
export async function getNextAvailableOrgSubnet(): Promise<string> {
const existingAddresses = await db
.select({
subnet: orgs.subnet
})
.from(orgs)
.where(isNotNull(orgs.subnet));
const addresses = existingAddresses.map((org) => org.subnet);
let subnet = findNextAvailableCidr(
addresses,
config.getRawConfig().orgs.block_size,
config.getRawConfig().orgs.subnet_group
);
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
return subnet;
}

View file

@ -0,0 +1,16 @@
import { MemoryStore, Store } from "express-rate-limit";
import config from "./config";
import redisManager from "@server/db/redis";
import { RedisStore } from "rate-limit-redis";
export function createStore(): Store {
let rateLimitStore: Store = new MemoryStore();
if (config.getRawConfig().flags?.enable_redis) {
const client = redisManager.client!;
rateLimitStore = new RedisStore({
sendCommand: async (command: string, ...args: string[]) =>
(await client.call(command, args)) as any
});
}
return rateLimitStore;
}

View file

@ -3,6 +3,7 @@ import yaml from "js-yaml";
import { configFilePath1, configFilePath2 } from "./consts";
import { z } from "zod";
import stoi from "./stoi";
import { build } from "@server/build";
const portSchema = z.number().positive().gt(0).lte(65535);
@ -10,214 +11,279 @@ const getEnvOrYaml = (envVar: string) => (valFromYaml: any) => {
return process.env[envVar] ?? valFromYaml;
};
export const configSchema = z.object({
app: z.object({
dashboard_url: z
.string()
.url()
.optional()
.pipe(z.string().url())
.transform((url) => url.toLowerCase()),
log_level: z
.enum(["debug", "info", "warn", "error"])
.optional()
.default("info"),
save_logs: z.boolean().optional().default(false),
log_failed_attempts: z.boolean().optional().default(false)
}),
domains: z
.record(
z.string(),
z.object({
base_domain: z
.string()
.nonempty("base_domain must not be empty")
.transform((url) => url.toLowerCase()),
cert_resolver: z.string().optional().default("letsencrypt"),
prefer_wildcard_cert: z.boolean().optional().default(false)
})
)
.refine(
(domains) => {
const keys = Object.keys(domains);
if (keys.length === 0) {
return false;
}
return true;
},
{
message: "At least one domain must be defined"
}
),
server: z.object({
integration_port: portSchema
.optional()
.default(3003)
.transform(stoi)
.pipe(portSchema.optional()),
external_port: portSchema
.optional()
.default(3000)
.transform(stoi)
.pipe(portSchema),
internal_port: portSchema
.optional()
.default(3001)
.transform(stoi)
.pipe(portSchema),
next_port: portSchema
.optional()
.default(3002)
.transform(stoi)
.pipe(portSchema),
internal_hostname: z
.string()
.optional()
.default("pangolin")
.transform((url) => url.toLowerCase()),
session_cookie_name: z.string().optional().default("p_session_token"),
resource_access_token_param: z.string().optional().default("p_token"),
resource_access_token_headers: z
export const configSchema = z
.object({
app: z.object({
dashboard_url: z
.string()
.url()
.optional()
.pipe(z.string().url())
.transform((url) => url.toLowerCase()),
log_level: z
.enum(["debug", "info", "warn", "error"])
.optional()
.default("info"),
save_logs: z.boolean().optional().default(false),
log_failed_attempts: z.boolean().optional().default(false)
}),
domains: z
.record(
z.string(),
z.object({
base_domain: z
.string()
.nonempty("base_domain must not be empty")
.transform((url) => url.toLowerCase()),
cert_resolver: z.string().optional().default("letsencrypt"),
prefer_wildcard_cert: z.boolean().optional().default(false)
})
)
.optional(),
server: z.object({
integration_port: portSchema
.optional()
.default(3003)
.transform(stoi)
.pipe(portSchema.optional()),
external_port: portSchema
.optional()
.default(3000)
.transform(stoi)
.pipe(portSchema),
internal_port: portSchema
.optional()
.default(3001)
.transform(stoi)
.pipe(portSchema),
next_port: portSchema
.optional()
.default(3002)
.transform(stoi)
.pipe(portSchema),
internal_hostname: z
.string()
.optional()
.default("pangolin")
.transform((url) => url.toLowerCase()),
session_cookie_name: z
.string()
.optional()
.default("p_session_token"),
resource_access_token_param: z
.string()
.optional()
.default("p_token"),
resource_access_token_headers: z
.object({
id: z.string().optional().default("P-Access-Token-Id"),
token: z.string().optional().default("P-Access-Token")
})
.optional()
.default({}),
resource_session_request_param: z
.string()
.optional()
.default("resource_session_request_param"),
dashboard_session_length_hours: z
.number()
.positive()
.gt(0)
.optional()
.default(720),
resource_session_length_hours: z
.number()
.positive()
.gt(0)
.optional()
.default(720),
cors: z
.object({
origins: z.array(z.string()).optional(),
methods: z.array(z.string()).optional(),
allowed_headers: z.array(z.string()).optional(),
credentials: z.boolean().optional()
})
.optional(),
trust_proxy: z.number().int().gte(0).optional().default(1),
secret: z
.string()
.optional()
.transform(getEnvOrYaml("SERVER_SECRET"))
.pipe(z.string().min(8))
}),
postgres: z
.object({
id: z.string().optional().default("P-Access-Token-Id"),
token: z.string().optional().default("P-Access-Token")
connection_string: z.string(),
replicas: z
.array(
z.object({
connection_string: z.string()
})
)
.optional()
})
.optional(),
redis: z
.object({
host: z.string(),
port: portSchema,
password: z.string().optional(),
db: z.number().int().nonnegative().optional().default(0),
tls: z
.object({
reject_unauthorized: z
.boolean()
.optional()
.default(true)
})
.optional()
})
.optional(),
traefik: z
.object({
http_entrypoint: z.string().optional().default("web"),
https_entrypoint: z.string().optional().default("websecure"),
additional_middlewares: z.array(z.string()).optional()
})
.optional()
.default({}),
resource_session_request_param: z
.string()
.optional()
.default("resource_session_request_param"),
dashboard_session_length_hours: z
.number()
.positive()
.gt(0)
.optional()
.default(720),
resource_session_length_hours: z
.number()
.positive()
.gt(0)
.optional()
.default(720),
cors: z
gerbil: z
.object({
origins: z.array(z.string()).optional(),
methods: z.array(z.string()).optional(),
allowed_headers: z.array(z.string()).optional(),
credentials: z.boolean().optional()
exit_node_name: z.string().optional(),
start_port: portSchema
.optional()
.default(51820)
.transform(stoi)
.pipe(portSchema),
base_endpoint: z
.string()
.optional()
.pipe(z.string())
.transform((url) => url.toLowerCase()),
use_subdomain: z.boolean().optional().default(false),
subnet_group: z.string().optional().default("100.89.137.0/20"),
block_size: z.number().positive().gt(0).optional().default(24),
site_block_size: z
.number()
.positive()
.gt(0)
.optional()
.default(30)
})
.optional()
.default({}),
orgs: z.object({
block_size: z.number().positive().gt(0),
subnet_group: z.string()
}),
rate_limits: z
.object({
global: z
.object({
window_minutes: z
.number()
.positive()
.gt(0)
.optional()
.default(1),
max_requests: z
.number()
.positive()
.gt(0)
.optional()
.default(500)
})
.optional()
.default({}),
auth: z
.object({
window_minutes: z
.number()
.positive()
.gt(0)
.optional()
.default(1),
max_requests: z
.number()
.positive()
.gt(0)
.optional()
.default(500)
})
.optional()
.default({})
})
.optional()
.default({}),
email: z
.object({
smtp_host: z.string().optional(),
smtp_port: portSchema.optional(),
smtp_user: z.string().optional(),
smtp_pass: z.string().optional(),
smtp_secure: z.boolean().optional(),
smtp_tls_reject_unauthorized: z.boolean().optional(),
no_reply: z.string().email().optional()
})
.optional(),
trust_proxy: z.number().int().gte(0).optional().default(1),
secret: z
.string()
flags: z
.object({
require_email_verification: z.boolean().optional(),
disable_signup_without_invite: z.boolean().optional(),
disable_user_create_org: z.boolean().optional(),
allow_raw_resources: z.boolean().optional(),
allow_base_domain_resources: z.boolean().optional(),
enable_integration_api: z.boolean().optional(),
enable_redis: z.boolean().optional(),
disable_local_sites: z.boolean().optional(),
disable_basic_wireguard_sites: z.boolean().optional(),
disable_config_managed_domains: z.boolean().optional(),
enable_clients: z.boolean().optional()
})
.optional()
.transform(getEnvOrYaml("SERVER_SECRET"))
.pipe(z.string().min(8))
}),
postgres: z
.object({
connection_string: z.string(),
replicas: z
.array(
z.object({
connection_string: z.string()
})
)
.optional()
})
.optional(),
traefik: z
.object({
http_entrypoint: z.string().optional().default("web"),
https_entrypoint: z.string().optional().default("websecure"),
additional_middlewares: z.array(z.string()).optional()
})
.optional()
.default({}),
gerbil: z
.object({
start_port: portSchema
.optional()
.default(51820)
.transform(stoi)
.pipe(portSchema),
base_endpoint: z
.string()
.optional()
.pipe(z.string())
.transform((url) => url.toLowerCase()),
use_subdomain: z.boolean().optional().default(false),
subnet_group: z.string().optional().default("100.89.137.0/20"),
block_size: z.number().positive().gt(0).optional().default(24),
site_block_size: z.number().positive().gt(0).optional().default(30)
})
.optional()
.default({}),
rate_limits: z
.object({
global: z
.object({
window_minutes: z
.number()
.positive()
.gt(0)
.optional()
.default(1),
max_requests: z
.number()
.positive()
.gt(0)
.optional()
.default(500)
})
.optional()
.default({}),
auth: z
.object({
window_minutes: z
.number()
.positive()
.gt(0)
.optional()
.default(1),
max_requests: z
.number()
.positive()
.gt(0)
.optional()
.default(500)
})
.optional()
.default({}),
})
.optional()
.default({}),
email: z
.object({
smtp_host: z.string().optional(),
smtp_port: portSchema.optional(),
smtp_user: z.string().optional(),
smtp_pass: z.string().optional(),
smtp_secure: z.boolean().optional(),
smtp_tls_reject_unauthorized: z.boolean().optional(),
no_reply: z.string().email().optional()
})
.optional(),
flags: z
.object({
require_email_verification: z.boolean().optional(),
disable_signup_without_invite: z.boolean().optional(),
disable_user_create_org: z.boolean().optional(),
allow_raw_resources: z.boolean().optional(),
allow_base_domain_resources: z.boolean().optional(),
allow_local_sites: z.boolean().optional(),
enable_integration_api: z.boolean().optional()
})
.optional()
});
})
.refine(
(data) => {
if (data.flags?.enable_redis) {
return data?.redis !== undefined;
}
return true;
},
{
message:
"If Redis is enabled, configuration details must be provided"
}
)
.refine(
(data) => {
const keys = Object.keys(data.domains || {});
if (data.flags?.disable_config_managed_domains) {
return true;
}
if (keys.length === 0) {
return false;
}
return true;
},
{
message: "At least one domain must be defined"
}
)
.refine(
(data) => {
if (build == "oss" && data.redis) {
return false;
}
if (build == "oss" && data.flags?.enable_redis) {
return false;
}
return true;
},
{
message: "Redis"
}
);
export function readConfigFile() {
const loadConfig = (configPath: string) => {

File diff suppressed because it is too large Load diff

View file

@ -13,9 +13,17 @@ export * from "./verifyAdmin";
export * from "./verifySetResourceUsers";
export * from "./verifyUserInRole";
export * from "./verifyAccessTokenAccess";
export * from "./requestTimeout";
export * from "./verifyClientAccess";
export * from "./verifyUserHasAction";
export * from "./verifyUserIsServerAdmin";
export * from "./verifyIsLoggedInUser";
export * from "./verifyIsLoggedInUser";
export * from "./verifyClientAccess";
export * from "./integration";
export * from "./verifyValidLicense";
export * from "./verifyUserHasAction";
export * from "./verifyApiKeyAccess";
export * from "./verifyDomainAccess";
export * from "./verifyClientsEnabled";
export * from "./verifyUserIsOrgOwner";

View file

@ -0,0 +1,35 @@
import { Request, Response, NextFunction } from 'express';
import logger from '@server/logger';
import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode';
export function requestTimeoutMiddleware(timeoutMs: number = 30000) {
return (req: Request, res: Response, next: NextFunction) => {
// Set a timeout for the request
const timeout = setTimeout(() => {
if (!res.headersSent) {
logger.error(`Request timeout: ${req.method} ${req.url} from ${req.ip}`);
return next(
createHttpError(
HttpCode.REQUEST_TIMEOUT,
'Request timeout - operation took too long to complete'
)
);
}
}, timeoutMs);
// Clear timeout when response finishes
res.on('finish', () => {
clearTimeout(timeout);
});
// Clear timeout when response closes
res.on('close', () => {
clearTimeout(timeout);
});
next();
};
}
export default requestTimeoutMiddleware;

View file

@ -0,0 +1,131 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { userOrgs, clients, roleClients, userClients } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export async function verifyClientAccess(
req: Request,
res: Response,
next: NextFunction
) {
const userId = req.user!.userId; // Assuming you have user information in the request
const clientId = parseInt(
req.params.clientId || req.body.clientId || req.query.clientId
);
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
if (isNaN(clientId)) {
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid client ID"));
}
try {
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
if (!client.orgId) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Client with ID ${clientId} does not have an organization ID`
)
);
}
if (!req.userOrg) {
// Get user's role ID in the organization
const userOrgRole = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, client.orgId)
)
)
.limit(1);
req.userOrg = userOrgRole[0];
}
if (!req.userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgId = client.orgId;
// Check role-based site access first
const [roleClientAccess] = await db
.select()
.from(roleClients)
.where(
and(
eq(roleClients.clientId, clientId),
eq(roleClients.roleId, userOrgRoleId)
)
)
.limit(1);
if (roleClientAccess) {
// User has access to the site through their role
return next();
}
// If role doesn't have access, check user-specific site access
const [userClientAccess] = await db
.select()
.from(userClients)
.where(
and(
eq(userClients.userId, userId),
eq(userClients.clientId, clientId)
)
)
.limit(1);
if (userClientAccess) {
// User has direct access to the site
return next();
}
// If we reach here, the user doesn't have access to the site
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this client"
)
);
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying site access"
)
);
}
}

View file

@ -0,0 +1,29 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import config from "@server/lib/config";
export async function verifyClientsEnabled(
req: Request,
res: Response,
next: NextFunction
) {
try {
if (!config.getRawConfig().flags?.enable_clients) {
return next(
createHttpError(
HttpCode.NOT_IMPLEMENTED,
"Clients are not enabled on this server."
)
);
}
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to check if clients are enabled"
)
);
}
}

View file

@ -0,0 +1,93 @@
import { Request, Response, NextFunction } from "express";
import { db, domains, orgDomains } from "@server/db";
import { userOrgs, apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export async function verifyDomainAccess(
req: Request,
res: Response,
next: NextFunction
) {
try {
const userId = req.user!.userId;
const domainId =
req.params.domainId || req.body.apiKeyId || req.query.apiKeyId;
const orgId = req.params.orgId;
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
if (!orgId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid organization ID")
);
}
if (!domainId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid domain ID")
);
}
const [domain] = await db
.select()
.from(domains)
.innerJoin(orgDomains, eq(orgDomains.domainId, domains.domainId))
.where(
and(
eq(orgDomains.domainId, domainId),
eq(orgDomains.orgId, orgId)
)
)
.limit(1);
if (!domain.orgDomains) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Domain with ID ${domainId} not found`
)
);
}
if (!req.userOrg) {
const userOrgRole = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, apiKeyOrg.orgId)
)
)
.limit(1);
req.userOrg = userOrgRole[0];
}
if (!req.userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying domain access"
)
);
}
}

View file

@ -14,5 +14,6 @@ export enum OpenAPITags {
AccessToken = "Access Token",
Idp = "Identity Provider",
Client = "Client",
ApiKey = "API Key"
ApiKey = "API Key",
Domain = "Domain"
}

View file

@ -112,7 +112,11 @@ export async function requestTotpSecret(
const hex = crypto.getRandomValues(new Uint8Array(20));
const secret = encodeHex(hex);
const uri = createTOTPKeyURI("Pangolin", user.email!, hex);
const uri = createTOTPKeyURI(
"Pangolin",
user.email!,
hex
);
await db
.update(users)

View file

@ -1,8 +1,7 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { db, users } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { users } from "@server/db";
import { fromError } from "zod-validation-error";
import createHttpError from "http-errors";
import response from "@server/lib/response";
@ -57,8 +56,6 @@ export async function signup(
const { email, password, inviteToken, inviteId } = parsedBody.data;
logger.debug("signup", { email, password, inviteToken, inviteId });
const passwordHash = await hashPassword(password);
const userId = generateId(15);
@ -143,15 +140,21 @@ export async function signup(
if (diff < 2) {
// If the user was created less than 2 hours ago, we don't want to create a new user
return response<SignUpResponse>(res, {
data: {
emailVerificationRequired: true
},
success: true,
error: false,
message: `A user with that email address already exists. We sent an email to ${email} with a verification code.`,
status: HttpCode.OK
});
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A user with that email address already exists"
)
);
// return response<SignUpResponse>(res, {
// data: {
// emailVerificationRequired: true
// },
// success: true,
// error: false,
// message: `A user with that email address already exists. We sent an email to ${email} with a verification code.`,
// status: HttpCode.OK
// });
} else {
// If the user was created more than 2 hours ago, we want to delete the old user and create a new one
await db.delete(users).where(eq(users.userId, user.userId));

View file

@ -4,7 +4,7 @@ import { z } from "zod";
import { fromError } from "zod-validation-error";
import HttpCode from "@server/types/HttpCode";
import { response } from "@server/lib";
import { db } from "@server/db";
import { db, userOrgs } from "@server/db";
import { User, emailVerificationCodes, users } from "@server/db";
import { eq } from "drizzle-orm";
import { isWithinExpirationDate } from "oslo";

View file

@ -0,0 +1,252 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import {
roles,
Client,
clients,
roleClients,
userClients,
olms,
clientSites,
exitNodes,
orgs,
sites
} from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import moment from "moment";
import { hashPassword } from "@server/auth/password";
import { isValidCIDR, isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip";
import { OpenAPITags, registry } from "@server/openApi";
const createClientParamsSchema = z
.object({
orgId: z.string()
})
.strict();
const createClientSchema = z
.object({
name: z.string().min(1).max(255),
siteIds: z.array(z.number().int().positive()),
olmId: z.string(),
secret: z.string(),
subnet: z.string(),
type: z.enum(["olm"])
})
.strict();
export type CreateClientBody = z.infer<typeof createClientSchema>;
export type CreateClientResponse = Client;
registry.registerPath({
method: "put",
path: "/org/{orgId}/client",
description: "Create a new client.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: createClientParamsSchema,
body: {
content: {
"application/json": {
schema: createClientSchema
}
}
}
},
responses: {}
});
export async function createClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = createClientSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { name, type, siteIds, olmId, secret, subnet } = parsedBody.data;
const parsedParams = createClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
if (req.user && !req.userOrgRoleId) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
}
if (!isValidIP(subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Organization with ID ${orgId} not found`
)
);
}
if (!isIpInCidr(subnet, org.subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IP is not in the CIDR range of the subnet."
)
);
}
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
// make sure the subnet is unique
const subnetExistsClients = await db
.select()
.from(clients)
.where(eq(clients.subnet, updatedSubnet))
.limit(1);
if (subnetExistsClients.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
const subnetExistsSites = await db
.select()
.from(sites)
.where(eq(sites.address, updatedSubnet))
.limit(1);
if (subnetExistsSites.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
await db.transaction(async (trx) => {
// TODO: more intelligent way to pick the exit node
// make sure there is an exit node by counting the exit nodes table
const nodes = await db.select().from(exitNodes);
if (nodes.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No exit nodes available"
)
);
}
// get the first exit node
const exitNode = nodes[0];
const adminRole = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (adminRole.length === 0) {
trx.rollback();
return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
}
const [newClient] = await trx
.insert(clients)
.values({
exitNodeId: exitNode.exitNodeId,
orgId,
name,
subnet: updatedSubnet,
type
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole[0].roleId,
clientId: newClient.clientId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
trx.insert(userClients).values({
userId: req.user?.userId!,
clientId: newClient.clientId
});
}
// Create site to client associations
if (siteIds && siteIds.length > 0) {
await trx.insert(clientSites).values(
siteIds.map((siteId) => ({
clientId: newClient.clientId,
siteId
}))
);
}
const secretHash = await hashPassword(secret);
await trx.insert(olms).values({
olmId,
secretHash,
clientId: newClient.clientId,
dateCreated: moment().toISOString()
});
return response<CreateClientResponse>(res, {
data: newClient,
success: true,
error: false,
message: "Site created successfully",
status: HttpCode.CREATED
});
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -0,0 +1,88 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const deleteClientSchema = z
.object({
clientId: z.string().transform(Number).pipe(z.number().int().positive())
})
.strict();
registry.registerPath({
method: "delete",
path: "/client/{clientId}",
description: "Delete a client by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: deleteClientSchema
},
responses: {}
});
export async function deleteClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = deleteClientSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
await db.transaction(async (trx) => {
// Delete the client-site associations first
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, clientId));
// Then delete the client itself
await trx
.delete(clients)
.where(eq(clients.clientId, clientId));
});
return response(res, {
data: null,
success: true,
error: false,
message: "Client deleted successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -0,0 +1,101 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import { eq, and } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import stoi from "@server/lib/stoi";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const getClientSchema = z
.object({
clientId: z.string().transform(stoi).pipe(z.number().int().positive()),
orgId: z.string().optional()
})
.strict();
async function query(clientId: number) {
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return null;
}
// Get the siteIds associated with this client
const sites = await db
.select({ siteId: clientSites.siteId })
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
// Add the siteIds to the client object
return {
...client,
siteIds: sites.map(site => site.siteId)
};
}
export type GetClientResponse = NonNullable<Awaited<ReturnType<typeof query>>>;
registry.registerPath({
method: "get",
path: "/org/{orgId}/client/{clientId}",
description: "Get a client by its client ID.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: getClientSchema
},
responses: {}
});
export async function getClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = getClientSchema.safeParse(req.params);
if (!parsedParams.success) {
logger.error(
`Error parsing params: ${fromError(parsedParams.error).toString()}`
);
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const client = await query(clientId);
if (!client) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Client not found")
);
}
return response<GetClientResponse>(res, {
data: client,
success: true,
error: false,
message: "Client retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -0,0 +1,6 @@
export * from "./pickClientDefaults";
export * from "./createClient";
export * from "./deleteClient";
export * from "./listClients";
export * from "./updateClient";
export * from "./getClient";

View file

@ -0,0 +1,229 @@
import { db } from "@server/db";
import {
clients,
orgs,
roleClients,
sites,
userClients,
clientSites
} from "@server/db";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { and, count, eq, inArray, or, sql } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const listClientsParamsSchema = z
.object({
orgId: z.string()
})
.strict();
const listClientsSchema = z.object({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.number().int().positive()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.number().int().nonnegative())
});
function queryClients(orgId: string, accessibleClientIds: number[]) {
return db
.select({
clientId: clients.clientId,
orgId: clients.orgId,
name: clients.name,
pubKey: clients.pubKey,
subnet: clients.subnet,
megabytesIn: clients.megabytesIn,
megabytesOut: clients.megabytesOut,
orgName: orgs.name,
type: clients.type,
online: clients.online
})
.from(clients)
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))
.where(
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
)
);
}
async function getSiteAssociations(clientIds: number[]) {
if (clientIds.length === 0) return [];
return db
.select({
clientId: clientSites.clientId,
siteId: clientSites.siteId,
siteName: sites.name,
siteNiceId: sites.niceId
})
.from(clientSites)
.leftJoin(sites, eq(clientSites.siteId, sites.siteId))
.where(inArray(clientSites.clientId, clientIds));
}
export type ListClientsResponse = {
clients: Array<Awaited<ReturnType<typeof queryClients>>[0] & { sites: Array<{
siteId: number;
siteName: string | null;
siteNiceId: string | null;
}> }>;
pagination: { total: number; limit: number; offset: number };
};
registry.registerPath({
method: "get",
path: "/org/{orgId}/clients",
description: "List all clients for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
query: listClientsSchema,
params: listClientsParamsSchema
},
responses: {}
});
export async function listClients(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedQuery = listClientsSchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error)
)
);
}
const { limit, offset } = parsedQuery.data;
const parsedParams = listClientsParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error)
)
);
}
const { orgId } = parsedParams.data;
if (req.user && orgId && orgId !== req.userOrgId) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
let accessibleClients;
if (req.user) {
accessibleClients = await db
.select({
clientId: sql<number>`COALESCE(${userClients.clientId}, ${roleClients.clientId})`
})
.from(userClients)
.fullJoin(
roleClients,
eq(userClients.clientId, roleClients.clientId)
)
.where(
or(
eq(userClients.userId, req.user!.userId),
eq(roleClients.roleId, req.userOrgRoleId!)
)
);
} else {
accessibleClients = await db
.select({ clientId: clients.clientId })
.from(clients)
.where(eq(clients.orgId, orgId));
}
const accessibleClientIds = accessibleClients.map(
(client) => client.clientId
);
const baseQuery = queryClients(orgId, accessibleClientIds);
// Get client count
const countQuery = db
.select({ count: count() })
.from(clients)
.where(
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
)
);
const clientsList = await baseQuery.limit(limit).offset(offset);
const totalCountResult = await countQuery;
const totalCount = totalCountResult[0].count;
// Get associated sites for all clients
const clientIds = clientsList.map(client => client.clientId);
const siteAssociations = await getSiteAssociations(clientIds);
// Group site associations by client ID
const sitesByClient = siteAssociations.reduce((acc, association) => {
if (!acc[association.clientId]) {
acc[association.clientId] = [];
}
acc[association.clientId].push({
siteId: association.siteId,
siteName: association.siteName,
siteNiceId: association.siteNiceId
});
return acc;
}, {} as Record<number, Array<{
siteId: number;
siteName: string | null;
siteNiceId: string | null;
}>>);
// Merge clients with their site associations
const clientsWithSites = clientsList.map(client => ({
...client,
sites: sitesByClient[client.clientId] || []
}));
return response<ListClientsResponse>(res, {
data: {
clients: clientsWithSites,
pagination: {
total: totalCount,
limit,
offset
}
},
success: true,
error: false,
message: "Clients retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -0,0 +1,85 @@
import { Request, Response, NextFunction } from "express";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { generateId } from "@server/auth/sessions/app";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
export type PickClientDefaultsResponse = {
olmId: string;
olmSecret: string;
subnet: string;
};
const pickClientDefaultsSchema = z
.object({
orgId: z.string()
})
.strict();
registry.registerPath({
method: "get",
path: "/site/{siteId}/pick-client-defaults",
description: "Return pre-requisite data for creating a client.",
tags: [OpenAPITags.Client, OpenAPITags.Site],
request: {
params: pickClientDefaultsSchema
},
responses: {}
});
export async function pickClientDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = pickClientDefaultsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const olmId = generateId(15);
const secret = generateId(48);
const newSubnet = await getNextAvailableClientSubnet(orgId);
if (!newSubnet) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"No available subnet found"
)
);
}
const subnet = newSubnet.split("/")[0];
return response<PickClientDefaultsResponse>(res, {
data: {
olmId: olmId,
olmSecret: secret,
subnet: subnet
},
success: true,
error: false,
message: "Organization retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -0,0 +1,225 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "../newt/peers";
import {
addPeer as olmAddPeer,
deletePeer as olmDeletePeer
} from "../olm/peers";
const updateClientParamsSchema = z
.object({
clientId: z.string().transform(Number).pipe(z.number().int().positive())
})
.strict();
const updateClientSchema = z
.object({
name: z.string().min(1).max(255).optional(),
siteIds: z
.array(z.string().transform(Number).pipe(z.number()))
.optional()
})
.strict();
export type UpdateClientBody = z.infer<typeof updateClientSchema>;
registry.registerPath({
method: "post",
path: "/client/{clientId}",
description: "Update a client by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: updateClientParamsSchema,
body: {
content: {
"application/json": {
schema: updateClientSchema
}
}
}
},
responses: {}
});
export async function updateClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = updateClientSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { name, siteIds } = parsedBody.data;
const parsedParams = updateClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
// Fetch the client to make sure it exists and the user has access to it
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
if (siteIds) {
let sitesAdded = [];
let sitesRemoved = [];
// Fetch existing site associations
const existingSites = await db
.select({ siteId: clientSites.siteId })
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
const existingSiteIds = existingSites.map((site) => site.siteId);
// Determine which sites were added and removed
sitesAdded = siteIds.filter(
(siteId) => !existingSiteIds.includes(siteId)
);
sitesRemoved = existingSiteIds.filter(
(siteId) => !siteIds.includes(siteId)
);
logger.info(
`Adding ${sitesAdded.length} new sites to client ${client.clientId}`
);
for (const siteId of sitesAdded) {
if (!client.subnet || !client.pubKey || !client.endpoint) {
logger.debug("Client subnet, pubKey or endpoint is not set");
continue;
}
const site = await newtAddPeer(siteId, {
publicKey: client.pubKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: client.endpoint
});
if (!site) {
logger.debug("Failed to add peer to newt - missing site");
continue;
}
if (!site.endpoint || !site.publicKey) {
logger.debug("Site endpoint or publicKey is not set");
continue;
}
await olmAddPeer(client.clientId, {
siteId: siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
logger.info(
`Removing ${sitesRemoved.length} sites from client ${client.clientId}`
);
for (const siteId of sitesRemoved) {
if (!client.pubKey) {
logger.debug("Client pubKey is not set");
continue;
}
const site = await newtDeletePeer(siteId, client.pubKey);
if (!site) {
logger.debug(
"Failed to delete peer from newt - missing site"
);
continue;
}
if (!site.endpoint || !site.publicKey) {
logger.debug("Site endpoint or publicKey is not set");
continue;
}
await olmDeletePeer(client.clientId, site.siteId, site.publicKey);
}
}
await db.transaction(async (trx) => {
// Update client name if provided
if (name) {
await trx
.update(clients)
.set({ name })
.where(eq(clients.clientId, clientId));
}
// Update site associations if provided
if (siteIds) {
// Delete existing site associations
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, clientId));
// Create new site associations
if (siteIds.length > 0) {
await trx.insert(clientSites).values(
siteIds.map((siteId) => ({
clientId,
siteId
}))
);
}
}
// Fetch the updated client
const [updatedClient] = await trx
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
return response(res, {
data: updatedClient,
success: true,
error: false,
message: "Client updated successfully",
status: HttpCode.OK
});
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -0,0 +1,287 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, Domain, domains, OrgDomains, orgDomains } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { subdomainSchema } from "@server/lib/schemas";
import { generateId } from "@server/auth/sessions/app";
import { eq, and } from "drizzle-orm";
import { isValidDomain } from "@server/lib/validators";
import { build } from "@server/build";
const paramsSchema = z
.object({
orgId: z.string()
})
.strict();
const bodySchema = z
.object({
type: z.enum(["ns", "cname", "wildcard"]),
baseDomain: subdomainSchema
})
.strict();
export type CreateDomainResponse = {
domainId: string;
nsRecords?: string[];
cnameRecords?: { baseDomain: string; value: string }[];
txtRecords?: { baseDomain: string; value: string }[];
};
// Helper to check if a domain is a subdomain or equal to another domain
function isSubdomainOrEqual(a: string, b: string): boolean {
const aParts = a.toLowerCase().split(".");
const bParts = b.toLowerCase().split(".");
if (aParts.length < bParts.length) return false;
return aParts.slice(-bParts.length).join(".") === bParts.join(".");
}
export async function createOrgDomain(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const { type, baseDomain } = parsedBody.data;
if (build == "oss") {
if (type !== "wildcard") {
return next(
createHttpError(
HttpCode.NOT_IMPLEMENTED,
"Creating NS or CNAME records is not supported"
)
);
}
} else if (build == "enterprise" || build == "saas") {
if (type !== "ns" && type !== "cname") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid domain type. Only NS, CNAME are allowed."
)
);
}
}
// Validate organization exists
if (!isValidDomain(baseDomain)) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid domain format")
);
}
let numOrgDomains: OrgDomains[] | undefined;
let cnameRecords: CreateDomainResponse["cnameRecords"];
let txtRecords: CreateDomainResponse["txtRecords"];
let nsRecords: CreateDomainResponse["nsRecords"];
let returned: Domain | undefined;
await db.transaction(async (trx) => {
const [existing] = await trx
.select()
.from(domains)
.where(
and(
eq(domains.baseDomain, baseDomain),
eq(domains.type, type)
)
)
.leftJoin(
orgDomains,
eq(orgDomains.domainId, domains.domainId)
);
if (existing) {
const {
domains: existingDomain,
orgDomains: existingOrgDomain
} = existing;
// user alrady added domain to this account
// always reject
if (existingOrgDomain?.orgId === orgId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Domain is already added to this org"
)
);
}
// domain already exists elsewhere
// check if it's already fully verified
if (existingDomain.verified) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Domain is already verified to an org"
)
);
}
}
// --- Domain overlap logic ---
// Only consider existing verified domains
const verifiedDomains = await trx
.select()
.from(domains)
.where(eq(domains.verified, true));
if (type == "cname") {
// Block if a verified CNAME exists at the same name
const cnameExists = verifiedDomains.some(
(d) => d.type === "cname" && d.baseDomain === baseDomain
);
if (cnameExists) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`A CNAME record already exists for ${baseDomain}. Only one CNAME record is allowed per domain.`
)
);
}
// Block if a verified NS exists at or below (same or subdomain)
const nsAtOrBelow = verifiedDomains.some(
(d) =>
d.type === "ns" &&
(isSubdomainOrEqual(baseDomain, d.baseDomain) ||
baseDomain === d.baseDomain)
);
if (nsAtOrBelow) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`A nameserver (NS) record exists at or below ${baseDomain}. You cannot create a CNAME record here.`
)
);
}
} else if (type == "ns") {
// Block if a verified NS exists at or below (same or subdomain)
const nsAtOrBelow = verifiedDomains.some(
(d) =>
d.type === "ns" &&
(isSubdomainOrEqual(baseDomain, d.baseDomain) ||
baseDomain === d.baseDomain)
);
if (nsAtOrBelow) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`A nameserver (NS) record already exists at or below ${baseDomain}. You cannot create another NS record here.`
)
);
}
} else if (type == "wildcard") {
// TODO: Figure out how to handle wildcards
}
const domainId = generateId(15);
const [insertedDomain] = await trx
.insert(domains)
.values({
domainId,
baseDomain,
type,
verified: build == "oss" ? true : false
})
.returning();
returned = insertedDomain;
// add domain to account
await trx
.insert(orgDomains)
.values({
orgId,
domainId
})
.returning();
// TODO: This needs to be cross region and not hardcoded
if (type === "ns") {
nsRecords = ["ns-east.fossorial.io", "ns-west.fossorial.io"];
} else if (type === "cname") {
cnameRecords = [
{
value: `${domainId}.cname.fossorial.io`,
baseDomain: baseDomain
},
{
value: `_acme-challenge.${domainId}.cname.fossorial.io`,
baseDomain: `_acme-challenge.${baseDomain}`
}
];
} else if (type === "wildcard") {
cnameRecords = [
{
value: `Server IP Address`,
baseDomain: `*.${baseDomain}`
},
{
value: `Server IP Address`,
baseDomain: `${baseDomain}`
}
];
}
numOrgDomains = await trx
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
});
if (!returned) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create domain"
)
);
}
return response<CreateDomainResponse>(res, {
data: {
domainId: returned.domainId,
cnameRecords,
txtRecords,
nsRecords
},
success: true,
error: false,
message: "Domain created successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -0,0 +1,72 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, domains, OrgDomains, orgDomains } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { and, eq } from "drizzle-orm";
const paramsSchema = z
.object({
domainId: z.string(),
orgId: z.string()
})
.strict();
export type DeleteAccountDomainResponse = {
success: boolean;
};
export async function deleteAccountDomain(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsed = paramsSchema.safeParse(req.params);
if (!parsed.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsed.error).toString()
)
);
}
const { domainId, orgId } = parsed.data;
let numOrgDomains: OrgDomains[] | undefined;
await db.transaction(async (trx) => {
await trx
.delete(orgDomains)
.where(
and(
eq(orgDomains.orgId, orgId),
eq(orgDomains.domainId, domainId)
)
);
await trx.delete(domains).where(eq(domains.domainId, domainId));
numOrgDomains = await trx
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
});
return response<DeleteAccountDomainResponse>(res, {
data: { success: true },
success: true,
error: false,
message: "Domain deleted from account successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -1 +1,4 @@
export * from "./listDomains";
export * from "./createOrgDomain";
export * from "./deleteOrgDomain";
export * from "./restartOrgDomain";

View file

@ -37,7 +37,11 @@ async function queryDomains(orgId: string, limit: number, offset: number) {
const res = await db
.select({
domainId: domains.domainId,
baseDomain: domains.baseDomain
baseDomain: domains.baseDomain,
verified: domains.verified,
type: domains.type,
failed: domains.failed,
tries: domains.tries,
})
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId))
@ -112,7 +116,7 @@ export async function listDomains(
},
success: true,
error: false,
message: "Users retrieved successfully",
message: "Domains retrieved successfully",
status: HttpCode.OK
});
} catch (error) {

View file

@ -0,0 +1,57 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, domains } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { and, eq } from "drizzle-orm";
const paramsSchema = z
.object({
domainId: z.string(),
orgId: z.string()
})
.strict();
export type RestartOrgDomainResponse = {
success: boolean;
};
export async function restartOrgDomain(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsed = paramsSchema.safeParse(req.params);
if (!parsed.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsed.error).toString()
)
);
}
const { domainId, orgId } = parsed.data;
await db
.update(domains)
.set({ failed: false, tries: 0 })
.where(and(eq(domains.domainId, domainId)));
return response<RestartOrgDomainResponse>(res, {
data: { success: true },
success: true,
error: false,
message: "Domain restarted successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -8,6 +8,7 @@ import * as target from "./target";
import * as user from "./user";
import * as auth from "./auth";
import * as role from "./role";
import * as client from "./client";
import * as supporterKey from "./supporterKey";
import * as accessToken from "./accessToken";
import * as idp from "./idp";
@ -28,14 +29,20 @@ import {
getUserOrgs,
verifyUserIsServerAdmin,
verifyIsLoggedInUser,
verifyApiKeyAccess
verifyClientAccess,
verifyApiKeyAccess,
verifyDomainAccess,
verifyClientsEnabled,
verifyUserHasAction,
verifyUserIsOrgOwner
} from "@server/middlewares";
import { verifyUserHasAction } from "../middlewares/verifyUserHasAction";
import { createStore } from "@server/lib/rateLimitStore";
import { ActionsEnum } from "@server/auth/actions";
import { verifyUserIsOrgOwner } from "../middlewares/verifyUserIsOrgOwner";
import { createNewt, getToken } from "./newt";
import { createNewt, getNewtToken } from "./newt";
import { getOlmToken } from "./olm";
import rateLimit from "express-rate-limit";
import createHttpError from "http-errors";
import { build } from "@server/build";
// Root routes
export const unauthenticated = Router();
@ -48,8 +55,11 @@ unauthenticated.get("/", (_, res) => {
export const authenticated = Router();
authenticated.use(verifySessionUserMiddleware);
authenticated.get("/pick-org-defaults", org.pickOrgDefaults);
authenticated.get("/org/checkId", org.checkId);
authenticated.put("/org", getUserOrgs, org.createOrg);
if (build === "oss" || build === "enterprise") {
authenticated.put("/org", getUserOrgs, org.createOrg);
}
authenticated.get("/orgs", verifyUserIsServerAdmin, org.listOrgs);
authenticated.get("/user/:userId/orgs", verifyIsLoggedInUser, org.listUserOrgs);
@ -104,6 +114,55 @@ authenticated.get(
verifyUserHasAction(ActionsEnum.getSite),
site.getSite
);
authenticated.get(
"/org/:orgId/pick-client-defaults",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
client.pickClientDefaults
);
authenticated.get(
"/org/:orgId/clients",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listClients),
client.listClients
);
authenticated.get(
"/org/:orgId/client/:clientId",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.getClient),
client.getClient
);
authenticated.put(
"/org/:orgId/client",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
client.createClient
);
authenticated.delete(
"/client/:clientId",
verifyClientsEnabled,
verifyClientAccess,
verifyUserHasAction(ActionsEnum.deleteClient),
client.deleteClient
);
authenticated.post(
"/client/:clientId",
verifyClientsEnabled,
verifyClientAccess, // this will check if the user has access to the client
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
client.updateClient
);
// authenticated.get(
// "/site/:siteId/roles",
// verifySiteAccess,
@ -698,6 +757,29 @@ authenticated.get(
apiKeys.getApiKey
);
authenticated.put(
`/org/:orgId/domain`,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createOrgDomain),
domain.createOrgDomain
);
authenticated.post(
`/org/:orgId/domain/:domainId/restart`,
verifyOrgAccess,
verifyDomainAccess,
verifyUserHasAction(ActionsEnum.restartOrgDomain),
domain.restartOrgDomain
);
authenticated.delete(
`/org/:orgId/domain/:domainId`,
verifyOrgAccess,
verifyDomainAccess,
verifyUserHasAction(ActionsEnum.deleteOrgDomain),
domain.deleteAccountDomain
);
// Auth routes
export const authRouter = Router();
unauthenticated.use("/auth", authRouter);
@ -751,7 +833,20 @@ authRouter.post(
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
}),
getToken
getNewtToken
);
authRouter.post(
"/olm/get-token",
rateLimit({
windowMs: 15 * 60 * 1000,
max: 900,
keyGenerator: (req) => `newtGetToken:${req.body.newtId}`,
handler: (req, res, next) => {
const message = `You can only request an Olm token ${900} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
}),
getOlmToken
);
authRouter.post(
@ -836,7 +931,8 @@ authRouter.post(
handler: (req, res, next) => {
const message = `You can only request an email verification code ${15} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
},
store: createStore()
}),
auth.requestEmailVerificationCode
);
@ -856,7 +952,8 @@ authRouter.post(
handler: (req, res, next) => {
const message = `You can only request a password reset ${15} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
},
store: createStore()
}),
auth.requestPasswordReset
);
@ -914,7 +1011,8 @@ authRouter.post(
handler: (req, res, next) => {
const message = `You can only request an email OTP ${15} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
},
store: createStore()
}),
resource.authWithWhitelist
);

View file

@ -0,0 +1,160 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { clients, exitNodes, newts, olms, Site, sites, clientSites } from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
// Define Zod schema for request validation
const getAllRelaysSchema = z.object({
publicKey: z.string().optional(),
});
// Type for peer destination
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
// Updated mappings type to support multiple destinations per endpoint
interface ProxyMapping {
destinations: PeerDestination[];
}
export async function getAllRelays(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
// Validate request parameters
const parsedParams = getAllRelaysSchema.safeParse(req.body);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { publicKey } = parsedParams.data;
if (!publicKey) {
return next(createHttpError(HttpCode.BAD_REQUEST, 'publicKey is required'));
}
// Fetch exit node
let [exitNode] = await db.select().from(exitNodes).where(eq(exitNodes.publicKey, publicKey));
if (!exitNode) {
return next(createHttpError(HttpCode.NOT_FOUND, "Exit node not found"));
}
// Fetch sites for this exit node
const sitesRes = await db.select().from(sites).where(eq(sites.exitNodeId, exitNode.exitNodeId));
if (sitesRes.length === 0) {
return res.status(HttpCode.OK).send({
mappings: {}
});
}
// Initialize mappings object for multi-peer support
let mappings: { [key: string]: ProxyMapping } = {};
// Process each site
for (const site of sitesRes) {
if (!site.endpoint || !site.subnet || !site.listenPort) {
continue;
}
// Find all clients associated with this site through clientSites
const clientSitesRes = await db
.select()
.from(clientSites)
.where(eq(clientSites.siteId, site.siteId));
for (const clientSite of clientSitesRes) {
// Get client information
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientSite.clientId));
if (!client || !client.endpoint) {
continue;
}
// Add this site as a destination for the client
if (!mappings[client.endpoint]) {
mappings[client.endpoint] = { destinations: [] };
}
// Add site as a destination for this client
const destination: PeerDestination = {
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
};
// Check if this destination is already in the array to avoid duplicates
const isDuplicate = mappings[client.endpoint].destinations.some(
dest => dest.destinationIP === destination.destinationIP &&
dest.destinationPort === destination.destinationPort
);
if (!isDuplicate) {
mappings[client.endpoint].destinations.push(destination);
}
}
// Also handle site-to-site communication (all sites in the same org)
if (site.orgId) {
const orgSites = await db
.select()
.from(sites)
.where(eq(sites.orgId, site.orgId));
for (const peer of orgSites) {
// Skip self
if (peer.siteId === site.siteId || !peer.endpoint || !peer.subnet || !peer.listenPort) {
continue;
}
// Add peer site as a destination for this site
if (!mappings[site.endpoint]) {
mappings[site.endpoint] = { destinations: [] };
}
const destination: PeerDestination = {
destinationIP: peer.subnet.split("/")[0],
destinationPort: peer.listenPort
};
// Check for duplicates
const isDuplicate = mappings[site.endpoint].destinations.some(
dest => dest.destinationIP === destination.destinationIP &&
dest.destinationPort === destination.destinationPort
);
if (!isDuplicate) {
mappings[site.endpoint].destinations.push(destination);
}
}
}
}
logger.debug(`Returning mappings for ${Object.keys(mappings).length} endpoints`);
return res.status(HttpCode.OK).send({ mappings });
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred..."
)
);
}
}

View file

@ -53,7 +53,7 @@ export async function getConfig(
}
// Fetch exit node
let exitNodeQuery = await db
const exitNodeQuery = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.publicKey, publicKey));
@ -68,6 +68,10 @@ export async function getConfig(
subEndpoint = await getUniqueExitNodeEndpointName();
}
const exitNodeName =
config.getRawConfig().gerbil.exit_node_name ||
`Exit Node ${publicKey.slice(0, 8)}`;
// create a new exit node
exitNode = await db
.insert(exitNodes)
@ -77,7 +81,7 @@ export async function getConfig(
address,
listenPort,
reachableAt,
name: `Exit Node ${publicKey.slice(0, 8)}`
name: exitNodeName
})
.returning()
.execute();

View file

@ -1,2 +1,4 @@
export * from "./getConfig";
export * from "./receiveBandwidth";
export * from "./updateHolePunch";
export * from "./getAllRelays";

View file

@ -1,12 +1,15 @@
import { Request, Response, NextFunction } from "express";
import { eq } from "drizzle-orm";
import { sites, } from "@server/db";
import { eq, and, lt, inArray, sql } from "drizzle-orm";
import { sites } from "@server/db";
import { db } from "@server/db";
import logger from "@server/logger";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
// Track sites that are already offline to avoid unnecessary queries
const offlineSites = new Set<string>();
interface PeerBandwidth {
publicKey: string;
bytesIn: number;
@ -25,47 +28,101 @@ export const receiveBandwidth = async (
throw new Error("Invalid bandwidth data");
}
const currentTime = new Date();
const oneMinuteAgo = new Date(currentTime.getTime() - 60000); // 1 minute ago
logger.debug(`Received data: ${JSON.stringify(bandwidthData)}`);
await db.transaction(async (trx) => {
for (const peer of bandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// First, handle sites that are actively reporting bandwidth
const activePeers = bandwidthData.filter(peer => peer.bytesIn > 0); // Bytesout will have data as it tries to send keep alive messages
const [site] = await trx
.select()
.from(sites)
.where(eq(sites.pubKey, publicKey))
.limit(1);
if (activePeers.length > 0) {
// Remove any active peers from offline tracking since they're sending data
activePeers.forEach(peer => offlineSites.delete(peer.publicKey));
if (!site) {
logger.warn(`Site not found for public key: ${publicKey}`);
continue;
}
let online = site.online;
// Aggregate usage data by organization
const orgUsageMap = new Map<string, number>();
const orgUptimeMap = new Map<string, number>();
// if the bandwidth for the site is > 0 then set it to online. if it has been less than 0 (no update) for 5 minutes then set it to offline
if (bytesIn > 0 || bytesOut > 0) {
online = true;
} else if (site.lastBandwidthUpdate) {
const lastBandwidthUpdate = new Date(
site.lastBandwidthUpdate
);
const currentTime = new Date();
const diff =
currentTime.getTime() - lastBandwidthUpdate.getTime();
if (diff < 300000) {
online = false;
// Update all active sites with bandwidth data and get the site data in one operation
const updatedSites = [];
for (const peer of activePeers) {
const updatedSite = await trx
.update(sites)
.set({
megabytesOut: sql`${sites.megabytesOut} + ${peer.bytesIn}`,
megabytesIn: sql`${sites.megabytesIn} + ${peer.bytesOut}`,
lastBandwidthUpdate: currentTime.toISOString(),
online: true
})
.where(eq(sites.pubKey, peer.publicKey))
.returning({
online: sites.online,
orgId: sites.orgId,
siteId: sites.siteId,
lastBandwidthUpdate: sites.lastBandwidthUpdate,
});
if (updatedSite.length > 0) {
updatedSites.push({ ...updatedSite[0], peer });
}
}
// Update the site's bandwidth usage
await trx
.update(sites)
.set({
megabytesOut: (site.megabytesOut || 0) + bytesIn,
megabytesIn: (site.megabytesIn || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString(),
online
})
.where(eq(sites.siteId, site.siteId));
// Calculate org usage aggregations using the updated site data
for (const { peer, ...site } of updatedSites) {
// Aggregate bandwidth usage for the org
const totalBandwidth = peer.bytesIn + peer.bytesOut;
const currentOrgUsage = orgUsageMap.get(site.orgId) || 0;
orgUsageMap.set(site.orgId, currentOrgUsage + totalBandwidth);
// Add 10 seconds of uptime for each active site
const currentOrgUptime = orgUptimeMap.get(site.orgId) || 0;
orgUptimeMap.set(site.orgId, currentOrgUptime + 10 / 60); // Store in minutes and jut add 10 seconds
}
}
// Handle sites that reported zero bandwidth but need online status updated
const zeroBandwidthPeers = bandwidthData.filter(peer =>
peer.bytesIn === 0 && !offlineSites.has(peer.publicKey) // Bytesout will have data as it tries to send keep alive messages
);
if (zeroBandwidthPeers.length > 0) {
const zeroBandwidthSites = await trx
.select()
.from(sites)
.where(inArray(sites.pubKey, zeroBandwidthPeers.map(p => p.publicKey)));
for (const site of zeroBandwidthSites) {
let newOnlineStatus = site.online;
// Check if site should go offline based on last bandwidth update WITH DATA
if (site.lastBandwidthUpdate) {
const lastUpdateWithData = new Date(site.lastBandwidthUpdate);
if (lastUpdateWithData < oneMinuteAgo) {
newOnlineStatus = false;
}
} else {
// No previous data update recorded, set to offline
newOnlineStatus = false;
}
// Always update lastBandwidthUpdate to show this instance is receiving reports
// Only update online status if it changed
if (site.online !== newOnlineStatus) {
await trx
.update(sites)
.set({
online: newOnlineStatus
})
.where(eq(sites.siteId, site.siteId));
// If site went offline, add it to our tracking set
if (!newOnlineStatus && site.pubKey) {
offlineSites.add(site.pubKey);
}
}
}
}
});
@ -73,7 +130,7 @@ export const receiveBandwidth = async (
data: {},
success: true,
error: false,
message: "Organization retrieved successfully",
message: "Bandwidth data updated successfully",
status: HttpCode.OK
});
} catch (error) {

View file

@ -0,0 +1,242 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { clients, newts, olms, Site, sites, clientSites } from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
// Define Zod schema for request validation
const updateHolePunchSchema = z.object({
olmId: z.string().optional(),
newtId: z.string().optional(),
token: z.string(),
ip: z.string(),
port: z.number(),
timestamp: z.number()
});
// New response type with multi-peer destination support
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
export async function updateHolePunch(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
// Validate request parameters
const parsedParams = updateHolePunchSchema.safeParse(req.body);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { olmId, newtId, ip, port, timestamp, token } = parsedParams.data;
let currentSiteId: number | undefined;
let destinations: PeerDestination[] = [];
if (olmId) {
logger.debug(`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`);
const { session, olm: olmSession } =
await validateOlmSessionToken(token);
if (!session || !olmSession) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
if (olmId !== olmSession.olmId) {
logger.warn(`Olm ID mismatch: ${olmId} !== ${olmSession.olmId}`);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId));
if (!olm || !olm.clientId) {
logger.warn(`Olm not found: ${olmId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Olm not found")
);
}
const [client] = await db
.update(clients)
.set({
endpoint: `${ip}:${port}`,
lastHolePunch: timestamp
})
.where(eq(clients.clientId, olm.clientId))
.returning();
if (!client) {
logger.warn(`Client not found for olm: ${olmId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Client not found")
);
}
// Get all sites that this client is connected to
const clientSitePairs = await db
.select()
.from(clientSites)
.where(eq(clientSites.clientId, client.clientId));
if (clientSitePairs.length === 0) {
logger.warn(`No sites found for client: ${client.clientId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "No sites found for client")
);
}
// Get all sites details
const siteIds = clientSitePairs.map(pair => pair.siteId);
for (const siteId of siteIds) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (site && site.subnet && site.listenPort) {
destinations.push({
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
});
}
}
} else if (newtId) {
const { session, newt: newtSession } =
await validateNewtSessionToken(token);
if (!session || !newtSession) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
if (newtId !== newtSession.newtId) {
logger.warn(`Newt ID mismatch: ${newtId} !== ${newtSession.newtId}`);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.newtId, newtId));
if (!newt || !newt.siteId) {
logger.warn(`Newt not found: ${newtId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "New not found")
);
}
currentSiteId = newt.siteId;
// Update the current site with the new endpoint
const [updatedSite] = await db
.update(sites)
.set({
endpoint: `${ip}:${port}`,
lastHolePunch: timestamp
})
.where(eq(sites.siteId, newt.siteId))
.returning();
if (!updatedSite || !updatedSite.subnet) {
logger.warn(`Site not found: ${newt.siteId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Site not found")
);
}
// Find all clients that connect to this site
const sitesClientPairs = await db
.select()
.from(clientSites)
.where(eq(clientSites.siteId, newt.siteId));
// Get client details for each client
for (const pair of sitesClientPairs) {
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, pair.clientId));
if (client && client.endpoint) {
const [host, portStr] = client.endpoint.split(':');
if (host && portStr) {
destinations.push({
destinationIP: host,
destinationPort: parseInt(portStr, 10)
});
}
}
}
// If this is a newt/site, also add other sites in the same org
// if (updatedSite.orgId) {
// const orgSites = await db
// .select()
// .from(sites)
// .where(eq(sites.orgId, updatedSite.orgId));
// for (const site of orgSites) {
// // Don't add the current site to the destinations
// if (site.siteId !== currentSiteId && site.subnet && site.endpoint && site.listenPort) {
// const [host, portStr] = site.endpoint.split(':');
// if (host && portStr) {
// destinations.push({
// destinationIP: host,
// destinationPort: site.listenPort
// });
// }
// }
// }
// }
}
// if (destinations.length === 0) {
// logger.warn(
// `No peer destinations found for olmId: ${olmId} or newtId: ${newtId}`
// );
// return next(createHttpError(HttpCode.NOT_FOUND, "No peer destinations found"));
// }
// Return the new multi-peer structure
return res.status(HttpCode.OK).send({
destinations: destinations
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred..."
)
);
}
}

View file

@ -11,6 +11,7 @@ import {
idpOidcConfig,
idpOrg,
orgs,
Role,
roles,
userOrgs,
users
@ -307,6 +308,8 @@ export async function validateOidcCallback(
let existingUserId = existingUser?.userId;
let orgUserCounts: { orgId: string; userCount: number }[] = [];
// sync the user with the orgs and roles
await db.transaction(async (trx) => {
let userId = existingUser?.userId;
@ -410,6 +413,19 @@ export async function validateOidcCallback(
}))
);
}
// Loop through all the orgs and get the total number of users from the userOrgs table
for (const org of currentUserOrgs) {
const userCount = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, org.orgId));
orgUserCounts.push({
orgId: org.orgId,
userCount: userCount.length
});
}
});
const token = generateSessionToken();

View file

@ -51,6 +51,8 @@ internalRouter.use("/gerbil", gerbilRouter);
gerbilRouter.post("/get-config", gerbil.getConfig);
gerbilRouter.post("/receive-bandwidth", gerbil.receiveBandwidth);
gerbilRouter.post("/update-hole-punch", gerbil.updateHolePunch);
gerbilRouter.post("/get-all-relays", gerbil.getAllRelays);
// Badger routes
const badgerRouter = Router();

View file

@ -1,12 +1,29 @@
import {
handleRegisterMessage,
handleNewtRegisterMessage,
handleReceiveBandwidthMessage,
handleGetConfigMessage,
handleDockerStatusMessage,
handleDockerContainersMessage
handleDockerContainersMessage,
handleNewtPingRequestMessage
} from "./newt";
import {
handleOlmRegisterMessage,
handleOlmRelayMessage,
handleOlmPingMessage,
startOfflineChecker
} from "./olm";
import { MessageHandler } from "./ws";
export const messageHandlers: Record<string, MessageHandler> = {
"newt/wg/register": handleRegisterMessage,
"newt/wg/register": handleNewtRegisterMessage,
"olm/wg/register": handleOlmRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage,
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
"olm/wg/relay": handleOlmRelayMessage,
"olm/ping": handleOlmPingMessage,
"newt/socket/status": handleDockerStatusMessage,
"newt/socket/containers": handleDockerContainersMessage
"newt/socket/containers": handleDockerContainersMessage,
"newt/ping/request": handleNewtPingRequestMessage,
};
startOfflineChecker(); // this is to handle the offline check for olms

View file

@ -24,7 +24,7 @@ export const newtGetTokenBodySchema = z.object({
export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
export async function getToken(
export async function getNewtToken(
req: Request,
res: Response,
next: NextFunction

View file

@ -0,0 +1,165 @@
import { z } from "zod";
import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { db } from "@server/db";
import { clients, clientSites, Newt, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { updatePeer } from "../olm/peers";
const inputSchema = z.object({
publicKey: z.string(),
port: z.number().int().positive()
});
type Input = z.infer<typeof inputSchema>;
export const handleGetConfigMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
const now = new Date().getTime() / 1000;
logger.debug("Handling Newt get config message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const parsed = inputSchema.safeParse(message.data);
if (!parsed.success) {
logger.error(
"handleGetConfigMessage: Invalid input: " +
fromError(parsed.error).toString()
);
return;
}
const { publicKey, port } = message.data as Input;
const siteId = newt.siteId;
// Get the current site data
const [existingSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (!existingSite) {
logger.warn("handleGetConfigMessage: Site not found");
return;
}
// we need to wait for hole punch success
if (!existingSite.endpoint) {
logger.warn(`Site ${existingSite.siteId} has no endpoint, skipping`);
return;
}
if (existingSite.publicKey !== publicKey) {
// TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly)
}
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) {
logger.warn(
`Site ${existingSite.siteId} last hole punch is too old, skipping`
);
return;
}
// update the endpoint and the public key
const [site] = await db
.update(sites)
.set({
publicKey,
listenPort: port
})
.where(eq(sites.siteId, siteId))
.returning();
if (!site) {
logger.error("handleGetConfigMessage: Failed to update site");
return;
}
// Get all clients connected to this site
const clientsRes = await db
.select()
.from(clients)
.innerJoin(clientSites, eq(clients.clientId, clientSites.clientId))
.where(eq(clientSites.siteId, siteId));
// Prepare peers data for the response
const peers = await Promise.all(
clientsRes
.filter((client) => {
if (!client.clients.pubKey) {
return false;
}
if (!client.clients.subnet) {
return false;
}
if (!client.clients.endpoint) {
return false;
}
if (!client.clients.online) {
return false;
}
return true;
})
.map(async (client) => {
// Add or update this peer on the olm if it is connected
try {
if (site.endpoint && site.publicKey) {
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
} catch (error) {
logger.error(
`Failed to add/update peer ${client.clients.pubKey} to newt ${newt.newtId}: ${error}`
);
}
return {
publicKey: client.clients.pubKey!,
allowedIps: [`${client.clients.subnet.split('/')[0]}/32`], // we want to only allow from that client
endpoint: client.clientSites.isRelayed
? ""
: client.clients.endpoint! // if its relayed it should be localhost
};
})
);
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Build the configuration response
const configResponse = {
ipAddress: site.address,
peers: validPeers
};
logger.debug("Sending config: ", configResponse);
return {
message: {
type: "newt/wg/receive-config",
data: {
...configResponse
}
},
broadcast: false,
excludeSender: false
};
};

View file

@ -0,0 +1,89 @@
import { db, sites } from "@server/db";
import { MessageHandler } from "../ws";
import { exitNodes, Newt } from "@server/db";
import logger from "@server/logger";
import config from "@server/lib/config";
import { ne, eq, or, and, count } from "drizzle-orm";
export const handleNewtPingRequestMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling ping request newt message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
// TODO: pick which nodes to send and ping better than just all of them
let exitNodesList = await db
.select()
.from(exitNodes);
exitNodesList = exitNodesList.filter((node) => node.maxConnections !== 0);
let lastExitNodeId = null;
if (newt.siteId) {
const [lastExitNode] = await db
.select()
.from(sites)
.where(eq(sites.siteId, newt.siteId))
.limit(1);
lastExitNodeId = lastExitNode?.exitNodeId || null;
}
const exitNodesPayload = await Promise.all(
exitNodesList.map(async (node) => {
// (MAX_CONNECTIONS - current_connections) / MAX_CONNECTIONS)
// higher = more desirable
// like saying, this node has x% of its capacity left
let weight = 1;
const maxConnections = node.maxConnections;
if (maxConnections !== null && maxConnections !== undefined) {
const [currentConnections] = await db
.select({
count: count()
})
.from(sites)
.where(
and(
eq(sites.exitNodeId, node.exitNodeId),
eq(sites.online, true)
)
);
if (currentConnections.count >= maxConnections) {
return null
}
weight =
(maxConnections - currentConnections.count) /
maxConnections;
}
return {
exitNodeId: node.exitNodeId,
exitNodeName: node.name,
endpoint: node.endpoint,
weight,
wasPreviouslyConnected: node.exitNodeId === lastExitNodeId
};
})
);
// filter out null values
const filteredExitNodes = exitNodesPayload.filter((node) => node !== null);
return {
message: {
type: "newt/ping/exitNodes",
data: {
exitNodes: filteredExitNodes
}
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};
};

View file

@ -0,0 +1,358 @@
import { db, newts } from "@server/db";
import { MessageHandler } from "../ws";
import { exitNodes, Newt, resources, sites, Target, targets } from "@server/db";
import { eq, and, sql, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
import config from "@server/lib/config";
import {
findNextAvailableCidr,
getNextAvailableClientSubnet
} from "@server/lib/ip";
export type ExitNodePingResult = {
exitNodeId: number;
latencyMs: number;
weight: number;
error?: string;
exitNodeName: string;
endpoint: string;
wasPreviouslyConnected: boolean;
};
export const handleNewtRegisterMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling register newt message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const siteId = newt.siteId;
const { publicKey, pingResults, newtVersion, backwardsCompatible } =
message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
if (backwardsCompatible) {
logger.debug(
"Backwards compatible mode detecting - not sending connect message and waiting for ping response."
);
return;
}
let exitNodeId: number | undefined;
if (pingResults) {
const bestPingResult = selectBestExitNode(
pingResults as ExitNodePingResult[]
);
if (!bestPingResult) {
logger.warn("No suitable exit node found based on ping results");
return;
}
exitNodeId = bestPingResult.exitNodeId;
}
if (newtVersion) {
// update the newt version in the database
await db
.update(newts)
.set({
version: newtVersion as string
})
.where(eq(newts.newtId, newt.newtId));
}
const [oldSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!oldSite || !oldSite.exitNodeId) {
logger.warn("Site not found or does not have exit node");
return;
}
let siteSubnet = oldSite.subnet;
let exitNodeIdToQuery = oldSite.exitNodeId;
if (exitNodeId && (oldSite.exitNodeId !== exitNodeId || !oldSite.subnet)) {
// This effectively moves the exit node to the new one
exitNodeIdToQuery = exitNodeId; // Use the provided exitNodeId if it differs from the site's exitNodeId
const sitesQuery = await db
.select({
subnet: sites.subnet
})
.from(sites)
.where(eq(sites.exitNodeId, exitNodeId));
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, exitNodeIdToQuery))
.limit(1);
const blockSize = config.getRawConfig().gerbil.site_block_size;
const subnets = sitesQuery.map((site) => site.subnet).filter((subnet) => subnet !== null);
subnets.push(exitNode.address.replace(/\/\d+$/, `/${blockSize}`));
const newSubnet = findNextAvailableCidr(
subnets,
blockSize,
exitNode.address
);
if (!newSubnet) {
logger.error("No available subnets found for the new exit node");
return;
}
siteSubnet = newSubnet;
await db
.update(sites)
.set({
pubKey: publicKey,
exitNodeId: exitNodeId,
subnet: newSubnet
})
.where(eq(sites.siteId, siteId))
.returning();
} else {
await db
.update(sites)
.set({
pubKey: publicKey
})
.where(eq(sites.siteId, siteId))
.returning();
}
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, exitNodeIdToQuery))
.limit(1);
if (oldSite.pubKey && oldSite.pubKey !== publicKey) {
logger.info("Public key mismatch. Deleting old peer...");
await deletePeer(oldSite.exitNodeId, oldSite.pubKey);
}
if (!siteSubnet) {
logger.warn("Site has no subnet");
return;
}
// add the peer to the exit node
await addPeer(exitNodeIdToQuery, {
publicKey: publicKey,
allowedIps: [siteSubnet]
});
// Improved version
const allResources = await db.transaction(async (tx) => {
// First get all resources for the site
const resourcesList = await tx
.select({
resourceId: resources.resourceId,
subdomain: resources.subdomain,
fullDomain: resources.fullDomain,
ssl: resources.ssl,
blockAccess: resources.blockAccess,
sso: resources.sso,
emailWhitelistEnabled: resources.emailWhitelistEnabled,
http: resources.http,
proxyPort: resources.proxyPort,
protocol: resources.protocol
})
.from(resources)
.where(eq(resources.siteId, siteId));
// Get all enabled targets for these resources in a single query
const resourceIds = resourcesList.map((r) => r.resourceId);
const allTargets =
resourceIds.length > 0
? await tx
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled
})
.from(targets)
.where(
and(
inArray(targets.resourceId, resourceIds),
eq(targets.enabled, true)
)
)
: [];
// Combine the data in JS instead of using SQL for the JSON
return resourcesList.map((resource) => ({
...resource,
targets: allTargets.filter(
(target) => target.resourceId === resource.resourceId
)
}));
});
const { tcpTargets, udpTargets } = allResources.reduce(
(acc, resource) => {
// Skip resources with no targets
if (!resource.targets?.length) return acc;
// Format valid targets into strings
const formattedTargets = resource.targets
.filter(
(target: Target) =>
target?.internalPort && target?.ip && target?.port
)
.map(
(target: Target) =>
`${target.internalPort}:${target.ip}:${target.port}`
);
// Add to the appropriate protocol array
if (resource.protocol === "tcp") {
acc.tcpTargets.push(...formattedTargets);
} else {
acc.udpTargets.push(...formattedTargets);
}
return acc;
},
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
);
return {
message: {
type: "newt/wg/connect",
data: {
endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`,
publicKey: exitNode.publicKey,
serverIP: exitNode.address.split("/")[0],
tunnelIP: siteSubnet.split("/")[0],
targets: {
udp: udpTargets,
tcp: tcpTargets
}
}
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};
};
/**
* Selects the most suitable exit node from a list of ping results.
*
* The selection algorithm follows these steps:
*
* 1. **Filter Invalid Nodes**: Excludes nodes with errors or zero weight.
*
* 2. **Sort by Latency**: Sorts valid nodes in ascending order of latency.
*
* 3. **Preferred Selection**:
* - If the lowest-latency node has sufficient capacity (10% weight),
* check if a previously connected node is also acceptable.
* - The previously connected node is preferred if its latency is within
* 30ms or 15% of the best nodes latency.
*
* 4. **Fallback to Next Best**:
* - If the lowest-latency node is under capacity, find the next node
* with acceptable capacity.
*
* 5. **Final Fallback**:
* - If no nodes meet the capacity threshold, fall back to the node
* with the highest weight (i.e., most available capacity).
*
*/
function selectBestExitNode(
pingResults: ExitNodePingResult[]
): ExitNodePingResult | null {
const MIN_CAPACITY_THRESHOLD = 0.1;
const LATENCY_TOLERANCE_MS = 30;
const LATENCY_TOLERANCE_PERCENT = 0.15;
// Filter out invalid nodes
const validNodes = pingResults.filter((n) => !n.error && n.weight > 0);
if (validNodes.length === 0) {
logger.error("No valid exit nodes available");
return null;
}
// Sort by latency (ascending)
const sortedNodes = validNodes
.slice()
.sort((a, b) => a.latencyMs - b.latencyMs);
const lowestLatencyNode = sortedNodes[0];
logger.info(
`Lowest latency node: ${lowestLatencyNode.exitNodeName} (${lowestLatencyNode.latencyMs} ms, weight=${lowestLatencyNode.weight.toFixed(2)})`
);
// If lowest latency node has enough capacity, check if previously connected node is acceptable
if (lowestLatencyNode.weight >= MIN_CAPACITY_THRESHOLD) {
const previouslyConnectedNode = sortedNodes.find(
(n) =>
n.wasPreviouslyConnected && n.weight >= MIN_CAPACITY_THRESHOLD
);
if (previouslyConnectedNode) {
const latencyDiff =
previouslyConnectedNode.latencyMs - lowestLatencyNode.latencyMs;
const percentDiff = latencyDiff / lowestLatencyNode.latencyMs;
if (
latencyDiff <= LATENCY_TOLERANCE_MS ||
percentDiff <= LATENCY_TOLERANCE_PERCENT
) {
logger.info(
`Sticking with previously connected node: ${previouslyConnectedNode.exitNodeName} ` +
`(${previouslyConnectedNode.latencyMs} ms), latency diff = ${latencyDiff.toFixed(1)}ms ` +
`/ ${(percentDiff * 100).toFixed(1)}%.`
);
return previouslyConnectedNode;
}
}
return lowestLatencyNode;
}
// Otherwise, find the next node (after the lowest) that has enough capacity
for (let i = 1; i < sortedNodes.length; i++) {
const node = sortedNodes[i];
if (node.weight >= MIN_CAPACITY_THRESHOLD) {
logger.info(
`Lowest latency node under capacity. Using next best: ${node.exitNodeName} ` +
`(${node.latencyMs} ms, weight=${node.weight.toFixed(2)})`
);
return node;
}
}
// Fallback: pick the highest weight node
const fallbackNode = validNodes.reduce((a, b) =>
a.weight > b.weight ? a : b
);
logger.warn(
`No nodes with ≥10% weight. Falling back to highest capacity node: ${fallbackNode.exitNodeName}`
);
return fallbackNode;
}

View file

@ -0,0 +1,52 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import { clients, Newt } from "@server/db";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
interface PeerBandwidth {
publicKey: string;
bytesIn: number;
bytesOut: number;
}
export const handleReceiveBandwidthMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
if (!message.data.bandwidthData) {
logger.warn("No bandwidth data provided");
}
const bandwidthData: PeerBandwidth[] = message.data.bandwidthData;
if (!Array.isArray(bandwidthData)) {
throw new Error("Invalid bandwidth data");
}
await db.transaction(async (trx) => {
for (const peer of bandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// Find the client by public key
const [client] = await trx
.select()
.from(clients)
.where(eq(clients.pubKey, publicKey))
.limit(1);
if (!client) {
continue;
}
// Update the client's bandwidth usage
await trx
.update(clients)
.set({
megabytesOut: (client.megabytesIn || 0) + bytesIn,
megabytesIn: (client.megabytesOut || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString(),
})
.where(eq(clients.clientId, client.clientId));
}
});
};

View file

@ -1,174 +0,0 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import {
exitNodes,
resources,
sites,
Target,
targets
} from "@server/db";
import { eq, and, sql, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
export const handleRegisterMessage: MessageHandler = async (context) => {
const { message, newt, sendToClient } = context;
logger.info("Handling register message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const siteId = newt.siteId;
const { publicKey } = message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site || !site.exitNodeId) {
logger.warn("Site not found or does not have exit node");
return;
}
await db
.update(sites)
.set({
pubKey: publicKey
})
.where(eq(sites.siteId, siteId))
.returning();
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (site.pubKey && site.pubKey !== publicKey) {
logger.info("Public key mismatch. Deleting old peer...");
await deletePeer(site.exitNodeId, site.pubKey);
}
if (!site.subnet) {
logger.warn("Site has no subnet");
return;
}
// add the peer to the exit node
await addPeer(site.exitNodeId, {
publicKey: publicKey,
allowedIps: [site.subnet]
});
// Improved version
const allResources = await db.transaction(async (tx) => {
// First get all resources for the site
const resourcesList = await tx
.select({
resourceId: resources.resourceId,
subdomain: resources.subdomain,
fullDomain: resources.fullDomain,
ssl: resources.ssl,
blockAccess: resources.blockAccess,
sso: resources.sso,
emailWhitelistEnabled: resources.emailWhitelistEnabled,
http: resources.http,
proxyPort: resources.proxyPort,
protocol: resources.protocol
})
.from(resources)
.where(eq(resources.siteId, siteId));
// Get all enabled targets for these resources in a single query
const resourceIds = resourcesList.map((r) => r.resourceId);
const allTargets =
resourceIds.length > 0
? await tx
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled
})
.from(targets)
.where(
and(
inArray(targets.resourceId, resourceIds),
eq(targets.enabled, true)
)
)
: [];
// Combine the data in JS instead of using SQL for the JSON
return resourcesList.map((resource) => ({
...resource,
targets: allTargets.filter(
(target) => target.resourceId === resource.resourceId
)
}));
});
const { tcpTargets, udpTargets } = allResources.reduce(
(acc, resource) => {
// Skip resources with no targets
if (!resource.targets?.length) return acc;
// Format valid targets into strings
const formattedTargets = resource.targets
.filter(
(target: Target) =>
target?.internalPort && target?.ip && target?.port
)
.map(
(target: Target) =>
`${target.internalPort}:${target.ip}:${target.port}`
);
// Add to the appropriate protocol array
if (resource.protocol === "tcp") {
acc.tcpTargets.push(...formattedTargets);
} else {
acc.udpTargets.push(...formattedTargets);
}
return acc;
},
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
);
return {
message: {
type: "newt/wg/connect",
data: {
endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`,
publicKey: exitNode.publicKey,
serverIP: exitNode.address.split("/")[0],
tunnelIP: site.subnet.split("/")[0],
targets: {
udp: udpTargets,
tcp: tcpTargets
}
}
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};
};

View file

@ -1,9 +1,11 @@
import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { dockerSocketCache } from "./dockerSocket";
import { Newt } from "@server/db";
export const handleDockerStatusMessage: MessageHandler = async (context) => {
const { message, newt } = context;
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling Docker socket check response");
@ -33,7 +35,8 @@ export const handleDockerStatusMessage: MessageHandler = async (context) => {
export const handleDockerContainersMessage: MessageHandler = async (
context
) => {
const { message, newt } = context;
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling Docker containers response");

View file

@ -1,4 +1,7 @@
export * from "./createNewt";
export * from "./getToken";
export * from "./handleRegisterMessage";
export * from "./handleSocketMessages";
export * from "./getNewtToken";
export * from "./handleNewtRegisterMessage";
export * from "./handleReceiveBandwidthMessage";
export * from "./handleGetConfigMessage";
export * from "./handleSocketMessages";
export * from "./handleNewtPingRequestMessage";

View file

@ -0,0 +1,114 @@
import { db } from "@server/db";
import { newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "../ws";
import logger from "@server/logger";
export async function addPeer(
siteId: number,
peer: {
publicKey: string;
allowedIps: string[];
endpoint: string;
}
) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Exit node with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Site found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/add",
data: peer
});
logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`);
return site;
}
export async function deletePeer(siteId: number, publicKey: string) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/remove",
data: {
publicKey
}
});
logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`);
return site;
}
export async function updatePeer(
siteId: number,
publicKey: string,
peer: {
allowedIps?: string[];
endpoint?: string;
}
) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/update",
data: {
publicKey,
...peer
}
});
logger.info(`Updated peer ${publicKey} on newt ${newt.newtId}`);
return site;
}

View file

@ -0,0 +1,106 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { hash } from "@node-rs/argon2";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { newts } from "@server/db";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { SqliteError } from "better-sqlite3";
import moment from "moment";
import { generateSessionToken } from "@server/auth/sessions/app";
import { createNewtSession } from "@server/auth/sessions/newt";
import { fromError } from "zod-validation-error";
import { hashPassword } from "@server/auth/password";
export const createNewtBodySchema = z.object({});
export type CreateNewtBody = z.infer<typeof createNewtBodySchema>;
export type CreateNewtResponse = {
token: string;
newtId: string;
secret: string;
};
const createNewtSchema = z
.object({
newtId: z.string(),
secret: z.string()
})
.strict();
export async function createNewt(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = createNewtSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { newtId, secret } = parsedBody.data;
if (req.user && !req.userOrgRoleId) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
}
const secretHash = await hashPassword(secret);
await db.insert(newts).values({
newtId: newtId,
secretHash,
dateCreated: moment().toISOString(),
});
// give the newt their default permissions:
// await db.insert(newtActions).values({
// newtId: newtId,
// actionId: ActionsEnum.createOrg,
// orgId: null,
// });
const token = generateSessionToken();
await createNewtSession(token, newtId);
return response<CreateNewtResponse>(res, {
data: {
newtId,
secret,
token,
},
success: true,
error: false,
message: "Newt created successfully",
status: HttpCode.OK,
});
} catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A newt with that email address already exists"
)
);
} else {
console.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create newt"
)
);
}
}
}

View file

@ -0,0 +1,119 @@
import { generateSessionToken } from "@server/auth/sessions/app";
import { db } from "@server/db";
import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { eq } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import {
createOlmSession,
validateOlmSessionToken
} from "@server/auth/sessions/olm";
import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import config from "@server/lib/config";
export const olmGetTokenBodySchema = z.object({
olmId: z.string(),
secret: z.string(),
token: z.string().optional()
});
export type OlmGetTokenBody = z.infer<typeof olmGetTokenBodySchema>;
export async function getOlmToken(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedBody = olmGetTokenBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { olmId, secret, token } = parsedBody.data;
try {
if (token) {
const { session, olm } = await validateOlmSessionToken(token);
if (session) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Olm session already valid. Olm ID: ${olmId}. IP: ${req.ip}.`
);
}
return response<null>(res, {
data: null,
success: true,
error: false,
message: "Token session already valid",
status: HttpCode.OK
});
}
}
const existingOlmRes = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId));
if (!existingOlmRes || !existingOlmRes.length) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No olm found with that olmId"
)
);
}
const existingOlm = existingOlmRes[0];
const validSecret = await verifyPassword(
secret,
existingOlm.secretHash
);
if (!validSecret) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.`
);
}
return next(
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
);
}
logger.debug("Creating new olm session token");
const resToken = generateSessionToken();
await createOlmSession(resToken, existingOlm.olmId);
logger.debug("Token created successfully");
return response<{ token: string }>(res, {
data: {
token: resToken
},
success: true,
error: false,
message: "Token created successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to authenticate olm"
)
);
}
}

View file

@ -0,0 +1,93 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import { clients, Olm } from "@server/db";
import { eq, lt, isNull } from "drizzle-orm";
import logger from "@server/logger";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/**
* Starts the background interval that checks for clients that haven't pinged recently
* and marks them as offline
*/
export const startOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = new Date(Date.now() - OFFLINE_THRESHOLD_MS);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
await db
.update(clients)
.set({ online: false })
.where(
eq(clients.online, true) &&
(lt(clients.lastPing, twoMinutesAgo.toISOString()) || isNull(clients.lastPing))
);
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
logger.info("Started offline checker interval");
}
/**
* Stops the background interval that checks for offline clients
*/
export const stopOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped offline checker interval");
}
}
/**
* Handles ping messages from clients and responds with pong
*/
export const handleOlmPingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
try {
// Update the client's last ping timestamp
await db
.update(clients)
.set({
lastPing: new Date().toISOString(),
online: true,
})
.where(eq(clients.clientId, olm.clientId));
} catch (error) {
logger.error("Error handling ping message", { error });
}
return {
message: {
type: "pong",
data: {
timestamp: new Date().toISOString(),
}
},
broadcast: false,
excludeSender: false
};
};

View file

@ -0,0 +1,181 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import {
clients,
clientSites,
exitNodes,
Olm,
olms,
sites
} from "@server/db";
import { eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
const now = new Date().getTime() / 1000;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
const clientId = olm.clientId;
const { publicKey } = message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
if (client.exitNodeId) {
// Get the exit node for this site
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, client.exitNodeId))
.limit(1);
// Send holepunch message for each site
sendToClient(olm.olmId, {
type: "olm/wg/holepunch",
data: {
serverPubKey: exitNode.publicKey
}
});
}
if (now - (client.lastHolePunch || 0) > 6) {
logger.warn("Client last hole punch is too old, skipping all sites");
return;
}
if (client.pubKey !== publicKey) {
logger.info(
"Public key mismatch. Updating public key and clearing session info..."
);
// Update the client's public key
await db
.update(clients)
.set({
pubKey: publicKey
})
.where(eq(clients.clientId, olm.clientId));
// set isRelay to false for all of the client's sites to reset the connection metadata
await db
.update(clientSites)
.set({
isRelayed: false
})
.where(eq(clientSites.clientId, olm.clientId));
}
// Get all sites data
const sitesData = await db
.select()
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.where(eq(clientSites.clientId, client.clientId));
// Prepare an array to store site configurations
const siteConfigurations = [];
// Process each site
for (const { sites: site } of sitesData) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
);
continue;
}
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(`Site ${site.siteId} has no endpoint, skipping`);
continue;
}
if (site.lastHolePunch && now - site.lastHolePunch > 6) {
logger.warn(
`Site ${site.siteId} last hole punch is too old, skipping`
);
continue;
}
// If public key changed, delete old peer from this site
if (client.pubKey && client.pubKey != publicKey) {
logger.info(
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
);
await deletePeer(site.siteId, client.pubKey!);
}
if (!site.subnet) {
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
continue;
}
// Add the peer to the exit node for this site
if (client.endpoint) {
logger.info(
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${client.endpoint}`
);
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [`${client.subnet.split('/')[0]}/32`], // we want to only allow from that client
endpoint: client.endpoint
});
} else {
logger.warn(
`Client ${client.clientId} has no endpoint, skipping peer addition`
);
}
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
// If we have no valid site configurations, don't send a connect message
if (siteConfigurations.length === 0) {
logger.warn("No valid site configurations found");
return;
}
// Return connect message with all site configurations
return {
message: {
type: "olm/wg/connect",
data: {
sites: siteConfigurations,
tunnelIP: client.subnet
}
},
broadcast: false,
excludeSender: false
};
};

View file

@ -0,0 +1,58 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import { clients, clientSites, Olm } from "@server/db";
import { eq } from "drizzle-orm";
import { updatePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmRelayMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
logger.info("Handling relay olm message!");
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
return;
}
const clientId = olm.clientId;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Site not found or does not have exit node");
return;
}
// make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old
if (!client.pubKey) {
logger.warn("Site or client has no endpoint or listen port");
return;
}
const { siteId } = message.data;
await db
.update(clientSites)
.set({
isRelayed: true
})
.where(eq(clientSites.clientId, olm.clientId));
// update the peer on the exit node
await updatePeer(siteId, client.pubKey, {
endpoint: "" // this removes the endpoint
});
return;
};

View file

@ -0,0 +1,5 @@
export * from "./handleOlmRegisterMessage";
export * from "./getOlmToken";
export * from "./createOlm";
export * from "./handleOlmRelayMessage";
export * from "./handleOlmPingMessage";

View file

@ -0,0 +1,92 @@
import { db } from "@server/db";
import { clients, olms, newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "../ws";
import logger from "@server/logger";
export async function addPeer(
clientId: number,
peer: {
siteId: number;
publicKey: string;
endpoint: string;
serverIP: string | null;
serverPort: number | null;
}
) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: "olm/wg/peer/add",
data: {
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort
}
});
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
}
export async function deletePeer(clientId: number, siteId: number, publicKey: string) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: "olm/wg/peer/remove",
data: {
publicKey,
siteId: siteId
}
});
logger.info(`Deleted peer ${publicKey} from olm ${olm.olmId}`);
}
export async function updatePeer(
clientId: number,
peer: {
siteId: number;
publicKey: string;
endpoint: string;
serverIP: string | null;
serverPort: number | null;
}
) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: "olm/wg/peer/update",
data: {
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort
}
});
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
}

View file

@ -23,16 +23,16 @@ import config from "@server/lib/config";
import { fromError } from "zod-validation-error";
import { defaultRoleAllowedActions } from "../role";
import { OpenAPITags, registry } from "@server/openApi";
import { isValidCIDR } from "@server/lib/validators";
const createOrgSchema = z
.object({
orgId: z.string(),
name: z.string().min(1).max(255)
name: z.string().min(1).max(255),
subnet: z.string()
})
.strict();
// const MAX_ORGS = 5;
registry.registerPath({
method: "put",
path: "/org",
@ -78,7 +78,32 @@ export async function createOrg(
);
}
const { orgId, name } = parsedBody.data;
const { orgId, name, subnet } = parsedBody.data;
if (!isValidCIDR(subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
// make sure the subnet is unique
const subnetExists = await db
.select()
.from(orgs)
.where(eq(orgs.subnet, subnet))
.limit(1);
if (subnetExists.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
// make sure the orgId is unique
const orgExists = await db
@ -109,7 +134,8 @@ export async function createOrg(
.insert(orgs)
.values({
orgId,
name
name,
subnet
})
.returning();
@ -142,25 +168,25 @@ export async function createOrg(
// Get all actions and create role actions
const actionIds = await trx.select().from(actions).execute();
if (actionIds.length > 0) {
await trx
.insert(roleActions)
.values(
actionIds.map((action) => ({
roleId,
actionId: action.actionId,
orgId: newOrg[0].orgId
}))
);
await trx.insert(roleActions).values(
actionIds.map((action) => ({
roleId,
actionId: action.actionId,
orgId: newOrg[0].orgId
}))
);
}
await trx.insert(orgDomains).values(
allDomains.map((domain) => ({
orgId: newOrg[0].orgId,
domainId: domain.domainId
}))
);
if (allDomains.length) {
await trx.insert(orgDomains).values(
allDomains.map((domain) => ({
orgId: newOrg[0].orgId,
domainId: domain.domainId
}))
);
}
if (req.user) {
await trx.insert(userOrgs).values({
@ -187,7 +213,7 @@ export async function createOrg(
orgId: newOrg[0].orgId,
roleId: roleId,
isOwner: true
});
});
}
const memberRole = await trx

View file

@ -89,6 +89,8 @@ export async function deleteOrg(
.where(eq(sites.orgId, orgId))
.limit(1);
const deletedNewtIds: string[] = [];
await db.transaction(async (trx) => {
if (sites) {
for (const site of orgSites) {
@ -102,11 +104,7 @@ export async function deleteOrg(
.where(eq(newts.siteId, site.siteId))
.returning();
if (deletedNewt) {
const payload = {
type: `newt/terminate`,
data: {}
};
sendToClient(deletedNewt.newtId, payload);
deletedNewtIds.push(deletedNewt.newtId);
// delete all of the sessions for the newt
await trx
@ -131,6 +129,18 @@ export async function deleteOrg(
await trx.delete(orgs).where(eq(orgs.orgId, orgId));
});
// Send termination messages outside of transaction to prevent blocking
for (const newtId of deletedNewtIds) {
const payload = {
type: `newt/terminate`,
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(newtId, payload).catch(error => {
logger.error("Failed to send termination message to newt:", error);
});
}
return response(res, {
data: null,
success: true,

View file

@ -6,3 +6,4 @@ export * from "./listUserOrgs";
export * from "./checkId";
export * from "./getOrgOverview";
export * from "./listOrgs";
export * from "./pickOrgDefaults";

View file

@ -5,7 +5,7 @@ import { Org, orgs, userOrgs } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { sql, inArray, eq } from "drizzle-orm";
import { sql, inArray, eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { fromZodError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
@ -40,8 +40,10 @@ const listOrgsSchema = z.object({
// responses: {}
// });
type ResponseOrg = Org & { isOwner?: boolean };
export type ListUserOrgsResponse = {
orgs: Org[];
orgs: ResponseOrg[];
pagination: { total: number; limit: number; offset: number };
};
@ -106,6 +108,10 @@ export async function listUserOrgs(
.select()
.from(orgs)
.where(inArray(orgs.orgId, userOrgIds))
.leftJoin(
userOrgs,
and(eq(userOrgs.orgId, orgs.orgId), eq(userOrgs.userId, userId))
)
.limit(limit)
.offset(offset);
@ -115,9 +121,19 @@ export async function listUserOrgs(
.where(inArray(orgs.orgId, userOrgIds));
const totalCount = totalCountResult[0].count;
const responseOrgs = organizations.map((val) => {
const res = {
...val.orgs
} as ResponseOrg;
if (val.userOrgs && val.userOrgs.isOwner) {
res.isOwner = val.userOrgs.isOwner;
}
return res;
});
return response<ListUserOrgsResponse>(res, {
data: {
orgs: organizations,
orgs: responseOrgs,
pagination: {
total: totalCount,
limit,

View file

@ -0,0 +1,39 @@
import { Request, Response, NextFunction } from "express";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
import config from "@server/lib/config";
export type PickOrgDefaultsResponse = {
subnet: string;
};
export async function pickOrgDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
// TODO: Why would each org have to have its own subnet?
// const subnet = await getNextAvailableOrgSubnet();
// Just hard code the subnet for now for everyone
const subnet = config.getRawConfig().orgs.subnet_group;
return response<PickOrgDefaultsResponse>(res, {
data: {
subnet: subnet
},
success: true,
error: false,
message: "Organization defaults created successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View file

@ -21,6 +21,7 @@ import logger from "@server/logger";
import { subdomainSchema } from "@server/lib/schemas";
import config from "@server/lib/config";
import { OpenAPITags, registry } from "@server/openApi";
import { build } from "@server/build";
const createResourceParamsSchema = z
.object({
@ -36,7 +37,6 @@ const createHttpResourceSchema = z
.string()
.optional()
.transform((val) => val?.toLowerCase()),
isBaseDomain: z.boolean().optional(),
siteId: z.number(),
http: z.boolean(),
protocol: z.enum(["tcp", "udp"]),
@ -52,19 +52,6 @@ const createHttpResourceSchema = z
},
{ message: "Invalid subdomain" }
)
.refine(
(data) => {
if (!config.getRawConfig().flags?.allow_base_domain_resources) {
if (data.isBaseDomain) {
return false;
}
}
return true;
},
{
message: "Base domain resources are not allowed"
}
);
const createRawResourceSchema = z
.object({
@ -101,9 +88,12 @@ registry.registerPath({
body: {
content: {
"application/json": {
schema: createHttpResourceSchema.or(
createRawResourceSchema
)
schema:
build == "oss"
? createHttpResourceSchema.or(
createRawResourceSchema
)
: createHttpResourceSchema
}
}
}
@ -166,6 +156,14 @@ export async function createResource(
{ siteId, orgId }
);
} else {
if (!config.getRawConfig().flags?.allow_raw_resources && build == "oss") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Raw resources are not allowed"
)
);
}
return await createRawResource(
{ req, res, next },
{ siteId, orgId }
@ -203,35 +201,81 @@ async function createHttpResource(
);
}
const { name, subdomain, isBaseDomain, http, protocol, domainId } =
parsedBody.data;
const { name, subdomain, domainId } = parsedBody.data;
const [orgDomain] = await db
const [domainRes] = await db
.select()
.from(orgDomains)
.where(
.from(domains)
.where(eq(domains.domainId, domainId))
.leftJoin(
orgDomains,
and(eq(orgDomains.orgId, orgId), eq(orgDomains.domainId, domainId))
)
.leftJoin(domains, eq(orgDomains.domainId, domains.domainId));
);
if (!orgDomain || !orgDomain.domains) {
if (!domainRes || !domainRes.domains) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Domain with ID ${parsedBody.data.domainId} not found`
`Domain with ID ${domainId} not found`
)
);
}
const domain = orgDomain.domains;
if (domainRes.orgDomains && domainRes.orgDomains.orgId !== orgId) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
`Organization does not have access to domain with ID ${domainId}`
)
);
}
if (!domainRes.domains.verified) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Domain with ID ${domainRes.domains.domainId} is not verified`
)
);
}
let fullDomain = "";
if (isBaseDomain) {
fullDomain = domain.baseDomain;
} else {
fullDomain = `${subdomain}.${domain.baseDomain}`;
if (domainRes.domains.type == "ns") {
if (subdomain) {
fullDomain = `${subdomain}.${domainRes.domains.baseDomain}`;
} else {
fullDomain = domainRes.domains.baseDomain;
}
} else if (domainRes.domains.type == "cname") {
fullDomain = domainRes.domains.baseDomain;
} else if (domainRes.domains.type == "wildcard") {
if (subdomain) {
// the subdomain cant have a dot in it
const parsedSubdomain = subdomainSchema.safeParse(subdomain);
if (!parsedSubdomain.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedSubdomain.error).toString()
)
);
}
if (parsedSubdomain.data.includes(".")) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Subdomain cannot contain a dot when using wildcard domains"
)
);
}
fullDomain = `${subdomain}.${domainRes.domains.baseDomain}`;
} else {
fullDomain = domainRes.domains.baseDomain;
}
}
fullDomain = fullDomain.toLowerCase();
logger.debug(`Full domain: ${fullDomain}`);
// make sure the full domain is unique
@ -261,10 +305,10 @@ async function createHttpResource(
orgId,
name,
subdomain,
http,
protocol,
http: true,
protocol: "tcp",
ssl: true,
isBaseDomain
isBaseDomain: false
})
.returning();

View file

@ -69,7 +69,8 @@ function queryResources(
http: resources.http,
protocol: resources.protocol,
proxyPort: resources.proxyPort,
enabled: resources.enabled
enabled: resources.enabled,
domainId: resources.domainId
})
.from(resources)
.leftJoin(sites, eq(resources.siteId, sites.siteId))
@ -103,7 +104,8 @@ function queryResources(
http: resources.http,
protocol: resources.protocol,
proxyPort: resources.proxyPort,
enabled: resources.enabled
enabled: resources.enabled,
domainId: resources.domainId
})
.from(resources)
.leftJoin(sites, eq(resources.siteId, sites.siteId))

View file

@ -20,6 +20,7 @@ import { tlsNameSchema } from "@server/lib/schemas";
import { subdomainSchema } from "@server/lib/schemas";
import { registry } from "@server/openApi";
import { OpenAPITags } from "@server/openApi";
import { build } from "@server/build";
const updateResourceParamsSchema = z
.object({
@ -40,7 +41,6 @@ const updateHttpResourceBodySchema = z
sso: z.boolean().optional(),
blockAccess: z.boolean().optional(),
emailWhitelistEnabled: z.boolean().optional(),
isBaseDomain: z.boolean().optional(),
applyRules: z.boolean().optional(),
domainId: z.string().optional(),
enabled: z.boolean().optional(),
@ -61,19 +61,6 @@ const updateHttpResourceBodySchema = z
},
{ message: "Invalid subdomain" }
)
.refine(
(data) => {
if (!config.getRawConfig().flags?.allow_base_domain_resources) {
if (data.isBaseDomain) {
return false;
}
}
return true;
},
{
message: "Base domain resources are not allowed"
}
)
.refine(
(data) => {
if (data.tlsServerName) {
@ -134,9 +121,12 @@ registry.registerPath({
body: {
content: {
"application/json": {
schema: updateHttpResourceBodySchema.and(
updateRawResourceBodySchema
)
schema:
build == "oss"
? updateHttpResourceBodySchema.and(
updateRawResourceBodySchema
)
: updateHttpResourceBodySchema
}
}
}
@ -242,86 +232,120 @@ async function updateHttpResource(
const updateData = parsedBody.data;
if (updateData.domainId) {
const [existingDomain] = await db
.select()
.from(orgDomains)
.where(
and(
eq(orgDomains.orgId, org.orgId),
eq(orgDomains.domainId, updateData.domainId)
)
)
.leftJoin(domains, eq(orgDomains.domainId, domains.domainId));
const domainId = updateData.domainId;
if (!existingDomain) {
const [domainRes] = await db
.select()
.from(domains)
.where(eq(domains.domainId, domainId))
.leftJoin(
orgDomains,
and(
eq(orgDomains.orgId, resource.orgId),
eq(orgDomains.domainId, domainId)
)
);
if (!domainRes || !domainRes.domains) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Domain not found`)
createHttpError(
HttpCode.NOT_FOUND,
`Domain with ID ${updateData.domainId} not found`
)
);
}
}
const domainId = updateData.domainId || resource.domainId!;
const subdomain = updateData.subdomain || resource.subdomain;
const [domain] = await db
.select()
.from(domains)
.where(eq(domains.domainId, domainId));
const isBaseDomain =
updateData.isBaseDomain !== undefined
? updateData.isBaseDomain
: resource.isBaseDomain;
let fullDomain: string | null = null;
if (isBaseDomain) {
fullDomain = domain.baseDomain;
} else if (subdomain && domain) {
fullDomain = `${subdomain}.${domain.baseDomain}`;
}
if (fullDomain) {
const [existingDomain] = await db
.select()
.from(resources)
.where(eq(resources.fullDomain, fullDomain));
if (
existingDomain &&
existingDomain.resourceId !== resource.resourceId
domainRes.orgDomains &&
domainRes.orgDomains.orgId !== resource.orgId
) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Resource with that domain already exists"
HttpCode.FORBIDDEN,
`You do not have permission to use domain with ID ${updateData.domainId}`
)
);
}
}
const updatePayload = {
...updateData,
fullDomain
};
if (!domainRes.domains.verified) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Domain with ID ${updateData.domainId} is not verified`
)
);
}
let fullDomain = "";
if (domainRes.domains.type == "ns") {
if (updateData.subdomain) {
fullDomain = `${updateData.subdomain}.${domainRes.domains.baseDomain}`;
} else {
fullDomain = domainRes.domains.baseDomain;
}
} else if (domainRes.domains.type == "cname") {
fullDomain = domainRes.domains.baseDomain;
} else if (domainRes.domains.type == "wildcard") {
if (updateData.subdomain) {
// the subdomain cant have a dot in it
const parsedSubdomain = subdomainSchema.safeParse(updateData.subdomain);
if (!parsedSubdomain.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedSubdomain.error).toString()
)
);
}
if (parsedSubdomain.data.includes(".")) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Subdomain cannot contain a dot when using wildcard domains"
)
);
}
fullDomain = `${updateData.subdomain}.${domainRes.domains.baseDomain}`;
} else {
fullDomain = domainRes.domains.baseDomain;
}
}
fullDomain = fullDomain.toLowerCase();
logger.debug(`Full domain: ${fullDomain}`);
if (fullDomain) {
const [existingDomain] = await db
.select()
.from(resources)
.where(eq(resources.fullDomain, fullDomain));
if (
existingDomain &&
existingDomain.resourceId !== resource.resourceId
) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Resource with that domain already exists"
)
);
}
}
// update the full domain if it has changed
if (fullDomain && fullDomain !== resource.fullDomain) {
await db
.update(resources)
.set({ fullDomain })
.where(eq(resources.resourceId, resource.resourceId));
}
}
const updatedResource = await db
.update(resources)
.set({
name: updatePayload.name,
subdomain: updatePayload.subdomain,
ssl: updatePayload.ssl,
sso: updatePayload.sso,
blockAccess: updatePayload.blockAccess,
emailWhitelistEnabled: updatePayload.emailWhitelistEnabled,
isBaseDomain: updatePayload.isBaseDomain,
applyRules: updatePayload.applyRules,
domainId: updatePayload.domainId,
enabled: updatePayload.enabled,
stickySession: updatePayload.stickySession,
tlsServerName: updatePayload.tlsServerName,
setHostHeader: updatePayload.setHostHeader,
fullDomain: updatePayload.fullDomain
})
.set(updateData)
.where(eq(resources.resourceId, resource.resourceId))
.returning();

View file

@ -14,6 +14,9 @@ import { newts } from "@server/db";
import moment from "moment";
import { OpenAPITags, registry } from "@server/openApi";
import { hashPassword } from "@server/auth/password";
import { isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip";
import config from "@server/lib/config";
const createSiteParamsSchema = z
.object({
@ -35,9 +38,18 @@ const createSiteSchema = z
subnet: z.string().optional(),
newtId: z.string().optional(),
secret: z.string().optional(),
address: z.string().optional(),
type: z.enum(["newt", "wireguard", "local"])
})
.strict();
.strict()
.refine((data) => {
if (data.type === "local") {
return !config.getRawConfig().flags?.disable_local_sites;
} else if (data.type === "wireguard") {
return !config.getRawConfig().flags?.disable_basic_wireguard_sites;
}
return true;
});
export type CreateSiteBody = z.infer<typeof createSiteSchema>;
@ -84,7 +96,8 @@ export async function createSite(
pubKey,
subnet,
newtId,
secret
secret,
address
} = parsedBody.data;
const parsedParams = createSiteParamsSchema.safeParse(req.params);
@ -116,6 +129,59 @@ export async function createSite(
);
}
let updatedAddress = null;
if (address) {
if (!isValidIP(address)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
if (!isIpInCidr(address, org.subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IP is not in the CIDR range of the subnet."
)
);
}
updatedAddress = `${address}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
// make sure the subnet is unique
const addressExistsSites = await db
.select()
.from(sites)
.where(eq(sites.address, updatedAddress))
.limit(1);
if (addressExistsSites.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
const addressExistsClients = await db
.select()
.from(sites)
.where(eq(sites.subnet, updatedAddress))
.limit(1);
if (addressExistsClients.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
}
const niceId = await getUniqueSiteName(orgId);
await db.transaction(async (trx) => {
@ -139,6 +205,7 @@ export async function createSite(
exitNodeId,
name,
niceId,
// address: updatedAddress || null,
subnet,
type,
dockerSocketEnabled: type == "newt",
@ -154,6 +221,7 @@ export async function createSite(
orgId,
name,
niceId,
// address: updatedAddress || null,
type,
dockerSocketEnabled: type == "newt",
subnet: "0.0.0.0/0"

View file

@ -62,6 +62,8 @@ export async function deleteSite(
);
}
let deletedNewtId: string | null = null;
await db.transaction(async (trx) => {
if (site.pubKey) {
if (site.type == "wireguard") {
@ -73,11 +75,7 @@ export async function deleteSite(
.where(eq(newts.siteId, siteId))
.returning();
if (deletedNewt) {
const payload = {
type: `newt/terminate`,
data: {}
};
sendToClient(deletedNewt.newtId, payload);
deletedNewtId = deletedNewt.newtId;
// delete all of the sessions for the newt
await trx
@ -90,6 +88,18 @@ export async function deleteSite(
await trx.delete(sites).where(eq(sites.siteId, siteId));
});
// Send termination message outside of transaction to prevent blocking
if (deletedNewtId) {
const payload = {
type: `newt/terminate`,
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(deletedNewtId, payload).catch(error => {
logger.error("Failed to send termination message to newt:", error);
});
}
return response(res, {
data: null,
success: true,

View file

@ -1,4 +1,4 @@
import { db } from "@server/db";
import { db, newts } from "@server/db";
import { orgs, roleSites, sites, userSites } from "@server/db";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
@ -9,6 +9,42 @@ import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import NodeCache from "node-cache";
import semver from "semver";
const newtVersionCache = new NodeCache({ stdTTL: 3600 }); // 1 hours in seconds
async function getLatestNewtVersion(): Promise<string | null> {
try {
const cachedVersion = newtVersionCache.get<string>("latestNewtVersion");
if (cachedVersion) {
return cachedVersion;
}
const response = await fetch(
"https://api.github.com/repos/fosrl/newt/tags"
);
if (!response.ok) {
logger.warn("Failed to fetch latest Newt version from GitHub");
return null;
}
const tags = await response.json();
if (!Array.isArray(tags) || tags.length === 0) {
logger.warn("No tags found for Newt repository");
return null;
}
const latestVersion = tags[0].name;
newtVersionCache.set("latestNewtVersion", latestVersion);
return latestVersion;
} catch (error) {
logger.error("Error fetching latest Newt version:", error);
return null;
}
}
const listSitesParamsSchema = z
.object({
@ -43,10 +79,13 @@ function querySites(orgId: string, accessibleSiteIds: number[]) {
megabytesOut: sites.megabytesOut,
orgName: orgs.name,
type: sites.type,
online: sites.online
online: sites.online,
address: sites.address,
newtVersion: newts.version
})
.from(sites)
.leftJoin(orgs, eq(sites.orgId, orgs.orgId))
.leftJoin(newts, eq(newts.siteId, sites.siteId))
.where(
and(
inArray(sites.siteId, accessibleSiteIds),
@ -55,8 +94,12 @@ function querySites(orgId: string, accessibleSiteIds: number[]) {
);
}
type SiteWithUpdateAvailable = Awaited<ReturnType<typeof querySites>>[0] & {
newtUpdateAvailable?: boolean;
};
export type ListSitesResponse = {
sites: Awaited<ReturnType<typeof querySites>>;
sites: SiteWithUpdateAvailable[];
pagination: { total: number; limit: number; offset: number };
};
@ -147,9 +190,36 @@ export async function listSites(
const totalCountResult = await countQuery;
const totalCount = totalCountResult[0].count;
const latestNewtVersion = await getLatestNewtVersion();
const sitesWithUpdates: SiteWithUpdateAvailable[] = sitesList.map(
(site) => {
const siteWithUpdate: SiteWithUpdateAvailable = { ...site };
if (
site.type === "newt" &&
site.newtVersion &&
latestNewtVersion
) {
try {
siteWithUpdate.newtUpdateAvailable = semver.lt(
site.newtVersion,
latestNewtVersion
);
} catch (error) {
siteWithUpdate.newtUpdateAvailable = false;
}
} else {
siteWithUpdate.newtUpdateAvailable = false;
}
return siteWithUpdate;
}
);
return response<ListSitesResponse>(res, {
data: {
sites: sitesList,
sites: sitesWithUpdates,
pagination: {
total: totalCount,
limit,

Some files were not shown because too many files have changed in this diff Show more