refactor(session): store user ID in session instead of full struct

Replaces storing the full User object in the session cookie with just
the user ID. GetLoginUser now re-fetches the user from the database on
every request so credential/permission changes take effect immediately
without requiring a re-login. Includes a backward-compatible migration
path for existing sessions that still carry the old struct payload.
This commit is contained in:
farhadh
2026-05-11 21:09:26 +02:00
parent cb962175c2
commit ce88b0b432
2 changed files with 115 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"time"
"github.com/mhsanaei/3x-ui/v3/database"
"github.com/mhsanaei/3x-ui/v3/database/model"
"github.com/mhsanaei/3x-ui/v3/logger"
@@ -27,7 +28,7 @@ func SetLoginUser(c *gin.Context, user *model.User) error {
return nil
}
s := sessions.Default(c)
s.Set(loginUserKey, *user)
s.Set(loginUserKey, user.Id)
return s.Save()
}
@@ -49,7 +50,7 @@ func GetLoginUser(c *gin.Context) *model.User {
if obj == nil {
return nil
}
user, ok := obj.(model.User)
userID, ok := sessionUserID(obj)
if !ok {
s.Delete(loginUserKey)
if err := s.Save(); err != nil {
@@ -57,13 +58,77 @@ func GetLoginUser(c *gin.Context) *model.User {
}
return nil
}
return &user
if legacyUserID, ok := legacySessionUserID(obj); ok {
s.Set(loginUserKey, legacyUserID)
if err := s.Save(); err != nil {
logger.Warning("session: failed to migrate legacy user payload:", err)
}
}
user, err := getUserByID(userID)
if err != nil {
logger.Warning("session: failed to load user:", err)
s.Delete(loginUserKey)
if saveErr := s.Save(); saveErr != nil {
logger.Warning("session: failed to drop missing user:", saveErr)
}
return nil
}
return user
}
func IsLogin(c *gin.Context) bool {
return GetLoginUser(c) != nil
}
func sessionUserID(obj any) (int, bool) {
switch v := obj.(type) {
case int:
return v, v > 0
case int64:
return int(v), v > 0
case int32:
return int(v), v > 0
case float64:
id := int(v)
return id, v == float64(id) && id > 0
case model.User:
return v.Id, v.Id > 0
case *model.User:
if v == nil {
return 0, false
}
return v.Id, v.Id > 0
default:
return 0, false
}
}
func legacySessionUserID(obj any) (int, bool) {
switch v := obj.(type) {
case model.User:
return v.Id, v.Id > 0
case *model.User:
if v == nil {
return 0, false
}
return v.Id, v.Id > 0
default:
return 0, false
}
}
func getUserByID(id int) (*model.User, error) {
db := database.GetDB()
if db == nil {
return nil, http.ErrServerClosed
}
user := &model.User{}
if err := db.Model(model.User{}).Where("id = ?", id).First(user).Error; err != nil {
return nil, err
}
return user, nil
}
func ClearSession(c *gin.Context) error {
s := sessions.Default(c)
s.Clear()

View File

@@ -0,0 +1,47 @@
package session
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/mhsanaei/3x-ui/v3/database/model"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
)
func TestSetLoginUserStoresOnlyUserID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(sessions.Sessions(sessionCookieName, cookie.NewStore([]byte("01234567890123456789012345678901"))))
router.GET("/", func(c *gin.Context) {
if err := SetLoginUser(c, &model.User{Id: 7, Username: "admin", Password: "hash"}); err != nil {
t.Fatal(err)
}
got := sessions.Default(c).Get(loginUserKey)
if got != 7 {
t.Fatalf("stored session payload = %#v, want user id only", got)
}
c.Status(http.StatusNoContent)
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
router.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
}
}
func TestSessionUserIDSupportsLegacyUserPayload(t *testing.T) {
id, ok := sessionUserID(model.User{Id: 11, Username: "admin", Password: "hash"})
if !ok || id != 11 {
t.Fatalf("legacy session payload resolved to (%d, %v), want (11, true)", id, ok)
}
id, ok = sessionUserID(&model.User{Id: 12, Username: "admin", Password: "hash"})
if !ok || id != 12 {
t.Fatalf("legacy pointer session payload resolved to (%d, %v), want (12, true)", id, ok)
}
}