126 lines
2.6 KiB
Go
126 lines
2.6 KiB
Go
package authmiddleware
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
|
|
"slices"
|
|
|
|
"git.sapphic.engineer/roleypoly/v4/discord"
|
|
"git.sapphic.engineer/roleypoly/v4/types"
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/gofiber/fiber/v3/middleware/session"
|
|
)
|
|
|
|
type AuthMiddleware struct {
|
|
Client discord.IDiscordClient
|
|
|
|
supportIDs []string
|
|
superuserIDs []string
|
|
}
|
|
|
|
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
|
|
am := AuthMiddleware{
|
|
Client: discordClient,
|
|
supportIDs: supportIDs,
|
|
superuserIDs: superuserIDs,
|
|
}
|
|
|
|
return am.Handle
|
|
}
|
|
|
|
func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
|
|
sc := session.FromContext(c)
|
|
sess := am.Init(sc)
|
|
|
|
// log.Println(sess)
|
|
|
|
am.validateAccessToken(sess) // this will remove AccessToken if its no good
|
|
am.setPermissions(sess) // this captures this
|
|
am.Commit(sc, sess) // and save our first bits?
|
|
|
|
return c.Next()
|
|
}
|
|
|
|
func (am *AuthMiddleware) Init(sess *session.Middleware) *Session {
|
|
session, ok := sess.Get(SessionKey).([]byte)
|
|
if !ok {
|
|
return DefaultSession
|
|
}
|
|
|
|
s, err := SessionFromJSON(session)
|
|
if err != nil {
|
|
log.Panicln("session failed to parse", err)
|
|
return nil
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) {
|
|
sc.Set(SessionKey, sess.AsJSON())
|
|
}
|
|
|
|
func (am *AuthMiddleware) isSupport(userID string) bool {
|
|
return slices.Contains(am.supportIDs, userID)
|
|
}
|
|
|
|
func (am *AuthMiddleware) isSuperuser(userID string) bool {
|
|
return slices.Contains(am.superuserIDs, userID)
|
|
}
|
|
|
|
func (am *AuthMiddleware) setPermissions(sess *Session) {
|
|
sess.Permissions = PermAnonymous
|
|
|
|
if sess.AccessToken != "" {
|
|
sess.Permissions = PermUser
|
|
} else {
|
|
return
|
|
}
|
|
|
|
if sess.User == nil {
|
|
return
|
|
}
|
|
|
|
if am.isSupport(sess.User.ID) {
|
|
sess.Permissions = PermSupport
|
|
}
|
|
|
|
if am.isSuperuser(sess.User.ID) {
|
|
sess.Permissions = PermSuperuser
|
|
}
|
|
}
|
|
|
|
func (am *AuthMiddleware) validateAccessToken(sess *Session) {
|
|
if sess.AccessToken == "" {
|
|
return
|
|
}
|
|
|
|
if sess.LastRefresh.Add(time.Hour).Before(time.Now()) {
|
|
user, err := am.GetCurrentUser(sess.AccessToken)
|
|
if err != nil {
|
|
if errors.Is(err, discord.ErrUnauthorized) {
|
|
sess.AccessToken = ""
|
|
return
|
|
}
|
|
}
|
|
|
|
sess.User = user
|
|
sess.LastRefresh = time.Now()
|
|
}
|
|
}
|
|
|
|
func (am *AuthMiddleware) GetCurrentUser(accessToken string) (*types.DiscordUser, error) {
|
|
req := discord.NewRequest("GET", "/users/@me")
|
|
am.Client.ClientAuth(req, accessToken)
|
|
|
|
resp, err := am.Client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("authmiddleware.GetCurrentUser: request failed: %w", err)
|
|
}
|
|
|
|
var user types.DiscordUser
|
|
err = discord.OutputResponse(resp, &user)
|
|
return &user, err
|
|
}
|