Skip to content

Commit

Permalink
Fix installflags change detection (#784)
Browse files Browse the repository at this point in the history
* Fix installFlags change detection

Signed-off-by: Kimmo Lehto <[email protected]>

* Add tests

Signed-off-by: Kimmo Lehto <[email protected]>

---------

Signed-off-by: Kimmo Lehto <[email protected]>
  • Loading branch information
kke authored Nov 6, 2024
1 parent 81260d3 commit fa88107
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 164 deletions.
62 changes: 62 additions & 0 deletions internal/shell/split.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package shell

// this is borrowed as-is from rig v2 until k0sctl is updated to use it

import (
"fmt"
"strings"
)

// Split splits the input string respecting shell-like quoted segments.
func Split(input string) ([]string, error) { //nolint:cyclop
var segments []string

currentSegment, ok := builderPool.Get().(*strings.Builder)
if !ok {
currentSegment = &strings.Builder{}
}
defer builderPool.Put(currentSegment)
defer currentSegment.Reset()

var inDoubleQuotes, inSingleQuotes, isEscaped bool

for i := range len(input) {
currentChar := input[i]

if isEscaped {
currentSegment.WriteByte(currentChar)
isEscaped = false
continue
}

switch {
case currentChar == '\\' && !inSingleQuotes:
isEscaped = true
case currentChar == '"' && !inSingleQuotes:
inDoubleQuotes = !inDoubleQuotes
case currentChar == '\'' && !inDoubleQuotes:
inSingleQuotes = !inSingleQuotes
case currentChar == ' ' && !inDoubleQuotes && !inSingleQuotes:
// Space outside quotes; delimiter for a new segment
segments = append(segments, currentSegment.String())
currentSegment.Reset()
default:
currentSegment.WriteByte(currentChar)
}
}

if inDoubleQuotes || inSingleQuotes {
return nil, fmt.Errorf("split `%q`: %w", input, ErrMismatchedQuotes)
}

if isEscaped {
return nil, fmt.Errorf("split `%q`: %w", input, ErrTrailingBackslash)
}

// Add the last segment if present
if currentSegment.Len() > 0 {
segments = append(segments, currentSegment.String())
}

return segments, nil
}
80 changes: 80 additions & 0 deletions internal/shell/unquote.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package shell

import (
"errors"
"fmt"
"strings"
"sync"
)

// This is borrowed from rig v2 until k0sctl is updated to use it

var (
builderPool = sync.Pool{
New: func() interface{} {
return &strings.Builder{}
},
}

// ErrMismatchedQuotes is returned when the input string has mismatched quotes when unquoting.
ErrMismatchedQuotes = errors.New("mismatched quotes")

// ErrTrailingBackslash is returned when the input string ends with a trailing backslash.
ErrTrailingBackslash = errors.New("trailing backslash")
)

// Unquote is a mostly POSIX compliant implementation of unquoting a string the same way a shell would.
// Variables and command substitutions are not handled.
func Unquote(input string) (string, error) { //nolint:cyclop
sb, ok := builderPool.Get().(*strings.Builder)
if !ok {
sb = &strings.Builder{}
}
defer builderPool.Put(sb)
defer sb.Reset()

var inDoubleQuotes, inSingleQuotes, isEscaped bool

for i := range len(input) {
currentChar := input[i]

if isEscaped {
sb.WriteByte(currentChar)
isEscaped = false
continue
}

switch currentChar {
case '\\':
if !inSingleQuotes { // Escape works in double quotes or outside any quotes
isEscaped = true
} else {
sb.WriteByte(currentChar) // Treat as a regular character within single quotes
}
case '"':
if !inSingleQuotes { // Toggle double quotes only if not in single quotes
inDoubleQuotes = !inDoubleQuotes
} else {
sb.WriteByte(currentChar) // Treat as a regular character within single quotes
}
case '\'':
if !inDoubleQuotes { // Toggle single quotes only if not in double quotes
inSingleQuotes = !inSingleQuotes
} else {
sb.WriteByte(currentChar) // Treat as a regular character within double quotes
}
default:
sb.WriteByte(currentChar)
}
}

if inDoubleQuotes || inSingleQuotes {
return "", fmt.Errorf("unquote `%q`: %w", input, ErrMismatchedQuotes)
}

if isEscaped {
return "", fmt.Errorf("unquote `%q`: %w", input, ErrTrailingBackslash)
}

return sb.String(), nil
}
40 changes: 40 additions & 0 deletions internal/shell/unquote_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package shell_test

import (
"testing"

"github.com/k0sproject/k0sctl/internal/shell"
"github.com/stretchr/testify/require"
)

func TestUnquote(t *testing.T) {
t.Run("no quotes", func(t *testing.T) {
out, err := shell.Unquote("foo bar")
require.NoError(t, err)
require.Equal(t, "foo bar", out)
})

t.Run("simple quotes", func(t *testing.T) {
out, err := shell.Unquote("\"foo\" 'bar'")
require.NoError(t, err)
require.Equal(t, "foo bar", out)
})

t.Run("mid-word quotes", func(t *testing.T) {
out, err := shell.Unquote("f\"o\"o b'a'r")
require.NoError(t, err)
require.Equal(t, "foo bar", out)
})

t.Run("complex quotes", func(t *testing.T) {
out, err := shell.Unquote(`'"'"'foo'"'"'`)
require.NoError(t, err)
require.Equal(t, `"'foo'"`, out)
})

t.Run("escaped quotes", func(t *testing.T) {
out, err := shell.Unquote("\\'foo\\' 'bar'")
require.NoError(t, err)
require.Equal(t, "'foo' bar", out)
})
}
6 changes: 5 additions & 1 deletion phase/gather_k0s_facts.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,14 @@ func (p *GatherK0sFacts) investigateK0s(h *cluster.Host) error {

h.Metadata.NeedsUpgrade = p.needsUpgrade(h)

var args cluster.Flags
if len(status.Args) > 2 {
// status.Args contains the binary path and the role as the first two elements, which we can ignore here.
h.Metadata.K0sStatusArgs = status.Args[2:]
for _, a := range status.Args[2:] {
args.Add(a)
}
}
h.Metadata.K0sStatusArgs = args

log.Infof("%s: is running k0s %s version %s", h, h.Role, h.Metadata.K0sRunningVersion)
if h.IsController() {
Expand Down
2 changes: 1 addition & 1 deletion phase/reinstall.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (p *Reinstall) reinstall(h *cluster.Host) error {
h.InstallFlags.AddOrReplace("--enable-dynamic-config")
}

h.InstallFlags.AddOrReplace("--force")
h.InstallFlags.AddOrReplace("--force=true")

cmd, err := h.K0sInstallCommand()
if err != nil {
Expand Down
88 changes: 80 additions & 8 deletions pkg/apis/k0sctl.k0sproject.io/v1beta1/cluster/flags.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
package cluster

import (
"fmt"
"strconv"
"strings"

"github.com/alessio/shellescape"
"github.com/k0sproject/k0sctl/internal/shell"
)

// Flags is a slice of strings with added functions to ease manipulating lists of command-line flags
type Flags []string

// Add adds a flag regardless if it exists already or not
func (f *Flags) Add(s string) {
if ns, err := shell.Unquote(s); err == nil {
s = ns
}
*f = append(*f, s)
}

// Add a flag with a value
func (f *Flags) AddWithValue(key, value string) {
*f = append(*f, key+" "+value)
if nv, err := shell.Unquote(value); err == nil {
value = nv
}
*f = append(*f, key+"="+value)
}

// AddUnlessExist adds a flag unless one with the same prefix exists
func (f *Flags) AddUnlessExist(s string) {
if ns, err := shell.Unquote(s); err == nil {
s = ns
}
if f.Include(s) {
return
}
*f = append(*f, s)
f.Add(s)
}

// AddOrReplace replaces a flag with the same prefix or adds a new one if one does not exist
func (f *Flags) AddOrReplace(s string) {
if ns, err := shell.Unquote(s); err == nil {
s = ns
}
idx := f.Index(s)
if idx > -1 {
(*f)[idx] = s
return
}
*f = append(*f, s)
f.Add(s)
}

// Include returns true if a flag with a matching prefix can be found
Expand All @@ -43,6 +59,9 @@ func (f Flags) Include(s string) bool {

// Index returns an index to a flag with a matching prefix
func (f Flags) Index(s string) int {
if ns, err := shell.Unquote(s); err == nil {
s = ns
}
var flag string
sepidx := strings.IndexAny(s, "= ")
if sepidx < 0 {
Expand Down Expand Up @@ -73,17 +92,16 @@ func (f Flags) GetValue(s string) string {
if fl == "" {
return ""
}
if nfl, err := shell.Unquote(fl); err == nil {
fl = nfl
}

idx := strings.IndexAny(fl, "= ")
if idx < 0 {
return ""
}

val := fl[idx+1:]
s, err := strconv.Unquote(val)
if err == nil {
return s
}

return val
}
Expand Down Expand Up @@ -137,5 +155,59 @@ func (f *Flags) MergeAdd(b Flags) {

// Join creates a string separated by spaces
func (f *Flags) Join() string {
return strings.Join(*f, " ")
var parts []string
f.Each(func(k, v string) {
if v == "" && k != "" {
parts = append(parts, shellescape.Quote(k))
} else {
parts = append(parts, fmt.Sprintf("%s=%s", k, shellescape.Quote(v)))
}
})
return strings.Join(parts, " ")
}

// Each iterates over each flag and calls the function with the flag key and value as arguments
func (f Flags) Each(fn func(string, string)) {
for _, flag := range f {
sepidx := strings.IndexAny(flag, "= ")
if sepidx < 0 {
if flag == "" {
continue
}
fn(flag, "")
} else {
key, value := flag[:sepidx], flag[sepidx+1:]
if unq, err := shell.Unquote(value); err == nil {
value = unq
}
fn(key, value)
}
}
}

// Map returns a map[string]string of the flags where the key is the flag and the value is the value
func (f Flags) Map() map[string]string {
res := make(map[string]string)
f.Each(func(k, v string) {
res[k] = v
})
return res
}

// Equals compares the flags with another Flags and returns true if they have the same flags and values, ignoring order
func (f Flags) Equals(b Flags) bool {
if len(f) != len(b) {
return false
}
for _, flag := range f {
if !b.Include(flag) {
return false
}
ourValue := f.GetValue(flag)
theirValue := b.GetValue(flag)
if ourValue != theirValue {
return false
}
}
return true
}
Loading

0 comments on commit fa88107

Please sign in to comment.