Skip to content

Commit

Permalink
util/ring: generic-ify ring.Buffer
Browse files Browse the repository at this point in the history
Epic: none

Release note: None
  • Loading branch information
ajwerner committed Dec 19, 2022
1 parent 7fd3f09 commit 3f3218d
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 138 deletions.
4 changes: 2 additions & 2 deletions pkg/sql/conn_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ type StmtBuf struct {
cond *sync.Cond

// data contains the elements of the buffer.
data ring.Buffer // []Command
data ring.Buffer[Command]

// startPos indicates the index of the first command currently in data
// relative to the start of the connection.
Expand Down Expand Up @@ -459,7 +459,7 @@ func (buf *StmtBuf) CurCmd() (Command, CmdPos, error) {
}
len := buf.mu.data.Len()
if cmdIdx < len {
return buf.mu.data.Get(cmdIdx).(Command), curPos, nil
return buf.mu.data.Get(cmdIdx), curPos, nil
}
if cmdIdx != len {
return nil, 0, errors.AssertionFailedf(
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/distsql_running.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ type cancelFlowsCoordinator struct {
mu struct {
syncutil.Mutex
// deadFlowsByNode is a ring of pointers to deadFlowsOnNode objects.
deadFlowsByNode ring.Buffer
deadFlowsByNode ring.Buffer[*deadFlowsOnNode]
}
// workerWait should be used by canceling workers to block until there are
// some dead flows to cancel.
Expand All @@ -301,7 +301,7 @@ func (c *cancelFlowsCoordinator) getFlowsToCancel() (
if c.mu.deadFlowsByNode.Len() == 0 {
return nil, base.SQLInstanceID(0)
}
deadFlows := c.mu.deadFlowsByNode.GetFirst().(*deadFlowsOnNode)
deadFlows := c.mu.deadFlowsByNode.GetFirst()
c.mu.deadFlowsByNode.RemoveFirst()
req := &execinfrapb.CancelDeadFlowsRequest{
FlowIDs: deadFlows.ids,
Expand All @@ -322,7 +322,7 @@ func (c *cancelFlowsCoordinator) addFlowsToCancel(
// sufficiently fast.
found := false
for j := 0; j < c.mu.deadFlowsByNode.Len(); j++ {
deadFlows := c.mu.deadFlowsByNode.Get(j).(*deadFlowsOnNode)
deadFlows := c.mu.deadFlowsByNode.Get(j)
if sqlInstanceID == deadFlows.sqlInstanceID {
deadFlows.ids = append(deadFlows.ids, f.FlowID)
found = true
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/distsql_running_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ func TestCancelFlowsCoordinator(t *testing.T) {
require.GreaterOrEqual(t, numNodes-1, c.mu.deadFlowsByNode.Len())
seen := make(map[base.SQLInstanceID]struct{})
for i := 0; i < c.mu.deadFlowsByNode.Len(); i++ {
deadFlows := c.mu.deadFlowsByNode.Get(i).(*deadFlowsOnNode)
deadFlows := c.mu.deadFlowsByNode.Get(i)
require.NotEqual(t, gatewaySQLInstanceID, deadFlows.sqlInstanceID)
_, ok := seen[deadFlows.sqlInstanceID]
require.False(t, ok)
Expand Down
61 changes: 7 additions & 54 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1247,69 +1247,22 @@ type flushInfo struct {
// cmdStarts maintains the state about where the results for the respective
// positions begin. We utilize the invariant that positions are
// monotonically increasing sequences.
cmdStarts cmdIdxBuffer
cmdStarts ring.Buffer[cmdIdx]
}

type cmdIdx struct {
pos sql.CmdPos
idx int
}

var cmdIdxPool = sync.Pool{
New: func() interface{} {
return &cmdIdx{}
},
}

func (c *cmdIdx) release() {
*c = cmdIdx{}
cmdIdxPool.Put(c)
}

type cmdIdxBuffer struct {
// We intentionally do not just embed ring.Buffer in order to restrict the
// methods that can be called on cmdIdxBuffer.
buffer ring.Buffer
}

func (b *cmdIdxBuffer) empty() bool {
return b.buffer.Len() == 0
}

func (b *cmdIdxBuffer) addLast(pos sql.CmdPos, idx int) {
cmdIdx := cmdIdxPool.Get().(*cmdIdx)
cmdIdx.pos = pos
cmdIdx.idx = idx
b.buffer.AddLast(cmdIdx)
}

// removeLast removes the last cmdIdx from the buffer and will panic if the
// buffer is empty.
func (b *cmdIdxBuffer) removeLast() {
b.getLast().release()
b.buffer.RemoveLast()
}

// getLast returns the last cmdIdx in the buffer and will panic if the buffer is
// empty.
func (b *cmdIdxBuffer) getLast() *cmdIdx {
return b.buffer.GetLast().(*cmdIdx)
}

func (b *cmdIdxBuffer) clear() {
for !b.empty() {
b.removeLast()
}
}

// registerCmd updates cmdStarts buffer when the first result for a new command
// is received.
func (fi *flushInfo) registerCmd(pos sql.CmdPos) {
if !fi.cmdStarts.empty() && fi.cmdStarts.getLast().pos >= pos {
if fi.cmdStarts.Len() > 0 && fi.cmdStarts.GetLast().pos >= pos {
// Not a new command, nothing to do.
return
}
fi.cmdStarts.addLast(pos, fi.buf.Len())
fi.cmdStarts.AddLast(cmdIdx{pos: pos, idx: fi.buf.Len()})
}

func cookTag(
Expand Down Expand Up @@ -1682,7 +1635,7 @@ func (c *conn) Flush(pos sql.CmdPos) error {

c.writerState.fi.lastFlushed = pos
// Make sure that the entire cmdStarts buffer is drained.
c.writerState.fi.cmdStarts.clear()
c.writerState.fi.cmdStarts.Discard()

_ /* n */, err := c.writerState.buf.WriteTo(c.conn)
if err != nil {
Expand Down Expand Up @@ -1756,13 +1709,13 @@ func (cl *clientConnLock) RTrim(ctx context.Context, pos sql.CmdPos) {
truncateIdx := cl.buf.Len()
// Update cmdStarts buffer: delete commands that were trimmed from the back
// of the cmdStarts buffer.
for !cl.cmdStarts.empty() {
cmdStart := cl.cmdStarts.getLast()
for cl.cmdStarts.Len() > 0 {
cmdStart := cl.cmdStarts.GetLast()
if cmdStart.pos < pos {
break
}
truncateIdx = cmdStart.idx
cl.cmdStarts.removeLast()
cl.cmdStarts.RemoveLast()
}
cl.buf.Truncate(truncateIdx)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/rowcontainer/row_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ type DiskBackedIndexedRowContainer struct {
firstCachedRowPos int
nextPosToCache int
// indexedRowsCache is the cache of up to maxCacheSize contiguous rows.
indexedRowsCache ring.Buffer
indexedRowsCache ring.Buffer[eval.IndexedRow]
// maxCacheSize indicates the maximum number of rows to be cached. It is
// initialized to maxIndexedRowsCacheSize and dynamically adjusted if OOM
// error is encountered.
Expand Down Expand Up @@ -783,7 +783,7 @@ func (f *DiskBackedIndexedRowContainer) GetRow(
if pos >= f.firstCachedRowPos && pos < f.nextPosToCache {
requestedRowCachePos := pos - f.firstCachedRowPos
f.hitCount++
return f.indexedRowsCache.Get(requestedRowCachePos).(eval.IndexedRow), nil
return f.indexedRowsCache.Get(requestedRowCachePos), nil
}
f.missCount++
if f.diskRowIter == nil {
Expand Down Expand Up @@ -860,7 +860,7 @@ func (f *DiskBackedIndexedRowContainer) GetRow(
return nil, errors.Errorf("unexpected last column type: should be DInt but found %T", idx)
}
if f.idxRowIter == pos {
return f.indexedRowsCache.GetLast().(eval.IndexedRow), nil
return f.indexedRowsCache.GetLast(), nil
}
}
f.idxRowIter++
Expand Down
10 changes: 5 additions & 5 deletions pkg/sql/sem/builtins/window_frame_builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type indexedValue struct {
// It assumes that the frame bounds will never go back, i.e. non-decreasing
// sequences of frame start and frame end indices.
type slidingWindow struct {
values ring.Buffer
values ring.Buffer[*indexedValue]
evalCtx *eval.Context
cmp func(*eval.Context, tree.Datum, tree.Datum) int
}
Expand All @@ -58,7 +58,7 @@ func makeSlidingWindow(
// largest idx).
func (sw *slidingWindow) add(iv *indexedValue) {
for i := sw.values.Len() - 1; i >= 0; i-- {
if sw.cmp(sw.evalCtx, sw.values.Get(i).(*indexedValue).value, iv.value) > 0 {
if sw.cmp(sw.evalCtx, sw.values.Get(i).value, iv.value) > 0 {
break
}
sw.values.RemoveLast()
Expand All @@ -70,15 +70,15 @@ func (sw *slidingWindow) add(iv *indexedValue) {
// indices smaller than given 'idx'. This operation corresponds to shifting the
// start of the frame up to 'idx'.
func (sw *slidingWindow) removeAllBefore(idx int) {
for sw.values.Len() > 0 && sw.values.Get(0).(*indexedValue).idx < idx {
for sw.values.Len() > 0 && sw.values.Get(0).idx < idx {
sw.values.RemoveFirst()
}
}

func (sw *slidingWindow) string() string {
var builder strings.Builder
for i := 0; i < sw.values.Len(); i++ {
builder.WriteString(fmt.Sprintf("(%v, %v)\t", sw.values.Get(i).(*indexedValue).value, sw.values.Get(i).(*indexedValue).idx))
builder.WriteString(fmt.Sprintf("(%v, %v)\t", sw.values.Get(i).value, sw.values.Get(i).idx))
}
return builder.String()
}
Expand Down Expand Up @@ -175,7 +175,7 @@ func (w *slidingWindowFunc) Compute(

// The datum with "highest priority" within the frame is at the very front
// of the deque.
return w.sw.values.GetFirst().(*indexedValue).value, nil
return w.sw.values.GetFirst().value, nil
}

func max(a, b int) int {
Expand Down
8 changes: 4 additions & 4 deletions pkg/sql/sem/eval/window_funcs_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type peerGroup struct {
// offsets if we have OFFSET_FOLLOWING type of bound (both F and O are
// upper-bounded by total number of peer groups).
type PeerGroupsIndicesHelper struct {
groups ring.Buffer // queue of peer groups
groups ring.Buffer[*peerGroup]
peerGrouper PeerGroupChecker
headPeerGroupNum int // number of the peer group at the head of the queue
allPeerGroupsSkipped bool // in GROUP mode, indicates whether all peer groups were skipped during Init
Expand Down Expand Up @@ -161,7 +161,7 @@ func (p *PeerGroupsIndicesHelper) Update(wfr *WindowFrameRun) error {

// nextPeerGroupStartIdx is the index of the first row that we haven't
// computed peer group for.
lastPeerGroup := p.groups.GetLast().(*peerGroup)
lastPeerGroup := p.groups.GetLast()
nextPeerGroupStartIdx := lastPeerGroup.firstPeerIdx + lastPeerGroup.rowCount

if (wfr.Frame == nil || wfr.Frame.Mode == treewindow.ROWS || wfr.Frame.Mode == treewindow.RANGE) ||
Expand Down Expand Up @@ -211,7 +211,7 @@ func (p *PeerGroupsIndicesHelper) GetFirstPeerIdx(peerGroupNum int) int {
if posInBuffer < 0 || p.groups.Len() < posInBuffer {
panic("peerGroupNum out of bounds")
}
return p.groups.Get(posInBuffer).(*peerGroup).firstPeerIdx
return p.groups.Get(posInBuffer).firstPeerIdx
}

// GetRowCount returns the number of rows within peer group of number
Expand All @@ -221,7 +221,7 @@ func (p *PeerGroupsIndicesHelper) GetRowCount(peerGroupNum int) int {
if posInBuffer < 0 || p.groups.Len() < posInBuffer {
panic("peerGroupNum out of bounds")
}
return p.groups.Get(posInBuffer).(*peerGroup).rowCount
return p.groups.Get(posInBuffer).rowCount
}

// GetLastPeerGroupNum returns the number of the last peer group in the queue.
Expand Down
Loading

0 comments on commit 3f3218d

Please sign in to comment.