From dd61f8b0d5629a711ef3582d7a5c3775714ef74d Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 12 Jun 2019 10:52:26 -0700 Subject: [PATCH] Disable port forwarding by default Fixes #68 --- server.go | 18 ++++++++++++------ session.go | 2 +- session_test.go | 4 ++-- tcpip.go | 11 ++++++++--- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/server.go b/server.go index 41ba87e..dc61fb4 100644 --- a/server.go +++ b/server.go @@ -65,6 +65,8 @@ func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *goss return f(ctx, srv, req) } +var DefaultRequestHandlers = map[string]RequestHandler{} + type ChannelHandler interface { HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) } @@ -75,6 +77,10 @@ func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn f(srv, conn, newChan, ctx) } +var DefaultChannelHandlers = map[string]ChannelHandler{ + "session": ChannelHandlerFunc(DefaultSessionHandler), +} + func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { signer, err := generateSigner() @@ -90,15 +96,15 @@ func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() if srv.RequestHandlers == nil { - srv.RequestHandlers = map[string]RequestHandler{ - "tcpip-forward": forwardedTCPHandler{}, - "cancel-tcpip-forward": forwardedTCPHandler{}, + srv.RequestHandlers = map[string]RequestHandler{} + for k, v := range DefaultRequestHandlers { + srv.RequestHandlers[k] = v } } if srv.ChannelHandlers == nil { - srv.ChannelHandlers = map[string]ChannelHandler{ - "session": ChannelHandlerFunc(sessionHandler), - "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), + srv.ChannelHandlers = map[string]ChannelHandler{} + for k, v := range DefaultChannelHandlers { + srv.ChannelHandlers[k] = v } } } diff --git a/session.go b/session.go index 19ddda6..a6085f3 100644 --- a/session.go +++ b/session.go @@ -77,7 +77,7 @@ type Session interface { // when there is no signal channel specified const maxSigBufSize = 128 -func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { +func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { ch, reqs, err := newChan.Accept() if err != nil { // TODO: trigger event callback diff --git a/session_test.go b/session_test.go index 4396c5a..f47ff8a 100644 --- a/session_test.go +++ b/session_test.go @@ -20,8 +20,8 @@ func (srv *Server) serveOnce(l net.Listener) error { return e } srv.ChannelHandlers = map[string]ChannelHandler{ - "session": ChannelHandlerFunc(sessionHandler), - "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), + "session": ChannelHandlerFunc(DefaultSessionHandler), + "direct-tcpip": ChannelHandlerFunc(DirectTCPIPHandler), } srv.handleConn(conn) return nil diff --git a/tcpip.go b/tcpip.go index 4afbf2d..2a7f33d 100644 --- a/tcpip.go +++ b/tcpip.go @@ -23,7 +23,9 @@ type localForwardChannelData struct { OriginPort uint32 } -func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { +// DirectTCPIPHandler can be enabled by adding it to the server's +// ChannelHandlers under direct-tcpip. +func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { d := localForwardChannelData{} if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) @@ -84,12 +86,15 @@ type remoteForwardChannelData struct { OriginPort uint32 } -type forwardedTCPHandler struct { +// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and +// adding it to the server's RequestHandlers under tcpip-forward and +// cancel-tcpip-forward. +type ForwardedTCPHandler struct { forwards map[string]net.Listener sync.Mutex } -func (h forwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { +func (h ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { h.Lock() if h.forwards == nil { h.forwards = make(map[string]net.Listener)