diff --git a/mem.go b/mem.go index a024c08c..be19ef5d 100644 --- a/mem.go +++ b/mem.go @@ -167,22 +167,27 @@ func (m *Message) NumSegments() int64 { // Segment returns the segment with the given ID. func (m *Message) Segment(id SegmentID) (*Segment, error) { - m.mu.Lock() - defer m.mu.Unlock() + var seg *Segment if isInt32Bit() && id > maxInt32 { return nil, errSegment32Bit } - if seg := m.segment(id); seg != nil { + m.mu.Lock() + if seg = m.segment(id); seg != nil { + m.mu.Unlock() return seg, nil } if int64(id) >= m.Arena.NumSegments() { + m.mu.Unlock() return nil, errSegmentOutOfBounds } data, err := m.Arena.Data(id) if err != nil { + m.mu.Unlock() return nil, err } - return m.setSegment(id, data), nil + seg = m.setSegment(id, data) + m.mu.Unlock() + return seg, nil } // segment returns the segment with the given ID. @@ -230,19 +235,23 @@ func (m *Message) setSegment(id SegmentID, data []byte) *Segment { // cap(seg.Data) - len(seg.Data) >= sz. func (m *Message) allocSegment(sz Size) (*Segment, error) { m.mu.Lock() - defer m.mu.Unlock() + var seg *Segment if m.segs == nil && m.firstSeg.msg != nil { m.segs = make(map[SegmentID]*Segment) m.segs[0] = &m.firstSeg } id, data, err := m.Arena.Allocate(sz, m.segs) if err != nil { + m.mu.Unlock() return nil, err } if isInt32Bit() && id > maxInt32 { + m.mu.Unlock() return nil, errSegment32Bit } - return m.setSegment(id, data), nil + seg = m.setSegment(id, data) + m.mu.Unlock() + return seg, nil } // alloc allocates sz zero-filled bytes. It prefers using s, but may