Skip to content

Commit

Permalink
ai21[patch]: AI21 Labs Contextual Answers support (langchain-ai#18270)
Browse files Browse the repository at this point in the history
Description: Added support for AI21 Labs model - Contextual Answers
Dependencies: ai21, ai21-tokenizer
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <[email protected]>
Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
3 people authored and gkorland committed Mar 30, 2024
1 parent 1407542 commit 43a2b4a
Show file tree
Hide file tree
Showing 10 changed files with 662 additions and 230 deletions.
97 changes: 85 additions & 12 deletions docs/docs/integrations/llms/ai21.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "59c710c4",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-05T20:58:42.397591Z",
"start_time": "2024-03-05T20:58:40.944729Z"
}
},
"outputs": [],
"source": [
"!pip install -qU langchain-ai21"
Expand All @@ -46,10 +51,14 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "035dea0f",
"metadata": {
"tags": []
"tags": [],
"ExecuteTime": {
"end_time": "2024-03-05T20:58:44.465443Z",
"start_time": "2024-03-05T20:58:42.399724Z"
}
},
"outputs": [],
"source": [
Expand All @@ -74,14 +83,16 @@
"execution_count": 6,
"id": "98f70927a87e4745",
"metadata": {
"collapsed": false
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-05T20:58:45.859265Z",
"start_time": "2024-03-05T20:58:44.466637Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"'\\nLangChain is a decentralized blockchain network that leverages AI and machine learning to provide language translation services.'"
]
"text/plain": "'\\nLangChain is a (database)\\nLangChain is a database for storing and processing documents'"
},
"execution_count": 6,
"metadata": {},
Expand All @@ -105,13 +116,75 @@
"chain.invoke({\"question\": \"What is LangChain?\"})"
]
},
{
"cell_type": "markdown",
"source": [
"# AI21 Contextual Answer\n",
"\n",
"You can use AI21's contextual answers model to receives text or document, serving as a context,\n",
"and a question and returns an answer based entirely on this context.\n",
"\n",
"This means that if the answer to your question is not in the document,\n",
"the model will indicate it (instead of providing a false answer)"
],
"metadata": {
"collapsed": false
},
"id": "9965c10269159ed1"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"from langchain_ai21 import AI21ContextualAnswers\n",
"\n",
"tsm = AI21ContextualAnswers()\n",
"\n",
"response = tsm.invoke(input={\"context\": \"Your context\", \"question\": \"Your question\"})"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-05T20:59:00.943426Z",
"start_time": "2024-03-05T20:59:00.263497Z"
}
},
"id": "411adf42eab80829",
"execution_count": 9
},
{
"cell_type": "markdown",
"source": [
"You can also use it with chains and output parsers and vector DBs"
],
"metadata": {
"collapsed": false
},
"id": "af59ffdbf4964875"
},
{
"cell_type": "code",
"execution_count": null,
"id": "a52f765c",
"metadata": {},
"outputs": [],
"source": []
"source": [
"from langchain_ai21 import AI21ContextualAnswers\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"\n",
"tsm = AI21ContextualAnswers()\n",
"chain = tsm | StrOutputParser()\n",
"\n",
"response = chain.invoke(\n",
" {\"context\": \"Your context\", \"question\": \"Your question\"},\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-05T20:59:07.719225Z",
"start_time": "2024-03-05T20:59:07.102950Z"
}
},
"id": "bc63830f921b4ac9",
"execution_count": 10
}
],
"metadata": {
Expand Down
30 changes: 30 additions & 0 deletions libs/partners/ai21/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,33 @@ from langchain_ai21 import AI21Embeddings
embeddings = AI21Embeddings()
embeddings.embed_documents(["Hello! This is document 1", "And this is document 2!"])
```

## Task Specific Models

### Contextual Answers

You can use AI21's contextual answers model to receives text or document, serving as a context,
and a question and returns an answer based entirely on this context.

This means that if the answer to your question is not in the document,
the model will indicate it (instead of providing a false answer)

```python
from langchain_ai21 import AI21ContextualAnswers

tsm = AI21ContextualAnswers()

response = tsm.invoke(input={"context": "Your context", "question": "Your question"})
```
You can also use it with chains and output parsers and vector DBs:
```python
from langchain_ai21 import AI21ContextualAnswers
from langchain_core.output_parsers import StrOutputParser

tsm = AI21ContextualAnswers()
chain = tsm | StrOutputParser()

response = chain.invoke(
{"context": "Your context", "question": "Your question"},
)
```
2 changes: 2 additions & 0 deletions libs/partners/ai21/langchain_ai21/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from langchain_ai21.chat_models import ChatAI21
from langchain_ai21.contextual_answers import AI21ContextualAnswers
from langchain_ai21.embeddings import AI21Embeddings
from langchain_ai21.llms import AI21LLM

__all__ = [
"AI21LLM",
"ChatAI21",
"AI21Embeddings",
"AI21ContextualAnswers",
]
108 changes: 108 additions & 0 deletions libs/partners/ai21/langchain_ai21/contextual_answers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import (
Any,
List,
Optional,
Tuple,
Type,
TypedDict,
Union,
)

from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config

from langchain_ai21.ai21_base import AI21Base

ANSWER_NOT_IN_CONTEXT_RESPONSE = "Answer not in context"

ContextType = Union[str, List[Union[Document, str]]]


class ContextualAnswerInput(TypedDict):
context: ContextType
question: str


class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base):
class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

@property
def InputType(self) -> Type[ContextualAnswerInput]:
"""Get the input type for this runnable."""
return ContextualAnswerInput

@property
def OutputType(self) -> Type[str]:
"""Get the input type for this runnable."""
return str

def invoke(
self,
input: ContextualAnswerInput,
config: Optional[RunnableConfig] = None,
response_if_no_answer_found: str = ANSWER_NOT_IN_CONTEXT_RESPONSE,
**kwargs: Any,
) -> str:
config = ensure_config(config)
return self._call_with_config(
func=lambda inner_input: self._call_contextual_answers(
inner_input, response_if_no_answer_found
),
input=input,
config=config,
run_type="llm",
)

def _call_contextual_answers(
self,
input: ContextualAnswerInput,
response_if_no_answer_found: str,
) -> str:
context, question = self._convert_input(input)
response = self.client.answer.create(context=context, question=question)

if response.answer is None:
return response_if_no_answer_found

return response.answer

def _convert_input(self, input: ContextualAnswerInput) -> Tuple[str, str]:
context, question = self._extract_context_and_question(input)

context = self._parse_context(context)

return context, question

def _extract_context_and_question(
self,
input: ContextualAnswerInput,
) -> Tuple[ContextType, str]:
context = input.get("context")
question = input.get("question")

if not context or not question:
raise ValueError(
f"Input must contain a 'context' and 'question' fields. Got {input}"
)

if not isinstance(context, list) and not isinstance(context, str):
raise ValueError(
f"Expected input to be a list of strings or Documents."
f" Received {type(input)}"
)

return context, question

def _parse_context(self, context: ContextType) -> str:
if isinstance(context, str):
return context

docs = [
item.page_content if isinstance(item, Document) else item
for item in context
]

return "\n".join(docs)
Loading

0 comments on commit 43a2b4a

Please sign in to comment.