diff --git a/api/types/app.go b/api/types/app.go index e37846ec9c87a..7452d35879815 100644 --- a/api/types/app.go +++ b/api/types/app.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types/compare" "github.com/gravitational/teleport/api/utils" + netutils "github.com/gravitational/teleport/api/utils/net" ) var _ compare.IsEqual[Application] = (*AppV3)(nil) @@ -446,21 +447,9 @@ func (a *AppV3) checkTCPPorts() error { return trace.BadParameter("TCP app URI %q must not include a port number when the app spec defines a list of ports", a.Spec.URI) } - const minPort = 1 - const maxPort = 65535 for _, portRange := range a.Spec.TCPPorts { - if portRange.Port < minPort || portRange.Port > maxPort { - return trace.BadParameter("TCP app port must be between %d and %d, but got %d", minPort, maxPort, portRange.Port) - } - - if portRange.EndPort != 0 { - if portRange.EndPort < minPort+1 || portRange.EndPort > maxPort { - return trace.BadParameter("TCP app end port must be between %d and %d, but got %d", minPort+1, maxPort, portRange.EndPort) - } - - if portRange.EndPort <= portRange.Port { - return trace.BadParameter("TCP app end port must be greater than port (%d vs %d)", portRange.EndPort, portRange.Port) - } + if err := netutils.ValidatePortRange(int(portRange.Port), int(portRange.EndPort)); err != nil { + return trace.Wrap(err, "validating a port range of a TCP app") } } diff --git a/api/utils/net/ports.go b/api/utils/net/ports.go new file mode 100644 index 0000000000000..896fb6ac83118 --- /dev/null +++ b/api/utils/net/ports.go @@ -0,0 +1,43 @@ +// Copyright 2024 Gravitational, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "github.com/gravitational/trace" +) + +// ValidatePortRange checks if the given port range is within bounds. If endPort is not zero, it +// also checks if it's bigger than port. A port range with zero as endPort is assumed to describe a +// single port. +func ValidatePortRange(port, endPort int) error { + const minPort = 1 + const maxPort = 65535 + + if port < minPort || port > maxPort { + return trace.BadParameter("port must be between %d and %d, but got %d", minPort, maxPort, port) + } + + if endPort != 0 { + if endPort < minPort+1 || endPort > maxPort { + return trace.BadParameter("end port must be between %d and %d, but got %d", minPort+1, maxPort, endPort) + } + + if endPort <= port { + return trace.BadParameter("end port must be greater than port (%d vs %d)", endPort, port) + } + } + + return nil +} diff --git a/api/utils/net/ports_test.go b/api/utils/net/ports_test.go new file mode 100644 index 0000000000000..7f0bb5e6166ee --- /dev/null +++ b/api/utils/net/ports_test.go @@ -0,0 +1,91 @@ +// Copyright 2024 Gravitational, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestValidatePortRange(t *testing.T) { + tests := []struct { + name string + port int + endPort int + check require.ErrorAssertionFunc + }{ + { + name: "valid single port", + port: 1337, + endPort: 0, + check: require.NoError, + }, + { + name: "valid port range", + port: 1337, + endPort: 3456, + check: require.NoError, + }, + { + name: "port smaller than 1", + port: 0, + endPort: 0, + check: badParameterError, + }, + { + name: "port bigger than max port", + port: 98765, + endPort: 0, + check: badParameterError, + }, + { + name: "end port smaller than 2", + port: 5, + endPort: 1, + check: badParameterErrorAndContains("end port must be between"), + }, + { + name: "end port bigger than max port", + port: 5, + endPort: 98765, + check: badParameterErrorAndContains("end port must be between"), + }, + { + name: "end port smaller than port", + port: 10, + endPort: 5, + check: badParameterErrorAndContains("end port must be greater than port"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.check(t, ValidatePortRange(tt.port, tt.endPort)) + }) + } +} + +func badParameterError(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.True(t, trace.IsBadParameter(err), "expected bad parameter error, got %+v", err) +} + +func badParameterErrorAndContains(msg string) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.True(t, trace.IsBadParameter(err), "expected bad parameter error, got %+v", err) + require.ErrorContains(t, err, msg, msgAndArgs...) + } +} diff --git a/lib/service/servicecfg/app.go b/lib/service/servicecfg/app.go index e2f85f9f41584..b007ff962890f 100644 --- a/lib/service/servicecfg/app.go +++ b/lib/service/servicecfg/app.go @@ -30,6 +30,7 @@ import ( "k8s.io/apimachinery/pkg/util/validation" "github.com/gravitational/teleport/api/types" + netutils "github.com/gravitational/teleport/api/utils/net" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/app/common" ) @@ -222,21 +223,9 @@ func (a *App) checkPorts() error { return trace.BadParameter("app URI %q must not include a port number when the app spec defines a list of ports", a.URI) } - const minPort = 1 - const maxPort = 65535 for _, portRange := range a.TCPPorts { - if portRange.Port < minPort || portRange.Port > maxPort { - return trace.BadParameter("app port must be between %d and %d, but got %d", minPort, maxPort, portRange.Port) - } - - if portRange.EndPort != 0 { - if portRange.EndPort < minPort+1 || portRange.EndPort > maxPort { - return trace.BadParameter("app end port must be between %d and %d, but got %d", minPort+1, maxPort, portRange.EndPort) - } - - if portRange.EndPort <= portRange.Port { - return trace.BadParameter("app end port must be greater than port (%d vs %d)", portRange.EndPort, portRange.Port) - } + if err := netutils.ValidatePortRange(int(portRange.Port), int(portRange.EndPort)); err != nil { + return trace.Wrap(err, "validating a port range of a TCP app") } }