Skip to content

Commit

Permalink
Merge pull request #188 from safing/maintain/sig-delete--binmeta--svc…
Browse files Browse the repository at this point in the history
…host

Improve: delete sigs, binary metadata, svchost service detection
  • Loading branch information
dhaavi authored Oct 10, 2022
2 parents 37b9178 + 40015b5 commit 8471f4f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
15 changes: 14 additions & 1 deletion updater/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package updater

import (
"errors"
"io/fs"
"os"
"path/filepath"
"sort"
Expand Down Expand Up @@ -461,11 +462,23 @@ boundarySearch:
storagePath := rv.storagePath()
err := os.Remove(storagePath)
if err != nil {
log.Warningf("%s: failed to purge resource %s v%s: %s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber, err)
if !errors.Is(err, fs.ErrNotExist) {
log.Warningf("%s: failed to purge resource %s v%s: %s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber, err)
}
} else {
log.Tracef("%s: purged resource %s v%s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber)
}

// Remove resource signature file.
err = os.Remove(rv.storageSigPath())
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Warningf("%s: failed to purge resource signature %s v%s: %s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber, err)
}
} else {
log.Tracef("%s: purged resource signature %s v%s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber)
}

// Remove unpacked version of resource.
ext := filepath.Ext(storagePath)
if ext == "" {
Expand Down
1 change: 1 addition & 0 deletions utils/osdetail/binmeta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func TestGenerateBinaryNameFromPath(t *testing.T) {
assert.Equal(t, "Browser Broker", GenerateBinaryNameFromPath("browser_broker.exe"))
assert.Equal(t, "Virtual Box VM", GenerateBinaryNameFromPath("VirtualBoxVM"))
assert.Equal(t, "Io Elementary Appcenter", GenerateBinaryNameFromPath("io.elementary.appcenter"))
assert.Equal(t, "Microsoft Windows Store", GenerateBinaryNameFromPath("Microsoft.WindowsStore"))
}

func TestCleanFileDescription(t *testing.T) {
Expand Down
26 changes: 14 additions & 12 deletions utils/osdetail/svchost_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

var (
serviceNames map[int32]string
serviceNames map[int32][]string
serviceNamesLock sync.Mutex
)

Expand All @@ -22,7 +22,7 @@ var (
)

// GetServiceNames returns all service names assosicated with a svchost.exe process on Windows.
func GetServiceNames(pid int32) (string, error) {
func GetServiceNames(pid int32) ([]string, error) {
serviceNamesLock.Lock()
defer serviceNamesLock.Unlock()

Expand All @@ -35,19 +35,19 @@ func GetServiceNames(pid int32) (string, error) {

serviceNames, err := GetAllServiceNames()
if err != nil {
return "", err
return nil, err
}

names, ok := serviceNames[pid]
if ok {
return names, nil
}

return "", ErrServiceNotFound
return nil, ErrServiceNotFound
}

// GetAllServiceNames returns a list of service names assosicated with svchost.exe processes on Windows.
func GetAllServiceNames() (map[int32]string, error) {
func GetAllServiceNames() (map[int32][]string, error) {
output, err := exec.Command("tasklist", "/svc", "/fi", "imagename eq svchost.exe").Output()
if err != nil {
return nil, fmt.Errorf("failed to get svchost tasklist: %s", err)
Expand All @@ -66,8 +66,8 @@ func GetAllServiceNames() (map[int32]string, error) {

var (
pid int32
services string
collection = make(map[int32]string)
services []string
collection = make(map[int32][]string)
)

for scanner.Scan() {
Expand All @@ -83,11 +83,11 @@ func GetAllServiceNames() (map[int32]string, error) {
if fields[0] == "svchost.exe" {
// save old entry
if pid != 0 {
collection[pid] = strings.TrimSpace(services)
collection[pid] = services
}
// reset
// reset PID
pid = 0
services = ""
services = make([]string, 0, len(fields))

// check fields length
if len(fields) < 3 {
Expand All @@ -106,12 +106,14 @@ func GetAllServiceNames() (map[int32]string, error) {
}

// add service names
services += " " + strings.Join(fields, " ")
for _, field := range fields {
services = append(services, strings.Trim(strings.TrimSpace(field), ","))
}
}

if pid != 0 {
// save last entry
collection[pid] = strings.TrimSpace(services)
collection[pid] = services
}

return collection, nil
Expand Down

0 comments on commit 8471f4f

Please sign in to comment.