Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoPR cognitiveservices/data-plane/FormRecognizer] FormRecognizer: Updating Train API to take optional parameter. #5607

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(


def train_custom_model(
self, source, custom_headers=None, raw=False, **operation_config):
self, source, source_filter=None, custom_headers=None, raw=False, **operation_config):
"""Train Model.

Create and train a custom model. The train request must include a
Expand All @@ -92,6 +92,10 @@ def train_custom_model(

:param source: Get or set source path.
:type source: str
:param source_filter: Get or set filter to further search the
source path for content.
:type source_filter:
~azure.cognitiveservices.formrecognizer.models.TrainSourceFilter
:param dict custom_headers: headers that will be added to the request
:param bool raw: returns the direct response alongside the
deserialized response
Expand All @@ -103,7 +107,7 @@ def train_custom_model(
:raises:
:class:`ErrorResponseException<azure.cognitiveservices.formrecognizer.models.ErrorResponseException>`
"""
train_request = models.TrainRequest(source=source)
train_request = models.TrainRequest(source=source, source_filter=source_filter)

# Construct URL
url = self.train_custom_model.metadata['url']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# --------------------------------------------------------------------------

try:
from .train_source_filter_py3 import TrainSourceFilter
from .train_request_py3 import TrainRequest
from .form_document_report_py3 import FormDocumentReport
from .form_operation_error_py3 import FormOperationError
Expand Down Expand Up @@ -38,6 +39,7 @@
from .computer_vision_error_py3 import ComputerVisionError, ComputerVisionErrorException
from .image_url_py3 import ImageUrl
except (SyntaxError, ImportError):
from .train_source_filter import TrainSourceFilter
from .train_request import TrainRequest
from .form_document_report import FormDocumentReport
from .form_operation_error import FormOperationError
Expand Down Expand Up @@ -72,6 +74,7 @@
)

__all__ = [
'TrainSourceFilter',
'TrainRequest',
'FormDocumentReport',
'FormOperationError',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class TrainRequest(Model):

:param source: Required. Get or set source path.
:type source: str
:param source_filter: Get or set filter to further search the
source path for content.
:type source_filter:
~azure.cognitiveservices.formrecognizer.models.TrainSourceFilter
"""

_validation = {
Expand All @@ -27,8 +31,10 @@ class TrainRequest(Model):

_attribute_map = {
'source': {'key': 'source', 'type': 'str'},
'source_filter': {'key': 'sourceFilter', 'type': 'TrainSourceFilter'},
}

def __init__(self, **kwargs):
super(TrainRequest, self).__init__(**kwargs)
self.source = kwargs.get('source', None)
self.source_filter = kwargs.get('source_filter', None)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class TrainRequest(Model):

:param source: Required. Get or set source path.
:type source: str
:param source_filter: Get or set filter to further search the
source path for content.
:type source_filter:
~azure.cognitiveservices.formrecognizer.models.TrainSourceFilter
"""

_validation = {
Expand All @@ -27,8 +31,10 @@ class TrainRequest(Model):

_attribute_map = {
'source': {'key': 'source', 'type': 'str'},
'source_filter': {'key': 'sourceFilter', 'type': 'TrainSourceFilter'},
}

def __init__(self, *, source: str, **kwargs) -> None:
def __init__(self, *, source: str, source_filter=None, **kwargs) -> None:
super(TrainRequest, self).__init__(**kwargs)
self.source = source
self.source_filter = source_filter
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is
# regenerated.
# --------------------------------------------------------------------------

from msrest.serialization import Model


class TrainSourceFilter(Model):
"""Filters to be applied when traversing a data source.

:param prefix: A case-sensitive prefix string to filter content
under the source location. For e.g., when using a Azure Blob
Uri use the prefix to restrict subfolders for content.
:type prefix: str
:param include_sub_folders: A flag to indicate if sub folders within the
set of
prefix folders will also need to be included when searching
for content to be preprocessed.
:type include_sub_folders: bool
"""

_validation = {
'prefix': {'max_length': 128, 'min_length': 0},
}

_attribute_map = {
'prefix': {'key': 'prefix', 'type': 'str'},
'include_sub_folders': {'key': 'includeSubFolders', 'type': 'bool'},
}

def __init__(self, **kwargs):
super(TrainSourceFilter, self).__init__(**kwargs)
self.prefix = kwargs.get('prefix', None)
self.include_sub_folders = kwargs.get('include_sub_folders', None)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is
# regenerated.
# --------------------------------------------------------------------------

from msrest.serialization import Model


class TrainSourceFilter(Model):
"""Filters to be applied when traversing a data source.

:param prefix: A case-sensitive prefix string to filter content
under the source location. For e.g., when using a Azure Blob
Uri use the prefix to restrict subfolders for content.
:type prefix: str
:param include_sub_folders: A flag to indicate if sub folders within the
set of
prefix folders will also need to be included when searching
for content to be preprocessed.
:type include_sub_folders: bool
"""

_validation = {
'prefix': {'max_length': 128, 'min_length': 0},
}

_attribute_map = {
'prefix': {'key': 'prefix', 'type': 'str'},
'include_sub_folders': {'key': 'includeSubFolders', 'type': 'bool'},
}

def __init__(self, *, prefix: str=None, include_sub_folders: bool=None, **kwargs) -> None:
super(TrainSourceFilter, self).__init__(**kwargs)
self.prefix = prefix
self.include_sub_folders = include_sub_folders