Skip to content

Commit

Permalink
Use atomic types instead of raw values
Browse files Browse the repository at this point in the history
Per Go documentation recommendation (e.g.
[link](https://pkg.go.dev/sync/atomic#AddUint64)), use the `atomic`
types and their associated methods instead of the
`atomic.Add*`/`.Store*` functions.

This makes the intent clearer, prevents (accidental) non-atomic access,
and (for boolean variables) simplifies code.

Signed-off-by: Hamza El-Saawy <[email protected]>
  • Loading branch information
helsaawy committed Aug 21, 2024
1 parent 4f3da95 commit 2eb9244
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 57 deletions.
8 changes: 4 additions & 4 deletions internal/guest/bridge/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ type Bridge struct {
hostState *hcsv2.Host

quitChan chan bool
// hasQuitPending when != 0 will cause no more requests to be Read.
hasQuitPending uint32
// hasQuitPending indicates the bridge is shutting down and cause no more requests to be Read.
hasQuitPending atomic.Bool

protVer prot.ProtocolVersion
}
Expand Down Expand Up @@ -243,7 +243,7 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
go func() {
var recverr error
for {
if atomic.LoadUint32(&b.hasQuitPending) == 0 {
if !b.hasQuitPending.Load() {
header := &prot.MessageHeader{}
if err := binary.Read(bridgeIn, binary.LittleEndian, header); err != nil {
if err == io.ErrUnexpectedEOF || err == os.ErrClosed { //nolint:errorlint
Expand Down Expand Up @@ -405,7 +405,7 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
case <-b.quitChan:
// The request loop needs to exit so that the teardown process begins.
// Set the request loop to stop processing new messages
atomic.StoreUint32(&b.hasQuitPending, 1)
b.hasQuitPending.Store(true)
// Wait for the request loop to process its last message. Its possible
// that if it lost the race with the hasQuitPending it could be stuck in
// a pending read from bridgeIn. Wait 2 seconds and kill the connection.
Expand Down
16 changes: 10 additions & 6 deletions internal/guest/runtime/hcsv2/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ type Container struct {
processesMutex sync.Mutex
processes map[uint32]*containerProcess

// Only access atomically through getStatus/setStatus.
status containerStatus
// current container (creation) status.
// Only access through [getStatus] and [setStatus].
//
// Note: its more ergonomic to store the uint32 and convert to/from [containerStatus]
// then use [atomic.Value] and deal with unsafe conversions to/from [any], or use [atomic.Pointer]
// and deal with the pointer dereferencing and extra storage.
status atomic.Uint32

// scratchDirPath represents the path inside the UVM where the scratch directory
// of this container is located. Usually, this is either `/run/gcs/c/<containerID>` or
Expand Down Expand Up @@ -268,17 +273,16 @@ func (c *Container) GetStats(ctx context.Context) (*v1.Metrics, error) {
return cg.Stat(cgroups.IgnoreNotExist)
}

func (c *Container) modifyContainerConstraints(ctx context.Context, rt guestrequest.RequestType, cc *guestresource.LCOWContainerConstraints) (err error) {
func (c *Container) modifyContainerConstraints(ctx context.Context, _ guestrequest.RequestType, cc *guestresource.LCOWContainerConstraints) (err error) {
return c.Update(ctx, cc.Linux)
}

func (c *Container) getStatus() containerStatus {
val := atomic.LoadUint32((*uint32)(&c.status))
return containerStatus(val)
return containerStatus(c.status.Load())
}

func (c *Container) setStatus(st containerStatus) {
atomic.StoreUint32((*uint32)(&c.status), uint32(st))
c.status.Store(uint32(st))
}

func (c *Container) ID() string {
Expand Down
2 changes: 1 addition & 1 deletion internal/guest/runtime/hcsv2/uvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM
isSandbox: criType == "sandbox",
exitType: prot.NtUnexpectedExit,
processes: make(map[uint32]*containerProcess),
status: containerCreating,
scratchDirPath: settings.ScratchDirPath,
}
c.setStatus(containerCreating)

if err := h.AddContainer(id, c); err != nil {
return nil, err
Expand Down
20 changes: 6 additions & 14 deletions internal/jobobject/jobobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ import (
// of the job and a mutex for synchronized handle access.
type JobObject struct {
handle windows.Handle
// All accesses to this MUST be done atomically except in `Open` as the object
// is being created in the function. 1 signifies that this job is currently a silo.
silo uint32
// silo signifies that this job is currently a silo.
silo atomic.Bool
mq *queue.MessageQueue
handleLock sync.RWMutex
}
Expand Down Expand Up @@ -204,9 +203,7 @@ func Open(ctx context.Context, options *Options) (_ *JobObject, err error) {
handle: jobHandle,
}

if isJobSilo(jobHandle) {
job.silo = 1
}
job.silo.Store(isJobSilo(jobHandle))

// If the IOCP we'll be using to receive messages for all jobs hasn't been
// created, create it and start polling.
Expand Down Expand Up @@ -479,7 +476,7 @@ func (job *JobObject) ApplyFileBinding(root, target string, readOnly bool) error
return ErrAlreadyClosed
}

if !job.isSilo() {
if !job.silo.Load() {
return ErrNotSilo
}

Expand Down Expand Up @@ -546,7 +543,7 @@ func (job *JobObject) PromoteToSilo() error {
return ErrAlreadyClosed
}

if job.isSilo() {
if job.silo.Load() {
return nil
}

Expand All @@ -569,15 +566,10 @@ func (job *JobObject) PromoteToSilo() error {
return fmt.Errorf("failed to promote job to silo: %w", err)
}

atomic.StoreUint32(&job.silo, 1)
job.silo.Store(true)
return nil
}

// isSilo returns if the job object is a silo.
func (job *JobObject) isSilo() bool {
return atomic.LoadUint32(&job.silo) == 1
}

// QueryPrivateWorkingSet returns the private working set size for the job. This is calculated by adding up the
// private working set for every process running in the job.
func (job *JobObject) QueryPrivateWorkingSet() (uint64, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/jobobject/jobobject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestSiloCreateAndOpen(t *testing.T) {
}
defer jobOpen.Close()

if !jobOpen.isSilo() {
if !jobOpen.silo.Load() {
t.Fatal("job is supposed to be a silo")
}
}
Expand Down
15 changes: 3 additions & 12 deletions internal/log/scrub.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,14 @@ var (
// case sensitive keywords, so "env" is not a substring on "Environment"
_scrubKeywords = [][]byte{[]byte("env"), []byte("Environment")}

_scrub int32
_scrub atomic.Bool
)

// SetScrubbing enables scrubbing
func SetScrubbing(enable bool) {
v := int32(0) // cant convert from bool to int32 directly
if enable {
v = 1
}
atomic.StoreInt32(&_scrub, v)
}
func SetScrubbing(enable bool) { _scrub.Store(enable) }

// IsScrubbingEnabled checks if scrubbing is enabled
func IsScrubbingEnabled() bool {
v := atomic.LoadInt32(&_scrub)
return v != 0
}
func IsScrubbingEnabled() bool { return _scrub.Load() }

// ScrubProcessParameters scrubs HCS Create Process requests with config parameters of
// type internal/hcs/schema2.ScrubProcessParameters (aka hcsshema.ScrubProcessParameters)
Expand Down
13 changes: 5 additions & 8 deletions internal/uvm/counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@

package uvm

import (
"sync/atomic"
)

// ContainerCounter is used for where we layout things for a container in
// a utility VM. For WCOW it'll be C:\c\N\. For LCOW it'll be /run/gcs/c/N/.
// ContainerCounter is used for where we layout things for a container in a utility VM.
// For WCOW it'll be C:\c\N\.
// For LCOW it'll be /run/gcs/c/N/.
func (uvm *UtilityVM) ContainerCounter() uint64 {
return atomic.AddUint64(&uvm.containerCounter, 1)
return uvm.containerCounter.Add(1)
}

// mountCounter is used for maintaining the number of mounts to the UVM.
// This helps in generating unique mount paths for every mount.
func (uvm *UtilityVM) UVMMountCounter() uint64 {
return atomic.AddUint64(&uvm.mountCounter, 1)
return uvm.mountCounter.Add(1)
}
12 changes: 5 additions & 7 deletions internal/uvm/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"

"github.com/Microsoft/go-winio/pkg/guid"
"golang.org/x/sys/windows"
Expand Down Expand Up @@ -56,11 +57,9 @@ type UtilityVM struct {
protocol uint32
guestCaps schema1.GuestDefinedCapabilities

// containerCounter is the current number of containers that have been
// created. This is never decremented in the life of the UVM.
//
// NOTE: All accesses to this MUST be done atomically.
containerCounter uint64
// containerCounter is the current number of containers that have been created.
// This is never decremented in the life of the UVM.
containerCounter atomic.Uint64

// noWritableFileShares disables mounting any writable vSMB or Plan9 shares
// on the uVM. This prevents containers in the uVM modifying files and directories
Expand Down Expand Up @@ -118,8 +117,7 @@ type UtilityVM struct {

// mountCounter is the number of mounts that have been added to the UVM
// This is used in generating a unique mount path inside the UVM for every mount.
// Access to this variable should be done atomically.
mountCounter uint64
mountCounter atomic.Uint64

// Location that container process dumps will get written too.
processDumpLocation string
Expand Down
10 changes: 6 additions & 4 deletions test/gcs/helper_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ const (
// port numbers to assign to connections.
var (
_pipes sync.Map
_portNumber uint32 = 1
_portNumber atomic.Uint32
)

func init() {
_portNumber.Store(1) // start at port 1
}

type PipeTransport struct{}

var _ transport.Transport = &PipeTransport{}
Expand Down Expand Up @@ -250,9 +254,7 @@ func newConnectionSettings(in, out, err bool) stdio.ConnectionSettings {
return c
}

func nextPortNumber() uint32 {
return atomic.AddUint32(&_portNumber, 2)
}
func nextPortNumber() uint32 { return _portNumber.Add(2) }

func TestFakeSocket(t *testing.T) {
ctx := context.Background()
Expand Down

0 comments on commit 2eb9244

Please sign in to comment.