Skip to content

Commit

Permalink
database: Move db logic to dbutil
Browse files Browse the repository at this point in the history
Move all transaction related logic to dbutil to simplify and later unify
the db interface.
  • Loading branch information
KeyboardNerd committed Mar 5, 2019
1 parent 4fa03d1 commit 4669192
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 145 deletions.
47 changes: 12 additions & 35 deletions api/v3/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package v3

import (
"fmt"

"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -25,7 +23,6 @@ import (
pb "github.com/coreos/clair/api/v3/clairpb"
"github.com/coreos/clair/database"
"github.com/coreos/clair/ext/imagefmt"
"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
)

Expand Down Expand Up @@ -128,20 +125,13 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq
return nil, status.Errorf(codes.InvalidArgument, "ancestry name should not be empty")
}

tx, err := s.Store.Begin()
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

defer tx.Rollback()

ancestry, ok, err := tx.FindAncestry(name)
ancestry, ok, err := database.FindAncestryAndRollback(s.Store, name)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}

if !ok {
return nil, status.Error(codes.NotFound, fmt.Sprintf("requested ancestry '%s' is not found", req.GetAncestryName()))
return nil, status.Errorf(codes.NotFound, "requested ancestry '%s' is not found", req.GetAncestryName())
}

pbAncestry := &pb.GetAncestryResponse_Ancestry{
Expand All @@ -150,7 +140,7 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq
}

for _, layer := range ancestry.Layers {
pbLayer, err := GetPbAncestryLayer(tx, layer)
pbLayer, err := s.GetPbAncestryLayer(layer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -180,25 +170,20 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
return nil, status.Error(codes.InvalidArgument, "notification page limit should not be empty or less than 1")
}

tx, err := s.Store.Begin()
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
defer tx.Rollback()

dbNotification, ok, err := tx.FindVulnerabilityNotification(
dbNotification, ok, err := database.FindVulnerabilityNotificationAndRollback(
s.Store,
req.GetName(),
int(req.GetLimit()),
pagination.Token(req.GetOldVulnerabilityPage()),
pagination.Token(req.GetNewVulnerabilityPage()),
)

if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}

if !ok {
return nil, status.Error(codes.NotFound, fmt.Sprintf("requested notification '%s' is not found", req.GetName()))
return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName())
}

notification, err := pb.NotificationFromDatabaseModel(dbNotification)
Expand All @@ -216,21 +201,13 @@ func (s *NotificationServer) MarkNotificationAsRead(ctx context.Context, req *pb
return nil, status.Error(codes.InvalidArgument, "notification name should not be empty")
}

tx, err := s.Store.Begin()
found, err := database.MarkNotificationAsReadAndCommit(s.Store, req.GetName())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

defer tx.Rollback()
err = tx.DeleteNotification(req.GetName())
if err == commonerr.ErrNotFound {
return nil, status.Error(codes.NotFound, "requested notification \""+req.GetName()+"\" is not found")
} else if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}

if err := tx.Commit(); err != nil {
return nil, status.Error(codes.Internal, err.Error())
if !found {
return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName())
}

return &pb.MarkNotificationAsReadResponse{}, nil
Expand Down
12 changes: 5 additions & 7 deletions api/v3/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,22 @@ func GetClairStatus(store database.Datastore) (*pb.ClairStatus, error) {

// GetPbAncestryLayer retrieves an ancestry layer with vulnerabilities and
// features in an ancestry based on the provided database layer.
func GetPbAncestryLayer(tx database.Session, layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) {
func (s *AncestryServer) GetPbAncestryLayer(layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) {
pbLayer := &pb.GetAncestryResponse_AncestryLayer{
Layer: &pb.Layer{
Hash: layer.Hash,
},
}

features := layer.GetFeatures()
affectedFeatures, err := tx.FindAffectedNamespacedFeatures(features)
affectedFeatures, err := database.FindAffectedNamespacedFeaturesAndRollback(s.Store, features)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}

// NOTE(sidac): It's quite inefficient, but the easiest way to implement
// this feature for now, we should refactor the implementation if there's
// any performance issue. It's expected that the number of features is less
// than 1000.
for _, feature := range affectedFeatures {
// TODO(sidac): This is wrong. Based on this usage, we should not
// return nullable feature at all.
if !feature.Valid {
return nil, status.Error(codes.Internal, "ancestry feature is not found")
}
Expand Down
125 changes: 125 additions & 0 deletions database/dbutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (

log "github.com/sirupsen/logrus"

"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
"github.com/deckarep/golang-set"
)

Expand Down Expand Up @@ -400,3 +402,126 @@ func PersistDetectorsAndCommit(store Datastore, detectors []Detector) error {

return nil
}

// MarkNotificationAsReadAndCommit marks a notification as read.
func MarkNotificationAsReadAndCommit(store Datastore, name string) (bool, error) {
tx, err := store.Begin()
if err != nil {
return false, err
}

defer tx.Rollback()
err = tx.DeleteNotification(name)
if err == commonerr.ErrNotFound {
return false, nil
} else if err != nil {
return false, err
}

if err := tx.Commit(); err != nil {
return false, err
}

return true, nil
}

// FindAffectedNamespacedFeaturesAndRollback finds the vulnerabilities on each
// feature.
func FindAffectedNamespacedFeaturesAndRollback(store Datastore, features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) {
tx, err := store.Begin()
if err != nil {
return nil, err
}

defer tx.Rollback()
nullableFeatures, err := tx.FindAffectedNamespacedFeatures(features)
if err != nil {
return nil, err
}

return nullableFeatures, nil
}

// FindVulnerabilityNotificationAndRollback finds the vulnerability notification
// and rollback.
func FindVulnerabilityNotificationAndRollback(store Datastore, name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti VulnerabilityNotificationWithVulnerable, found bool, err error) {
tx, err := store.Begin()
if err != nil {
return
}

defer tx.Rollback()
noti, found, err = tx.FindVulnerabilityNotification(name, limit, oldVulnerabilityPage, newVulnerabilityPage)
return
}

// FindNewNotification finds notifications either never notified or notified
// before the given time.
func FindNewNotification(store Datastore, notifiedBefore time.Time) (NotificationHook, bool, error) {
tx, err := store.Begin()
if err != nil {
return NotificationHook{}, false, err
}

defer tx.Rollback()
return tx.FindNewNotification(notifiedBefore)
}

// UpdateKeyValueAndCommit stores the key value to storage.
func UpdateKeyValueAndCommit(store Datastore, key, value string) error {
tx, err := store.Begin()
if err != nil {
return err
}

defer tx.Rollback()
if err = tx.UpdateKeyValue(key, value); err != nil {
return err
}

return tx.Commit()
}

// InsertVulnerabilityNotificationsAndCommit inserts the notifications into db
// and commit.
func InsertVulnerabilityNotificationsAndCommit(store Datastore, notifications []VulnerabilityNotification) error {
tx, err := store.Begin()
if err != nil {
return err
}
defer tx.Rollback()

if err := tx.InsertVulnerabilityNotifications(notifications); err != nil {
return err
}

return tx.Commit()
}

// FindVulnerabilitiesAndRollback finds the vulnerabilities based on given ids.
func FindVulnerabilitiesAndRollback(store Datastore, ids []VulnerabilityID) ([]NullableVulnerability, error) {
tx, err := store.Begin()
if err != nil {
return nil, err
}

defer tx.Rollback()
return tx.FindVulnerabilities(ids)
}

func UpdateVulnerabilitiesAndCommit(store Datastore, toRemove []VulnerabilityID, toAdd []VulnerabilityWithAffected) error {
tx, err := store.Begin()
if err != nil {
return err
}

if err := tx.DeleteVulnerabilities(toRemove); err != nil {
return err
}

if err := tx.InsertVulnerabilities(toAdd); err != nil {
return err
}

return tx.Commit()
}
8 changes: 1 addition & 7 deletions ext/vulnsrc/suse/suse.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,10 @@ func init() {
func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) {
log.WithField("package", u.Name).Info("Start fetching vulnerabilities")

tx, err := datastore.Begin()
if err != nil {
return resp, err
}
defer tx.Rollback()

// openSUSE and SUSE have one single xml file for all the products, there are no incremental
// xml files. We store into the database the value of the generation timestamp
// of the latest file we parsed.
flagValue, ok, err := tx.FindKeyValue(u.UpdaterFlag)
flagValue, ok, err := database.FindKeyValueAndRollback(datastore, u.UpdaterFlag)
if err != nil {
return resp, err
}
Expand Down
7 changes: 0 additions & 7 deletions ext/vulnsrc/ubuntu/ubuntu.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,6 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er
return resp, err
}

// Open a database transaction.
tx, err := db.Begin()
if err != nil {
return resp, err
}
defer tx.Rollback()

// Ask the database for the latest commit we successfully applied.
dbCommit, ok, err := database.FindKeyValueAndRollback(db, updaterFlag)
if err != nil {
Expand Down
26 changes: 2 additions & 24 deletions notifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop
go func() {
success, interrupted := handleTask(*notification, stopper, config.Attempts)
if success {
err := markNotificationAsRead(datastore, notification.Name)
_, err := database.MarkNotificationAsReadAndCommit(datastore, notification.Name)
if err != nil {
log.WithError(err).Error("Failed to mark notification notified")
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop

func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.NotificationHook {
for {
notification, ok, err := findNewNotification(datastore, renotifyInterval)
notification, ok, err := database.FindNewNotification(datastore, time.Now().Add(-renotifyInterval))
if err != nil || !ok {
if !ok {
log.WithError(err).Warning("could not get notification to send")
Expand Down Expand Up @@ -186,25 +186,3 @@ func handleTask(n database.NotificationHook, st *stopper.Stopper, maxAttempts in
log.WithField(logNotiName, n.Name).Info("successfully sent notification")
return true, false
}

func findNewNotification(datastore database.Datastore, renotifyInterval time.Duration) (database.NotificationHook, bool, error) {
tx, err := datastore.Begin()
if err != nil {
return database.NotificationHook{}, false, err
}
defer tx.Rollback()
return tx.FindNewNotification(time.Now().Add(-renotifyInterval))
}

func markNotificationAsRead(datastore database.Datastore, name string) error {
tx, err := datastore.Begin()
if err != nil {
log.WithError(err).Error("an error happens when beginning database transaction")
}
defer tx.Rollback()

if err := tx.MarkNotificationAsRead(name); err != nil {
return err
}
return tx.Commit()
}
Loading

0 comments on commit 4669192

Please sign in to comment.