Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run black formatting #9

Merged
merged 1 commit into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rl_chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import pick_best_chain
from . import pick_best_prompt
from . import slates_chain
from .rl_chain_base import Embed, ResponseValidator, Embedder
from .rl_chain_base import Embed, ResponseValidator, Embedder
54 changes: 31 additions & 23 deletions rl_chain/pick_best_chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import rl_chain_base as base
from . import rl_chain_base as base

from langchain.prompts import (
ChatPromptTemplate,
Expand All @@ -19,6 +19,7 @@
from langchain.chains.llm import LLMChain
from sentence_transformers import SentenceTransformer


class ContextualBanditTextEmbedder(base.Embedder):
"""
Contextual Bandit Text Embedder class that embeds the context and actions into a format that can be used by VW
Expand Down Expand Up @@ -55,7 +56,9 @@ def embed_context(self, context: Any):
"""
return base.embed(context, self.model, "Context")

def to_vw_format(self, inputs :Dict[str, Any], cb_label: Optional[Tuple]=None) -> str:
def to_vw_format(
self, inputs: Dict[str, Any], cb_label: Optional[Tuple] = None
) -> str:
"""
Converts the context and actions into a format that can be used by VW

Expand All @@ -71,14 +74,16 @@ def to_vw_format(self, inputs :Dict[str, Any], cb_label: Optional[Tuple]=None) -
if cb_label:
chosen_action, cost, prob = cb_label

context = inputs.get('context')
actions = inputs.get('actions')
context = inputs.get("context")
actions = inputs.get("actions")

context_emb = self.embed_context(context) if context else None
action_embs = self.embed_actions(actions) if actions else None

if not context_emb or not action_embs:
raise ValueError("Context and actions must be provided in the inputs dictionary")
raise ValueError(
"Context and actions must be provided in the inputs dictionary"
)

example_string = ""
example_string += f"shared "
Expand All @@ -96,6 +101,7 @@ def to_vw_format(self, inputs :Dict[str, Any], cb_label: Optional[Tuple]=None) -
# Strip the last newline
return example_string[:-1]


class AutoValidatePickBest(base.ResponseValidator):
llm_chain: LLMChain
prompt: PromptTemplate
Expand All @@ -107,7 +113,7 @@ def __init__(self, llm, prompt=None):
template = "PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
You must respond with your ranking by providing a single float within the range [-1, 1], -1 being very bad response and 1 being very good response."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "Given this context \"{context}\" as the most important attribute, rank how good or bad this text selection is: \"{selected}\"."
human_template = 'Given this context "{context}" as the most important attribute, rank how good or bad this text selection is: "{selected}".'
human_message_prompt = HumanMessagePromptTemplate.from_template(
human_template
)
Expand All @@ -119,7 +125,9 @@ def __init__(self, llm, prompt=None):

self.llm_chain = LLMChain(llm=llm, prompt=self.prompt)

def grade_response(self, inputs: Dict[str, Any], llm_response: str, **kwargs) -> float:
def grade_response(
self, inputs: Dict[str, Any], llm_response: str, **kwargs
) -> float:
inputs["llm_response"] = llm_response
inputs["selected"] = inputs["selected"]
ranking = self.llm_chain.predict(**inputs)
Expand Down Expand Up @@ -196,11 +204,11 @@ def __init__(self, *args, **kwargs):
raise ValueError(
"If vw_cmd is specified, it must include --cb_explore_adf"
)

kwargs["vw_cmd"] = vw_cmd

super().__init__(*args, **kwargs)

@property
def input_keys(self) -> List[str]:
"""Expect input key.
Expand Down Expand Up @@ -248,7 +256,7 @@ def _call(
sampled_prob = sampled_ap[1]

pred_action = actions[sampled_action]
inputs['selected'] = pred_action
inputs["selected"] = pred_action

llm_resp: Dict[str, Any] = super()._call(run_manager=run_manager, inputs=inputs)

Expand All @@ -259,14 +267,13 @@ def _call(
if self.response_validator:
try:
cost = -1.0 * self.response_validator.grade_response(
inputs=inputs,
llm_response=llm_resp[self.output_key] )
inputs=inputs, llm_response=llm_resp[self.output_key]
)
latest_cost = cost
cb_label = (sampled_action, cost, sampled_prob)

vw_ex = self.text_embedder.to_vw_format(
cb_label=cb_label,
inputs=inputs,
cb_label=cb_label, inputs=inputs
)
self._learn(vw_ex)

Expand Down Expand Up @@ -294,13 +301,15 @@ def _chain_type(self) -> str:
return "llm_personalizer_chain"

@classmethod
def from_chain(cls, llm_chain: Chain, prompt: PromptTemplate = PROMPT, **kwargs: Any):
return PickBest(
llm_chain=llm_chain, prompt=prompt, **kwargs
)
def from_chain(
cls, llm_chain: Chain, prompt: PromptTemplate = PROMPT, **kwargs: Any
):
return PickBest(llm_chain=llm_chain, prompt=prompt, **kwargs)

@classmethod
def from_llm(cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any):
def from_llm(
cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any
):
llm_chain = LLMChain(llm=llm, prompt=prompt)
return PickBest.from_chain(llm_chain=llm_chain, prompt=prompt, **kwargs)

Expand All @@ -317,15 +326,14 @@ def learn_delayed_reward(
raise RuntimeError(
"check_response is set to True, this must be turned off for explicit feedback and training to be provided, or overriden by calling the method with force_reward=True"
)
cost = -1.0 * reward
cost = -1.0 * reward
cb_label = (
response_result.chosen_action,
cost,
response_result.chosen_action_probability,
)

vw_ex = self.text_embedder.to_vw_format(
cb_label=cb_label,
inputs=response_result.inputs,
cb_label=cb_label, inputs=response_result.inputs
)
self._learn(vw_ex)
self._learn(vw_ex)
3 changes: 1 addition & 2 deletions rl_chain/pick_best_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@


PROMPT = PromptTemplate(
input_variables=["selected", "text_to_personalize"],
template=_PROMPT_TEMPLATE,
input_variables=["selected", "text_to_personalize"], template=_PROMPT_TEMPLATE
)
46 changes: 33 additions & 13 deletions rl_chain/rl_chain_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,33 @@
ch.setLevel(logging.INFO)
logger.addHandler(ch)


class _Embed:
def __init__(self, impl):
self.impl = impl

def __str__(self):
return self.impl


def Embed(anything):
if isinstance(anything, list):
return [Embed(v) for v in anything]
elif isinstance(anything, dict):
return {k: _Embed(v) for k, v in anything.items()}
return _Embed(anything)


def parse_lines(parser: vw.TextFormatParser, input_str: str) -> List[vw.Example]:
return [parser.parse_line(line) for line in input_str.split("\n")]


class Embedder(ABC):
@abstractmethod
def to_vw_format(self, **kwargs) -> str:
pass


class ResponseValidator(ABC):
"""Abstract method to grade the chosen action or the response of the llm"""

Expand All @@ -56,6 +61,7 @@ def grade_response(
) -> float:
pass


class RLChain(Chain):
"""
RLChain class that utilizes the Vowpal Wabbit (VW) model for personalization.
Expand Down Expand Up @@ -156,10 +162,7 @@ def _call(
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()

t = self.llm_chain.run(
**inputs,
callbacks=_run_manager.get_child(),
)
t = self.llm_chain.run(**inputs, callbacks=_run_manager.get_child())
_run_manager.on_text(t, color="green", verbose=self.verbose)
t = t.strip()

Expand Down Expand Up @@ -201,11 +204,17 @@ def _learn(self, vw_ex):
multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)


def is_stringtype_instance(item: Any) -> bool:
"""Helper function to check if an item is a string."""
return isinstance(item, str) or (isinstance(item, _Embed) and isinstance(item.impl, str))
return isinstance(item, str) or (
isinstance(item, _Embed) and isinstance(item.impl, str)
)

def embed_string_type(item: Union[str, _Embed], model: Any, namespace: Optional[str] = None) -> Dict[str, str]:

def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, str]:
"""Helper function to embed a string or an _Embed object."""
join_char = ""
if isinstance(item, _Embed):
Expand All @@ -218,11 +227,14 @@ def embed_string_type(item: Union[str, _Embed], model: Any, namespace: Optional[
raise ValueError(f"Unsupported type {type(item)} for embedding")

if namespace is None:
raise ValueError("The default namespace must be provided when embedding a string or _Embed object.")

raise ValueError(
"The default namespace must be provided when embedding a string or _Embed object."
)

return {namespace: join_char.join(map(str, encoded))}

def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str,List[str]]]:

def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a dictionary item."""
inner_dict = {}
for ns, embed_item in item.items():
Expand All @@ -235,7 +247,10 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str,List[str]]]:
inner_dict.update(embed_string_type(embed_item, model, ns))
return inner_dict

def embed_list_type(item: list, model: Any, namespace: Optional[str] = None) -> List[Dict[str, Union[str, List[str]]]]:

def embed_list_type(
item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]:
ret_list = []
for embed_item in item:
if isinstance(embed_item, dict):
Expand All @@ -244,11 +259,14 @@ def embed_list_type(item: list, model: Any, namespace: Optional[str] = None) ->
ret_list.append(embed_string_type(embed_item, model, namespace))
return ret_list


def embed(
to_embed: Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]],
to_embed: Union[
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
],
model: Any,
namespace: Optional[str] = None,
) -> List[Dict[str, Union[str,List[str]]]]:
) -> List[Dict[str, Union[str, List[str]]]]:
"""
Embeds the actions or context using the SentenceTransformer model

Expand All @@ -259,7 +277,9 @@ def embed(
Returns:
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
"""
if (isinstance(to_embed, _Embed) and isinstance(to_embed.impl, str)) or isinstance(to_embed, str):
if (isinstance(to_embed, _Embed) and isinstance(to_embed.impl, str)) or isinstance(
to_embed, str
):
return [embed_string_type(to_embed, model, namespace)]
elif isinstance(to_embed, dict):
return [embed_dict_type(to_embed, model)]
Expand Down
Loading