Skip to content

Commit

Permalink
155 add source and test file diffs in database logging (#156)
Browse files Browse the repository at this point in the history
* Adding comments to top level CoverAgent class.

* Refactored out init runs from UnitTestGenerator.

* Added source file to DB.

* Added language.

* Refactored out DRY RUN and updated HTML template for DB insertion.

* Updated report template.

* Added report generation with poetry.

* Refactored unit test calls.

* Adding test analysis and summary report.

* Updated doc and fixed test analysis logic.
  • Loading branch information
EmbeddedDevops1 authored Sep 15, 2024
1 parent 5df4c89 commit d894997
Show file tree
Hide file tree
Showing 13 changed files with 470 additions and 333 deletions.
72 changes: 69 additions & 3 deletions cover_agent/CoverAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

class CoverAgent:
def __init__(self, args):
"""
Initialize the CoverAgent class with the provided arguments and run the test generation process.
Parameters:
args (Namespace): The parsed command-line arguments containing necessary information for test generation.
Returns:
None
"""
self.args = args
self.logger = CustomLogger.get_logger(__name__)

Expand All @@ -33,49 +42,99 @@ def __init__(self, args):
)

def _validate_paths(self):
"""
Validate the paths provided in the arguments.
Raises:
FileNotFoundError: If the source file or test file is not found at the specified paths.
"""
# Ensure the source file exists
if not os.path.isfile(self.args.source_file_path):
raise FileNotFoundError(
f"Source file not found at {self.args.source_file_path}"
)
# Ensure the test file exists
if not os.path.isfile(self.args.test_file_path):
raise FileNotFoundError(
f"Test file not found at {self.args.test_file_path}"
)
# Create default DB file if not provided
if not self.args.log_db_path:
# Create default DB file if not provided
self.args.log_db_path = "cover_agent_unit_test_runs.db"
# Connect to the test DB
self.test_db = UnitTestDB(db_connection_string=f"sqlite:///{self.args.log_db_path}")

def _duplicate_test_file(self):
"""
Initialize the CoverAgent class with the provided arguments and run the test generation process.
Parameters:
args (Namespace): The parsed command-line arguments containing necessary information for test generation.
Returns:
None
"""
# If the test file output path is set, copy the test file there
if self.args.test_file_output_path != "":
shutil.copy(self.args.test_file_path, self.args.test_file_output_path)
else:
# Otherwise, set the test file output path to the current test file
self.args.test_file_output_path = self.args.test_file_path

def run(self):
"""
Run the test generation process.
This method performs the following steps:
1. Initialize the Weights & Biases run if the WANDS_API_KEY environment variable is set.
2. Initialize variables to track progress.
3. Run the initial test suite analysis.
4. Loop until desired coverage is reached or maximum iterations are met.
5. Generate new tests.
6. Loop through each new test and validate it.
7. Insert the test result into the database.
8. Increment the iteration count.
9. Check if the desired coverage has been reached.
10. If the desired coverage has been reached, log the final coverage.
11. If the maximum iteration limit is reached, log a failure message if strict coverage is specified.
12. Provide metrics on total token usage.
13. Generate a report.
14. Finish the Weights & Biases run if it was initialized.
"""
# Check if user has exported the WANDS_API_KEY environment variable
if "WANDB_API_KEY" in os.environ:
# Initialize the Weights & Biases run
wandb.login(key=os.environ["WANDB_API_KEY"])
time_and_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name = f"{self.args.model}_" + time_and_date
wandb.init(project="cover-agent", name=run_name)

# Initialize variables to track progress
iteration_count = 0
test_results_list = []

# Run initial test suite analysis
self.test_gen.get_coverage_and_build_prompt()
self.test_gen.initial_test_suite_analysis()

# Loop until desired coverage is reached or maximum iterations are met
while (
self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100)
and iteration_count < self.args.max_iterations
):
# Log the current coverage
self.logger.info(
f"Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
)
self.logger.info(f"Desired Coverage: {self.test_gen.desired_coverage}%")

# Generate new tests
generated_tests_dict = self.test_gen.generate_tests(max_tokens=4096)

# Loop through each new test and validate it
for generated_test in generated_tests_dict.get("new_tests", []):
# Validate the test and record the result
test_result = self.test_gen.validate_test(
generated_test, self.args.run_tests_multiple_times
)
Expand All @@ -84,11 +143,15 @@ def run(self):
# Insert the test result into the database
self.test_db.insert_attempt(test_result)

# Increment the iteration count
iteration_count += 1

# Check if the desired coverage has been reached
if self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100):
# Run the coverage tool again if the desired coverage hasn't been reached
self.test_gen.run_coverage()

# Log the final coverage
if self.test_gen.current_coverage >= (self.test_gen.desired_coverage / 100):
self.logger.info(
f"Reached above target coverage of {self.test_gen.desired_coverage}% (Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%) in {iteration_count} iterations."
Expand All @@ -102,15 +165,18 @@ def run(self):
else:
self.logger.info(failure_message)

# Provide metric on total token usage
# Provide metrics on total token usage
self.logger.info(
f"Total number of input tokens used for LLM model {self.test_gen.ai_caller.model}: {self.test_gen.total_input_token_count}"
)
self.logger.info(
f"Total number of output tokens used for LLM model {self.test_gen.ai_caller.model}: {self.test_gen.total_output_token_count}"
)

ReportGenerator.generate_report(test_results_list, self.args.report_filepath)
# Generate a report
# ReportGenerator.generate_report(test_results_list, self.args.report_filepath)
self.test_db.dump_to_report(self.args.report_filepath)

# Finish the Weights & Biases run if it was initialized
if "WANDB_API_KEY" in os.environ:
wandb.finish()
95 changes: 86 additions & 9 deletions cover_agent/ReportGenerator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import difflib
from jinja2 import Template


class ReportGenerator:
# Enhanced HTML template with additional styling
# HTML template with fixed code formatting and dark background for the code block
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
Expand Down Expand Up @@ -39,10 +40,14 @@ class ReportGenerator:
color: red;
}
pre {
background-color: #000000 !important;
background-color: #282c34 !important;
color: #ffffff !important;
padding: 5px;
padding: 10px;
border-radius: 5px;
overflow-x: auto;
white-space: pre-wrap;
font-family: 'Courier New', Courier, monospace;
font-size: 1.1em; /* Slightly larger font size */
}
</style>
</head>
Expand All @@ -52,18 +57,31 @@ class ReportGenerator:
<th>Status</th>
<th>Reason</th>
<th>Exit Code</th>
<th>Stderr</th>
<th>Stdout</th>
<th>Test</th>
<th>Language</th>
<th>Modified Test File</th>
<th>Details</th>
</tr>
{% for result in results %}
<tr>
<td class="status-{{ result.status }}">{{ result.status }}</td>
<td>{{ result.reason }}</td>
<td>{{ result.exit_code }}</td>
<td>{% if result.stderr %}<pre><code class="language-shell">{{ result.stderr }}</code></pre>{% else %}&nbsp;{% endif %}</td>
<td>{% if result.stdout %}<pre><code class="language-shell">{{ result.stdout }}</code></pre>{% else %}&nbsp;{% endif %}</td>
<td>{% if result.test %}<pre><code class="language-python">{{ result.test }}</code></pre>{% else %}&nbsp;{% endif %}</td>
<td>{{ result.language }}</td>
<td>
<details>
<summary>View Full Code</summary>
<pre><code>{{ result.full_diff | safe }}</code></pre>
</details>
</td>
<td>
<details>
<summary>View More</summary>
<div><strong>STDERR:</strong> <pre><code class="language-{{ result.language|lower }}">{{ result.stderr }}</code></pre></div>
<div><strong>STDOUT:</strong> <pre><code class="language-{{ result.language|lower }}">{{ result.stdout }}</code></pre></div>
<div><strong>Test Code:</strong> <pre><code class="language-{{ result.language|lower }}">{{ result.test_code }}</code></pre></div>
<div><strong>Imports:</strong> <pre><code class="language-{{ result.language|lower }}">{{ result.imports }}</code></pre></div>
</details>
</td>
</tr>
{% endfor %}
</table>
Expand All @@ -72,6 +90,61 @@ class ReportGenerator:
</html>
"""

@classmethod
def generate_full_diff(cls, original, processed):
"""
Generates a full view of both the original and processed test files,
highlighting added, removed, and unchanged lines, showing the full code.
:param original: String content of the original test file.
:param processed: String content of the processed test file.
:return: Full diff string formatted for HTML display, highlighting added, removed, and unchanged lines.
"""
diff = difflib.ndiff(original.splitlines(), processed.splitlines())

diff_html = []
for line in diff:
if line.startswith('+'):
diff_html.append(f'<span class="diff-added">{line}</span>')
elif line.startswith('-'):
diff_html.append(f'<span class="diff-removed">{line}</span>')
else:
diff_html.append(f'<span class="diff-unchanged">{line}</span>')
return '\n'.join(diff_html)

@classmethod
def generate_partial_diff(cls, original, processed, context_lines=3):
"""
Generates a partial diff of both the original and processed test files,
showing only added, removed, or changed lines with a few lines of context.
:param original: String content of the original test file.
:param processed: String content of the processed test file.
:param context_lines: Number of context lines to include around changes.
:return: Partial diff string formatted for HTML display, highlighting added, removed, and unchanged lines with context.
"""
# Use unified_diff to generate a partial diff with context
diff = difflib.unified_diff(
original.splitlines(),
processed.splitlines(),
n=context_lines
)

diff_html = []
for line in diff:
if line.startswith('+') and not line.startswith('+++'):
diff_html.append(f'<span class="diff-added">{line}</span>')
elif line.startswith('-') and not line.startswith('---'):
diff_html.append(f'<span class="diff-removed">{line}</span>')
elif line.startswith('@@'):
# Highlight the diff context (line numbers)
diff_html.append(f'<span class="diff-context">{line}</span>')
else:
# Show unchanged lines as context
diff_html.append(f'<span class="diff-unchanged">{line}</span>')

return '\n'.join(diff_html)

@classmethod
def generate_report(cls, results, file_path):
"""
Expand All @@ -80,6 +153,10 @@ def generate_report(cls, results, file_path):
:param results: List of dictionaries with test results.
:param file_path: Path to the HTML file where the report will be written.
"""
# Generate the full diff for each result
for result in results:
result['full_diff'] = cls.generate_full_diff(result['original_test_file'], result['processed_test_file'])

template = Template(cls.HTML_TEMPLATE)
html_content = template.render(results=results)

Expand Down
Loading

0 comments on commit d894997

Please sign in to comment.