diff --git a/cmd/server.go b/cmd/server.go index 877aa6c25c..d74af5890c 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -119,6 +119,7 @@ const ( WebBasicAuthFlag = "web-basic-auth" WebUsernameFlag = "web-username" WebPasswordFlag = "web-password" + WebsocketCheckOrigin = "websocket-check-origin" // NOTE: Must manually set these as defaults in the setDefaults function. DefaultADBasicUser = "" @@ -470,6 +471,10 @@ var boolFlags = map[string]boolFlag{ description: "Switches on or off the Basic Authentication on the HTTP Middleware interface", defaultValue: DefaultWebBasicAuth, }, + WebsocketCheckOrigin: { + description: "Enable websocket origin check", + defaultValue: false, + }, } var intFlags = map[string]intFlag{ ParallelPoolSize: { @@ -682,6 +687,7 @@ func (s *ServerCmd) run() error { RepoConfigJSONFlag: RepoConfigJSONFlag, SilenceForkPRErrorsFlag: SilenceForkPRErrorsFlag, }) + if err != nil { return errors.Wrap(err, "initializing server") } diff --git a/runatlantis.io/docs/server-configuration.md b/runatlantis.io/docs/server-configuration.md index 6cbf5617e8..16a9402887 100644 --- a/runatlantis.io/docs/server-configuration.md +++ b/runatlantis.io/docs/server-configuration.md @@ -767,3 +767,9 @@ Values are chosen in this order: atlantis server --web-password="atlantis" ``` Password used for Basic Authentication on the Atlantis web service. Defaults to `atlantis`. + +### `--websocket-check-origin` + ```bash + atlantis server --websocket-check-origin + ``` + Only allow websockets connection when they originate from the running Atlantis web server \ No newline at end of file diff --git a/server/controllers/websocket/mux.go b/server/controllers/websocket/mux.go index 8288df3212..e0924a2e52 100644 --- a/server/controllers/websocket/mux.go +++ b/server/controllers/websocket/mux.go @@ -31,9 +31,19 @@ type Multiplexor struct { registry PartitionRegistry } -func NewMultiplexor(log logging.SimpleLogging, keyGenerator PartitionKeyGenerator, registry PartitionRegistry) *Multiplexor { - upgrader := websocket.Upgrader{} - upgrader.CheckOrigin = func(r *http.Request) bool { return true } +func checkOriginFunc(checkOrigin bool) func(r *http.Request) bool { + if checkOrigin { + return nil // use Gorilla websocket's checkSameOrigin + } + return func(r *http.Request) bool { + return true + } +} + +func NewMultiplexor(log logging.SimpleLogging, keyGenerator PartitionKeyGenerator, registry PartitionRegistry, checkOrigin bool) *Multiplexor { + upgrader := websocket.Upgrader{ + CheckOrigin: checkOriginFunc(checkOrigin), + } return &Multiplexor{ writer: &Writer{ upgrader: upgrader, diff --git a/server/controllers/websocket/mux_test.go b/server/controllers/websocket/mux_test.go new file mode 100644 index 0000000000..baddfb2196 --- /dev/null +++ b/server/controllers/websocket/mux_test.go @@ -0,0 +1,61 @@ +package websocket + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/websocket" +) + +func wsHandler(t *testing.T, checkOrigin bool) http.HandlerFunc { + upgrader := websocket.Upgrader{ + CheckOrigin: checkOriginFunc(checkOrigin), + } + return func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Log("upgrade:", err) + return + } + defer c.Close() + } +} + +func TestCheckOriginFunc(t *testing.T) { + + tests := []struct { + name string + checkOrigin bool + origin string + host string + wantErr bool + }{ + {"same origin", true, "http://example.com/", "example.com", false}, + {"same origin with port", true, "http://example.com:8080/", "example.com:8080", false}, + {"fail with different origin", true, "http://example.net/", "example.com", true}, + {"success with same origin without check", false, "http://example.com/", "example.com", false}, + {"success with different origin without check", false, "http://example.net/", "example.com", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := httptest.NewServer(wsHandler(t, tt.checkOrigin)) + u, _ := url.Parse(s.URL) + u.Path = "/" + u.Scheme = "ws" + header := http.Header{ + "Origin": []string{tt.origin}, + "Host": []string{tt.host}, + } + c, _, err := websocket.DefaultDialer.Dial(u.String(), header) + if err == nil { + defer c.Close() + } + if (err != nil) != tt.wantErr { + t.Errorf("websocket dial error = %v, wantErr %v", err, tt.wantErr) + } + }) + } + +} diff --git a/server/controllers/websocket/writer.go b/server/controllers/websocket/writer.go index 89eb3b6dbe..1b862d628d 100644 --- a/server/controllers/websocket/writer.go +++ b/server/controllers/websocket/writer.go @@ -8,8 +8,10 @@ import ( "github.com/runatlantis/atlantis/server/logging" ) -func NewWriter(log logging.SimpleLogging) *Writer { - upgrader := websocket.Upgrader{} +func NewWriter(log logging.SimpleLogging, checkOrigin bool) *Writer { + upgrader := websocket.Upgrader{ + CheckOrigin: checkOriginFunc(checkOrigin), + } upgrader.CheckOrigin = func(r *http.Request) bool { return true } return &Writer{ upgrader: upgrader, diff --git a/server/server.go b/server/server.go index 34d99befc0..c41612d1d2 100644 --- a/server/server.go +++ b/server/server.go @@ -747,6 +747,7 @@ func NewServer(userConfig UserConfig, config Config) (*Server, error) { logger, controllers.JobIDKeyGenerator{}, projectCmdOutputHandler, + userConfig.WebsocketCheckOrigin, ) jobsController := &controllers.JobsController{ diff --git a/server/user_config.go b/server/user_config.go index c8f5435669..a68ef23037 100644 --- a/server/user_config.go +++ b/server/user_config.go @@ -105,6 +105,7 @@ type UserConfig struct { WebUsername string `mapstructure:"web-username"` WebPassword string `mapstructure:"web-password"` WriteGitCreds bool `mapstructure:"write-git-creds"` + WebsocketCheckOrigin bool `mapstructure:"websocket-check-origin"` } // ToLogLevel returns the LogLevel object corresponding to the user-passed