Skip to content

Commit

Permalink
Incremental test fixes (#2298)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kostas Papageorgiou authored Dec 17, 2020
1 parent 04073f9 commit 5fd1587
Show file tree
Hide file tree
Showing 14 changed files with 285 additions and 238 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ var skipOutput = []string{"SKIP"}

type Handler struct {
SqsClient sqsiface.SQSAPI
Cache *RuleCache
Cache RuleCache
DdbClient dynamodbiface.DynamoDBAPI
AlertTable string
AlertingQueueURL string
Expand Down
30 changes: 20 additions & 10 deletions internal/log_analysis/alert_forwarder/forwarder/rule_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package forwarder
*/

import (
"net/http"

lru "github.com/hashicorp/golang-lru"
"github.com/pkg/errors"
"go.uber.org/zap"
Expand All @@ -27,24 +29,28 @@ import (
"github.com/panther-labs/panther/pkg/gatewayapi"
)

type RuleCache interface {
Get(id, version string) (*models.Rule, error)
}

// s3ClientCacheKey -> S3 client
type RuleCache struct {
cache *lru.ARCCache
policyClient gatewayapi.API
type LRUCache struct {
cache *lru.ARCCache
ruleClient gatewayapi.API
}

func NewCache(policyClient gatewayapi.API) *RuleCache {
func NewCache(ruleClient gatewayapi.API) *LRUCache {
cache, err := lru.NewARC(1000)
if err != nil {
panic("failed to create cache")
}
return &RuleCache{
cache: cache,
policyClient: policyClient,
return &LRUCache{
cache: cache,
ruleClient: ruleClient,
}
}

func (c *RuleCache) Get(id, version string) (*models.Rule, error) {
func (c *LRUCache) Get(id, version string) (*models.Rule, error) {
value, ok := c.cache.Get(cacheKey(id, version))
if !ok {
rule, err := c.getRule(id, version)
Expand All @@ -61,15 +67,19 @@ func cacheKey(id, version string) string {
return id + ":" + version
}

func (c *RuleCache) getRule(id, version string) (*models.Rule, error) {
func (c *LRUCache) getRule(id, version string) (*models.Rule, error) {
zap.L().Debug("calling analysis API to retrieve information for rule", zap.String("ruleId", id), zap.String("ruleVersion", version))
input := models.LambdaInput{
GetRule: &models.GetRuleInput{ID: id, VersionID: version},
}
var rule models.Rule

if _, err := c.policyClient.Invoke(&input, &rule); err != nil {
httpStatus, err := c.ruleClient.Invoke(&input, &rule)
if err != nil {
return nil, errors.Wrapf(err, "failed to fetch information for ruleID [%s], version [%s]", id, version)
}
if httpStatus != http.StatusOK {
return nil, errors.Errorf("failed to fetch information for ruleID [%s], version [%s], got HTTP response [%d]", id, version, httpStatus)
}
return &rule, nil
}
70 changes: 70 additions & 0 deletions internal/log_analysis/alert_forwarder/forwarder/rule_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package forwarder

import (
"net/http"
"testing"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

"github.com/panther-labs/panther/api/lambda/analysis/models"
"github.com/panther-labs/panther/pkg/testutils"
)

/**
* Panther is a Cloud-Native SIEM for the Modern Security Team.
* Copyright (C) 2020 Panther Labs Inc
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

func TestCacheHttpError(t *testing.T) {
t.Parallel()
ruleClientMock := &testutils.GatewayapiMock{}
cache := NewCache(ruleClientMock)

ruleClientMock.On("Invoke", mock.Anything, mock.Anything).Return(http.StatusInternalServerError, nil).Once()
rule, err := cache.Get("id", "version")
assert.Error(t, err)
assert.Nil(t, rule)
ruleClientMock.AssertExpectations(t)
}

func TestCacheInvocationError(t *testing.T) {
t.Parallel()
ruleClientMock := &testutils.GatewayapiMock{}
cache := NewCache(ruleClientMock)

ruleClientMock.On("Invoke", mock.Anything, mock.Anything).Return(0, errors.New("test")).Once()
rule, err := cache.Get("id", "version")
assert.Error(t, err)
assert.Nil(t, rule)
ruleClientMock.AssertExpectations(t)
}

func TestCacheRuleRetrieval(t *testing.T) {
t.Parallel()
ruleClientMock := &testutils.GatewayapiMock{}
cache := NewCache(ruleClientMock)

expectedInput := &models.LambdaInput{
GetRule: &models.GetRuleInput{ID: "id", VersionID: "version"},
}
ruleClientMock.On("Invoke", expectedInput, mock.Anything).Return(http.StatusOK, nil).Once()
rule, err := cache.Get("id", "version")
assert.NoError(t, err)
assert.NotNil(t, rule)
ruleClientMock.AssertExpectations(t)
}
20 changes: 9 additions & 11 deletions internal/log_analysis/alerts_api/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ import (

// API has all of the handlers as receiver methods.
type API struct {
awsSession *session.Session
alertsDB table.API
s3Client s3iface.S3API
analysisClient gatewayapi.API
ruleCache *forwarder.RuleCache
awsSession *session.Session
alertsDB table.API
s3Client s3iface.S3API
ruleCache forwarder.RuleCache

env envConfig
}
Expand All @@ -66,12 +65,11 @@ func Setup() *API {
ruleCache := forwarder.NewCache(analysisClient)

return &API{
awsSession: awsSession,
alertsDB: env.NewAlertsTable(dynamodb.New(awsSession)),
s3Client: s3.New(awsSession),
env: env,
analysisClient: analysisClient,
ruleCache: ruleCache,
awsSession: awsSession,
alertsDB: env.NewAlertsTable(dynamodb.New(awsSession)),
s3Client: s3.New(awsSession),
env: env,
ruleCache: ruleCache,
}
}

Expand Down
1 change: 0 additions & 1 deletion internal/log_analysis/alerts_api/api/get_alert.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ func (api *API) GetAlert(input *models.GetAlertInput) (result *models.GetAlertOu
}

alertRule, err := api.ruleCache.Get(alertItem.RuleID, alertItem.RuleVersion)

if err != nil {
zap.L().Warn("failed to get rule with ID", zap.Any("rule id", alertItem.RuleID),
zap.Any("rule version", alertItem.RuleVersion), zap.Any("error", err))
Expand Down
Loading

0 comments on commit 5fd1587

Please sign in to comment.