diff --git a/net/ghttp/ghttp_server_handler.go b/net/ghttp/ghttp_server_handler.go index 09c66be3420..da4863032a9 100644 --- a/net/ghttp/ghttp_server_handler.go +++ b/net/ghttp/ghttp_server_handler.go @@ -43,6 +43,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Create a new request object. request := newRequest(s, r, w) + // Get sessionId before user handler + sessionId := request.GetSessionId() + defer func() { request.LeaveTime = gtime.TimestampMilli() // error log handling. @@ -176,10 +179,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Automatically set the session id to cookie // if it creates a new session id in this request // and SessionCookieOutput is enabled. - if s.config.SessionCookieOutput && - request.Session.IsDirty() && - request.Session.MustId() != request.GetSessionId() { - request.Cookie.SetSessionId(request.Session.MustId()) + if s.config.SessionCookieOutput && request.Session.IsDirty() { + // Can change by r.Session.SetId("") before init session + // Can change by r.Cookie.SetSessionId("") + sidFromSession, sidFromRequest := request.Session.MustId(), request.GetSessionId() + if sidFromSession != sidFromRequest { + if sidFromSession != sessionId { + request.Cookie.SetSessionId(sidFromSession) + } else { + request.Cookie.SetSessionId(sidFromRequest) + } + } } // Output the cookie content to the client. request.Cookie.Flush() diff --git a/net/ghttp/ghttp_z_unit_feature_session_test.go b/net/ghttp/ghttp_z_unit_feature_session_test.go index 7bb9514a4be..2f4b36dd5a4 100644 --- a/net/ghttp/ghttp_z_unit_feature_session_test.go +++ b/net/ghttp/ghttp_z_unit_feature_session_test.go @@ -189,3 +189,92 @@ func Test_Session_Custom_Id(t *testing.T) { t.Assert(client.GetContent(ctx, "/value"), value) }) } + +func Test_Session_New_Id(t *testing.T) { + var ( + sessionId = "1234567890" + newSessionId = "0987654321" + newSessionId2 = "abcdefghij" + key = "key" + value = "value" + s = g.Server(guid.S()) + ) + s.BindHandler("/id", func(r *ghttp.Request) { + if err := r.Session.SetId(sessionId); err != nil { + r.Response.WriteExit(err.Error()) + } + if err := r.Session.Set(key, value); err != nil { + r.Response.WriteExit(err.Error()) + } + r.Response.WriteExit(r.Session.Id()) + }) + + s.BindHandler("/newIdBySession", func(r *ghttp.Request) { + // Use before session init + if err := r.Session.SetId(newSessionId); err != nil { + r.Response.WriteExit(err.Error()) + } + if err := r.Session.Set(key, value); err != nil { + r.Response.WriteExit(err.Error()) + } + r.Response.WriteExit(r.Session.Id()) + }) + + s.BindHandler("/newIdByCookie", func(r *ghttp.Request) { + if err := r.Session.Remove("someKey"); err != nil { + r.Response.WriteExit(err.Error()) + } + + r.Cookie.SetSessionId(newSessionId2) + //r.Response.WriteExit(r.Session.Id()) // only change in cookie + + r.Response.WriteExit(newSessionId2) + }) + + s.BindHandler("/value", func(r *ghttp.Request) { + r.Response.WriteExit(r.Session.Get(key)) + }) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + + gtest.C(t, func(t *gtest.T) { + client := g.Client() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort())) + r, err := client.Get(ctx, "/id") + t.AssertNil(err) + defer r.Close() + t.Assert(r.ReadAllString(), sessionId) + t.Assert(r.GetCookie(s.GetSessionIdName()), sessionId) + }) + gtest.C(t, func(t *gtest.T) { + client := g.Client() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort())) + client.SetHeader(s.GetSessionIdName(), sessionId) + t.Assert(client.GetContent(ctx, "/value"), value) + }) + + gtest.C(t, func(t *gtest.T) { + client := g.Client() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort())) + client.SetHeader(s.GetSessionIdName(), sessionId) + r, err := client.Get(ctx, "/newIdBySession") + t.AssertNil(err) + defer r.Close() + t.Assert(r.ReadAllString(), newSessionId) + t.Assert(r.GetCookie(s.GetSessionIdName()), newSessionId) + }) + + gtest.C(t, func(t *gtest.T) { + client := g.Client() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort())) + r, err := client.Get(ctx, "/newIdByCookie") + client.SetHeader(s.GetSessionIdName(), sessionId) + t.AssertNil(err) + defer r.Close() + t.Assert(r.ReadAllString(), newSessionId2) + t.Assert(r.GetCookie(s.GetSessionIdName()), newSessionId2) + }) +}