diff --git a/packages/api/bindings.d.ts b/packages/api/bindings.d.ts index 342afdd..281cb27 100644 --- a/packages/api/bindings.d.ts +++ b/packages/api/bindings.d.ts @@ -6,6 +6,7 @@ declare global { const UI_PUBLIC_URI: string; const API_PUBLIC_URI: string; const ROOT_USERS: string; + const ALLOWED_CALLBACK_HOSTS: string; const KV_SESSIONS: KVNamespace; const KV_GUILDS: KVNamespace; diff --git a/packages/api/handlers/login-bounce.ts b/packages/api/handlers/login-bounce.ts index 5fcf4c9..3a51656 100644 --- a/packages/api/handlers/login-bounce.ts +++ b/packages/api/handlers/login-bounce.ts @@ -1,4 +1,5 @@ -import KSUID from 'ksuid'; +import { StateSession } from '@roleypoly/types'; +import { getQuery, isAllowedCallbackHost, setupStateSession } from '../utils/api-tools'; import { Bounce } from '../utils/bounce'; import { apiPublicURI, botClientID } from '../utils/config'; @@ -16,9 +17,17 @@ const buildURL = (params: URLParams) => )}&state=${params.state}`; export const LoginBounce = async (request: Request): Promise => { - const state = await KSUID.random(); + const stateSessionData: StateSession = {}; + + const { cbh: callbackHost } = getQuery(request); + if (callbackHost && isAllowedCallbackHost(callbackHost)) { + stateSessionData.callbackHost = callbackHost; + } + + const state = await setupStateSession(stateSessionData); + const redirectURI = `${apiPublicURI}/login-callback`; const clientID = botClientID; - return Bounce(buildURL({ state: state.string, redirectURI, clientID })); + return Bounce(buildURL({ state, redirectURI, clientID })); }; diff --git a/packages/api/handlers/login-callback.ts b/packages/api/handlers/login-callback.ts index c553422..71a3d34 100644 --- a/packages/api/handlers/login-callback.ts +++ b/packages/api/handlers/login-callback.ts @@ -1,9 +1,17 @@ -import { AuthTokenResponse, DiscordUser, GuildSlug, SessionData } from '@roleypoly/types'; +import { + AuthTokenResponse, + DiscordUser, + GuildSlug, + SessionData, + StateSession, +} from '@roleypoly/types'; import KSUID from 'ksuid'; import { AuthType, discordFetch, formData, + getStateSession, + isAllowedCallbackHost, parsePermissions, resolveFailures, userAgent, @@ -21,8 +29,9 @@ const AuthErrorResponse = (extra?: string) => export const LoginCallback = resolveFailures( AuthErrorResponse, async (request: Request): Promise => { - const query = new URL(request.url).searchParams; + let bounceBaseUrl = uiPublicURI; + const query = new URL(request.url).searchParams; const stateValue = query.get('state'); if (stateValue === null) { @@ -37,6 +46,14 @@ export const LoginCallback = resolveFailures( if (currentTime > stateExpiry) { return AuthErrorResponse('state expired'); } + + const stateSession = await getStateSession(state.string); + if ( + stateSession?.callbackHost && + isAllowedCallbackHost(stateSession.callbackHost) + ) { + bounceBaseUrl = stateSession.callbackHost; + } } catch (e) { return AuthErrorResponse('state invalid'); } @@ -90,7 +107,7 @@ export const LoginCallback = resolveFailures( await Sessions.put(sessionID.string, sessionData, 60 * 60 * 6); return Bounce( - uiPublicURI + '/machinery/new-session?session_id=' + sessionID.string + bounceBaseUrl + '/machinery/new-session?session_id=' + sessionID.string ); } ); diff --git a/packages/api/utils/api-tools.ts b/packages/api/utils/api-tools.ts index f251788..61bbf1a 100644 --- a/packages/api/utils/api-tools.ts +++ b/packages/api/utils/api-tools.ts @@ -3,8 +3,9 @@ import { permissions as Permissions, } from '@roleypoly/misc-utils/hasPermission'; import { SessionData, UserGuildPermissions } from '@roleypoly/types'; +import KSUID from 'ksuid'; import { Handler } from '../router'; -import { rootUsers, uiPublicURI } from './config'; +import { allowedCallbackHosts, apiPublicURI, rootUsers } from './config'; import { Sessions, WrappedKVNamespace } from './kv'; export const formData = (obj: Record): string => { @@ -17,7 +18,7 @@ export const addCORS = (init: ResponseInit = {}) => ({ ...init, headers: { ...(init.headers || {}), - 'access-control-allow-origin': uiPublicURI, + 'access-control-allow-origin': '*', 'access-control-allow-methods': '*', 'access-control-allow-headers': '*', }, @@ -159,6 +160,20 @@ export const withSession = ( return await wrappedHandler(session)(request); }; +export const setupStateSession = async (data: T): Promise => { + const stateID = (await KSUID.random()).string; + + await Sessions.put(`state_${stateID}`, { data }, 60 * 5); + + return stateID; +}; + +export const getStateSession = async (stateID: string): Promise => { + const stateSession = await Sessions.get<{ data: T }>(`state_${stateID}`); + + return stateSession?.data; +}; + export const isRoot = (userID: string): boolean => rootUsers.includes(userID); export const onlyRootUsers = (handler: Handler): Handler => @@ -176,3 +191,17 @@ export const onlyRootUsers = (handler: Handler): Handler => } ); }); + +export const getQuery = (request: Request): { [x: string]: string } => { + const output: { [x: string]: string } = {}; + + for (let [key, value] of new URL(request.url).searchParams.entries()) { + output[key] = value; + } + + return output; +}; + +export const isAllowedCallbackHost = (host: string): boolean => { + return host === apiPublicURI || allowedCallbackHosts.includes(host); +}; diff --git a/packages/api/utils/config.ts b/packages/api/utils/config.ts index c089a1e..1fdb774 100644 --- a/packages/api/utils/config.ts +++ b/packages/api/utils/config.ts @@ -11,3 +11,4 @@ export const botToken = env('BOT_TOKEN'); export const uiPublicURI = safeURI(env('UI_PUBLIC_URI')); export const apiPublicURI = safeURI(env('API_PUBLIC_URI')); export const rootUsers = list(env('ROOT_USERS')); +export const allowedCallbackHosts = list(env('ALLOWED_CALLBACK_HOSTS')); diff --git a/packages/api/worker.config.js b/packages/api/worker.config.js index fe82c13..f26bde6 100644 --- a/packages/api/worker.config.js +++ b/packages/api/worker.config.js @@ -10,6 +10,7 @@ module.exports = { 'UI_PUBLIC_URI', 'API_PUBLIC_URI', 'ROOT_USERS', + 'ALLOWED_CALLBACK_HOSTS', ]), kv: ['KV_SESSIONS', 'KV_GUILDS', 'KV_GUILD_DATA'], }; diff --git a/packages/types/Session.ts b/packages/types/Session.ts index e7ae51c..1749a3a 100644 --- a/packages/types/Session.ts +++ b/packages/types/Session.ts @@ -16,3 +16,7 @@ export type SessionData = { user: DiscordUser; guilds: GuildSlug[]; }; + +export type StateSession = { + callbackHost?: string; +}; diff --git a/packages/web/src/pages/auth/login.tsx b/packages/web/src/pages/auth/login.tsx index bf45a66..f0d2a6d 100644 --- a/packages/web/src/pages/auth/login.tsx +++ b/packages/web/src/pages/auth/login.tsx @@ -11,9 +11,10 @@ const Login = () => { React.useEffect(() => { const url = new URL(window.location.href); + const callbackHost = `${url.protocol}://${url.host}`; const redirectServerID = url.searchParams.get('r'); if (!redirectServerID) { - window.location.href = `${apiUrl}/login-bounce`; + window.location.href = `${apiUrl}/login-bounce?cbh=${callbackHost}`; return; }