diff --git a/server/db/schema.ts b/server/db/schema.ts index 36b384b7..e784c523 100644 --- a/server/db/schema.ts +++ b/server/db/schema.ts @@ -130,8 +130,10 @@ export const userOrgs = sqliteTable("userOrgs", { .notNull() .references(() => users.userId), orgId: text("orgId") - .notNull() - .references(() => orgs.orgId), + .references(() => orgs.orgId, { + onDelete: "cascade" + }) + .notNull(), roleId: integer("roleId") .notNull() .references(() => roles.roleId), diff --git a/server/middlewares/verifyUserIsOrgOwner.ts b/server/middlewares/verifyUserIsOrgOwner.ts index 1b89ba67..49ddafc6 100644 --- a/server/middlewares/verifyUserIsOrgOwner.ts +++ b/server/middlewares/verifyUserIsOrgOwner.ts @@ -27,7 +27,6 @@ export async function verifyUserIsOrgOwner( ) ); } - try { if (!req.userOrg) { const res = await db @@ -56,6 +55,8 @@ export async function verifyUserIsOrgOwner( ) ); } + + return next(); } catch (e) { return next( createHttpError( diff --git a/server/routers/org/deleteOrg.ts b/server/routers/org/deleteOrg.ts index c9fdbc39..3ebea645 100644 --- a/server/routers/org/deleteOrg.ts +++ b/server/routers/org/deleteOrg.ts @@ -24,6 +24,10 @@ const deleteOrgSchema = z }) .strict(); +export type DeleteOrgResponse = { + +} + export async function deleteOrg( req: Request, res: Response, @@ -41,7 +45,6 @@ export async function deleteOrg( } const { orgId } = parsedParams.data; - // Check if the user has permission to list sites const hasPermission = await checkUserActionPermission( ActionsEnum.deleteOrg, @@ -55,7 +58,6 @@ export async function deleteOrg( ) ); } - const [org] = await db .select() .from(orgs) @@ -70,7 +72,6 @@ export async function deleteOrg( ) ); } - // we need to handle deleting each site const orgSites = await db .select() @@ -97,20 +98,20 @@ export async function deleteOrg( sendToClient(deletedNewt.newtId, payload); // delete all of the sessions for the newt - db.delete(newtSessions) + await db.delete(newtSessions) .where( eq(newtSessions.newtId, deletedNewt.newtId) - ) - .run(); + ); } } } - db.delete(sites).where(eq(sites.siteId, site.siteId)).run(); + logger.info(`Deleting site ${site.siteId}`); + await db.delete(sites).where(eq(sites.siteId, site.siteId)) } } - await db.delete(orgs).where(eq(orgs.orgId, orgId)).returning(); + await db.delete(orgs).where(eq(orgs.orgId, orgId)); return response(res, { data: null, diff --git a/server/routers/site/getSite.ts b/server/routers/site/getSite.ts index 6c259e3d..5bc25e09 100644 --- a/server/routers/site/getSite.ts +++ b/server/routers/site/getSite.ts @@ -28,6 +28,7 @@ export type GetSiteResponse = { name: string; subdomain: string; subnet: string; + type: string; }; export async function getSite( @@ -81,7 +82,8 @@ export async function getSite( siteId: site[0].siteId, niceId: site[0].niceId, name: site[0].name, - subnet: site[0].subnet + subnet: site[0].subnet, + type: site[0].type }, success: true, error: false, diff --git a/src/app/[orgId]/settings/general/page.tsx b/src/app/[orgId]/settings/general/page.tsx index 213542a5..d134715f 100644 --- a/src/app/[orgId]/settings/general/page.tsx +++ b/src/app/[orgId]/settings/general/page.tsx @@ -30,6 +30,9 @@ import { CardHeader, CardTitle } from "@/components/ui/card"; +import { AxiosResponse } from "axios"; +import { DeleteOrgResponse, ListOrgsResponse } from "@server/routers/org"; +import { redirect, useRouter } from "next/navigation"; const GeneralFormSchema = z.object({ name: z.string() @@ -41,6 +44,7 @@ export default function GeneralPage() { const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); const { orgUser } = userOrgUserContext(); + const router = useRouter(); const { org } = useOrgContext(); const { toast } = useToast(); const api = createApiClient(useEnvContext()); @@ -54,16 +58,54 @@ export default function GeneralPage() { }); async function deleteOrg() { - await api.delete(`/org/${org?.org.orgId}`).catch((e) => { + try { + const res = await api.delete>( + `/org/${org?.org.orgId}` + ); + if (res.status === 200) { + pickNewOrgAndNavigate(); + } + } catch (err) { + console.error(err); toast({ variant: "destructive", title: "Failed to delete org", description: formatAxiosError( - e, + err, "An error occurred while deleting the org." ) }); - }); + } + } + + async function pickNewOrgAndNavigate() { + try { + + const res = await api.get>( + `/orgs` + ); + + if (res.status === 200) { + if (res.data.data.orgs.length > 0) { + const orgId = res.data.data.orgs[0].orgId; + // go to `/${orgId}/settings`); + router.push(`/${orgId}/settings`); + } else { + // go to `/setup` + router.push("/setup"); + } + } + } catch (err) { + console.error(err); + toast({ + variant: "destructive", + title: "Failed to fetch orgs", + description: formatAxiosError( + err, + "An error occurred while listing your orgs" + ) + }); + } } async function onSubmit(data: GeneralFormValues) { diff --git a/src/app/[orgId]/settings/resources/[resourceId]/connectivity/page.tsx b/src/app/[orgId]/settings/resources/[resourceId]/connectivity/page.tsx index 441fdae9..31f49564 100644 --- a/src/app/[orgId]/settings/resources/[resourceId]/connectivity/page.tsx +++ b/src/app/[orgId]/settings/resources/[resourceId]/connectivity/page.tsx @@ -51,6 +51,7 @@ import { ArrayElement } from "@server/types/ArrayElement"; import { formatAxiosError } from "@app/lib/utils"; import { useEnvContext } from "@app/hooks/useEnvContext"; import { createApiClient } from "@app/api"; +import { GetSiteResponse } from "@server/routers/site"; const addTargetSchema = z.object({ ip: z.string().ip(), @@ -85,6 +86,7 @@ export default function ReverseProxyTargets(props: { const api = createApiClient(useEnvContext()); const [targets, setTargets] = useState([]); + const [site, setSite] = useState(); const [targetsToRemove, setTargetsToRemove] = useState([]); const [sslEnabled, setSslEnabled] = useState(resource.ssl); @@ -103,7 +105,7 @@ export default function ReverseProxyTargets(props: { }); useEffect(() => { - const fetchSites = async () => { + const fetchTargets = async () => { try { const res = await api.get>( `/resource/${params.resourceId}/targets`, @@ -126,7 +128,30 @@ export default function ReverseProxyTargets(props: { setPageLoading(false); } }; - fetchSites(); + fetchTargets(); + + const fetchSite = async () => { + try { + const res = await api.get>( + `/site/${resource.siteId}`, + ); + + if (res.status === 200) { + setSite(res.data.data); + } + } catch (err) { + console.error(err); + toast({ + variant: "destructive", + title: "Failed to fetch resource", + description: formatAxiosError( + err, + "An error occurred while fetching resource", + ), + }); + } + } + fetchSite(); }, []); async function addTarget(data: AddTargetFormValues) { @@ -146,6 +171,20 @@ export default function ReverseProxyTargets(props: { return; } + if (site && site.type == "wireguard" && site.subnet) { + // make sure that the target IP is within the site subnet + const targetIp = data.ip; + const subnet = site.subnet; + if (!isIPInSubnet(targetIp, subnet)) { + toast({ + variant: "destructive", + title: "Invalid target IP", + description: "Target IP must be within the site subnet", + }); + return; + } + } + const newTarget: LocalTarget = { ...data, enabled: true, @@ -602,3 +641,40 @@ export default function ReverseProxyTargets(props: { ); } + +function isIPInSubnet(subnet: string, ip: string): boolean { + // Split subnet into IP and mask parts + const [subnetIP, maskBits] = subnet.split('/'); + const mask = parseInt(maskBits); + + if (mask < 0 || mask > 32) { + throw new Error('Invalid subnet mask. Must be between 0 and 32.'); + } + + // Convert IP addresses to binary numbers + const subnetNum = ipToNumber(subnetIP); + const ipNum = ipToNumber(ip); + + // Calculate subnet mask + const maskNum = mask === 32 ? -1 : ~((1 << (32 - mask)) - 1); + + // Check if the IP is in the subnet + return (subnetNum & maskNum) === (ipNum & maskNum); +} + +function ipToNumber(ip: string): number { + // Validate IP address format + const parts = ip.split('.'); + if (parts.length !== 4) { + throw new Error('Invalid IP address format'); + } + + // Convert IP octets to 32-bit number + return parts.reduce((num, octet) => { + const oct = parseInt(octet); + if (isNaN(oct) || oct < 0 || oct > 255) { + throw new Error('Invalid IP address octet'); + } + return (num << 8) + oct; + }, 0); +}