Skip to content

Commit

Permalink
feat: add index to avoid performance degradation (#1606)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Aug 7, 2024
2 parents 50840e9 + d1f400e commit bbe37fc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
4 changes: 4 additions & 0 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
JSON,
Column,
DateTime,
Index,
String,
TypeDecorator,
and_,
Expand Down Expand Up @@ -163,6 +164,8 @@ class PassageModel(Base):
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))

Index("passage_idx_user", user_id, agent_id, doc_id),

def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"

Expand Down Expand Up @@ -228,6 +231,7 @@ class MessageModel(Base):

# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
Index("message_idx_user", user_id, agent_id),

def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
Expand Down
10 changes: 10 additions & 0 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Boolean,
Column,
DateTime,
Index,
String,
TypeDecorator,
create_engine,
Expand Down Expand Up @@ -146,6 +147,9 @@ class TokenModel(Base):
# extra (optional) metadata
name = Column(String)

Index(__tablename__ + "_idx_user", user_id),
Index(__tablename__ + "_idx_token", token),

def __repr__(self) -> str:
return f"<Token(id='{self.id}', token='{self.token}', name='{self.name}')>"

Expand Down Expand Up @@ -189,6 +193,8 @@ class AgentModel(Base):
# tools
tools = Column(JSON)

Index(__tablename__ + "_idx_user", user_id),

def __repr__(self) -> str:
return f"<Agent(id='{self.id}', name='{self.name}')>"

Expand Down Expand Up @@ -222,6 +228,7 @@ class SourceModel(Base):
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
description = Column(String)
Index(__tablename__ + "_idx_user", user_id),

# TODO: add num passages

Expand Down Expand Up @@ -249,6 +256,7 @@ class AgentSourceMappingModel(Base):
user_id = Column(CommonUUID, nullable=False)
agent_id = Column(CommonUUID, nullable=False)
source_id = Column(CommonUUID, nullable=False)
Index(__tablename__ + "_idx_user", user_id, agent_id, source_id),

def __repr__(self) -> str:
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
Expand All @@ -261,6 +269,7 @@ class PresetSourceMapping(Base):
user_id = Column(CommonUUID, nullable=False)
preset_id = Column(CommonUUID, nullable=False)
source_id = Column(CommonUUID, nullable=False)
Index(__tablename__ + "_idx_user", user_id, preset_id, source_id),

def __repr__(self) -> str:
return f"<PresetSourceMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', source_id='{self.source_id}')>"
Expand Down Expand Up @@ -298,6 +307,7 @@ class PresetModel(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now())

functions_schema = Column(JSON)
Index(__tablename__ + "_idx_user", user_id),

def __repr__(self) -> str:
return f"<Preset(id='{self.id}', name='{self.name}')>"
Expand Down
8 changes: 4 additions & 4 deletions memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ToolModel(SQLModel, table=True):
json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.")

# optional: user_id (user-specific tools)
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.", index=True)

# Needed for Column(JSON)
class Config:
Expand Down Expand Up @@ -128,14 +128,14 @@ class HumanModel(SQLModel, table=True):
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
name: str = Field(..., description="The name of the human.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.")
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.", index=True)


class PersonaModel(SQLModel, table=True):
text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.")
name: str = Field(..., description="The name of the persona.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.")
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.", index=True)


class SourceModel(SQLModel, table=True):
Expand Down Expand Up @@ -167,7 +167,7 @@ class JobModel(SQLModel, table=True):
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.", sa_column=Column(ChoiceType(JobStatus)))
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the job.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the job.", index=True)
metadata_: Optional[dict] = Field({}, sa_column=Column(JSON), description="The metadata of the job.")


Expand Down

0 comments on commit bbe37fc

Please sign in to comment.