diff --git a/engine/src/main/java/io/seldon/engine/grpc/SeldonGrpcServer.java b/engine/src/main/java/io/seldon/engine/grpc/SeldonGrpcServer.java index 64baa76b50..ddfa909f43 100644 --- a/engine/src/main/java/io/seldon/engine/grpc/SeldonGrpcServer.java +++ b/engine/src/main/java/io/seldon/engine/grpc/SeldonGrpcServer.java @@ -35,7 +35,7 @@ public class SeldonGrpcServer { private final static String ENGINE_SERVER_PORT_KEY = "ENGINE_SERVER_GRPC_PORT"; public static final int SERVER_PORT = 5000; - private final String ANNOTATION_MAX_MESSAGE_SIZE = "seldon.io/grpc-max-message-size"; + public final static String ANNOTATION_MAX_MESSAGE_SIZE = "seldon.io/grpc-max-message-size"; private final int port; private final Server server; @@ -124,7 +124,4 @@ private void blockUntilShutdown() throws InterruptedException { } } - public int getMaxMessageSize() { - return maxMessageSize; - } } diff --git a/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java b/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java index 713626168c..8836242840 100644 --- a/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java +++ b/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java @@ -40,6 +40,7 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.seldon.engine.config.AnnotationsConfig; import io.seldon.engine.exception.APIException; import io.seldon.engine.grpc.SeldonGrpcServer; import io.seldon.engine.pb.ProtoBufUtils; @@ -77,13 +78,24 @@ public class InternalPredictionService { ObjectMapper mapper = new ObjectMapper(); RestTemplate restTemplate; - SeldonGrpcServer grpcServer; - + + private int maxMessageSize = io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; @Autowired - public InternalPredictionService(RestTemplate restTemplate,SeldonGrpcServer grpcServer){ + public InternalPredictionService(RestTemplate restTemplate,AnnotationsConfig annotations){ this.restTemplate = restTemplate; - this.grpcServer = grpcServer; + if (annotations.has(SeldonGrpcServer.ANNOTATION_MAX_MESSAGE_SIZE)) + { + try + { + maxMessageSize =Integer.parseInt(annotations.get(SeldonGrpcServer.ANNOTATION_MAX_MESSAGE_SIZE)); + logger.info("Setting max message to {}",maxMessageSize); + } + catch(NumberFormatException e) + { + logger.warn("Failed to parse {} with value {}",SeldonGrpcServer.ANNOTATION_MAX_MESSAGE_SIZE,annotations.get(SeldonGrpcServer.ANNOTATION_MAX_MESSAGE_SIZE),e); + } + } } public SeldonMessage route(SeldonMessage input, PredictiveUnitState state) throws InvalidProtocolBufferException @@ -98,15 +110,15 @@ public SeldonMessage route(SeldonMessage input, PredictiveUnitState state) throw if (state.type==PredictiveUnitType.UNKNOWN_TYPE){ GenericBlockingStub stub = GenericGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return stub.route(input); } else { RouterBlockingStub stub = RouterGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return stub.route(input); } } @@ -125,15 +137,15 @@ public SeldonMessage sendFeedback(Feedback feedback, PredictiveUnitState state) if (state.type==PredictiveUnitType.UNKNOWN_TYPE){ GenericBlockingStub stub = GenericGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return stub.sendFeedback(feedback); } else { RouterBlockingStub routerStub = RouterGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return routerStub.sendFeedback(feedback); } } @@ -158,20 +170,20 @@ public SeldonMessage transformInput(SeldonMessage input, PredictiveUnitState sta case UNKNOWN_TYPE: GenericBlockingStub genStub = GenericGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return genStub.transformInput(input); case MODEL: ModelBlockingStub modelStub = ModelGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return modelStub.predict(input); case TRANSFORMER: TransformerBlockingStub transformerStub = TransformerGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return transformerStub.transformInput(input); default: throw new APIException(APIException.ApiExceptionType.ENGINE_MICROSERVICE_ERROR,"Unhandled type"); @@ -192,15 +204,15 @@ public SeldonMessage transformOutput(SeldonMessage output, PredictiveUnitState s if (state.type==PredictiveUnitType.UNKNOWN_TYPE){ GenericBlockingStub stub = GenericGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return stub.transformOutput(output); } else { OutputTransformerBlockingStub stub = OutputTransformerGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return stub.transformOutput(output); } } @@ -219,15 +231,15 @@ public SeldonMessage aggregate(List outputs, PredictiveUnitState if (state.type==PredictiveUnitType.UNKNOWN_TYPE){ GenericBlockingStub stub = GenericGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return stub.aggregate(outputsList); } else { CombinerBlockingStub stub = CombinerGrpc.newBlockingStub(getChannel(endpoint)) .withDeadlineAfter(TIMEOUT, TimeUnit.SECONDS) - .withMaxInboundMessageSize(grpcServer.getMaxMessageSize()) - .withMaxOutboundMessageSize(grpcServer.getMaxMessageSize()); + .withMaxInboundMessageSize(maxMessageSize) + .withMaxOutboundMessageSize(maxMessageSize); return stub.aggregate(outputsList); } }