diff --git a/cmd/entrypoints/serve.go b/cmd/entrypoints/serve.go index 8cb8140e31..2028dfeb9c 100644 --- a/cmd/entrypoints/serve.go +++ b/cmd/entrypoints/serve.go @@ -113,6 +113,9 @@ func newGRPCServer(ctx context.Context, cfg *config.ServerConfig, authCtx interf grpc.StreamInterceptor(grpcPrometheus.StreamServerInterceptor), grpc.UnaryInterceptor(chainedUnaryInterceptors), } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes)) + } serverOpts = append(serverOpts, opts...) grpcServer := grpc.NewServer(serverOpts...) grpcPrometheus.Register(grpcServer) @@ -125,7 +128,7 @@ func newGRPCServer(ctx context.Context, cfg *config.ServerConfig, authCtx interf healthServer := health.NewServer() healthServer.SetServingStatus("flyteadmin", grpc_health_v1.HealthCheckResponse_SERVING) grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) - if cfg.GrpcServerReflection { + if cfg.GrpcConfig.ServerReflection || cfg.GrpcServerReflection { reflection.Register(grpcServer) } return grpcServer, nil @@ -263,8 +266,15 @@ func serveGatewayInsecure(ctx context.Context, cfg *config.ServerConfig, authCfg }() logger.Infof(ctx, "Starting HTTP/1 Gateway server on %s", cfg.GetHostAddress()) - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, cfg.GetGrpcHostAddress(), grpc.WithInsecure(), - grpc.WithMaxHeaderListSize(common.MaxResponseStatusBytes)) + grpcOptions := []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithMaxHeaderListSize(common.MaxResponseStatusBytes), + } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + grpcOptions = append(grpcOptions, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) + } + httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, cfg.GetGrpcHostAddress(), grpcOptions...) if err != nil { return err } @@ -351,7 +361,14 @@ func serveGatewaySecure(ctx context.Context, cfg *config.ServerConfig, authCfg * ServerName: cfg.GetHostAddress(), RootCAs: certPool, }) - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, cfg.GetHostAddress(), grpc.WithTransportCredentials(dialCreds)) + serverOpts := []grpc.DialOption{ + grpc.WithTransportCredentials(dialCreds), + } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + serverOpts = append(serverOpts, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) + } + httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, cfg.GetHostAddress(), serverOpts...) if err != nil { return err } diff --git a/pkg/config/config.go b/pkg/config/config.go index 41c8290b2e..b5ba90e2b4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -13,16 +13,22 @@ const SectionKey = "server" type ServerConfig struct { HTTPPort int `json:"httpPort" pflag:",On which http port to serve admin"` - GrpcPort int `json:"grpcPort" pflag:",On which grpc port to serve admin"` - GrpcServerReflection bool `json:"grpcServerReflection" pflag:",Enable GRPC Server Reflection"` + GrpcPort int `json:"grpcPort" pflag:",deprecated"` + GrpcServerReflection bool `json:"grpcServerReflection" pflag:",deprecated"` KubeConfig string `json:"kube-config" pflag:",Path to kubernetes client config file, default is empty, useful for incluster config."` Master string `json:"master" pflag:",The address of the Kubernetes API server."` Security ServerSecurityOptions `json:"security"` - + GrpcConfig GrpcConfig `json:"grpc"` // Deprecated: please use auth.AppAuth.ThirdPartyConfig instead. DeprecatedThirdPartyConfig authConfig.ThirdPartyConfigOptions `json:"thirdPartyConfig" pflag:",Deprecated please use auth.appAuth.thirdPartyConfig instead."` } +type GrpcConfig struct { + Port int `json:"port" pflag:",On which grpc port to serve admin"` + ServerReflection bool `json:"serverReflection" pflag:",Enable GRPC Server Reflection"` + MaxMessageSizeBytes int `json:"maxMessageSizeBytes" pflag:",The max size in bytes for incoming gRPC messages"` +} + type ServerSecurityOptions struct { Secure bool `json:"secure"` Ssl SslOptions `json:"ssl"` @@ -48,14 +54,17 @@ type SslOptions struct { } var defaultServerConfig = &ServerConfig{ - HTTPPort: 8088, - GrpcPort: 8089, - GrpcServerReflection: true, + HTTPPort: 8088, + KubeConfig: "$HOME/.kube/config", Security: ServerSecurityOptions{ AllowCors: true, AllowedHeaders: []string{"Content-Type", "flyte-authorization"}, AllowedOrigins: []string{"*"}, }, + GrpcConfig: GrpcConfig{ + Port: 8089, + ServerReflection: true, + }, } var serverConfig = config.MustRegisterSection(SectionKey, defaultServerConfig) @@ -78,6 +87,9 @@ func (s ServerConfig) GetHostAddress() string { } func (s ServerConfig) GetGrpcHostAddress() string { + if s.GrpcConfig.Port >= 0 { + return fmt.Sprintf(":%d", s.GrpcConfig.Port) + } return fmt.Sprintf(":%d", s.GrpcPort) } diff --git a/pkg/config/serverconfig_flags.go b/pkg/config/serverconfig_flags.go index f9ad6a732a..300ed61bbd 100755 --- a/pkg/config/serverconfig_flags.go +++ b/pkg/config/serverconfig_flags.go @@ -51,8 +51,8 @@ func (ServerConfig) mustMarshalJSON(v json.Marshaler) string { func (cfg ServerConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("ServerConfig", pflag.ExitOnError) cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "httpPort"), defaultServerConfig.HTTPPort, "On which http port to serve admin") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpcPort"), defaultServerConfig.GrpcPort, "On which grpc port to serve admin") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "grpcServerReflection"), defaultServerConfig.GrpcServerReflection, "Enable GRPC Server Reflection") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpcPort"), defaultServerConfig.GrpcPort, "deprecated") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "grpcServerReflection"), defaultServerConfig.GrpcServerReflection, "deprecated") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "kube-config"), defaultServerConfig.KubeConfig, "Path to kubernetes client config file, default is empty, useful for incluster config.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "master"), defaultServerConfig.Master, "The address of the Kubernetes API server.") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.secure"), defaultServerConfig.Security.Secure, "") @@ -63,6 +63,9 @@ func (cfg ServerConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.allowCors"), defaultServerConfig.Security.AllowCors, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedOrigins"), []string{}, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedHeaders"), []string{}, "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpc.port"), defaultServerConfig.GrpcConfig.Port, "On which grpc port to serve admin") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "grpc.serverReflection"), defaultServerConfig.GrpcConfig.ServerReflection, "Enable GRPC Server Reflection") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpc.maxMessageSizeBytes"), defaultServerConfig.GrpcConfig.MaxMessageSizeBytes, "The max size in bytes for incoming gRPC messages") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "thirdPartyConfig.flyteClient.clientId"), defaultServerConfig.DeprecatedThirdPartyConfig.FlyteClientConfig.ClientID, "public identifier for the app which handles authorization for a Flyte deployment") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "thirdPartyConfig.flyteClient.redirectUri"), defaultServerConfig.DeprecatedThirdPartyConfig.FlyteClientConfig.RedirectURI, "This is the callback uri registered with the app which handles authorization for a Flyte deployment") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "thirdPartyConfig.flyteClient.scopes"), []string{}, "Recommended scopes for the client to request.") diff --git a/pkg/config/serverconfig_flags_test.go b/pkg/config/serverconfig_flags_test.go index 99c1081254..e000f2dcba 100755 --- a/pkg/config/serverconfig_flags_test.go +++ b/pkg/config/serverconfig_flags_test.go @@ -281,6 +281,48 @@ func TestServerConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_grpc.port", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("grpc.port", testValue) + if vInt, err := cmdFlags.GetInt("grpc.port"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vInt), &actual.GrpcConfig.Port) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_grpc.serverReflection", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("grpc.serverReflection", testValue) + if vBool, err := cmdFlags.GetBool("grpc.serverReflection"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vBool), &actual.GrpcConfig.ServerReflection) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_grpc.maxMessageSizeBytes", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("grpc.maxMessageSizeBytes", testValue) + if vInt, err := cmdFlags.GetInt("grpc.maxMessageSizeBytes"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vInt), &actual.GrpcConfig.MaxMessageSizeBytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_thirdPartyConfig.flyteClient.clientId", func(t *testing.T) { t.Run("Override", func(t *testing.T) {