diff --git a/.env.example b/.env.example index 5d4cde3..dc498e7 100644 --- a/.env.example +++ b/.env.example @@ -6,4 +6,7 @@ DISCORD_BOT_TOKEN=dd LISTEN_ADDR=:8169 STORAGE_DIR=./.storage -PUBLIC_BASE_URL=http://localhost:8169 \ No newline at end of file +PUBLIC_BASE_URL=http://localhost:8169 + +SUPERUSER_IDS=62601275618889728 +SUPPORT_IDS= \ No newline at end of file diff --git a/authmiddleware/authmiddleware.go b/authmiddleware/authmiddleware.go index 0146d54..281bf02 100644 --- a/authmiddleware/authmiddleware.go +++ b/authmiddleware/authmiddleware.go @@ -3,6 +3,7 @@ package authmiddleware import ( "errors" "fmt" + "log" "time" "slices" @@ -20,31 +21,12 @@ type AuthMiddleware struct { superuserIDs []string } -type Session struct { - Permissions Permission - User *types.DiscordUser - AccessToken string - LastRefresh time.Time -} - -var SessionKey uint8 - -func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error { - am := AuthMiddleware{ - Client: discordClient, - supportIDs: supportIDs, - superuserIDs: superuserIDs, - } - - return am.Handle -} - -var DefaultSession *Session = &Session{Permissions: PermAnonymous} - func (am *AuthMiddleware) Handle(c fiber.Ctx) error { sc := session.FromContext(c) sess := am.Init(sc) + // log.Println(sess) + am.validateAccessToken(sess) // this will remove AccessToken if its no good am.setPermissions(sess) // this captures this am.Commit(sc, sess) // and save our first bits? @@ -53,23 +35,30 @@ func (am *AuthMiddleware) Handle(c fiber.Ctx) error { } func (am *AuthMiddleware) Init(sess *session.Middleware) *Session { - session, ok := sess.Get(SessionKey).(*Session) + session, ok := sess.Get(SessionKey).([]byte) if !ok { - am.Commit(sess, DefaultSession) return DefaultSession } - return session + s, err := SessionFromJSON(session) + if err != nil { + log.Panicln("session failed to parse", err) + return nil + } + return s } func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) { - sc.Set(SessionKey, sess) + sc.Set(SessionKey, sess.AsJSON()) } func SessionFrom(c fiber.Ctx) (s *Session) { sess := session.FromContext(c) - s, _ = sess.Get(SessionKey).(*Session) - + sessionJSON := sess.Get(SessionKey).([]byte) + s, err := SessionFromJSON(sessionJSON) + if err != nil { + log.Panicln("session failed to parse", err) + } return s } @@ -86,6 +75,12 @@ func (am *AuthMiddleware) setPermissions(sess *Session) { if sess.AccessToken != "" { sess.Permissions = PermUser + } else { + return + } + + if sess.User == nil { + return } if am.isSupport(sess.User.ID) { @@ -112,6 +107,7 @@ func (am *AuthMiddleware) validateAccessToken(sess *Session) { } sess.User = user + sess.LastRefresh = time.Now() } } diff --git a/authmiddleware/authmiddleware_test.go b/authmiddleware/authmiddleware_test.go deleted file mode 100644 index 292389e..0000000 --- a/authmiddleware/authmiddleware_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package authmiddleware_test - -import ( - "bytes" - "net/http" - "testing" - - "git.sapphic.engineer/roleypoly/v4/authmiddleware" - "git.sapphic.engineer/roleypoly/v4/discord" - "git.sapphic.engineer/roleypoly/v4/discord/clientmock" - "github.com/goccy/go-json" - "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/middleware/session" -) - -func TestAnonymous(t *testing.T) { - dc := clientmock.NewDiscordClientMock() - app := getApp(dc) - - setSession(app, dc, authmiddleware.Session{ - Permissions: authmiddleware.PermAnonymous, - }) -} - -func getApp(dc discord.IDiscordClient) *fiber.App { - app := fiber.New(fiber.Config{}) - sessionMiddleware, sessionStore := session.NewWithStore() - sessionStore.RegisterType(authmiddleware.Session{}) - - app.Use(sessionMiddleware, authmiddleware.New(dc, []string{}, []string{})) - app.Get("/", authState) - app.Post("/updateSession", updateSession) - - return app -} - -func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session) error { - body := bytes.Buffer{} - json.NewEncoder(&body).Encode(sess) - - // do mocks here - - req, _ := http.NewRequest("POST", "/updateSession", &body) - _, err := app.Test(req) - return err -} - -func authState(c fiber.Ctx) error { - s := authmiddleware.SessionFrom(c) - - permList := []string{} - - if s.Permissions >= authmiddleware.PermAnonymous { - permList = append(permList, "anonymous") - } - - if s.Permissions >= authmiddleware.PermUser { - permList = append(permList, "user") - } - - if s.Permissions >= authmiddleware.PermSupport { - permList = append(permList, "support") - } - - if s.Permissions >= authmiddleware.PermSuperuser { - permList = append(permList, "superuser") - } - - return c.JSON(permList) -} - -func updateSession(c fiber.Ctx) error { - var newSession *authmiddleware.Session - c.Bind().JSON(&newSession) - - sc := session.FromContext(c) - sc.Set(authmiddleware.SessionKey, newSession) - - return c.SendString("ok") -} diff --git a/authmiddleware/must.go b/authmiddleware/must.go new file mode 100644 index 0000000..6929120 --- /dev/null +++ b/authmiddleware/must.go @@ -0,0 +1,20 @@ +package authmiddleware + +import ( + "net/http" + + "git.sapphic.engineer/roleypoly/v4/types" + "github.com/gofiber/fiber/v3" +) + +func MustHavePermission(perm Permission) func(fiber.Ctx) error { + return func(c fiber.Ctx) error { + sess := SessionFrom(c) + + if sess.Permissions >= perm { + return c.Next() + } + + return types.NewAPIError(http.StatusForbidden, "no sorry").Send(c) + } +} diff --git a/authmiddleware/must_test.go b/authmiddleware/must_test.go new file mode 100644 index 0000000..20122d3 --- /dev/null +++ b/authmiddleware/must_test.go @@ -0,0 +1,92 @@ +package authmiddleware_test + +import ( + "testing" + + "git.sapphic.engineer/roleypoly/v4/authmiddleware" + "git.sapphic.engineer/roleypoly/v4/discord/clientmock" + "git.sapphic.engineer/roleypoly/v4/types" + "git.sapphic.engineer/roleypoly/v4/types/fixtures" + "github.com/stretchr/testify/assert" +) + +func TestMustAnonymous(t *testing.T) { + dc := clientmock.NewDiscordClientMock() + app := getApp(dc) + + cookie, err := setSession(app, dc, authmiddleware.Session{ + Permissions: authmiddleware.PermAnonymous, + }, nil) + assert.Nil(t, err) + + err = get(t, app, cookie, "/must/user") + assert.ErrorIs(t, err, errUnauthorized) + err = get(t, app, cookie, "/must/support") + assert.ErrorIs(t, err, errUnauthorized) + err = get(t, app, cookie, "/must/superuser") + assert.ErrorIs(t, err, errUnauthorized) +} + +func TestMustUser(t *testing.T) { + dc := clientmock.NewDiscordClientMock() + app := getApp(dc) + + user := fixtures.User + mockUser(dc, &user) + cookie, err := setSession(app, dc, authmiddleware.Session{ + User: &user, + AccessToken: "access-token", + }, nil) + assert.Nil(t, err) + + err = get(t, app, cookie, "/must/user") + assert.Nil(t, err) + err = get(t, app, cookie, "/must/support") + assert.ErrorIs(t, err, errUnauthorized) + err = get(t, app, cookie, "/must/superuser") + assert.ErrorIs(t, err, errUnauthorized) +} + +func TestMustSupport(t *testing.T) { + dc := clientmock.NewDiscordClientMock() + app := getApp(dc) + + user := &types.DiscordUser{ + ID: "support-user", + } + mockUser(dc, user) + cookie, err := setSession(app, dc, authmiddleware.Session{ + User: user, + AccessToken: "access-token", + }, nil) + assert.Nil(t, err) + + err = get(t, app, cookie, "/must/user") + assert.Nil(t, err) + err = get(t, app, cookie, "/must/support") + assert.Nil(t, err) + err = get(t, app, cookie, "/must/superuser") + assert.ErrorIs(t, err, errUnauthorized) +} + +func TestMustSupuruser(t *testing.T) { + dc := clientmock.NewDiscordClientMock() + app := getApp(dc) + + user := &types.DiscordUser{ + ID: "superuser-user", + } + mockUser(dc, user) + cookie, err := setSession(app, dc, authmiddleware.Session{ + User: user, + AccessToken: "access-token", + }, nil) + assert.Nil(t, err) + + err = get(t, app, cookie, "/must/user") + assert.Nil(t, err) + err = get(t, app, cookie, "/must/support") + assert.Nil(t, err) + err = get(t, app, cookie, "/must/superuser") + assert.Nil(t, err) +} diff --git a/authmiddleware/session.go b/authmiddleware/session.go new file mode 100644 index 0000000..c2e9f3d --- /dev/null +++ b/authmiddleware/session.go @@ -0,0 +1,47 @@ +package authmiddleware + +import ( + "log" + "time" + + "git.sapphic.engineer/roleypoly/v4/discord" + "git.sapphic.engineer/roleypoly/v4/types" + "github.com/goccy/go-json" + "github.com/gofiber/fiber/v3" +) + +type Session struct { + Permissions Permission + User *types.DiscordUser + AccessToken string + LastRefresh time.Time +} + +func (s *Session) AsJSON() []byte { + out, err := json.Marshal(s) + if err != nil { + log.Panicln("failed to marshal session json: ", err) + } + + return out +} + +func SessionFromJSON(input []byte) (*Session, error) { + var s Session + err := json.Unmarshal(input, &s) + return &s, err +} + +var SessionKey uint8 + +func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error { + am := AuthMiddleware{ + Client: discordClient, + supportIDs: supportIDs, + superuserIDs: superuserIDs, + } + + return am.Handle +} + +var DefaultSession *Session = &Session{Permissions: PermAnonymous} diff --git a/authmiddleware/util_test.go b/authmiddleware/util_test.go new file mode 100644 index 0000000..681407b --- /dev/null +++ b/authmiddleware/util_test.go @@ -0,0 +1,117 @@ +package authmiddleware_test + +import ( + "bytes" + "errors" + "net/http" + "testing" + + "git.sapphic.engineer/roleypoly/v4/authmiddleware" + "git.sapphic.engineer/roleypoly/v4/discord" + "git.sapphic.engineer/roleypoly/v4/discord/clientmock" + "git.sapphic.engineer/roleypoly/v4/types" + "github.com/goccy/go-json" + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/session" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func mockUser(dc *clientmock.DiscordClientMock, user *types.DiscordUser) { + if user == nil { + dc.MockResponse("GET", "/users/@me", http.StatusUnauthorized, "unauthorized").Once() + } + + dc.MockResponse("GET", "/users/@me", 200, user).Once() +} + +func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session, inputCookie *http.Cookie) (*http.Cookie, error) { + dc.On("ClientAuth", mock.AnythingOfType("*http.Request"), "access-token") + + body := bytes.Buffer{} + json.NewEncoder(&body).Encode(sess) + + req, _ := http.NewRequest("POST", "/updateSession", &body) + + if inputCookie != nil { + req.AddCookie(inputCookie) + } + + resp, err := app.Test(req) + if err != nil { + return nil, err + } + + cookie := resp.Cookies()[0] + return cookie, err +} + +func authState(c fiber.Ctx) error { + s := authmiddleware.SessionFrom(c) + + permList := []string{} + + if s.Permissions >= authmiddleware.PermAnonymous { + permList = append(permList, "anonymous") + } + + if s.Permissions >= authmiddleware.PermUser { + permList = append(permList, "user") + } + + if s.Permissions >= authmiddleware.PermSupport { + permList = append(permList, "support") + } + + if s.Permissions >= authmiddleware.PermSuperuser { + permList = append(permList, "superuser") + } + + return c.JSON(permList) +} + +func updateSession(c fiber.Ctx) error { + var newSession *authmiddleware.Session + c.Bind().JSON(&newSession) + + sc := session.FromContext(c) + sc.Set(authmiddleware.SessionKey, newSession.AsJSON()) + + return c.SendString("ok") +} + +var errUnauthorized = errors.New("unauthorized") + +func get(t *testing.T, app *fiber.App, cookie *http.Cookie, path string) error { + req, err := http.NewRequest("GET", path, nil) + req.AddCookie(cookie) + assert.Nil(t, err) + + resp, err := app.Test(req) + assert.Nil(t, err) + + if resp.StatusCode == http.StatusForbidden { + return errUnauthorized + } + + return nil +} + +func ok(c fiber.Ctx) error { + return c.SendString("ok") +} + +func getApp(dc discord.IDiscordClient) *fiber.App { + app := fiber.New(fiber.Config{}) + sessionMiddleware, sessionStore := session.NewWithStore() + sessionStore.RegisterType(authmiddleware.Session{}) + + app.Use(sessionMiddleware, authmiddleware.New(dc, []string{"support-user"}, []string{"superuser-user"})) + app.Get("/", authState) + app.Post("/updateSession", updateSession) + app.Get("/must/user", ok, authmiddleware.MustHavePermission(authmiddleware.PermUser)) + app.Get("/must/support", ok, authmiddleware.MustHavePermission(authmiddleware.PermSupport)) + app.Get("/must/superuser", ok, authmiddleware.MustHavePermission(authmiddleware.PermSuperuser)) + + return app +} diff --git a/authmiddleware/validation_test.go b/authmiddleware/validation_test.go new file mode 100644 index 0000000..20aaa7f --- /dev/null +++ b/authmiddleware/validation_test.go @@ -0,0 +1,34 @@ +package authmiddleware_test + +import ( + "testing" + + "git.sapphic.engineer/roleypoly/v4/authmiddleware" + "git.sapphic.engineer/roleypoly/v4/discord/clientmock" + "git.sapphic.engineer/roleypoly/v4/types/fixtures" + "github.com/stretchr/testify/assert" +) + +func TestExpiration(t *testing.T) { + dc := clientmock.NewDiscordClientMock() + app := getApp(dc) + + mockUser(dc, &fixtures.User) + cookie, err := setSession(app, dc, authmiddleware.Session{ + AccessToken: "access-token", + }, nil) + assert.Nil(t, err) + + // Good request for now,, + err = get(t, app, cookie, "/must/user") + assert.Nil(t, err) + + mockUser(dc, nil) + cookie, err = setSession(app, dc, authmiddleware.Session{ + AccessToken: "access-token", + }, cookie) + assert.Nil(t, err) + + err = get(t, app, cookie, "/must/user") + assert.ErrorIs(t, errUnauthorized, err) +} diff --git a/discord/clientmock/discord_client_mock.go b/discord/clientmock/discord_client_mock.go index 6d6f523..d86eb7c 100644 --- a/discord/clientmock/discord_client_mock.go +++ b/discord/clientmock/discord_client_mock.go @@ -46,7 +46,7 @@ func (c *DiscordClientMock) ClientAuth(req *http.Request, accessToken string) { c.Called(req, accessToken) } -func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) { +func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) *mock.Call { body := bytes.Buffer{} json.NewEncoder(&body).Encode(data) @@ -57,7 +57,7 @@ func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, da pathMatcher := regexp.MustCompile(strings.ReplaceAll(path, "*", "[a-z0-9_-]+")) - c.On("Do", mock.MatchedBy(func(req *http.Request) bool { + return c.On("Do", mock.MatchedBy(func(req *http.Request) bool { return req.Method == method && pathMatcher.MatchString(req.URL.Path) })).Return(r, nil) }