auth yay!

This commit is contained in:
41666 2025-03-26 22:28:20 -07:00
parent a3a1654030
commit c50bfc1a7f
9 changed files with 339 additions and 110 deletions

View file

@ -6,4 +6,7 @@ DISCORD_BOT_TOKEN=dd
LISTEN_ADDR=:8169
STORAGE_DIR=./.storage
PUBLIC_BASE_URL=http://localhost:8169
PUBLIC_BASE_URL=http://localhost:8169
SUPERUSER_IDS=62601275618889728
SUPPORT_IDS=

View file

@ -3,6 +3,7 @@ package authmiddleware
import (
"errors"
"fmt"
"log"
"time"
"slices"
@ -20,31 +21,12 @@ type AuthMiddleware struct {
superuserIDs []string
}
type Session struct {
Permissions Permission
User *types.DiscordUser
AccessToken string
LastRefresh time.Time
}
var SessionKey uint8
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
am := AuthMiddleware{
Client: discordClient,
supportIDs: supportIDs,
superuserIDs: superuserIDs,
}
return am.Handle
}
var DefaultSession *Session = &Session{Permissions: PermAnonymous}
func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
sc := session.FromContext(c)
sess := am.Init(sc)
// log.Println(sess)
am.validateAccessToken(sess) // this will remove AccessToken if its no good
am.setPermissions(sess) // this captures this
am.Commit(sc, sess) // and save our first bits?
@ -53,23 +35,30 @@ func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
}
func (am *AuthMiddleware) Init(sess *session.Middleware) *Session {
session, ok := sess.Get(SessionKey).(*Session)
session, ok := sess.Get(SessionKey).([]byte)
if !ok {
am.Commit(sess, DefaultSession)
return DefaultSession
}
return session
s, err := SessionFromJSON(session)
if err != nil {
log.Panicln("session failed to parse", err)
return nil
}
return s
}
func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) {
sc.Set(SessionKey, sess)
sc.Set(SessionKey, sess.AsJSON())
}
func SessionFrom(c fiber.Ctx) (s *Session) {
sess := session.FromContext(c)
s, _ = sess.Get(SessionKey).(*Session)
sessionJSON := sess.Get(SessionKey).([]byte)
s, err := SessionFromJSON(sessionJSON)
if err != nil {
log.Panicln("session failed to parse", err)
}
return s
}
@ -86,6 +75,12 @@ func (am *AuthMiddleware) setPermissions(sess *Session) {
if sess.AccessToken != "" {
sess.Permissions = PermUser
} else {
return
}
if sess.User == nil {
return
}
if am.isSupport(sess.User.ID) {
@ -112,6 +107,7 @@ func (am *AuthMiddleware) validateAccessToken(sess *Session) {
}
sess.User = user
sess.LastRefresh = time.Now()
}
}

View file

@ -1,80 +0,0 @@
package authmiddleware_test
import (
"bytes"
"net/http"
"testing"
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
"git.sapphic.engineer/roleypoly/v4/discord"
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
)
func TestAnonymous(t *testing.T) {
dc := clientmock.NewDiscordClientMock()
app := getApp(dc)
setSession(app, dc, authmiddleware.Session{
Permissions: authmiddleware.PermAnonymous,
})
}
func getApp(dc discord.IDiscordClient) *fiber.App {
app := fiber.New(fiber.Config{})
sessionMiddleware, sessionStore := session.NewWithStore()
sessionStore.RegisterType(authmiddleware.Session{})
app.Use(sessionMiddleware, authmiddleware.New(dc, []string{}, []string{}))
app.Get("/", authState)
app.Post("/updateSession", updateSession)
return app
}
func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session) error {
body := bytes.Buffer{}
json.NewEncoder(&body).Encode(sess)
// do mocks here
req, _ := http.NewRequest("POST", "/updateSession", &body)
_, err := app.Test(req)
return err
}
func authState(c fiber.Ctx) error {
s := authmiddleware.SessionFrom(c)
permList := []string{}
if s.Permissions >= authmiddleware.PermAnonymous {
permList = append(permList, "anonymous")
}
if s.Permissions >= authmiddleware.PermUser {
permList = append(permList, "user")
}
if s.Permissions >= authmiddleware.PermSupport {
permList = append(permList, "support")
}
if s.Permissions >= authmiddleware.PermSuperuser {
permList = append(permList, "superuser")
}
return c.JSON(permList)
}
func updateSession(c fiber.Ctx) error {
var newSession *authmiddleware.Session
c.Bind().JSON(&newSession)
sc := session.FromContext(c)
sc.Set(authmiddleware.SessionKey, newSession)
return c.SendString("ok")
}

20
authmiddleware/must.go Normal file
View file

@ -0,0 +1,20 @@
package authmiddleware
import (
"net/http"
"git.sapphic.engineer/roleypoly/v4/types"
"github.com/gofiber/fiber/v3"
)
func MustHavePermission(perm Permission) func(fiber.Ctx) error {
return func(c fiber.Ctx) error {
sess := SessionFrom(c)
if sess.Permissions >= perm {
return c.Next()
}
return types.NewAPIError(http.StatusForbidden, "no sorry").Send(c)
}
}

View file

@ -0,0 +1,92 @@
package authmiddleware_test
import (
"testing"
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
"git.sapphic.engineer/roleypoly/v4/types"
"git.sapphic.engineer/roleypoly/v4/types/fixtures"
"github.com/stretchr/testify/assert"
)
func TestMustAnonymous(t *testing.T) {
dc := clientmock.NewDiscordClientMock()
app := getApp(dc)
cookie, err := setSession(app, dc, authmiddleware.Session{
Permissions: authmiddleware.PermAnonymous,
}, nil)
assert.Nil(t, err)
err = get(t, app, cookie, "/must/user")
assert.ErrorIs(t, err, errUnauthorized)
err = get(t, app, cookie, "/must/support")
assert.ErrorIs(t, err, errUnauthorized)
err = get(t, app, cookie, "/must/superuser")
assert.ErrorIs(t, err, errUnauthorized)
}
func TestMustUser(t *testing.T) {
dc := clientmock.NewDiscordClientMock()
app := getApp(dc)
user := fixtures.User
mockUser(dc, &user)
cookie, err := setSession(app, dc, authmiddleware.Session{
User: &user,
AccessToken: "access-token",
}, nil)
assert.Nil(t, err)
err = get(t, app, cookie, "/must/user")
assert.Nil(t, err)
err = get(t, app, cookie, "/must/support")
assert.ErrorIs(t, err, errUnauthorized)
err = get(t, app, cookie, "/must/superuser")
assert.ErrorIs(t, err, errUnauthorized)
}
func TestMustSupport(t *testing.T) {
dc := clientmock.NewDiscordClientMock()
app := getApp(dc)
user := &types.DiscordUser{
ID: "support-user",
}
mockUser(dc, user)
cookie, err := setSession(app, dc, authmiddleware.Session{
User: user,
AccessToken: "access-token",
}, nil)
assert.Nil(t, err)
err = get(t, app, cookie, "/must/user")
assert.Nil(t, err)
err = get(t, app, cookie, "/must/support")
assert.Nil(t, err)
err = get(t, app, cookie, "/must/superuser")
assert.ErrorIs(t, err, errUnauthorized)
}
func TestMustSupuruser(t *testing.T) {
dc := clientmock.NewDiscordClientMock()
app := getApp(dc)
user := &types.DiscordUser{
ID: "superuser-user",
}
mockUser(dc, user)
cookie, err := setSession(app, dc, authmiddleware.Session{
User: user,
AccessToken: "access-token",
}, nil)
assert.Nil(t, err)
err = get(t, app, cookie, "/must/user")
assert.Nil(t, err)
err = get(t, app, cookie, "/must/support")
assert.Nil(t, err)
err = get(t, app, cookie, "/must/superuser")
assert.Nil(t, err)
}

47
authmiddleware/session.go Normal file
View file

@ -0,0 +1,47 @@
package authmiddleware
import (
"log"
"time"
"git.sapphic.engineer/roleypoly/v4/discord"
"git.sapphic.engineer/roleypoly/v4/types"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v3"
)
type Session struct {
Permissions Permission
User *types.DiscordUser
AccessToken string
LastRefresh time.Time
}
func (s *Session) AsJSON() []byte {
out, err := json.Marshal(s)
if err != nil {
log.Panicln("failed to marshal session json: ", err)
}
return out
}
func SessionFromJSON(input []byte) (*Session, error) {
var s Session
err := json.Unmarshal(input, &s)
return &s, err
}
var SessionKey uint8
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
am := AuthMiddleware{
Client: discordClient,
supportIDs: supportIDs,
superuserIDs: superuserIDs,
}
return am.Handle
}
var DefaultSession *Session = &Session{Permissions: PermAnonymous}

117
authmiddleware/util_test.go Normal file
View file

@ -0,0 +1,117 @@
package authmiddleware_test
import (
"bytes"
"errors"
"net/http"
"testing"
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
"git.sapphic.engineer/roleypoly/v4/discord"
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
"git.sapphic.engineer/roleypoly/v4/types"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func mockUser(dc *clientmock.DiscordClientMock, user *types.DiscordUser) {
if user == nil {
dc.MockResponse("GET", "/users/@me", http.StatusUnauthorized, "unauthorized").Once()
}
dc.MockResponse("GET", "/users/@me", 200, user).Once()
}
func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session, inputCookie *http.Cookie) (*http.Cookie, error) {
dc.On("ClientAuth", mock.AnythingOfType("*http.Request"), "access-token")
body := bytes.Buffer{}
json.NewEncoder(&body).Encode(sess)
req, _ := http.NewRequest("POST", "/updateSession", &body)
if inputCookie != nil {
req.AddCookie(inputCookie)
}
resp, err := app.Test(req)
if err != nil {
return nil, err
}
cookie := resp.Cookies()[0]
return cookie, err
}
func authState(c fiber.Ctx) error {
s := authmiddleware.SessionFrom(c)
permList := []string{}
if s.Permissions >= authmiddleware.PermAnonymous {
permList = append(permList, "anonymous")
}
if s.Permissions >= authmiddleware.PermUser {
permList = append(permList, "user")
}
if s.Permissions >= authmiddleware.PermSupport {
permList = append(permList, "support")
}
if s.Permissions >= authmiddleware.PermSuperuser {
permList = append(permList, "superuser")
}
return c.JSON(permList)
}
func updateSession(c fiber.Ctx) error {
var newSession *authmiddleware.Session
c.Bind().JSON(&newSession)
sc := session.FromContext(c)
sc.Set(authmiddleware.SessionKey, newSession.AsJSON())
return c.SendString("ok")
}
var errUnauthorized = errors.New("unauthorized")
func get(t *testing.T, app *fiber.App, cookie *http.Cookie, path string) error {
req, err := http.NewRequest("GET", path, nil)
req.AddCookie(cookie)
assert.Nil(t, err)
resp, err := app.Test(req)
assert.Nil(t, err)
if resp.StatusCode == http.StatusForbidden {
return errUnauthorized
}
return nil
}
func ok(c fiber.Ctx) error {
return c.SendString("ok")
}
func getApp(dc discord.IDiscordClient) *fiber.App {
app := fiber.New(fiber.Config{})
sessionMiddleware, sessionStore := session.NewWithStore()
sessionStore.RegisterType(authmiddleware.Session{})
app.Use(sessionMiddleware, authmiddleware.New(dc, []string{"support-user"}, []string{"superuser-user"}))
app.Get("/", authState)
app.Post("/updateSession", updateSession)
app.Get("/must/user", ok, authmiddleware.MustHavePermission(authmiddleware.PermUser))
app.Get("/must/support", ok, authmiddleware.MustHavePermission(authmiddleware.PermSupport))
app.Get("/must/superuser", ok, authmiddleware.MustHavePermission(authmiddleware.PermSuperuser))
return app
}

View file

@ -0,0 +1,34 @@
package authmiddleware_test
import (
"testing"
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
"git.sapphic.engineer/roleypoly/v4/types/fixtures"
"github.com/stretchr/testify/assert"
)
func TestExpiration(t *testing.T) {
dc := clientmock.NewDiscordClientMock()
app := getApp(dc)
mockUser(dc, &fixtures.User)
cookie, err := setSession(app, dc, authmiddleware.Session{
AccessToken: "access-token",
}, nil)
assert.Nil(t, err)
// Good request for now,,
err = get(t, app, cookie, "/must/user")
assert.Nil(t, err)
mockUser(dc, nil)
cookie, err = setSession(app, dc, authmiddleware.Session{
AccessToken: "access-token",
}, cookie)
assert.Nil(t, err)
err = get(t, app, cookie, "/must/user")
assert.ErrorIs(t, errUnauthorized, err)
}

View file

@ -46,7 +46,7 @@ func (c *DiscordClientMock) ClientAuth(req *http.Request, accessToken string) {
c.Called(req, accessToken)
}
func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) {
func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) *mock.Call {
body := bytes.Buffer{}
json.NewEncoder(&body).Encode(data)
@ -57,7 +57,7 @@ func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, da
pathMatcher := regexp.MustCompile(strings.ReplaceAll(path, "*", "[a-z0-9_-]+"))
c.On("Do", mock.MatchedBy(func(req *http.Request) bool {
return c.On("Do", mock.MatchedBy(func(req *http.Request) bool {
return req.Method == method && pathMatcher.MatchString(req.URL.Path)
})).Return(r, nil)
}