Skip to content

Commit

Permalink
enable setup common name whitelist for tls checking
Browse files Browse the repository at this point in the history
Signed-off-by: Frank Yang <[email protected]>
  • Loading branch information
allencloud authored and yyb196 committed Apr 2, 2018
2 parents caf4ea8 + 53247a3 commit 45f33ea
Show file tree
Hide file tree
Showing 12 changed files with 324 additions and 110 deletions.
91 changes: 51 additions & 40 deletions apis/server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,53 +24,53 @@ func initRoute(s *Server) http.Handler {
r := mux.NewRouter()

// system
addRoute(r, http.MethodGet, "/_ping", s.ping)
addRoute(r, http.MethodGet, "/info", s.info)
addRoute(r, http.MethodGet, "/version", s.version)
addRoute(r, http.MethodPost, "/auth", s.auth)
s.addRoute(r, http.MethodGet, "/_ping", s.ping)
s.addRoute(r, http.MethodGet, "/info", s.info)
s.addRoute(r, http.MethodGet, "/version", s.version)
s.addRoute(r, http.MethodPost, "/auth", s.auth)

// daemon, we still list this API into system manager.
addRoute(r, http.MethodPost, "/daemon/update", s.updateDaemon)
s.addRoute(r, http.MethodPost, "/daemon/update", s.updateDaemon)

// container
addRoute(r, http.MethodPost, "/containers/create", s.createContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/start", s.startContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/stop", s.stopContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/attach", s.attachContainer)
addRoute(r, http.MethodGet, "/containers/json", s.getContainers)
addRoute(r, http.MethodGet, "/containers/{name:.*}/json", s.getContainer)
addRoute(r, http.MethodDelete, "/containers/{name:.*}", s.removeContainers)
addRoute(r, http.MethodPost, "/containers/{name:.*}/exec", s.createContainerExec)
addRoute(r, http.MethodPost, "/exec/{name:.*}/start", s.startContainerExec)
addRoute(r, http.MethodPost, "/containers/{name:.*}/rename", s.renameContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/restart", s.restartContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/pause", s.pauseContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/unpause", s.unpauseContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/update", s.updateContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/upgrade", s.upgradeContainer)
addRoute(r, http.MethodGet, "/containers/{name:.*}/top", s.topContainer)
addRoute(r, http.MethodGet, "/containers/{name:.*}/logs", s.logsContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/resize", s.resizeContainer)
addRoute(r, http.MethodPost, "/containers/{name:.*}/restart", s.restartContainer)
s.addRoute(r, http.MethodPost, "/containers/create", s.createContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/start", s.startContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/stop", s.stopContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/attach", s.attachContainer)
s.addRoute(r, http.MethodGet, "/containers/json", s.getContainers)
s.addRoute(r, http.MethodGet, "/containers/{name:.*}/json", s.getContainer)
s.addRoute(r, http.MethodDelete, "/containers/{name:.*}", s.removeContainers)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/exec", s.createContainerExec)
s.addRoute(r, http.MethodPost, "/exec/{name:.*}/start", s.startContainerExec)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/rename", s.renameContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/restart", s.restartContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/pause", s.pauseContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/unpause", s.unpauseContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/update", s.updateContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/upgrade", s.upgradeContainer)
s.addRoute(r, http.MethodGet, "/containers/{name:.*}/top", s.topContainer)
s.addRoute(r, http.MethodGet, "/containers/{name:.*}/logs", s.logsContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/resize", s.resizeContainer)
s.addRoute(r, http.MethodPost, "/containers/{name:.*}/restart", s.restartContainer)

// image
addRoute(r, http.MethodPost, "/images/create", s.pullImage)
addRoute(r, http.MethodPost, "/images/search", s.searchImages)
addRoute(r, http.MethodGet, "/images/json", s.listImages)
addRoute(r, http.MethodDelete, "/images/{name:.*}", s.removeImage)
addRoute(r, http.MethodGet, "/images/{name:.*}/json", s.getImage)
s.addRoute(r, http.MethodPost, "/images/create", s.pullImage)
s.addRoute(r, http.MethodPost, "/images/search", s.searchImages)
s.addRoute(r, http.MethodGet, "/images/json", s.listImages)
s.addRoute(r, http.MethodDelete, "/images/{name:.*}", s.removeImage)
s.addRoute(r, http.MethodGet, "/images/{name:.*}/json", s.getImage)

// volume
addRoute(r, http.MethodGet, "/volumes", s.listVolume)
addRoute(r, http.MethodPost, "/volumes/create", s.createVolume)
addRoute(r, http.MethodGet, "/volumes/{name:.*}", s.getVolume)
addRoute(r, http.MethodDelete, "/volumes/{name:.*}", s.removeVolume)
s.addRoute(r, http.MethodGet, "/volumes", s.listVolume)
s.addRoute(r, http.MethodPost, "/volumes/create", s.createVolume)
s.addRoute(r, http.MethodGet, "/volumes/{name:.*}", s.getVolume)
s.addRoute(r, http.MethodDelete, "/volumes/{name:.*}", s.removeVolume)

// network
addRoute(r, http.MethodGet, "/networks", s.listNetwork)
addRoute(r, http.MethodPost, "/networks/create", s.createNetwork)
addRoute(r, http.MethodGet, "/networks/{name:.*}", s.getNetwork)
addRoute(r, http.MethodDelete, "/networks/{name:.*}", s.deleteNetwork)
s.addRoute(r, http.MethodGet, "/networks", s.listNetwork)
s.addRoute(r, http.MethodPost, "/networks/create", s.createNetwork)
s.addRoute(r, http.MethodGet, "/networks/{name:.*}", s.getNetwork)
s.addRoute(r, http.MethodDelete, "/networks/{name:.*}", s.deleteNetwork)

// metrics
r.Path(versionMatcher + "/metrics").Methods(http.MethodGet).Handler(prometheus.Handler())
Expand All @@ -81,8 +81,8 @@ func initRoute(s *Server) http.Handler {
return r
}

func addRoute(r *mux.Router, mothod string, path string, f func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error) {
r.Path(versionMatcher + path).Methods(mothod).Handler(filter(f))
func (s *Server) addRoute(r *mux.Router, mothod string, path string, f func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error) {
r.Path(versionMatcher + path).Methods(mothod).Handler(filter(f, s))
}

func profilerSetup(mainRouter *mux.Router) {
Expand All @@ -100,13 +100,24 @@ func profilerSetup(mainRouter *mux.Router) {

type handler func(context.Context, http.ResponseWriter, *http.Request) error

func filter(handler handler) http.HandlerFunc {
func filter(handler handler, s *Server) http.HandlerFunc {
pctx := context.Background()

return func(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithCancel(pctx)
defer cancel()

s.lock.RLock()
if len(s.ManagerWhiteList) > 0 && req.TLS != nil && len(req.TLS.PeerCertificates) > 0 {
if _, isManager := s.ManagerWhiteList[req.TLS.PeerCertificates[0].Subject.CommonName]; !isManager {
s.lock.RUnlock()
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("tls verified error."))
return
}
}
s.lock.RUnlock()

t := time.Now()
clientInfo := req.RemoteAddr
defer func() {
Expand Down
33 changes: 25 additions & 8 deletions apis/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"os"
"strings"
"sync"
"syscall"

"github.com/alibaba/pouch/apis/plugins"
Expand All @@ -19,14 +20,16 @@ import (

// Server is a http server which serves restful api to client.
type Server struct {
Config *config.Config
ContainerMgr mgr.ContainerMgr
SystemMgr mgr.SystemMgr
ImageMgr mgr.ImageMgr
VolumeMgr mgr.VolumeMgr
NetworkMgr mgr.NetworkMgr
listeners []net.Listener
ContainerPlugin plugins.ContainerPlugin
Config *config.Config
ContainerMgr mgr.ContainerMgr
SystemMgr mgr.SystemMgr
ImageMgr mgr.ImageMgr
VolumeMgr mgr.VolumeMgr
NetworkMgr mgr.NetworkMgr
listeners []net.Listener
ContainerPlugin plugins.ContainerPlugin
ManagerWhiteList map[string]struct{}
lock sync.RWMutex
}

// Start setup route table and listen to specified address which currently only supports unix socket and tcp address.
Expand All @@ -51,6 +54,7 @@ func (s *Server) Start() (err error) {
if s.Config.TLS.VerifyRemote {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
SetupManagerWhitelist(s)
}

for _, one := range s.Config.Listen {
Expand All @@ -70,6 +74,19 @@ func (s *Server) Start() (err error) {
return <-errCh
}

// SetupManagerWhitelist enables users to setup which common name can access this server
func SetupManagerWhitelist(server *Server) {
if server.Config.TLS.ManagerWhiteList != "" {
server.lock.Lock()
defer server.lock.Unlock()
arr := strings.Split(server.Config.TLS.ManagerWhiteList, ",")
server.ManagerWhiteList = make(map[string]struct{}, len(arr))
for _, cn := range arr {
server.ManagerWhiteList[cn] = struct{}{}
}
}
}

// Stop will shutdown http server by closing all listeners.
func (s *Server) Stop() error {
for _, one := range s.listeners {
Expand Down
15 changes: 7 additions & 8 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ type APIClient struct {

// TLSConfig contains information of tls which users can specify
type TLSConfig struct {
CA string `json:"tlscacert,omitempty"`
Cert string `json:"tlscert,omitempty"`
Key string `json:"tlskey,omitempty"`
VerifyRemote bool
CA string `json:"tlscacert,omitempty"`
Cert string `json:"tlscert,omitempty"`
Key string `json:"tlskey,omitempty"`
VerifyRemote bool
ManagerWhiteList string
}

// NewAPIClient initializes a new API client for the given host
Expand Down Expand Up @@ -185,10 +186,8 @@ func GenTLSConfig(key, cert, ca string) (*tls.Config, error) {
if ca == "" {
return tlsConfig, nil
}
cp, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to read system certificates: %v", err)
}

cp := x509.NewCertPool()
pem, err := ioutil.ReadFile(ca)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate %q: %v", ca, err)
Expand Down
22 changes: 22 additions & 0 deletions client/registry_login.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package client

import (
"context"
"net/http"

"github.com/alibaba/pouch/apis/types"
)

// RegistryLogin authenticates the server with a given registry to login.
func (client *APIClient) RegistryLogin(ctx context.Context, auth *types.AuthConfig) (*types.AuthResponse, error) {
resp, err := client.post(ctx, "/auth", nil, auth, nil)
if err != nil || resp.StatusCode == http.StatusUnauthorized {
return nil, err
}

authResp := &types.AuthResponse{}
err = decodeBody(authResp, resp.Body)
ensureCloseReader(resp)

return authResp, err
}
64 changes: 64 additions & 0 deletions client/registry_login_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package client

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"

"github.com/alibaba/pouch/apis/types"
"github.com/stretchr/testify/assert"
)

func TestRegistryLoginError(t *testing.T) {
client := &APIClient{
HTTPCli: newMockClient(errorMockResponse(http.StatusInternalServerError, "Server error")),
}
loginConfig := types.AuthConfig{}
_, err := client.RegistryLogin(context.Background(), &loginConfig)
if err == nil || !strings.Contains(err.Error(), "Server error") {
t.Fatalf("expected a Server Error, got %v", err)
}
}

func TestRegistryLogin(t *testing.T) {
expectedURL := "/auth"

httpClient := newMockClient(func(req *http.Request) (*http.Response, error) {
if !strings.HasPrefix(req.URL.Path, expectedURL) {
return nil, fmt.Errorf("Expected URL '%s', got '%s'", expectedURL, req.URL)
}
if req.Header.Get("Content-Type") == "application/json" {
loginConfig := types.AuthConfig{}
if err := json.NewDecoder(req.Body).Decode(&loginConfig); err != nil {
return nil, fmt.Errorf("failed to parse json: %v", err)
}
}
auth, err := json.Marshal(types.AuthResponse{
IdentityToken: "aaa",
Status: "bbb",
})
if err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader([]byte(auth))),
}, nil
})

client := &APIClient{
HTTPCli: httpClient,
}

res, err := client.RegistryLogin(context.Background(), &types.AuthConfig{})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, res.IdentityToken, "aaa")
assert.Equal(t, res.Status, "bbb")
}
53 changes: 0 additions & 53 deletions client/system.go

This file was deleted.

21 changes: 21 additions & 0 deletions client/system_info.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package client

import (
"context"

"github.com/alibaba/pouch/apis/types"
)

// SystemInfo requests daemon for system info.
func (client *APIClient) SystemInfo(ctx context.Context) (*types.SystemInfo, error) {
resp, err := client.get(ctx, "/info", nil, nil)
if err != nil {
return nil, err
}

info := &types.SystemInfo{}
err = decodeBody(info, resp.Body)
ensureCloseReader(resp)

return info, err
}
Loading

0 comments on commit 45f33ea

Please sign in to comment.