diff --git a/server/service/accessControl.go b/server/service/accessControl.go index e493ffe..ac069ae 100644 --- a/server/service/accessControl.go +++ b/server/service/accessControl.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "net/http" "net/url" + "strings" ) // AccessProfile holds information about which hosts the user is allowed access to, @@ -22,6 +23,22 @@ func (ap *AccessProfile) IsAdmin() bool { return ap.isAdmin } +func (ap *AccessProfile) GetSQLWHERE() string { + //TODO this is temporary (bad) solution; will be removed later. See issue #67 + var sql strings.Builder + sql.WriteString("'") + frist := true + for k := range ap.certs { + if !frist { + sql.WriteString("','") + } + sql.WriteString(k) + frist = false + } + sql.WriteString("'") + return sql.String() +} + func GenerateAccessProfileForUser(userID string) (*AccessProfile, error) { if authorizationPluginURL == "" { // If no authorization plugin is defined, diff --git a/server/service/api_hostlist.go b/server/service/api_hostlist.go index fcab42b..6d1941d 100644 --- a/server/service/api_hostlist.go +++ b/server/service/api_hostlist.go @@ -77,7 +77,7 @@ func (vars *apiMethodHostList) ServeHTTP(w http.ResponseWriter, req *http.Reques // Grouping changes the whole SQL statement and what's returned, // so it is handled in a separate function if req.FormValue("group") != "" { - performGroupQuery(w, req, vars.db, customFieldIDs, vars.devmode) + performGroupQuery(w, req, vars.db, customFieldIDs, vars.devmode, access) return } @@ -305,7 +305,7 @@ func (vars *apiMethodHostList) ServeHTTP(w http.ResponseWriter, req *http.Reques } func performGroupQuery(w http.ResponseWriter, req *http.Request, - db *sql.DB, customFieldIDs map[string]int, devmode bool) { + db *sql.DB, customFieldIDs map[string]int, devmode bool, access *AccessProfile) { if req.FormValue("fields") != "" { http.Error(w, "Can't combine group and fields parameters", http.StatusUnprocessableEntity) @@ -362,6 +362,11 @@ func performGroupQuery(w http.ResponseWriter, req *http.Request, if len(where) > 0 { statement += " WHERE " + where + if access != nil && !access.IsAdmin() { + statement += " AND certfp IN (" + access.GetSQLWHERE() + ")" + } + } else if access != nil && !access.IsAdmin() { + statement += " WHERE certfp IN (" + access.GetSQLWHERE() + ")" } statement += " GROUP BY " + colname diff --git a/server/service/api_hostlist_test.go b/server/service/api_hostlist_test.go index 4aac8b9..d6ec10c 100644 --- a/server/service/api_hostlist_test.go +++ b/server/service/api_hostlist_test.go @@ -153,6 +153,20 @@ func TestApiMethodHostList(t *testing.T) { expectStatus: http.StatusOK, expectJSON: "{\"workstation\":2}", }, + // Test with an access profile that should prevent some hosts from being counted + { + methodAndPath: "GET /api/v0/hostlist?group=osEdition", + expectStatus: http.StatusOK, + expectJSON: "{\"workstation\":1}", + accessProfile: &AccessProfile{isAdmin: false, certs: map[string]bool{"1111": true}}, + }, + // Test with an access profile that should prevent some hosts from being counted + { + methodAndPath: "GET /api/v0/hostlist?group=osEdition&hostname=*baz*", + expectStatus: http.StatusOK, + expectJSON: "{}", + accessProfile: &AccessProfile{isAdmin: false, certs: map[string]bool{"1111": true}}, + }, } db := getDBconnForTesting(t) diff --git a/server/service/api_internal.go b/server/service/api_internal.go index d6e9e6d..fbabb83 100644 --- a/server/service/api_internal.go +++ b/server/service/api_internal.go @@ -65,10 +65,14 @@ func countFiles(w http.ResponseWriter, req *http.Request) { return } i, err := strconv.Atoi(req.FormValue("n")) - if err != nil || i == 0 { + if err != nil { + http.Error(w, "Invalid number: "+req.FormValue("n"), http.StatusBadRequest) return } - pfib.Add(float64(i)) // pfib = parsed files interval buffer + if i > 0 { + pfib.Add(float64(i)) // pfib = parsed files interval buffer + } + http.Error(w, "OK", http.StatusNoContent) } func doNothing(w http.ResponseWriter, req *http.Request) { diff --git a/server/service/fastSearch.go b/server/service/fastSearch.go index 2050eaf..a3ff077 100644 --- a/server/service/fastSearch.go +++ b/server/service/fastSearch.go @@ -111,8 +111,8 @@ func compareSearchCacheToDB(db *sql.DB) { log.Panic(rows.Err()) } // find entries in the cache that should have been removed - fsMutex.RLock() - defer fsMutex.RUnlock() + fsMutex.Lock() + defer fsMutex.Unlock() var obsoleteCount int for fileID := range fsKey { if _, ok := source[fileID]; !ok { diff --git a/server/service/oauth2login.go b/server/service/oauth2login.go index 9c36c26..6e67e7e 100644 --- a/server/service/oauth2login.go +++ b/server/service/oauth2login.go @@ -118,7 +118,9 @@ func handleOauth2Redirect(w http.ResponseWriter, req *http.Request) { http.Error(w, "Error reading Userinfo from Oauth2 provider", http.StatusInternalServerError) return } - log.Printf("Oauth2: Userinfo: %s", string(body)) + if devmode { + log.Printf("Oauth2: Userinfo: %s", string(body)) + } // Parse the JSON var userinfo interface{}