Skip to content

Commit

Permalink
Refactor the doUpgrade transport return parameters (#76)
Browse files Browse the repository at this point in the history
When refactoring the transportSessions, there may be some updates to the doUpgrade function. This is a small PR getting that code inline for an update.
  • Loading branch information
njones authored Apr 27, 2023
1 parent c34fad1 commit 5de2c30
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 45 deletions.
54 changes: 32 additions & 22 deletions engineio/server.v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ const Version2 EIOVersionStr = "2"

func init() { registry[Version2.Int()] = NewServerV2 }

type upgradeable struct {
transport eiot.Transporter
isProbeOnInit bool
upgradeFn func() error
err error
}

type serverV2 struct {
path *string

Expand All @@ -53,7 +60,7 @@ type serverV2 struct {
eto []eiot.Option

servers map[EIOVersionStr]server
sessions mapSessions
sessions transportSessions
transports map[TransportName]func(SessionID, eiot.Codec) eiot.Transporter

transportRunError chan error
Expand Down Expand Up @@ -185,7 +192,8 @@ func (v2 *serverV2) serveTransport(w http.ResponseWriter, r *http.Request) (tran
sessionID, _ := ctx.Value(ctxSessionID).(SessionID)
if sessionID == "" {
sessionID = v2.generateID()
transportName, _ := r.Context().Value(ctxTransportName).(TransportName)

transportName := transportNameFrom(r)
transport = v2.transports[transportName](sessionID, v2.codec)
if err := v2.sessions.Set(transport); err != nil {
return nil, err
Expand All @@ -208,28 +216,28 @@ func (v2 *serverV2) serveTransport(w http.ResponseWriter, r *http.Request) (tran
}
}

var isProbeOnInit bool
var fnOnUpgrade func() error
transport, isProbeOnInit, fnOnUpgrade, err = v2.doUpgrade(v2.sessions.Get(sessionID))(w, r)
if err != nil {
return nil, err
upgrade := v2.doUpgrade(v2.sessions.Get(sessionID))(w, r)
if upgrade.err != nil {
return nil, upgrade.err
}

var opts []eiot.Option
if isProbeOnInit {
opts = []eiot.Option{eiot.OnInitProbe(isProbeOnInit)}
if upgrade.isProbeOnInit {
opts = []eiot.Option{eiot.OnInitProbe(upgrade.isProbeOnInit)}
}
if fnOnUpgrade != nil {
opts = []eiot.Option{eiot.OnUpgrade(fnOnUpgrade)}
if upgrade.upgradeFn != nil {
opts = []eiot.Option{eiot.OnUpgrade(upgrade.upgradeFn)}
}

ctx = v2.sessions.WithTimeout(ctx, v2.pingTimeout*4)
ctx = v2.sessions.WithInterval(ctx, v2.pingTimeout)

opts = append(opts, eiot.WithNoPing())
go func() { v2.transportRunError <- transport.Run(w, r.WithContext(ctx), append(v2.eto, opts...)...) }()
go func() {
v2.transportRunError <- upgrade.transport.Run(w, r.WithContext(ctx), append(v2.eto, opts...)...)
}()

return
return upgrade.transport, nil
}

func (v2 *serverV2) handshakePacket(sessionID SessionID, transportName TransportName) eiop.Packet {
Expand All @@ -243,24 +251,26 @@ func (v2 *serverV2) handshakePacket(sessionID SessionID, transportName Transport
}
}

func (v2 *serverV2) doUpgrade(transport eiot.Transporter, err error) func(http.ResponseWriter, *http.Request) (eiot.Transporter, bool, func() error, error) {
var isUpgrade bool
return func(w http.ResponseWriter, r *http.Request) (eiot.Transporter, bool, func() error, error) {
func (v2 *serverV2) doUpgrade(transport eiot.Transporter, err error) func(http.ResponseWriter, *http.Request) upgradeable {
return func(w http.ResponseWriter, r *http.Request) upgradeable {
if err != nil {
return transport, isUpgrade, nil, err
return upgradeable{transport: transport, err: err}
}
sessionID, from, to := transport.ID(), transport.Name(), transportNameFrom(r)
if to != from {
for _, val := range v2.upgrades(from, v2.transports) {
if string(to) == val {
transport = v2.transports[to](sessionID, v2.codec)
isUpgrade = true
return transport, isUpgrade, func() error { return v2.sessions.Set(transport) }, nil
return upgradeable{
transport: v2.transports[to](sessionID, v2.codec),
isProbeOnInit: true,
upgradeFn: func() error { return v2.sessions.Set(transport) },
err: nil,
}
}
}
return nil, false, nil, ErrTransportUpgradeFailed
return upgradeable{err: ErrTransportUpgradeFailed}
}
return transport, isUpgrade, nil, err
return upgradeable{transport: transport, err: err}
}
}

Expand Down
23 changes: 12 additions & 11 deletions engineio/server.v3.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (v3 *serverV3) serveTransport(w http.ResponseWriter, r *http.Request) (tran
sessionID := sessionIDFrom(r)
if sessionID == "" {
sessionID = v3.generateID()

transportName := transportNameFrom(r)
transport = v3.transports[transportName](sessionID, v3.codec)
if err := v3.sessions.Set(transport); err != nil {
Expand All @@ -112,27 +113,27 @@ func (v3 *serverV3) serveTransport(w http.ResponseWriter, r *http.Request) (tran
}
}

var isProbeOnInit bool
var fnOnUpgrade func() error
transport, isProbeOnInit, fnOnUpgrade, err = v3.doUpgrade(v3.sessions.Get(sessionID))(w, r)
if err != nil {
return nil, err
upgrade := v3.doUpgrade(v3.sessions.Get(sessionID))(w, r)
if upgrade.err != nil {
return nil, upgrade.err
}

var opts []eiot.Option
if isProbeOnInit {
opts = []eiot.Option{eiot.OnInitProbe(isProbeOnInit)}
if upgrade.isProbeOnInit {
opts = []eiot.Option{eiot.OnInitProbe(upgrade.isProbeOnInit)}
}
if fnOnUpgrade != nil {
opts = []eiot.Option{eiot.OnUpgrade(fnOnUpgrade)}
if upgrade.upgradeFn != nil {
opts = []eiot.Option{eiot.OnUpgrade(upgrade.upgradeFn)}
}

ctx = v3.sessions.WithInterval(ctx, v3.pingInterval)
ctx = v3.sessions.WithTimeout(ctx, v3.pingTimeout)

go func() { v3.transportRunError <- transport.Run(w, r.WithContext(ctx), append(v3.eto, opts...)...) }()
go func() {
v3.transportRunError <- upgrade.transport.Run(w, r.WithContext(ctx), append(v3.eto, opts...)...)
}()

return
return upgrade.transport, nil
}

func (v3 *serverV3) handshakePacket(sessionID SessionID, transportName TransportName) eiop.Packet {
Expand Down
19 changes: 8 additions & 11 deletions engineio/server.v4.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (v4 *serverV4) serveTransport(w http.ResponseWriter, r *http.Request) (tran
sessionID = v4.generateID()
ctx = context.WithValue(ctx, ctxSessionID, sessionID)

transportName, _ := r.Context().Value(ctxTransportName).(TransportName)
transportName := transportNameFrom(r)
transport = v4.transports[transportName](sessionID, v4.codec)
if err := v4.sessions.Set(transport); err != nil {
return nil, err
Expand All @@ -98,28 +98,25 @@ func (v4 *serverV4) serveTransport(w http.ResponseWriter, r *http.Request) (tran
}
}

var isProbeOnInit bool
var fnOnUpgrade func() error
transport, _, fnOnUpgrade, err = v4.doUpgrade(v4.sessions.Get(sessionID))(w, r)
upgrade := v4.doUpgrade(v4.sessions.Get(sessionID))(w, r)
if err != nil {
return nil, err
}

var opts []eiot.Option
if isProbeOnInit {
opts = []eiot.Option{eiot.OnInitProbe(isProbeOnInit)}
}
if fnOnUpgrade != nil {
opts = []eiot.Option{eiot.OnUpgrade(fnOnUpgrade)}
if upgrade.upgradeFn != nil {
opts = []eiot.Option{eiot.OnUpgrade(upgrade.upgradeFn)}
}

ctx = v4.sessions.WithCancel(ctx)
ctx = v4.sessions.WithInterval(ctx, v4.pingInterval)
ctx = v4.sessions.WithTimeout(ctx, v4.pingTimeout)

go func() { v4.transportRunError <- transport.Run(w, r.WithContext(ctx), append(v4.eto, opts...)...) }()
go func() {
v4.transportRunError <- upgrade.transport.Run(w, r.WithContext(ctx), append(v4.eto, opts...)...)
}()

return
return upgrade.transport, err
}

func (v4 *serverV4) handshakePacket(sessionID SessionID, transportName TransportName) eiop.Packet {
Expand Down
2 changes: 1 addition & 1 deletion engineio/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func storeDuration(addr *time.Duration, val time.Duration) {
atomic.StoreInt64((*int64)(addr), int64(val))
}

type mapSessions interface {
type transportSessions interface {
Set(eiot.Transporter) error
Get(SessionID) (eiot.Transporter, error)

Expand Down

0 comments on commit 5de2c30

Please sign in to comment.