Skip to content

Commit

Permalink
fix: #2689 change sessionId in cookie (#3203)
Browse files Browse the repository at this point in the history
  • Loading branch information
glennliao authored Dec 20, 2023
1 parent 645c5ff commit d08e3ef
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 4 deletions.
18 changes: 14 additions & 4 deletions net/ghttp/ghttp_server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Create a new request object.
request := newRequest(s, r, w)

// Get sessionId before user handler
sessionId := request.GetSessionId()

defer func() {
request.LeaveTime = gtime.TimestampMilli()
// error log handling.
Expand Down Expand Up @@ -176,10 +179,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Automatically set the session id to cookie
// if it creates a new session id in this request
// and SessionCookieOutput is enabled.
if s.config.SessionCookieOutput &&
request.Session.IsDirty() &&
request.Session.MustId() != request.GetSessionId() {
request.Cookie.SetSessionId(request.Session.MustId())
if s.config.SessionCookieOutput && request.Session.IsDirty() {
// Can change by r.Session.SetId("") before init session
// Can change by r.Cookie.SetSessionId("")
sidFromSession, sidFromRequest := request.Session.MustId(), request.GetSessionId()
if sidFromSession != sidFromRequest {
if sidFromSession != sessionId {
request.Cookie.SetSessionId(sidFromSession)
} else {
request.Cookie.SetSessionId(sidFromRequest)
}
}
}
// Output the cookie content to the client.
request.Cookie.Flush()
Expand Down
89 changes: 89 additions & 0 deletions net/ghttp/ghttp_z_unit_feature_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,92 @@ func Test_Session_Custom_Id(t *testing.T) {
t.Assert(client.GetContent(ctx, "/value"), value)
})
}

func Test_Session_New_Id(t *testing.T) {
var (
sessionId = "1234567890"
newSessionId = "0987654321"
newSessionId2 = "abcdefghij"
key = "key"
value = "value"
s = g.Server(guid.S())
)
s.BindHandler("/id", func(r *ghttp.Request) {
if err := r.Session.SetId(sessionId); err != nil {
r.Response.WriteExit(err.Error())
}
if err := r.Session.Set(key, value); err != nil {
r.Response.WriteExit(err.Error())
}
r.Response.WriteExit(r.Session.Id())
})

s.BindHandler("/newIdBySession", func(r *ghttp.Request) {
// Use before session init
if err := r.Session.SetId(newSessionId); err != nil {
r.Response.WriteExit(err.Error())
}
if err := r.Session.Set(key, value); err != nil {
r.Response.WriteExit(err.Error())
}
r.Response.WriteExit(r.Session.Id())
})

s.BindHandler("/newIdByCookie", func(r *ghttp.Request) {
if err := r.Session.Remove("someKey"); err != nil {
r.Response.WriteExit(err.Error())
}

r.Cookie.SetSessionId(newSessionId2)
//r.Response.WriteExit(r.Session.Id()) // only change in cookie

r.Response.WriteExit(newSessionId2)
})

s.BindHandler("/value", func(r *ghttp.Request) {
r.Response.WriteExit(r.Session.Get(key))
})
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()

time.Sleep(100 * time.Millisecond)

gtest.C(t, func(t *gtest.T) {
client := g.Client()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort()))
r, err := client.Get(ctx, "/id")
t.AssertNil(err)
defer r.Close()
t.Assert(r.ReadAllString(), sessionId)
t.Assert(r.GetCookie(s.GetSessionIdName()), sessionId)
})
gtest.C(t, func(t *gtest.T) {
client := g.Client()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort()))
client.SetHeader(s.GetSessionIdName(), sessionId)
t.Assert(client.GetContent(ctx, "/value"), value)
})

gtest.C(t, func(t *gtest.T) {
client := g.Client()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort()))
client.SetHeader(s.GetSessionIdName(), sessionId)
r, err := client.Get(ctx, "/newIdBySession")
t.AssertNil(err)
defer r.Close()
t.Assert(r.ReadAllString(), newSessionId)
t.Assert(r.GetCookie(s.GetSessionIdName()), newSessionId)
})

gtest.C(t, func(t *gtest.T) {
client := g.Client()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort()))
r, err := client.Get(ctx, "/newIdByCookie")
client.SetHeader(s.GetSessionIdName(), sessionId)
t.AssertNil(err)
defer r.Close()
t.Assert(r.ReadAllString(), newSessionId2)
t.Assert(r.GetCookie(s.GetSessionIdName()), newSessionId2)
})
}

0 comments on commit d08e3ef

Please sign in to comment.