diff --git a/AiPlatform/metadata/V1/Model.php b/AiPlatform/metadata/V1/Model.php index accbe29d357c..4b55fc391975 100644 Binary files a/AiPlatform/metadata/V1/Model.php and b/AiPlatform/metadata/V1/Model.php differ diff --git a/AiPlatform/metadata/V1/PredictionService.php b/AiPlatform/metadata/V1/PredictionService.php index e1047b52a254..54ed3373dacb 100644 Binary files a/AiPlatform/metadata/V1/PredictionService.php and b/AiPlatform/metadata/V1/PredictionService.php differ diff --git a/AiPlatform/samples/V1/PredictionServiceClient/direct_predict.php b/AiPlatform/samples/V1/PredictionServiceClient/direct_predict.php new file mode 100644 index 000000000000..acc38debd859 --- /dev/null +++ b/AiPlatform/samples/V1/PredictionServiceClient/direct_predict.php @@ -0,0 +1,74 @@ +setEndpoint($formattedEndpoint); + + // Call the API and handle any network failures. + try { + /** @var DirectPredictResponse $response */ + $response = $predictionServiceClient->directPredict($request); + printf('Response data: %s' . PHP_EOL, $response->serializeToJsonString()); + } catch (ApiException $ex) { + printf('Call failed with message: %s' . PHP_EOL, $ex->getMessage()); + } +} + +/** + * Helper to execute the sample. + * + * This sample has been automatically generated and should be regarded as a code + * template only. It will require modifications to work: + * - It may require correct/in-range values for request initialization. + * - It may require specifying regional endpoints when creating the service client, + * please see the apiEndpoint client configuration option for more details. + */ +function callSample(): void +{ + $formattedEndpoint = PredictionServiceClient::endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + + direct_predict_sample($formattedEndpoint); +} +// [END aiplatform_v1_generated_PredictionService_DirectPredict_sync] diff --git a/AiPlatform/samples/V1/PredictionServiceClient/direct_raw_predict.php b/AiPlatform/samples/V1/PredictionServiceClient/direct_raw_predict.php new file mode 100644 index 000000000000..34b69cfa715b --- /dev/null +++ b/AiPlatform/samples/V1/PredictionServiceClient/direct_raw_predict.php @@ -0,0 +1,73 @@ +setEndpoint($formattedEndpoint); + + // Call the API and handle any network failures. + try { + /** @var DirectRawPredictResponse $response */ + $response = $predictionServiceClient->directRawPredict($request); + printf('Response data: %s' . PHP_EOL, $response->serializeToJsonString()); + } catch (ApiException $ex) { + printf('Call failed with message: %s' . PHP_EOL, $ex->getMessage()); + } +} + +/** + * Helper to execute the sample. + * + * This sample has been automatically generated and should be regarded as a code + * template only. It will require modifications to work: + * - It may require correct/in-range values for request initialization. + * - It may require specifying regional endpoints when creating the service client, + * please see the apiEndpoint client configuration option for more details. + */ +function callSample(): void +{ + $formattedEndpoint = PredictionServiceClient::endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + + direct_raw_predict_sample($formattedEndpoint); +} +// [END aiplatform_v1_generated_PredictionService_DirectRawPredict_sync] diff --git a/AiPlatform/samples/V1/PredictionServiceClient/streaming_predict.php b/AiPlatform/samples/V1/PredictionServiceClient/streaming_predict.php new file mode 100644 index 000000000000..6767713cf471 --- /dev/null +++ b/AiPlatform/samples/V1/PredictionServiceClient/streaming_predict.php @@ -0,0 +1,80 @@ +setEndpoint($formattedEndpoint); + + // Call the API and handle any network failures. + try { + /** @var BidiStream $stream */ + $stream = $predictionServiceClient->streamingPredict(); + $stream->writeAll([$request,]); + + /** @var StreamingPredictResponse $element */ + foreach ($stream->closeWriteAndReadAll() as $element) { + printf('Element data: %s' . PHP_EOL, $element->serializeToJsonString()); + } + } catch (ApiException $ex) { + printf('Call failed with message: %s' . PHP_EOL, $ex->getMessage()); + } +} + +/** + * Helper to execute the sample. + * + * This sample has been automatically generated and should be regarded as a code + * template only. It will require modifications to work: + * - It may require correct/in-range values for request initialization. + * - It may require specifying regional endpoints when creating the service client, + * please see the apiEndpoint client configuration option for more details. + */ +function callSample(): void +{ + $formattedEndpoint = PredictionServiceClient::endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + + streaming_predict_sample($formattedEndpoint); +} +// [END aiplatform_v1_generated_PredictionService_StreamingPredict_sync] diff --git a/AiPlatform/samples/V1/PredictionServiceClient/streaming_raw_predict.php b/AiPlatform/samples/V1/PredictionServiceClient/streaming_raw_predict.php new file mode 100644 index 000000000000..a94ae0bd9f84 --- /dev/null +++ b/AiPlatform/samples/V1/PredictionServiceClient/streaming_raw_predict.php @@ -0,0 +1,79 @@ +setEndpoint($formattedEndpoint); + + // Call the API and handle any network failures. + try { + /** @var BidiStream $stream */ + $stream = $predictionServiceClient->streamingRawPredict(); + $stream->writeAll([$request,]); + + /** @var StreamingRawPredictResponse $element */ + foreach ($stream->closeWriteAndReadAll() as $element) { + printf('Element data: %s' . PHP_EOL, $element->serializeToJsonString()); + } + } catch (ApiException $ex) { + printf('Call failed with message: %s' . PHP_EOL, $ex->getMessage()); + } +} + +/** + * Helper to execute the sample. + * + * This sample has been automatically generated and should be regarded as a code + * template only. It will require modifications to work: + * - It may require correct/in-range values for request initialization. + * - It may require specifying regional endpoints when creating the service client, + * please see the apiEndpoint client configuration option for more details. + */ +function callSample(): void +{ + $formattedEndpoint = PredictionServiceClient::endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + + streaming_raw_predict_sample($formattedEndpoint); +} +// [END aiplatform_v1_generated_PredictionService_StreamingRawPredict_sync] diff --git a/AiPlatform/src/V1/Client/PredictionServiceClient.php b/AiPlatform/src/V1/Client/PredictionServiceClient.php index 1e4a21d4d423..40c73a2fc80f 100644 --- a/AiPlatform/src/V1/Client/PredictionServiceClient.php +++ b/AiPlatform/src/V1/Client/PredictionServiceClient.php @@ -25,6 +25,7 @@ namespace Google\Cloud\AIPlatform\V1\Client; use Google\ApiCore\ApiException; +use Google\ApiCore\BidiStream; use Google\ApiCore\CredentialsWrapper; use Google\ApiCore\GapicClientTrait; use Google\ApiCore\PagedListResponse; @@ -35,6 +36,10 @@ use Google\ApiCore\ValidationException; use Google\Api\HttpBody; use Google\Auth\FetchAuthTokenInterface; +use Google\Cloud\AIPlatform\V1\DirectPredictRequest; +use Google\Cloud\AIPlatform\V1\DirectPredictResponse; +use Google\Cloud\AIPlatform\V1\DirectRawPredictRequest; +use Google\Cloud\AIPlatform\V1\DirectRawPredictResponse; use Google\Cloud\AIPlatform\V1\ExplainRequest; use Google\Cloud\AIPlatform\V1\ExplainResponse; use Google\Cloud\AIPlatform\V1\PredictRequest; @@ -68,6 +73,8 @@ * * @experimental * + * @method PromiseInterface directPredictAsync(DirectPredictRequest $request, array $optionalArgs = []) + * @method PromiseInterface directRawPredictAsync(DirectRawPredictRequest $request, array $optionalArgs = []) * @method PromiseInterface explainAsync(ExplainRequest $request, array $optionalArgs = []) * @method PromiseInterface predictAsync(PredictRequest $request, array $optionalArgs = []) * @method PromiseInterface rawPredictAsync(RawPredictRequest $request, array $optionalArgs = []) @@ -274,6 +281,59 @@ public function __call($method, $args) return call_user_func_array([$this, 'startAsyncCall'], $args); } + /** + * Perform an unary online prediction request for Vertex first-party products + * and frameworks. + * + * The async variant is {@see PredictionServiceClient::directPredictAsync()} . + * + * @example samples/V1/PredictionServiceClient/direct_predict.php + * + * @param DirectPredictRequest $request A request to house fields associated with the call. + * @param array $callOptions { + * Optional. + * + * @type RetrySettings|array $retrySettings + * Retry settings to use for this call. Can be a {@see RetrySettings} object, or an + * associative array of retry settings parameters. See the documentation on + * {@see RetrySettings} for example usage. + * } + * + * @return DirectPredictResponse + * + * @throws ApiException Thrown if the API call fails. + */ + public function directPredict(DirectPredictRequest $request, array $callOptions = []): DirectPredictResponse + { + return $this->startApiCall('DirectPredict', $request, $callOptions)->wait(); + } + + /** + * Perform an online prediction request through gRPC. + * + * The async variant is {@see PredictionServiceClient::directRawPredictAsync()} . + * + * @example samples/V1/PredictionServiceClient/direct_raw_predict.php + * + * @param DirectRawPredictRequest $request A request to house fields associated with the call. + * @param array $callOptions { + * Optional. + * + * @type RetrySettings|array $retrySettings + * Retry settings to use for this call. Can be a {@see RetrySettings} object, or an + * associative array of retry settings parameters. See the documentation on + * {@see RetrySettings} for example usage. + * } + * + * @return DirectRawPredictResponse + * + * @throws ApiException Thrown if the API call fails. + */ + public function directRawPredict(DirectRawPredictRequest $request, array $callOptions = []): DirectRawPredictResponse + { + return $this->startApiCall('DirectRawPredict', $request, $callOptions)->wait(); + } + /** * Perform an online explanation. * @@ -395,6 +455,49 @@ public function serverStreamingPredict(StreamingPredictRequest $request, array $ return $this->startApiCall('ServerStreamingPredict', $request, $callOptions); } + /** + * Perform a streaming online prediction request for Vertex first-party + * products and frameworks. + * + * @example samples/V1/PredictionServiceClient/streaming_predict.php + * + * @param array $callOptions { + * Optional. + * + * @type int $timeoutMillis + * Timeout to use for this call. + * } + * + * @return BidiStream + * + * @throws ApiException Thrown if the API call fails. + */ + public function streamingPredict(array $callOptions = []): BidiStream + { + return $this->startApiCall('StreamingPredict', null, $callOptions); + } + + /** + * Perform a streaming online prediction request through gRPC. + * + * @example samples/V1/PredictionServiceClient/streaming_raw_predict.php + * + * @param array $callOptions { + * Optional. + * + * @type int $timeoutMillis + * Timeout to use for this call. + * } + * + * @return BidiStream + * + * @throws ApiException Thrown if the API call fails. + */ + public function streamingRawPredict(array $callOptions = []): BidiStream + { + return $this->startApiCall('StreamingRawPredict', null, $callOptions); + } + /** * Gets information about a location. * diff --git a/AiPlatform/src/V1/DirectPredictRequest.php b/AiPlatform/src/V1/DirectPredictRequest.php new file mode 100644 index 000000000000..6bde21b5e0d7 --- /dev/null +++ b/AiPlatform/src/V1/DirectPredictRequest.php @@ -0,0 +1,154 @@ +google.cloud.aiplatform.v1.DirectPredictRequest + */ +class DirectPredictRequest extends \Google\Protobuf\Internal\Message +{ + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + */ + private $endpoint = ''; + /** + * The prediction input. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Tensor inputs = 2; + */ + private $inputs; + /** + * The parameters that govern the prediction. + * + * Generated from protobuf field .google.cloud.aiplatform.v1.Tensor parameters = 3; + */ + private $parameters = null; + + /** + * Constructor. + * + * @param array $data { + * Optional. Data for populating the Message object. + * + * @type string $endpoint + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * @type array<\Google\Cloud\AIPlatform\V1\Tensor>|\Google\Protobuf\Internal\RepeatedField $inputs + * The prediction input. + * @type \Google\Cloud\AIPlatform\V1\Tensor $parameters + * The parameters that govern the prediction. + * } + */ + public function __construct($data = NULL) { + \GPBMetadata\Google\Cloud\Aiplatform\V1\PredictionService::initOnce(); + parent::__construct($data); + } + + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + * @return string + */ + public function getEndpoint() + { + return $this->endpoint; + } + + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + * @param string $var + * @return $this + */ + public function setEndpoint($var) + { + GPBUtil::checkString($var, True); + $this->endpoint = $var; + + return $this; + } + + /** + * The prediction input. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Tensor inputs = 2; + * @return \Google\Protobuf\Internal\RepeatedField + */ + public function getInputs() + { + return $this->inputs; + } + + /** + * The prediction input. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Tensor inputs = 2; + * @param array<\Google\Cloud\AIPlatform\V1\Tensor>|\Google\Protobuf\Internal\RepeatedField $var + * @return $this + */ + public function setInputs($var) + { + $arr = GPBUtil::checkRepeatedField($var, \Google\Protobuf\Internal\GPBType::MESSAGE, \Google\Cloud\AIPlatform\V1\Tensor::class); + $this->inputs = $arr; + + return $this; + } + + /** + * The parameters that govern the prediction. + * + * Generated from protobuf field .google.cloud.aiplatform.v1.Tensor parameters = 3; + * @return \Google\Cloud\AIPlatform\V1\Tensor|null + */ + public function getParameters() + { + return $this->parameters; + } + + public function hasParameters() + { + return isset($this->parameters); + } + + public function clearParameters() + { + unset($this->parameters); + } + + /** + * The parameters that govern the prediction. + * + * Generated from protobuf field .google.cloud.aiplatform.v1.Tensor parameters = 3; + * @param \Google\Cloud\AIPlatform\V1\Tensor $var + * @return $this + */ + public function setParameters($var) + { + GPBUtil::checkMessage($var, \Google\Cloud\AIPlatform\V1\Tensor::class); + $this->parameters = $var; + + return $this; + } + +} + diff --git a/AiPlatform/src/V1/DirectPredictResponse.php b/AiPlatform/src/V1/DirectPredictResponse.php new file mode 100644 index 000000000000..35a747ce3736 --- /dev/null +++ b/AiPlatform/src/V1/DirectPredictResponse.php @@ -0,0 +1,112 @@ +google.cloud.aiplatform.v1.DirectPredictResponse + */ +class DirectPredictResponse extends \Google\Protobuf\Internal\Message +{ + /** + * The prediction output. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Tensor outputs = 1; + */ + private $outputs; + /** + * The parameters that govern the prediction. + * + * Generated from protobuf field .google.cloud.aiplatform.v1.Tensor parameters = 2; + */ + private $parameters = null; + + /** + * Constructor. + * + * @param array $data { + * Optional. Data for populating the Message object. + * + * @type array<\Google\Cloud\AIPlatform\V1\Tensor>|\Google\Protobuf\Internal\RepeatedField $outputs + * The prediction output. + * @type \Google\Cloud\AIPlatform\V1\Tensor $parameters + * The parameters that govern the prediction. + * } + */ + public function __construct($data = NULL) { + \GPBMetadata\Google\Cloud\Aiplatform\V1\PredictionService::initOnce(); + parent::__construct($data); + } + + /** + * The prediction output. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Tensor outputs = 1; + * @return \Google\Protobuf\Internal\RepeatedField + */ + public function getOutputs() + { + return $this->outputs; + } + + /** + * The prediction output. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Tensor outputs = 1; + * @param array<\Google\Cloud\AIPlatform\V1\Tensor>|\Google\Protobuf\Internal\RepeatedField $var + * @return $this + */ + public function setOutputs($var) + { + $arr = GPBUtil::checkRepeatedField($var, \Google\Protobuf\Internal\GPBType::MESSAGE, \Google\Cloud\AIPlatform\V1\Tensor::class); + $this->outputs = $arr; + + return $this; + } + + /** + * The parameters that govern the prediction. + * + * Generated from protobuf field .google.cloud.aiplatform.v1.Tensor parameters = 2; + * @return \Google\Cloud\AIPlatform\V1\Tensor|null + */ + public function getParameters() + { + return $this->parameters; + } + + public function hasParameters() + { + return isset($this->parameters); + } + + public function clearParameters() + { + unset($this->parameters); + } + + /** + * The parameters that govern the prediction. + * + * Generated from protobuf field .google.cloud.aiplatform.v1.Tensor parameters = 2; + * @param \Google\Cloud\AIPlatform\V1\Tensor $var + * @return $this + */ + public function setParameters($var) + { + GPBUtil::checkMessage($var, \Google\Cloud\AIPlatform\V1\Tensor::class); + $this->parameters = $var; + + return $this; + } + +} + diff --git a/AiPlatform/src/V1/DirectRawPredictRequest.php b/AiPlatform/src/V1/DirectRawPredictRequest.php new file mode 100644 index 000000000000..d4737fd0641a --- /dev/null +++ b/AiPlatform/src/V1/DirectRawPredictRequest.php @@ -0,0 +1,164 @@ +google.cloud.aiplatform.v1.DirectRawPredictRequest + */ +class DirectRawPredictRequest extends \Google\Protobuf\Internal\Message +{ + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + */ + private $endpoint = ''; + /** + * Fully qualified name of the API method being invoked to perform + * predictions. + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * + * Generated from protobuf field string method_name = 2; + */ + private $method_name = ''; + /** + * The prediction input. + * + * Generated from protobuf field bytes input = 3; + */ + private $input = ''; + + /** + * Constructor. + * + * @param array $data { + * Optional. Data for populating the Message object. + * + * @type string $endpoint + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * @type string $method_name + * Fully qualified name of the API method being invoked to perform + * predictions. + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * @type string $input + * The prediction input. + * } + */ + public function __construct($data = NULL) { + \GPBMetadata\Google\Cloud\Aiplatform\V1\PredictionService::initOnce(); + parent::__construct($data); + } + + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + * @return string + */ + public function getEndpoint() + { + return $this->endpoint; + } + + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + * @param string $var + * @return $this + */ + public function setEndpoint($var) + { + GPBUtil::checkString($var, True); + $this->endpoint = $var; + + return $this; + } + + /** + * Fully qualified name of the API method being invoked to perform + * predictions. + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * + * Generated from protobuf field string method_name = 2; + * @return string + */ + public function getMethodName() + { + return $this->method_name; + } + + /** + * Fully qualified name of the API method being invoked to perform + * predictions. + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * + * Generated from protobuf field string method_name = 2; + * @param string $var + * @return $this + */ + public function setMethodName($var) + { + GPBUtil::checkString($var, True); + $this->method_name = $var; + + return $this; + } + + /** + * The prediction input. + * + * Generated from protobuf field bytes input = 3; + * @return string + */ + public function getInput() + { + return $this->input; + } + + /** + * The prediction input. + * + * Generated from protobuf field bytes input = 3; + * @param string $var + * @return $this + */ + public function setInput($var) + { + GPBUtil::checkString($var, False); + $this->input = $var; + + return $this; + } + +} + diff --git a/AiPlatform/src/V1/DirectRawPredictResponse.php b/AiPlatform/src/V1/DirectRawPredictResponse.php new file mode 100644 index 000000000000..4c32343e0a84 --- /dev/null +++ b/AiPlatform/src/V1/DirectRawPredictResponse.php @@ -0,0 +1,68 @@ +google.cloud.aiplatform.v1.DirectRawPredictResponse + */ +class DirectRawPredictResponse extends \Google\Protobuf\Internal\Message +{ + /** + * The prediction output. + * + * Generated from protobuf field bytes output = 1; + */ + private $output = ''; + + /** + * Constructor. + * + * @param array $data { + * Optional. Data for populating the Message object. + * + * @type string $output + * The prediction output. + * } + */ + public function __construct($data = NULL) { + \GPBMetadata\Google\Cloud\Aiplatform\V1\PredictionService::initOnce(); + parent::__construct($data); + } + + /** + * The prediction output. + * + * Generated from protobuf field bytes output = 1; + * @return string + */ + public function getOutput() + { + return $this->output; + } + + /** + * The prediction output. + * + * Generated from protobuf field bytes output = 1; + * @param string $var + * @return $this + */ + public function setOutput($var) + { + GPBUtil::checkString($var, False); + $this->output = $var; + + return $this; + } + +} + diff --git a/AiPlatform/src/V1/Gapic/PredictionServiceGapicClient.php b/AiPlatform/src/V1/Gapic/PredictionServiceGapicClient.php index 9e2832941006..ef4e1176a381 100644 --- a/AiPlatform/src/V1/Gapic/PredictionServiceGapicClient.php +++ b/AiPlatform/src/V1/Gapic/PredictionServiceGapicClient.php @@ -35,6 +35,10 @@ use Google\ApiCore\ValidationException; use Google\Api\HttpBody; use Google\Auth\FetchAuthTokenInterface; +use Google\Cloud\AIPlatform\V1\DirectPredictRequest; +use Google\Cloud\AIPlatform\V1\DirectPredictResponse; +use Google\Cloud\AIPlatform\V1\DirectRawPredictRequest; +use Google\Cloud\AIPlatform\V1\DirectRawPredictResponse; use Google\Cloud\AIPlatform\V1\ExplainRequest; use Google\Cloud\AIPlatform\V1\ExplainResponse; use Google\Cloud\AIPlatform\V1\ExplanationSpecOverride; @@ -43,6 +47,8 @@ use Google\Cloud\AIPlatform\V1\RawPredictRequest; use Google\Cloud\AIPlatform\V1\StreamingPredictRequest; use Google\Cloud\AIPlatform\V1\StreamingPredictResponse; +use Google\Cloud\AIPlatform\V1\StreamingRawPredictRequest; +use Google\Cloud\AIPlatform\V1\StreamingRawPredictResponse; use Google\Cloud\AIPlatform\V1\Tensor; use Google\Cloud\Iam\V1\GetIamPolicyRequest; use Google\Cloud\Iam\V1\GetPolicyOptions; @@ -67,8 +73,7 @@ * $predictionServiceClient = new PredictionServiceClient(); * try { * $formattedEndpoint = $predictionServiceClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); - * $instances = []; - * $response = $predictionServiceClient->explain($formattedEndpoint, $instances); + * $response = $predictionServiceClient->directPredict($formattedEndpoint); * } finally { * $predictionServiceClient->close(); * } @@ -357,6 +362,137 @@ public function __construct(array $options = []) $this->setClientOptions($clientOptions); } + /** + * Perform an unary online prediction request for Vertex first-party products + * and frameworks. + * + * Sample code: + * ``` + * $predictionServiceClient = new PredictionServiceClient(); + * try { + * $formattedEndpoint = $predictionServiceClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + * $response = $predictionServiceClient->directPredict($formattedEndpoint); + * } finally { + * $predictionServiceClient->close(); + * } + * ``` + * + * @param string $endpoint Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * @param array $optionalArgs { + * Optional. + * + * @type Tensor[] $inputs + * The prediction input. + * @type Tensor $parameters + * The parameters that govern the prediction. + * @type RetrySettings|array $retrySettings + * Retry settings to use for this call. Can be a {@see RetrySettings} object, or an + * associative array of retry settings parameters. See the documentation on + * {@see RetrySettings} for example usage. + * } + * + * @return \Google\Cloud\AIPlatform\V1\DirectPredictResponse + * + * @throws ApiException if the remote call fails + */ + public function directPredict($endpoint, array $optionalArgs = []) + { + $request = new DirectPredictRequest(); + $requestParamHeaders = []; + $request->setEndpoint($endpoint); + $requestParamHeaders['endpoint'] = $endpoint; + if (isset($optionalArgs['inputs'])) { + $request->setInputs($optionalArgs['inputs']); + } + + if (isset($optionalArgs['parameters'])) { + $request->setParameters($optionalArgs['parameters']); + } + + $requestParams = new RequestParamsHeaderDescriptor( + $requestParamHeaders + ); + $optionalArgs['headers'] = isset($optionalArgs['headers']) + ? array_merge($requestParams->getHeader(), $optionalArgs['headers']) + : $requestParams->getHeader(); + return $this->startCall( + 'DirectPredict', + DirectPredictResponse::class, + $optionalArgs, + $request + )->wait(); + } + + /** + * Perform an online prediction request through gRPC. + * + * Sample code: + * ``` + * $predictionServiceClient = new PredictionServiceClient(); + * try { + * $formattedEndpoint = $predictionServiceClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + * $response = $predictionServiceClient->directRawPredict($formattedEndpoint); + * } finally { + * $predictionServiceClient->close(); + * } + * ``` + * + * @param string $endpoint Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * @param array $optionalArgs { + * Optional. + * + * @type string $methodName + * Fully qualified name of the API method being invoked to perform + * predictions. + * + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * @type string $input + * The prediction input. + * @type RetrySettings|array $retrySettings + * Retry settings to use for this call. Can be a {@see RetrySettings} object, or an + * associative array of retry settings parameters. See the documentation on + * {@see RetrySettings} for example usage. + * } + * + * @return \Google\Cloud\AIPlatform\V1\DirectRawPredictResponse + * + * @throws ApiException if the remote call fails + */ + public function directRawPredict($endpoint, array $optionalArgs = []) + { + $request = new DirectRawPredictRequest(); + $requestParamHeaders = []; + $request->setEndpoint($endpoint); + $requestParamHeaders['endpoint'] = $endpoint; + if (isset($optionalArgs['methodName'])) { + $request->setMethodName($optionalArgs['methodName']); + } + + if (isset($optionalArgs['input'])) { + $request->setInput($optionalArgs['input']); + } + + $requestParams = new RequestParamsHeaderDescriptor( + $requestParamHeaders + ); + $optionalArgs['headers'] = isset($optionalArgs['headers']) + ? array_merge($requestParams->getHeader(), $optionalArgs['headers']) + : $requestParams->getHeader(); + return $this->startCall( + 'DirectRawPredict', + DirectRawPredictResponse::class, + $optionalArgs, + $request + )->wait(); + } + /** * Perform an online explanation. * @@ -678,6 +814,141 @@ public function serverStreamingPredict($endpoint, array $optionalArgs = []) ); } + /** + * Perform a streaming online prediction request for Vertex first-party + * products and frameworks. + * + * Sample code: + * ``` + * $predictionServiceClient = new PredictionServiceClient(); + * try { + * $endpoint = 'endpoint'; + * $request = new StreamingPredictRequest(); + * $request->setEndpoint($endpoint); + * // Write all requests to the server, then read all responses until the + * // stream is complete + * $requests = [ + * $request, + * ]; + * $stream = $predictionServiceClient->streamingPredict(); + * $stream->writeAll($requests); + * foreach ($stream->closeWriteAndReadAll() as $element) { + * // doSomethingWith($element); + * } + * // Alternatively: + * // Write requests individually, making read() calls if + * // required. Call closeWrite() once writes are complete, and read the + * // remaining responses from the server. + * $requests = [ + * $request, + * ]; + * $stream = $predictionServiceClient->streamingPredict(); + * foreach ($requests as $request) { + * $stream->write($request); + * // if required, read a single response from the stream + * $element = $stream->read(); + * // doSomethingWith($element) + * } + * $stream->closeWrite(); + * $element = $stream->read(); + * while (!is_null($element)) { + * // doSomethingWith($element) + * $element = $stream->read(); + * } + * } finally { + * $predictionServiceClient->close(); + * } + * ``` + * + * @param array $optionalArgs { + * Optional. + * + * @type int $timeoutMillis + * Timeout to use for this call. + * } + * + * @return \Google\ApiCore\BidiStream + * + * @throws ApiException if the remote call fails + */ + public function streamingPredict(array $optionalArgs = []) + { + return $this->startCall( + 'StreamingPredict', + StreamingPredictResponse::class, + $optionalArgs, + null, + Call::BIDI_STREAMING_CALL + ); + } + + /** + * Perform a streaming online prediction request through gRPC. + * + * Sample code: + * ``` + * $predictionServiceClient = new PredictionServiceClient(); + * try { + * $endpoint = 'endpoint'; + * $request = new StreamingRawPredictRequest(); + * $request->setEndpoint($endpoint); + * // Write all requests to the server, then read all responses until the + * // stream is complete + * $requests = [ + * $request, + * ]; + * $stream = $predictionServiceClient->streamingRawPredict(); + * $stream->writeAll($requests); + * foreach ($stream->closeWriteAndReadAll() as $element) { + * // doSomethingWith($element); + * } + * // Alternatively: + * // Write requests individually, making read() calls if + * // required. Call closeWrite() once writes are complete, and read the + * // remaining responses from the server. + * $requests = [ + * $request, + * ]; + * $stream = $predictionServiceClient->streamingRawPredict(); + * foreach ($requests as $request) { + * $stream->write($request); + * // if required, read a single response from the stream + * $element = $stream->read(); + * // doSomethingWith($element) + * } + * $stream->closeWrite(); + * $element = $stream->read(); + * while (!is_null($element)) { + * // doSomethingWith($element) + * $element = $stream->read(); + * } + * } finally { + * $predictionServiceClient->close(); + * } + * ``` + * + * @param array $optionalArgs { + * Optional. + * + * @type int $timeoutMillis + * Timeout to use for this call. + * } + * + * @return \Google\ApiCore\BidiStream + * + * @throws ApiException if the remote call fails + */ + public function streamingRawPredict(array $optionalArgs = []) + { + return $this->startCall( + 'StreamingRawPredict', + StreamingRawPredictResponse::class, + $optionalArgs, + null, + Call::BIDI_STREAMING_CALL + ); + } + /** * Gets information about a location. * diff --git a/AiPlatform/src/V1/ModelContainerSpec.php b/AiPlatform/src/V1/ModelContainerSpec.php index 269b152417dc..1ed2155a6eff 100644 --- a/AiPlatform/src/V1/ModelContainerSpec.php +++ b/AiPlatform/src/V1/ModelContainerSpec.php @@ -225,6 +225,18 @@ class ModelContainerSpec extends \Google\Protobuf\Internal\Message * Generated from protobuf field string health_route = 7 [(.google.api.field_behavior) = IMMUTABLE]; */ private $health_route = ''; + /** + * Immutable. List of ports to expose from the container. Vertex AI sends gRPC + * prediction requests that it receives to the first port on this list. Vertex + * AI also sends liveness and health checks to this port. + * If you do not specify this field, gRPC requests to the container will be + * disabled. + * Vertex AI does not use ports other than the first one listed. This field + * corresponds to the `ports` field of the Kubernetes Containers v1 core API. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Port grpc_ports = 9 [(.google.api.field_behavior) = IMMUTABLE]; + */ + private $grpc_ports; /** * Immutable. Deployment timeout. * Limit for deployment timeout is 2 hours. @@ -438,6 +450,14 @@ class ModelContainerSpec extends \Google\Protobuf\Internal\Message * (Vertex AI makes this value available to your container code as the * [`AIP_DEPLOYED_MODEL_ID` environment * variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables).) + * @type array<\Google\Cloud\AIPlatform\V1\Port>|\Google\Protobuf\Internal\RepeatedField $grpc_ports + * Immutable. List of ports to expose from the container. Vertex AI sends gRPC + * prediction requests that it receives to the first port on this list. Vertex + * AI also sends liveness and health checks to this port. + * If you do not specify this field, gRPC requests to the container will be + * disabled. + * Vertex AI does not use ports other than the first one listed. This field + * corresponds to the `ports` field of the Kubernetes Containers v1 core API. * @type \Google\Protobuf\Duration $deployment_timeout * Immutable. Deployment timeout. * Limit for deployment timeout is 2 hours. @@ -969,6 +989,44 @@ public function setHealthRoute($var) return $this; } + /** + * Immutable. List of ports to expose from the container. Vertex AI sends gRPC + * prediction requests that it receives to the first port on this list. Vertex + * AI also sends liveness and health checks to this port. + * If you do not specify this field, gRPC requests to the container will be + * disabled. + * Vertex AI does not use ports other than the first one listed. This field + * corresponds to the `ports` field of the Kubernetes Containers v1 core API. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Port grpc_ports = 9 [(.google.api.field_behavior) = IMMUTABLE]; + * @return \Google\Protobuf\Internal\RepeatedField + */ + public function getGrpcPorts() + { + return $this->grpc_ports; + } + + /** + * Immutable. List of ports to expose from the container. Vertex AI sends gRPC + * prediction requests that it receives to the first port on this list. Vertex + * AI also sends liveness and health checks to this port. + * If you do not specify this field, gRPC requests to the container will be + * disabled. + * Vertex AI does not use ports other than the first one listed. This field + * corresponds to the `ports` field of the Kubernetes Containers v1 core API. + * + * Generated from protobuf field repeated .google.cloud.aiplatform.v1.Port grpc_ports = 9 [(.google.api.field_behavior) = IMMUTABLE]; + * @param array<\Google\Cloud\AIPlatform\V1\Port>|\Google\Protobuf\Internal\RepeatedField $var + * @return $this + */ + public function setGrpcPorts($var) + { + $arr = GPBUtil::checkRepeatedField($var, \Google\Protobuf\Internal\GPBType::MESSAGE, \Google\Cloud\AIPlatform\V1\Port::class); + $this->grpc_ports = $arr; + + return $this; + } + /** * Immutable. Deployment timeout. * Limit for deployment timeout is 2 hours. diff --git a/AiPlatform/src/V1/StreamingRawPredictRequest.php b/AiPlatform/src/V1/StreamingRawPredictRequest.php new file mode 100644 index 000000000000..cc16d4d3a4af --- /dev/null +++ b/AiPlatform/src/V1/StreamingRawPredictRequest.php @@ -0,0 +1,174 @@ +google.cloud.aiplatform.v1.StreamingRawPredictRequest + */ +class StreamingRawPredictRequest extends \Google\Protobuf\Internal\Message +{ + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + */ + private $endpoint = ''; + /** + * Fully qualified name of the API method being invoked to perform + * predictions. + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * + * Generated from protobuf field string method_name = 2; + */ + private $method_name = ''; + /** + * The prediction input. + * + * Generated from protobuf field bytes input = 3; + */ + private $input = ''; + + /** + * Constructor. + * + * @param array $data { + * Optional. Data for populating the Message object. + * + * @type string $endpoint + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * @type string $method_name + * Fully qualified name of the API method being invoked to perform + * predictions. + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * @type string $input + * The prediction input. + * } + */ + public function __construct($data = NULL) { + \GPBMetadata\Google\Cloud\Aiplatform\V1\PredictionService::initOnce(); + parent::__construct($data); + } + + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + * @return string + */ + public function getEndpoint() + { + return $this->endpoint; + } + + /** + * Required. The name of the Endpoint requested to serve the prediction. + * Format: + * `projects/{project}/locations/{location}/endpoints/{endpoint}` + * + * Generated from protobuf field string endpoint = 1 [(.google.api.field_behavior) = REQUIRED, (.google.api.resource_reference) = { + * @param string $var + * @return $this + */ + public function setEndpoint($var) + { + GPBUtil::checkString($var, True); + $this->endpoint = $var; + + return $this; + } + + /** + * Fully qualified name of the API method being invoked to perform + * predictions. + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * + * Generated from protobuf field string method_name = 2; + * @return string + */ + public function getMethodName() + { + return $this->method_name; + } + + /** + * Fully qualified name of the API method being invoked to perform + * predictions. + * Format: + * `/namespace.Service/Method/` + * Example: + * `/tensorflow.serving.PredictionService/Predict` + * + * Generated from protobuf field string method_name = 2; + * @param string $var + * @return $this + */ + public function setMethodName($var) + { + GPBUtil::checkString($var, True); + $this->method_name = $var; + + return $this; + } + + /** + * The prediction input. + * + * Generated from protobuf field bytes input = 3; + * @return string + */ + public function getInput() + { + return $this->input; + } + + /** + * The prediction input. + * + * Generated from protobuf field bytes input = 3; + * @param string $var + * @return $this + */ + public function setInput($var) + { + GPBUtil::checkString($var, False); + $this->input = $var; + + return $this; + } + +} + diff --git a/AiPlatform/src/V1/StreamingRawPredictResponse.php b/AiPlatform/src/V1/StreamingRawPredictResponse.php new file mode 100644 index 000000000000..ba43e40ae271 --- /dev/null +++ b/AiPlatform/src/V1/StreamingRawPredictResponse.php @@ -0,0 +1,68 @@ +google.cloud.aiplatform.v1.StreamingRawPredictResponse + */ +class StreamingRawPredictResponse extends \Google\Protobuf\Internal\Message +{ + /** + * The prediction output. + * + * Generated from protobuf field bytes output = 1; + */ + private $output = ''; + + /** + * Constructor. + * + * @param array $data { + * Optional. Data for populating the Message object. + * + * @type string $output + * The prediction output. + * } + */ + public function __construct($data = NULL) { + \GPBMetadata\Google\Cloud\Aiplatform\V1\PredictionService::initOnce(); + parent::__construct($data); + } + + /** + * The prediction output. + * + * Generated from protobuf field bytes output = 1; + * @return string + */ + public function getOutput() + { + return $this->output; + } + + /** + * The prediction output. + * + * Generated from protobuf field bytes output = 1; + * @param string $var + * @return $this + */ + public function setOutput($var) + { + GPBUtil::checkString($var, False); + $this->output = $var; + + return $this; + } + +} + diff --git a/AiPlatform/src/V1/gapic_metadata.json b/AiPlatform/src/V1/gapic_metadata.json index 20682b1dfec8..2ccde3e144b5 100644 --- a/AiPlatform/src/V1/gapic_metadata.json +++ b/AiPlatform/src/V1/gapic_metadata.json @@ -970,6 +970,16 @@ "grpc": { "libraryClient": "PredictionServiceGapicClient", "rpcs": { + "DirectPredict": { + "methods": [ + "directPredict" + ] + }, + "DirectRawPredict": { + "methods": [ + "directRawPredict" + ] + }, "Explain": { "methods": [ "explain" @@ -990,6 +1000,16 @@ "serverStreamingPredict" ] }, + "StreamingPredict": { + "methods": [ + "streamingPredict" + ] + }, + "StreamingRawPredict": { + "methods": [ + "streamingRawPredict" + ] + }, "GetLocation": { "methods": [ "getLocation" diff --git a/AiPlatform/src/V1/resources/prediction_service_client_config.json b/AiPlatform/src/V1/resources/prediction_service_client_config.json index 330f5373da63..4f0fb9dbc543 100644 --- a/AiPlatform/src/V1/resources/prediction_service_client_config.json +++ b/AiPlatform/src/V1/resources/prediction_service_client_config.json @@ -16,6 +16,16 @@ } }, "methods": { + "DirectPredict": { + "timeout_millis": 60000, + "retry_codes_name": "no_retry_codes", + "retry_params_name": "no_retry_params" + }, + "DirectRawPredict": { + "timeout_millis": 60000, + "retry_codes_name": "no_retry_codes", + "retry_params_name": "no_retry_params" + }, "Explain": { "timeout_millis": 60000, "retry_codes_name": "no_retry_codes", @@ -34,6 +44,12 @@ "ServerStreamingPredict": { "timeout_millis": 60000 }, + "StreamingPredict": { + "timeout_millis": 60000 + }, + "StreamingRawPredict": { + "timeout_millis": 60000 + }, "GetLocation": { "timeout_millis": 60000, "retry_codes_name": "no_retry_codes", diff --git a/AiPlatform/src/V1/resources/prediction_service_descriptor_config.php b/AiPlatform/src/V1/resources/prediction_service_descriptor_config.php index 6960752d2ae0..a6d2557f1c22 100644 --- a/AiPlatform/src/V1/resources/prediction_service_descriptor_config.php +++ b/AiPlatform/src/V1/resources/prediction_service_descriptor_config.php @@ -3,6 +3,30 @@ return [ 'interfaces' => [ 'google.cloud.aiplatform.v1.PredictionService' => [ + 'DirectPredict' => [ + 'callType' => \Google\ApiCore\Call::UNARY_CALL, + 'responseType' => 'Google\Cloud\AIPlatform\V1\DirectPredictResponse', + 'headerParams' => [ + [ + 'keyName' => 'endpoint', + 'fieldAccessors' => [ + 'getEndpoint', + ], + ], + ], + ], + 'DirectRawPredict' => [ + 'callType' => \Google\ApiCore\Call::UNARY_CALL, + 'responseType' => 'Google\Cloud\AIPlatform\V1\DirectRawPredictResponse', + 'headerParams' => [ + [ + 'keyName' => 'endpoint', + 'fieldAccessors' => [ + 'getEndpoint', + ], + ], + ], + ], 'Explain' => [ 'callType' => \Google\ApiCore\Call::UNARY_CALL, 'responseType' => 'Google\Cloud\AIPlatform\V1\ExplainResponse', @@ -54,6 +78,20 @@ ], ], ], + 'StreamingPredict' => [ + 'grpcStreaming' => [ + 'grpcStreamingType' => 'BidiStreaming', + ], + 'callType' => \Google\ApiCore\Call::BIDI_STREAMING_CALL, + 'responseType' => 'Google\Cloud\AIPlatform\V1\StreamingPredictResponse', + ], + 'StreamingRawPredict' => [ + 'grpcStreaming' => [ + 'grpcStreamingType' => 'BidiStreaming', + ], + 'callType' => \Google\ApiCore\Call::BIDI_STREAMING_CALL, + 'responseType' => 'Google\Cloud\AIPlatform\V1\StreamingRawPredictResponse', + ], 'GetLocation' => [ 'callType' => \Google\ApiCore\Call::UNARY_CALL, 'responseType' => 'Google\Cloud\Location\Location', diff --git a/AiPlatform/src/V1/resources/prediction_service_rest_client_config.php b/AiPlatform/src/V1/resources/prediction_service_rest_client_config.php index 15457404299c..d2d3320b759f 100644 --- a/AiPlatform/src/V1/resources/prediction_service_rest_client_config.php +++ b/AiPlatform/src/V1/resources/prediction_service_rest_client_config.php @@ -3,6 +3,30 @@ return [ 'interfaces' => [ 'google.cloud.aiplatform.v1.PredictionService' => [ + 'DirectPredict' => [ + 'method' => 'post', + 'uriTemplate' => '/v1/{endpoint=projects/*/locations/*/endpoints/*}:directPredict', + 'body' => '*', + 'placeholders' => [ + 'endpoint' => [ + 'getters' => [ + 'getEndpoint', + ], + ], + ], + ], + 'DirectRawPredict' => [ + 'method' => 'post', + 'uriTemplate' => '/v1/{endpoint=projects/*/locations/*/endpoints/*}:directRawPredict', + 'body' => '*', + 'placeholders' => [ + 'endpoint' => [ + 'getters' => [ + 'getEndpoint', + ], + ], + ], + ], 'Explain' => [ 'method' => 'post', 'uriTemplate' => '/v1/{endpoint=projects/*/locations/*/endpoints/*}:explain', diff --git a/AiPlatform/tests/Unit/V1/Client/PredictionServiceClientTest.php b/AiPlatform/tests/Unit/V1/Client/PredictionServiceClientTest.php index f79b676a7cec..62084966ac0a 100644 --- a/AiPlatform/tests/Unit/V1/Client/PredictionServiceClientTest.php +++ b/AiPlatform/tests/Unit/V1/Client/PredictionServiceClientTest.php @@ -23,12 +23,17 @@ namespace Google\Cloud\AIPlatform\Tests\Unit\V1\Client; use Google\ApiCore\ApiException; +use Google\ApiCore\BidiStream; use Google\ApiCore\CredentialsWrapper; use Google\ApiCore\ServerStream; use Google\ApiCore\Testing\GeneratedTest; use Google\ApiCore\Testing\MockTransport; use Google\Api\HttpBody; use Google\Cloud\AIPlatform\V1\Client\PredictionServiceClient; +use Google\Cloud\AIPlatform\V1\DirectPredictRequest; +use Google\Cloud\AIPlatform\V1\DirectPredictResponse; +use Google\Cloud\AIPlatform\V1\DirectRawPredictRequest; +use Google\Cloud\AIPlatform\V1\DirectRawPredictResponse; use Google\Cloud\AIPlatform\V1\ExplainRequest; use Google\Cloud\AIPlatform\V1\ExplainResponse; use Google\Cloud\AIPlatform\V1\PredictRequest; @@ -36,6 +41,8 @@ use Google\Cloud\AIPlatform\V1\RawPredictRequest; use Google\Cloud\AIPlatform\V1\StreamingPredictRequest; use Google\Cloud\AIPlatform\V1\StreamingPredictResponse; +use Google\Cloud\AIPlatform\V1\StreamingRawPredictRequest; +use Google\Cloud\AIPlatform\V1\StreamingRawPredictResponse; use Google\Cloud\Iam\V1\GetIamPolicyRequest; use Google\Cloud\Iam\V1\Policy; use Google\Cloud\Iam\V1\SetIamPolicyRequest; @@ -76,6 +83,132 @@ private function createClient(array $options = []) return new PredictionServiceClient($options); } + /** @test */ + public function directPredictTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + // Mock response + $expectedResponse = new DirectPredictResponse(); + $transport->addResponse($expectedResponse); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request = (new DirectPredictRequest()) + ->setEndpoint($formattedEndpoint); + $response = $gapicClient->directPredict($request); + $this->assertEquals($expectedResponse, $response); + $actualRequests = $transport->popReceivedCalls(); + $this->assertSame(1, count($actualRequests)); + $actualFuncCall = $actualRequests[0]->getFuncCall(); + $actualRequestObject = $actualRequests[0]->getRequestObject(); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/DirectPredict', $actualFuncCall); + $actualValue = $actualRequestObject->getEndpoint(); + $this->assertProtobufEquals($formattedEndpoint, $actualValue); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function directPredictExceptionTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + $status = new stdClass(); + $status->code = Code::DATA_LOSS; + $status->details = 'internal error'; + $expectedExceptionMessage = json_encode([ + 'message' => 'internal error', + 'code' => Code::DATA_LOSS, + 'status' => 'DATA_LOSS', + 'details' => [], + ], JSON_PRETTY_PRINT); + $transport->addResponse(null, $status); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request = (new DirectPredictRequest()) + ->setEndpoint($formattedEndpoint); + try { + $gapicClient->directPredict($request); + // If the $gapicClient method call did not throw, fail the test + $this->fail('Expected an ApiException, but no exception was thrown.'); + } catch (ApiException $ex) { + $this->assertEquals($status->code, $ex->getCode()); + $this->assertEquals($expectedExceptionMessage, $ex->getMessage()); + } + // Call popReceivedCalls to ensure the stub is exhausted + $transport->popReceivedCalls(); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function directRawPredictTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + // Mock response + $output = '1'; + $expectedResponse = new DirectRawPredictResponse(); + $expectedResponse->setOutput($output); + $transport->addResponse($expectedResponse); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request = (new DirectRawPredictRequest()) + ->setEndpoint($formattedEndpoint); + $response = $gapicClient->directRawPredict($request); + $this->assertEquals($expectedResponse, $response); + $actualRequests = $transport->popReceivedCalls(); + $this->assertSame(1, count($actualRequests)); + $actualFuncCall = $actualRequests[0]->getFuncCall(); + $actualRequestObject = $actualRequests[0]->getRequestObject(); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/DirectRawPredict', $actualFuncCall); + $actualValue = $actualRequestObject->getEndpoint(); + $this->assertProtobufEquals($formattedEndpoint, $actualValue); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function directRawPredictExceptionTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + $status = new stdClass(); + $status->code = Code::DATA_LOSS; + $status->details = 'internal error'; + $expectedExceptionMessage = json_encode([ + 'message' => 'internal error', + 'code' => Code::DATA_LOSS, + 'status' => 'DATA_LOSS', + 'details' => [], + ], JSON_PRETTY_PRINT); + $transport->addResponse(null, $status); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request = (new DirectRawPredictRequest()) + ->setEndpoint($formattedEndpoint); + try { + $gapicClient->directRawPredict($request); + // If the $gapicClient method call did not throw, fail the test + $this->fail('Expected an ApiException, but no exception was thrown.'); + } catch (ApiException $ex) { + $this->assertEquals($status->code, $ex->getCode()); + $this->assertEquals($expectedExceptionMessage, $ex->getMessage()); + } + // Call popReceivedCalls to ensure the stub is exhausted + $transport->popReceivedCalls(); + $this->assertTrue($transport->isExhausted()); + } + /** @test */ public function explainTest() { @@ -362,6 +495,200 @@ public function serverStreamingPredictExceptionTest() $this->assertTrue($transport->isExhausted()); } + /** @test */ + public function streamingPredictTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + // Mock response + $expectedResponse = new StreamingPredictResponse(); + $transport->addResponse($expectedResponse); + $expectedResponse2 = new StreamingPredictResponse(); + $transport->addResponse($expectedResponse2); + $expectedResponse3 = new StreamingPredictResponse(); + $transport->addResponse($expectedResponse3); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request = new StreamingPredictRequest(); + $request->setEndpoint($formattedEndpoint); + $formattedEndpoint2 = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request2 = new StreamingPredictRequest(); + $request2->setEndpoint($formattedEndpoint2); + $formattedEndpoint3 = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request3 = new StreamingPredictRequest(); + $request3->setEndpoint($formattedEndpoint3); + $bidi = $gapicClient->streamingPredict(); + $this->assertInstanceOf(BidiStream::class, $bidi); + $bidi->write($request); + $responses = []; + $responses[] = $bidi->read(); + $bidi->writeAll([ + $request2, + $request3, + ]); + foreach ($bidi->closeWriteAndReadAll() as $response) { + $responses[] = $response; + } + + $expectedResponses = []; + $expectedResponses[] = $expectedResponse; + $expectedResponses[] = $expectedResponse2; + $expectedResponses[] = $expectedResponse3; + $this->assertEquals($expectedResponses, $responses); + $createStreamRequests = $transport->popReceivedCalls(); + $this->assertSame(1, count($createStreamRequests)); + $streamFuncCall = $createStreamRequests[0]->getFuncCall(); + $streamRequestObject = $createStreamRequests[0]->getRequestObject(); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/StreamingPredict', $streamFuncCall); + $this->assertNull($streamRequestObject); + $callObjects = $transport->popCallObjects(); + $this->assertSame(1, count($callObjects)); + $bidiCall = $callObjects[0]; + $writeRequests = $bidiCall->popReceivedCalls(); + $expectedRequests = []; + $expectedRequests[] = $request; + $expectedRequests[] = $request2; + $expectedRequests[] = $request3; + $this->assertEquals($expectedRequests, $writeRequests); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function streamingPredictExceptionTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $status = new stdClass(); + $status->code = Code::DATA_LOSS; + $status->details = 'internal error'; + $expectedExceptionMessage = json_encode([ + 'message' => 'internal error', + 'code' => Code::DATA_LOSS, + 'status' => 'DATA_LOSS', + 'details' => [], + ], JSON_PRETTY_PRINT); + $transport->setStreamingStatus($status); + $this->assertTrue($transport->isExhausted()); + $bidi = $gapicClient->streamingPredict(); + $results = $bidi->closeWriteAndReadAll(); + try { + iterator_to_array($results); + // If the close stream method call did not throw, fail the test + $this->fail('Expected an ApiException, but no exception was thrown.'); + } catch (ApiException $ex) { + $this->assertEquals($status->code, $ex->getCode()); + $this->assertEquals($expectedExceptionMessage, $ex->getMessage()); + } + // Call popReceivedCalls to ensure the stub is exhausted + $transport->popReceivedCalls(); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function streamingRawPredictTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + // Mock response + $output = '1'; + $expectedResponse = new StreamingRawPredictResponse(); + $expectedResponse->setOutput($output); + $transport->addResponse($expectedResponse); + $output2 = '116'; + $expectedResponse2 = new StreamingRawPredictResponse(); + $expectedResponse2->setOutput($output2); + $transport->addResponse($expectedResponse2); + $output3 = '117'; + $expectedResponse3 = new StreamingRawPredictResponse(); + $expectedResponse3->setOutput($output3); + $transport->addResponse($expectedResponse3); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request = new StreamingRawPredictRequest(); + $request->setEndpoint($formattedEndpoint); + $formattedEndpoint2 = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request2 = new StreamingRawPredictRequest(); + $request2->setEndpoint($formattedEndpoint2); + $formattedEndpoint3 = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request3 = new StreamingRawPredictRequest(); + $request3->setEndpoint($formattedEndpoint3); + $bidi = $gapicClient->streamingRawPredict(); + $this->assertInstanceOf(BidiStream::class, $bidi); + $bidi->write($request); + $responses = []; + $responses[] = $bidi->read(); + $bidi->writeAll([ + $request2, + $request3, + ]); + foreach ($bidi->closeWriteAndReadAll() as $response) { + $responses[] = $response; + } + + $expectedResponses = []; + $expectedResponses[] = $expectedResponse; + $expectedResponses[] = $expectedResponse2; + $expectedResponses[] = $expectedResponse3; + $this->assertEquals($expectedResponses, $responses); + $createStreamRequests = $transport->popReceivedCalls(); + $this->assertSame(1, count($createStreamRequests)); + $streamFuncCall = $createStreamRequests[0]->getFuncCall(); + $streamRequestObject = $createStreamRequests[0]->getRequestObject(); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/StreamingRawPredict', $streamFuncCall); + $this->assertNull($streamRequestObject); + $callObjects = $transport->popCallObjects(); + $this->assertSame(1, count($callObjects)); + $bidiCall = $callObjects[0]; + $writeRequests = $bidiCall->popReceivedCalls(); + $expectedRequests = []; + $expectedRequests[] = $request; + $expectedRequests[] = $request2; + $expectedRequests[] = $request3; + $this->assertEquals($expectedRequests, $writeRequests); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function streamingRawPredictExceptionTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $status = new stdClass(); + $status->code = Code::DATA_LOSS; + $status->details = 'internal error'; + $expectedExceptionMessage = json_encode([ + 'message' => 'internal error', + 'code' => Code::DATA_LOSS, + 'status' => 'DATA_LOSS', + 'details' => [], + ], JSON_PRETTY_PRINT); + $transport->setStreamingStatus($status); + $this->assertTrue($transport->isExhausted()); + $bidi = $gapicClient->streamingRawPredict(); + $results = $bidi->closeWriteAndReadAll(); + try { + iterator_to_array($results); + // If the close stream method call did not throw, fail the test + $this->fail('Expected an ApiException, but no exception was thrown.'); + } catch (ApiException $ex) { + $this->assertEquals($status->code, $ex->getCode()); + $this->assertEquals($expectedExceptionMessage, $ex->getMessage()); + } + // Call popReceivedCalls to ensure the stub is exhausted + $transport->popReceivedCalls(); + $this->assertTrue($transport->isExhausted()); + } + /** @test */ public function getLocationTest() { @@ -693,7 +1020,7 @@ public function testIamPermissionsExceptionTest() } /** @test */ - public function explainAsyncTest() + public function directPredictAsyncTest() { $transport = $this->createTransport(); $gapicClient = $this->createClient([ @@ -701,27 +1028,21 @@ public function explainAsyncTest() ]); $this->assertTrue($transport->isExhausted()); // Mock response - $deployedModelId2 = 'deployedModelId2-380204163'; - $expectedResponse = new ExplainResponse(); - $expectedResponse->setDeployedModelId($deployedModelId2); + $expectedResponse = new DirectPredictResponse(); $transport->addResponse($expectedResponse); // Mock request $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); - $instances = []; - $request = (new ExplainRequest()) - ->setEndpoint($formattedEndpoint) - ->setInstances($instances); - $response = $gapicClient->explainAsync($request)->wait(); + $request = (new DirectPredictRequest()) + ->setEndpoint($formattedEndpoint); + $response = $gapicClient->directPredictAsync($request)->wait(); $this->assertEquals($expectedResponse, $response); $actualRequests = $transport->popReceivedCalls(); $this->assertSame(1, count($actualRequests)); $actualFuncCall = $actualRequests[0]->getFuncCall(); $actualRequestObject = $actualRequests[0]->getRequestObject(); - $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/Explain', $actualFuncCall); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/DirectPredict', $actualFuncCall); $actualValue = $actualRequestObject->getEndpoint(); $this->assertProtobufEquals($formattedEndpoint, $actualValue); - $actualValue = $actualRequestObject->getInstances(); - $this->assertProtobufEquals($instances, $actualValue); $this->assertTrue($transport->isExhausted()); } } diff --git a/AiPlatform/tests/Unit/V1/PredictionServiceClientTest.php b/AiPlatform/tests/Unit/V1/PredictionServiceClientTest.php index 135031e88837..fc19abd897fe 100644 --- a/AiPlatform/tests/Unit/V1/PredictionServiceClientTest.php +++ b/AiPlatform/tests/Unit/V1/PredictionServiceClientTest.php @@ -23,15 +23,21 @@ namespace Google\Cloud\AIPlatform\Tests\Unit\V1; use Google\ApiCore\ApiException; +use Google\ApiCore\BidiStream; use Google\ApiCore\CredentialsWrapper; use Google\ApiCore\ServerStream; use Google\ApiCore\Testing\GeneratedTest; use Google\ApiCore\Testing\MockTransport; use Google\Api\HttpBody; +use Google\Cloud\AIPlatform\V1\DirectPredictResponse; +use Google\Cloud\AIPlatform\V1\DirectRawPredictResponse; use Google\Cloud\AIPlatform\V1\ExplainResponse; use Google\Cloud\AIPlatform\V1\PredictResponse; use Google\Cloud\AIPlatform\V1\PredictionServiceClient; +use Google\Cloud\AIPlatform\V1\StreamingPredictRequest; use Google\Cloud\AIPlatform\V1\StreamingPredictResponse; +use Google\Cloud\AIPlatform\V1\StreamingRawPredictRequest; +use Google\Cloud\AIPlatform\V1\StreamingRawPredictResponse; use Google\Cloud\Iam\V1\Policy; use Google\Cloud\Iam\V1\TestIamPermissionsResponse; use Google\Cloud\Location\ListLocationsResponse; @@ -67,6 +73,124 @@ private function createClient(array $options = []) return new PredictionServiceClient($options); } + /** @test */ + public function directPredictTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + // Mock response + $expectedResponse = new DirectPredictResponse(); + $transport->addResponse($expectedResponse); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $response = $gapicClient->directPredict($formattedEndpoint); + $this->assertEquals($expectedResponse, $response); + $actualRequests = $transport->popReceivedCalls(); + $this->assertSame(1, count($actualRequests)); + $actualFuncCall = $actualRequests[0]->getFuncCall(); + $actualRequestObject = $actualRequests[0]->getRequestObject(); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/DirectPredict', $actualFuncCall); + $actualValue = $actualRequestObject->getEndpoint(); + $this->assertProtobufEquals($formattedEndpoint, $actualValue); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function directPredictExceptionTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + $status = new stdClass(); + $status->code = Code::DATA_LOSS; + $status->details = 'internal error'; + $expectedExceptionMessage = json_encode([ + 'message' => 'internal error', + 'code' => Code::DATA_LOSS, + 'status' => 'DATA_LOSS', + 'details' => [], + ], JSON_PRETTY_PRINT); + $transport->addResponse(null, $status); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + try { + $gapicClient->directPredict($formattedEndpoint); + // If the $gapicClient method call did not throw, fail the test + $this->fail('Expected an ApiException, but no exception was thrown.'); + } catch (ApiException $ex) { + $this->assertEquals($status->code, $ex->getCode()); + $this->assertEquals($expectedExceptionMessage, $ex->getMessage()); + } + // Call popReceivedCalls to ensure the stub is exhausted + $transport->popReceivedCalls(); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function directRawPredictTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + // Mock response + $output = '1'; + $expectedResponse = new DirectRawPredictResponse(); + $expectedResponse->setOutput($output); + $transport->addResponse($expectedResponse); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $response = $gapicClient->directRawPredict($formattedEndpoint); + $this->assertEquals($expectedResponse, $response); + $actualRequests = $transport->popReceivedCalls(); + $this->assertSame(1, count($actualRequests)); + $actualFuncCall = $actualRequests[0]->getFuncCall(); + $actualRequestObject = $actualRequests[0]->getRequestObject(); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/DirectRawPredict', $actualFuncCall); + $actualValue = $actualRequestObject->getEndpoint(); + $this->assertProtobufEquals($formattedEndpoint, $actualValue); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function directRawPredictExceptionTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + $status = new stdClass(); + $status->code = Code::DATA_LOSS; + $status->details = 'internal error'; + $expectedExceptionMessage = json_encode([ + 'message' => 'internal error', + 'code' => Code::DATA_LOSS, + 'status' => 'DATA_LOSS', + 'details' => [], + ], JSON_PRETTY_PRINT); + $transport->addResponse(null, $status); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + try { + $gapicClient->directRawPredict($formattedEndpoint); + // If the $gapicClient method call did not throw, fail the test + $this->fail('Expected an ApiException, but no exception was thrown.'); + } catch (ApiException $ex) { + $this->assertEquals($status->code, $ex->getCode()); + $this->assertEquals($expectedExceptionMessage, $ex->getMessage()); + } + // Call popReceivedCalls to ensure the stub is exhausted + $transport->popReceivedCalls(); + $this->assertTrue($transport->isExhausted()); + } + /** @test */ public function explainTest() { @@ -333,6 +457,200 @@ public function serverStreamingPredictExceptionTest() $this->assertTrue($transport->isExhausted()); } + /** @test */ + public function streamingPredictTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + // Mock response + $expectedResponse = new StreamingPredictResponse(); + $transport->addResponse($expectedResponse); + $expectedResponse2 = new StreamingPredictResponse(); + $transport->addResponse($expectedResponse2); + $expectedResponse3 = new StreamingPredictResponse(); + $transport->addResponse($expectedResponse3); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request = new StreamingPredictRequest(); + $request->setEndpoint($formattedEndpoint); + $formattedEndpoint2 = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request2 = new StreamingPredictRequest(); + $request2->setEndpoint($formattedEndpoint2); + $formattedEndpoint3 = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request3 = new StreamingPredictRequest(); + $request3->setEndpoint($formattedEndpoint3); + $bidi = $gapicClient->streamingPredict(); + $this->assertInstanceOf(BidiStream::class, $bidi); + $bidi->write($request); + $responses = []; + $responses[] = $bidi->read(); + $bidi->writeAll([ + $request2, + $request3, + ]); + foreach ($bidi->closeWriteAndReadAll() as $response) { + $responses[] = $response; + } + + $expectedResponses = []; + $expectedResponses[] = $expectedResponse; + $expectedResponses[] = $expectedResponse2; + $expectedResponses[] = $expectedResponse3; + $this->assertEquals($expectedResponses, $responses); + $createStreamRequests = $transport->popReceivedCalls(); + $this->assertSame(1, count($createStreamRequests)); + $streamFuncCall = $createStreamRequests[0]->getFuncCall(); + $streamRequestObject = $createStreamRequests[0]->getRequestObject(); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/StreamingPredict', $streamFuncCall); + $this->assertNull($streamRequestObject); + $callObjects = $transport->popCallObjects(); + $this->assertSame(1, count($callObjects)); + $bidiCall = $callObjects[0]; + $writeRequests = $bidiCall->popReceivedCalls(); + $expectedRequests = []; + $expectedRequests[] = $request; + $expectedRequests[] = $request2; + $expectedRequests[] = $request3; + $this->assertEquals($expectedRequests, $writeRequests); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function streamingPredictExceptionTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $status = new stdClass(); + $status->code = Code::DATA_LOSS; + $status->details = 'internal error'; + $expectedExceptionMessage = json_encode([ + 'message' => 'internal error', + 'code' => Code::DATA_LOSS, + 'status' => 'DATA_LOSS', + 'details' => [], + ], JSON_PRETTY_PRINT); + $transport->setStreamingStatus($status); + $this->assertTrue($transport->isExhausted()); + $bidi = $gapicClient->streamingPredict(); + $results = $bidi->closeWriteAndReadAll(); + try { + iterator_to_array($results); + // If the close stream method call did not throw, fail the test + $this->fail('Expected an ApiException, but no exception was thrown.'); + } catch (ApiException $ex) { + $this->assertEquals($status->code, $ex->getCode()); + $this->assertEquals($expectedExceptionMessage, $ex->getMessage()); + } + // Call popReceivedCalls to ensure the stub is exhausted + $transport->popReceivedCalls(); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function streamingRawPredictTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $this->assertTrue($transport->isExhausted()); + // Mock response + $output = '1'; + $expectedResponse = new StreamingRawPredictResponse(); + $expectedResponse->setOutput($output); + $transport->addResponse($expectedResponse); + $output2 = '116'; + $expectedResponse2 = new StreamingRawPredictResponse(); + $expectedResponse2->setOutput($output2); + $transport->addResponse($expectedResponse2); + $output3 = '117'; + $expectedResponse3 = new StreamingRawPredictResponse(); + $expectedResponse3->setOutput($output3); + $transport->addResponse($expectedResponse3); + // Mock request + $formattedEndpoint = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request = new StreamingRawPredictRequest(); + $request->setEndpoint($formattedEndpoint); + $formattedEndpoint2 = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request2 = new StreamingRawPredictRequest(); + $request2->setEndpoint($formattedEndpoint2); + $formattedEndpoint3 = $gapicClient->endpointName('[PROJECT]', '[LOCATION]', '[ENDPOINT]'); + $request3 = new StreamingRawPredictRequest(); + $request3->setEndpoint($formattedEndpoint3); + $bidi = $gapicClient->streamingRawPredict(); + $this->assertInstanceOf(BidiStream::class, $bidi); + $bidi->write($request); + $responses = []; + $responses[] = $bidi->read(); + $bidi->writeAll([ + $request2, + $request3, + ]); + foreach ($bidi->closeWriteAndReadAll() as $response) { + $responses[] = $response; + } + + $expectedResponses = []; + $expectedResponses[] = $expectedResponse; + $expectedResponses[] = $expectedResponse2; + $expectedResponses[] = $expectedResponse3; + $this->assertEquals($expectedResponses, $responses); + $createStreamRequests = $transport->popReceivedCalls(); + $this->assertSame(1, count($createStreamRequests)); + $streamFuncCall = $createStreamRequests[0]->getFuncCall(); + $streamRequestObject = $createStreamRequests[0]->getRequestObject(); + $this->assertSame('/google.cloud.aiplatform.v1.PredictionService/StreamingRawPredict', $streamFuncCall); + $this->assertNull($streamRequestObject); + $callObjects = $transport->popCallObjects(); + $this->assertSame(1, count($callObjects)); + $bidiCall = $callObjects[0]; + $writeRequests = $bidiCall->popReceivedCalls(); + $expectedRequests = []; + $expectedRequests[] = $request; + $expectedRequests[] = $request2; + $expectedRequests[] = $request3; + $this->assertEquals($expectedRequests, $writeRequests); + $this->assertTrue($transport->isExhausted()); + } + + /** @test */ + public function streamingRawPredictExceptionTest() + { + $transport = $this->createTransport(); + $gapicClient = $this->createClient([ + 'transport' => $transport, + ]); + $status = new stdClass(); + $status->code = Code::DATA_LOSS; + $status->details = 'internal error'; + $expectedExceptionMessage = json_encode([ + 'message' => 'internal error', + 'code' => Code::DATA_LOSS, + 'status' => 'DATA_LOSS', + 'details' => [], + ], JSON_PRETTY_PRINT); + $transport->setStreamingStatus($status); + $this->assertTrue($transport->isExhausted()); + $bidi = $gapicClient->streamingRawPredict(); + $results = $bidi->closeWriteAndReadAll(); + try { + iterator_to_array($results); + // If the close stream method call did not throw, fail the test + $this->fail('Expected an ApiException, but no exception was thrown.'); + } catch (ApiException $ex) { + $this->assertEquals($status->code, $ex->getCode()); + $this->assertEquals($expectedExceptionMessage, $ex->getMessage()); + } + // Call popReceivedCalls to ensure the stub is exhausted + $transport->popReceivedCalls(); + $this->assertTrue($transport->isExhausted()); + } + /** @test */ public function getLocationTest() {