From 8633ace70b52904c9306bf5fc887a433b70cd1bc Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Tue, 29 Jun 2021 17:51:53 -0700 Subject: [PATCH] run black formatting --- .../azure-ai-textanalytics/azure/__init__.py | 2 +- .../azure/ai/__init__.py | 2 +- .../azure/ai/textanalytics/__init__.py | 72 +-- .../azure/ai/textanalytics/_base_client.py | 9 +- .../azure/ai/textanalytics/_lro.py | 59 +- .../azure/ai/textanalytics/_models.py | 532 +++++++++++------- .../azure/ai/textanalytics/_policies.py | 12 +- .../ai/textanalytics/_request_handlers.py | 37 +- .../ai/textanalytics/_response_handlers.py | 218 +++++-- .../textanalytics/_response_handlers_async.py | 55 +- .../textanalytics/_text_analytics_client.py | 157 ++++-- .../azure/ai/textanalytics/aio/__init__.py | 6 +- .../textanalytics/aio/_base_client_async.py | 6 +- .../azure/ai/textanalytics/aio/_lro_async.py | 40 +- .../aio/_text_analytics_client_async.py | 226 +++++--- 15 files changed, 936 insertions(+), 497 deletions(-) diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/__init__.py b/sdk/textanalytics/azure-ai-textanalytics/azure/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/__init__.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/__init__.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/__init__.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py index 37e9c9cc34ed..90bc268bdbc6 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py @@ -49,46 +49,46 @@ EntityConditionality, EntityCertainty, EntityAssociation, - HealthcareEntityCategory + HealthcareEntityCategory, ) from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller __all__ = [ - 'TextAnalyticsApiVersion', - 'TextAnalyticsClient', - 'DetectLanguageInput', - 'TextDocumentInput', - 'DetectedLanguage', - 'RecognizeEntitiesResult', - 'DetectLanguageResult', - 'CategorizedEntity', - 'TextAnalyticsError', - 'TextAnalyticsWarning', - 'ExtractKeyPhrasesResult', - 'RecognizeLinkedEntitiesResult', - 'AnalyzeSentimentResult', - 'TextDocumentStatistics', - 'DocumentError', - 'LinkedEntity', - 'LinkedEntityMatch', - 'TextDocumentBatchStatistics', - 'SentenceSentiment', - 'SentimentConfidenceScores', - 'MinedOpinion', - 'TargetSentiment', - 'AssessmentSentiment', - 'RecognizePiiEntitiesResult', - 'PiiEntity', - 'PiiEntityDomain', - 'AnalyzeHealthcareEntitiesResult', - 'HealthcareEntity', - 'HealthcareEntityDataSource', - 'RecognizeEntitiesAction', - 'RecognizeLinkedEntitiesAction', - 'RecognizePiiEntitiesAction', - 'ExtractKeyPhrasesAction', - '_AnalyzeActionsType', + "TextAnalyticsApiVersion", + "TextAnalyticsClient", + "DetectLanguageInput", + "TextDocumentInput", + "DetectedLanguage", + "RecognizeEntitiesResult", + "DetectLanguageResult", + "CategorizedEntity", + "TextAnalyticsError", + "TextAnalyticsWarning", + "ExtractKeyPhrasesResult", + "RecognizeLinkedEntitiesResult", + "AnalyzeSentimentResult", + "TextDocumentStatistics", + "DocumentError", + "LinkedEntity", + "LinkedEntityMatch", + "TextDocumentBatchStatistics", + "SentenceSentiment", + "SentimentConfidenceScores", + "MinedOpinion", + "TargetSentiment", + "AssessmentSentiment", + "RecognizePiiEntitiesResult", + "PiiEntity", + "PiiEntityDomain", + "AnalyzeHealthcareEntitiesResult", + "HealthcareEntity", + "HealthcareEntityDataSource", + "RecognizeEntitiesAction", + "RecognizeLinkedEntitiesAction", + "RecognizePiiEntitiesAction", + "ExtractKeyPhrasesAction", + "_AnalyzeActionsType", "PiiEntityCategory", "HealthcareEntityRelation", "HealthcareRelation", @@ -100,7 +100,7 @@ "AnalyzeSentimentAction", "AnalyzeHealthcareEntitiesLROPoller", "AnalyzeActionsLROPoller", - "HealthcareEntityCategory" + "HealthcareEntityCategory", ] __version__ = VERSION diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_base_client.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_base_client.py index 091de75b0438..2d1e9896e856 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_base_client.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_base_client.py @@ -10,6 +10,7 @@ from ._policies import TextAnalyticsResponseHookPolicy from ._user_agent import USER_AGENT + class TextAnalyticsApiVersion(str, Enum): """Text Analytics API versions supported by this package""" @@ -17,6 +18,7 @@ class TextAnalyticsApiVersion(str, Enum): V3_1 = "v3.1" V3_0 = "v3.0" + def _authentication_policy(credential): authentication_policy = None if credential is None: @@ -26,8 +28,10 @@ def _authentication_policy(credential): name="Ocp-Apim-Subscription-Key", credential=credential ) elif credential is not None and not hasattr(credential, "get_token"): - raise TypeError("Unsupported credential: {}. Use an instance of AzureKeyCredential " - "or a token credential from azure.identity".format(type(credential))) + raise TypeError( + "Unsupported credential: {}. Use an instance of AzureKeyCredential " + "or a token credential from azure.identity".format(type(credential)) + ) return authentication_policy @@ -43,7 +47,6 @@ def __init__(self, endpoint, credential, **kwargs): **kwargs ) - def __enter__(self): self._client.__enter__() # pylint:disable=no-member return self diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_lro.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_lro.py index 482ead5e1798..e47d5dbfa634 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_lro.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_lro.py @@ -8,7 +8,12 @@ from azure.core.polling._poller import PollingReturnType from azure.core.exceptions import HttpResponseError from azure.core.polling import LROPoller -from azure.core.polling.base_polling import LROBasePolling, OperationResourcePolling, OperationFailed, BadStatus +from azure.core.polling.base_polling import ( + LROBasePolling, + OperationResourcePolling, + OperationFailed, + BadStatus, +) _FINISHED = frozenset(["succeeded", "cancelled", "failed", "partiallycompleted"]) _FAILED = frozenset(["failed"]) @@ -20,19 +25,24 @@ class TextAnalyticsOperationResourcePolling(OperationResourcePolling): - def __init__(self, operation_location_header="operation-location", show_stats=False): - super(TextAnalyticsOperationResourcePolling, self).__init__(operation_location_header=operation_location_header) + def __init__( + self, operation_location_header="operation-location", show_stats=False + ): + super(TextAnalyticsOperationResourcePolling, self).__init__( + operation_location_header=operation_location_header + ) self._show_stats = show_stats - self._query_params = { - "showStats": show_stats - } + self._query_params = {"showStats": show_stats} def get_polling_url(self): if not self._show_stats: return super(TextAnalyticsOperationResourcePolling, self).get_polling_url() - return super(TextAnalyticsOperationResourcePolling, self).get_polling_url() + \ - "?" + urlencode(self._query_params) + return ( + super(TextAnalyticsOperationResourcePolling, self).get_polling_url() + + "?" + + urlencode(self._query_params) + ) class TextAnalyticsLROPollingMethod(LROBasePolling): @@ -93,11 +103,12 @@ def _poll(self): final_get_url = self._operation.get_final_get_url(self._pipeline_response) if final_get_url: self._pipeline_response = self.request_status(final_get_url) - TextAnalyticsLROPollingMethod._raise_if_bad_http_status_and_method(self._pipeline_response.http_response) + TextAnalyticsLROPollingMethod._raise_if_bad_http_status_and_method( + self._pipeline_response.http_response + ) class AnalyzeHealthcareEntitiesLROPollingMethod(TextAnalyticsLROPollingMethod): - def __init__(self, *args, **kwargs): self._text_analytics_client = kwargs.pop("text_analytics_client") super(AnalyzeHealthcareEntitiesLROPollingMethod, self).__init__(*args, **kwargs) @@ -105,6 +116,7 @@ def __init__(self, *args, **kwargs): @property def _current_body(self): from ._generated.v3_1.models import JobMetadata + return JobMetadata.deserialize(self._pipeline_response) @property @@ -133,11 +145,9 @@ def id(self): class AnalyzeHealthcareEntitiesLROPoller(LROPoller, Generic[PollingReturnType]): - def polling_method(self): # type: () -> AnalyzeHealthcareEntitiesLROPollingMethod - """Return the polling method associated to this poller. - """ + """Return the polling method associated to this poller.""" return self._polling_method # type: ignore @property @@ -208,20 +218,23 @@ def cancel(self, **kwargs): # type: ignore # Get a final status update. getattr(self._polling_method, "update_status")() - return getattr(self._polling_method, "_text_analytics_client").begin_cancel_health_job( - self.id, - polling=TextAnalyticsLROPollingMethod(timeout=polling_interval) + return getattr( + self._polling_method, "_text_analytics_client" + ).begin_cancel_health_job( + self.id, polling=TextAnalyticsLROPollingMethod(timeout=polling_interval) ) except HttpResponseError as error: from ._response_handlers import process_http_response_error + process_http_response_error(error) -class AnalyzeActionsLROPollingMethod(TextAnalyticsLROPollingMethod): +class AnalyzeActionsLROPollingMethod(TextAnalyticsLROPollingMethod): @property def _current_body(self): from ._generated.v3_1.models import AnalyzeJobMetadata + return AnalyzeJobMetadata.deserialize(self._pipeline_response) @property @@ -246,19 +259,19 @@ def display_name(self): def actions_failed_count(self): if not self._current_body: return None - return self._current_body.additional_properties['tasks']['failed'] + return self._current_body.additional_properties["tasks"]["failed"] @property def actions_in_progress_count(self): if not self._current_body: return None - return self._current_body.additional_properties['tasks']['inProgress'] + return self._current_body.additional_properties["tasks"]["inProgress"] @property def actions_succeeded_count(self): if not self._current_body: return None - return self._current_body.additional_properties['tasks']["completed"] + return self._current_body.additional_properties["tasks"]["completed"] @property def last_modified_on(self): @@ -270,7 +283,7 @@ def last_modified_on(self): def total_actions_count(self): if not self._current_body: return None - return self._current_body.additional_properties['tasks']["total"] + return self._current_body.additional_properties["tasks"]["total"] @property def id(self): @@ -280,11 +293,9 @@ def id(self): class AnalyzeActionsLROPoller(LROPoller, Generic[PollingReturnType]): - def polling_method(self): # type: () -> AnalyzeActionsLROPollingMethod - """Return the polling method associated to this poller. - """ + """Return the polling method associated to this poller.""" return self._polling_method # type: ignore @property diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py index 2654ef4d0449..fab47f9750b7 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py @@ -13,11 +13,12 @@ from ._generated.v3_0 import models as _v3_0_models from ._generated.v3_1 import models as _v3_1_models + def _get_indices(relation): return [int(s) for s in re.findall(r"\d+", relation)] -class DictMixin(object): +class DictMixin(object): def __setitem__(self, key, item): self.__dict__[key] = item @@ -47,7 +48,7 @@ def __contains__(self, key): return key in self.__dict__ def __str__(self): - return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')}) + return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) def has_key(self, k): return k in self.__dict__ @@ -56,13 +57,13 @@ def update(self, *args, **kwargs): return self.__dict__.update(*args, **kwargs) def keys(self): - return [k for k in self.__dict__ if not k.startswith('_')] + return [k for k in self.__dict__ if not k.startswith("_")] def values(self): - return [v for k, v in self.__dict__.items() if not k.startswith('_')] + return [v for k, v in self.__dict__.items() if not k.startswith("_")] def items(self): - return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')] + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] def get(self, key, default=None): if key in self.__dict__: @@ -71,16 +72,14 @@ def get(self, key, default=None): class EntityAssociation(str, Enum): - """Describes if the entity is the subject of the text or if it describes someone else. - """ + """Describes if the entity is the subject of the text or if it describes someone else.""" SUBJECT = "subject" OTHER = "other" class EntityCertainty(str, Enum): - """Describes the entities certainty and polarity. - """ + """Describes the entities certainty and polarity.""" POSITIVE = "positive" POSITIVE_POSSIBLE = "positivePossible" @@ -90,16 +89,14 @@ class EntityCertainty(str, Enum): class EntityConditionality(str, Enum): - """Describes any conditionality on the entity. - """ + """Describes any conditionality on the entity.""" HYPOTHETICAL = "hypothetical" CONDITIONAL = "conditional" class HealthcareEntityRelation(str, Enum): - """Type of relation. Examples include: ``DosageOfMedication`` or 'FrequencyOfMedication', etc. - """ + """Type of relation. Examples include: ``DosageOfMedication`` or 'FrequencyOfMedication', etc.""" ABBREVIATION = "Abbreviation" DIRECTION_OF_BODY_STRUCTURE = "DirectionOfBodyStructure" @@ -139,7 +136,9 @@ class PiiEntityCategory(str, Enum): AT_TAX_IDENTIFICATION_NUMBER = "ATTaxIdentificationNumber" AT_VALUE_ADDED_TAX_NUMBER = "ATValueAddedTaxNumber" AZURE_DOCUMENT_DB_AUTH_KEY = "AzureDocumentDBAuthKey" - AZURE_IAAS_DATABASE_CONNECTION_AND_SQL_STRING = "AzureIAASDatabaseConnectionAndSQLString" + AZURE_IAAS_DATABASE_CONNECTION_AND_SQL_STRING = ( + "AzureIAASDatabaseConnectionAndSQLString" + ) AZURE_IO_T_CONNECTION_STRING = "AzureIoTConnectionString" AZURE_PUBLISH_SETTING_PASSWORD = "AzurePublishSettingPassword" AZURE_REDIS_CACHE_STRING = "AzureRedisCacheString" @@ -227,7 +226,9 @@ class PiiEntityCategory(str, Enum): LV_PERSONAL_CODE = "LVPersonalCode" LT_PERSONAL_CODE = "LTPersonalCode" LU_NATIONAL_IDENTIFICATION_NUMBER_NATURAL = "LUNationalIdentificationNumberNatural" - LU_NATIONAL_IDENTIFICATION_NUMBER_NON_NATURAL = "LUNationalIdentificationNumberNonNatural" + LU_NATIONAL_IDENTIFICATION_NUMBER_NON_NATURAL = ( + "LUNationalIdentificationNumberNonNatural" + ) MY_IDENTITY_CARD_NUMBER = "MYIdentityCardNumber" MT_IDENTITY_CARD_NUMBER = "MTIdentityCardNumber" MT_TAX_ID_NUMBER = "MTTaxIDNumber" @@ -255,7 +256,9 @@ class PiiEntityCategory(str, Enum): RU_PASSPORT_NUMBER_DOMESTIC = "RUPassportNumberDomestic" RU_PASSPORT_NUMBER_INTERNATIONAL = "RUPassportNumberInternational" SA_NATIONAL_ID = "SANationalID" - SG_NATIONAL_REGISTRATION_IDENTITY_CARD_NUMBER = "SGNationalRegistrationIdentityCardNumber" + SG_NATIONAL_REGISTRATION_IDENTITY_CARD_NUMBER = ( + "SGNationalRegistrationIdentityCardNumber" + ) SK_PERSONAL_NUMBER = "SKPersonalNumber" SI_TAX_IDENTIFICATION_NUMBER = "SITaxIdentificationNumber" SI_UNIQUE_MASTER_CITIZEN_NUMBER = "SIUniqueMasterCitizenNumber" @@ -302,8 +305,7 @@ class PiiEntityCategory(str, Enum): class HealthcareEntityCategory(str, Enum): - """Healthcare Entity Category. - """ + """Healthcare Entity Category.""" BODY_STRUCTURE = "BodyStructure" AGE = "Age" @@ -335,7 +337,10 @@ class HealthcareEntityCategory(str, Enum): class PiiEntityDomain(str, Enum): """The different domains of PII entities that users can filter by""" - PROTECTED_HEALTH_INFORMATION = "phi" # See https://aka.ms/tanerpii for more information. + + PROTECTED_HEALTH_INFORMATION = ( + "phi" # See https://aka.ms/tanerpii for more information. + ) class DetectedLanguage(DictMixin): @@ -361,12 +366,15 @@ def __init__(self, **kwargs): @classmethod def _from_generated(cls, language): return cls( - name=language.name, iso6391_name=language.iso6391_name, confidence_score=language.confidence_score + name=language.name, + iso6391_name=language.iso6391_name, + confidence_score=language.confidence_score, ) def __repr__(self): - return "DetectedLanguage(name={}, iso6391_name={}, confidence_score={})" \ - .format(self.name, self.iso6391_name, self.confidence_score)[:1024] + return "DetectedLanguage(name={}, iso6391_name={}, confidence_score={})".format( + self.name, self.iso6391_name, self.confidence_score + )[:1024] class RecognizeEntitiesResult(DictMixin): @@ -399,8 +407,15 @@ def __init__(self, **kwargs): self.is_error = False def __repr__(self): - return "RecognizeEntitiesResult(id={}, entities={}, warnings={}, statistics={}, is_error={})" \ - .format(self.id, repr(self.entities), repr(self.warnings), repr(self.statistics), self.is_error)[:1024] + return "RecognizeEntitiesResult(id={}, entities={}, warnings={}, statistics={}, is_error={})".format( + self.id, + repr(self.entities), + repr(self.warnings), + repr(self.statistics), + self.is_error, + )[ + :1024 + ] class RecognizePiiEntitiesResult(DictMixin): @@ -438,15 +453,17 @@ def __init__(self, **kwargs): self.is_error = False def __repr__(self): - return "RecognizePiiEntitiesResult(id={}, entities={}, redacted_text={}, warnings={}, " \ - "statistics={}, is_error={})" .format( + return ( + "RecognizePiiEntitiesResult(id={}, entities={}, redacted_text={}, warnings={}, " + "statistics={}, is_error={})".format( self.id, repr(self.entities), self.redacted_text, repr(self.warnings), repr(self.statistics), - self.is_error + self.is_error, )[:1024] + ) class AnalyzeHealthcareEntitiesResult(DictMixin): @@ -485,30 +502,42 @@ def __init__(self, **kwargs): self.statistics = kwargs.get("statistics", None) self.is_error = False - @classmethod def _from_generated(cls, healthcare_result): - entities = [HealthcareEntity._from_generated(e) for e in healthcare_result.entities] # pylint: disable=protected-access - relations = [HealthcareRelation._from_generated(r, entities) for r in healthcare_result.relations] # pylint: disable=protected-access + entities = [ + HealthcareEntity._from_generated(e) for e in healthcare_result.entities # pylint: disable=protected-access + ] + relations = [ + HealthcareRelation._from_generated(r, entities) # pylint: disable=protected-access + for r in healthcare_result.relations + ] return cls( id=healthcare_result.id, entities=entities, entity_relations=relations, - warnings=[TextAnalyticsWarning._from_generated(w) for w in healthcare_result.warnings], # pylint: disable=protected-access - statistics=TextDocumentStatistics._from_generated(healthcare_result.statistics), # pylint: disable=protected-access + warnings=[ + TextAnalyticsWarning._from_generated(w) # pylint: disable=protected-access + for w in healthcare_result.warnings + ], + statistics=TextDocumentStatistics._from_generated( # pylint: disable=protected-access + healthcare_result.statistics + ), ) def __repr__(self): - return "AnalyzeHealthcareEntitiesResult(id={}, entities={}, entity_relations={}, warnings={}, "\ - "statistics={}, is_error={})".format( - self.id, - repr(self.entities), - repr(self.entity_relations), - repr(self.warnings), - repr(self.statistics), - self.is_error - )[:1024] + return ( + "AnalyzeHealthcareEntitiesResult(id={}, entities={}, entity_relations={}, warnings={}, " + "statistics={}, is_error={})".format( + self.id, + repr(self.entities), + repr(self.entity_relations), + repr(self.warnings), + repr(self.statistics), + self.is_error, + )[:1024] + ) + class HealthcareRelation(DictMixin): """HealthcareRelation is a result object which represents a relation detected in a document. @@ -531,7 +560,9 @@ def __init__(self, **kwargs): @classmethod def _from_generated(cls, healthcare_relation_result, entities): roles = [ - HealthcareRelationRole._from_generated(r, entities) # pylint: disable=protected-access + HealthcareRelationRole._from_generated( # pylint: disable=protected-access + r, entities + ) for r in healthcare_relation_result.entities ] return cls( @@ -545,6 +576,7 @@ def __repr__(self): repr(self.roles), )[:1024] + class HealthcareRelationRole(DictMixin): """A model representing a role in a relation. @@ -569,14 +601,16 @@ def __init__(self, **kwargs): @staticmethod def _get_entity(healthcare_role_result, entities): nums = _get_indices(healthcare_role_result.ref) - entity_index = nums[1] # first num parsed from index is document #, second is entity index + entity_index = nums[ + 1 + ] # first num parsed from index is document #, second is entity index return entities[entity_index] @classmethod def _from_generated(cls, healthcare_role_result, entities): return cls( name=healthcare_role_result.role, - entity=HealthcareRelationRole._get_entity(healthcare_role_result, entities) + entity=HealthcareRelationRole._get_entity(healthcare_role_result, entities), ) def __repr__(self): @@ -614,9 +648,16 @@ def __init__(self, **kwargs): self.is_error = False def __repr__(self): - return "DetectLanguageResult(id={}, primary_language={}, warnings={}, statistics={}, "\ - "is_error={})".format(self.id, repr(self.primary_language), repr(self.warnings), - repr(self.statistics), self.is_error)[:1024] + return ( + "DetectLanguageResult(id={}, primary_language={}, warnings={}, statistics={}, " + "is_error={})".format( + self.id, + repr(self.primary_language), + repr(self.warnings), + repr(self.statistics), + self.is_error, + )[:1024] + ) class CategorizedEntity(DictMixin): @@ -644,12 +685,12 @@ class CategorizedEntity(DictMixin): """ def __init__(self, **kwargs): - self.text = kwargs.get('text', None) - self.category = kwargs.get('category', None) - self.subcategory = kwargs.get('subcategory', None) - self.length = kwargs.get('length', None) - self.offset = kwargs.get('offset', None) - self.confidence_score = kwargs.get('confidence_score', None) + self.text = kwargs.get("text", None) + self.category = kwargs.get("category", None) + self.subcategory = kwargs.get("subcategory", None) + self.length = kwargs.get("length", None) + self.offset = kwargs.get("offset", None) + self.confidence_score = kwargs.get("confidence_score", None) @classmethod def _from_generated(cls, entity): @@ -670,15 +711,17 @@ def _from_generated(cls, entity): ) def __repr__(self): - return "CategorizedEntity(text={}, category={}, subcategory={}, "\ + return ( + "CategorizedEntity(text={}, category={}, subcategory={}, " "length={}, offset={}, confidence_score={})".format( - self.text, - self.category, - self.subcategory, - self.length, - self.offset, - self.confidence_score - )[:1024] + self.text, + self.category, + self.subcategory, + self.length, + self.offset, + self.confidence_score, + )[:1024] + ) class PiiEntity(DictMixin): @@ -701,12 +744,12 @@ class PiiEntity(DictMixin): """ def __init__(self, **kwargs): - self.text = kwargs.get('text', None) - self.category = kwargs.get('category', None) - self.subcategory = kwargs.get('subcategory', None) - self.length = kwargs.get('length', None) - self.offset = kwargs.get('offset', None) - self.confidence_score = kwargs.get('confidence_score', None) + self.text = kwargs.get("text", None) + self.category = kwargs.get("category", None) + self.subcategory = kwargs.get("subcategory", None) + self.length = kwargs.get("length", None) + self.offset = kwargs.get("offset", None) + self.confidence_score = kwargs.get("confidence_score", None) @classmethod def _from_generated(cls, entity): @@ -721,14 +764,14 @@ def _from_generated(cls, entity): def __repr__(self): return ( - "PiiEntity(text={}, category={}, subcategory={}, length={}, "\ + "PiiEntity(text={}, category={}, subcategory={}, length={}, " "offset={}, confidence_score={})".format( self.text, self.category, self.subcategory, self.length, self.offset, - self.confidence_score + self.confidence_score, )[:1024] ) @@ -791,26 +834,32 @@ def _from_generated(cls, healthcare_entity): offset=healthcare_entity.offset, confidence_score=healthcare_entity.confidence_score, data_sources=[ - HealthcareEntityDataSource(entity_id=l.id, name=l.data_source) for l in healthcare_entity.links - ] if healthcare_entity.links else None + HealthcareEntityDataSource(entity_id=l.id, name=l.data_source) + for l in healthcare_entity.links + ] + if healthcare_entity.links + else None, ) def __hash__(self): return hash(repr(self)) def __repr__(self): - return "HealthcareEntity(text={}, normalized_text={}, category={}, subcategory={}, assertion={}, length={}, "\ - "offset={}, confidence_score={}, data_sources={})".format( - self.text, - self.normalized_text, - self.category, - self.subcategory, - repr(self.assertion), - self.length, - self.offset, - self.confidence_score, - repr(self.data_sources), - )[:1024] + return ( + "HealthcareEntity(text={}, normalized_text={}, category={}, subcategory={}, assertion={}, length={}, " + "offset={}, confidence_score={}, data_sources={})".format( + self.text, + self.normalized_text, + self.category, + self.subcategory, + repr(self.assertion), + self.length, + self.offset, + self.confidence_score, + repr(self.data_sources), + )[:1024] + ) + class HealthcareEntityAssertion(DictMixin): """Contains various assertions about a `HealthcareEntity`. @@ -865,7 +914,9 @@ def __init__(self, **kwargs): self.name = kwargs.get("name", None) def __repr__(self): - return "HealthcareEntityDataSource(entity_id={}, name={})".format(self.entity_id, self.name)[:1024] + return "HealthcareEntityDataSource(entity_id={}, name={})".format( + self.entity_id, self.name + )[:1024] class TextAnalyticsError(DictMixin): @@ -886,9 +937,9 @@ class TextAnalyticsError(DictMixin): """ def __init__(self, **kwargs): - self.code = kwargs.get('code', None) - self.message = kwargs.get('message', None) - self.target = kwargs.get('target', None) + self.code = kwargs.get("code", None) + self.message = kwargs.get("message", None) + self.target = kwargs.get("target", None) @classmethod def _from_generated(cls, err): @@ -896,17 +947,15 @@ def _from_generated(cls, err): return cls( code=err.innererror.code, message=err.innererror.message, - target=err.innererror.target + target=err.innererror.target, ) - return cls( - code=err.code, - message=err.message, - target=err.target - ) + return cls(code=err.code, message=err.message, target=err.target) def __repr__(self): - return "TextAnalyticsError(code={}, message={}, target={})" \ - .format(self.code, self.message, self.target)[:1024] + return "TextAnalyticsError(code={}, message={}, target={})".format( + self.code, self.message, self.target + )[:1024] + class TextAnalyticsWarning(DictMixin): """TextAnalyticsWarning contains the warning code and message that explains why @@ -920,8 +969,8 @@ class TextAnalyticsWarning(DictMixin): """ def __init__(self, **kwargs): - self.code = kwargs.get('code', None) - self.message = kwargs.get('message', None) + self.code = kwargs.get("code", None) + self.message = kwargs.get("message", None) @classmethod def _from_generated(cls, warning): @@ -931,8 +980,9 @@ def _from_generated(cls, warning): ) def __repr__(self): - return "TextAnalyticsWarning(code={}, message={})" \ - .format(self.code, self.message)[:1024] + return "TextAnalyticsWarning(code={}, message={})".format( + self.code, self.message + )[:1024] class ExtractKeyPhrasesResult(DictMixin): @@ -966,8 +1016,15 @@ def __init__(self, **kwargs): self.is_error = False def __repr__(self): - return "ExtractKeyPhrasesResult(id={}, key_phrases={}, warnings={}, statistics={}, is_error={})" \ - .format(self.id, self.key_phrases, repr(self.warnings), repr(self.statistics), self.is_error)[:1024] + return "ExtractKeyPhrasesResult(id={}, key_phrases={}, warnings={}, statistics={}, is_error={})".format( + self.id, + self.key_phrases, + repr(self.warnings), + repr(self.statistics), + self.is_error, + )[ + :1024 + ] class RecognizeLinkedEntitiesResult(DictMixin): @@ -1000,8 +1057,15 @@ def __init__(self, **kwargs): self.is_error = False def __repr__(self): - return "RecognizeLinkedEntitiesResult(id={}, entities={}, warnings={}, statistics={}, is_error={})" \ - .format(self.id, repr(self.entities), repr(self.warnings), repr(self.statistics), self.is_error)[:1024] + return "RecognizeLinkedEntitiesResult(id={}, entities={}, warnings={}, statistics={}, is_error={})".format( + self.id, + repr(self.entities), + repr(self.warnings), + repr(self.statistics), + self.is_error, + )[ + :1024 + ] class AnalyzeSentimentResult(DictMixin): @@ -1045,10 +1109,18 @@ def __init__(self, **kwargs): self.is_error = False def __repr__(self): - return "AnalyzeSentimentResult(id={}, sentiment={}, warnings={}, statistics={}, confidence_scores={}, "\ - "sentences={}, is_error={})".format( - self.id, self.sentiment, repr(self.warnings), repr(self.statistics), - repr(self.confidence_scores), repr(self.sentences), self.is_error)[:1024] + return ( + "AnalyzeSentimentResult(id={}, sentiment={}, warnings={}, statistics={}, confidence_scores={}, " + "sentences={}, is_error={})".format( + self.id, + self.sentiment, + repr(self.warnings), + repr(self.statistics), + repr(self.confidence_scores), + repr(self.sentences), + self.is_error, + )[:1024] + ) class TextDocumentStatistics(DictMixin): @@ -1076,8 +1148,11 @@ def _from_generated(cls, stats): ) def __repr__(self): - return "TextDocumentStatistics(character_count={}, transaction_count={})" \ - .format(self.character_count, self.transaction_count)[:1024] + return ( + "TextDocumentStatistics(character_count={}, transaction_count={})".format( + self.character_count, self.transaction_count + )[:1024] + ) class DocumentError(DictMixin): @@ -1102,30 +1177,39 @@ def __init__(self, **kwargs): def __getattr__(self, attr): result_set = set() result_set.update( - RecognizeEntitiesResult().keys() + RecognizePiiEntitiesResult().keys() - + DetectLanguageResult().keys() + RecognizeLinkedEntitiesResult().keys() - + AnalyzeSentimentResult().keys() + ExtractKeyPhrasesResult().keys() + RecognizeEntitiesResult().keys() + + RecognizePiiEntitiesResult().keys() + + DetectLanguageResult().keys() + + RecognizeLinkedEntitiesResult().keys() + + AnalyzeSentimentResult().keys() + + ExtractKeyPhrasesResult().keys() ) result_attrs = result_set.difference(DocumentError().keys()) if attr in result_attrs: raise AttributeError( "'DocumentError' object has no attribute '{}'. The service was unable to process this document:\n" - "Document Id: {}\nError: {} - {}\n". - format(attr, self.id, self.error.code, self.error.message) + "Document Id: {}\nError: {} - {}\n".format( + attr, self.id, self.error.code, self.error.message + ) ) - raise AttributeError("'DocumentError' object has no attribute '{}'".format(attr)) + raise AttributeError( + "'DocumentError' object has no attribute '{}'".format(attr) + ) @classmethod def _from_generated(cls, doc_err): return cls( id=doc_err.id, - error=TextAnalyticsError._from_generated(doc_err.error), # pylint: disable=protected-access - is_error=True + error=TextAnalyticsError._from_generated( # pylint: disable=protected-access + doc_err.error + ), + is_error=True, ) def __repr__(self): - return "DocumentError(id={}, error={}, is_error={})" \ - .format(self.id, repr(self.error), self.is_error)[:1024] + return "DocumentError(id={}, error={}, is_error={})".format( + self.id, repr(self.error), self.is_error + )[:1024] class DetectLanguageInput(LanguageInput): @@ -1155,8 +1239,9 @@ def __init__(self, **kwargs): self.country_hint = kwargs.get("country_hint", None) def __repr__(self): - return "DetectLanguageInput(id={}, text={}, country_hint={})" \ - .format(self.id, self.text, self.country_hint)[:1024] + return "DetectLanguageInput(id={}, text={}, country_hint={})".format( + self.id, self.text, self.country_hint + )[:1024] class LinkedEntity(DictMixin): @@ -1198,10 +1283,14 @@ def __init__(self, **kwargs): @classmethod def _from_generated(cls, entity): - bing_entity_search_api_id = entity.bing_id if hasattr(entity, "bing_id") else None + bing_entity_search_api_id = ( + entity.bing_id if hasattr(entity, "bing_id") else None + ) return cls( name=entity.name, - matches=[LinkedEntityMatch._from_generated(e) for e in entity.matches], # pylint: disable=protected-access + matches=[ + LinkedEntityMatch._from_generated(e) for e in entity.matches # pylint: disable=protected-access + ], language=entity.language, data_source_entity_id=entity.id, url=entity.url, @@ -1210,7 +1299,8 @@ def _from_generated(cls, entity): ) def __repr__(self): - return "LinkedEntity(name={}, matches={}, language={}, data_source_entity_id={}, url={}, " \ + return ( + "LinkedEntity(name={}, matches={}, language={}, data_source_entity_id={}, url={}, " "data_source={}, bing_entity_search_api_id={})".format( self.name, repr(self.matches), @@ -1219,7 +1309,8 @@ def __repr__(self): self.url, self.data_source, self.bing_entity_search_api_id, - )[:1024] + )[:1024] + ) class LinkedEntityMatch(DictMixin): @@ -1269,7 +1360,9 @@ def _from_generated(cls, match): def __repr__(self): return "LinkedEntityMatch(confidence_score={}, text={}, length={}, offset={})".format( self.confidence_score, self.text, self.length, self.offset - )[:1024] + )[ + :1024 + ] class TextDocumentInput(DictMixin, MultiLanguageInput): @@ -1297,8 +1390,9 @@ def __init__(self, **kwargs): self.language = kwargs.get("language", None) def __repr__(self): - return "TextDocumentInput(id={}, text={}, language={})" \ - .format(self.id, self.text, self.language)[:1024] + return "TextDocumentInput(id={}, text={}, language={})".format( + self.id, self.text, self.language + )[:1024] class TextDocumentBatchStatistics(DictMixin): @@ -1336,9 +1430,15 @@ def _from_generated(cls, statistics): ) def __repr__(self): - return "TextDocumentBatchStatistics(document_count={}, valid_document_count={}, erroneous_document_count={}, " \ - "transaction_count={})".format(self.document_count, self.valid_document_count, - self.erroneous_document_count, self.transaction_count)[:1024] + return ( + "TextDocumentBatchStatistics(document_count={}, valid_document_count={}, erroneous_document_count={}, " + "transaction_count={})".format( + self.document_count, + self.valid_document_count, + self.erroneous_document_count, + self.transaction_count, + )[:1024] + ) class SentenceSentiment(DictMixin): @@ -1391,30 +1491,39 @@ def _from_generated(cls, sentence, results, sentiment): length = None if hasattr(sentence, "targets"): mined_opinions = ( - [MinedOpinion._from_generated(target, results, sentiment) for target in sentence.targets] # pylint: disable=protected-access - if sentence.targets else [] + [ + MinedOpinion._from_generated(target, results, sentiment) # pylint: disable=protected-access + for target in sentence.targets + ] + if sentence.targets + else [] ) else: mined_opinions = None return cls( text=sentence.text, sentiment=sentence.sentiment, - confidence_scores=SentimentConfidenceScores._from_generated(sentence.confidence_scores), # pylint: disable=protected-access + confidence_scores=SentimentConfidenceScores._from_generated( # pylint: disable=protected-access + sentence.confidence_scores + ), length=length, offset=offset, - mined_opinions=mined_opinions + mined_opinions=mined_opinions, ) def __repr__(self): - return "SentenceSentiment(text={}, sentiment={}, confidence_scores={}, "\ + return ( + "SentenceSentiment(text={}, sentiment={}, confidence_scores={}, " "length={}, offset={}, mined_opinions={})".format( - self.text, - self.sentiment, - repr(self.confidence_scores), - self.length, - self.offset, - repr(self.mined_opinions) - )[:1024] + self.text, + self.sentiment, + repr(self.confidence_scores), + self.length, + self.offset, + repr(self.mined_opinions), + )[:1024] + ) + class MinedOpinion(DictMixin): """A mined opinion object represents an opinion we've extracted from a sentence. @@ -1432,10 +1541,14 @@ def __init__(self, **kwargs): self.assessments = kwargs.get("assessments", None) @staticmethod - def _get_assessments(relations, results, sentiment): # pylint: disable=unused-argument + def _get_assessments( + relations, results, sentiment + ): # pylint: disable=unused-argument if not relations: return [] - assessment_relations = [r.ref for r in relations if r.relation_type == "assessment"] + assessment_relations = [ + r.ref for r in relations if r.relation_type == "assessment" + ] assessments = [] for assessment_relation in assessment_relations: nums = _get_indices(assessment_relation) @@ -1449,17 +1562,22 @@ def _get_assessments(relations, results, sentiment): # pylint: disable=unused-a @classmethod def _from_generated(cls, target, results, sentiment): return cls( - target=TargetSentiment._from_generated(target), # pylint: disable=protected-access + target=TargetSentiment._from_generated( # pylint: disable=protected-access + target + ), assessments=[ - AssessmentSentiment._from_generated(assessment) # pylint: disable=protected-access - for assessment in cls._get_assessments(target.relations, results, sentiment) + AssessmentSentiment._from_generated( # pylint: disable=protected-access + assessment + ) + for assessment in cls._get_assessments( + target.relations, results, sentiment + ) ], ) def __repr__(self): return "MinedOpinion(target={}, assessments={})".format( - repr(self.target), - repr(self.assessments) + repr(self.target), repr(self.assessments) )[:1024] @@ -1497,20 +1615,24 @@ def _from_generated(cls, target): return cls( text=target.text, sentiment=target.sentiment, - confidence_scores=SentimentConfidenceScores._from_generated(target.confidence_scores), # pylint: disable=protected-access + confidence_scores=SentimentConfidenceScores._from_generated( # pylint: disable=protected-access + target.confidence_scores + ), length=target.length, offset=target.offset, ) def __repr__(self): - return "TargetSentiment(text={}, sentiment={}, confidence_scores={}, "\ + return ( + "TargetSentiment(text={}, sentiment={}, confidence_scores={}, " "length={}, offset={})".format( - self.text, - self.sentiment, - repr(self.confidence_scores), - self.length, - self.offset, - )[:1024] + self.text, + self.sentiment, + repr(self.confidence_scores), + self.length, + self.offset, + )[:1024] + ) class AssessmentSentiment(DictMixin): @@ -1550,22 +1672,24 @@ def _from_generated(cls, assessment): return cls( text=assessment.text, sentiment=assessment.sentiment, - confidence_scores=SentimentConfidenceScores._from_generated(assessment.confidence_scores), # pylint: disable=protected-access + confidence_scores=SentimentConfidenceScores._from_generated( # pylint: disable=protected-access + assessment.confidence_scores + ), length=assessment.length, offset=assessment.offset, - is_negated=assessment.is_negated + is_negated=assessment.is_negated, ) def __repr__(self): return ( - "AssessmentSentiment(text={}, sentiment={}, confidence_scores={}, length={}, offset={}, " \ + "AssessmentSentiment(text={}, sentiment={}, confidence_scores={}, length={}, offset={}, " "is_negated={})".format( self.text, self.sentiment, repr(self.confidence_scores), self.length, self.offset, - self.is_negated + self.is_negated, )[:1024] ) @@ -1583,32 +1707,38 @@ class SentimentConfidenceScores(DictMixin): """ def __init__(self, **kwargs): - self.positive = kwargs.get('positive', 0.0) - self.neutral = kwargs.get('neutral', 0.0) - self.negative = kwargs.get('negative', 0.0) + self.positive = kwargs.get("positive", 0.0) + self.neutral = kwargs.get("neutral", 0.0) + self.negative = kwargs.get("negative", 0.0) @classmethod def _from_generated(cls, score): return cls( positive=score.positive, neutral=score.neutral if hasattr(score, "neutral") else 0.0, - negative=score.negative + negative=score.negative, ) def __repr__(self): - return "SentimentConfidenceScores(positive={}, neutral={}, negative={})" \ - .format(self.positive, self.neutral, self.negative)[:1024] + return "SentimentConfidenceScores(positive={}, neutral={}, negative={})".format( + self.positive, self.neutral, self.negative + )[:1024] class _AnalyzeActionsType(str, Enum): - """The type of action that was applied to the documents - """ + """The type of action that was applied to the documents""" + RECOGNIZE_ENTITIES = "recognize_entities" #: Entities Recognition action. - RECOGNIZE_PII_ENTITIES = "recognize_pii_entities" #: PII Entities Recognition action. + RECOGNIZE_PII_ENTITIES = ( + "recognize_pii_entities" #: PII Entities Recognition action. + ) EXTRACT_KEY_PHRASES = "extract_key_phrases" #: Key Phrase Extraction action. - RECOGNIZE_LINKED_ENTITIES = "recognize_linked_entities" #: Linked Entities Recognition action. + RECOGNIZE_LINKED_ENTITIES = ( + "recognize_linked_entities" #: Linked Entities Recognition action. + ) ANALYZE_SENTIMENT = "analyze_sentiment" #: Sentiment Analysis action. + class RecognizeEntitiesAction(DictMixin): """RecognizeEntitiesAction encapsulates the parameters for starting a long-running Entities Recognition operation. @@ -1650,8 +1780,11 @@ def __init__(self, **kwargs): self.disable_service_logs = kwargs.get("disable_service_logs", False) def __repr__(self, **kwargs): - return "RecognizeEntitiesAction(model_version={}, string_index_type={}, disable_service_logs={})" \ - .format(self.model_version, self.string_index_type, self.disable_service_logs)[:1024] + return "RecognizeEntitiesAction(model_version={}, string_index_type={}, disable_service_logs={})".format( + self.model_version, self.string_index_type, self.disable_service_logs + )[ + :1024 + ] def to_generated(self): return _v3_1_models.EntitiesTask( @@ -1710,19 +1843,21 @@ class AnalyzeSentimentAction(DictMixin): """ def __init__(self, **kwargs): - self.model_version = kwargs.get('model_version', "latest") - self.show_opinion_mining = kwargs.get('show_opinion_mining', False) - self.string_index_type = kwargs.get('string_index_type', None) + self.model_version = kwargs.get("model_version", "latest") + self.show_opinion_mining = kwargs.get("show_opinion_mining", False) + self.string_index_type = kwargs.get("string_index_type", None) self.disable_service_logs = kwargs.get("disable_service_logs", False) def __repr__(self, **kwargs): - return "AnalyzeSentimentAction(model_version={}, show_opinion_mining={}, string_index_type={}, "\ + return ( + "AnalyzeSentimentAction(model_version={}, show_opinion_mining={}, string_index_type={}, " "disable_service_logs={}".format( - self.model_version, - self.show_opinion_mining, - self.string_index_type, - self.disable_service_logs, - )[:1024] + self.model_version, + self.show_opinion_mining, + self.string_index_type, + self.disable_service_logs, + )[:1024] + ) def to_generated(self): return _v3_1_models.SentimentAnalysisTask( @@ -1793,14 +1928,16 @@ def __init__(self, **kwargs): self.disable_service_logs = kwargs.get("disable_service_logs", False) def __repr__(self, **kwargs): - return "RecognizePiiEntitiesAction(model_version={}, domain_filter={}, categories_filter={}, "\ - "string_index_type={}, disable_service_logs={}".format( - self.model_version, - self.domain_filter, - self.categories_filter, - self.string_index_type, - self.disable_service_logs, - )[:1024] + return ( + "RecognizePiiEntitiesAction(model_version={}, domain_filter={}, categories_filter={}, " + "string_index_type={}, disable_service_logs={}".format( + self.model_version, + self.domain_filter, + self.categories_filter, + self.string_index_type, + self.disable_service_logs, + )[:1024] + ) def to_generated(self): return _v3_1_models.PiiTask( @@ -1809,7 +1946,7 @@ def to_generated(self): domain=self.domain_filter, pii_categories=self.categories_filter, string_index_type=self.string_index_type, - logging_opt_out=self.disable_service_logs + logging_opt_out=self.disable_service_logs, ) ) @@ -1847,8 +1984,11 @@ def __init__(self, **kwargs): self.disable_service_logs = kwargs.get("disable_service_logs", False) def __repr__(self, **kwargs): - return "ExtractKeyPhrasesAction(model_version={}, disable_service_logs={})" \ - .format(self.model_version, self.disable_service_logs)[:1024] + return ( + "ExtractKeyPhrasesAction(model_version={}, disable_service_logs={})".format( + self.model_version, self.disable_service_logs + )[:1024] + ) def to_generated(self): return _v3_1_models.KeyPhrasesTask( @@ -1901,10 +2041,12 @@ def __init__(self, **kwargs): self.disable_service_logs = kwargs.get("disable_service_logs", False) def __repr__(self, **kwargs): - return "RecognizeLinkedEntitiesAction(model_version={}, string_index_type={}), " \ + return ( + "RecognizeLinkedEntitiesAction(model_version={}, string_index_type={}), " "disable_service_logs={}".format( self.model_version, self.string_index_type, self.disable_service_logs )[:1024] + ) def to_generated(self): return _v3_1_models.EntityLinkingTask( diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_policies.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_policies.py index d03249808f28..a442833c2d7d 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_policies.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_policies.py @@ -17,14 +17,18 @@ def __init__(self, **kwargs): super(TextAnalyticsResponseHookPolicy, self).__init__() def on_request(self, request): - self._response_callback = request.context.options.pop("raw_response_hook", self._response_callback) + self._response_callback = request.context.options.pop( + "raw_response_hook", self._response_callback + ) def on_response(self, request, response): if self._is_lro is None: # determine LRO based off of initial response. If 202, we say it's an LRO self._is_lro = response.http_response.status_code == 202 if self._response_callback: - data = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) + data = ContentDecodePolicy.deserialize_from_http_generics( + response.http_response + ) if self._is_lro and (not data or data.get("status") not in _FINISHED): return if data: @@ -32,7 +36,9 @@ def on_response(self, request, response): model_version = data.get("modelVersion", None) if statistics or model_version: - batch_statistics = TextDocumentBatchStatistics._from_generated(statistics) # pylint: disable=protected-access + batch_statistics = TextDocumentBatchStatistics._from_generated( # pylint: disable=protected-access + statistics + ) response.statistics = batch_statistics response.model_version = model_version response.raw_response = data diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py index 6995966d6fc1..83f1c761085a 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py @@ -17,6 +17,7 @@ _AnalyzeActionsType, ) + def _validate_input(documents, hint, whole_input_hint): """Validate that batch input has either all string docs or dict/DetectLanguageInput/TextDocumentInput, not a mix of both. @@ -36,9 +37,13 @@ def _validate_input(documents, hint, whole_input_hint): raise TypeError("Input documents cannot be a dict") if not all(isinstance(x, six.string_types) for x in documents): - if not all(isinstance(x, (dict, TextDocumentInput, DetectLanguageInput)) for x in documents): - raise TypeError("Mixing string and dictionary/object document input unsupported.") - + if not all( + isinstance(x, (dict, TextDocumentInput, DetectLanguageInput)) + for x in documents + ): + raise TypeError( + "Mixing string and dictionary/object document input unsupported." + ) request_batch = [] for idx, doc in enumerate(documents): @@ -50,25 +55,38 @@ def _validate_input(documents, hint, whole_input_hint): if isinstance(doc, dict): item_hint = doc.get(hint, None) if item_hint is None: - doc = {"id": doc.get("id", None), hint: whole_input_hint, "text": doc.get("text", None)} + doc = { + "id": doc.get("id", None), + hint: whole_input_hint, + "text": doc.get("text", None), + } elif item_hint.lower() == "none": - doc = {"id": doc.get("id", None), hint: "", "text": doc.get("text", None)} + doc = { + "id": doc.get("id", None), + hint: "", + "text": doc.get("text", None), + } request_batch.append(doc) if isinstance(doc, TextDocumentInput): item_hint = doc.language if item_hint is None: - doc = TextDocumentInput(id=doc.id, language=whole_input_hint, text=doc.text) + doc = TextDocumentInput( + id=doc.id, language=whole_input_hint, text=doc.text + ) request_batch.append(doc) if isinstance(doc, DetectLanguageInput): item_hint = doc.country_hint if item_hint is None: - doc = DetectLanguageInput(id=doc.id, country_hint=whole_input_hint, text=doc.text) + doc = DetectLanguageInput( + id=doc.id, country_hint=whole_input_hint, text=doc.text + ) elif item_hint.lower() == "none": doc = DetectLanguageInput(id=doc.id, country_hint="", text=doc.text) request_batch.append(doc) return request_batch + def _determine_action_type(action): if isinstance(action, RecognizeEntitiesAction): return _AnalyzeActionsType.RECOGNIZE_ENTITIES @@ -80,7 +98,10 @@ def _determine_action_type(action): return _AnalyzeActionsType.ANALYZE_SENTIMENT return _AnalyzeActionsType.EXTRACT_KEY_PHRASES -def _check_string_index_type_arg(string_index_type_arg, api_version, string_index_type_default="UnicodeCodePoint"): + +def _check_string_index_type_arg( + string_index_type_arg, api_version, string_index_type_default="UnicodeCodePoint" +): string_index_type = None if api_version == "v3.0": diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py index 0e4e3dfe1f17..8beaceb1714d 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py @@ -11,7 +11,7 @@ from azure.core.exceptions import ( HttpResponseError, ClientAuthenticationError, - ODataV4Format + ODataV4Format, ) from azure.core.paging import ItemPaged from ._models import ( @@ -35,24 +35,26 @@ _AnalyzeActionsType, ) -class CSODataV4Format(ODataV4Format): +class CSODataV4Format(ODataV4Format): def __init__(self, odata_error): try: if odata_error["error"]["innererror"]: - super(CSODataV4Format, self).__init__(odata_error["error"]["innererror"]) + super(CSODataV4Format, self).__init__( + odata_error["error"]["innererror"] + ) except KeyError: super(CSODataV4Format, self).__init__(odata_error) def process_http_response_error(error): - """Raise detailed error message. - """ + """Raise detailed error message.""" raise_error = HttpResponseError if error.status_code == 401: raise_error = ClientAuthenticationError raise raise_error(response=error.response, error_format=CSODataV4Format) + def order_results(response, combined): """Order results in the order the user passed them in. @@ -77,13 +79,17 @@ def order_lro_results(doc_id_order, combined): """ mapping = [(item.id, item) for item in combined] - ordered_response = [i[1] for i in sorted(mapping, key=lambda m: doc_id_order.index(m[0]))] + ordered_response = [ + i[1] for i in sorted(mapping, key=lambda m: doc_id_order.index(m[0])) + ] return ordered_response def prepare_result(func): def choose_wrapper(*args, **kwargs): - def wrapper(response, obj, response_headers, ordering_function): # pylint: disable=unused-argument + def wrapper( + response, obj, response_headers, ordering_function + ): # pylint: disable=unused-argument if obj.errors: combined = obj.documents + obj.errors results = ordering_function(response, combined) @@ -93,7 +99,9 @@ def wrapper(response, obj, response_headers, ordering_function): # pylint: disa for idx, item in enumerate(results): if hasattr(item, "error"): - results[idx] = DocumentError(id=item.id, error=TextAnalyticsError._from_generated(item.error)) # pylint: disable=protected-access + results[idx] = DocumentError( + id=item.id, error=TextAnalyticsError._from_generated(item.error) # pylint: disable=protected-access + ) else: results[idx] = func(item, results) return results @@ -111,72 +119,133 @@ def wrapper(response, obj, response_headers, ordering_function): # pylint: disa def language_result(language, results): # pylint: disable=unused-argument return DetectLanguageResult( id=language.id, - primary_language=DetectedLanguage._from_generated(language.detected_language), # pylint: disable=protected-access - warnings=[TextAnalyticsWarning._from_generated(w) for w in language.warnings], # pylint: disable=protected-access - statistics=TextDocumentStatistics._from_generated(language.statistics), # pylint: disable=protected-access + primary_language=DetectedLanguage._from_generated( # pylint: disable=protected-access + language.detected_language + ), + warnings=[ + TextAnalyticsWarning._from_generated(w) for w in language.warnings # pylint: disable=protected-access + ], + statistics=TextDocumentStatistics._from_generated( # pylint: disable=protected-access + language.statistics + ), ) @prepare_result -def entities_result(entity, results, *args, **kwargs): # pylint: disable=unused-argument +def entities_result( + entity, results, *args, **kwargs +): # pylint: disable=unused-argument return RecognizeEntitiesResult( id=entity.id, - entities=[CategorizedEntity._from_generated(e) for e in entity.entities], # pylint: disable=protected-access - warnings=[TextAnalyticsWarning._from_generated(w) for w in entity.warnings], # pylint: disable=protected-access - statistics=TextDocumentStatistics._from_generated(entity.statistics), # pylint: disable=protected-access + entities=[ + CategorizedEntity._from_generated(e) for e in entity.entities # pylint: disable=protected-access + ], + warnings=[ + TextAnalyticsWarning._from_generated(w) for w in entity.warnings # pylint: disable=protected-access + ], + statistics=TextDocumentStatistics._from_generated( # pylint: disable=protected-access + entity.statistics + ), ) @prepare_result -def linked_entities_result(entity, results, *args, **kwargs): # pylint: disable=unused-argument +def linked_entities_result( + entity, results, *args, **kwargs +): # pylint: disable=unused-argument return RecognizeLinkedEntitiesResult( id=entity.id, - entities=[LinkedEntity._from_generated(e) for e in entity.entities], # pylint: disable=protected-access - warnings=[TextAnalyticsWarning._from_generated(w) for w in entity.warnings], # pylint: disable=protected-access - statistics=TextDocumentStatistics._from_generated(entity.statistics), # pylint: disable=protected-access + entities=[ + LinkedEntity._from_generated(e) for e in entity.entities # pylint: disable=protected-access + ], + warnings=[ + TextAnalyticsWarning._from_generated(w) for w in entity.warnings # pylint: disable=protected-access + ], + statistics=TextDocumentStatistics._from_generated( # pylint: disable=protected-access + entity.statistics + ), ) @prepare_result -def key_phrases_result(phrases, results, *args, **kwargs): # pylint: disable=unused-argument +def key_phrases_result( + phrases, results, *args, **kwargs +): # pylint: disable=unused-argument return ExtractKeyPhrasesResult( id=phrases.id, key_phrases=phrases.key_phrases, - warnings=[TextAnalyticsWarning._from_generated(w) for w in phrases.warnings], # pylint: disable=protected-access - statistics=TextDocumentStatistics._from_generated(phrases.statistics), # pylint: disable=protected-access + warnings=[ + TextAnalyticsWarning._from_generated(w) for w in phrases.warnings # pylint: disable=protected-access + ], + statistics=TextDocumentStatistics._from_generated( # pylint: disable=protected-access + phrases.statistics + ), ) @prepare_result -def sentiment_result(sentiment, results, *args, **kwargs): # pylint: disable=unused-argument +def sentiment_result( + sentiment, results, *args, **kwargs +): # pylint: disable=unused-argument return AnalyzeSentimentResult( id=sentiment.id, sentiment=sentiment.sentiment, - warnings=[TextAnalyticsWarning._from_generated(w) for w in sentiment.warnings], # pylint: disable=protected-access - statistics=TextDocumentStatistics._from_generated(sentiment.statistics), # pylint: disable=protected-access - confidence_scores=SentimentConfidenceScores._from_generated(sentiment.confidence_scores), # pylint: disable=protected-access - sentences=[SentenceSentiment._from_generated(s, results, sentiment) for s in sentiment.sentences], # pylint: disable=protected-access + warnings=[ + TextAnalyticsWarning._from_generated(w) for w in sentiment.warnings # pylint: disable=protected-access + ], + statistics=TextDocumentStatistics._from_generated( # pylint: disable=protected-access + sentiment.statistics + ), + confidence_scores=SentimentConfidenceScores._from_generated( # pylint: disable=protected-access + sentiment.confidence_scores + ), + sentences=[ + SentenceSentiment._from_generated(s, results, sentiment) # pylint: disable=protected-access + for s in sentiment.sentences + ], ) + @prepare_result -def pii_entities_result(entity, results, *args, **kwargs): # pylint: disable=unused-argument +def pii_entities_result( + entity, results, *args, **kwargs +): # pylint: disable=unused-argument return RecognizePiiEntitiesResult( id=entity.id, - entities=[PiiEntity._from_generated(e) for e in entity.entities], # pylint: disable=protected-access - redacted_text=entity.redacted_text if hasattr(entity, "redacted_text") else None, - warnings=[TextAnalyticsWarning._from_generated(w) for w in entity.warnings], # pylint: disable=protected-access - statistics=TextDocumentStatistics._from_generated(entity.statistics), # pylint: disable=protected-access + entities=[ + PiiEntity._from_generated(e) for e in entity.entities # pylint: disable=protected-access + ], + redacted_text=entity.redacted_text + if hasattr(entity, "redacted_text") + else None, + warnings=[ + TextAnalyticsWarning._from_generated(w) for w in entity.warnings # pylint: disable=protected-access + ], + statistics=TextDocumentStatistics._from_generated( # pylint: disable=protected-access + entity.statistics + ), ) @prepare_result -def healthcare_result(health_result, results, *args, **kwargs): # pylint: disable=unused-argument - return AnalyzeHealthcareEntitiesResult._from_generated(health_result) # pylint: disable=protected-access +def healthcare_result( + health_result, results, *args, **kwargs +): # pylint: disable=unused-argument + return AnalyzeHealthcareEntitiesResult._from_generated( # pylint: disable=protected-access + health_result + ) -def healthcare_extract_page_data(doc_id_order, obj, response_headers, health_job_state): # pylint: disable=unused-argument - return (health_job_state.next_link, - healthcare_result(doc_id_order, health_job_state.results, response_headers, lro=True)) +def healthcare_extract_page_data( + doc_id_order, obj, response_headers, health_job_state +): # pylint: disable=unused-argument + return ( + health_job_state.next_link, + healthcare_result( + doc_id_order, health_job_state.results, response_headers, lro=True + ), + ) + def _get_deserialization_callback_from_task_type(task_type): if task_type == _AnalyzeActionsType.RECOGNIZE_ENTITIES: @@ -189,6 +258,7 @@ def _get_deserialization_callback_from_task_type(task_type): return sentiment_result return key_phrases_result + def _get_property_name_from_task_type(task_type): if task_type == _AnalyzeActionsType.RECOGNIZE_ENTITIES: return "entity_recognition_tasks" @@ -200,17 +270,31 @@ def _get_property_name_from_task_type(task_type): return "sentiment_analysis_tasks" return "key_phrase_extraction_tasks" -def _get_good_result(current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object): - deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type) + +def _get_good_result( + current_task_type, + index_of_task_result, + doc_id_order, + response_headers, + returned_tasks_object, +): + deserialization_callback = _get_deserialization_callback_from_task_type( + current_task_type + ) property_name = _get_property_name_from_task_type(current_task_type) - response_task_to_deserialize = getattr(returned_tasks_object, property_name)[index_of_task_result] + response_task_to_deserialize = getattr(returned_tasks_object, property_name)[ + index_of_task_result + ] return deserialization_callback( doc_id_order, response_task_to_deserialize.results, response_headers, lro=True ) + def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state): - iter_items = defaultdict(list) # map doc id to action results - task_type_to_index = defaultdict(int) # need to keep track of how many of each type of tasks we've seen + iter_items = defaultdict(list) # map doc id to action results + task_type_to_index = defaultdict( + int + ) # need to keep track of how many of each type of tasks we've seen returned_tasks_object = analyze_job_state.tasks for current_task_type in task_order: index_of_task_result = task_type_to_index[current_task_type] @@ -225,19 +309,22 @@ def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state iter_items[result.id].append(result) task_type_to_index[current_task_type] += 1 - return [ - iter_items[doc_id] - for doc_id in doc_id_order - if doc_id in iter_items - ] + return [iter_items[doc_id] for doc_id in doc_id_order if doc_id in iter_items] + -def analyze_extract_page_data(doc_id_order, task_order, response_headers, analyze_job_state): +def analyze_extract_page_data( + doc_id_order, task_order, response_headers, analyze_job_state +): # return next link, list of - iter_items = get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state) + iter_items = get_iter_items( + doc_id_order, task_order, response_headers, analyze_job_state + ) return analyze_job_state.next_link, iter_items -def lro_get_next_page(lro_status_callback, first_page, continuation_token, show_stats=False): +def lro_get_next_page( + lro_status_callback, first_page, continuation_token, show_stats=False +): if continuation_token is None: return first_page @@ -257,14 +344,33 @@ def lro_get_next_page(lro_status_callback, first_page, continuation_token, show_ return lro_status_callback(job_id, **query_params) -def healthcare_paged_result(doc_id_order, health_status_callback, _, obj, response_headers, show_stats=False): # pylint: disable=unused-argument +def healthcare_paged_result( + doc_id_order, health_status_callback, _, obj, response_headers, show_stats=False +): # pylint: disable=unused-argument return ItemPaged( - functools.partial(lro_get_next_page, health_status_callback, obj, show_stats=show_stats), - functools.partial(healthcare_extract_page_data, doc_id_order, obj, response_headers), + functools.partial( + lro_get_next_page, health_status_callback, obj, show_stats=show_stats + ), + functools.partial( + healthcare_extract_page_data, doc_id_order, obj, response_headers + ), ) -def analyze_paged_result(doc_id_order, task_order, analyze_status_callback, _, obj, response_headers, show_stats=False): # pylint: disable=unused-argument + +def analyze_paged_result( + doc_id_order, + task_order, + analyze_status_callback, + _, + obj, + response_headers, + show_stats=False, +): # pylint: disable=unused-argument return ItemPaged( - functools.partial(lro_get_next_page, analyze_status_callback, obj, show_stats=show_stats), - functools.partial(analyze_extract_page_data, doc_id_order, task_order, response_headers) + functools.partial( + lro_get_next_page, analyze_status_callback, obj, show_stats=show_stats + ), + functools.partial( + analyze_extract_page_data, doc_id_order, task_order, response_headers + ), ) diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers_async.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers_async.py index 682baeba6b4c..3267ff92ca51 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers_async.py @@ -12,12 +12,20 @@ from ._response_handlers import healthcare_result, get_iter_items -async def healthcare_extract_page_data_async(doc_id_order, obj, response_headers, health_job_state): # pylint: disable=unused-argument - return (health_job_state.next_link, - healthcare_result(doc_id_order, health_job_state.results, response_headers, lro=True)) +async def healthcare_extract_page_data_async( + doc_id_order, obj, response_headers, health_job_state +): # pylint: disable=unused-argument + return ( + health_job_state.next_link, + healthcare_result( + doc_id_order, health_job_state.results, response_headers, lro=True + ), + ) -async def lro_get_next_page_async(lro_status_callback, first_page, continuation_token, show_stats=False): +async def lro_get_next_page_async( + lro_status_callback, first_page, continuation_token, show_stats=False +): if continuation_token is None: return first_page @@ -37,20 +45,45 @@ async def lro_get_next_page_async(lro_status_callback, first_page, continuation_ return await lro_status_callback(job_id, **query_params) -def healthcare_paged_result(doc_id_order, health_status_callback, response, obj, response_headers, show_stats=False): # pylint: disable=unused-argument +def healthcare_paged_result( + doc_id_order, + health_status_callback, + response, + obj, + response_headers, + show_stats=False, +): # pylint: disable=unused-argument return AsyncItemPaged( - functools.partial(lro_get_next_page_async, health_status_callback, obj, show_stats=show_stats), - functools.partial(healthcare_extract_page_data_async, doc_id_order, obj, response_headers), + functools.partial( + lro_get_next_page_async, health_status_callback, obj, show_stats=show_stats + ), + functools.partial( + healthcare_extract_page_data_async, doc_id_order, obj, response_headers + ), ) -async def analyze_extract_page_data_async(doc_id_order, task_order, response_headers, analyze_job_state): - iter_items = get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state) + +async def analyze_extract_page_data_async( + doc_id_order, task_order, response_headers, analyze_job_state +): + iter_items = get_iter_items( + doc_id_order, task_order, response_headers, analyze_job_state + ) return analyze_job_state.next_link, AsyncList(iter_items) + def analyze_paged_result( - doc_id_order, task_order, analyze_status_callback, response, obj, response_headers, show_stats=False # pylint: disable=unused-argument + doc_id_order, + task_order, + analyze_status_callback, + response, # pylint: disable=unused-argument + obj, + response_headers, + show_stats=False, # pylint: disable=unused-argument ): return AsyncItemPaged( functools.partial(lro_get_next_page_async, analyze_status_callback, obj), - functools.partial(analyze_extract_page_data_async, doc_id_order, task_order, response_headers), + functools.partial( + analyze_extract_page_data_async, doc_id_order, task_order, response_headers + ), ) diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py index 7255d32eedf4..52900fac63d2 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py @@ -19,7 +19,7 @@ from ._request_handlers import ( _validate_input, _determine_action_type, - _check_string_index_type_arg + _check_string_index_type_arg, ) from ._response_handlers import ( process_http_response_error, @@ -109,15 +109,14 @@ class TextAnalyticsClient(TextAnalyticsClientBase): def __init__(self, endpoint, credential, **kwargs): # type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None super(TextAnalyticsClient, self).__init__( - endpoint=endpoint, - credential=credential, - **kwargs + endpoint=endpoint, credential=credential, **kwargs ) self._api_version = kwargs.get("api_version") self._default_language = kwargs.pop("default_language", "en") self._default_country_hint = kwargs.pop("default_country_hint", "US") - self._string_index_type_default = None if kwargs.get("api_version") == "v3.0" else "UnicodeCodePoint" - + self._string_index_type_default = ( + None if kwargs.get("api_version") == "v3.0" else "UnicodeCodePoint" + ) @distributed_trace def detect_language( # type: ignore @@ -178,13 +177,17 @@ def detect_language( # type: ignore :caption: Detecting language in a batch of documents. """ country_hint_arg = kwargs.pop("country_hint", None) - country_hint = country_hint_arg if country_hint_arg is not None else self._default_country_hint + country_hint = ( + country_hint_arg + if country_hint_arg is not None + else self._default_country_hint + ) docs = _validate_input(documents, "country_hint", country_hint) model_version = kwargs.pop("model_version", None) show_stats = kwargs.pop("show_stats", False) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs try: return self._client.languages( documents=docs, @@ -267,13 +270,13 @@ def recognize_entities( # type: ignore string_index_type = _check_string_index_type_arg( kwargs.pop("string_index_type", None), self._api_version, - string_index_type_default=self._string_index_type_default + string_index_type_default=self._string_index_type_default, ) if string_index_type: kwargs.update({"string_index_type": string_index_type}) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs try: return self._client.entities_recognition_general( @@ -369,13 +372,13 @@ def recognize_pii_entities( # type: ignore string_index_type = _check_string_index_type_arg( kwargs.pop("string_index_type", None), self._api_version, - string_index_type_default=self._string_index_type_default + string_index_type_default=self._string_index_type_default, ) if string_index_type: kwargs.update({"string_index_type": string_index_type}) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs try: return self._client.entities_recognition_pii( @@ -388,7 +391,10 @@ def recognize_pii_entities( # type: ignore **kwargs ) except ValueError as error: - if "API version v3.0 does not have operation 'entities_recognition_pii'" in str(error): + if ( + "API version v3.0 does not have operation 'entities_recognition_pii'" + in str(error) + ): raise ValueError( "'recognize_pii_entities' endpoint is only available for API version V3_1 and up" ) @@ -467,12 +473,12 @@ def recognize_linked_entities( # type: ignore show_stats = kwargs.pop("show_stats", False) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs string_index_type = _check_string_index_type_arg( kwargs.pop("string_index_type", None), self._api_version, - string_index_type_default=self._string_index_type_default + string_index_type_default=self._string_index_type_default, ) if string_index_type: kwargs.update({"string_index_type": string_index_type}) @@ -488,17 +494,19 @@ def recognize_linked_entities( # type: ignore except HttpResponseError as error: process_http_response_error(error) - def _healthcare_result_callback(self, doc_id_order, raw_response, _, headers, show_stats=False): - healthcare_result = self._client.models(api_version="v3.1").HealthcareJobState.deserialize( - raw_response - ) + def _healthcare_result_callback( + self, doc_id_order, raw_response, _, headers, show_stats=False + ): + healthcare_result = self._client.models( + api_version="v3.1" + ).HealthcareJobState.deserialize(raw_response) return healthcare_paged_result( doc_id_order, self._client.health_status, raw_response, healthcare_result, headers, - show_stats=show_stats + show_stats=show_stats, ) @distributed_trace @@ -571,17 +579,22 @@ def begin_analyze_healthcare_entities( # type: ignore show_stats = kwargs.pop("show_stats", False) polling_interval = kwargs.pop("polling_interval", 5) continuation_token = kwargs.pop("continuation_token", None) - string_index_type = kwargs.pop("string_index_type", self._string_index_type_default) + string_index_type = kwargs.pop( + "string_index_type", self._string_index_type_default + ) doc_id_order = [doc.get("id") for doc in docs] my_cls = kwargs.pop( - "cls", partial(self._healthcare_result_callback, doc_id_order, show_stats=show_stats) + "cls", + partial( + self._healthcare_result_callback, doc_id_order, show_stats=show_stats + ), ) disable_service_logs = kwargs.pop("disable_service_logs", None) polling_kwargs = kwargs operation_kwargs = copy.copy(kwargs) if disable_service_logs is not None: - operation_kwargs['logging_opt_out'] = disable_service_logs + operation_kwargs["logging_opt_out"] = disable_service_logs try: return self._client.begin_health( @@ -595,7 +608,8 @@ def begin_analyze_healthcare_entities( # type: ignore lro_algorithms=[ TextAnalyticsOperationResourcePolling(show_stats=show_stats) ], - **polling_kwargs), + **polling_kwargs + ), continuation_token=continuation_token, **operation_kwargs ) @@ -611,7 +625,6 @@ def begin_analyze_healthcare_entities( # type: ignore except HttpResponseError as error: process_http_response_error(error) - @distributed_trace def extract_key_phrases( # type: ignore self, @@ -679,7 +692,7 @@ def extract_key_phrases( # type: ignore show_stats = kwargs.pop("show_stats", False) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs try: return self._client.key_phrases( @@ -772,18 +785,21 @@ def analyze_sentiment( # type: ignore show_opinion_mining = kwargs.pop("show_opinion_mining", None) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs string_index_type = _check_string_index_type_arg( kwargs.pop("string_index_type", None), self._api_version, - string_index_type_default=self._string_index_type_default + string_index_type_default=self._string_index_type_default, ) if string_index_type: kwargs.update({"string_index_type": string_index_type}) if show_opinion_mining is not None: - if self._api_version == TextAnalyticsApiVersion.V3_0 and show_opinion_mining: + if ( + self._api_version == TextAnalyticsApiVersion.V3_0 + and show_opinion_mining + ): raise ValueError( "'show_opinion_mining' is only available for API version v3.1 and up" ) @@ -800,10 +816,12 @@ def analyze_sentiment( # type: ignore except HttpResponseError as error: process_http_response_error(error) - def _analyze_result_callback(self, doc_id_order, task_order, raw_response, _, headers, show_stats=False): - analyze_result = self._client.models(api_version="v3.1").AnalyzeJobState.deserialize( - raw_response - ) + def _analyze_result_callback( + self, doc_id_order, task_order, raw_response, _, headers, show_stats=False + ): + analyze_result = self._client.models( + api_version="v3.1" + ).AnalyzeJobState.deserialize(raw_response) return analyze_paged_result( doc_id_order, task_order, @@ -811,7 +829,7 @@ def _analyze_result_callback(self, doc_id_order, task_order, raw_response, _, he raw_response, analyze_result, headers, - show_stats=show_stats + show_stats=show_stats, ) @distributed_trace @@ -897,47 +915,74 @@ def begin_analyze_actions( # type: ignore raise ValueError("Multiple of the same action is not currently supported.") try: - analyze_tasks = self._client.models(api_version='v3.1').JobManifestTasks( + analyze_tasks = self._client.models(api_version="v3.1").JobManifestTasks( entity_recognition_tasks=[ - t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_ENTITIES] + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.RECOGNIZE_ENTITIES + ] ], entity_recognition_pii_tasks=[ - t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES] + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES + ] ], key_phrase_extraction_tasks=[ - t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.EXTRACT_KEY_PHRASES] + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.EXTRACT_KEY_PHRASES + ] ], entity_linking_tasks=[ - t.to_generated() for t in - [ - a for a in actions - if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES ] ], sentiment_analysis_tasks=[ - t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.ANALYZE_SENTIMENT] - ] + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.ANALYZE_SENTIMENT + ] + ], ) - analyze_body = self._client.models(api_version='v3.1').AnalyzeBatchInput( - display_name=display_name, - tasks=analyze_tasks, - analysis_input=docs + analyze_body = self._client.models(api_version="v3.1").AnalyzeBatchInput( + display_name=display_name, tasks=analyze_tasks, analysis_input=docs ) return self._client.begin_analyze( body=analyze_body, - cls=kwargs.pop("cls", partial( - self._analyze_result_callback, doc_id_order, task_order, show_stats=show_stats - )), + cls=kwargs.pop( + "cls", + partial( + self._analyze_result_callback, + doc_id_order, + task_order, + show_stats=show_stats, + ), + ), polling=AnalyzeActionsLROPollingMethod( timeout=polling_interval, lro_algorithms=[ TextAnalyticsOperationResourcePolling(show_stats=show_stats) ], - **kwargs), + **kwargs + ), continuation_token=continuation_token, **kwargs ) diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/__init__.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/__init__.py index a0af0804b106..766fbc0bb994 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/__init__.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/__init__.py @@ -11,7 +11,7 @@ ) __all__ = [ - 'TextAnalyticsClient', - 'AsyncAnalyzeHealthcareEntitiesLROPoller', - 'AsyncAnalyzeActionsLROPoller', + "TextAnalyticsClient", + "AsyncAnalyzeHealthcareEntitiesLROPoller", + "AsyncAnalyzeActionsLROPoller", ] diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_base_client_async.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_base_client_async.py index e68bce1eec95..dac8da7b7ad7 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_base_client_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_base_client_async.py @@ -21,8 +21,10 @@ def _authentication_policy(credential): name="Ocp-Apim-Subscription-Key", credential=credential ) elif credential is not None and not hasattr(credential, "get_token"): - raise TypeError("Unsupported credential: {}. Use an instance of AzureKeyCredential " - "or a token credential from azure.identity".format(type(credential))) + raise TypeError( + "Unsupported credential: {}. Use an instance of AzureKeyCredential " + "or a token credential from azure.identity".format(type(credential)) + ) return authentication_policy diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_lro_async.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_lro_async.py index c09bbe02b9f0..275ad7feca2d 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_lro_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_lro_async.py @@ -18,7 +18,6 @@ class TextAnalyticsAsyncLROPollingMethod(AsyncLROBasePolling): - def finished(self): """Is this polling finished? :rtype: bool @@ -78,15 +77,20 @@ async def _poll(self): # pylint:disable=invalid-overridden-method self._pipeline_response.http_response ) -class AsyncAnalyzeHealthcareEntitiesLROPollingMethod(TextAnalyticsAsyncLROPollingMethod): +class AsyncAnalyzeHealthcareEntitiesLROPollingMethod( + TextAnalyticsAsyncLROPollingMethod +): def __init__(self, *args, **kwargs): self._text_analytics_client = kwargs.pop("text_analytics_client") - super(AsyncAnalyzeHealthcareEntitiesLROPollingMethod, self).__init__(*args, **kwargs) + super(AsyncAnalyzeHealthcareEntitiesLROPollingMethod, self).__init__( + *args, **kwargs + ) @property def _current_body(self): from .._generated.v3_1.models import JobMetadata + return JobMetadata.deserialize(self._pipeline_response) @property @@ -115,10 +119,8 @@ def id(self): class AsyncAnalyzeHealthcareEntitiesLROPoller(AsyncLROPoller[PollingReturnType]): - def polling_method(self) -> AsyncAnalyzeHealthcareEntitiesLROPollingMethod: # type: ignore - """Return the polling method associated to this poller. - """ + """Return the polling method associated to this poller.""" return self._polling_method # type: ignore @property @@ -157,9 +159,8 @@ def id(self) -> str: """ return self.polling_method().id - async def cancel( # type: ignore - self, - **kwargs + async def cancel( # type: ignore + self, **kwargs ) -> "AsyncAnalyzeHealthcareEntitiesLROPoller[None]": """Cancel the operation currently being polled. @@ -182,21 +183,24 @@ async def cancel( # type: ignore await self.polling_method().update_status() try: - return await getattr(self._polling_method, "_text_analytics_client").begin_cancel_health_job( + return await getattr( + self._polling_method, "_text_analytics_client" + ).begin_cancel_health_job( self.id, - polling=TextAnalyticsAsyncLROPollingMethod(timeout=polling_interval) + polling=TextAnalyticsAsyncLROPollingMethod(timeout=polling_interval), ) except HttpResponseError as error: from .._response_handlers import process_http_response_error + process_http_response_error(error) class AsyncAnalyzeActionsLROPollingMethod(TextAnalyticsAsyncLROPollingMethod): - @property def _current_body(self): from .._generated.v3_1.models import AnalyzeJobMetadata + return AnalyzeJobMetadata.deserialize(self._pipeline_response) @property @@ -221,19 +225,19 @@ def expires_on(self): def actions_failed_count(self): if not self._current_body: return None - return self._current_body.additional_properties['tasks']['failed'] + return self._current_body.additional_properties["tasks"]["failed"] @property def actions_in_progress_count(self): if not self._current_body: return None - return self._current_body.additional_properties['tasks']['inProgress'] + return self._current_body.additional_properties["tasks"]["inProgress"] @property def actions_succeeded_count(self): if not self._current_body: return None - return self._current_body.additional_properties['tasks']["completed"] + return self._current_body.additional_properties["tasks"]["completed"] @property def last_modified_on(self): @@ -245,7 +249,7 @@ def last_modified_on(self): def total_actions_count(self): if not self._current_body: return None - return self._current_body.additional_properties['tasks']["total"] + return self._current_body.additional_properties["tasks"]["total"] @property def id(self): @@ -255,10 +259,8 @@ def id(self): class AsyncAnalyzeActionsLROPoller(AsyncLROPoller[PollingReturnType]): - def polling_method(self) -> AsyncAnalyzeActionsLROPollingMethod: # type: ignore - """Return the polling method associated to this poller. - """ + """Return the polling method associated to this poller.""" return self._polling_method # type: ignore @property diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py index 6b5051031ac5..8e1347ed3af9 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py @@ -3,14 +3,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +# pylint: disable=too-many-lines + import copy -from typing import ( - Union, - Any, - List, - Dict, - TYPE_CHECKING -) +from typing import Union, Any, List, Dict, TYPE_CHECKING from functools import partial from azure.core.async_paging import AsyncItemPaged from azure.core.tracing.decorator_async import distributed_trace_async @@ -18,7 +14,11 @@ from azure.core.credentials import AzureKeyCredential from ._base_client_async import AsyncTextAnalyticsClientBase from .._base_client import TextAnalyticsApiVersion -from .._request_handlers import _validate_input, _determine_action_type, _check_string_index_type_arg +from .._request_handlers import ( + _validate_input, + _determine_action_type, + _check_string_index_type_arg, +) from .._response_handlers import ( process_http_response_error, entities_result, @@ -106,23 +106,23 @@ def __init__( # type: ignore self, endpoint: str, credential: Union["AzureKeyCredential", "AsyncTokenCredential"], - **kwargs: Any + **kwargs: Any, ) -> None: super(TextAnalyticsClient, self).__init__( - endpoint=endpoint, - credential=credential, - **kwargs + endpoint=endpoint, credential=credential, **kwargs ) self._api_version = kwargs.get("api_version") self._default_language = kwargs.pop("default_language", "en") self._default_country_hint = kwargs.pop("default_country_hint", "US") - self._string_code_unit = None if kwargs.get("api_version") == "v3.0" else "UnicodeCodePoint" + self._string_code_unit = ( + None if kwargs.get("api_version") == "v3.0" else "UnicodeCodePoint" + ) @distributed_trace_async async def detect_language( # type: ignore self, documents: Union[List[str], List[DetectLanguageInput], List[Dict[str, str]]], - **kwargs: Any + **kwargs: Any, ) -> List[Union[DetectLanguageResult, DocumentError]]: """Detect language for a batch of documents. @@ -176,20 +176,24 @@ async def detect_language( # type: ignore :caption: Detecting language in a batch of documents. """ country_hint_arg = kwargs.pop("country_hint", None) - country_hint = country_hint_arg if country_hint_arg is not None else self._default_country_hint + country_hint = ( + country_hint_arg + if country_hint_arg is not None + else self._default_country_hint + ) docs = _validate_input(documents, "country_hint", country_hint) model_version = kwargs.pop("model_version", None) show_stats = kwargs.pop("show_stats", False) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs try: return await self._client.languages( documents=docs, model_version=model_version, show_stats=show_stats, cls=kwargs.pop("cls", language_result), - **kwargs + **kwargs, ) except HttpResponseError as error: process_http_response_error(error) @@ -198,7 +202,7 @@ async def detect_language( # type: ignore async def recognize_entities( # type: ignore self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], - **kwargs: Any + **kwargs: Any, ) -> List[Union[RecognizeEntitiesResult, DocumentError]]: """Recognize entities for a batch of documents. @@ -262,12 +266,12 @@ async def recognize_entities( # type: ignore show_stats = kwargs.pop("show_stats", False) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs string_index_type = _check_string_index_type_arg( kwargs.pop("string_index_type", None), self._api_version, - string_index_type_default=self._string_code_unit + string_index_type_default=self._string_code_unit, ) if string_index_type: kwargs.update({"string_index_type": string_index_type}) @@ -278,7 +282,7 @@ async def recognize_entities( # type: ignore model_version=model_version, show_stats=show_stats, cls=kwargs.pop("cls", entities_result), - **kwargs + **kwargs, ) except HttpResponseError as error: process_http_response_error(error) @@ -287,7 +291,7 @@ async def recognize_entities( # type: ignore async def recognize_pii_entities( # type: ignore self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], - **kwargs: Any + **kwargs: Any, ) -> List[Union[RecognizePiiEntitiesResult, DocumentError]]: """Recognize entities containing personal information for a batch of documents. @@ -364,13 +368,13 @@ async def recognize_pii_entities( # type: ignore string_index_type = _check_string_index_type_arg( kwargs.pop("string_index_type", None), self._api_version, - string_index_type_default=self._string_code_unit + string_index_type_default=self._string_code_unit, ) if string_index_type: kwargs.update({"string_index_type": string_index_type}) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs try: return await self._client.entities_recognition_pii( @@ -380,10 +384,13 @@ async def recognize_pii_entities( # type: ignore domain=domain_filter, pii_categories=categories_filter, cls=kwargs.pop("cls", pii_entities_result), - **kwargs + **kwargs, ) except ValueError as error: - if "API version v3.0 does not have operation 'entities_recognition_pii'" in str(error): + if ( + "API version v3.0 does not have operation 'entities_recognition_pii'" + in str(error) + ): raise ValueError( "'recognize_pii_entities' endpoint is only available for API version V3_1 and up" ) @@ -395,7 +402,7 @@ async def recognize_pii_entities( # type: ignore async def recognize_linked_entities( # type: ignore self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], - **kwargs: Any + **kwargs: Any, ) -> List[Union[RecognizeLinkedEntitiesResult, DocumentError]]: """Recognize linked entities from a well-known knowledge base for a batch of documents. @@ -460,12 +467,12 @@ async def recognize_linked_entities( # type: ignore show_stats = kwargs.pop("show_stats", False) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs string_index_type = _check_string_index_type_arg( kwargs.pop("string_index_type", None), self._api_version, - string_index_type_default=self._string_code_unit + string_index_type_default=self._string_code_unit, ) if string_index_type: kwargs.update({"string_index_type": string_index_type}) @@ -476,7 +483,7 @@ async def recognize_linked_entities( # type: ignore model_version=model_version, show_stats=show_stats, cls=kwargs.pop("cls", linked_entities_result), - **kwargs + **kwargs, ) except HttpResponseError as error: process_http_response_error(error) @@ -485,7 +492,7 @@ async def recognize_linked_entities( # type: ignore async def extract_key_phrases( # type: ignore self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], - **kwargs: Any + **kwargs: Any, ) -> List[Union[ExtractKeyPhrasesResult, DocumentError]]: """Extract key phrases from a batch of documents. @@ -547,14 +554,14 @@ async def extract_key_phrases( # type: ignore show_stats = kwargs.pop("show_stats", False) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs try: return await self._client.key_phrases( documents=docs, model_version=model_version, show_stats=show_stats, cls=kwargs.pop("cls", key_phrases_result), - **kwargs + **kwargs, ) except HttpResponseError as error: process_http_response_error(error) @@ -563,7 +570,7 @@ async def extract_key_phrases( # type: ignore async def analyze_sentiment( # type: ignore self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], - **kwargs: Any + **kwargs: Any, ) -> List[Union[AnalyzeSentimentResult, DocumentError]]: """Analyze sentiment for a batch of documents. Turn on opinion mining with `show_opinion_mining`. @@ -636,18 +643,21 @@ async def analyze_sentiment( # type: ignore show_opinion_mining = kwargs.pop("show_opinion_mining", None) disable_service_logs = kwargs.pop("disable_service_logs", None) if disable_service_logs is not None: - kwargs['logging_opt_out'] = disable_service_logs + kwargs["logging_opt_out"] = disable_service_logs string_index_type = _check_string_index_type_arg( kwargs.pop("string_index_type", None), self._api_version, - string_index_type_default=self._string_code_unit + string_index_type_default=self._string_code_unit, ) if string_index_type: kwargs.update({"string_index_type": string_index_type}) if show_opinion_mining is not None: - if self._api_version == TextAnalyticsApiVersion.V3_0 and show_opinion_mining: + if ( + self._api_version == TextAnalyticsApiVersion.V3_0 + and show_opinion_mining + ): raise ValueError( "'show_opinion_mining' is only available for API version v3.1 and up" ) @@ -659,22 +669,24 @@ async def analyze_sentiment( # type: ignore model_version=model_version, show_stats=show_stats, cls=kwargs.pop("cls", sentiment_result), - **kwargs + **kwargs, ) except HttpResponseError as error: process_http_response_error(error) - def _healthcare_result_callback(self, doc_id_order, raw_response, _, headers, show_stats=False): - healthcare_result = self._client.models(api_version="v3.1").HealthcareJobState.deserialize( - raw_response - ) + def _healthcare_result_callback( + self, doc_id_order, raw_response, _, headers, show_stats=False + ): + healthcare_result = self._client.models( + api_version="v3.1" + ).HealthcareJobState.deserialize(raw_response) return healthcare_paged_result( doc_id_order, self._client.health_status, raw_response, healthcare_result, headers, - show_stats=show_stats + show_stats=show_stats, ) @distributed_trace_async @@ -682,7 +694,9 @@ async def begin_analyze_healthcare_entities( # type: ignore self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], **kwargs: Any, - ) -> AsyncAnalyzeHealthcareEntitiesLROPoller[AsyncItemPaged[Union[AnalyzeHealthcareEntitiesResult, DocumentError]]]: + ) -> AsyncAnalyzeHealthcareEntitiesLROPoller[ + AsyncItemPaged[Union[AnalyzeHealthcareEntitiesResult, DocumentError]] + ]: """Analyze healthcare entities and identify relationships between these entities in a batch of documents. Entities are associated with references that can be found in existing knowledge bases, @@ -750,12 +764,15 @@ async def begin_analyze_healthcare_entities( # type: ignore disable_service_logs = kwargs.pop("disable_service_logs", None) doc_id_order = [doc.get("id") for doc in docs] my_cls = kwargs.pop( - "cls", partial(self._healthcare_result_callback, doc_id_order, show_stats=show_stats) + "cls", + partial( + self._healthcare_result_callback, doc_id_order, show_stats=show_stats + ), ) polling_kwargs = kwargs operation_kwargs = copy.copy(kwargs) if disable_service_logs is not None: - operation_kwargs['logging_opt_out'] = disable_service_logs + operation_kwargs["logging_opt_out"] = disable_service_logs try: return await self._client.begin_health( @@ -769,9 +786,10 @@ async def begin_analyze_healthcare_entities( # type: ignore lro_algorithms=[ TextAnalyticsOperationResourcePolling(show_stats=show_stats) ], - **polling_kwargs), + **polling_kwargs, + ), continuation_token=continuation_token, - **operation_kwargs + **operation_kwargs, ) except ValueError as error: @@ -784,10 +802,12 @@ async def begin_analyze_healthcare_entities( # type: ignore except HttpResponseError as error: process_http_response_error(error) - def _analyze_result_callback(self, doc_id_order, task_order, raw_response, _, headers, show_stats=False): - analyze_result = self._client.models(api_version="v3.1").AnalyzeJobState.deserialize( - raw_response - ) + def _analyze_result_callback( + self, doc_id_order, task_order, raw_response, _, headers, show_stats=False + ): + analyze_result = self._client.models( + api_version="v3.1" + ).AnalyzeJobState.deserialize(raw_response) return analyze_paged_result( doc_id_order, task_order, @@ -795,16 +815,37 @@ def _analyze_result_callback(self, doc_id_order, task_order, raw_response, _, he raw_response, analyze_result, headers, - show_stats=show_stats + show_stats=show_stats, ) @distributed_trace_async async def begin_analyze_actions( # type: ignore self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], - actions: List[Union[RecognizeEntitiesAction, RecognizeLinkedEntitiesAction, RecognizePiiEntitiesAction, ExtractKeyPhrasesAction, AnalyzeSentimentAction]], # pylint: disable=line-too-long - **kwargs: Any - ) -> AsyncAnalyzeActionsLROPoller[AsyncItemPaged[List[Union[RecognizeEntitiesResult, RecognizeLinkedEntitiesResult, RecognizePiiEntitiesResult, ExtractKeyPhrasesResult, AnalyzeSentimentResult, DocumentError]]]]: # pylint: disable=line-too-long + actions: List[ + Union[ + RecognizeEntitiesAction, + RecognizeLinkedEntitiesAction, + RecognizePiiEntitiesAction, + ExtractKeyPhrasesAction, + AnalyzeSentimentAction, + ] + ], # pylint: disable=line-too-long + **kwargs: Any, + ) -> AsyncAnalyzeActionsLROPoller[ + AsyncItemPaged[ + List[ + Union[ + RecognizeEntitiesResult, + RecognizeLinkedEntitiesResult, + RecognizePiiEntitiesResult, + ExtractKeyPhrasesResult, + AnalyzeSentimentResult, + DocumentError, + ] + ] + ] + ]: # pylint: disable=line-too-long """Start a long-running operation to perform a variety of text analysis actions over a batch of documents. We recommend you use this function if you're looking to analyze larger documents, and / or @@ -881,49 +922,76 @@ async def begin_analyze_actions( # type: ignore raise ValueError("Multiple of the same action is not currently supported.") try: - analyze_tasks = self._client.models(api_version='v3.1').JobManifestTasks( + analyze_tasks = self._client.models(api_version="v3.1").JobManifestTasks( entity_recognition_tasks=[ - t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_ENTITIES] + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.RECOGNIZE_ENTITIES + ] ], entity_recognition_pii_tasks=[ - t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES] + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES + ] ], key_phrase_extraction_tasks=[ - t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.EXTRACT_KEY_PHRASES] + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.EXTRACT_KEY_PHRASES + ] ], entity_linking_tasks=[ - t.to_generated() for t in - [ - a for a in actions if \ - _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES ] ], sentiment_analysis_tasks=[ - t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.ANALYZE_SENTIMENT] - ] + t.to_generated() + for t in [ + a + for a in actions + if _determine_action_type(a) + == _AnalyzeActionsType.ANALYZE_SENTIMENT + ] + ], ) - analyze_body = self._client.models(api_version='v3.1').AnalyzeBatchInput( - display_name=display_name, - tasks=analyze_tasks, - analysis_input=docs + analyze_body = self._client.models(api_version="v3.1").AnalyzeBatchInput( + display_name=display_name, tasks=analyze_tasks, analysis_input=docs ) return await self._client.begin_analyze( body=analyze_body, - cls=kwargs.pop("cls", partial( - self._analyze_result_callback, doc_id_order, task_order, show_stats=show_stats - )), + cls=kwargs.pop( + "cls", + partial( + self._analyze_result_callback, + doc_id_order, + task_order, + show_stats=show_stats, + ), + ), polling=AsyncAnalyzeActionsLROPollingMethod( timeout=polling_interval, lro_algorithms=[ TextAnalyticsOperationResourcePolling(show_stats=show_stats) ], - **kwargs), + **kwargs, + ), continuation_token=continuation_token, - **kwargs + **kwargs, ) except ValueError as error: