diff --git a/authmiddleware/authmiddleware.go b/authmiddleware/authmiddleware.go new file mode 100644 index 0000000..0146d54 --- /dev/null +++ b/authmiddleware/authmiddleware.go @@ -0,0 +1,130 @@ +package authmiddleware + +import ( + "errors" + "fmt" + "time" + + "slices" + + "git.sapphic.engineer/roleypoly/v4/discord" + "git.sapphic.engineer/roleypoly/v4/types" + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/session" +) + +type AuthMiddleware struct { + Client discord.IDiscordClient + + supportIDs []string + superuserIDs []string +} + +type Session struct { + Permissions Permission + User *types.DiscordUser + AccessToken string + LastRefresh time.Time +} + +var SessionKey uint8 + +func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error { + am := AuthMiddleware{ + Client: discordClient, + supportIDs: supportIDs, + superuserIDs: superuserIDs, + } + + return am.Handle +} + +var DefaultSession *Session = &Session{Permissions: PermAnonymous} + +func (am *AuthMiddleware) Handle(c fiber.Ctx) error { + sc := session.FromContext(c) + sess := am.Init(sc) + + am.validateAccessToken(sess) // this will remove AccessToken if its no good + am.setPermissions(sess) // this captures this + am.Commit(sc, sess) // and save our first bits? + + return c.Next() +} + +func (am *AuthMiddleware) Init(sess *session.Middleware) *Session { + session, ok := sess.Get(SessionKey).(*Session) + if !ok { + am.Commit(sess, DefaultSession) + return DefaultSession + } + + return session +} + +func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) { + sc.Set(SessionKey, sess) +} + +func SessionFrom(c fiber.Ctx) (s *Session) { + sess := session.FromContext(c) + s, _ = sess.Get(SessionKey).(*Session) + + return s +} + +func (am *AuthMiddleware) isSupport(userID string) bool { + return slices.Contains(am.supportIDs, userID) +} + +func (am *AuthMiddleware) isSuperuser(userID string) bool { + return slices.Contains(am.superuserIDs, userID) +} + +func (am *AuthMiddleware) setPermissions(sess *Session) { + sess.Permissions = PermAnonymous + + if sess.AccessToken != "" { + sess.Permissions = PermUser + } + + if am.isSupport(sess.User.ID) { + sess.Permissions = PermSupport + } + + if am.isSuperuser(sess.User.ID) { + sess.Permissions = PermSuperuser + } +} + +func (am *AuthMiddleware) validateAccessToken(sess *Session) { + if sess.AccessToken == "" { + return + } + + if sess.LastRefresh.Add(time.Hour).Before(time.Now()) { + user, err := am.GetCurrentUser(sess.AccessToken) + if err != nil { + if errors.Is(err, discord.ErrUnauthorized) { + sess.AccessToken = "" + return + } + } + + sess.User = user + } +} + +func (am *AuthMiddleware) GetCurrentUser(accessToken string) (*types.DiscordUser, error) { + req := discord.NewRequest("GET", "/users/@me") + am.Client.ClientAuth(req, accessToken) + + resp, err := am.Client.Do(req) + if err != nil { + return nil, fmt.Errorf("authmiddleware.GetCurrentUser: request failed: %w", err) + } + + var user types.DiscordUser + err = discord.OutputResponse(resp, &user) + return &user, err +} diff --git a/authmiddleware/authmiddleware_test.go b/authmiddleware/authmiddleware_test.go new file mode 100644 index 0000000..292389e --- /dev/null +++ b/authmiddleware/authmiddleware_test.go @@ -0,0 +1,80 @@ +package authmiddleware_test + +import ( + "bytes" + "net/http" + "testing" + + "git.sapphic.engineer/roleypoly/v4/authmiddleware" + "git.sapphic.engineer/roleypoly/v4/discord" + "git.sapphic.engineer/roleypoly/v4/discord/clientmock" + "github.com/goccy/go-json" + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/session" +) + +func TestAnonymous(t *testing.T) { + dc := clientmock.NewDiscordClientMock() + app := getApp(dc) + + setSession(app, dc, authmiddleware.Session{ + Permissions: authmiddleware.PermAnonymous, + }) +} + +func getApp(dc discord.IDiscordClient) *fiber.App { + app := fiber.New(fiber.Config{}) + sessionMiddleware, sessionStore := session.NewWithStore() + sessionStore.RegisterType(authmiddleware.Session{}) + + app.Use(sessionMiddleware, authmiddleware.New(dc, []string{}, []string{})) + app.Get("/", authState) + app.Post("/updateSession", updateSession) + + return app +} + +func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session) error { + body := bytes.Buffer{} + json.NewEncoder(&body).Encode(sess) + + // do mocks here + + req, _ := http.NewRequest("POST", "/updateSession", &body) + _, err := app.Test(req) + return err +} + +func authState(c fiber.Ctx) error { + s := authmiddleware.SessionFrom(c) + + permList := []string{} + + if s.Permissions >= authmiddleware.PermAnonymous { + permList = append(permList, "anonymous") + } + + if s.Permissions >= authmiddleware.PermUser { + permList = append(permList, "user") + } + + if s.Permissions >= authmiddleware.PermSupport { + permList = append(permList, "support") + } + + if s.Permissions >= authmiddleware.PermSuperuser { + permList = append(permList, "superuser") + } + + return c.JSON(permList) +} + +func updateSession(c fiber.Ctx) error { + var newSession *authmiddleware.Session + c.Bind().JSON(&newSession) + + sc := session.FromContext(c) + sc.Set(authmiddleware.SessionKey, newSession) + + return c.SendString("ok") +} diff --git a/authmiddleware/const.go b/authmiddleware/const.go new file mode 100644 index 0000000..ea4b93c --- /dev/null +++ b/authmiddleware/const.go @@ -0,0 +1,10 @@ +package authmiddleware + +type Permission uint8 + +const ( + PermAnonymous Permission = 1 << iota + PermUser + PermSupport + PermSuperuser +) diff --git a/discord/clientmock/discord_client_mock.go b/discord/clientmock/discord_client_mock.go index c0a640d..6d6f523 100644 --- a/discord/clientmock/discord_client_mock.go +++ b/discord/clientmock/discord_client_mock.go @@ -42,6 +42,10 @@ func (c *DiscordClientMock) BotAuth(req *http.Request) { c.Called(req) } +func (c *DiscordClientMock) ClientAuth(req *http.Request, accessToken string) { + c.Called(req, accessToken) +} + func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) { body := bytes.Buffer{} json.NewEncoder(&body).Encode(data) diff --git a/discord/discord_client.go b/discord/discord_client.go index d3f832f..7eb7c8b 100644 --- a/discord/discord_client.go +++ b/discord/discord_client.go @@ -1,6 +1,7 @@ package discord import ( + "errors" "fmt" "net/http" "net/url" @@ -11,9 +12,14 @@ import ( const DiscordBaseUrl = "https://discord.com/api/v10" +var ( + ErrUnauthorized = errors.New("discord: got a 401 from discord") +) + type IDiscordClient interface { Do(req *http.Request) (*http.Response, error) BotAuth(req *http.Request) + ClientAuth(req *http.Request, accessToken string) GetClientID() string } @@ -44,6 +50,10 @@ func (d *DiscordClient) BotAuth(req *http.Request) { req.Header.Set("Authorization", fmt.Sprintf("Bot %s", d.BotToken)) } +func (d *DiscordClient) ClientAuth(req *http.Request, accessToken string) { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", d.BotToken)) +} + func NewRequest(method string, path string) *http.Request { url, err := url.Parse(fmt.Sprintf("%s%s", DiscordBaseUrl, path)) if err != nil { @@ -62,6 +72,10 @@ func NewRequest(method string, path string) *http.Request { } func OutputResponse(resp *http.Response, dst any) error { + if resp.StatusCode == 401 { + return ErrUnauthorized + } + // TODO: more checks? return json.NewDecoder(resp.Body).Decode(dst) } diff --git a/interactions/cmd_roleypoly_test.go b/interactions/cmd_roleypoly_test.go index 869a369..c568332 100644 --- a/interactions/cmd_roleypoly_test.go +++ b/interactions/cmd_roleypoly_test.go @@ -2,6 +2,7 @@ package interactions_test import ( "fmt" + "net/url" "testing" "git.sapphic.engineer/roleypoly/v4/interactions" @@ -28,6 +29,10 @@ func TestCmdRoleypoly(t *testing.T) { button := ir.Data.Components[0] assert.Equal(t, fmt.Sprintf("%s/s/guild-id", i.PublicBaseURL), button.URL) + u, _ := url.Parse(i.PublicBaseURL) + hostname := u.Host + assert.Equal(t, fmt.Sprintf("Pick roles on %s", hostname), button.Label) + // test the command mentions tests := map[string]string{ "See all the roles": "2:pickable-roles>", diff --git a/roleypoly/fiber.go b/roleypoly/fiber.go index b02463a..33a5706 100644 --- a/roleypoly/fiber.go +++ b/roleypoly/fiber.go @@ -1,15 +1,22 @@ package roleypoly import ( + "log" "net/http" + "strings" + "time" "github.com/goccy/go-json" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/csrf" "github.com/gofiber/fiber/v3/middleware/session" + "github.com/gofiber/fiber/v3/middleware/static" "github.com/gofiber/template/html/v2" + "git.sapphic.engineer/roleypoly/v4/authmiddleware" "git.sapphic.engineer/roleypoly/v4/discord" "git.sapphic.engineer/roleypoly/v4/interactions" + staticfs "git.sapphic.engineer/roleypoly/v4/static" "git.sapphic.engineer/roleypoly/v4/templates" "git.sapphic.engineer/roleypoly/v4/testing" ) @@ -24,11 +31,38 @@ func CreateFiberApp() *fiber.App { ViewsLayout: "layouts/main", }) - app.Use(session.New(session.Config{})) + sessionMiddleware, sessionStore := session.NewWithStore() + sessionStore.RegisterType(authmiddleware.Session{}) + + app.Use(sessionMiddleware) + app.Use(csrf.New(csrf.Config{ + Session: sessionStore, + })) + + app.Get("/static*", static.New("", static.Config{ + FS: staticfs.FS, + Compress: true, + CacheDuration: 24 * 8 * time.Hour, + })) + app.Get("/favicon.ico", oneStatic) + app.Get("/manifest.json", oneStatic) + app.Get("/robots.txt", oneStatic) + app.Get("/humans.txt", oneStatic) return app } +func oneStatic(c fiber.Ctx) error { + path := strings.Replace(c.OriginalURL(), "/", "", 1) + f, err := staticfs.FS.Open(path) + if err != nil { + log.Println("oneStatic:", c.OriginalURL(), " failed: ", err) + return c.SendStatus(500) + } + + return c.SendStream(f) +} + func SetupControllers(app *fiber.App, dc *discord.DiscordClient, publicKey string, publicBaseURL string) { gs := discord.NewGuildService(dc) @@ -37,6 +71,8 @@ func SetupControllers(app *fiber.App, dc *discord.DiscordClient, publicKey strin }).Routes(app.Group("/testing")) interactions.NewInteractions(publicKey, publicBaseURL, gs).Routes(app.Group("/interactions")) + + app.Use(authmiddleware.New(dc, []string{}, []string{})) } // func getStorageDirectory() string { diff --git a/static/embed.go b/static/embed.go new file mode 100644 index 0000000..0085f72 --- /dev/null +++ b/static/embed.go @@ -0,0 +1,9 @@ +package static + +import "embed" + +//go:embed * +var FS embed.FS + +// hi :3 +// dont look its rude :c diff --git a/static/favicon.ico b/static/favicon.ico new file mode 100644 index 0000000..99cefdb Binary files /dev/null and b/static/favicon.ico differ diff --git a/static/main.css b/static/main.css new file mode 100644 index 0000000..c0c10b1 --- /dev/null +++ b/static/main.css @@ -0,0 +1,3 @@ +body { + font-family: "Atkinson Hyperlegible", sans-serif; +} diff --git a/static/robots.txt b/static/robots.txt new file mode 100644 index 0000000..2b10c71 --- /dev/null +++ b/static/robots.txt @@ -0,0 +1 @@ +# Hi cutie \ No newline at end of file diff --git a/templates/index.html b/templates/index.html index e69de29..774fc13 100644 --- a/templates/index.html +++ b/templates/index.html @@ -0,0 +1,6 @@ +