roles
This commit is contained in:
parent
8c8cbfd7dd
commit
41d48bf60a
28 changed files with 434 additions and 7 deletions
126
auth/authmiddleware/authmiddleware.go
Normal file
126
auth/authmiddleware/authmiddleware.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
package authmiddleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/session"
|
||||
)
|
||||
|
||||
type AuthMiddleware struct {
|
||||
Client discord.IDiscordClient
|
||||
|
||||
supportIDs []string
|
||||
superuserIDs []string
|
||||
}
|
||||
|
||||
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
|
||||
am := AuthMiddleware{
|
||||
Client: discordClient,
|
||||
supportIDs: supportIDs,
|
||||
superuserIDs: superuserIDs,
|
||||
}
|
||||
|
||||
return am.Handle
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
|
||||
sc := session.FromContext(c)
|
||||
sess := am.Init(sc)
|
||||
|
||||
// log.Println(sess)
|
||||
|
||||
am.validateAccessToken(sess) // this will remove AccessToken if its no good
|
||||
am.setPermissions(sess) // this captures this
|
||||
am.Commit(sc, sess) // and save our first bits?
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) Init(sess *session.Middleware) *Session {
|
||||
session, ok := sess.Get(SessionKey).([]byte)
|
||||
if !ok {
|
||||
return DefaultSession
|
||||
}
|
||||
|
||||
s, err := SessionFromJSON(session)
|
||||
if err != nil {
|
||||
log.Panicln("session failed to parse", err)
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) {
|
||||
sc.Set(SessionKey, sess.AsJSON())
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) isSupport(userID string) bool {
|
||||
return slices.Contains(am.supportIDs, userID)
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) isSuperuser(userID string) bool {
|
||||
return slices.Contains(am.superuserIDs, userID)
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) setPermissions(sess *Session) {
|
||||
sess.Permissions = PermAnonymous
|
||||
|
||||
if sess.AccessToken != "" {
|
||||
sess.Permissions = PermUser
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
if sess.User == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if am.isSupport(sess.User.ID) {
|
||||
sess.Permissions = PermSupport
|
||||
}
|
||||
|
||||
if am.isSuperuser(sess.User.ID) {
|
||||
sess.Permissions = PermSuperuser
|
||||
}
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) validateAccessToken(sess *Session) {
|
||||
if sess.AccessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if sess.LastRefresh.Add(time.Hour).Before(time.Now()) {
|
||||
user, err := am.GetCurrentUser(sess.AccessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, discord.ErrUnauthorized) {
|
||||
sess.AccessToken = ""
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sess.User = user
|
||||
sess.LastRefresh = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) GetCurrentUser(accessToken string) (*types.DiscordUser, error) {
|
||||
req := discord.NewRequest("GET", "/users/@me")
|
||||
am.Client.ClientAuth(req, accessToken)
|
||||
|
||||
resp, err := am.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authmiddleware.GetCurrentUser: request failed: %w", err)
|
||||
}
|
||||
|
||||
var user types.DiscordUser
|
||||
err = discord.OutputResponse(resp, &user)
|
||||
return &user, err
|
||||
}
|
10
auth/authmiddleware/const.go
Normal file
10
auth/authmiddleware/const.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package authmiddleware
|
||||
|
||||
type Permission uint8
|
||||
|
||||
const (
|
||||
PermAnonymous Permission = 1 << iota
|
||||
PermUser
|
||||
PermSupport
|
||||
PermSuperuser
|
||||
)
|
20
auth/authmiddleware/must.go
Normal file
20
auth/authmiddleware/must.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package authmiddleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
func MustHavePermission(perm Permission) func(fiber.Ctx) error {
|
||||
return func(c fiber.Ctx) error {
|
||||
sess := SessionFrom(c)
|
||||
|
||||
if sess.Permissions >= perm {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
return types.NewAPIError(http.StatusForbidden, "no sorry").Send(c)
|
||||
}
|
||||
}
|
92
auth/authmiddleware/must_test.go
Normal file
92
auth/authmiddleware/must_test.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package authmiddleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"git.sapphic.engineer/roleypoly/v4/types/fixtures"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMustAnonymous(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
Permissions: authmiddleware.PermAnonymous,
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
err = get(t, app, cookie, "/must/support")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
err = get(t, app, cookie, "/must/superuser")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
}
|
||||
|
||||
func TestMustUser(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
user := fixtures.User
|
||||
mockUser(dc, &user)
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
User: &user,
|
||||
AccessToken: "access-token",
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/support")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
err = get(t, app, cookie, "/must/superuser")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
}
|
||||
|
||||
func TestMustSupport(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
user := &types.DiscordUser{
|
||||
ID: "support-user",
|
||||
}
|
||||
mockUser(dc, user)
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
User: user,
|
||||
AccessToken: "access-token",
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/support")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/superuser")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
}
|
||||
|
||||
func TestMustSupuruser(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
user := &types.DiscordUser{
|
||||
ID: "superuser-user",
|
||||
}
|
||||
mockUser(dc, user)
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
User: user,
|
||||
AccessToken: "access-token",
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/support")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/superuser")
|
||||
assert.Nil(t, err)
|
||||
}
|
48
auth/authmiddleware/session.go
Normal file
48
auth/authmiddleware/session.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package authmiddleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/session"
|
||||
)
|
||||
|
||||
var (
|
||||
SessionKey uint8
|
||||
DefaultSession *Session = &Session{Permissions: PermAnonymous}
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Permissions Permission
|
||||
User *types.DiscordUser
|
||||
AccessToken string
|
||||
LastRefresh time.Time
|
||||
}
|
||||
|
||||
func (s *Session) AsJSON() []byte {
|
||||
out, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
log.Panicln("failed to marshal session json: ", err)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func SessionFromJSON(input []byte) (*Session, error) {
|
||||
var s Session
|
||||
err := json.Unmarshal(input, &s)
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func SessionFrom(c fiber.Ctx) (s *Session) {
|
||||
sess := session.FromContext(c)
|
||||
sessionJSON := sess.Get(SessionKey).([]byte)
|
||||
s, err := SessionFromJSON(sessionJSON)
|
||||
if err != nil {
|
||||
log.Panicln("session failed to parse", err)
|
||||
}
|
||||
return s
|
||||
}
|
118
auth/authmiddleware/util_test.go
Normal file
118
auth/authmiddleware/util_test.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package authmiddleware_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
|
||||
"git.sapphic.engineer/roleypoly/v4/roleypoly"
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/session"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func mockUser(dc *clientmock.DiscordClientMock, user *types.DiscordUser) {
|
||||
if user == nil {
|
||||
dc.MockResponse("GET", "/users/@me", http.StatusUnauthorized, "unauthorized").Once()
|
||||
}
|
||||
|
||||
dc.MockResponse("GET", "/users/@me", 200, user).Once()
|
||||
}
|
||||
|
||||
func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session, inputCookie *http.Cookie) (*http.Cookie, error) {
|
||||
dc.On("ClientAuth", mock.AnythingOfType("*http.Request"), "access-token")
|
||||
|
||||
body := bytes.Buffer{}
|
||||
json.NewEncoder(&body).Encode(sess)
|
||||
|
||||
req, _ := http.NewRequest("POST", "/updateSession", &body)
|
||||
|
||||
if inputCookie != nil {
|
||||
req.AddCookie(inputCookie)
|
||||
}
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookie := resp.Cookies()[0]
|
||||
return cookie, err
|
||||
}
|
||||
|
||||
func authState(c fiber.Ctx) error {
|
||||
s := authmiddleware.SessionFrom(c)
|
||||
|
||||
permList := []string{}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermAnonymous {
|
||||
permList = append(permList, "anonymous")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermUser {
|
||||
permList = append(permList, "user")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSupport {
|
||||
permList = append(permList, "support")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSuperuser {
|
||||
permList = append(permList, "superuser")
|
||||
}
|
||||
|
||||
return c.JSON(permList)
|
||||
}
|
||||
|
||||
func updateSession(c fiber.Ctx) error {
|
||||
var newSession *authmiddleware.Session
|
||||
c.Bind().JSON(&newSession)
|
||||
|
||||
sc := session.FromContext(c)
|
||||
sc.Set(authmiddleware.SessionKey, newSession.AsJSON())
|
||||
|
||||
return c.SendString("ok")
|
||||
}
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized")
|
||||
|
||||
func get(t *testing.T, app *fiber.App, cookie *http.Cookie, path string) error {
|
||||
req, err := http.NewRequest("GET", path, nil)
|
||||
req.AddCookie(cookie)
|
||||
assert.Nil(t, err)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return errUnauthorized
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ok(c fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
}
|
||||
|
||||
func getApp(dc discord.IDiscordClient) *fiber.App {
|
||||
app := roleypoly.CreateFiberApp()
|
||||
sessionMiddleware, sessionStore := session.NewWithStore()
|
||||
sessionStore.RegisterType(authmiddleware.Session{})
|
||||
|
||||
app.Use(sessionMiddleware, authmiddleware.New(dc, []string{"support-user"}, []string{"superuser-user"}))
|
||||
app.Get("/", authState)
|
||||
app.Post("/updateSession", updateSession)
|
||||
app.Get("/must/user", ok, authmiddleware.MustHavePermission(authmiddleware.PermUser))
|
||||
app.Get("/must/support", ok, authmiddleware.MustHavePermission(authmiddleware.PermSupport))
|
||||
app.Get("/must/superuser", ok, authmiddleware.MustHavePermission(authmiddleware.PermSuperuser))
|
||||
|
||||
return app
|
||||
}
|
34
auth/authmiddleware/validation_test.go
Normal file
34
auth/authmiddleware/validation_test.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package authmiddleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
|
||||
"git.sapphic.engineer/roleypoly/v4/types/fixtures"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestExpiration(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
mockUser(dc, &fixtures.User)
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
AccessToken: "access-token",
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Good request for now,,
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.Nil(t, err)
|
||||
|
||||
mockUser(dc, nil)
|
||||
cookie, err = setSession(app, dc, authmiddleware.Session{
|
||||
AccessToken: "access-token",
|
||||
}, cookie)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.ErrorIs(t, errUnauthorized, err)
|
||||
}
|
0
auth/discordauth.go
Normal file
0
auth/discordauth.go
Normal file
Loading…
Add table
Add a link
Reference in a new issue