Skip to content

Commit

Permalink
use a generic streams map for outgoing streams
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 2, 2022
1 parent 936a585 commit 67c30d3
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 510 deletions.
10 changes: 6 additions & 4 deletions streams_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ type streamsMap struct {
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController

mutex sync.Mutex
outgoingBidiStreams *outgoingBidiStreamsMap
outgoingUniStreams *outgoingUniStreamsMap
outgoingBidiStreams *outgoingStreamsMap[streamI]
outgoingUniStreams *outgoingStreamsMap[sendStreamI]
incomingBidiStreams *incomingBidiStreamsMap
incomingUniStreams *incomingUniStreamsMap
reset bool
Expand Down Expand Up @@ -85,7 +85,8 @@ func newStreamsMap(
}

func (m *streamsMap) initMaps() {
m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
m.outgoingBidiStreams = newOutgoingStreamsMap(
protocol.StreamTypeBidi,
func(num protocol.StreamNum) streamI {
id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
return newStream(id, m.sender, m.newFlowController(id), m.version)
Expand All @@ -100,7 +101,8 @@ func (m *streamsMap) initMaps() {
m.maxIncomingBidiStreams,
m.sender.queueControlFrame,
)
m.outgoingUniStreams = newOutgoingUniStreamsMap(
m.outgoingUniStreams = newOutgoingStreamsMap(
protocol.StreamTypeUni,
func(num protocol.StreamNum) sendStreamI {
id := num.StreamID(protocol.StreamTypeUni, m.perspective)
return newSendStream(id, m.sender, m.newFlowController(id), m.version)
Expand Down
64 changes: 34 additions & 30 deletions streams_map_outgoing_uni.go → streams_map_outgoing.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny

package quic

import (
Expand All @@ -12,10 +8,16 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)

type outgoingUniStreamsMap struct {
type outgoingStream interface {
updateSendWindow(protocol.ByteCount)
closeForShutdown(error)
}

type outgoingStreamsMap[T outgoingStream] struct {
mutex sync.RWMutex

streams map[protocol.StreamNum]sendStreamI
streamType protocol.StreamType
streams map[protocol.StreamNum]T

openQueue map[uint64]chan struct{}
lowestInQueue uint64
Expand All @@ -25,18 +27,20 @@ type outgoingUniStreamsMap struct {
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream

newStream func(protocol.StreamNum) sendStreamI
newStream func(protocol.StreamNum) T
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)

closeErr error
}

func newOutgoingUniStreamsMap(
newStream func(protocol.StreamNum) sendStreamI,
func newOutgoingStreamsMap[T outgoingStream](
streamType protocol.StreamType,
newStream func(protocol.StreamNum) T,
queueControlFrame func(wire.Frame),
) *outgoingUniStreamsMap {
return &outgoingUniStreamsMap{
streams: make(map[protocol.StreamNum]sendStreamI),
) *outgoingStreamsMap[T] {
return &outgoingStreamsMap[T]{
streamType: streamType,
streams: make(map[protocol.StreamNum]T),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
Expand All @@ -45,32 +49,32 @@ func newOutgoingUniStreamsMap(
}
}

func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) {
func (m *outgoingStreamsMap[T]) OpenStream() (T, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

if m.closeErr != nil {
return nil, m.closeErr
return *new(T), m.closeErr
}

// if there are OpenStreamSync calls waiting, return an error here
if len(m.openQueue) > 0 || m.nextStream > m.maxStream {
m.maybeSendBlockedFrame()
return nil, streamOpenErr{errTooManyOpenStreams}
return *new(T), streamOpenErr{errTooManyOpenStreams}
}
return m.openStream(), nil
}

func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) {
func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

if m.closeErr != nil {
return nil, m.closeErr
return *new(T), m.closeErr
}

if err := ctx.Err(); err != nil {
return nil, err
return *new(T), err
}

if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
Expand All @@ -92,13 +96,13 @@ func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
return *new(T), ctx.Err()
case <-waitChan:
}
m.mutex.Lock()

if m.closeErr != nil {
return nil, m.closeErr
return *new(T), m.closeErr
}
if m.nextStream > m.maxStream {
// no stream available. Continue waiting
Expand All @@ -112,7 +116,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI
}
}

func (m *outgoingUniStreamsMap) openStream() sendStreamI {
func (m *outgoingStreamsMap[T]) openStream() T {
s := m.newStream(m.nextStream)
m.streams[m.nextStream] = s
m.nextStream++
Expand All @@ -121,7 +125,7 @@ func (m *outgoingUniStreamsMap) openStream() sendStreamI {

// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset,
// if we haven't sent one for this offset yet
func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() {
func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() {
if m.blockedSent {
return
}
Expand All @@ -131,17 +135,17 @@ func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() {
streamNum = m.maxStream
}
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: protocol.StreamTypeUni,
Type: m.streamType,
StreamLimit: streamNum,
})
m.blockedSent = true
}

func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) {
func (m *outgoingStreamsMap[T]) GetStream(num protocol.StreamNum) (T, error) {
m.mutex.RLock()
if num >= m.nextStream {
m.mutex.RUnlock()
return nil, streamError{
return *new(T), streamError{
message: "peer attempted to open stream %d",
nums: []protocol.StreamNum{num},
}
Expand All @@ -151,7 +155,7 @@ func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI,
return s, nil
}

func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error {
func (m *outgoingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -165,7 +169,7 @@ func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error {
return nil
}

func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) {
func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -183,7 +187,7 @@ func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) {
// UpdateSendWindow is called when the peer's transport parameters are received.
// Only in the case of a 0-RTT handshake will we have open streams at this point.
// We might need to update the send window, in case the server increased it.
func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) {
func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) {
m.mutex.Lock()
for _, str := range m.streams {
str.updateSendWindow(limit)
Expand All @@ -192,7 +196,7 @@ func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) {
}

// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
func (m *outgoingUniStreamsMap) unblockOpenSync() {
func (m *outgoingStreamsMap[T]) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
Expand All @@ -211,7 +215,7 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() {
}
}

func (m *outgoingUniStreamsMap) CloseWithError(err error) {
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
m.mutex.Lock()
m.closeErr = err
for _, str := range m.streams {
Expand Down
Loading

0 comments on commit 67c30d3

Please sign in to comment.