diff --git a/xray/plugin.go b/xray/plugin.go index 2382cac..a6217bf 100644 --- a/xray/plugin.go +++ b/xray/plugin.go @@ -84,6 +84,7 @@ func getVersion() string { // search the most specific module if strings.HasPrefix(pkg, dep.Path) && len(dep.Path) > depth { version = dep.Version + depth = len(dep.Path) } } return version diff --git a/xray/plugin_test.go b/xray/plugin_test.go index e844b58..56fe417 100644 --- a/xray/plugin_test.go +++ b/xray/plugin_test.go @@ -13,7 +13,9 @@ func TestAddPlugin(t *testing.T) { var wg sync.WaitGroup wg.Add(n) for i := 0; i < n; i++ { - go AddPlugin(&xrayPlugin{}) + go AddPlugin(&xrayPlugin{ + sdkVersion: getVersion(), + }) go getPlugins() } diff --git a/xraysql/connector.go b/xraysql/connector.go index b72bbe2..d7d11c9 100644 --- a/xraysql/connector.go +++ b/xraysql/connector.go @@ -6,7 +6,9 @@ import ( "fmt" "io" "reflect" + "runtime/debug" "strconv" + "strings" "sync" "time" @@ -172,21 +174,7 @@ func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, con } } - // There's no standard to get SQL driver version information - // So we invent an interface by which drivers can provide us this data - type versionedDriver interface { - Version() string - } - - if vd, ok := d.(versionedDriver); ok { - attr.driverVersion = vd.Version() - } else { - t := reflect.TypeOf(d) - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - attr.driverVersion = t.PkgPath() - } + attr.driverVersion = getDriverVersion(d) if driverName != "" { attr.name = attr.dbname + "@" + driverName } else { @@ -205,6 +193,34 @@ func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, con return &attr, nil } +func getDriverVersion(d driver.Driver) string { + t := reflect.TypeOf(d) + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + pkg := t.PkgPath() + + info, ok := debug.ReadBuildInfo() + if !ok { + return pkg + } + + version := "" + depth := 0 + for _, dep := range info.Deps { + // search the most specific module + if strings.HasPrefix(pkg, dep.Path) && len(dep.Path) > depth { + version = dep.Version + depth = len(dep.Path) + } + } + + if version == "" { + return pkg + } + return pkg + "@" + version +} + func postgresDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error { var databaseVersion, user, dbname string err := queryRow(