auth yay!
This commit is contained in:
parent
a3a1654030
commit
c50bfc1a7f
9 changed files with 339 additions and 110 deletions
|
@ -6,4 +6,7 @@ DISCORD_BOT_TOKEN=dd
|
|||
LISTEN_ADDR=:8169
|
||||
STORAGE_DIR=./.storage
|
||||
|
||||
PUBLIC_BASE_URL=http://localhost:8169
|
||||
PUBLIC_BASE_URL=http://localhost:8169
|
||||
|
||||
SUPERUSER_IDS=62601275618889728
|
||||
SUPPORT_IDS=
|
|
@ -3,6 +3,7 @@ package authmiddleware
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
@ -20,31 +21,12 @@ type AuthMiddleware struct {
|
|||
superuserIDs []string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Permissions Permission
|
||||
User *types.DiscordUser
|
||||
AccessToken string
|
||||
LastRefresh time.Time
|
||||
}
|
||||
|
||||
var SessionKey uint8
|
||||
|
||||
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
|
||||
am := AuthMiddleware{
|
||||
Client: discordClient,
|
||||
supportIDs: supportIDs,
|
||||
superuserIDs: superuserIDs,
|
||||
}
|
||||
|
||||
return am.Handle
|
||||
}
|
||||
|
||||
var DefaultSession *Session = &Session{Permissions: PermAnonymous}
|
||||
|
||||
func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
|
||||
sc := session.FromContext(c)
|
||||
sess := am.Init(sc)
|
||||
|
||||
// log.Println(sess)
|
||||
|
||||
am.validateAccessToken(sess) // this will remove AccessToken if its no good
|
||||
am.setPermissions(sess) // this captures this
|
||||
am.Commit(sc, sess) // and save our first bits?
|
||||
|
@ -53,23 +35,30 @@ func (am *AuthMiddleware) Handle(c fiber.Ctx) error {
|
|||
}
|
||||
|
||||
func (am *AuthMiddleware) Init(sess *session.Middleware) *Session {
|
||||
session, ok := sess.Get(SessionKey).(*Session)
|
||||
session, ok := sess.Get(SessionKey).([]byte)
|
||||
if !ok {
|
||||
am.Commit(sess, DefaultSession)
|
||||
return DefaultSession
|
||||
}
|
||||
|
||||
return session
|
||||
s, err := SessionFromJSON(session)
|
||||
if err != nil {
|
||||
log.Panicln("session failed to parse", err)
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) Commit(sc *session.Middleware, sess *Session) {
|
||||
sc.Set(SessionKey, sess)
|
||||
sc.Set(SessionKey, sess.AsJSON())
|
||||
}
|
||||
|
||||
func SessionFrom(c fiber.Ctx) (s *Session) {
|
||||
sess := session.FromContext(c)
|
||||
s, _ = sess.Get(SessionKey).(*Session)
|
||||
|
||||
sessionJSON := sess.Get(SessionKey).([]byte)
|
||||
s, err := SessionFromJSON(sessionJSON)
|
||||
if err != nil {
|
||||
log.Panicln("session failed to parse", err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -86,6 +75,12 @@ func (am *AuthMiddleware) setPermissions(sess *Session) {
|
|||
|
||||
if sess.AccessToken != "" {
|
||||
sess.Permissions = PermUser
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
if sess.User == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if am.isSupport(sess.User.ID) {
|
||||
|
@ -112,6 +107,7 @@ func (am *AuthMiddleware) validateAccessToken(sess *Session) {
|
|||
}
|
||||
|
||||
sess.User = user
|
||||
sess.LastRefresh = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
package authmiddleware_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/session"
|
||||
)
|
||||
|
||||
func TestAnonymous(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
setSession(app, dc, authmiddleware.Session{
|
||||
Permissions: authmiddleware.PermAnonymous,
|
||||
})
|
||||
}
|
||||
|
||||
func getApp(dc discord.IDiscordClient) *fiber.App {
|
||||
app := fiber.New(fiber.Config{})
|
||||
sessionMiddleware, sessionStore := session.NewWithStore()
|
||||
sessionStore.RegisterType(authmiddleware.Session{})
|
||||
|
||||
app.Use(sessionMiddleware, authmiddleware.New(dc, []string{}, []string{}))
|
||||
app.Get("/", authState)
|
||||
app.Post("/updateSession", updateSession)
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session) error {
|
||||
body := bytes.Buffer{}
|
||||
json.NewEncoder(&body).Encode(sess)
|
||||
|
||||
// do mocks here
|
||||
|
||||
req, _ := http.NewRequest("POST", "/updateSession", &body)
|
||||
_, err := app.Test(req)
|
||||
return err
|
||||
}
|
||||
|
||||
func authState(c fiber.Ctx) error {
|
||||
s := authmiddleware.SessionFrom(c)
|
||||
|
||||
permList := []string{}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermAnonymous {
|
||||
permList = append(permList, "anonymous")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermUser {
|
||||
permList = append(permList, "user")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSupport {
|
||||
permList = append(permList, "support")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSuperuser {
|
||||
permList = append(permList, "superuser")
|
||||
}
|
||||
|
||||
return c.JSON(permList)
|
||||
}
|
||||
|
||||
func updateSession(c fiber.Ctx) error {
|
||||
var newSession *authmiddleware.Session
|
||||
c.Bind().JSON(&newSession)
|
||||
|
||||
sc := session.FromContext(c)
|
||||
sc.Set(authmiddleware.SessionKey, newSession)
|
||||
|
||||
return c.SendString("ok")
|
||||
}
|
20
authmiddleware/must.go
Normal file
20
authmiddleware/must.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package authmiddleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
func MustHavePermission(perm Permission) func(fiber.Ctx) error {
|
||||
return func(c fiber.Ctx) error {
|
||||
sess := SessionFrom(c)
|
||||
|
||||
if sess.Permissions >= perm {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
return types.NewAPIError(http.StatusForbidden, "no sorry").Send(c)
|
||||
}
|
||||
}
|
92
authmiddleware/must_test.go
Normal file
92
authmiddleware/must_test.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package authmiddleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"git.sapphic.engineer/roleypoly/v4/types/fixtures"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMustAnonymous(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
Permissions: authmiddleware.PermAnonymous,
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
err = get(t, app, cookie, "/must/support")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
err = get(t, app, cookie, "/must/superuser")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
}
|
||||
|
||||
func TestMustUser(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
user := fixtures.User
|
||||
mockUser(dc, &user)
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
User: &user,
|
||||
AccessToken: "access-token",
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/support")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
err = get(t, app, cookie, "/must/superuser")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
}
|
||||
|
||||
func TestMustSupport(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
user := &types.DiscordUser{
|
||||
ID: "support-user",
|
||||
}
|
||||
mockUser(dc, user)
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
User: user,
|
||||
AccessToken: "access-token",
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/support")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/superuser")
|
||||
assert.ErrorIs(t, err, errUnauthorized)
|
||||
}
|
||||
|
||||
func TestMustSupuruser(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
user := &types.DiscordUser{
|
||||
ID: "superuser-user",
|
||||
}
|
||||
mockUser(dc, user)
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
User: user,
|
||||
AccessToken: "access-token",
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/support")
|
||||
assert.Nil(t, err)
|
||||
err = get(t, app, cookie, "/must/superuser")
|
||||
assert.Nil(t, err)
|
||||
}
|
47
authmiddleware/session.go
Normal file
47
authmiddleware/session.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package authmiddleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Permissions Permission
|
||||
User *types.DiscordUser
|
||||
AccessToken string
|
||||
LastRefresh time.Time
|
||||
}
|
||||
|
||||
func (s *Session) AsJSON() []byte {
|
||||
out, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
log.Panicln("failed to marshal session json: ", err)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func SessionFromJSON(input []byte) (*Session, error) {
|
||||
var s Session
|
||||
err := json.Unmarshal(input, &s)
|
||||
return &s, err
|
||||
}
|
||||
|
||||
var SessionKey uint8
|
||||
|
||||
func New(discordClient discord.IDiscordClient, supportIDs []string, superuserIDs []string) func(fiber.Ctx) error {
|
||||
am := AuthMiddleware{
|
||||
Client: discordClient,
|
||||
supportIDs: supportIDs,
|
||||
superuserIDs: superuserIDs,
|
||||
}
|
||||
|
||||
return am.Handle
|
||||
}
|
||||
|
||||
var DefaultSession *Session = &Session{Permissions: PermAnonymous}
|
117
authmiddleware/util_test.go
Normal file
117
authmiddleware/util_test.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package authmiddleware_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
|
||||
"git.sapphic.engineer/roleypoly/v4/types"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/session"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func mockUser(dc *clientmock.DiscordClientMock, user *types.DiscordUser) {
|
||||
if user == nil {
|
||||
dc.MockResponse("GET", "/users/@me", http.StatusUnauthorized, "unauthorized").Once()
|
||||
}
|
||||
|
||||
dc.MockResponse("GET", "/users/@me", 200, user).Once()
|
||||
}
|
||||
|
||||
func setSession(app *fiber.App, dc *clientmock.DiscordClientMock, sess authmiddleware.Session, inputCookie *http.Cookie) (*http.Cookie, error) {
|
||||
dc.On("ClientAuth", mock.AnythingOfType("*http.Request"), "access-token")
|
||||
|
||||
body := bytes.Buffer{}
|
||||
json.NewEncoder(&body).Encode(sess)
|
||||
|
||||
req, _ := http.NewRequest("POST", "/updateSession", &body)
|
||||
|
||||
if inputCookie != nil {
|
||||
req.AddCookie(inputCookie)
|
||||
}
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookie := resp.Cookies()[0]
|
||||
return cookie, err
|
||||
}
|
||||
|
||||
func authState(c fiber.Ctx) error {
|
||||
s := authmiddleware.SessionFrom(c)
|
||||
|
||||
permList := []string{}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermAnonymous {
|
||||
permList = append(permList, "anonymous")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermUser {
|
||||
permList = append(permList, "user")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSupport {
|
||||
permList = append(permList, "support")
|
||||
}
|
||||
|
||||
if s.Permissions >= authmiddleware.PermSuperuser {
|
||||
permList = append(permList, "superuser")
|
||||
}
|
||||
|
||||
return c.JSON(permList)
|
||||
}
|
||||
|
||||
func updateSession(c fiber.Ctx) error {
|
||||
var newSession *authmiddleware.Session
|
||||
c.Bind().JSON(&newSession)
|
||||
|
||||
sc := session.FromContext(c)
|
||||
sc.Set(authmiddleware.SessionKey, newSession.AsJSON())
|
||||
|
||||
return c.SendString("ok")
|
||||
}
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized")
|
||||
|
||||
func get(t *testing.T, app *fiber.App, cookie *http.Cookie, path string) error {
|
||||
req, err := http.NewRequest("GET", path, nil)
|
||||
req.AddCookie(cookie)
|
||||
assert.Nil(t, err)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return errUnauthorized
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ok(c fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
}
|
||||
|
||||
func getApp(dc discord.IDiscordClient) *fiber.App {
|
||||
app := fiber.New(fiber.Config{})
|
||||
sessionMiddleware, sessionStore := session.NewWithStore()
|
||||
sessionStore.RegisterType(authmiddleware.Session{})
|
||||
|
||||
app.Use(sessionMiddleware, authmiddleware.New(dc, []string{"support-user"}, []string{"superuser-user"}))
|
||||
app.Get("/", authState)
|
||||
app.Post("/updateSession", updateSession)
|
||||
app.Get("/must/user", ok, authmiddleware.MustHavePermission(authmiddleware.PermUser))
|
||||
app.Get("/must/support", ok, authmiddleware.MustHavePermission(authmiddleware.PermSupport))
|
||||
app.Get("/must/superuser", ok, authmiddleware.MustHavePermission(authmiddleware.PermSuperuser))
|
||||
|
||||
return app
|
||||
}
|
34
authmiddleware/validation_test.go
Normal file
34
authmiddleware/validation_test.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package authmiddleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.sapphic.engineer/roleypoly/v4/authmiddleware"
|
||||
"git.sapphic.engineer/roleypoly/v4/discord/clientmock"
|
||||
"git.sapphic.engineer/roleypoly/v4/types/fixtures"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestExpiration(t *testing.T) {
|
||||
dc := clientmock.NewDiscordClientMock()
|
||||
app := getApp(dc)
|
||||
|
||||
mockUser(dc, &fixtures.User)
|
||||
cookie, err := setSession(app, dc, authmiddleware.Session{
|
||||
AccessToken: "access-token",
|
||||
}, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Good request for now,,
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.Nil(t, err)
|
||||
|
||||
mockUser(dc, nil)
|
||||
cookie, err = setSession(app, dc, authmiddleware.Session{
|
||||
AccessToken: "access-token",
|
||||
}, cookie)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = get(t, app, cookie, "/must/user")
|
||||
assert.ErrorIs(t, errUnauthorized, err)
|
||||
}
|
|
@ -46,7 +46,7 @@ func (c *DiscordClientMock) ClientAuth(req *http.Request, accessToken string) {
|
|||
c.Called(req, accessToken)
|
||||
}
|
||||
|
||||
func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) {
|
||||
func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, data any) *mock.Call {
|
||||
body := bytes.Buffer{}
|
||||
json.NewEncoder(&body).Encode(data)
|
||||
|
||||
|
@ -57,7 +57,7 @@ func (c *DiscordClientMock) MockResponse(method, path string, statusCode int, da
|
|||
|
||||
pathMatcher := regexp.MustCompile(strings.ReplaceAll(path, "*", "[a-z0-9_-]+"))
|
||||
|
||||
c.On("Do", mock.MatchedBy(func(req *http.Request) bool {
|
||||
return c.On("Do", mock.MatchedBy(func(req *http.Request) bool {
|
||||
return req.Method == method && pathMatcher.MatchString(req.URL.Path)
|
||||
})).Return(r, nil)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue