From 0289211f72428e7aeb5b64ff6e5cbdcb9d354977 Mon Sep 17 00:00:00 2001 From: Chris Polcyn Date: Mon, 31 Oct 2016 21:31:45 -0500 Subject: [PATCH] added debug version of Apex test runner --- cumulusci/tasks/salesforce.py | 368 ++++++++++++++++++++--- cumulusci/tasks/tests/test_salesforce.py | 112 ++++++- 2 files changed, 426 insertions(+), 54 deletions(-) diff --git a/cumulusci/tasks/salesforce.py b/cumulusci/tasks/salesforce.py index fc02298f45..21764b19d2 100644 --- a/cumulusci/tasks/salesforce.py +++ b/cumulusci/tasks/salesforce.py @@ -1,6 +1,8 @@ import base64 import cgi +import datetime import io +import json import logging import os import tempfile @@ -68,6 +70,10 @@ class BaseSalesforceToolingApiTask(BaseSalesforceApiTask): def _init_task(self): self.tooling = self._init_api() self.tooling.base_url += 'tooling/' + self._init_class() + + def _init_class(self): + pass def _get_tooling_object(self, obj_name): obj = getattr(self.tooling, obj_name) @@ -181,6 +187,32 @@ class RunApexTests(BaseSalesforceToolingApiTask): }, } + def _init_class(self): + self.classes_by_id = {} + self.classes_by_name = {} + self.job_id = None + self.results_by_class_name = {} + self._debug_init_class() + + # These are overridden in the debug version + def _debug_init_class(self): + pass + + def _debug_get_duration_class(self, class_id): + pass + + def _debug_get_duration_method(self, result): + pass + + def _debug_get_logs(self): + pass + + def _debug_get_results(self, result): + pass + + def _debug_create_trace_flag(self): + pass + def _decode_to_unicode(self, content): if content: try: @@ -221,34 +253,44 @@ def _get_test_classes(self): self.logger.info('Found {} test classes'.format(result['totalSize'])) return result - def _get_test_results(self, job_id, classes_by_id, results_by_class_name, - classes_by_name): + def _get_test_results(self): result = self.tooling.query_all("SELECT StackTrace, Message, " + "ApexLogId, AsyncApexJobId, MethodName, Outcome, ApexClassId, " + "TestTimestamp FROM ApexTestResult " + - "WHERE AsyncApexJobId = '{}'".format(job_id)) + "WHERE AsyncApexJobId = '{}'".format(self.job_id)) counts = { 'Pass': 0, 'Fail': 0, 'CompileFail': 0, 'Skip': 0, } - for record in result['records']: - class_name = classes_by_id[record['ApexClassId']] - results_by_class_name[class_name][record['MethodName']] = record - counts[record['Outcome']] += 1 + for test_result in result['records']: + class_name = self.classes_by_id[test_result['ApexClassId']] + self.results_by_class_name[class_name][test_result[ + 'MethodName']] = test_result + counts[test_result['Outcome']] += 1 + self._debug_get_results(test_result) + self._debug_get_logs() test_results = [] - class_names = results_by_class_name.keys() + class_names = self.results_by_class_name.keys() class_names.sort() for class_name in class_names: - class_id = classes_by_name[class_name] - duration = None - self.logger.info(u'Class: {}'.format(class_name)) - method_names = results_by_class_name[class_name].keys() + class_id = self.classes_by_name[class_name] + message = 'Class: {}'.format(class_name) + duration = self._debug_get_duration_class(class_id) + if duration: + message += '({}s)'.format(duration) + self.logger.info(message) + method_names = self.results_by_class_name[class_name].keys() method_names.sort() for method_name in method_names: - result = results_by_class_name[class_name][method_name] - self.logger.info(u'\t{Outcome}: {MethodName}'.format(**result)) + result = self.results_by_class_name[class_name][method_name] + message = '\t{}: {}'.format(result['Outcome'], + result['MethodName']) + duration = self._debug_get_duration_method(result) + if duration: + message += ' ({}s)'.format(duration) + self.logger.info(message) test_results.append({ 'Children': result.get('children', None), 'ClassName': self._decode_to_unicode(class_name), @@ -261,31 +303,31 @@ def _get_test_results(self, job_id, classes_by_id, results_by_class_name, 'TestTimestamp': result.get('TestTimestamp', None), }) if result['Outcome'] in ['Fail', 'CompileFail']: - self.logger.info(u'\tMessage: {Message}'.format(**result)) - self.logger.info(u'\tStackTrace: {StackTrace}'.format( - **result)) - self.logger.info(u'-' * 80) - self.logger.info(u'Pass: {} Fail: {} CompileFail: {} Skip: {}' + self.logger.info('\tMessage: {}'.format(result['Message'])) + self.logger.info('\tStackTrace: {}'.format( + result['StackTrace'])) + self.logger.info('-' * 80) + self.logger.info('Pass: {} Fail: {} CompileFail: {} Skip: {}' .format( counts['Pass'], counts['Fail'], counts['CompileFail'], counts['Skip'], )) - self.logger.info(u'-' * 80) + self.logger.info('-' * 80) if counts['Fail'] or counts['CompileFail']: - self.logger.info(u'-' * 80) - self.logger.info(u'Failing Tests') - self.logger.info(u'-' * 80) + self.logger.info('-' * 80) + self.logger.info('Failing Tests') + self.logger.info('-' * 80) counter = 0 for result in test_results: if result['Outcome'] not in ['Fail', 'CompileFail']: continue counter += 1 - self.logger.info(u'{}: {}.{} - {}'.format(counter, + self.logger.info('{}: {}.{} - {}'.format(counter, result['ClassName'], result['Method'], result['Outcome'])) - self.logger.info(u'\tMessage: {}'.format(result['Message'])) - self.logger.info(u'\tStackTrace: {}'.format( + self.logger.info('\tMessage: {}'.format(result['Message'])) + self.logger.info('\tStackTrace: {}'.format( result['StackTrace'])) return test_results @@ -293,31 +335,25 @@ def _run_task(self): result = self._get_test_classes() if result['totalSize'] == 0: return - classes_by_id = {} - classes_by_name = {} - trace_id = None - results_by_class_name = {} - classes_by_log_id = {} - logs_by_class_id = {} - for record in result['records']: - classes_by_id[record['Id']] = record['Name'] - classes_by_name[record['Name']] = record['Id'] - results_by_class_name[record['Name']] = {} + for test_class in result['records']: + self.classes_by_id[test_class['Id']] = test_class['Name'] + self.classes_by_name[test_class['Name']] = test_class['Id'] + self.results_by_class_name[test_class['Name']] = {} + self._debug_create_trace_flag() self.logger.info('Queuing tests for execution...') - ids = classes_by_id.keys() - job_id = self.tooling.restful('runTestsAsynchronous', + ids = self.classes_by_id.keys() + self.job_id = self.tooling.restful('runTestsAsynchronous', params={'classids': ','.join(str(id) for id in ids)}) - self._wait_for_tests(job_id) - test_results = self._get_test_results( - job_id, classes_by_id, results_by_class_name, classes_by_name) + self._wait_for_tests() + test_results = self._get_test_results() self._write_output(test_results) - def _wait_for_tests(self, job_id): + def _wait_for_tests(self): poll_interval = int(self.options.get('poll_interval', 1)) while True: result = self.tooling.query_all( "SELECT Id, Status, ApexClassId FROM ApexTestQueueItem " + - "WHERE ParentJobId = '{}'".format(job_id)) + "WHERE ParentJobId = '{}'".format(self.job_id)) counts = { 'Aborted': 0, 'Completed': 0, @@ -327,8 +363,8 @@ def _wait_for_tests(self, job_id): 'Processing': 0, 'Queued': 0, } - for record in result['records']: - counts[record['Status']] += 1 + for test_queue_item in result['records']: + counts[test_queue_item['Status']] += 1 self.logger.info('Completed: {} Processing: {} Queued: {}' .format( counts['Completed'], @@ -341,11 +377,11 @@ def _wait_for_tests(self, job_id): sleep(poll_interval) def _write_output(self, test_results): - filename = self.options['junit_output'] - with io.open(filename, mode='w', encoding='utf-8') as f: + results_filename = self.options['results_filename'] + with io.open(results_filename, mode='w', encoding='utf-8') as f: f.write(u'\n'.format(len(test_results))) for result in test_results: - s = u' ') + + +class RunApexTestsDebug(RunApexTests): + """Run Apex tests and collect debug info""" + + def _debug_init_class(self): + self.classes_by_log_id = {} + self.logs_by_class_id = {} + self.tooling.TraceFlag.base_url = ( + 'https://{}/services/data/v{}/tooling/sobjects/{}/'.format( + self.tooling.sf_instance, self.tooling.sf_version, 'TraceFlag')) + self.trace_id = None + + def _debug_create_trace_flag(self): + """Create a TraceFlag for a given user.""" + self._delete_trace_flags() + self.logger.info('Setting up trace flag to capture debug logs') + # New TraceFlag expires 12 hours from now + expiration_date = (datetime.datetime.now() + + datetime.timedelta(seconds=60*60*12)) + result = self.tooling.TraceFlag.create({ + 'ApexCode': 'Info', + 'ApexProfiling': 'Debug', + 'Callout': 'Info', + 'Database': 'Info', + 'ExpirationDate': expiration_date.isoformat(), + 'System': 'Info', + 'TracedEntityId': self.org_config.user_id, + 'Validation': 'Info', + 'Visualforce': 'Info', + 'Workflow': 'Info', + }) + self.trace_id = result['id'] + self.logger.info('Created TraceFlag for user') + + def _delete_trace_flags(self): + """Delete existing TraceFlags.""" + self.logger.info('Deleting existing TraceFlags') + traceflags = self.tooling.query('Select Id from TraceFlag') + if traceflags['totalSize']: + for traceflag in traceflags['records']: + self.tooling.TraceFlag.delete(str(traceflag['Id'])) + + def _debug_get_duration_class(self, class_id): + if class_id in self.logs_by_class_id: + return int(self.logs_by_class_id[class_id][ + 'DurationMilliseconds']) * .001 + + def _debug_get_duration_method(self, result): + if result.get('stats') and 'duration' in result['stats']: + return result['stats']['duration'] + + def _debug_get_logs(self): + log_ids = "('{}')".format( + "','".join(str(id) for id in self.classes_by_log_id.keys())) + result = self.tooling.query_all('SELECT Id, Application, ' + + 'DurationMilliseconds, Location, LogLength, LogUserId, ' + + 'Operation, Request, StartTime, Status ' + + 'from ApexLog where Id in {}'.format(log_ids)) + for log in result['records']: + class_id = self.classes_by_log_id[log['Id']] + class_name = self.classes_by_id[class_id] + self.logs_by_class_id[class_id] = log + body_url = '{}sobjects/ApexLog/{}/Body'.format( + self.tooling.base_url, log['Id']) + response = self.tooling.request.get(body_url, + headers=self.tooling.headers) + log_file = class_name + '.log' + debug_log_dir = self.options.get('debug_log_dir') + if debug_log_dir: + log_file = os.path.join(debug_log_dir, log_file) + with io.open(log_file, mode='w', encoding='utf-8') as f: + f.write(unicode(response.content)) + with io.open(log_file, mode='r', encoding='utf-8') as f: + method_stats = self._parse_log(class_name, f) + # Add method stats to results_by_class_name + for method, info in method_stats.items(): + results_by_class_name[class_name][method].update(info) + # Delete the TraceFlag + self.tooling.TraceFlag.delete(str(self.trace_id)) + + def _debug_get_results(self, result): + if result['ApexLogId']: + self.classes_by_log_id[result['ApexLogId']] = result['ApexClassId'] + + def _log_time_delta(self, start, end): + """ + Returns microsecond difference between two debug log timestamps in the + format HH:MM:SS.micro. + """ + dummy_date = datetime.date(2001, 1, 1) + dummy_date_next = datetime.date(2001, 1, 2) + # Split out the parts of the start and end string + start_parts = re.split(':|\.', start) + start_parts = [int(part) for part in start_parts] + start_parts[3] = start_parts[3] * 1000 + t_start = datetime.time(*start_parts) + end_parts = re.split(':|\.', end) + end_parts = [int(part) for part in end_parts] + end_parts[3] = end_parts[3] * 1000 + t_end = datetime.time(*end_parts) + # Combine with dummy date to do date math + d_start = datetime.datetime.combine(dummy_date, t_start) + # If end was on the next day, attach to next dummy day + if start_parts[0] > end_parts[0]: + d_end = datetime.datetime.combine(dummy_date_next, t_end) + else: + d_end = datetime.datetime.combine(dummy_date, t_end) + delta = d_end - d_start + return delta.total_seconds() + + def _parse_log(self, class_name, f): + """Parse an Apex test log.""" + class_name = self._decode_to_unicode(class_name) + methods = {} + for method, stats, children in self._parse_log_by_method(class_name, + f): + methods[method] = {'stats': stats, 'children': children} + return methods + + def _parse_log_by_method(self, class_name, f): + """Parse an Apex test log by method.""" + stats = {} + last_stats = {} + in_limits = False + in_cumulative_limits = False + in_testing_limits = False + unit = None + method = None + children = {} + parent = None + for line in f: + line = self._decode_to_unicode(line).strip() + if '|CODE_UNIT_STARTED|[EXTERNAL]|' in line: + unit, unit_type, unit_info = self._parse_unit_started( + class_name, line) + if unit_type == 'test_method': + method = self._decode_to_unicode(unit) + method_unit_info = unit_info + children = [] + stack = [] + else: + stack.append({ + 'unit': unit, + 'unit_type': unit_type, + 'unit_info': unit_info, + 'stats': {}, + 'children': [], + }) + continue + if '|CUMULATIVE_LIMIT_USAGE' in line and 'USAGE_END' not in line: + in_cumulative_limits = True + in_testing_limits = False + continue + if '|TESTING_LIMITS' in line: + in_testing_limits = True + in_cumulative_limits = False + continue + if '|LIMIT_USAGE_FOR_NS|(default)|' in line: + # Parse the start of the limits section + in_limits = True + continue + if in_limits and ':' not in line: + # Parse the end of the limits section + in_limits = False + in_cumulative_limits = False + in_testing_limits = False + continue + if in_limits: + # Parse the limit name, used, and allowed values + limit, value = line.split(': ') + if in_testing_limits: + limit = 'TESTING_LIMITS: {}'.format(limit) + used, allowed = value.split(' out of ') + stats[limit] = {'used': used, 'allowed': allowed} + continue + if '|CODE_UNIT_FINISHED|{}.{}'.format(class_name, method) in line: + # Handle the finish of test methods + end_timestamp = line.split(' ')[0] + stats['duration'] = self._log_time_delta( + method_unit_info['start_timestamp'], end_timestamp) + # Yield the stats for the method + yield method, stats, children + last_stats = stats.copy() + stats = {} + in_cumulative_limits = False + in_limits = False + elif '|CODE_UNIT_FINISHED|' in line: + # Handle all other code units finishing + end_timestamp = line.split(' ')[0] + stats['duration'] = self._log_time_delta( + method_unit_info['start_timestamp'], end_timestamp) + try: + child = stack.pop() + except: + # Skip if there was no stack. This seems to have have + # started in Spring 16 where the debug log will contain + # CODE_UNIT_FINISHED lines which have no matching + # CODE_UNIT_STARTED from earlier in the file. + continue + child['stats'] = stats + if not stack: + # Add the child to the main children list + children.append(child) + else: + # Add this child to its parent + stack[-1]['children'].append(child) + stats = {} + in_cumulative_limits = False + in_limits = False + if '* MAXIMUM DEBUG LOG SIZE REACHED *' in line: + # If debug log size limit was reached, fail gracefully + break + + def _parse_unit_started(self, class_name, line): + unit = line.split('|')[-1] + unit_type = 'other' + unit_info = {} + if unit.startswith(class_name + '.'): + unit_type = 'test_method' + unit = unit.split('.')[-1] + elif 'trigger event' in unit: + unit_type = 'trigger' + unit, obj, event = re.match( + r'(.*) on (.*) trigger event (.*) for.*', unit).groups() + unit_info = {'event': event, 'object': obj} + # Add the start timestamp to unit_info + unit_info['start_timestamp'] = line.split(' ')[0] + return unit, unit_type, unit_info + + def _write_output(self, test_results): + results_filename = self.options['results_filename'] + with io.open(results_filename, mode='w', encoding='utf-8') as f: + f.write(unicode(json.dumps(test_results))) diff --git a/cumulusci/tasks/tests/test_salesforce.py b/cumulusci/tasks/tests/test_salesforce.py index a5deaad002..c3c69e9ac3 100644 --- a/cumulusci/tasks/tests/test_salesforce.py +++ b/cumulusci/tasks/tests/test_salesforce.py @@ -10,6 +10,8 @@ from cumulusci.core.keychain import BaseProjectKeychain from cumulusci.tasks.salesforce import BaseSalesforceToolingApiTask from cumulusci.tasks.salesforce import RunApexTests +from cumulusci.tasks.salesforce import RunApexTestsDebug + class TestBaseSalesforceToolingApiTask(unittest.TestCase): @@ -38,6 +40,7 @@ def test_get_tooling_object(self): url = self.base_tooling_url + 'sobjects/TestObject/' self.assertEqual(obj.base_url, url) + class TestRunApexTests(unittest.TestCase): def setUp(self): @@ -46,7 +49,7 @@ def setUp(self): {'project': {'api_version': self.api_version}}) self.task_config = TaskConfig() self.task_config.config['options'] = { - 'junit_output': 'test_junit_output.txt', + 'results_filename': self._get_results_filename(), 'poll_interval': 1, 'test_name_match': '%_TEST', } @@ -56,18 +59,39 @@ def setUp(self): keychain = BaseProjectKeychain(self.project_config, '') self.project_config.set_keychain(keychain) self.org_config = OrgConfig({ + 'id': 'foo/1', 'instance_url': 'example.com', 'access_token': 'abc123', }) self.base_tooling_url = 'https://{}/services/data/v{}/tooling/'.format( self.org_config.instance_url, self.api_version) + def _get_results_filename(self): + return 'results_junit.xml' + def _mock_apex_class_query(self): - url = self.base_tooling_url + 'query/' + url = (self.base_tooling_url + 'query/?q=SELECT+Id%2C+Name+' + + 'FROM+ApexClass+WHERE+NamespacePrefix+%3D+null' + + '+AND+%28Name+LIKE+%27%25_TEST%27%29') + expected_response = { + 'done': True, + 'records': [{'Id': 1, 'Name': 'TestClass_TEST'}], + 'totalSize': 1, + } + responses.add(responses.GET, url, match_querystring=True, + json=expected_response) + + def _mock_get_test_results(self): + url = (self.base_tooling_url + 'query/?q=SELECT+StackTrace%2C+' + + 'Message%2C+ApexLogId%2C+AsyncApexJobId%2C+MethodName%2C+' + + 'Outcome%2C+ApexClassId%2C+TestTimestamp+FROM+ApexTestResult+' + + 'WHERE+AsyncApexJobId+%3D+%27OrderedDict%28%5B%28u%27' + + 'foo%27%2C+u%27bar%27%29%5D%29%27') expected_response = { 'done': True, 'records': [{ 'ApexClassId': 1, + 'ApexLogId': 1, 'Id': 1, 'Message': 'Test passed', 'MethodName': 'TestMethod', @@ -76,22 +100,100 @@ def _mock_apex_class_query(self): 'StackTrace': '1. ParentFunction\n2. ChildFunction', 'Status': 'Completed', }], - 'totalSize': 1, } - responses.add(responses.GET, url, json=expected_response) + responses.add(responses.GET, url, match_querystring=True, + json=expected_response) + + def _mock_tests_complete(self): + url = (self.base_tooling_url + 'query/?q=SELECT+Id%2C+Status%2C+' + + 'ApexClassId+FROM+ApexTestQueueItem+WHERE+ParentJobId+%3D+%27' + + 'OrderedDict%28%5B%28u%27foo%27%2C+u%27bar%27%29%5D%29%27') + expected_response = { + 'done': True, + 'records': [{'Status': 'Completed'}], + } + responses.add(responses.GET, url, match_querystring=True, + json=expected_response) def _mock_run_tests(self): url = self.base_tooling_url + 'runTestsAsynchronous' + expected_response = {'foo': 'bar'} + responses.add(responses.GET, url, json=expected_response) + + @responses.activate + def test_run_task(self): + self._mock_apex_class_query() + self._mock_run_tests() + self._mock_tests_complete() + self._mock_get_test_results() + task = RunApexTests( + self.project_config, self.task_config, self.org_config) + with patch.object(OrgConfig, 'refresh_oauth_token'): + task() + self.assertEqual(len(responses.calls), 4) + + +class TestRunApexTestsDebug(TestRunApexTests): + + def _get_results_filename(self): + return 'results.json' + + def _mock_create_trace_flag(self): + url = self.base_tooling_url.replace('/tooling/', + '/sobjects/TraceFlag/') + expected_response = { + 'id': 1, + } + responses.add(responses.POST, url, json=expected_response) + + def _mock_delete_trace_flags(self): + url = self.base_tooling_url.replace('/tooling/', + '/sobjects/TraceFlag/1') + responses.add(responses.DELETE, url) + + def _mock_get_duration(self): + url = (self.base_tooling_url + 'query/?q=SELECT+Id%2C+' + + 'Application%2C+DurationMilliseconds%2C+Location%2C+LogLength%2C' + + '+LogUserId%2C+Operation%2C+Request%2C+StartTime%2C+Status+' + + 'from+ApexLog+where+Id+in+%28%271%27%29') + expected_response = { + 'done': True, + 'records': [{'Id': 1, 'DurationMilliseconds': 1}], + 'totalSize': 1, + } + responses.add(responses.GET, url, match_querystring=True, + json=expected_response) + + def _mock_get_log_body(self): + url = self.base_tooling_url + 'sobjects/ApexLog/1/Body' expected_response = { 'foo': 'bar', } responses.add(responses.GET, url, json=expected_response) + def _mock_get_trace_flags(self): + url = self.base_tooling_url + 'query/?q=Select+Id+from+TraceFlag' + expected_response = { + 'records': [{'Id': 1}], + 'totalSize': 1, + } + responses.add(responses.GET, url, match_querystring=True, + json=expected_response) + @responses.activate def test_run_task(self): self._mock_apex_class_query() + self._mock_get_trace_flags() + self._mock_delete_trace_flags() + self._mock_create_trace_flag() self._mock_run_tests() - task = RunApexTests( + self._mock_tests_complete() + self._mock_get_test_results() + self._mock_get_duration() + self._mock_get_log_body() + self._mock_delete_trace_flags() + task = RunApexTestsDebug( self.project_config, self.task_config, self.org_config) with patch.object(OrgConfig, 'refresh_oauth_token'): task() + self.assertEqual(len(responses.calls), 10)