Skip to content

Commit

Permalink
refactor sagemaker resource type (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 authored Jul 27, 2023
1 parent 8ce2b4c commit ff1a157
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 239 deletions.
2 changes: 1 addition & 1 deletion aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
}
if IsNukeable(notebookInstances.ResourceName(), resourceTypes) {
start := time.Now()
instances, err := getAllNotebookInstances(cloudNukeSession, excludeAfter, configObj)
instances, err := notebookInstances.getAll(configObj)
if err != nil {
ge := report.GeneralError{
Error: err,
Expand Down
60 changes: 23 additions & 37 deletions aws/sagemaker_notebook_instance.go
Original file line number Diff line number Diff line change
@@ -1,97 +1,83 @@
package aws

import (
"github.com/gruntwork-io/cloud-nuke/telemetry"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
"time"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
"github.com/gruntwork-io/cloud-nuke/telemetry"
"github.com/gruntwork-io/go-commons/errors"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
)

func getAllNotebookInstances(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) {
svc := sagemaker.New(session)

result, err := svc.ListNotebookInstances(&sagemaker.ListNotebookInstancesInput{})
func (smni SageMakerNotebookInstances) getAll(configObj config.Config) ([]*string, error) {
result, err := smni.Client.ListNotebookInstances(&sagemaker.ListNotebookInstancesInput{})
if err != nil {
return nil, errors.WithStackTrace(err)
}

var names []*string

for _, notebook := range result.NotebookInstances {
if notebook.CreationTime == nil {
continue
}
if !excludeAfter.After(awsgo.TimeValue(notebook.CreationTime)) {
continue
}
if !config.ShouldInclude(awsgo.StringValue(notebook.NotebookInstanceName), configObj.S3.IncludeRule.NamesRegExp, configObj.S3.ExcludeRule.NamesRegExp) {
continue
if configObj.SageMakerNotebook.ShouldInclude(config.ResourceValue{
Name: notebook.NotebookInstanceName,
Time: notebook.CreationTime,
}) {
names = append(names, notebook.NotebookInstanceName)
}
names = append(names, notebook.NotebookInstanceName)
}

return names, nil
}

func nukeAllNotebookInstances(session *session.Session, names []*string) error {
svc := sagemaker.New(session)

func (smni SageMakerNotebookInstances) nukeAll(names []*string) error {
if len(names) == 0 {
logging.Logger.Debugf("No Sagemaker Notebook Instance to nuke in region %s", *session.Config.Region)
logging.Logger.Debugf("No Sagemaker Notebook Instance to nuke in region %s", smni.Region)
return nil
}

logging.Logger.Debugf("Deleting all Sagemaker Notebook Instances in region %s", *session.Config.Region)
logging.Logger.Debugf("Deleting all Sagemaker Notebook Instances in region %s", smni.Region)
deletedNames := []*string{}

for _, name := range names {
params := &sagemaker.DeleteNotebookInstanceInput{
NotebookInstanceName: name,
}

_, err := svc.StopNotebookInstance(&sagemaker.StopNotebookInstanceInput{
_, err := smni.Client.StopNotebookInstance(&sagemaker.StopNotebookInstanceInput{
NotebookInstanceName: name,
})
if err != nil {
logging.Logger.Errorf("[Failed] %s: %s", *name, err)
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking Sagemaker Notebook Instance",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": smni.Region,
"reason": "Failed to Stop Notebook",
})
}

err = svc.WaitUntilNotebookInstanceStopped(&sagemaker.DescribeNotebookInstanceInput{
err = smni.Client.WaitUntilNotebookInstanceStopped(&sagemaker.DescribeNotebookInstanceInput{
NotebookInstanceName: name,
})

if err != nil {
logging.Logger.Errorf("[Failed] %s", err)
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking Sagemaker Notebook Instance",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": smni.Region,
"reason": "Failed waiting for notebook to stop",
})
}

_, err = svc.DeleteNotebookInstance(params)
_, err = smni.Client.DeleteNotebookInstance(&sagemaker.DeleteNotebookInstanceInput{
NotebookInstanceName: name,
})

if err != nil {
logging.Logger.Errorf("[Failed] %s: %s", *name, err)
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking Sagemaker Notebook Instance",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": smni.Region,
"reason": "Failed to Delete Notebook",
})
} else {
Expand All @@ -103,7 +89,7 @@ func nukeAllNotebookInstances(session *session.Session, names []*string) error {
if len(deletedNames) > 0 {
for _, name := range deletedNames {

err := svc.WaitUntilNotebookInstanceDeleted(&sagemaker.DescribeNotebookInstanceInput{
err := smni.Client.WaitUntilNotebookInstanceDeleted(&sagemaker.DescribeNotebookInstanceInput{
NotebookInstanceName: name,
})

Expand All @@ -120,14 +106,14 @@ func nukeAllNotebookInstances(session *session.Session, names []*string) error {
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking Sagemaker Notebook Instance",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": smni.Region,
"reason": "Failed waiting for notebook instance to delete",
})
return errors.WithStackTrace(err)
}
}
}

logging.Logger.Debugf("[OK] %d Sagemaker Notebook Instance(s) deleted in %s", len(deletedNames), *session.Config.Region)
logging.Logger.Debugf("[OK] %d Sagemaker Notebook Instance(s) deleted in %s", len(deletedNames), smni.Region)
return nil
}
159 changes: 88 additions & 71 deletions aws/sagemaker_notebook_instance_test.go
Original file line number Diff line number Diff line change
@@ -1,104 +1,121 @@
package aws

import (
"github.com/gruntwork-io/cloud-nuke/telemetry"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/aws/aws-sdk-go/service/sagemaker/sagemakeriface"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/util"
"github.com/gruntwork-io/go-commons/errors"
"github.com/stretchr/testify/assert"
"github.com/gruntwork-io/cloud-nuke/telemetry"
"github.com/stretchr/testify/require"
"regexp"
"testing"
"time"
)

// There's a built-in function WaitUntilDBInstanceAvailable but
// the times that it was tested, it wasn't returning anything so we'll leave with the
// custom one.

func waitUntilNotebookInstanceCreated(svc *sagemaker.SageMaker, name *string) error {
input := &sagemaker.DescribeNotebookInstanceInput{
NotebookInstanceName: name,
}

for i := 0; i < 600; i++ {
instance, err := svc.DescribeNotebookInstance(input)
status := instance.NotebookInstanceStatus

if awsgo.StringValue(status) != "Pending" {
return nil
}

if err != nil {
return err
}

time.Sleep(1 * time.Second)
logging.Logger.Debug("Waiting for SageMaker Notebook Instance to be created")
}
type mockedSageMakerNotebookInstance struct {
sagemakeriface.SageMakerAPI
ListNotebookInstancesOutput sagemaker.ListNotebookInstancesOutput
StopNotebookInstanceOutput sagemaker.StopNotebookInstanceOutput
DeleteNotebookInstanceOutput sagemaker.DeleteNotebookInstanceOutput
}

return SageMakerNotebookInstanceDeleteError{name: *name}
func (m mockedSageMakerNotebookInstance) ListNotebookInstances(input *sagemaker.ListNotebookInstancesInput) (*sagemaker.ListNotebookInstancesOutput, error) {
return &m.ListNotebookInstancesOutput, nil
}

func createTestNotebookInstance(t *testing.T, session *session.Session, name string, roleArn string) {
svc := sagemaker.New(session)
func (m mockedSageMakerNotebookInstance) StopNotebookInstance(input *sagemaker.StopNotebookInstanceInput) (*sagemaker.StopNotebookInstanceOutput, error) {
return &m.StopNotebookInstanceOutput, nil
}

params := &sagemaker.CreateNotebookInstanceInput{
InstanceType: awsgo.String("ml.t2.medium"),
NotebookInstanceName: awsgo.String(name),
RoleArn: awsgo.String(roleArn),
}
func (m mockedSageMakerNotebookInstance) WaitUntilNotebookInstanceStopped(*sagemaker.DescribeNotebookInstanceInput) error {
return nil
}

_, err := svc.CreateNotebookInstance(params)
require.NoError(t, err)
func (m mockedSageMakerNotebookInstance) WaitUntilNotebookInstanceDeleted(*sagemaker.DescribeNotebookInstanceInput) error {
return nil
}

waitUntilNotebookInstanceCreated(svc, &name)
func (m mockedSageMakerNotebookInstance) DeleteNotebookInstance(input *sagemaker.DeleteNotebookInstanceInput) (*sagemaker.DeleteNotebookInstanceOutput, error) {
return &m.DeleteNotebookInstanceOutput, nil
}

func TestNukeNotebookInstance(t *testing.T) {
func TestSageMakerNotebookInstances_GetAll(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

region, err := getRandomRegion()

require.NoError(t, errors.WithStackTrace(err))

session, err := session.NewSessionWithOptions(
session.Options{
SharedConfigState: session.SharedConfigEnable,
Config: awsgo.Config{
Region: awsgo.String(region),
now := time.Now()
testName1 := "test1"
testName2 := "test2"
smni := SageMakerNotebookInstances{
Client: mockedSageMakerNotebookInstance{
ListNotebookInstancesOutput: sagemaker.ListNotebookInstancesOutput{
NotebookInstances: []*sagemaker.NotebookInstanceSummary{
{
NotebookInstanceName: awsgo.String(testName1),
CreationTime: awsgo.Time(now),
},
{
NotebookInstanceName: awsgo.String(testName2),
CreationTime: awsgo.Time(now.Add(1)),
},
},
},
},
)

notebookName := "cloud-nuke-test-" + util.UniqueID()
excludeAfter := time.Now().Add(1 * time.Hour)

role := createNotebookRole(t, session, notebookName+"-role")
defer deleteNotebookRole(session, role)

createTestNotebookInstance(t, session, notebookName, *role.Arn)

defer func() {
nukeAllNotebookInstances(session, []*string{&notebookName})
}

notebookNames, _ := getAllNotebookInstances(session, excludeAfter, config.Config{})
tests := map[string]struct {
configObj config.ResourceType
expected []string
}{
"emptyFilter": {
configObj: config.ResourceType{},
expected: []string{testName1, testName2},
},
"nameExclusionFilter": {
configObj: config.ResourceType{
ExcludeRule: config.FilterRule{
NamesRegExp: []config.Expression{{
RE: *regexp.MustCompile(testName1),
}}},
},
expected: []string{testName2},
},
"timeAfterExclusionFilter": {
configObj: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: aws.Time(now.Add(-1 * time.Hour)),
}},
expected: []string{},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
names, err := smni.getAll(config.Config{
SageMakerNotebook: tc.configObj,
})
require.NoError(t, err)
require.Equal(t, tc.expected, awsgo.StringValueSlice(names))
})
}

assert.NotContains(t, awsgo.StringValueSlice(notebookNames), strings.ToLower(notebookName))
}()
}

instances, err := getAllNotebookInstances(session, excludeAfter, config.Config{})
func TestSageMakerNotebookInstances_NukeAll(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

if err != nil {
assert.Failf(t, "Unable to fetch list of SageMaker Notebook Instances", errors.WithStackTrace(err).Error())
smni := SageMakerNotebookInstances{
Client: mockedSageMakerNotebookInstance{
StopNotebookInstanceOutput: sagemaker.StopNotebookInstanceOutput{},
DeleteNotebookInstanceOutput: sagemaker.DeleteNotebookInstanceOutput{},
},
}

assert.Contains(t, awsgo.StringValueSlice(instances), notebookName)

err := smni.nukeAll([]*string{aws.String("test")})
require.NoError(t, err)
}
14 changes: 7 additions & 7 deletions aws/sagemaker_notebook_instance_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@ type SageMakerNotebookInstances struct {
InstanceNames []string
}

func (instance SageMakerNotebookInstances) ResourceName() string {
return "sagemaker-notebook-instance"
func (smni SageMakerNotebookInstances) ResourceName() string {
return "sagemaker-notebook-smni"
}

// ResourceIdentifiers - The instance names of the rds db instances
func (instance SageMakerNotebookInstances) ResourceIdentifiers() []string {
return instance.InstanceNames
func (smni SageMakerNotebookInstances) ResourceIdentifiers() []string {
return smni.InstanceNames
}

func (instance SageMakerNotebookInstances) MaxBatchSize() int {
func (smni SageMakerNotebookInstances) MaxBatchSize() int {
// Tentative batch size to ensure AWS doesn't throttle
return 49
}

// Nuke - nuke 'em all!!!
func (instance SageMakerNotebookInstances) Nuke(session *session.Session, identifiers []string) error {
if err := nukeAllNotebookInstances(session, awsgo.StringSlice(identifiers)); err != nil {
func (smni SageMakerNotebookInstances) Nuke(session *session.Session, identifiers []string) error {
if err := smni.nukeAll(awsgo.StringSlice(identifiers)); err != nil {
return errors.WithStackTrace(err)
}

Expand Down
Loading

0 comments on commit ff1a157

Please sign in to comment.