Skip to content

Commit

Permalink
Extract port range validation to api/utils/net
Browse files Browse the repository at this point in the history
  • Loading branch information
ravicious committed Nov 13, 2024
1 parent 7170548 commit c931742
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 28 deletions.
17 changes: 3 additions & 14 deletions api/types/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types/compare"
"github.com/gravitational/teleport/api/utils"
netutils "github.com/gravitational/teleport/api/utils/net"
)

var _ compare.IsEqual[Application] = (*AppV3)(nil)
Expand Down Expand Up @@ -446,21 +447,9 @@ func (a *AppV3) checkTCPPorts() error {
return trace.BadParameter("TCP app URI %q must not include a port number when the app spec defines a list of ports", a.Spec.URI)
}

const minPort = 1
const maxPort = 65535
for _, portRange := range a.Spec.TCPPorts {
if portRange.Port < minPort || portRange.Port > maxPort {
return trace.BadParameter("TCP app port must be between %d and %d, but got %d", minPort, maxPort, portRange.Port)
}

if portRange.EndPort != 0 {
if portRange.EndPort < minPort+1 || portRange.EndPort > maxPort {
return trace.BadParameter("TCP app end port must be between %d and %d, but got %d", minPort+1, maxPort, portRange.EndPort)
}

if portRange.EndPort <= portRange.Port {
return trace.BadParameter("TCP app end port must be greater than port (%d vs %d)", portRange.EndPort, portRange.Port)
}
if err := netutils.ValidatePortRange(int(portRange.Port), int(portRange.EndPort)); err != nil {
return trace.Wrap(err, "validating a port range of a TCP app")
}
}

Expand Down
43 changes: 43 additions & 0 deletions api/utils/net/ports.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 Gravitational, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package net

import (
"github.com/gravitational/trace"
)

// ValidatePortRange checks if the given port range is within bounds. If endPort is not zero, it
// also checks if it's bigger than port. A port range with zero as endPort is assumed to describe a
// single port.
func ValidatePortRange(port, endPort int) error {
const minPort = 1
const maxPort = 65535

if port < minPort || port > maxPort {
return trace.BadParameter("port must be between %d and %d, but got %d", minPort, maxPort, port)
}

if endPort != 0 {
if endPort < minPort+1 || endPort > maxPort {
return trace.BadParameter("end port must be between %d and %d, but got %d", minPort+1, maxPort, endPort)
}

if endPort <= port {
return trace.BadParameter("end port must be greater than port (%d vs %d)", endPort, port)
}
}

return nil
}
91 changes: 91 additions & 0 deletions api/utils/net/ports_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2024 Gravitational, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package net

import (
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

func TestValidatePortRange(t *testing.T) {
tests := []struct {
name string
port int
endPort int
check require.ErrorAssertionFunc
}{
{
name: "valid single port",
port: 1337,
endPort: 0,
check: require.NoError,
},
{
name: "valid port range",
port: 1337,
endPort: 3456,
check: require.NoError,
},
{
name: "port smaller than 1",
port: 0,
endPort: 0,
check: badParameterError,
},
{
name: "port bigger than max port",
port: 98765,
endPort: 0,
check: badParameterError,
},
{
name: "end port smaller than 2",
port: 5,
endPort: 1,
check: badParameterErrorAndContains("end port must be between"),
},
{
name: "end port bigger than max port",
port: 5,
endPort: 98765,
check: badParameterErrorAndContains("end port must be between"),
},
{
name: "end port smaller than port",
port: 10,
endPort: 5,
check: badParameterErrorAndContains("end port must be greater than port"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.check(t, ValidatePortRange(tt.port, tt.endPort))
})
}
}

func badParameterError(t require.TestingT, err error, msgAndArgs ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter error, got %+v", err)
}

func badParameterErrorAndContains(msg string) require.ErrorAssertionFunc {
return func(t require.TestingT, err error, msgAndArgs ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter error, got %+v", err)
require.ErrorContains(t, err, msg, msgAndArgs...)
}
}
17 changes: 3 additions & 14 deletions lib/service/servicecfg/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"k8s.io/apimachinery/pkg/util/validation"

"github.com/gravitational/teleport/api/types"
netutils "github.com/gravitational/teleport/api/utils/net"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/app/common"
)
Expand Down Expand Up @@ -222,21 +223,9 @@ func (a *App) checkPorts() error {
return trace.BadParameter("app URI %q must not include a port number when the app spec defines a list of ports", a.URI)
}

const minPort = 1
const maxPort = 65535
for _, portRange := range a.TCPPorts {
if portRange.Port < minPort || portRange.Port > maxPort {
return trace.BadParameter("app port must be between %d and %d, but got %d", minPort, maxPort, portRange.Port)
}

if portRange.EndPort != 0 {
if portRange.EndPort < minPort+1 || portRange.EndPort > maxPort {
return trace.BadParameter("app end port must be between %d and %d, but got %d", minPort+1, maxPort, portRange.EndPort)
}

if portRange.EndPort <= portRange.Port {
return trace.BadParameter("app end port must be greater than port (%d vs %d)", portRange.EndPort, portRange.Port)
}
if err := netutils.ValidatePortRange(int(portRange.Port), int(portRange.EndPort)); err != nil {
return trace.Wrap(err, "validating a port range of a TCP app")
}
}

Expand Down

0 comments on commit c931742

Please sign in to comment.