-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* DSPy instrumentation basic * Fix * Fix * remove hardcodings * Bump version
- Loading branch information
1 parent
aa60509
commit 4316b4c
Showing
12 changed files
with
530 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import dspy | ||
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric | ||
from dspy.teleprompt import BootstrapFewShot | ||
|
||
# flake8: noqa | ||
from langtrace_python_sdk import langtrace, with_langtrace_root_span | ||
|
||
langtrace.init() | ||
|
||
turbo = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=250) | ||
dspy.settings.configure(lm=turbo) | ||
|
||
# Load math questions from the GSM8K dataset | ||
gsm8k = GSM8K() | ||
gsm8k_trainset, gsm8k_devset = gsm8k.train[:10], gsm8k.dev[:10] | ||
|
||
|
||
class CoT(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.prog = dspy.ChainOfThought("question -> answer") | ||
|
||
def forward(self, question): | ||
return self.prog(question=question) | ||
|
||
|
||
@with_langtrace_root_span(name="math_problems_cot_example") | ||
def example(): | ||
|
||
# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 4-shot examples of our CoT program. | ||
config = dict(max_bootstrapped_demos=4, max_labeled_demos=4) | ||
|
||
# Optimize! Use the `gsm8k_metric` here. In general, the metric is going to tell the optimizer how well it's doing. | ||
teleprompter = BootstrapFewShot(metric=gsm8k_metric, **config) | ||
optimized_cot = teleprompter.compile(CoT(), trainset=gsm8k_trainset) | ||
|
||
ans = optimized_cot(question="What is the sqrt of 345?") | ||
print(ans) | ||
|
||
|
||
if __name__ == "__main__": | ||
example() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import dspy | ||
|
||
# flake8: noqa | ||
from langtrace_python_sdk import langtrace, with_langtrace_root_span | ||
|
||
langtrace.init() | ||
|
||
turbo = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=250) | ||
dspy.settings.configure(lm=turbo) | ||
|
||
|
||
# Define a simple signature for basic question answering | ||
class BasicQA(dspy.Signature): | ||
"""Answer questions with short factoid answers.""" | ||
|
||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
|
||
@with_langtrace_root_span(name="pot_example") | ||
def example(): | ||
|
||
# Pass signature to ProgramOfThought Module | ||
pot = dspy.ProgramOfThought(BasicQA) | ||
|
||
# Call the ProgramOfThought module on a particular input | ||
question = "Sarah has 5 apples. She buys 7 more apples from the store. How many apples does Sarah have now?" | ||
result = pot(question=question) | ||
|
||
print(f"Question: {question}") | ||
print(f"Final Predicted Answer (after ProgramOfThought process): {result.answer}") | ||
|
||
|
||
if __name__ == "__main__": | ||
example() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import dspy | ||
import json | ||
from dspy.datasets import HotPotQA | ||
from dspy.teleprompt import BootstrapFewShot | ||
from dspy.evaluate.evaluate import Evaluate | ||
|
||
# flake8: noqa | ||
from langtrace_python_sdk import langtrace, with_langtrace_root_span | ||
|
||
langtrace.init() | ||
|
||
|
||
colbertv2_wiki17_abstracts = dspy.ColBERTv2( | ||
url="http://20.102.90.50:2017/wiki17_abstracts" | ||
) | ||
dspy.settings.configure(rm=colbertv2_wiki17_abstracts) | ||
turbo = dspy.OpenAI(model="gpt-3.5-turbo-0613", max_tokens=500) | ||
dspy.settings.configure(lm=turbo, trace=[], temperature=0.7) | ||
|
||
dataset = HotPotQA( | ||
train_seed=1, | ||
train_size=300, | ||
eval_seed=2023, | ||
dev_size=300, | ||
test_size=0, | ||
keep_details=True, | ||
) | ||
trainset = [x.with_inputs("question", "answer") for x in dataset.train] | ||
devset = [x.with_inputs("question", "answer") for x in dataset.dev] | ||
|
||
|
||
class GenerateAnswerChoices(dspy.Signature): | ||
"""Generate answer choices in JSON format that include the correct answer and plausible distractors for the specified question.""" | ||
|
||
question = dspy.InputField() | ||
correct_answer = dspy.InputField() | ||
number_of_choices = dspy.InputField() | ||
answer_choices = dspy.OutputField(desc="JSON key-value pairs") | ||
|
||
|
||
class QuizAnswerGenerator(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.prog = dspy.ChainOfThought(GenerateAnswerChoices) | ||
|
||
def forward(self, question, answer): | ||
choices = self.prog( | ||
question=question, correct_answer=answer, number_of_choices="4" | ||
).answer_choices | ||
# dspy.Suggest( | ||
# format_checker(choices), | ||
# "The format of the answer choices should be in JSON format. Please revise accordingly.", | ||
# target_module=GenerateAnswerChoices, | ||
# ) | ||
return dspy.Prediction(choices=choices) | ||
|
||
|
||
def format_checker(choice_string): | ||
try: | ||
choices = json.loads(choice_string) | ||
if isinstance(choices, dict) and all( | ||
isinstance(key, str) and isinstance(value, str) | ||
for key, value in choices.items() | ||
): | ||
return True | ||
except json.JSONDecodeError: | ||
return False | ||
|
||
return False | ||
|
||
|
||
def format_valid_metric(gold, pred, trace=None): | ||
generated_choices = pred.choices | ||
format_valid = format_checker(generated_choices) | ||
score = format_valid | ||
return score | ||
|
||
|
||
@with_langtrace_root_span(name="quiz_generator_1") | ||
def quiz_generator_1(): | ||
quiz_generator = QuizAnswerGenerator() | ||
|
||
example = devset[67] | ||
print("Example Question: ", example.question) | ||
print("Example Answer: ", example.answer) | ||
# quiz_choices = quiz_generator(question=example.question, answer=example.answer) | ||
# print("Generated Quiz Choices: ", quiz_choices.choices) | ||
|
||
optimizer = BootstrapFewShot( | ||
metric=format_valid_metric, max_bootstrapped_demos=4, max_labeled_demos=4 | ||
) | ||
compiled_quiz_generator = optimizer.compile( | ||
quiz_generator, | ||
trainset=trainset, | ||
) | ||
quiz_choices = compiled_quiz_generator( | ||
question=example.question, answer=example.answer | ||
) | ||
print("Generated Quiz Choices: ", quiz_choices.choices) | ||
|
||
# Evaluate | ||
evaluate = Evaluate( | ||
metric=format_valid_metric, | ||
devset=devset[67:70], | ||
num_threads=1, | ||
display_progress=True, | ||
display_table=5, | ||
) | ||
evaluate(quiz_generator) | ||
|
||
|
||
if __name__ == "__main__": | ||
quiz_generator_1() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import sys | ||
import os | ||
import dspy | ||
|
||
# Add the local src folder to the Python path | ||
sys.path.insert(0, os.path.abspath('/Users/karthikkalyanaraman/work/langtrace/langtrace-python-sdk/src')) | ||
|
||
# flake8: noqa | ||
from langtrace_python_sdk import langtrace, with_langtrace_root_span | ||
langtrace.init() | ||
|
||
turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=250) | ||
dspy.settings.configure(lm=turbo) | ||
|
||
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') | ||
dspy.settings.configure(rm=colbertv2_wiki17_abstracts) | ||
retriever = dspy.Retrieve(k=3) | ||
|
||
# Define a simple signature for basic question answering | ||
class BasicQA(dspy.Signature): | ||
"""Answer questions with short factoid answers.""" | ||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
@with_langtrace_root_span(name="react_example") | ||
def example(): | ||
|
||
# Pass signature to ReAct module | ||
react_module = dspy.ReAct(BasicQA) | ||
|
||
# Call the ReAct module on a particular input | ||
question = 'Aside from the Apple Remote, what other devices can control the program Apple Remote was originally designed to interact with?' | ||
result = react_module(question=question) | ||
|
||
print(f"Question: {question}") | ||
print(f"Final Predicted Answer (after ReAct process): {result.answer}") | ||
|
||
if __name__ == '__main__': | ||
example() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .instrumentation import DspyInstrumentor | ||
|
||
__all__ = ["DspyInstrumentor"] |
85 changes: 85 additions & 0 deletions
85
src/langtrace_python_sdk/instrumentation/dspy/instrumentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
Copyright (c) 2024 Scale3 Labs | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor | ||
from opentelemetry.trace import get_tracer | ||
from wrapt import wrap_function_wrapper as _W | ||
from typing import Collection | ||
from importlib_metadata import version as v | ||
from .patch import patch_bootstrapfewshot_optimizer, patch_signature, patch_evaluate | ||
|
||
|
||
class DspyInstrumentor(BaseInstrumentor): | ||
""" | ||
The DspyInstrumentor class represents the DSPy instrumentation""" | ||
|
||
def instrumentation_dependencies(self) -> Collection[str]: | ||
return ["dspy >= 0.1.5"] | ||
|
||
def _instrument(self, **kwargs): | ||
tracer_provider = kwargs.get("tracer_provider") | ||
tracer = get_tracer(__name__, "", tracer_provider) | ||
version = v("dspy") | ||
_W( | ||
"dspy.teleprompt.bootstrap", | ||
"BootstrapFewShot.compile", | ||
patch_bootstrapfewshot_optimizer( | ||
"BootstrapFewShot.compile", version, tracer | ||
), | ||
) | ||
_W( | ||
"dspy.predict.predict", | ||
"Predict.forward", | ||
patch_signature("Predict.forward", version, tracer), | ||
) | ||
_W( | ||
"dspy.predict.chain_of_thought", | ||
"ChainOfThought.forward", | ||
patch_signature("ChainOfThought.forward", version, tracer), | ||
) | ||
_W( | ||
"dspy.predict.chain_of_thought_with_hint", | ||
"ChainOfThoughtWithHint.forward", | ||
patch_signature("ChainOfThoughtWithHint.forward", version, tracer), | ||
) | ||
_W( | ||
"dspy.predict.react", | ||
"ReAct.forward", | ||
patch_signature("ReAct.forward", version, tracer), | ||
) | ||
_W( | ||
"dspy.predict.program_of_thought", | ||
"ProgramOfThought.forward", | ||
patch_signature("ProgramOfThought.forward", version, tracer), | ||
) | ||
_W( | ||
"dspy.predict.multi_chain_comparison", | ||
"MultiChainComparison.forward", | ||
patch_signature("MultiChainComparison.forward", version, tracer), | ||
) | ||
_W( | ||
"dspy.predict.retry", | ||
"Retry.forward", | ||
patch_signature("Retry.forward", version, tracer), | ||
) | ||
_W( | ||
"dspy.evaluate.evaluate", | ||
"Evaluate.__call__", | ||
patch_evaluate("Evaluate", version, tracer), | ||
) | ||
|
||
def _uninstrument(self, **kwargs): | ||
pass |
Oops, something went wrong.