From e085241fe1333959eaa33b9c2dcede40b697744d Mon Sep 17 00:00:00 2001 From: Renan Butkeraites Date: Thu, 31 Oct 2024 15:09:34 -0300 Subject: [PATCH 1/6] Parrallelize discover --- tap_salesforce/__init__.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/tap_salesforce/__init__.py b/tap_salesforce/__init__.py index 58545259..5dde3085 100644 --- a/tap_salesforce/__init__.py +++ b/tap_salesforce/__init__.py @@ -11,6 +11,7 @@ from tap_salesforce.salesforce.bulk import Bulk from tap_salesforce.salesforce.exceptions import ( TapSalesforceException, TapSalesforceQuotaExceededException, TapSalesforceBulkAPIDisabledException) +import concurrent.futures LOGGER = singer.get_logger() @@ -271,44 +272,53 @@ def do_discover(sf): if sf.api_type == 'BULK' and not Bulk(sf).has_permissions(): raise TapSalesforceBulkAPIDisabledException('This client does not have Bulk API permissions, received "API_DISABLED_FOR_ORG" error code') - for sobject_name in sorted(objects_to_discover): - + def process_sobject(sobject_name): # Skip blacklisted SF objects depending on the api_type in use # ChangeEvent objects are not queryable via Bulk or REST (undocumented) if (sobject_name in sf.get_blacklisted_objects() and sobject_name not in ACTIVITY_STREAMS) \ or sobject_name.endswith("ChangeEvent"): - continue + return None, None, None sobject_description = sf.describe(sobject_name) if sobject_description is None: - continue + return None, None, None # Cache customSetting and Tag objects to check for blacklisting after # all objects have been described if sobject_description.get("customSetting"): - sf_custom_setting_objects.append(sobject_name) + return sobject_name, None, None elif sobject_name.endswith("__Tag"): relationship_field = next( (f for f in sobject_description["fields"] if f.get("relationshipName") == "Item"), None) if relationship_field: # Map {"Object":"Object__Tag"} - object_to_tag_references[relationship_field["referenceTo"] - [0]] = sobject_name + return None, {relationship_field["referenceTo"][0]: sobject_name}, None fields = sobject_description['fields'] replication_key = get_replication_key(sobject_name, fields) # Salesforce Objects are skipped when they do not have an Id field - if not [f["name"] for f in fields if f["name"]=="Id"]: + if not [f["name"] for f in fields if f["name"] == "Id"]: LOGGER.info( "Skipping Salesforce Object %s, as it has no Id field", sobject_name) - continue + return None, None, None entry = generate_schema(fields, sf, sobject_name, replication_key) - entries.append(entry) + return None, None, entry + + with concurrent.futures.ThreadPoolExecutor() as executor: + results = list(executor.map(process_sobject, sorted(objects_to_discover))) + + for custom_setting, tag_reference, entry in results: + if custom_setting: + sf_custom_setting_objects.append(custom_setting) + if tag_reference: + object_to_tag_references.update(tag_reference) + if entry: + entries.append(entry) # Handle ListViews views = get_views_list(sf) From e83f0daf3b9adb1a6de5b0dd469705d4f5375f0f Mon Sep 17 00:00:00 2001 From: Renan Butkeraites Date: Thu, 31 Oct 2024 17:52:16 -0300 Subject: [PATCH 2/6] record --- tap_salesforce/__init__.py | 5 +- tap_salesforce/salesforce/__init__.py | 5 +- tap_salesforce/salesforce/bulk.py | 46 ++-- tap_salesforce/salesforce/rest.py | 51 ++-- tap_salesforce/sync.py | 332 ++++++++++++-------------- 5 files changed, 216 insertions(+), 223 deletions(-) diff --git a/tap_salesforce/__init__.py b/tap_salesforce/__init__.py index 5dde3085..841ec26d 100644 --- a/tap_salesforce/__init__.py +++ b/tap_salesforce/__init__.py @@ -475,9 +475,9 @@ def do_sync(sf, catalog, state,config=None): state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('JobID', None) state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('BatchIDs', None) bookmark = state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}) \ - .pop('JobHighestBookmarkSeen', None) + .pop('JobHighestBookmarkSeen', None) existing_bookmark = state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}) \ - .pop(replication_key, None) + .pop(replication_key, None) state = singer.write_bookmark( state, catalog_entry['tap_stream_id'], @@ -506,7 +506,6 @@ def do_sync(sf, catalog, state,config=None): stream_version) counter = sync_stream(sf, catalog_entry, state, input_state, catalog,config) LOGGER.info("%s: Completed sync (%s rows)", stream_name, counter.value) - state["current_stream"] = None singer.write_state(state) LOGGER.info("Finished sync") diff --git a/tap_salesforce/salesforce/__init__.py b/tap_salesforce/salesforce/__init__.py index 186a9209..0b3dfdb4 100644 --- a/tap_salesforce/salesforce/__init__.py +++ b/tap_salesforce/salesforce/__init__.py @@ -9,8 +9,6 @@ import singer.utils as singer_utils from singer import metadata, metrics -from tap_salesforce.salesforce.bulk import Bulk -from tap_salesforce.salesforce.rest import Rest from simplejson.scanner import JSONDecodeError from tap_salesforce.salesforce.exceptions import ( TapSalesforceException, @@ -289,6 +287,7 @@ def _make_request(self, http_method, url, headers=None, body=None, stream=False, elif http_method == "POST": LOGGER.info("Making %s request to %s with body %s", http_method, url, body) resp = self.session.post(url, headers=headers, data=body) + LOGGER.info("Completed %s request to %s with body %s", http_method, url, body) else: raise TapSalesforceException("Unsupported HTTP method") @@ -436,9 +435,11 @@ def query(self, catalog_entry, state, query_override=None): if state["bookmarks"]["ListView"].get("SystemModstamp"): del state["bookmarks"]["ListView"]["SystemModstamp"] if self.api_type == BULK_API_TYPE and query_override is None: + from tap_salesforce.salesforce.bulk import Bulk bulk = Bulk(self) return bulk.query(catalog_entry, state) elif self.api_type == REST_API_TYPE or query_override is not None: + from tap_salesforce.salesforce.rest import Rest rest = Rest(self) return rest.query(catalog_entry, state, query_override=query_override) else: diff --git a/tap_salesforce/salesforce/bulk.py b/tap_salesforce/salesforce/bulk.py index d8867cef..309f5830 100644 --- a/tap_salesforce/salesforce/bulk.py +++ b/tap_salesforce/salesforce/bulk.py @@ -8,7 +8,7 @@ from singer import metrics import requests from requests.exceptions import RequestException - +import concurrent.futures import xmltodict from tap_salesforce.salesforce.exceptions import ( @@ -16,7 +16,7 @@ BATCH_STATUS_POLLING_SLEEP = 20 PK_CHUNKED_BATCH_STATUS_POLLING_SLEEP = 60 -ITER_CHUNK_SIZE = 1024 +ITER_CHUNK_SIZE = 2**15 DEFAULT_CHUNK_SIZE = 100000 # Max is 250000 LOGGER = singer.get_logger() @@ -197,7 +197,7 @@ def _create_job(self, catalog_entry, pk_chunking=False): return job['id'] def _add_batch(self, catalog_entry, job_id, start_date, order_by_clause=True): - endpoint = "job/{}/batch".format(job_id) + endpoint = self._get_endpoint(job_id) + "/batch" url = self.bulk_url.format(self.sf.instance_url, endpoint) body = self.sf._build_query_string(catalog_entry, start_date, order_by_clause=order_by_clause) @@ -241,7 +241,7 @@ def _poll_on_batch_status(self, job_id, batch_id): def job_exists(self, job_id): try: - endpoint = "job/{}".format(job_id) + endpoint = self._get_endpoint(job_id) url = self.bulk_url.format(self.sf.instance_url, endpoint) headers = self._get_bulk_headers() @@ -258,7 +258,7 @@ def job_exists(self, job_id): raise def _get_batches(self, job_id): - endpoint = "job/{}/batch".format(job_id) + endpoint = self._get_endpoint(job_id) + "/batch" url = self.bulk_url.format(self.sf.instance_url, endpoint) headers = self._get_bulk_headers() @@ -272,7 +272,7 @@ def _get_batches(self, job_id): return batches def _get_batch(self, job_id, batch_id): - endpoint = "job/{}/batch/{}".format(job_id, batch_id) + endpoint = self._get_endpoint(job_id) + "/batch/{}".format(batch_id) url = self.bulk_url.format(self.sf.instance_url, endpoint) headers = self._get_bulk_headers() @@ -287,7 +287,7 @@ def get_batch_results(self, job_id, batch_id, catalog_entry): """Given a job_id and batch_id, queries the batches results and reads CSV lines yielding each line as a record.""" headers = self._get_bulk_headers() - endpoint = "job/{}/batch/{}/result".format(job_id, batch_id) + endpoint = self._get_endpoint(job_id) + "/batch/{}/result".format(batch_id) url = self.bulk_url.format(self.sf.instance_url, endpoint) with metrics.http_request_timer("batch_result_list") as timer: @@ -301,31 +301,32 @@ def get_batch_results(self, job_id, batch_id, catalog_entry): xml_attribs=False, force_list={'result'})['result-list'] - for result in batch_result_list['result']: - endpoint = "job/{}/batch/{}/result/{}".format(job_id, batch_id, result) + def process_result(result): + endpoint = self._get_endpoint(job_id) + "/batch/{}/result/{}".format(batch_id, result) url = self.bulk_url.format(self.sf.instance_url, endpoint) headers['Content-Type'] = 'text/csv' with tempfile.NamedTemporaryFile(mode="w+", encoding="utf8") as csv_file: resp = self.sf._make_request('GET', url, headers=headers, stream=True) - for chunk in resp.iter_content(chunk_size=ITER_CHUNK_SIZE, decode_unicode=True): - if chunk: - # Replace any NULL bytes in the chunk so it can be safely given to the CSV reader - csv_file.write(chunk.replace('\0', '')) - - csv_file.seek(0) - csv_reader = csv.reader(csv_file, - delimiter=',', - quotechar='"') + csv_reader = csv.reader( + (chunk.replace('\0', '') for chunk in self._iter_lines(resp) if chunk), + delimiter=',', + quotechar='"' + ) column_name_list = next(csv_reader) - for line in csv_reader: - rec = dict(zip(column_name_list, line)) + records = [dict(zip(column_name_list, line)) for line in csv_reader] + return records + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(process_result, result) for result in batch_result_list['result']] + for future in concurrent.futures.as_completed(futures): + for rec in future.result(): yield rec def _close_job(self, job_id): - endpoint = "job/{}".format(job_id) + endpoint = self._get_endpoint(job_id) url = self.bulk_url.format(self.sf.instance_url, endpoint) body = {"state": "Closed"} @@ -336,6 +337,9 @@ def _close_job(self, job_id): headers=self._get_bulk_headers(), body=json.dumps(body)) + def _get_endpoint(self, job_id): + return "job/{}".format(job_id) + # pylint: disable=no-self-use def _iter_lines(self, response): """Clone of the iter_lines function from the requests library with the change diff --git a/tap_salesforce/salesforce/rest.py b/tap_salesforce/salesforce/rest.py index 539a58be..75b742aa 100644 --- a/tap_salesforce/salesforce/rest.py +++ b/tap_salesforce/salesforce/rest.py @@ -2,6 +2,7 @@ import singer import singer.utils as singer_utils from requests.exceptions import HTTPError +from tap_salesforce.salesforce import Salesforce from tap_salesforce.salesforce.exceptions import TapSalesforceException LOGGER = singer.get_logger() @@ -10,7 +11,7 @@ class Rest(): - def __init__(self, sf): + def __init__(self, sf: Salesforce): self.sf = sf def query(self, catalog_entry, state, query_override=None): @@ -42,7 +43,7 @@ def _query_recur( retryable = False try: - for rec in self._sync_records(url, headers, params): + for rec in self._sync_records(url, headers, catalog_entry, params): yield rec # If the date range was chunked (an end_date was passed), sync @@ -70,27 +71,32 @@ def _query_recur( else: raise ex - if retryable: - start_date = singer_utils.strptime_with_tz(start_date_str) - half_day_range = (end_date - start_date) // 2 - end_date = end_date - half_day_range - - if half_day_range.days == 0: - raise TapSalesforceException( - "Attempting to query by 0 day range, this would cause infinite looping.") - - query = self.sf._build_query_string(catalog_entry, singer_utils.strftime(start_date), - singer_utils.strftime(end_date)) - for record in self._query_recur( - query, - catalog_entry, - start_date_str, - end_date, - retries - 1): - yield record - - def _sync_records(self, url, headers, params): + if not retryable: + LOGGER.info("[Rest] Not retrying: Stream:%s - Query:%s", catalog_entry['stream'], query) + return + + start_date = singer_utils.strptime_with_tz(start_date_str) + half_day_range = (end_date - start_date) // 2 + end_date = end_date - half_day_range + + if half_day_range.days == 0: + raise TapSalesforceException( + "Attempting to query by 0 day range, this would cause infinite looping.") + + query = self.sf._build_query_string(catalog_entry, singer_utils.strftime(start_date), + singer_utils.strftime(end_date)) + LOGGER.info("[Rest] Retrying: Stream: %s - Query: %s", catalog_entry['stream'], query) + for record in self._query_recur( + query, + catalog_entry, + start_date_str, + end_date, + retries - 1): + yield record + + def _sync_records(self, url, headers, catalog_entry, params): while True: + LOGGER.info("[Rest] Fetching records from: Stream: %s - URL: %s", catalog_entry['stream'], url) resp = self.sf._make_request('GET', url, headers=headers, params=params, validate_json=True) resp_json = resp.json() @@ -100,6 +106,7 @@ def _sync_records(self, url, headers, params): next_records_url = resp_json.get('nextRecordsUrl') if next_records_url is None: + LOGGER.info("[Rest] No more records to fetch") break url = "{}{}".format(self.sf.instance_url, next_records_url) diff --git a/tap_salesforce/sync.py b/tap_salesforce/sync.py index 08210674..f8063e47 100644 --- a/tap_salesforce/sync.py +++ b/tap_salesforce/sync.py @@ -4,8 +4,10 @@ import singer.utils as singer_utils from singer import Transformer, metadata, metrics from requests.exceptions import RequestException +from tap_salesforce.salesforce import Salesforce from tap_salesforce.salesforce.bulk import Bulk import base64 +from concurrent.futures import ThreadPoolExecutor, as_completed LOGGER = singer.get_logger() @@ -112,7 +114,6 @@ def sync_stream(sf, catalog_entry, state, input_state, catalog,config=None): except Exception as ex: raise Exception("Error syncing {}: {}".format( stream, ex)) from ex - return counter @@ -203,215 +204,196 @@ def handle_ListView(sf,rec_id,sobject,lv_name,lv_catalog_entry,state,input_state version=lv_stream_version, time_extracted=start_time)) -def sync_records(sf, catalog_entry, state, input_state, counter, catalog,config=None): - download_files = False - if "download_files" in config: - if config['download_files']==True: - download_files = True +def sync_records(sf, catalog_entry, state, input_state, counter, catalog, config=None): + download_files = config.get('download_files', False) chunked_bookmark = singer_utils.strptime_with_tz(sf.get_start_date(state, catalog_entry)) - stream = catalog_entry['stream'] + stream = catalog_entry['stream'].replace("/", "_") schema = catalog_entry['schema'] stream_alias = catalog_entry.get('stream_alias') catalog_metadata = metadata.to_map(catalog_entry['metadata']) replication_key = catalog_metadata.get((), {}).get('replication-key') stream_version = get_stream_version(catalog_entry, state) - stream = stream.replace("/","_") - activate_version_message = singer.ActivateVersionMessage(stream=(stream_alias or stream), - version=stream_version) + activate_version_message = singer.ActivateVersionMessage(stream=(stream_alias or stream), version=stream_version) start_time = singer_utils.now() - LOGGER.info('Syncing Salesforce data for stream %s', stream) - records_post = [] - + + # Update state for current stream if "/" in state["current_stream"]: - # get current name old_key = state["current_stream"] - # get the new key name - new_key = old_key.replace("/","_") - # move to new key + new_key = old_key.replace("/", "_") state["bookmarks"][new_key] = state["bookmarks"].pop(old_key) - if not replication_key: singer.write_message(activate_version_message) - state = singer.write_bookmark( - state, catalog_entry['tap_stream_id'], 'version', None) + state = singer.write_bookmark(state, catalog_entry['tap_stream_id'], 'version', None) - # If pk_chunking is set, only write a bookmark at the end if sf.pk_chunking: - # Write a bookmark with the highest value we've seen - state = singer.write_bookmark( - state, - catalog_entry['tap_stream_id'], - replication_key, - singer_utils.strftime(chunked_bookmark)) + state = singer.write_bookmark(state, catalog_entry['tap_stream_id'], replication_key, singer_utils.strftime(chunked_bookmark)) + # Process different stream types if catalog_entry["stream"].startswith("Report_"): - report_name = catalog_entry["stream"].split("Report_", 1)[1] - - reports = [] - done = False - headers = sf._get_standard_headers() - endpoint = "queryAll" - params = {'q': 'SELECT Id,DeveloperName FROM Report'} - url = sf.data_url.format(sf.instance_url, endpoint) - - while not done: - response = sf._make_request('GET', url, headers=headers, params=params) - response_json = response.json() - done = response_json.get("done") - reports.extend(response_json.get("records", [])) - if not done: - url = sf.instance_url+response_json.get("nextRecordsUrl") - - report = [r for r in reports if report_name==r["DeveloperName"]][0] - report_id = report["Id"] - - endpoint = f"analytics/reports/{report_id}" - url = sf.data_url.format(sf.instance_url, endpoint) - response = sf._make_request('GET', url, headers=headers) + process_report_stream(sf, catalog_entry, stream, schema, stream_alias, stream_version, start_time) + elif "ListViews" == catalog_entry["stream"]: + process_list_views(sf, catalog_entry, state, input_state, catalog, start_time) + else: + process_other_streams(sf, catalog_entry, state, input_state, counter, catalog, schema, stream, stream_alias, stream_version, start_time, download_files, catalog_metadata, replication_key) + +def process_report_stream(sf, catalog_entry, stream, schema, stream_alias, stream_version, start_time): + report_name = catalog_entry["stream"].split("Report_", 1)[1] + reports = [] + done = False + headers = sf._get_standard_headers() + endpoint = "queryAll" + params = {'q': 'SELECT Id,DeveloperName FROM Report'} + url = sf.data_url.format(sf.instance_url, endpoint) + + while not done: + response = sf._make_request('GET', url, headers=headers, params=params) + response_json = response.json() + done = response_json.get("done") + reports.extend(response_json.get("records", [])) + if not done: + url = sf.instance_url + response_json.get("nextRecordsUrl") + report = next(r for r in reports if report_name == r["DeveloperName"]) + report_id = report["Id"] + + endpoint = f"analytics/reports/{report_id}" + url = sf.data_url.format(sf.instance_url, endpoint) + response = sf._make_request('GET', url, headers=headers) + + with Transformer(pre_hook=transform_bulk_data_hook) as transformer: + rec = transformer.transform(response.json(), schema) + rec = fix_record_anytype(rec, schema) + singer.write_message( + singer.RecordMessage( + stream=(stream_alias or stream), + record=rec, + version=stream_version, + time_extracted=start_time)) + +def process_list_views(sf, catalog_entry, state, input_state, catalog, start_time): + headers = sf._get_standard_headers() + endpoint = "queryAll" + params = {'q': f'SELECT Name,Id,SobjectType,DeveloperName FROM ListView'} + url = sf.data_url.format(sf.instance_url, endpoint) + response = sf._make_request('GET', url, headers=headers, params=params) + + Id_Sobject = [{"Id": r["Id"], "SobjectType": r["SobjectType"], "DeveloperName": r["DeveloperName"], "Name": r["Name"]} + for r in response.json().get('records', []) if r["Name"]] + + selected_lists_names = [] + for ln in catalog_entry.get("metadata", [])[:-1]: + if ln.get("metadata", [])['selected']: + selected_list = ln.get('breadcrumb', [])[1] + for isob in Id_Sobject: + if selected_list == f"ListView_{isob['SobjectType']}_{isob['DeveloperName']}": + selected_lists_names.append(isob) + + for list_info in selected_lists_names: + sobject = list_info['SobjectType'] + lv_name = list_info['DeveloperName'] + lv_id = list_info['Id'] + + lv_catalog = [x for x in catalog["streams"] if x["stream"] == sobject] + + if lv_catalog: + lv_catalog_entry = lv_catalog[0].copy() + try: + handle_ListView(sf, lv_id, sobject, lv_name, lv_catalog_entry, state, input_state, start_time) + except RequestException as e: + LOGGER.warning(f"No existing /'results/' endpoint was found for SobjectType:{sobject}, Id:{lv_id}") + +def process_other_streams(sf:Salesforce, catalog_entry, state, input_state, counter, catalog, schema, stream, stream_alias, stream_version, start_time, download_files, catalog_metadata, replication_key): + if catalog_entry["stream"] in ACTIVITY_STREAMS: + start_date_str = sf.get_start_date(state, catalog_entry) + start_date = singer_utils.strptime_with_tz(start_date_str) + start_date = singer_utils.strftime(start_date) + + selected_properties = sf._get_selected_properties(catalog_entry) + + query_map = { + "ActivityHistory": "ActivityHistories", + "OpenActivity": "OpenActivities" + } + + query_field = query_map[catalog_entry['stream']] + + query = "SELECT {} FROM {}".format(",".join(selected_properties), query_field) + query = f"SELECT ({query}) FROM Contact" + + order_by = "" + if replication_key: + where_clause = " WHERE {} > {} ".format( + replication_key, + start_date) + order_by = " ORDER BY {} ASC".format(replication_key) + query = query + where_clause + order_by + + def unwrap_query(query_response, query_field): + for q in query_response: + if q.get(query_field): + for f in q[query_field]["records"]: + yield f + + query_response = sf.query(catalog_entry, state, query_override=query) + query_response = unwrap_query(query_response, query_field) + else: + query_response = sf.query(catalog_entry, state) + + def process_record(rec): + counter.increment() with Transformer(pre_hook=transform_bulk_data_hook) as transformer: - rec = transformer.transform(response.json(), schema) + rec = transformer.transform(rec, schema) rec = fix_record_anytype(rec, schema) - stream = stream.replace("/","_") + if stream == 'ContentVersion': + if "IsLatest" in rec: + if rec['IsLatest'] == True and download_files == True: + rec['TextPreview'] = base64.b64encode(get_content_document_file(sf, rec['Id'])).decode('utf-8') singer.write_message( singer.RecordMessage( - stream=( - stream_alias or stream), + stream=(stream_alias or stream), record=rec, version=stream_version, time_extracted=start_time)) - elif "ListViews" == catalog_entry["stream"]: - headers = sf._get_standard_headers() - endpoint = "queryAll" - - params = {'q': f'SELECT Name,Id,SobjectType,DeveloperName FROM ListView'} - url = sf.data_url.format(sf.instance_url, endpoint) - response = sf._make_request('GET', url, headers=headers, params=params) - - Id_Sobject = [{"Id":r["Id"],"SobjectType": r["SobjectType"],"DeveloperName":r["DeveloperName"],"Name":r["Name"]} - for r in response.json().get('records',[]) if r["Name"]] - - selected_lists_names = [] - for ln in catalog_entry.get("metadata",[])[:-1]: - if ln.get("metadata",[])['selected']: - selected_list = ln.get('breadcrumb',[])[1] - for isob in Id_Sobject: - if selected_list==f"ListView_{isob['SobjectType']}_{isob['DeveloperName']}": - selected_lists_names.append(isob) - replication_key_value = replication_key and singer_utils.strptime_with_tz(rec[replication_key]) - for list_info in selected_lists_names: - - sobject = list_info['SobjectType'] - lv_name = list_info['DeveloperName'] - lv_id = list_info['Id'] - - lv_catalog = [x for x in catalog["streams"] if x["stream"] == sobject] - - if lv_catalog: - lv_catalog_entry = lv_catalog[0].copy() - try: - handle_ListView(sf,lv_id,sobject,lv_name,lv_catalog_entry,state,input_state,start_time) - except RequestException as e: - LOGGER.warning(f"No existing /'results/' endpoint was found for SobjectType:{sobject}, Id:{lv_id}") - - else: - if catalog_entry["stream"] in ACTIVITY_STREAMS: - start_date_str = sf.get_start_date(state, catalog_entry) - start_date = singer_utils.strptime_with_tz(start_date_str) - start_date = singer_utils.strftime(start_date) - - selected_properties = sf._get_selected_properties(catalog_entry) - - query_map = { - "ActivityHistory": "ActivityHistories", - "OpenActivity": "OpenActivities" - } - - query_field = query_map[catalog_entry['stream']] - - query = "SELECT {} FROM {}".format(",".join(selected_properties), query_field) - query = f"SELECT ({query}) FROM Contact" - - catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') - - order_by = "" - if replication_key: - where_clause = " WHERE {} > {} ".format( - replication_key, - start_date) - order_by = " ORDER BY {} ASC".format(replication_key) - query = query + where_clause + order_by - - def unwrap_query(query_response, query_field): - for q in query_response: - if q.get(query_field): - for f in q[query_field]["records"]: - yield f - - query_response = sf.query(catalog_entry, state, query_override=query) - query_response = unwrap_query(query_response, query_field) - else: - query_response = sf.query(catalog_entry, state) - - for rec in query_response: - counter.increment() - with Transformer(pre_hook=transform_bulk_data_hook) as transformer: - rec = transformer.transform(rec, schema) - rec = fix_record_anytype(rec, schema) - if stream=='ContentVersion': - if "IsLatest" in rec: - if rec['IsLatest']==True and download_files==True: - rec['TextPreview'] = base64.b64encode(get_content_document_file(sf,rec['Id'])).decode('utf-8') - singer.write_message( - singer.RecordMessage( - stream=( - stream_alias or stream), - record=rec, - version=stream_version, - time_extracted=start_time)) - - replication_key_value = replication_key and singer_utils.strptime_with_tz(rec[replication_key]) - - if sf.pk_chunking: - if replication_key_value and replication_key_value <= start_time and replication_key_value > chunked_bookmark: - # Replace the highest seen bookmark and save the state in case we need to resume later - chunked_bookmark = singer_utils.strptime_with_tz(rec[replication_key]) - state = singer.write_bookmark( - state, - catalog_entry['tap_stream_id'], - 'JobHighestBookmarkSeen', - singer_utils.strftime(chunked_bookmark)) - singer.write_state(state) - # Before writing a bookmark, make sure Salesforce has not given us a - # record with one outside our range - elif replication_key_value and replication_key_value <= start_time: + if sf.pk_chunking: + if replication_key_value and replication_key_value <= start_time and replication_key_value > chunked_bookmark: + chunked_bookmark = singer_utils.strptime_with_tz(rec[replication_key]) state = singer.write_bookmark( state, catalog_entry['tap_stream_id'], - replication_key, - rec[replication_key]) + 'JobHighestBookmarkSeen', + singer_utils.strftime(chunked_bookmark)) singer.write_state(state) + elif replication_key_value and replication_key_value <= start_time: + state = singer.write_bookmark( + state, + catalog_entry['tap_stream_id'], + replication_key, + rec[replication_key]) + singer.write_state(state) - selected = get_selected_streams(catalog) - if stream == "ListView" and rec.get("SobjectType") in selected and rec.get("Id") is not None: - # Handle listview - try: - sobject = rec["SobjectType"] - lv_name = rec["DeveloperName"] - lv_catalog = [x for x in catalog["streams"] if x["stream"] == sobject] - rec_id = rec["Id"] - lv_catalog_entry = lv_catalog[0].copy() - if len(lv_catalog) > 0: - handle_ListView(sf,rec_id,sobject,lv_name,lv_catalog_entry,state,input_state,start_time) - except RequestException as e: - pass + selected = get_selected_streams(catalog) + if stream == "ListView" and rec.get("SobjectType") in selected and rec.get("Id") is not None: + try: + sobject = rec["SobjectType"] + lv_name = rec["DeveloperName"] + lv_catalog = [x for x in catalog["streams"] if x["stream"] == sobject] + rec_id = rec["Id"] + lv_catalog_entry = lv_catalog[0].copy() + if len(lv_catalog) > 0: + handle_ListView(sf, rec_id, sobject, lv_name, lv_catalog_entry, state, input_state, start_time) + except RequestException as e: + pass + + with ThreadPoolExecutor() as executor: + futures = [executor.submit(process_record, rec) for rec in query_response] + for future in as_completed(futures): + future.result() def fix_record_anytype(rec, schema): From 0cf9f430587f610270f22c2f347ba3eb2da3179c Mon Sep 17 00:00:00 2001 From: Renan Butkeraites Date: Mon, 4 Nov 2024 10:48:09 -0300 Subject: [PATCH 3/6] Deprecate the temporary files usage when processing requests, change the replication-key, fix issue with concurrent record processing, improve logging messages, and bump default chunk size for pk chunking jobs --- tap_salesforce/__init__.py | 17 ++++---------- tap_salesforce/salesforce/__init__.py | 12 +++++----- tap_salesforce/salesforce/bulk.py | 32 +++++++++++++------------- tap_salesforce/salesforce/rest.py | 33 ++++++++++++++++----------- tap_salesforce/sync.py | 12 +++++----- 5 files changed, 53 insertions(+), 53 deletions(-) diff --git a/tap_salesforce/__init__.py b/tap_salesforce/__init__.py index 841ec26d..69c589f3 100644 --- a/tap_salesforce/__init__.py +++ b/tap_salesforce/__init__.py @@ -75,7 +75,7 @@ def build_state(raw_state, catalog): state = singer.write_bookmark(state, tap_stream_id, 'JobHighestBookmarkSeen', current_bookmark) if replication_method == 'INCREMENTAL': - replication_key = catalog_metadata.get((), {}).get('replication-key') + replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) replication_key_value = singer.get_bookmark(raw_state, tap_stream_id, replication_key) @@ -432,11 +432,9 @@ def do_sync(sf, catalog, state,config=None): stream=(stream_alias or stream.replace("/","_")), version=stream_version) catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) - mdata = metadata.to_map(catalog_entry['metadata']) - - if not stream_is_selected(mdata): + if not stream_is_selected(catalog_metadata): LOGGER.info("%s: Skipping - not selected", stream_name) continue @@ -490,14 +488,9 @@ def do_sync(sf, catalog, state,config=None): bookmark_is_empty = state.get('bookmarks', {}).get( catalog_entry['tap_stream_id']) is None - if "/" in state["current_stream"]: - # get current name - old_key = state["current_stream"] - # get the new key name - new_key = old_key.replace("/","_") - state["current_stream"] = new_key - + state["current_stream"] = state["current_stream"].replace("/", "_") catalog_entry['tap_stream_id'] = catalog_entry['tap_stream_id'].replace("/","_") + if replication_key or bookmark_is_empty: singer.write_message(activate_version_message) state = singer.write_bookmark(state, diff --git a/tap_salesforce/salesforce/__init__.py b/tap_salesforce/salesforce/__init__.py index 0b3dfdb4..65451a44 100644 --- a/tap_salesforce/salesforce/__init__.py +++ b/tap_salesforce/salesforce/__init__.py @@ -281,13 +281,13 @@ def check_rest_quota_usage(self, headers): on_backoff=log_backoff_attempt) def _make_request(self, http_method, url, headers=None, body=None, stream=False, params=None, validate_json=False, timeout=None): if http_method == "GET": - LOGGER.info("Making %s request to %s with params: %s", http_method, url, params) + LOGGER.debug("[REST] Making %s request to %s", http_method, url) resp = self.session.get(url, headers=headers, stream=stream, params=params, timeout=timeout) - LOGGER.info("Completed %s request to %s with params: %s", http_method, url, params) + LOGGER.debug("[REST] Completed %s request to %s", http_method, url) elif http_method == "POST": - LOGGER.info("Making %s request to %s with body %s", http_method, url, body) + LOGGER.debug("[REST] Making %s request to %s", http_method, url) resp = self.session.post(url, headers=headers, data=body) - LOGGER.info("Completed %s request to %s with body %s", http_method, url, body) + LOGGER.debug("[REST] Completed %s request to %s", http_method, url) else: raise TapSalesforceException("Unsupported HTTP method") @@ -399,7 +399,7 @@ def _get_selected_properties(self, catalog_entry): def get_start_date(self, state, catalog_entry): catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) return (singer.get_bookmark(state, catalog_entry['tap_stream_id'], @@ -411,7 +411,7 @@ def _build_query_string(self, catalog_entry, start_date, end_date=None, order_by query = "SELECT {} FROM {}".format(",".join(selected_properties), catalog_entry['stream']) catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) if replication_key: where_clause = " WHERE {} > {} ".format( diff --git a/tap_salesforce/salesforce/bulk.py b/tap_salesforce/salesforce/bulk.py index 309f5830..5b79e707 100644 --- a/tap_salesforce/salesforce/bulk.py +++ b/tap_salesforce/salesforce/bulk.py @@ -11,13 +11,14 @@ import concurrent.futures import xmltodict +from tap_salesforce.salesforce import Salesforce from tap_salesforce.salesforce.exceptions import ( TapSalesforceException, TapSalesforceQuotaExceededException) BATCH_STATUS_POLLING_SLEEP = 20 PK_CHUNKED_BATCH_STATUS_POLLING_SLEEP = 60 ITER_CHUNK_SIZE = 2**15 -DEFAULT_CHUNK_SIZE = 100000 # Max is 250000 +DEFAULT_CHUNK_SIZE = 250000 # Max is 250000 LOGGER = singer.get_logger() @@ -44,7 +45,7 @@ class Bulk(): def bulk_url(self): return "{}/services/async/" + self.sf.version + "/{}" - def __init__(self, sf): + def __init__(self, sf:Salesforce): # Set csv max reading size to the platform's max size available. csv.field_size_limit(sys.maxsize) self.sf = sf @@ -139,8 +140,8 @@ def _bulk_query(self, catalog_entry, state): yield result # Remove the completed batch ID and write state state['bookmarks'][catalog_entry['tap_stream_id']]["BatchIDs"].remove(completed_batch_id) - LOGGER.info("Finished syncing batch %s. Removing batch from state.", completed_batch_id) - LOGGER.info("Batches to go: %d", len(state['bookmarks'][catalog_entry['tap_stream_id']]["BatchIDs"])) + LOGGER.info("[BULK] Finished syncing batch %s. Removing batch from state.", completed_batch_id) + LOGGER.info("[BULK] Batches to go: %d", len(state['bookmarks'][catalog_entry['tap_stream_id']]["BatchIDs"])) singer.write_state(state) else: raise TapSalesforceException(batch_status['stateMessage']) @@ -149,7 +150,7 @@ def _bulk_query(self, catalog_entry, state): yield result def _bulk_query_with_pk_chunking(self, catalog_entry, start_date): - LOGGER.info("Retrying Bulk Query with PK Chunking") + LOGGER.info("[BULK] Retrying Bulk Query with PK Chunking") # Create a new job job_id = self._create_job(catalog_entry, True) @@ -175,7 +176,7 @@ def _create_job(self, catalog_entry, pk_chunking=False): headers['Sforce-Disable-Batch-Retry'] = "true" if pk_chunking: - LOGGER.info("ADDING PK CHUNKING HEADER") + LOGGER.info("[BULK] ADDING PK CHUNKING HEADER") headers['Sforce-Enable-PKChunking'] = "true; chunkSize={}".format(DEFAULT_CHUNK_SIZE) @@ -306,18 +307,17 @@ def process_result(result): url = self.bulk_url.format(self.sf.instance_url, endpoint) headers['Content-Type'] = 'text/csv' - with tempfile.NamedTemporaryFile(mode="w+", encoding="utf8") as csv_file: - resp = self.sf._make_request('GET', url, headers=headers, stream=True) - csv_reader = csv.reader( - (chunk.replace('\0', '') for chunk in self._iter_lines(resp) if chunk), - delimiter=',', - quotechar='"' - ) + resp = self.sf._make_request('GET', url, headers=headers, stream=True) + csv_reader = csv.reader( + (chunk.replace('\0', '') for chunk in self._iter_lines(resp) if chunk), + delimiter=',', + quotechar='"' + ) - column_name_list = next(csv_reader) + column_name_list = next(csv_reader) - records = [dict(zip(column_name_list, line)) for line in csv_reader] - return records + records = [dict(zip(column_name_list, line)) for line in csv_reader] + return records with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(process_result, result) for result in batch_result_list['result']] diff --git a/tap_salesforce/salesforce/rest.py b/tap_salesforce/salesforce/rest.py index 75b742aa..7e5fa8ff 100644 --- a/tap_salesforce/salesforce/rest.py +++ b/tap_salesforce/salesforce/rest.py @@ -1,7 +1,7 @@ # pylint: disable=protected-access import singer import singer.utils as singer_utils -from requests.exceptions import HTTPError +from requests.exceptions import HTTPError, RequestException from tap_salesforce.salesforce import Salesforce from tap_salesforce.salesforce.exceptions import TapSalesforceException @@ -72,7 +72,7 @@ def _query_recur( raise ex if not retryable: - LOGGER.info("[Rest] Not retrying: Stream:%s - Query:%s", catalog_entry['stream'], query) + LOGGER.info("[REST] Not retrying: Stream:%s", catalog_entry['stream']) return start_date = singer_utils.strptime_with_tz(start_date_str) @@ -85,7 +85,7 @@ def _query_recur( query = self.sf._build_query_string(catalog_entry, singer_utils.strftime(start_date), singer_utils.strftime(end_date)) - LOGGER.info("[Rest] Retrying: Stream: %s - Query: %s", catalog_entry['stream'], query) + LOGGER.info("[REST] Retrying: Stream: %s", catalog_entry['stream']) for record in self._query_recur( query, catalog_entry, @@ -95,18 +95,25 @@ def _query_recur( yield record def _sync_records(self, url, headers, catalog_entry, params): + # Set the desired batch size + params['batchSize'] = 20000 # Adjust this value as needed, max is typically 2000 + while True: - LOGGER.info("[Rest] Fetching records from: Stream: %s - URL: %s", catalog_entry['stream'], url) - resp = self.sf._make_request('GET', url, headers=headers, params=params, validate_json=True) - resp_json = resp.json() + LOGGER.debug("[REST] Fetching records from: Stream: %s - URL: %s", catalog_entry['stream'], url) + try: + resp = self.sf._make_request('GET', url, headers=headers, params=params, validate_json=True) + resp_json = resp.json() - for rec in resp_json.get('records'): - yield rec + for rec in resp_json.get('records'): + yield rec - next_records_url = resp_json.get('nextRecordsUrl') + next_records_url = resp_json.get('nextRecordsUrl') - if next_records_url is None: - LOGGER.info("[Rest] No more records to fetch") - break + if next_records_url is None: + LOGGER.info("[REST] No more records to fetch from: Stream: %s - URL: %s", catalog_entry['stream'], url) + break - url = "{}{}".format(self.sf.instance_url, next_records_url) + url = "{}{}".format(self.sf.instance_url, next_records_url) + except RequestException as e: + LOGGER.error("Error fetching records: %s", e) + raise e diff --git a/tap_salesforce/sync.py b/tap_salesforce/sync.py index f8063e47..5e3ccfed 100644 --- a/tap_salesforce/sync.py +++ b/tap_salesforce/sync.py @@ -41,7 +41,7 @@ def transform_bulk_data_hook(data, typ, schema): def get_stream_version(catalog_entry, state): tap_stream_id = catalog_entry['tap_stream_id'] catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) if singer.get_bookmark(state, tap_stream_id, 'version') is None: stream_version = int(time.time() * 1000) @@ -62,7 +62,7 @@ def resume_syncing_bulk_query(sf, catalog_entry, job_id, state, counter): stream = catalog_entry['stream'] stream_alias = catalog_entry.get('stream_alias') catalog_metadata = metadata.to_map(catalog_entry.get('metadata')) - replication_key = catalog_metadata.get((), {}).get('replication-key') + replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) stream_version = get_stream_version(catalog_entry, state) schema = catalog_entry['schema'] @@ -147,7 +147,7 @@ def handle_ListView(sf,rec_id,sobject,lv_name,lv_catalog_entry,state,input_state # Save the schema lv_schema = lv_catalog_entry['schema'] lv_catalog_metadata = metadata.to_map(lv_catalog_entry['metadata']) - lv_replication_key = lv_catalog_metadata.get((), {}).get('replication-key') + lv_replication_key = next(iter(lv_catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) lv_key_properties = lv_catalog_metadata.get((), {}).get('table-key-properties') date_filter = None @@ -211,7 +211,7 @@ def sync_records(sf, catalog_entry, state, input_state, counter, catalog, config schema = catalog_entry['schema'] stream_alias = catalog_entry.get('stream_alias') catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) stream_version = get_stream_version(catalog_entry, state) activate_version_message = singer.ActivateVersionMessage(stream=(stream_alias or stream), version=stream_version) @@ -342,7 +342,7 @@ def unwrap_query(query_response, query_field): else: query_response = sf.query(catalog_entry, state) - def process_record(rec): + def process_record(rec, state): counter.increment() with Transformer(pre_hook=transform_bulk_data_hook) as transformer: rec = transformer.transform(rec, schema) @@ -391,7 +391,7 @@ def process_record(rec): pass with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_record, rec) for rec in query_response] + futures = [executor.submit(process_record, rec, state) for rec in query_response] for future in as_completed(futures): future.result() From 3a598cdb1378fedd047d6831d08436250d40e461 Mon Sep 17 00:00:00 2001 From: Renan Butkeraites Date: Wed, 20 Nov 2024 18:25:19 -0300 Subject: [PATCH 4/6] Change typing to consider date and datetime fields nullable --- tap_salesforce/__init__.py | 2 +- tap_salesforce/salesforce/__init__.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tap_salesforce/__init__.py b/tap_salesforce/__init__.py index 69c589f3..7c96bc75 100644 --- a/tap_salesforce/__init__.py +++ b/tap_salesforce/__init__.py @@ -256,7 +256,7 @@ def get_views_list(sf): # pylint: disable=too-many-branches,too-many-statements -def do_discover(sf): +def do_discover(sf: Salesforce): """Describes a Salesforce instance's objects and generates a JSON schema for each field.""" global_description = sf.describe() diff --git a/tap_salesforce/salesforce/__init__.py b/tap_salesforce/salesforce/__init__.py index 65451a44..7fc584d2 100644 --- a/tap_salesforce/salesforce/__init__.py +++ b/tap_salesforce/salesforce/__init__.py @@ -142,17 +142,17 @@ def field_to_property_schema(field, mdata): # pylint:disable=too-many-branches sf_type = field['type'] if sf_type in STRING_TYPES: - property_schema['type'] = "string" + property_schema['type'] = ["string", "null"] elif sf_type in DATE_TYPES: - date_type = {"type": "string", "format": "date-time"} + date_type = {"type": ["string", "null"], "format": "date-time"} string_type = {"type": ["string", "null"]} property_schema["anyOf"] = [date_type, string_type] elif sf_type == "boolean": - property_schema['type'] = "boolean" + property_schema['type'] = ["boolean", "null"] elif sf_type in NUMBER_TYPES: - property_schema['type'] = "number" + property_schema['type'] = ["number", "null"] elif sf_type == "address": - property_schema['type'] = "object" + property_schema['type'] = ["object", "null"] property_schema['properties'] = { "street": {"type": ["null", "string"]}, "state": {"type": ["null", "string"]}, @@ -164,9 +164,9 @@ def field_to_property_schema(field, mdata): # pylint:disable=too-many-branches "geocodeAccuracy": {"type": ["null", "string"]} } elif sf_type == "int": - property_schema['type'] = "integer" + property_schema['type'] = ["integer", "null"] elif sf_type == "time": - property_schema['type'] = "string" + property_schema['type'] = ["string", "null"] elif sf_type in LOOSE_TYPES: return property_schema, mdata # No type = all types elif sf_type in BINARY_TYPES: @@ -182,13 +182,14 @@ def field_to_property_schema(field, mdata): # pylint:disable=too-many-branches "latitude": {"type": ["null", "number"]} } elif sf_type == 'json': - property_schema['type'] = "string" + property_schema['type'] = ["string", "null"] else: raise TapSalesforceException("Found unsupported type: {}".format(sf_type)) # The nillable field cannot be trusted if field_name != 'Id' and sf_type != 'location' and sf_type not in DATE_TYPES: - property_schema['type'] = ["null", property_schema['type']] + if "null" not in property_schema['type']: + property_schema['type'].append("null") return property_schema, mdata From 9860c3a10e19675803ae9ee41b40e4ba406296ce Mon Sep 17 00:00:00 2001 From: Renan Butkeraites Date: Tue, 10 Dec 2024 11:36:14 -0300 Subject: [PATCH 5/6] Refactor schema type handling to prioritize 'null' in property schemas and update record processing to include schema parameter. This improves data transformation for Salesforce Bulk API and ensures proper handling of nullable fields. --- tap_salesforce/salesforce/__init__.py | 22 +++++++++++----------- tap_salesforce/sync.py | 6 +++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tap_salesforce/salesforce/__init__.py b/tap_salesforce/salesforce/__init__.py index 7fc584d2..ed8b4e6f 100644 --- a/tap_salesforce/salesforce/__init__.py +++ b/tap_salesforce/salesforce/__init__.py @@ -142,17 +142,17 @@ def field_to_property_schema(field, mdata): # pylint:disable=too-many-branches sf_type = field['type'] if sf_type in STRING_TYPES: - property_schema['type'] = ["string", "null"] + property_schema['type'] = ["null", "string"] elif sf_type in DATE_TYPES: - date_type = {"type": ["string", "null"], "format": "date-time"} - string_type = {"type": ["string", "null"]} + date_type = {"type": ["null", "string"], "format": "date-time"} + string_type = {"type": ["null", "string"]} property_schema["anyOf"] = [date_type, string_type] elif sf_type == "boolean": - property_schema['type'] = ["boolean", "null"] + property_schema['type'] = ["null", "boolean"] elif sf_type in NUMBER_TYPES: - property_schema['type'] = ["number", "null"] + property_schema['type'] = ["null", "number"] elif sf_type == "address": - property_schema['type'] = ["object", "null"] + property_schema['type'] = ["null", "object"] property_schema['properties'] = { "street": {"type": ["null", "string"]}, "state": {"type": ["null", "string"]}, @@ -164,9 +164,9 @@ def field_to_property_schema(field, mdata): # pylint:disable=too-many-branches "geocodeAccuracy": {"type": ["null", "string"]} } elif sf_type == "int": - property_schema['type'] = ["integer", "null"] + property_schema['type'] = ["null", "integer"] elif sf_type == "time": - property_schema['type'] = ["string", "null"] + property_schema['type'] = ["null", "string"] elif sf_type in LOOSE_TYPES: return property_schema, mdata # No type = all types elif sf_type in BINARY_TYPES: @@ -176,20 +176,20 @@ def field_to_property_schema(field, mdata): # pylint:disable=too-many-branches return property_schema, mdata elif sf_type == 'location': # geo coordinates are numbers or objects divided into two fields for lat/long - property_schema['type'] = ["number", "object", "null"] + property_schema['type'] = ["null", "object", "number"] property_schema['properties'] = { "longitude": {"type": ["null", "number"]}, "latitude": {"type": ["null", "number"]} } elif sf_type == 'json': - property_schema['type'] = ["string", "null"] + property_schema['type'] = ["null", "string"] else: raise TapSalesforceException("Found unsupported type: {}".format(sf_type)) # The nillable field cannot be trusted if field_name != 'Id' and sf_type != 'location' and sf_type not in DATE_TYPES: if "null" not in property_schema['type']: - property_schema['type'].append("null") + property_schema['type'].insert(0, "null") return property_schema, mdata diff --git a/tap_salesforce/sync.py b/tap_salesforce/sync.py index 5e3ccfed..e709fec8 100644 --- a/tap_salesforce/sync.py +++ b/tap_salesforce/sync.py @@ -33,7 +33,7 @@ def transform_bulk_data_hook(data, typ, schema): # Salesforce Bulk API returns CSV's with empty strings for text fields. # When the text field is nillable and the data value is an empty string, # change the data so that it is None. - if data == "" and "null" in schema['type']: + if data == "" and "null" in schema.get('type', []): result = None return result @@ -342,7 +342,7 @@ def unwrap_query(query_response, query_field): else: query_response = sf.query(catalog_entry, state) - def process_record(rec, state): + def process_record(rec, state, schema): counter.increment() with Transformer(pre_hook=transform_bulk_data_hook) as transformer: rec = transformer.transform(rec, schema) @@ -391,7 +391,7 @@ def process_record(rec, state): pass with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_record, rec, state) for rec in query_response] + futures = [executor.submit(process_record, rec, state, schema) for rec in query_response] for future in as_completed(futures): future.result() From 40ce1987ecefe5126677014de912cc700cb9b3f1 Mon Sep 17 00:00:00 2001 From: Renan Butkeraites Date: Mon, 23 Dec 2024 15:52:52 -0300 Subject: [PATCH 6/6] Enhance schema handling by introducing deep copies of catalog entries and schemas to prevent unintended mutations during processing. This change improves data integrity in sync operations and ensures that original catalog data remains unchanged while being manipulated. --- tap_salesforce/__init__.py | 7 +++++-- tap_salesforce/sync.py | 9 ++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tap_salesforce/__init__.py b/tap_salesforce/__init__.py index 7c96bc75..518473a6 100644 --- a/tap_salesforce/__init__.py +++ b/tap_salesforce/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import copy import json import sys import singer @@ -56,7 +57,8 @@ def stream_is_selected(mdata): def build_state(raw_state, catalog): state = {} - for catalog_entry in catalog['streams']: + for read_only_catalog_entry in catalog['streams']: + catalog_entry = copy.deepcopy(read_only_catalog_entry) tap_stream_id = catalog_entry['tap_stream_id'] catalog_metadata = metadata.to_map(catalog_entry['metadata']) replication_method = catalog_metadata.get((), {}).get('replication-method') @@ -423,7 +425,8 @@ def do_sync(sf, catalog, state,config=None): catalog["streams"] = list_view + catalog["streams"] # Sync Streams - for catalog_entry in catalog["streams"]: + for read_only_catalog_entry in catalog["streams"]: + catalog_entry = copy.deepcopy(read_only_catalog_entry) stream_version = get_stream_version(catalog_entry, state) stream = catalog_entry['stream'] stream_alias = catalog_entry.get('stream_alias') diff --git a/tap_salesforce/sync.py b/tap_salesforce/sync.py index e709fec8..2308093f 100644 --- a/tap_salesforce/sync.py +++ b/tap_salesforce/sync.py @@ -1,3 +1,4 @@ +import copy import time import re import singer @@ -64,7 +65,7 @@ def resume_syncing_bulk_query(sf, catalog_entry, job_id, state, counter): catalog_metadata = metadata.to_map(catalog_entry.get('metadata')) replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) stream_version = get_stream_version(catalog_entry, state) - schema = catalog_entry['schema'] + schema_copy = copy.deepcopy(catalog_entry['schema']) if not bulk.job_exists(job_id): LOGGER.info("Found stored Job ID that no longer exists, resetting bookmark and removing JobID from state.") @@ -75,6 +76,7 @@ def resume_syncing_bulk_query(sf, catalog_entry, job_id, state, counter): with Transformer(pre_hook=transform_bulk_data_hook) as transformer: for rec in bulk.get_batch_results(job_id, batch_id, catalog_entry): counter.increment() + schema = copy.deepcopy(schema_copy) rec = transformer.transform(rec, schema) rec = fix_record_anytype(rec, schema) singer.write_message( @@ -208,7 +210,7 @@ def sync_records(sf, catalog_entry, state, input_state, counter, catalog, config download_files = config.get('download_files', False) chunked_bookmark = singer_utils.strptime_with_tz(sf.get_start_date(state, catalog_entry)) stream = catalog_entry['stream'].replace("/", "_") - schema = catalog_entry['schema'] + schema = copy.deepcopy(catalog_entry['schema']) stream_alias = catalog_entry.get('stream_alias') catalog_metadata = metadata.to_map(catalog_entry['metadata']) replication_key = next(iter(catalog_metadata.get((), {}).get('valid-replication-keys', [])), None) @@ -342,7 +344,8 @@ def unwrap_query(query_response, query_field): else: query_response = sf.query(catalog_entry, state) - def process_record(rec, state, schema): + def process_record(rec, state, read_only_schema): + schema = copy.deepcopy(read_only_schema) counter.increment() with Transformer(pre_hook=transform_bulk_data_hook) as transformer: rec = transformer.transform(rec, schema)