v4/authmiddleware/authmiddleware.go
2025-03-26 21:08:15 -07:00

130 lines
2.8 KiB
Go

package authmiddleware
import (
"errors"
"fmt"
"time"
"slices"
"git.sapphic.engineer/roleypoly/v4/discord"
"git.sapphic.engineer/roleypoly/v4/types"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
)
type AuthMiddleware struct {
Client discord.IDiscordClient
supportIDs []string
superuserIDs []string
}
type Session struct {
Permissions Permission
User *types.DiscordUser
AccessToken string
LastRefresh time.Time
}
var SessionKey uint8
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
am := AuthMiddleware{
Client: discordClient,
supportIDs: supportIDs,
superuserIDs: superuserIDs,
}
return am.Handle
}
var DefaultSession *Session = &Session{Permissions: PermAnonymous}
func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
sc := session.FromContext(c)
sess := am.Init(sc)
am.validateAccessToken(sess) // this will remove AccessToken if its no good
am.setPermissions(sess) // this captures this
am.Commit(sc, sess) // and save our first bits?
return c.Next()
}
func (am *AuthMiddleware) Init(sess *session.Middleware) *Session {
session, ok := sess.Get(SessionKey).(*Session)
if !ok {
am.Commit(sess, DefaultSession)
return DefaultSession
}
return session
}
func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) {
sc.Set(SessionKey, sess)
}
func SessionFrom(c fiber.Ctx) (s *Session) {
sess := session.FromContext(c)
s, _ = sess.Get(SessionKey).(*Session)
return s
}
func (am *AuthMiddleware) isSupport(userID string) bool {
return slices.Contains(am.supportIDs, userID)
}
func (am *AuthMiddleware) isSuperuser(userID string) bool {
return slices.Contains(am.superuserIDs, userID)
}
func (am *AuthMiddleware) setPermissions(sess *Session) {
sess.Permissions = PermAnonymous
if sess.AccessToken != "" {
sess.Permissions = PermUser
}
if am.isSupport(sess.User.ID) {
sess.Permissions = PermSupport
}
if am.isSuperuser(sess.User.ID) {
sess.Permissions = PermSuperuser
}
}
func (am *AuthMiddleware) validateAccessToken(sess *Session) {
if sess.AccessToken == "" {
return
}
if sess.LastRefresh.Add(time.Hour).Before(time.Now()) {
user, err := am.GetCurrentUser(sess.AccessToken)
if err != nil {
if errors.Is(err, discord.ErrUnauthorized) {
sess.AccessToken = ""
return
}
}
sess.User = user
}
}
func (am *AuthMiddleware) GetCurrentUser(accessToken string) (*types.DiscordUser, error) {
req := discord.NewRequest("GET", "/users/@me")
am.Client.ClientAuth(req, accessToken)
resp, err := am.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("authmiddleware.GetCurrentUser: request failed: %w", err)
}
var user types.DiscordUser
err = discord.OutputResponse(resp, &user)
return &user, err
}