From 5d444b25bcdc8ffeb9ab0e6dac749221ef995627 Mon Sep 17 00:00:00 2001 From: Nathan Smyth Date: Thu, 14 Mar 2024 07:06:52 +0000 Subject: [PATCH] fix: Make audit log unit test pass --- internal/cli/atlas/streams/instance/logs.go | 26 +++++++++++++++---- .../cli/atlas/streams/instance/logs_test.go | 21 ++++++++------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/internal/cli/atlas/streams/instance/logs.go b/internal/cli/atlas/streams/instance/logs.go index 663d71fde0..3c72e3183a 100644 --- a/internal/cli/atlas/streams/instance/logs.go +++ b/internal/cli/atlas/streams/instance/logs.go @@ -17,6 +17,7 @@ package instance import ( "context" "fmt" + "io" "slices" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" @@ -27,6 +28,8 @@ import ( "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/usage" "github.com/spf13/afero" "github.com/spf13/cobra" + + atlasv2 "go.mongodb.org/atlas-sdk/v20231115008/admin" ) var downloadMessage = "Download of %s completed.\n" @@ -38,7 +41,6 @@ type DownloadOpts struct { fileName string start int64 end int64 - decompress bool store store.StreamsDownloader } @@ -58,13 +60,28 @@ func (opts *DownloadOpts) initDefaultOut() error { } func (opts *DownloadOpts) Run() error { - w, err := opts.NewWriteCloser() + params := atlasv2.DownloadStreamTenantAuditLogsApiParams{ + GroupId: opts.ProjectID, + TenantName: opts.tenantName, + StartDate: &opts.start, + EndDate: &opts.end, + } + + f, err := opts.store.DownloadAuditLog(¶ms) if err != nil { return err } - defer w.Close() - return nil + defer f.Close() + + out, err := opts.NewWriteCloser() + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, f) + return err } // DownloadBuilder @@ -109,7 +126,6 @@ func DownloadBuilder() *cobra.Command { cmd.Flags().Int64Var(&opts.start, flag.Start, 0, usage.LogStart) cmd.Flags().Int64Var(&opts.end, flag.End, 0, usage.LogEnd) cmd.Flags().BoolVar(&opts.Force, flag.Force, false, usage.ForceFile) - cmd.Flags().BoolVarP(&opts.decompress, flag.Decompress, flag.DecompressShort, false, usage.Decompress) cmd.Flags().StringVar(&opts.ProjectID, flag.ProjectID, "", usage.ProjectID) diff --git a/internal/cli/atlas/streams/instance/logs_test.go b/internal/cli/atlas/streams/instance/logs_test.go index d606f5f25b..f782087df4 100644 --- a/internal/cli/atlas/streams/instance/logs_test.go +++ b/internal/cli/atlas/streams/instance/logs_test.go @@ -23,10 +23,10 @@ import ( "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/flag" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/mocks" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/store" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/test" "github.com/spf13/afero" "github.com/stretchr/testify/require" + atlasv2 "go.mongodb.org/atlas-sdk/v20231115008/admin" ) func TestDownloadOpts_Run(t *testing.T) { @@ -34,7 +34,8 @@ func TestDownloadOpts_Run(t *testing.T) { mockStore := mocks.NewMockStreamsDownloader(ctrl) const contents = "expected" - const fileName = "auditLogs.gz" + const projectID = "download-project-id" + const tenantName = "streams-tenant" file, err := os.CreateTemp("", "") if err != nil { @@ -52,22 +53,22 @@ func TestDownloadOpts_Run(t *testing.T) { downloadOpts := &DownloadOpts{ store: mockStore, DownloaderOpts: cli.DownloaderOpts{ - Out: fileName, + Out: "auditLogs.gz", Fs: fs, }, } - downloadOpts.ProjectID = "download-project-id" - downloadOpts.decompress = true - downloadOpts.fileName = fileName + downloadOpts.ProjectID = projectID + downloadOpts.tenantName = tenantName endDate := int64(0) startDate := int64(0) - downloadParams := new(store.DownloadStreamTenantAuditLogsApiParams) + downloadParams := new(atlasv2.DownloadStreamTenantAuditLogsApiParams) downloadParams.EndDate = &endDate downloadParams.StartDate = &startDate - downloadParams.TenantName = "streams-tenant" + downloadParams.GroupId = projectID + downloadParams.TenantName = tenantName mockStore. EXPECT(). @@ -79,7 +80,7 @@ func TestDownloadOpts_Run(t *testing.T) { t.Fatalf("Run() unexpected error: %v", err) } - of, _ := fs.Open(fileName) + of, _ := fs.Open("auditLogs.gz") defer of.Close() b, _ := io.ReadAll(of) require.Equal(t, contents, string(b)) @@ -90,6 +91,6 @@ func TestDownloadBuilder(t *testing.T) { t, DownloadBuilder(), 0, - []string{flag.Out, flag.Start, flag.End, flag.Force, flag.Decompress, flag.ProjectID}, + []string{flag.Out, flag.Start, flag.End, flag.Force, flag.ProjectID}, ) }