diff --git a/embed.go b/embed.go index 0380ef2..c2be2fe 100644 --- a/embed.go +++ b/embed.go @@ -34,6 +34,9 @@ var ( {"and", -1, evalAnd}, {"or", -1, evalOr}, {"repeat", 2, evalRepeat}, + {"firstCaseIndex", -1, evalFirstCaseIndex}, + {"firstEqualIndex", -1, evalFirstEqualIndex}, + {"selectCaseByIndex", -1, evalSelectCaseByIndex}, } embedArithmeticsShort = []*EmbeddedFunctionData{ {"add", 2, evalAddUint}, @@ -292,7 +295,7 @@ func evalRepeat(par *CallParams) []byte { fragment := par.Arg(0) n := par.Arg(1) if len(n) != 1 { - par.TracePanic("evalRepeat: count must 1-byte long") + par.TracePanic("evalRepeat: count must be 1-byte long") } ret := bytes.Repeat(fragment, int(n[0])) par.Trace("hasPrefix:: %s, %s -> %s", Fmt(fragment), Fmt(n), Fmt(ret)) @@ -319,6 +322,44 @@ func evalIf(par *CallParams) []byte { return no } +func evalFirstCaseIndex(par *CallParams) []byte { + for i := byte(0); i < par.Arity(); i++ { + if ret := par.Arg(i); len(ret) > 0 { + par.Trace("firstCaseIndex:: -> %d", i) + return []byte{i} + } + } + par.Trace("firstCaseIndex:: -> nil") + return nil +} + +func evalFirstEqualIndex(par *CallParams) []byte { + if par.Arity() == 0 { + return nil + } + + v := par.Arg(0) + for i := byte(1); i < par.Arity(); i++ { + if bytes.Equal(v, par.Arg(i)) { + par.Trace("firstEqualIndex:: -> %d", i) + return []byte{i - 1} + } + } + par.Trace("firstEqualIndex:: -> nil") + return nil +} + +func evalSelectCaseByIndex(par *CallParams) []byte { + if par.Arity() == 0 { + par.TracePanic("evalSelectCaseByIndex: must be at least 1 argument") + } + idx := par.Arg(0) + if len(idx) != 1 || idx[0]+1 >= par.Arity() { + return nil + } + return par.Arg(idx[0] + 1) +} + func evalIsZero(par *CallParams) []byte { arg := par.Arg(0) for _, b := range arg { diff --git a/library_test.go b/library_test.go index 8659b5f..a13dfb7 100644 --- a/library_test.go +++ b/library_test.go @@ -1071,3 +1071,82 @@ func TestBytecodeParams(t *testing.T) { }) } + +func TestCases(t *testing.T) { + lib := NewBase() + t.Run("1", func(t *testing.T) { + const src = `firstCaseIndex( + equal($0, 1), + equal($0, 2), + equal($0, 3), + equal($0, 4), + equal($0, 0xffff), + ) +` + expr, n, _, err := lib.CompileExpression(src) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + res := EvalExpression(nil, expr, []byte{3}) + require.EqualValues(t, []byte{2}, res) + + res = EvalExpression(nil, expr, []byte{4}) + require.EqualValues(t, []byte{3}, res) + + res = EvalExpression(nil, expr, []byte{0}) + require.True(t, len(res) == 0) + + res = EvalExpression(nil, expr, []byte{7}) + require.True(t, len(res) == 0) + + res = EvalExpression(nil, expr, []byte{0xff, 0xff}) + require.EqualValues(t, []byte{4}, res) + }) + t.Run("2", func(t *testing.T) { + const src = "firstEqualIndex($0, 1, 2, 3, 4, 0xffff)" + + expr, n, _, err := lib.CompileExpression(src) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + res := EvalExpression(nil, expr, []byte{3}) + require.EqualValues(t, []byte{2}, res) + + res = EvalExpression(nil, expr, []byte{4}) + require.EqualValues(t, []byte{3}, res) + + res = EvalExpression(nil, expr, []byte{0}) + require.True(t, len(res) == 0) + + res = EvalExpression(nil, expr, []byte{7}) + require.True(t, len(res) == 0) + + res = EvalExpression(nil, expr, []byte{0xff, 0xff}) + require.EqualValues(t, []byte{4}, res) + }) + t.Run("3", func(t *testing.T) { + const src = "selectCaseByIndex($0, 1, 0x1234, add(5,3), true)" + + expr, n, _, err := lib.CompileExpression(src) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + res := EvalExpression(nil, expr, []byte{0}) + require.EqualValues(t, []byte{1}, res) + + res = EvalExpression(nil, expr, []byte{1}) + require.EqualValues(t, []byte{0x12, 0x34}, res) + + res = EvalExpression(nil, expr, []byte{2}) + require.EqualValues(t, []byte{0, 0, 0, 0, 0, 0, 0, 8}, res) + + res = EvalExpression(nil, expr, []byte{3}) + require.EqualValues(t, []byte{0xff}, res) + + res = EvalExpression(nil, expr, []byte{4}) + require.True(t, len(res) == 0) + + res = EvalExpression(nil, expr, []byte{0, 0}) + require.True(t, len(res) == 0) + }) +}