mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-05-17 08:15:56 +03:00
refactor(session): store user ID in session instead of full struct
Replaces storing the full User object in the session cookie with just the user ID. GetLoginUser now re-fetches the user from the database on every request so credential/permission changes take effect immediately without requiring a re-login. Includes a backward-compatible migration path for existing sessions that still carry the old struct payload.
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v3/database"
|
||||
"github.com/mhsanaei/3x-ui/v3/database/model"
|
||||
"github.com/mhsanaei/3x-ui/v3/logger"
|
||||
|
||||
@@ -27,7 +28,7 @@ func SetLoginUser(c *gin.Context, user *model.User) error {
|
||||
return nil
|
||||
}
|
||||
s := sessions.Default(c)
|
||||
s.Set(loginUserKey, *user)
|
||||
s.Set(loginUserKey, user.Id)
|
||||
return s.Save()
|
||||
}
|
||||
|
||||
@@ -49,7 +50,7 @@ func GetLoginUser(c *gin.Context) *model.User {
|
||||
if obj == nil {
|
||||
return nil
|
||||
}
|
||||
user, ok := obj.(model.User)
|
||||
userID, ok := sessionUserID(obj)
|
||||
if !ok {
|
||||
s.Delete(loginUserKey)
|
||||
if err := s.Save(); err != nil {
|
||||
@@ -57,13 +58,77 @@ func GetLoginUser(c *gin.Context) *model.User {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &user
|
||||
if legacyUserID, ok := legacySessionUserID(obj); ok {
|
||||
s.Set(loginUserKey, legacyUserID)
|
||||
if err := s.Save(); err != nil {
|
||||
logger.Warning("session: failed to migrate legacy user payload:", err)
|
||||
}
|
||||
}
|
||||
user, err := getUserByID(userID)
|
||||
if err != nil {
|
||||
logger.Warning("session: failed to load user:", err)
|
||||
s.Delete(loginUserKey)
|
||||
if saveErr := s.Save(); saveErr != nil {
|
||||
logger.Warning("session: failed to drop missing user:", saveErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func IsLogin(c *gin.Context) bool {
|
||||
return GetLoginUser(c) != nil
|
||||
}
|
||||
|
||||
func sessionUserID(obj any) (int, bool) {
|
||||
switch v := obj.(type) {
|
||||
case int:
|
||||
return v, v > 0
|
||||
case int64:
|
||||
return int(v), v > 0
|
||||
case int32:
|
||||
return int(v), v > 0
|
||||
case float64:
|
||||
id := int(v)
|
||||
return id, v == float64(id) && id > 0
|
||||
case model.User:
|
||||
return v.Id, v.Id > 0
|
||||
case *model.User:
|
||||
if v == nil {
|
||||
return 0, false
|
||||
}
|
||||
return v.Id, v.Id > 0
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func legacySessionUserID(obj any) (int, bool) {
|
||||
switch v := obj.(type) {
|
||||
case model.User:
|
||||
return v.Id, v.Id > 0
|
||||
case *model.User:
|
||||
if v == nil {
|
||||
return 0, false
|
||||
}
|
||||
return v.Id, v.Id > 0
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func getUserByID(id int) (*model.User, error) {
|
||||
db := database.GetDB()
|
||||
if db == nil {
|
||||
return nil, http.ErrServerClosed
|
||||
}
|
||||
user := &model.User{}
|
||||
if err := db.Model(model.User{}).Where("id = ?", id).First(user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func ClearSession(c *gin.Context) error {
|
||||
s := sessions.Default(c)
|
||||
s.Clear()
|
||||
|
||||
47
web/session/session_test.go
Normal file
47
web/session/session_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v3/database/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/cookie"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetLoginUserStoresOnlyUserID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(sessions.Sessions(sessionCookieName, cookie.NewStore([]byte("01234567890123456789012345678901"))))
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
if err := SetLoginUser(c, &model.User{Id: 7, Username: "admin", Password: "hash"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := sessions.Default(c).Get(loginUserKey)
|
||||
if got != 7 {
|
||||
t.Fatalf("stored session payload = %#v, want user id only", got)
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionUserIDSupportsLegacyUserPayload(t *testing.T) {
|
||||
id, ok := sessionUserID(model.User{Id: 11, Username: "admin", Password: "hash"})
|
||||
if !ok || id != 11 {
|
||||
t.Fatalf("legacy session payload resolved to (%d, %v), want (11, true)", id, ok)
|
||||
}
|
||||
id, ok = sessionUserID(&model.User{Id: 12, Username: "admin", Password: "hash"})
|
||||
if !ok || id != 12 {
|
||||
t.Fatalf("legacy pointer session payload resolved to (%d, %v), want (12, true)", id, ok)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user