-
Notifications
You must be signed in to change notification settings - Fork 44.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #372 from BillSchumacher/redis-backend
Implement Local Cache and Redis Memory backend. Removes dependence on Pinecone
- Loading branch information
Showing
10 changed files
with
391 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from memory.local import LocalCache | ||
try: | ||
from memory.redismem import RedisMemory | ||
except ImportError: | ||
print("Redis not installed. Skipping import.") | ||
RedisMemory = None | ||
|
||
try: | ||
from memory.pinecone import PineconeMemory | ||
except ImportError: | ||
print("Pinecone not installed. Skipping import.") | ||
PineconeMemory = None | ||
|
||
|
||
def get_memory(cfg, init=False): | ||
memory = None | ||
if cfg.memory_backend == "pinecone": | ||
if not PineconeMemory: | ||
print("Error: Pinecone is not installed. Please install pinecone" | ||
" to use Pinecone as a memory backend.") | ||
else: | ||
memory = PineconeMemory(cfg) | ||
if init: | ||
memory.clear() | ||
elif cfg.memory_backend == "redis": | ||
if not RedisMemory: | ||
print("Error: Redis is not installed. Please install redis-py to" | ||
" use Redis as a memory backend.") | ||
else: | ||
memory = RedisMemory(cfg) | ||
|
||
if memory is None: | ||
memory = LocalCache(cfg) | ||
if init: | ||
memory.clear() | ||
return memory | ||
|
||
|
||
__all__ = [ | ||
"get_memory", | ||
"LocalCache", | ||
"RedisMemory", | ||
"PineconeMemory", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Base class for memory providers.""" | ||
import abc | ||
from config import AbstractSingleton | ||
import openai | ||
|
||
|
||
def get_ada_embedding(text): | ||
text = text.replace("\n", " ") | ||
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"] | ||
|
||
|
||
class MemoryProviderSingleton(AbstractSingleton): | ||
@abc.abstractmethod | ||
def add(self, data): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def get(self, data): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def clear(self): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def get_relevant(self, data, num_relevant=5): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def get_stats(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import dataclasses | ||
import orjson | ||
from typing import Any, List, Optional | ||
import numpy as np | ||
import os | ||
from memory.base import MemoryProviderSingleton, get_ada_embedding | ||
|
||
|
||
EMBED_DIM = 1536 | ||
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS | ||
|
||
|
||
def create_default_embeddings(): | ||
return np.zeros((0, EMBED_DIM)).astype(np.float32) | ||
|
||
|
||
@dataclasses.dataclass | ||
class CacheContent: | ||
texts: List[str] = dataclasses.field(default_factory=list) | ||
embeddings: np.ndarray = dataclasses.field( | ||
default_factory=create_default_embeddings | ||
) | ||
|
||
|
||
class LocalCache(MemoryProviderSingleton): | ||
|
||
# on load, load our database | ||
def __init__(self, cfg) -> None: | ||
self.filename = f"{cfg.memory_index}.json" | ||
if os.path.exists(self.filename): | ||
with open(self.filename, 'rb') as f: | ||
loaded = orjson.loads(f.read()) | ||
self.data = CacheContent(**loaded) | ||
else: | ||
self.data = CacheContent() | ||
|
||
def add(self, text: str): | ||
""" | ||
Add text to our list of texts, add embedding as row to our | ||
embeddings-matrix | ||
Args: | ||
text: str | ||
Returns: None | ||
""" | ||
if 'Command Error:' in text: | ||
return "" | ||
self.data.texts.append(text) | ||
|
||
embedding = get_ada_embedding(text) | ||
|
||
vector = np.array(embedding).astype(np.float32) | ||
vector = vector[np.newaxis, :] | ||
self.data.embeddings = np.concatenate( | ||
[ | ||
vector, | ||
self.data.embeddings, | ||
], | ||
axis=0, | ||
) | ||
|
||
with open(self.filename, 'wb') as f: | ||
out = orjson.dumps( | ||
self.data, | ||
option=SAVE_OPTIONS | ||
) | ||
f.write(out) | ||
return text | ||
|
||
def clear(self) -> str: | ||
""" | ||
Clears the redis server. | ||
Returns: A message indicating that the memory has been cleared. | ||
""" | ||
self.data = CacheContent() | ||
return "Obliviated" | ||
|
||
def get(self, data: str) -> Optional[List[Any]]: | ||
""" | ||
Gets the data from the memory that is most relevant to the given data. | ||
Args: | ||
data: The data to compare to. | ||
Returns: The most relevant data. | ||
""" | ||
return self.get_relevant(data, 1) | ||
|
||
def get_relevant(self, text: str, k: int) -> List[Any]: | ||
"""" | ||
matrix-vector mult to find score-for-each-row-of-matrix | ||
get indices for top-k winning scores | ||
return texts for those indices | ||
Args: | ||
text: str | ||
k: int | ||
Returns: List[str] | ||
""" | ||
embedding = get_ada_embedding(text) | ||
|
||
scores = np.dot(self.data.embeddings, embedding) | ||
|
||
top_k_indices = np.argsort(scores)[-k:][::-1] | ||
|
||
return [self.data.texts[i] for i in top_k_indices] | ||
|
||
def get_stats(self): | ||
""" | ||
Returns: The stats of the local cache. | ||
""" | ||
return len(self.data.texts), self.data.embeddings.shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.