diff --git a/compiler.go b/compiler.go index f71dcf6..f2f9f23 100644 --- a/compiler.go +++ b/compiler.go @@ -126,14 +126,28 @@ func (c *Compiler) setCallbackData(cb CompilerIncludeFunc) { } } +var ( + errParse = errors.New("Compiler cannot be used after parse error") + errRules = errors.New("Compiler cannot be used after producing rule set") +) + +func (c *Compiler) checkUsage() (err error) { + if c.cptr.errors != 0 { + err = errParse + } else if c.cptr.rules != nil { + err = errRules + } + return +} + // AddFile compiles rules from a file. Rules are added to the // specified namespace. // // If this function returns an error, the Compiler object will become // unusable. func (c *Compiler) AddFile(file *os.File, namespace string) (err error) { - if c.cptr.errors != 0 { - return errors.New("Compiler cannot be used after parse error") + if err := c.checkUsage(); err != nil { + return err } var ns *C.char if namespace != "" { @@ -164,8 +178,8 @@ func (c *Compiler) AddFile(file *os.File, namespace string) (err error) { // If this function returns an error, the Compiler object will become // unusable. func (c *Compiler) AddString(rules string, namespace string) (err error) { - if c.cptr.errors != 0 { - return errors.New("Compiler cannot be used after parse error") + if err := c.checkUsage(); err != nil { + return err } var ns *C.char if namespace != "" { @@ -224,8 +238,8 @@ func (c *Compiler) DefineVariable(identifier string, value interface{}) (err err // GetRules returns the compiled ruleset. func (c *Compiler) GetRules() (*Rules, error) { - if c.cptr.errors != 0 { - return nil, errors.New("Compiler cannot be used after parse error") + if err := c.checkUsage(); err != nil { + return nil, err } var yrRules *C.YR_RULES if err := newError(C.yr_compiler_get_rules(c.cptr, &yrRules)); err != nil { diff --git a/compiler_test.go b/compiler_test.go index fc23319..4b05546 100644 --- a/compiler_test.go +++ b/compiler_test.go @@ -57,6 +57,19 @@ func TestErrors(t *testing.T) { } } +func TestErrorNoPanic(t *testing.T) { + c, _ := NewCompiler() + c.AddString("rule test { condition: true }", "") + if _, err := c.GetRules(); err != nil { + t.Errorf("did not expect error: %v", err) + } + if err := c.AddString("rule test { }", ""); err == nil { + t.Error("expected AddString after GetRules to fail") + } else { + t.Logf("got error as expected: %v", err) + } +} + func setupCompiler(t *testing.T) *Compiler { c, err := NewCompiler() if err != nil {