wtf auth??
This commit is contained in:
parent
607d7e121c
commit
a3a1654030
14 changed files with 337 additions and 1 deletions
130
authmiddleware/authmiddleware.go
Normal file
130
authmiddleware/authmiddleware.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package authmiddleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/session"
|
||||
)
|
||||
|
||||
type AuthMiddleware struct {
|
||||
Client discord.IDiscordClient
|
||||
|
||||
supportIDs []string
|
||||
superuserIDs []string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Permissions Permission
|
||||
User *types.DiscordUser
|
||||
AccessToken string
|
||||
LastRefresh time.Time
|
||||
}
|
||||
|
||||
var SessionKey uint8
|
||||
|
||||
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
|
||||
am := AuthMiddleware{
|
||||
Client: discordClient,
|
||||
supportIDs: supportIDs,
|
||||
superuserIDs: superuserIDs,
|
||||
}
|
||||
|
||||
return am.Handle
|
||||
}
|
||||
|
||||
var DefaultSession *Session = &Session{Permissions: PermAnonymous}
|
||||
|
||||
func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
|
||||
sc := session.FromContext(c)
|
||||
sess := am.Init(sc)
|
||||
|
||||
am.validateAccessToken(sess) // this will remove AccessToken if its no good
|
||||
am.setPermissions(sess) // this captures this
|
||||
am.Commit(sc, sess) // and save our first bits?
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) Init(sess *session.Middleware) *Session {
|
||||
session, ok := sess.Get(SessionKey).(*Session)
|
||||
if !ok {
|
||||
am.Commit(sess, DefaultSession)
|
||||
return DefaultSession
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) {
|
||||
sc.Set(SessionKey, sess)
|
||||
}
|
||||
|
||||
func SessionFrom(c fiber.Ctx) (s *Session) {
|
||||
sess := session.FromContext(c)
|
||||
s, _ = sess.Get(SessionKey).(*Session)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) isSupport(userID string) bool {
|
||||
return slices.Contains(am.supportIDs, userID)
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) isSuperuser(userID string) bool {
|
||||
return slices.Contains(am.superuserIDs, userID)
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) setPermissions(sess *Session) {
|
||||
sess.Permissions = PermAnonymous
|
||||
|
||||
if sess.AccessToken != "" {
|
||||
sess.Permissions = PermUser
|
||||
}
|
||||
|
||||
if am.isSupport(sess.User.ID) {
|
||||
sess.Permissions = PermSupport
|
||||
}
|
||||
|
||||
if am.isSuperuser(sess.User.ID) {
|
||||
sess.Permissions = PermSuperuser
|
||||
}
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) validateAccessToken(sess *Session) {
|
||||
if sess.AccessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if sess.LastRefresh.Add(time.Hour).Before(time.Now()) {
|
||||
user, err := am.GetCurrentUser(sess.AccessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, discord.ErrUnauthorized) {
|
||||
sess.AccessToken = ""
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sess.User = user
|
||||
}
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) GetCurrentUser(accessToken string) (*types.DiscordUser, error) {
|
||||
req := discord.NewRequest("GET", "/users/@me")
|
||||
am.Client.ClientAuth(req, accessToken)
|
||||
|
||||
resp, err := am.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authmiddleware.GetCurrentUser: request failed: %w", err)
|
||||
}
|
||||
|
||||
var user types.DiscordUser
|
||||
err = discord.OutputResponse(resp, &user)
|
||||
return &user, err
|
||||
}
|
80
authmiddleware/authmiddleware_test.go
Normal file
80
authmiddleware/authmiddleware_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package authmiddleware_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/session"
|
||||
)
|
||||
|
||||
func TestAnonymous(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
setSession(app, dc, authmiddleware.Session{
|
||||
Permissions: authmiddleware.PermAnonymous,
|
||||
})
|
||||
}
|
||||
|
||||
func getApp(dc discord.IDiscordClient) *fiber.App {
|
||||
app := fiber.New(fiber.Config{})
|
||||
sessionMiddleware, sessionStore := session.NewWithStore()
|
||||
sessionStore.RegisterType(authmiddleware.Session{})
|
||||
|
||||
app.Use(sessionMiddleware, authmiddleware.New(dc, []string{}, []string{}))
|
||||
app.Get("/", authState)
|
||||
app.Post("/updateSession", updateSession)
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session) error {
|
||||
body := bytes.Buffer{}
|
||||
json.NewEncoder(&body).Encode(sess)
|
||||
|
||||
// do mocks here
|
||||
|
||||
req, _ := http.NewRequest("POST", "/updateSession", &body)
|
||||
_, err := app.Test(req)
|
||||
return err
|
||||
}
|
||||
|
||||
func authState(c fiber.Ctx) error {
|
||||
s := authmiddleware.SessionFrom(c)
|
||||
|
||||
permList := []string{}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermAnonymous {
|
||||
permList = append(permList, "anonymous")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermUser {
|
||||
permList = append(permList, "user")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSupport {
|
||||
permList = append(permList, "support")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSuperuser {
|
||||
permList = append(permList, "superuser")
|
||||
}
|
||||
|
||||
return c.JSON(permList)
|
||||
}
|
||||
|
||||
func updateSession(c fiber.Ctx) error {
|
||||
var newSession *authmiddleware.Session
|
||||
c.Bind().JSON(&newSession)
|
||||
|
||||
sc := session.FromContext(c)
|
||||
sc.Set(authmiddleware.SessionKey, newSession)
|
||||
|
||||
return c.SendString("ok")
|
||||
}
|
10
authmiddleware/const.go
Normal file
10
authmiddleware/const.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package authmiddleware
|
||||
|
||||
type Permission uint8
|
||||
|
||||
const (
|
||||
PermAnonymous Permission = 1 << iota
|
||||
PermUser
|
||||
PermSupport
|
||||
PermSuperuser
|
||||
)
|
|
@ -42,6 +42,10 @@ func (c *DiscordClientMock) BotAuth(req *http.Request) {
|
|||
c.Called(req)
|
||||
}
|
||||
|
||||
func (c *DiscordClientMock) ClientAuth(req *http.Request, accessToken string) {
|
||||
c.Called(req, accessToken)
|
||||
}
|
||||
|
||||
func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) {
|
||||
body := bytes.Buffer{}
|
||||
json.NewEncoder(&body).Encode(data)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package discord
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -11,9 +12,14 @@ import (
|
|||
|
||||
const DiscordBaseUrl = "https://discord.com/api/v10"
|
||||
|
||||
var (
|
||||
ErrUnauthorized = errors.New("discord: got a 401 from discord")
|
||||
)
|
||||
|
||||
type IDiscordClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
BotAuth(req *http.Request)
|
||||
ClientAuth(req *http.Request, accessToken string)
|
||||
GetClientID() string
|
||||
}
|
||||
|
||||
|
@ -44,6 +50,10 @@ func (d *DiscordClient) BotAuth(req *http.Request) {
|
|||
req.Header.Set("Authorization", fmt.Sprintf("Bot %s", d.BotToken))
|
||||
}
|
||||
|
||||
func (d *DiscordClient) ClientAuth(req *http.Request, accessToken string) {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", d.BotToken))
|
||||
}
|
||||
|
||||
func NewRequest(method string, path string) *http.Request {
|
||||
url, err := url.Parse(fmt.Sprintf("%s%s", DiscordBaseUrl, path))
|
||||
if err != nil {
|
||||
|
@ -62,6 +72,10 @@ func NewRequest(method string, path string) *http.Request {
|
|||
}
|
||||
|
||||
func OutputResponse(resp *http.Response, dst any) error {
|
||||
if resp.StatusCode == 401 {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
|
||||
// TODO: more checks?
|
||||
return json.NewDecoder(resp.Body).Decode(dst)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package interactions_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/interactions"
|
||||
|
@ -28,6 +29,10 @@ func TestCmdRoleypoly(t *testing.T) {
|
|||
button := ir.Data.Components[0]
|
||||
assert.Equal(t, fmt.Sprintf("%s/s/guild-id", i.PublicBaseURL), button.URL)
|
||||
|
||||
u, _ := url.Parse(i.PublicBaseURL)
|
||||
hostname := u.Host
|
||||
assert.Equal(t, fmt.Sprintf("Pick roles on %s", hostname), button.Label)
|
||||
|
||||
// test the command mentions
|
||||
tests := map[string]string{
|
||||
"See all the roles": "</2:pickable-roles>",
|
||||
|
|
|
@ -1,15 +1,22 @@
|
|||
package roleypoly
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/csrf"
|
||||
"github.com/gofiber/fiber/v3/middleware/session"
|
||||
"github.com/gofiber/fiber/v3/middleware/static"
|
||||
"github.com/gofiber/template/html/v2"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/interactions"
|
||||
staticfs "git.sapphic.engineer/roleypoly/v4/static"
|
||||
"git.sapphic.engineer/roleypoly/v4/templates"
|
||||
"git.sapphic.engineer/roleypoly/v4/testing"
|
||||
)
|
||||
|
@ -24,11 +31,38 @@ func CreateFiberApp() *fiber.App {
|
|||
ViewsLayout: "layouts/main",
|
||||
})
|
||||
|
||||
app.Use(session.New(session.Config{}))
|
||||
sessionMiddleware, sessionStore := session.NewWithStore()
|
||||
sessionStore.RegisterType(authmiddleware.Session{})
|
||||
|
||||
app.Use(sessionMiddleware)
|
||||
app.Use(csrf.New(csrf.Config{
|
||||
Session: sessionStore,
|
||||
}))
|
||||
|
||||
app.Get("/static*", static.New("", static.Config{
|
||||
FS: staticfs.FS,
|
||||
Compress: true,
|
||||
CacheDuration: 24 * 8 * time.Hour,
|
||||
}))
|
||||
app.Get("/favicon.ico", oneStatic)
|
||||
app.Get("/manifest.json", oneStatic)
|
||||
app.Get("/robots.txt", oneStatic)
|
||||
app.Get("/humans.txt", oneStatic)
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func oneStatic(c fiber.Ctx) error {
|
||||
path := strings.Replace(c.OriginalURL(), "/", "", 1)
|
||||
f, err := staticfs.FS.Open(path)
|
||||
if err != nil {
|
||||
log.Println("oneStatic:", c.OriginalURL(), " failed: ", err)
|
||||
return c.SendStatus(500)
|
||||
}
|
||||
|
||||
return c.SendStream(f)
|
||||
}
|
||||
|
||||
func SetupControllers(app *fiber.App, dc *discord.DiscordClient, publicKey string, publicBaseURL string) {
|
||||
gs := discord.NewGuildService(dc)
|
||||
|
||||
|
@ -37,6 +71,8 @@ func SetupControllers(app *fiber.App, dc *discord.DiscordClient, publicKey strin
|
|||
}).Routes(app.Group("/testing"))
|
||||
|
||||
interactions.NewInteractions(publicKey, publicBaseURL, gs).Routes(app.Group("/interactions"))
|
||||
|
||||
app.Use(authmiddleware.New(dc, []string{}, []string{}))
|
||||
}
|
||||
|
||||
// func getStorageDirectory() string {
|
||||
|
|
9
static/embed.go
Normal file
9
static/embed.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package static
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed *
|
||||
var FS embed.FS
|
||||
|
||||
// hi :3
|
||||
// dont look its rude :c
|
BIN
static/favicon.ico
Normal file
BIN
static/favicon.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 101 KiB |
3
static/main.css
Normal file
3
static/main.css
Normal file
|
@ -0,0 +1,3 @@
|
|||
body {
|
||||
font-family: "Atkinson Hyperlegible", sans-serif;
|
||||
}
|
1
static/robots.txt
Normal file
1
static/robots.txt
Normal file
|
@ -0,0 +1 @@
|
|||
# Hi cutie
|
|
@ -0,0 +1,6 @@
|
|||
<h1>roleypoly!</h1>
|
||||
<div>put a role here</div>
|
||||
<div>put a role here</div>
|
||||
<div>put a role here</div>
|
||||
<div>put a role here</div>
|
||||
<div>put a role here</div>
|
|
@ -1,7 +1,14 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>{{ .HeadTitle }}</title>
|
||||
<link rel="preconnect" href="https://fonts.bunny.net" />
|
||||
<link
|
||||
href="https://fonts.bunny.net/css?family=atkinson-hyperlegible:400,400i"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
<link rel="stylesheet" href="/static/main.css" />
|
||||
</head>
|
||||
<body>
|
||||
{{embed}}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
)
|
||||
|
@ -15,8 +16,14 @@ type TestingController struct {
|
|||
|
||||
func (t *TestingController) Routes(r fiber.Router) {
|
||||
r.Get("/picker/:version?", t.Picker)
|
||||
r.Get("/index", t.Index)
|
||||
r.Get("/m/:server/:user", t.GetMember)
|
||||
r.Get("/g/:server", t.GetGuild)
|
||||
r.Get("/auth", t.AuthState, authmiddleware.New(t.Guilds.Client(), []string{}, []string{}))
|
||||
}
|
||||
|
||||
func (t *TestingController) Index(c fiber.Ctx) error {
|
||||
return c.Render("index", fiber.Map{})
|
||||
}
|
||||
|
||||
func (t *TestingController) Picker(c fiber.Ctx) error {
|
||||
|
@ -54,3 +61,27 @@ func (t *TestingController) GetGuild(c fiber.Ctx) error {
|
|||
|
||||
return c.JSON(g.DiscordGuild)
|
||||
}
|
||||
|
||||
func (t *TestingController) AuthState(c fiber.Ctx) error {
|
||||
s := authmiddleware.SessionFrom(c)
|
||||
|
||||
permList := []string{}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermAnonymous {
|
||||
permList = append(permList, "anonymous")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermUser {
|
||||
permList = append(permList, "user")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSupport {
|
||||
permList = append(permList, "support")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSuperuser {
|
||||
permList = append(permList, "superuser")
|
||||
}
|
||||
|
||||
return c.JSON(permList)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue