Skip to content

Commit

Permalink
Merge branch 'update-to-gorilla-websocket'
Browse files Browse the repository at this point in the history
  • Loading branch information
joeybloggs authored and joeybloggs committed May 26, 2016
2 parents 78ede96 + e92bff0 commit f8a1be7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 24 deletions.
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"strings"
"time"

"github.com/gorilla/websocket"
"golang.org/x/net/context"
"golang.org/x/net/websocket"
)

// Param is a single URL parameter, consisting of a key and a value.
Expand Down
20 changes: 9 additions & 11 deletions group.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package lars

import (
"net/http"
"strconv"
"strings"

"golang.org/x/net/websocket"
"github.com/gorilla/websocket"
)

// IRouteGroup interface for router group
Expand All @@ -27,7 +26,7 @@ type IRoutes interface {
Head(string, ...Handler)
Connect(string, ...Handler)
Trace(string, ...Handler)
WebSocket(string, Handler)
WebSocket(websocket.Upgrader, string, Handler)
}

// routeGroup struct containing all fields and methods for use.
Expand Down Expand Up @@ -159,22 +158,21 @@ func (g *routeGroup) Match(methods []string, path string, h ...Handler) {
}

// WebSocket adds a websocket route
func (g *routeGroup) WebSocket(path string, h Handler) {
func (g *routeGroup) WebSocket(upgrader websocket.Upgrader, path string, h Handler) {

handler := g.lars.wrapHandler(h)
g.Get(path, func(c Context) {

ctx := c.BaseContext()
var err error

wss := websocket.Server{
Handler: func(ws *websocket.Conn) {
ctx.websocket = ws
ctx.response.status = http.StatusSwitchingProtocols
ctx.Next()
},
ctx.websocket, err = upgrader.Upgrade(ctx.response, ctx.request, nil)
if err != nil {
return
}

wss.ServeHTTP(ctx.response, ctx.request)
defer ctx.websocket.Close()
c.Next()
}, handler)
}

Expand Down
49 changes: 37 additions & 12 deletions group_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package lars

import (
"bytes"
"fmt"
"log"
"net/http"
"net/http/httptest"
"testing"

"golang.org/x/net/websocket"
"github.com/gorilla/websocket"
. "gopkg.in/go-playground/assert.v1"
)

Expand All @@ -22,14 +23,28 @@ import (
//

func TestWebsockets(t *testing.T) {

origin := "http://localhost"

var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
o := r.Header.Get(Origin)
return o == origin
},
}

l := New()
l.WebSocket("/ws", func(c Context) {
l.WebSocket(upgrader, "/ws", func(c Context) {

recv := make([]byte, 1000)
messageType, b, err := c.WebSocket().ReadMessage()
if err != nil {
return
}

i, err := c.WebSocket().Read(recv)
if err == nil {
_, err := c.WebSocket().Write(recv[:i])
err := c.WebSocket().WriteMessage(messageType, b)
if err != nil {
panic(err)
}
Expand All @@ -40,19 +55,29 @@ func TestWebsockets(t *testing.T) {
defer server.Close()

addr := server.Listener.Addr().String()
origin := "http://localhost"

header := make(http.Header, 0)
header.Set(Origin, origin)

url := fmt.Sprintf("ws://%s/ws", addr)
ws, err := websocket.Dial(url, "", origin)
ws, _, err := websocket.DefaultDialer.Dial(url, header)
if err != nil {
log.Fatal("dial:", err)
}
Equal(t, err, nil)

defer ws.Close()

_, err = ws.Write([]byte("websockets in action!"))
err = ws.WriteMessage(websocket.TextMessage, []byte("websockets in action!"))
Equal(t, err, nil)

buf := new(bytes.Buffer)
_, err = buf.ReadFrom(ws)
typ, b, err := ws.ReadMessage()
Equal(t, err, nil)
Equal(t, "websockets in action!", buf.String())
Equal(t, typ, websocket.TextMessage)
Equal(t, "websockets in action!", string(b))

wsBad, res, err := websocket.DefaultDialer.Dial(url, nil)
NotEqual(t, err, nil)
Equal(t, wsBad, nil)
Equal(t, res.StatusCode, http.StatusForbidden)
}
1 change: 1 addition & 0 deletions lars.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ const (
XForwardedFor = "X-Forwarded-For"
XRealIP = "X-Real-Ip"
Allow = "Allow"
Origin = "Origin"

Gzip = "gzip"

Expand Down

0 comments on commit f8a1be7

Please sign in to comment.