package authmiddleware import ( "errors" "fmt" "time" "slices" "git.sapphic.engineer/roleypoly/v4/discord" "git.sapphic.engineer/roleypoly/v4/types" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/session" ) type AuthMiddleware struct { Client discord.IDiscordClient supportIDs []string superuserIDs []string } type Session struct { Permissions Permission User *types.DiscordUser AccessToken string LastRefresh time.Time } var SessionKey uint8 func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error { am := AuthMiddleware{ Client: discordClient, supportIDs: supportIDs, superuserIDs: superuserIDs, } return am.Handle } var DefaultSession *Session = &Session{Permissions: PermAnonymous} func (am *AuthMiddleware) Handle(c fiber.Ctx) error { sc := session.FromContext(c) sess := am.Init(sc) am.validateAccessToken(sess) // this will remove AccessToken if its no good am.setPermissions(sess) // this captures this am.Commit(sc, sess) // and save our first bits? return c.Next() } func (am *AuthMiddleware) Init(sess *session.Middleware) *Session { session, ok := sess.Get(SessionKey).(*Session) if !ok { am.Commit(sess, DefaultSession) return DefaultSession } return session } func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) { sc.Set(SessionKey, sess) } func SessionFrom(c fiber.Ctx) (s *Session) { sess := session.FromContext(c) s, _ = sess.Get(SessionKey).(*Session) return s } func (am *AuthMiddleware) isSupport(userID string) bool { return slices.Contains(am.supportIDs, userID) } func (am *AuthMiddleware) isSuperuser(userID string) bool { return slices.Contains(am.superuserIDs, userID) } func (am *AuthMiddleware) setPermissions(sess *Session) { sess.Permissions = PermAnonymous if sess.AccessToken != "" { sess.Permissions = PermUser } if am.isSupport(sess.User.ID) { sess.Permissions = PermSupport } if am.isSuperuser(sess.User.ID) { sess.Permissions = PermSuperuser } } func (am *AuthMiddleware) validateAccessToken(sess *Session) { if sess.AccessToken == "" { return } if sess.LastRefresh.Add(time.Hour).Before(time.Now()) { user, err := am.GetCurrentUser(sess.AccessToken) if err != nil { if errors.Is(err, discord.ErrUnauthorized) { sess.AccessToken = "" return } } sess.User = user } } func (am *AuthMiddleware) GetCurrentUser(accessToken string) (*types.DiscordUser, error) { req := discord.NewRequest("GET", "/users/@me") am.Client.ClientAuth(req, accessToken) resp, err := am.Client.Do(req) if err != nil { return nil, fmt.Errorf("authmiddleware.GetCurrentUser: request failed: %w", err) } var user types.DiscordUser err = discord.OutputResponse(resp, &user) return &user, err }