From e41c0c9d7cb23937a10a30e6c9338c68971b5859 Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Sat, 15 Feb 2025 22:41:39 -0500 Subject: [PATCH] sync config managed domains to db --- server/db/schema.ts | 14 +++-- server/lib/config.ts | 23 ++------ server/setup/copyInConfig.ts | 108 +++++++++++++++++++++++++++++------ 3 files changed, 106 insertions(+), 39 deletions(-) diff --git a/server/db/schema.ts b/server/db/schema.ts index a583e4ad..6cbe0fc5 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -2,21 +2,23 @@ import { InferSelectModel } from "drizzle-orm"; import { sqliteTable, text, integer } from "drizzle-orm/sqlite-core"; export const domains = sqliteTable("domains", { - domainId: integer("domainId").primaryKey({ autoIncrement: true }), - baseDomain: text("domain").notNull().unique() + domainId: text("domainId").primaryKey(), + baseDomain: text("baseDomain").notNull().unique(), + configManaged: integer("configManaged", { mode: "boolean" }) + .notNull() + .default(false) }); export const orgs = sqliteTable("orgs", { orgId: text("orgId").primaryKey(), - name: text("name").notNull(), - domain: text("domain").notNull() + name: text("name").notNull() }); export const orgDomains = sqliteTable("orgDomains", { orgId: text("orgId") .notNull() .references(() => orgs.orgId, { onDelete: "cascade" }), - domainId: integer("domainId") + domainId: text("domainId") .notNull() .references(() => domains.domainId, { onDelete: "cascade" }) }); @@ -57,7 +59,7 @@ export const resources = sqliteTable("resources", { name: text("name").notNull(), subdomain: text("subdomain"), fullDomain: text("fullDomain"), - domainId: integer("domainId").references(() => domains.domainId, { + domainId: text("domainId").references(() => domains.domainId, { onDelete: "set null" }), ssl: integer("ssl", { mode: "boolean" }).notNull().default(false), diff --git a/server/lib/config.ts b/server/lib/config.ts index bf1f17dc..04f00335 100644 --- a/server/lib/config.ts +++ b/server/lib/config.ts @@ -38,23 +38,12 @@ const configSchema = z.object({ save_logs: z.boolean(), log_failed_attempts: z.boolean().optional() }), - domains: z - .array( - z.object({ - base_domain: hostnameSchema.transform((url) => - url.toLowerCase() - ) - }) - ) - .refine( - (data) => { - const baseDomains = data.map((d) => d.base_domain); - return new Set(baseDomains).size === baseDomains.length; - }, - { - message: "Base domains must be unique" - } - ), + domains: z.record( + z.string(), + z.object({ + base_domain: hostnameSchema.transform((url) => url.toLowerCase()) + }) + ), server: z.object({ external_port: portSchema .optional() diff --git a/server/setup/copyInConfig.ts b/server/setup/copyInConfig.ts index 8f3af8d6..88d7bcdc 100644 --- a/server/setup/copyInConfig.ts +++ b/server/setup/copyInConfig.ts @@ -1,40 +1,116 @@ import { db } from "@server/db"; -import { exitNodes, orgs, resources } from "../db/schema"; +import { domains, exitNodes, orgDomains, orgs, resources } from "../db/schema"; import config from "@server/lib/config"; import { eq, ne } from "drizzle-orm"; import logger from "@server/logger"; export async function copyInConfig() { - const domain = config.getBaseDomain(); const endpoint = config.getRawConfig().gerbil.base_endpoint; const listenPort = config.getRawConfig().gerbil.start_port; - // update the domain on all of the orgs where the domain is not equal to the new domain - // TODO: eventually each org could have a unique domain that we do not want to overwrite, so this will be unnecessary - await db.update(orgs).set({ domain }).where(ne(orgs.domain, domain)); - - // TODO: eventually each exit node could have a different endpoint - await db.update(exitNodes).set({ endpoint }).where(ne(exitNodes.endpoint, endpoint)); - // TODO: eventually each exit node could have a different port - await db.update(exitNodes).set({ listenPort }).where(ne(exitNodes.listenPort, listenPort)); - - // update all resources fullDomain to use the new domain await db.transaction(async (trx) => { - const allResources = await trx.select().from(resources); + const rawDomains = config.getRawConfig().domains; + + const configDomains = Object.entries(rawDomains).map( + ([key, value]) => ({ + domainId: key, + baseDomain: value.base_domain.toLowerCase() + }) + ); + + const existingDomains = await trx + .select() + .from(domains) + .where(eq(domains.configManaged, true)); + const existingDomainKeys = new Set( + existingDomains.map((d) => d.domainId) + ); + + const configDomainKeys = new Set(configDomains.map((d) => d.domainId)); + for (const existingDomain of existingDomains) { + if (!configDomainKeys.has(existingDomain.domainId)) { + await trx + .delete(domains) + .where(eq(domains.domainId, existingDomain.domainId)) + .execute(); + } + } + + for (const { domainId, baseDomain } of configDomains) { + if (existingDomainKeys.has(domainId)) { + await trx + .update(domains) + .set({ baseDomain }) + .where(eq(domains.domainId, domainId)) + .execute(); + } else { + await trx + .insert(domains) + .values({ domainId, baseDomain, configManaged: true }) + .execute(); + } + } + + const allResources = await trx + .select() + .from(resources) + .leftJoin(domains, eq(domains.domainId, resources.domainId)); + + for (const { resources: resource, domains: domain } of allResources) { + if (!resource || !domain) { + continue; + } + + if (!domain.configManaged) { + continue; + } - for (const resource of allResources) { let fullDomain = ""; if (resource.isBaseDomain) { - fullDomain = domain; + fullDomain = domain.baseDomain; } else { fullDomain = `${resource.subdomain}.${domain}`; } + await trx .update(resources) .set({ fullDomain }) .where(eq(resources.resourceId, resource.resourceId)); } + + const allOrgs = await trx.select().from(orgs); + + const existingOrgDomains = await trx.select().from(orgDomains); + const existingOrgDomainSet = new Set( + existingOrgDomains.map((od) => `${od.orgId}-${od.domainId}`) + ); + + const newOrgDomains = []; + for (const org of allOrgs) { + for (const domain of configDomains) { + const key = `${org.orgId}-${domain.domainId}`; + if (!existingOrgDomainSet.has(key)) { + newOrgDomains.push({ + orgId: org.orgId, + domainId: domain.domainId + }); + } + } + } + + if (newOrgDomains.length > 0) { + await trx.insert(orgDomains).values(newOrgDomains).execute(); + } }); - logger.info(`Updated orgs with new domain (${domain})`); + // TODO: eventually each exit node could have a different endpoint + await db + .update(exitNodes) + .set({ endpoint }) + .where(ne(exitNodes.endpoint, endpoint)); + // TODO: eventually each exit node could have a different port + await db + .update(exitNodes) + .set({ listenPort }) + .where(ne(exitNodes.listenPort, listenPort)); }