Skip to content

Commit

Permalink
added host in secret info struct
Browse files Browse the repository at this point in the history
simplified the mysql test due to huge structure
  • Loading branch information
abmussani committed Sep 9, 2024
1 parent 5b85bbf commit e00c8dd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 39 deletions.
2 changes: 1 addition & 1 deletion pkg/analyzer/analyzers/mysql/expected_output.json

Large diffs are not rendered by default.

24 changes: 17 additions & 7 deletions pkg/analyzer/analyzers/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func bakeUserBindings(info *SecretInfo) ([]analyzers.Binding, *analyzers.Resourc
// add user and their priviliges to bindings
userResource := analyzers.Resource{
Name: info.User,
FullyQualifiedName: info.User,
FullyQualifiedName: info.Host + "/" + info.User,
Type: "user",
}

Expand All @@ -93,7 +93,7 @@ func bakeDatabaseBindings(userResource *analyzers.Resource, info *SecretInfo) []
for _, database := range info.Databases {
dbResource := analyzers.Resource{
Name: database.Name,
FullyQualifiedName: database.Name,
FullyQualifiedName: info.Host + "/" + database.Name,
Type: "database",
Metadata: map[string]any{
"default": database.Default,
Expand Down Expand Up @@ -131,7 +131,7 @@ func bakeTableBindings(dbResource *analyzers.Resource, database *Database) []ana
for _, table := range *database.Tables {
tableResource := analyzers.Resource{
Name: table.Name,
FullyQualifiedName: table.Name,
FullyQualifiedName: dbResource.FullyQualifiedName + "/" + table.Name,
Type: "table",
Metadata: map[string]any{
"bytes": table.Bytes,
Expand All @@ -153,7 +153,7 @@ func bakeTableBindings(dbResource *analyzers.Resource, database *Database) []ana
for _, column := range table.Columns {
columnResource := analyzers.Resource{
Name: column.Name,
FullyQualifiedName: column.Name,
FullyQualifiedName: tableResource.FullyQualifiedName + "/" + column.Name,
Type: "column",
Parent: &tableResource,
}
Expand Down Expand Up @@ -181,7 +181,7 @@ func bakeRoutineBindings(dbResource *analyzers.Resource, database *Database) []a
for _, routine := range *database.Routines {
routineResource := analyzers.Resource{
Name: routine.Name,
FullyQualifiedName: routine.Name,
FullyQualifiedName: dbResource.FullyQualifiedName + "/" + routine.Name,
Type: "routine",
Metadata: map[string]any{
"non_existent": routine.Nonexistent,
Expand Down Expand Up @@ -257,6 +257,7 @@ type Routine struct {
// USER() returns `doadmin@localhost`

type SecretInfo struct {
Host string
User string
Databases map[string]*Database
GlobalPrivs GlobalPrivs
Expand All @@ -282,8 +283,13 @@ func AnalyzeAndPrintPermissions(cfg *config.Config, key string) {
}

func AnalyzePermissions(cfg *config.Config, connectionStr string) (*SecretInfo, error) {
// Parse the connection string
u, err := parseConnectionStr(connectionStr)
if err != nil {
return nil, fmt.Errorf("parsing the connection string: %w", err)
}

db, err := createConnection(connectionStr)
db, err := createConnection(u)
if err != nil {
return nil, fmt.Errorf("connecting to the MySQL database: %w", err)
}
Expand Down Expand Up @@ -322,13 +328,14 @@ func AnalyzePermissions(cfg *config.Config, connectionStr string) (*SecretInfo,
processGrants(grants, databases, &globalPrivs)

return &SecretInfo{
Host: u.Hostname(),
User: user,
Databases: databases,
GlobalPrivs: globalPrivs,
}, nil
}

func createConnection(connection string) (*sql.DB, error) {
func parseConnectionStr(connection string) (*dburl.URL, error) {
// Check if the connection string starts with 'mysql://'
if !strings.HasPrefix(connection, "mysql://") {
color.Yellow("[i] The connection string should start with 'mysql://'. Adding it for you.")
Expand All @@ -346,7 +353,10 @@ func createConnection(connection string) (*sql.DB, error) {
if err != nil {
return nil, err
}
return u, nil
}

func createConnection(u *dburl.URL) (*sql.DB, error) {
// Connect to the MySQL database
db, err := sql.Open("mysql", u.DSN)
if err != nil {
Expand Down
45 changes: 14 additions & 31 deletions pkg/analyzer/analyzers/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
_ "embed"
"encoding/json"
"fmt"
"sort"
"testing"

"github.com/brianvoe/gofakeit/v7"
"github.com/google/go-cmp/cmp"
"github.com/testcontainers/testcontainers-go/modules/mysql"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/analyzers"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
Expand Down Expand Up @@ -67,9 +67,6 @@ func TestAnalyzer_Analyze(t *testing.T) {
return
}

// bindings need to be in the same order to be comparable
sortBindings(got.Bindings)

// Marshal the actual result to JSON
gotJSON, err := json.Marshal(got)
if err != nil {
Expand All @@ -82,40 +79,26 @@ func TestAnalyzer_Analyze(t *testing.T) {
t.Fatalf("could not unmarshal want JSON string: %s", err)
}

// bindings need to be in the same order to be comparable
sortBindings(wantObj.Bindings)

// Marshal the expected result to JSON (to normalize)
wantJSON, err := json.Marshal(wantObj)
if err != nil {
t.Fatalf("could not marshal want to JSON: %s", err)
}

// Compare the JSON strings
if string(gotJSON) != string(wantJSON) {
// Pretty-print both JSON strings for easier comparison
var gotIndented, wantIndented []byte
gotIndented, err = json.MarshalIndent(got, "", " ")
if err != nil {
t.Fatalf("could not marshal got to indented JSON: %s", err)
}
wantIndented, err = json.MarshalIndent(wantObj, "", " ")
if err != nil {
t.Fatalf("could not marshal want to indented JSON: %s", err)
}

t.Errorf("Analyzer.Analyze() = %s, want %s", gotIndented, wantIndented)
// Compare bindings separately because they are not guaranteed to be in the same order
if len(got.Bindings) != len(wantObj.Bindings) {
t.Errorf("Analyzer.Analyze() = %s, want %s", gotJSON, wantJSON)
return
}

got.Bindings = nil
wantObj.Bindings = nil

// Compare the rest of the Object
if diff := cmp.Diff(&wantObj, got); diff != "" {
t.Errorf("%s: (-want +got)\n%s", tt.name, diff)
return
}
})
}
}

// Helper function to sort bindings
func sortBindings(bindings []analyzers.Binding) {
sort.SliceStable(bindings, func(i, j int) bool {
if bindings[i].Resource.Name == bindings[j].Resource.Name {
return bindings[i].Permission.Value < bindings[j].Permission.Value
}
return bindings[i].Resource.Name < bindings[j].Resource.Name
})
}

0 comments on commit e00c8dd

Please sign in to comment.