Skip to content

Commit

Permalink
Check that requests to docker endpoint are against rdns-valid hostname (
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Whitlock authored Apr 12, 2018
1 parent 68a3060 commit e6e9ada
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 1 deletion.
20 changes: 20 additions & 0 deletions cmd/docker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ import (

vicbackends "github.com/vmware/vic/lib/apiservers/engine/backends"
"github.com/vmware/vic/lib/apiservers/engine/backends/executor"
vicmiddleware "github.com/vmware/vic/lib/apiservers/engine/backends/middleware"
"github.com/vmware/vic/lib/config"
"github.com/vmware/vic/lib/constants"
vicdns "github.com/vmware/vic/lib/dns"
"github.com/vmware/vic/lib/portlayer/util"
"github.com/vmware/vic/lib/pprof"
viclog "github.com/vmware/vic/pkg/log"
Expand Down Expand Up @@ -196,6 +198,16 @@ func loadCAPool() *x509.CertPool {
return pool
}

func enforceHostHeaderCheck(addr string, api *apiserver.Server) {
rdnsNames := vicdns.ReverseLookup(addr)
if len(rdnsNames) == 0 {
log.Warnf("Could not resolve domain names for %s. Docker endpoint will only be accessible via the IP", addr)
}
rdnsNames[addr] = true // add the client IP because that's always allowed
hostCheckMW := vicmiddleware.HostCheckMiddleware{ValidDomains: rdnsNames}
api.UseMiddleware(hostCheckMW)
}

func startServer() *apiserver.Server {
serverConfig := &apiserver.Config{
Logging: true,
Expand Down Expand Up @@ -246,6 +258,8 @@ func startServer() *apiserver.Server {
tlsConfig.ClientCAs = loadCAPool()
tlsConfig.InsecureSkipVerify = false
}
} else {
log.Warnf("Docker endpoint running in plain HTTP mode")
}

addr := "0.0.0.0"
Expand All @@ -268,6 +282,12 @@ func startServer() *apiserver.Server {
version.DockerDefaultVersion,
version.DockerMinimumVersion)
api.UseMiddleware(mw)

if vchConfig.HostCertificate.IsNil() && vchConfig.Diagnostics.DebugLevel <= 2 {
// only enforce host header check in non-debug http-only mode
enforceHostHeaderCheck(addr, api)
}

fullserver := fmt.Sprintf("%s:%d", addr, *cli.serverPort)
l, err := listeners.Init(cli.proto, fullserver, "", serverConfig.TLSConfig)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion lib/apiservers/engine/backends/backends.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import (
"github.com/vmware/vic/pkg/registry"
"github.com/vmware/vic/pkg/vsphere/session"
"github.com/vmware/vic/pkg/vsphere/sys"
//"github.com/vishvananda/netlink"
)

const (
Expand Down
72 changes: 72 additions & 0 deletions lib/apiservers/engine/backends/middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2018 VMware, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package middleware

import (
"fmt"
"net/http"
"strings"

"golang.org/x/net/context"

vicdns "github.com/vmware/vic/lib/dns"
)

// HostCheckMiddleware provides middleware for Host header correctness enforcement
type HostCheckMiddleware struct {
ValidDomains vicdns.SetOfDomains
}

// validateHostname trims the port off the Host header in an HTTP request and returns either the bare IP (v4 or v6) or the FQDN with the port truncated. Returns non-nil error if Host field doesn't make sense.
func validateHostname(r *http.Request) (hostname string, err error) {
if r.Host == "" {
// this really shouldn't be necessary https://tools.ietf.org/html/rfc2616#section-14.23
// you can delete this if stanza if you're braver than me.
return "", fmt.Errorf("empty host header from %s", r.RemoteAddr)
}

if r.Host[len(r.Host)-1] == ']' {
// ipv6 w/o specified port
return r.Host, nil
}

// trim port if it's there. r.Host should never contain a scheme
hostnameSplit := strings.Split(r.Host, ":")

if len(hostnameSplit) <= 2 {
// ipv4 or dns hostname with or without port, first element is hostname
return hostnameSplit[0], nil
}

// if we see >2 colons in the hostname, it's an ipv6 address w/ port
// unfortunately that means we have to recombine the rest..
return fmt.Sprintf("%s", strings.Join(hostnameSplit[:len(hostnameSplit)-1], ":")), nil
}

// WrapHandler satisfies the Docker middleware interface for HostCheckMiddleware to reject http requests that do not specify a known DNS name for this endpoint in the Host: field of the request
func (h HostCheckMiddleware) WrapHandler(f func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) (err error) {
var hostname string
if hostname, err = validateHostname(r); err != nil {
return err
}

if h.ValidDomains[hostname] {
return f(ctx, w, r, vars)
}

return fmt.Errorf("invalid host header from %s to requested host %s", r.RemoteAddr, r.Host)
}
}
75 changes: 75 additions & 0 deletions lib/apiservers/engine/backends/middleware/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright 2018 VMware, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package middleware

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestValidateHostname(t *testing.T) {
r := &http.Request{}
hostname, err := validateHostname(r)
assert.Error(t, err)
assert.EqualValues(t, "", hostname)

r.Host = ""
hostname, err = validateHostname(r)
assert.Error(t, err)
assert.EqualValues(t, "", hostname)

r.Host = "localname"
hostname, err = validateHostname(r)
assert.NoError(t, err)
assert.EqualValues(t, "localname", hostname)

r.Host = "localname:4567"
hostname, err = validateHostname(r)
assert.NoError(t, err)
assert.EqualValues(t, "localname", hostname)

r.Host = "[2605:a601:1119:6800:c69b:b2ec:eefa:ef4b]"
hostname, err = validateHostname(r)
assert.NoError(t, err)
assert.EqualValues(t, "[2605:a601:1119:6800:c69b:b2ec:eefa:ef4b]", hostname)

r.Host = "[2605:a601:1119:6800:c69b:b2ec:eefa:ef4b]:8080"
hostname, err = validateHostname(r)
assert.NoError(t, err)
assert.EqualValues(t, "[2605:a601:1119:6800:c69b:b2ec:eefa:ef4b]", hostname)

r.Host = "127.0.0.1:8080"
hostname, err = validateHostname(r)
assert.NoError(t, err)
assert.EqualValues(t, "127.0.0.1", hostname)

r.Host = "127.0.0.1"
hostname, err = validateHostname(r)
assert.NoError(t, err)
assert.EqualValues(t, "127.0.0.1", hostname)

r.Host = "foo.com:80"
hostname, err = validateHostname(r)
assert.NoError(t, err)
assert.EqualValues(t, "foo.com", hostname)

r.Host = "foo.com"
hostname, err = validateHostname(r)
assert.NoError(t, err)
assert.EqualValues(t, "foo.com", hostname)

}
80 changes: 80 additions & 0 deletions lib/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"math/rand"
"net"
"os"
"strconv"
"strings"

"sync"
Expand All @@ -39,6 +40,7 @@ const (
DefaultTTL = 600 * time.Second
DefaultCacheSize = 1024
DefaultTimeout = 4 * time.Second
hexDigit = "0123456789abcdef"
)

var (
Expand Down Expand Up @@ -639,3 +641,81 @@ func (s *Server) Stop() {
func (s *Server) Wait() {
s.wg.Wait()
}

// SetOfDomains is a type for storing string-type domain names as an unsorted set
// var f SetOfDomains
// f = make(map[string]bool)
// Store in the set
// f["foo.com"] = true
// then to check to see if something is in the 'set':
// if f["foo.com"] {
type SetOfDomains map[string]bool

// ReverseLookup returns a set of FQDNs for ipAddr from nameservers in /etc/resolv.conf
// /etc/hosts and /etc/nsswitch.conf are ignored by this function
func ReverseLookup(ipAddr string) (domains SetOfDomains) {
domains = make(map[string]bool)

address, err := reverseaddr(ipAddr)
if err != nil {
log.Errorf("%s", err)
return
}

nameservers := resolvconf()
for _, n := range nameservers {
dnsClient := new(mdns.Client)
msg := new(mdns.Msg)

msg.SetQuestion(address, mdns.TypePTR)
r, _, err := dnsClient.Exchange(msg, n+":53")
if err != nil {
log.Warnf("got error \"%s\" from %s", err, n)
continue
}

if len(r.Answer) == 0 {
log.Warnf("no reply from %s", n)
continue
}

for _, a := range r.Answer {
switch a := a.(type) {
case *mdns.PTR:
// cut the . off the end of the returned record & store it
domains[strings.TrimSuffix(a.Ptr, ".")] = true
default:
log.Debugf("got nonstandard answer %s (from nameserver %s)", a, n)
}
}
}

return
}

// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
// address addr suitable for rDNS (PTR) record lookup or an error if it fails
// to parse the IP address.
// this helper func was lifted from stdlib -- net/dnsclient.go
func reverseaddr(addr string) (arpa string, err error) {
ip := net.ParseIP(addr)
if ip == nil {
return "", &net.DNSError{Err: "unrecognized address", Name: addr}
}
if ip.To4() != nil {
return strconv.FormatUint(uint64(ip[15]), 10) + "." + strconv.FormatUint(uint64(ip[14]), 10) + "." + strconv.FormatUint(uint64(ip[13]), 10) + "." + strconv.FormatUint(uint64(ip[12]), 10) + ".in-addr.arpa.", nil
}
// Must be IPv6
buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
// Add it, in reverse, to the buffer
for i := len(ip) - 1; i >= 0; i-- {
v := ip[i]
buf = append(buf, hexDigit[v&0xF])
buf = append(buf, '.')
buf = append(buf, hexDigit[v>>4])
buf = append(buf, '.')
}
// Append "ip6.arpa." and return (buf already has the final .)
buf = append(buf, "ip6.arpa."...)
return string(buf), nil
}
20 changes: 20 additions & 0 deletions tests/test-cases/Group6-VIC-Machine/6-13-TLS.robot
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ Resource ../../resources/Util.robot
Test Teardown Run Keyword If Test Failed Cleanup VIC Appliance On Test Server
Test Timeout 20 minutes

*** Keywords ***
Check that requests with invalid host field are rejected
${rc} ${output}= Run And Return Rc And Output curl -vvv -H"Host: please.ddos" %{DOCKER_HOST}/_ping
Should Contain ${output} invalid host header
Should Contain ${output} 500 Internal Server Error
Should Contain ${output} to requested host please.ddos
Should Not Contain ${output} 200 OK

Check that normal requests are accepted
${rc} ${output}= Run And Return Rc And Output curl -vvv %{DOCKER_HOST}/_ping
Should Contain ${output} 200 OK
Should Not Contain ${output} 500 Internal Server Error

# also make sure it works if the port isn't part of the host header
${rc} ${output}= Run And Return Rc And Output curl -vvv -H"Host: %{VCH-IP}" %{DOCKER_HOST}/_ping
Should Contain ${output} 200 OK
Should Not Contain ${output} 500 Internal Server Error

*** Test Cases ***
Create VCH - defaults with --no-tls
Set Test Environment Variables
Expand All @@ -29,6 +47,8 @@ Create VCH - defaults with --no-tls
Get Docker Params ${output} ${true}
Log To Console Installer completed successfully: %{VCH-NAME}

Check that requests with invalid host field are rejected
Check that normal requests are accepted

Run Regression Tests
Cleanup VIC Appliance On Test Server
Expand Down

0 comments on commit e6e9ada

Please sign in to comment.