wtf auth??

This commit is contained in:
41666 2025-03-26 21:08:15 -07:00
parent 607d7e121c
commit a3a1654030
14 changed files with 337 additions and 1 deletions

View file

@ -0,0 +1,130 @@
package authmiddleware
import (
"errors"
"fmt"
"time"
"slices"
"git.sapphic.engineer/roleypoly/v4/discord"
"git.sapphic.engineer/roleypoly/v4/types"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
)
type AuthMiddleware struct {
Client discord.IDiscordClient
supportIDs []string
superuserIDs []string
}
type Session struct {
Permissions Permission
User *types.DiscordUser
AccessToken string
LastRefresh time.Time
}
var SessionKey uint8
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
am := AuthMiddleware{
Client: discordClient,
supportIDs: supportIDs,
superuserIDs: superuserIDs,
}
return am.Handle
}
var DefaultSession *Session = &Session{Permissions: PermAnonymous}
func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
sc := session.FromContext(c)
sess := am.Init(sc)
am.validateAccessToken(sess) // this will remove AccessToken if its no good
am.setPermissions(sess) // this captures this
am.Commit(sc, sess) // and save our first bits?
return c.Next()
}
func (am *AuthMiddleware) Init(sess *session.Middleware) *Session {
session, ok := sess.Get(SessionKey).(*Session)
if !ok {
am.Commit(sess, DefaultSession)
return DefaultSession
}
return session
}
func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) {
sc.Set(SessionKey, sess)
}
func SessionFrom(c fiber.Ctx) (s *Session) {
sess := session.FromContext(c)
s, _ = sess.Get(SessionKey).(*Session)
return s
}
func (am *AuthMiddleware) isSupport(userID string) bool {
return slices.Contains(am.supportIDs, userID)
}
func (am *AuthMiddleware) isSuperuser(userID string) bool {
return slices.Contains(am.superuserIDs, userID)
}
func (am *AuthMiddleware) setPermissions(sess *Session) {
sess.Permissions = PermAnonymous
if sess.AccessToken != "" {
sess.Permissions = PermUser
}
if am.isSupport(sess.User.ID) {
sess.Permissions = PermSupport
}
if am.isSuperuser(sess.User.ID) {
sess.Permissions = PermSuperuser
}
}
func (am *AuthMiddleware) validateAccessToken(sess *Session) {
if sess.AccessToken == "" {
return
}
if sess.LastRefresh.Add(time.Hour).Before(time.Now()) {
user, err := am.GetCurrentUser(sess.AccessToken)
if err != nil {
if errors.Is(err, discord.ErrUnauthorized) {
sess.AccessToken = ""
return
}
}
sess.User = user
}
}
func (am *AuthMiddleware) GetCurrentUser(accessToken string) (*types.DiscordUser, error) {
req := discord.NewRequest("GET", "/users/@me")
am.Client.ClientAuth(req, accessToken)
resp, err := am.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("authmiddleware.GetCurrentUser: request failed: %w", err)
}
var user types.DiscordUser
err = discord.OutputResponse(resp, &user)
return &user, err
}

View file

@ -0,0 +1,80 @@
package authmiddleware_test
import (
"bytes"
"net/http"
"testing"
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
"git.sapphic.engineer/roleypoly/v4/discord"
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
)
func TestAnonymous(t *testing.T) {
dc := clientmock.NewDiscordClientMock()
app := getApp(dc)
setSession(app, dc, authmiddleware.Session{
Permissions: authmiddleware.PermAnonymous,
})
}
func getApp(dc discord.IDiscordClient) *fiber.App {
app := fiber.New(fiber.Config{})
sessionMiddleware, sessionStore := session.NewWithStore()
sessionStore.RegisterType(authmiddleware.Session{})
app.Use(sessionMiddleware, authmiddleware.New(dc, []string{}, []string{}))
app.Get("/", authState)
app.Post("/updateSession", updateSession)
return app
}
func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session) error {
body := bytes.Buffer{}
json.NewEncoder(&body).Encode(sess)
// do mocks here
req, _ := http.NewRequest("POST", "/updateSession", &body)
_, err := app.Test(req)
return err
}
func authState(c fiber.Ctx) error {
s := authmiddleware.SessionFrom(c)
permList := []string{}
if s.Permissions >= authmiddleware.PermAnonymous {
permList = append(permList, "anonymous")
}
if s.Permissions >= authmiddleware.PermUser {
permList = append(permList, "user")
}
if s.Permissions >= authmiddleware.PermSupport {
permList = append(permList, "support")
}
if s.Permissions >= authmiddleware.PermSuperuser {
permList = append(permList, "superuser")
}
return c.JSON(permList)
}
func updateSession(c fiber.Ctx) error {
var newSession *authmiddleware.Session
c.Bind().JSON(&newSession)
sc := session.FromContext(c)
sc.Set(authmiddleware.SessionKey, newSession)
return c.SendString("ok")
}

10
authmiddleware/const.go Normal file
View file

@ -0,0 +1,10 @@
package authmiddleware
type Permission uint8
const (
PermAnonymous Permission = 1 << iota
PermUser
PermSupport
PermSuperuser
)

View file

@ -42,6 +42,10 @@ func (c *DiscordClientMock) BotAuth(req *http.Request) {
c.Called(req)
}
func (c *DiscordClientMock) ClientAuth(req *http.Request, accessToken string) {
c.Called(req, accessToken)
}
func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) {
body := bytes.Buffer{}
json.NewEncoder(&body).Encode(data)

View file

@ -1,6 +1,7 @@
package discord
import (
"errors"
"fmt"
"net/http"
"net/url"
@ -11,9 +12,14 @@ import (
const DiscordBaseUrl = "https://discord.com/api/v10"
var (
ErrUnauthorized = errors.New("discord: got a 401 from discord")
)
type IDiscordClient interface {
Do(req *http.Request) (*http.Response, error)
BotAuth(req *http.Request)
ClientAuth(req *http.Request, accessToken string)
GetClientID() string
}
@ -44,6 +50,10 @@ func (d *DiscordClient) BotAuth(req *http.Request) {
req.Header.Set("Authorization", fmt.Sprintf("Bot %s", d.BotToken))
}
func (d *DiscordClient) ClientAuth(req *http.Request, accessToken string) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", d.BotToken))
}
func NewRequest(method string, path string) *http.Request {
url, err := url.Parse(fmt.Sprintf("%s%s", DiscordBaseUrl, path))
if err != nil {
@ -62,6 +72,10 @@ func NewRequest(method string, path string) *http.Request {
}
func OutputResponse(resp *http.Response, dst any) error {
if resp.StatusCode == 401 {
return ErrUnauthorized
}
// TODO: more checks?
return json.NewDecoder(resp.Body).Decode(dst)
}

View file

@ -2,6 +2,7 @@ package interactions_test
import (
"fmt"
"net/url"
"testing"
"git.sapphic.engineer/roleypoly/v4/interactions"
@ -28,6 +29,10 @@ func TestCmdRoleypoly(t *testing.T) {
button := ir.Data.Components[0]
assert.Equal(t, fmt.Sprintf("%s/s/guild-id", i.PublicBaseURL), button.URL)
u, _ := url.Parse(i.PublicBaseURL)
hostname := u.Host
assert.Equal(t, fmt.Sprintf("Pick roles on %s", hostname), button.Label)
// test the command mentions
tests := map[string]string{
"See all the roles": "</2:pickable-roles>",

View file

@ -1,15 +1,22 @@
package roleypoly
import (
"log"
"net/http"
"strings"
"time"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/csrf"
"github.com/gofiber/fiber/v3/middleware/session"
"github.com/gofiber/fiber/v3/middleware/static"
"github.com/gofiber/template/html/v2"
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
"git.sapphic.engineer/roleypoly/v4/discord"
"git.sapphic.engineer/roleypoly/v4/interactions"
staticfs "git.sapphic.engineer/roleypoly/v4/static"
"git.sapphic.engineer/roleypoly/v4/templates"
"git.sapphic.engineer/roleypoly/v4/testing"
)
@ -24,11 +31,38 @@ func CreateFiberApp() *fiber.App {
ViewsLayout: "layouts/main",
})
app.Use(session.New(session.Config{}))
sessionMiddleware, sessionStore := session.NewWithStore()
sessionStore.RegisterType(authmiddleware.Session{})
app.Use(sessionMiddleware)
app.Use(csrf.New(csrf.Config{
Session: sessionStore,
}))
app.Get("/static*", static.New("", static.Config{
FS: staticfs.FS,
Compress: true,
CacheDuration: 24 * 8 * time.Hour,
}))
app.Get("/favicon.ico", oneStatic)
app.Get("/manifest.json", oneStatic)
app.Get("/robots.txt", oneStatic)
app.Get("/humans.txt", oneStatic)
return app
}
func oneStatic(c fiber.Ctx) error {
path := strings.Replace(c.OriginalURL(), "/", "", 1)
f, err := staticfs.FS.Open(path)
if err != nil {
log.Println("oneStatic:", c.OriginalURL(), " failed: ", err)
return c.SendStatus(500)
}
return c.SendStream(f)
}
func SetupControllers(app *fiber.App, dc *discord.DiscordClient, publicKey string, publicBaseURL string) {
gs := discord.NewGuildService(dc)
@ -37,6 +71,8 @@ func SetupControllers(app *fiber.App, dc *discord.DiscordClient, publicKey strin
}).Routes(app.Group("/testing"))
interactions.NewInteractions(publicKey, publicBaseURL, gs).Routes(app.Group("/interactions"))
app.Use(authmiddleware.New(dc, []string{}, []string{}))
}
// func getStorageDirectory() string {

9
static/embed.go Normal file
View file

@ -0,0 +1,9 @@
package static
import "embed"
//go:embed *
var FS embed.FS
// hi :3
// dont look its rude :c

BIN
static/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

3
static/main.css Normal file
View file

@ -0,0 +1,3 @@
body {
font-family: "Atkinson Hyperlegible", sans-serif;
}

1
static/robots.txt Normal file
View file

@ -0,0 +1 @@
# Hi cutie

View file

@ -0,0 +1,6 @@
<h1>roleypoly!</h1>
<div>put a role here</div>
<div>put a role here</div>
<div>put a role here</div>
<div>put a role here</div>
<div>put a role here</div>

View file

@ -1,7 +1,14 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>{{ .HeadTitle }}</title>
<link rel="preconnect" href="https://fonts.bunny.net" />
<link
href="https://fonts.bunny.net/css?family=atkinson-hyperlegible:400,400i"
rel="stylesheet"
/>
<link rel="stylesheet" href="/static/main.css" />
</head>
<body>
{{embed}}

View file

@ -5,6 +5,7 @@ import (
"github.com/gofiber/fiber/v3"
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
"git.sapphic.engineer/roleypoly/v4/discord"
"git.sapphic.engineer/roleypoly/v4/types"
)
@ -15,8 +16,14 @@ type TestingController struct {
func (t *TestingController) Routes(r fiber.Router) {
r.Get("/picker/:version?", t.Picker)
r.Get("/index", t.Index)
r.Get("/m/:server/:user", t.GetMember)
r.Get("/g/:server", t.GetGuild)
r.Get("/auth", t.AuthState, authmiddleware.New(t.Guilds.Client(), []string{}, []string{}))
}
func (t *TestingController) Index(c fiber.Ctx) error {
return c.Render("index", fiber.Map{})
}
func (t *TestingController) Picker(c fiber.Ctx) error {
@ -54,3 +61,27 @@ func (t *TestingController) GetGuild(c fiber.Ctx) error {
return c.JSON(g.DiscordGuild)
}
func (t *TestingController) AuthState(c fiber.Ctx) error {
s := authmiddleware.SessionFrom(c)
permList := []string{}
if s.Permissions >= authmiddleware.PermAnonymous {
permList = append(permList, "anonymous")
}
if s.Permissions >= authmiddleware.PermUser {
permList = append(permList, "user")
}
if s.Permissions >= authmiddleware.PermSupport {
permList = append(permList, "support")
}
if s.Permissions >= authmiddleware.PermSuperuser {
permList = append(permList, "superuser")
}
return c.JSON(permList)
}