diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..5701746 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,31 @@ +name: Build and Publish Docker Image + +on: + push: + branches: + - main # Change this to your default branch if different + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build Docker image + run: | + docker build -t ghcr.io/${{ github.repository }}/my-image:latest . + # Replace 'my-image' with your desired image name + + - name: Push Docker image + run: | + docker push ghcr.io/${{ github.repository }}/my-image:latest + # Replace 'my-image' with your desired image name diff --git a/alembic/versions/001_setup_postgres.py b/alembic/versions/001_setup_postgres.py new file mode 100644 index 0000000..3626ee9 --- /dev/null +++ b/alembic/versions/001_setup_postgres.py @@ -0,0 +1,25 @@ +"""Create uuid-ossp extension + +Revision ID: 001_create_uuid_ossp +Revises: +Create Date: 2024-07-30 15:00:00.000000 + +""" + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "001_create_uuid_ossp" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"') + + +def downgrade(): + op.execute('DROP EXTENSION IF EXISTS "uuid-ossp"') diff --git a/alembic/versions/4f4580079301_initial_migration.py b/alembic/versions/4f4580079301_initial_migration.py deleted file mode 100644 index 9faa5f4..0000000 --- a/alembic/versions/4f4580079301_initial_migration.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Initial migration - -Revision ID: 4f4580079301 -Revises: -Create Date: 2024-08-04 23:31:21.059320 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = '4f4580079301' -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('context', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('content', sa.String(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('entities', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('type', sa.String(), nullable=False), - sa.Column('attributes', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('memories', - sa.Column('id', sa.String(), nullable=False), - sa.Column('content', sa.String(), nullable=False), - sa.Column('importance', sa.Float(), nullable=True), - sa.Column('last_accessed', sa.DateTime(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('queues', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('system_state', - sa.Column('id', sa.String(), nullable=False), - sa.Column('current_focus', sa.String(), nullable=True), - sa.Column('mood', sa.String(), nullable=True), - sa.Column('parameters', sa.String(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('tags', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('actions', - sa.Column('id', sa.String(), nullable=False), - sa.Column('queue_id', sa.String(), nullable=True), - sa.Column('type', sa.String(), nullable=False), - sa.Column('content', sa.String(), nullable=False), - sa.Column('status', sa.String(), nullable=False), - sa.Column('priority', sa.Integer(), nullable=True), - sa.Column('deadline', sa.DateTime(), nullable=True), - sa.Column('result', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['queue_id'], ['queues.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('entity_memories', - sa.Column('id', sa.String(), nullable=False), - sa.Column('entity_id', sa.String(), nullable=True), - sa.Column('memory_id', sa.String(), nullable=True), - sa.ForeignKeyConstraint(['entity_id'], ['entities.id'], ), - sa.ForeignKeyConstraint(['memory_id'], ['memories.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('memory_tags', - sa.Column('memory_id', sa.String(), nullable=True), - sa.Column('tag_id', sa.String(), nullable=True), - sa.ForeignKeyConstraint(['memory_id'], ['memories.id'], ), - sa.ForeignKeyConstraint(['tag_id'], ['tags.id'], ) - ) - op.create_table('relationships', - sa.Column('id', sa.String(), nullable=False), - sa.Column('entity1_id', sa.String(), nullable=True), - sa.Column('entity2_id', sa.String(), nullable=True), - sa.Column('relationship_type', sa.String(), nullable=False), - sa.Column('strength', sa.Float(), nullable=True), - sa.ForeignKeyConstraint(['entity1_id'], ['entities.id'], ), - sa.ForeignKeyConstraint(['entity2_id'], ['entities.id'], ), - sa.PrimaryKeyConstraint('id') - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('relationships') - op.drop_table('memory_tags') - op.drop_table('entity_memories') - op.drop_table('actions') - op.drop_table('tags') - op.drop_table('system_state') - op.drop_table('queues') - op.drop_table('memories') - op.drop_table('entities') - op.drop_table('context') - # ### end Alembic commands ### diff --git a/alembic/versions/6674e2e67df2_align_models_with_gql_schema.py b/alembic/versions/6674e2e67df2_align_models_with_gql_schema.py deleted file mode 100644 index 5c0a9e6..0000000 --- a/alembic/versions/6674e2e67df2_align_models_with_gql_schema.py +++ /dev/null @@ -1,43 +0,0 @@ -"""align models with gql schema - -Revision ID: 6674e2e67df2 -Revises: fb8c8d46203a -Create Date: 2024-08-05 01:17:37.713501 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = '6674e2e67df2' -down_revision: Union[str, None] = 'fb8c8d46203a' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('note_tags', - sa.Column('note_id', sa.String(), nullable=True), - sa.Column('tag_id', sa.String(), nullable=True), - sa.ForeignKeyConstraint(['note_id'], ['notes.id'], ), - sa.ForeignKeyConstraint(['tag_id'], ['tags.id'], ) - ) - op.add_column('chats', sa.Column('title', sa.String(), nullable=False)) - op.alter_column('messages', 'role', - existing_type=sa.VARCHAR(), - nullable=False) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column('messages', 'role', - existing_type=sa.VARCHAR(), - nullable=True) - op.drop_column('chats', 'title') - op.drop_table('note_tags') - # ### end Alembic commands ### diff --git a/alembic/versions/a3e6578f513d_inital_models.py b/alembic/versions/a3e6578f513d_inital_models.py new file mode 100644 index 0000000..5487153 --- /dev/null +++ b/alembic/versions/a3e6578f513d_inital_models.py @@ -0,0 +1,203 @@ +"""inital models + +Revision ID: a3e6578f513d +Revises: 001_create_uuid_ossp +Create Date: 2024-08-05 13:13:19.961363 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a3e6578f513d' +down_revision: Union[str, None] = '001_create_uuid_ossp' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('users', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('email', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('email') + ) + op.create_table('profiles', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('full_name', sa.String(), nullable=True), + sa.Column('provider', sa.String(), nullable=False), + sa.Column('user_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('refresh_tokens', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('user_id', sa.UUID(), nullable=False), + sa.Column('token', sa.String(), nullable=False), + sa.Column('revoked', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('token'), + sa.UniqueConstraint('user_id') + ) + op.create_table('chats', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('title', sa.String(), nullable=False), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('context', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('content', sa.String(), nullable=False), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('entities', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('type', sa.String(), nullable=False), + sa.Column('attributes', sa.String(), nullable=True), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('memories', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('content', sa.String(), nullable=False), + sa.Column('importance', sa.Float(), nullable=True), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('last_accessed', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('notes', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('content', sa.String(), nullable=False), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('queues', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('profile_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('system_state', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('current_focus', sa.String(), nullable=True), + sa.Column('mood', sa.String(), nullable=True), + sa.Column('parameters', sa.String(), nullable=True), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('tags', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('actions', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('queue_id', sa.UUID(), nullable=True), + sa.Column('type', sa.String(), nullable=False), + sa.Column('content', sa.String(), nullable=False), + sa.Column('status', sa.String(), nullable=False), + sa.Column('priority', sa.Integer(), nullable=True), + sa.Column('deadline', sa.DateTime(), nullable=True), + sa.Column('result', sa.String(), nullable=True), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.ForeignKeyConstraint(['queue_id'], ['queues.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('entity_memories', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('entity_id', sa.UUID(), nullable=True), + sa.Column('memory_id', sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(['entity_id'], ['entities.id'], ), + sa.ForeignKeyConstraint(['memory_id'], ['memories.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('memory_tags', + sa.Column('memory_id', sa.UUID(), nullable=True), + sa.Column('tag_id', sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(['memory_id'], ['memories.id'], ), + sa.ForeignKeyConstraint(['tag_id'], ['tags.id'], ) + ) + op.create_table('messages', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message', sa.String(), nullable=True), + sa.Column('role', sa.String(), nullable=False), + sa.Column('chat_id', sa.UUID(), nullable=True), + sa.Column('profile_id', sa.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['chat_id'], ['chats.id'], ), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('note_tags', + sa.Column('note_id', sa.UUID(), nullable=True), + sa.Column('tag_id', sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(['note_id'], ['notes.id'], ), + sa.ForeignKeyConstraint(['tag_id'], ['tags.id'], ) + ) + op.create_table('relationships', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('entity1_id', sa.UUID(), nullable=True), + sa.Column('entity2_id', sa.UUID(), nullable=True), + sa.Column('relationship_type', sa.String(), nullable=False), + sa.Column('strength', sa.Float(), nullable=True), + sa.ForeignKeyConstraint(['entity1_id'], ['entities.id'], ), + sa.ForeignKeyConstraint(['entity2_id'], ['entities.id'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('relationships') + op.drop_table('note_tags') + op.drop_table('messages') + op.drop_table('memory_tags') + op.drop_table('entity_memories') + op.drop_table('actions') + op.drop_table('tags') + op.drop_table('system_state') + op.drop_table('queues') + op.drop_table('notes') + op.drop_table('memories') + op.drop_table('entities') + op.drop_table('context') + op.drop_table('chats') + op.drop_table('refresh_tokens') + op.drop_table('profiles') + op.drop_table('users') + # ### end Alembic commands ### diff --git a/alembic/versions/e4d6f72d4b9b_add_supabase_tables.py b/alembic/versions/e4d6f72d4b9b_add_supabase_tables.py deleted file mode 100644 index 9e834d0..0000000 --- a/alembic/versions/e4d6f72d4b9b_add_supabase_tables.py +++ /dev/null @@ -1,68 +0,0 @@ -"""add supabase tables - -Revision ID: e4d6f72d4b9b -Revises: 4f4580079301 -Create Date: 2024-08-04 23:55:50.068351 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'e4d6f72d4b9b' -down_revision: Union[str, None] = '4f4580079301' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('profiles', - sa.Column('id', sa.String(), nullable=False), - sa.Column('full_name', sa.String(), nullable=True), - sa.Column('provider', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('chats', - sa.Column('id', sa.String(), nullable=False), - sa.Column('profile_id', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('notes', - sa.Column('id', sa.String(), nullable=False), - sa.Column('content', sa.String(), nullable=False), - sa.Column('profile_id', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('messages', - sa.Column('id', sa.String(), nullable=False), - sa.Column('chat_id', sa.String(), nullable=True), - sa.Column('profile_id', sa.String(), nullable=True), - sa.Column('role', sa.String(), nullable=True), - sa.Column('message', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['chat_id'], ['chats.id'], ), - sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], ), - sa.PrimaryKeyConstraint('id') - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('messages') - op.drop_table('notes') - op.drop_table('chats') - op.drop_table('profiles') - # ### end Alembic commands ### diff --git a/alembic/versions/e75809b95dc7_add_missing_default_created_at.py b/alembic/versions/e75809b95dc7_add_missing_default_created_at.py new file mode 100644 index 0000000..0b5697f --- /dev/null +++ b/alembic/versions/e75809b95dc7_add_missing_default_created_at.py @@ -0,0 +1,30 @@ +"""add missing default created_at + +Revision ID: e75809b95dc7 +Revises: a3e6578f513d +Create Date: 2024-08-05 16:43:43.024476 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'e75809b95dc7' +down_revision: Union[str, None] = 'a3e6578f513d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/alembic/versions/fb8c8d46203a_relationships.py b/alembic/versions/fb8c8d46203a_relationships.py deleted file mode 100644 index d046371..0000000 --- a/alembic/versions/fb8c8d46203a_relationships.py +++ /dev/null @@ -1,58 +0,0 @@ -"""relationships - -Revision ID: fb8c8d46203a -Revises: e4d6f72d4b9b -Create Date: 2024-08-05 00:03:10.417625 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'fb8c8d46203a' -down_revision: Union[str, None] = 'e4d6f72d4b9b' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column('actions', sa.Column('profile_id', sa.String(), nullable=True)) - op.create_foreign_key(None, 'actions', 'profiles', ['profile_id'], ['id']) - op.add_column('context', sa.Column('profile_id', sa.String(), nullable=True)) - op.create_foreign_key(None, 'context', 'profiles', ['profile_id'], ['id']) - op.add_column('entities', sa.Column('profile_id', sa.String(), nullable=True)) - op.create_foreign_key(None, 'entities', 'profiles', ['profile_id'], ['id']) - op.add_column('memories', sa.Column('profile_id', sa.String(), nullable=True)) - op.create_foreign_key(None, 'memories', 'profiles', ['profile_id'], ['id']) - op.add_column('queues', sa.Column('profile_id', sa.String(), nullable=True)) - op.create_foreign_key(None, 'queues', 'profiles', ['profile_id'], ['id']) - op.add_column('system_state', sa.Column('profile_id', sa.String(), nullable=True)) - op.create_foreign_key(None, 'system_state', 'profiles', ['profile_id'], ['id']) - op.add_column('tags', sa.Column('profile_id', sa.String(), nullable=True)) - op.add_column('tags', sa.Column('created_at', sa.DateTime(), nullable=True)) - op.create_foreign_key(None, 'tags', 'profiles', ['profile_id'], ['id']) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'tags', type_='foreignkey') - op.drop_column('tags', 'created_at') - op.drop_column('tags', 'profile_id') - op.drop_constraint(None, 'system_state', type_='foreignkey') - op.drop_column('system_state', 'profile_id') - op.drop_constraint(None, 'queues', type_='foreignkey') - op.drop_column('queues', 'profile_id') - op.drop_constraint(None, 'memories', type_='foreignkey') - op.drop_column('memories', 'profile_id') - op.drop_constraint(None, 'entities', type_='foreignkey') - op.drop_column('entities', 'profile_id') - op.drop_constraint(None, 'context', type_='foreignkey') - op.drop_column('context', 'profile_id') - op.drop_constraint(None, 'actions', type_='foreignkey') - op.drop_column('actions', 'profile_id') - # ### end Alembic commands ### diff --git a/docker-compose.yml b/docker-compose.yml index e4fe7e8..3052f24 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,6 @@ services: - web: + api: + container_name: mindsherpa_api build: context: . dockerfile: Dockerfile.dev @@ -9,18 +10,25 @@ services: - .:/app - /app/__pycache__/ command: uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 + depends_on: + - postgres + postgres: - image: postgres:13 - environment: - POSTGRES_DB: postgres - POSTGRES_HOST: postgres - POSTGRES_PORT: 5432 - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password + image: postgres:latest + container_name: mindsherpa_postgres volumes: - postgres_data:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres ports: - "5432:5432" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres_test"] + interval: 5s + timeout: 5s + retries: 5 volumes: - postgres_data: + postgres_data: \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 94244ed..962af5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,9 @@ annotated-types==0.7.0 anyio==3.7.1 attrs==23.2.0 certifi==2024.7.4 +cffi==1.16.0 click==8.1.7 +cryptography==43.0.0 deprecation==2.1.0 distro==1.9.0 dnspython==2.6.1 @@ -37,9 +39,11 @@ packaging==24.1 pluggy==1.5.0 postgrest==0.16.9 psycopg2-binary==2.9.9 +pycparser==2.22 pydantic==2.8.2 pydantic_core==2.20.1 Pygments==2.18.0 +PyJWT==2.9.0 pytest==8.3.2 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 diff --git a/src/data/data_access.py b/src/data/data_access.py new file mode 100644 index 0000000..0707e9c --- /dev/null +++ b/src/data/data_access.py @@ -0,0 +1,75 @@ +from sqlalchemy.orm import Session +from typing import List + +from src.data.models import Message +from src.data.notes import get_user_notes +from src.services.openai_service import openai_client + + +def get_chat_history(session: Session, chat_id: int) -> List[Message]: + messages = session.query(Message).filter(Message.chat_id == chat_id).all() + return messages + + +def insert_message( + session: Session, chat_id: str, message: str, profile_id: str, role: str +) -> Message: + new_message = Message( + message=message, chat_id=chat_id, profile_id=profile_id, role=role + ) + session.add(new_message) + session.commit() + + return new_message + + +def get_sherpa_response( + session: Session, message: str, chat_id, profile_id +) -> str | None: + system_prompt = """ + You are the user's expert-level personal assistant and best friend. + + You have full history of the user's chat with you and their Context, which is a list of notes they have taken. + These notes include their goals, tasks, and any other important information they have shared with you. + + The user is going to provide with their entire chat history with you, along with their latest message and \n + + You must respond to the user's message based on the chat history. + + ## Rules: + - Do not say who you are or that you are an AI. + - Do not speak in paragraphs. + - Respond with the least amount of words possible, but use full sentences. + - Include emojis in your responses where applicable. + - Your response should be in a friendly, upbeat and conversational tone. + - Your response should use all of the User Context and the entire Chat History to provide context to your response. + - Your response should use that knowledge about the user to answer the user's latest message. + + ## User Context + {user_context} + + ## Chat History + {chat_history} + """ + + chat_history = get_chat_history(session, chat_id) + user_context = get_user_notes(session, profile_id) + chat_history_contents = [message.message for message in chat_history] + user_context_contents = [note.content for note in user_context] + + response = openai_client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": system_prompt.format( + chat_history=chat_history_contents, + user_context=user_context_contents, + ), + }, + {"role": "user", "content": message}, + ], + stream=False, + ) + content = response.choices[0].message.content + return content diff --git a/src/data/db.py b/src/data/db.py index 6a27028..0e9ef51 100644 --- a/src/data/db.py +++ b/src/data/db.py @@ -1,7 +1,7 @@ import os from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker - +from src.utils.logger import logger from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() @@ -10,5 +10,16 @@ if not DATABASE_URL: raise ValueError("DATABASE_URL environment variable is not set") + engine = create_engine(DATABASE_URL, echo=True) +try: + connection = engine.connect() + logger.info("Successfully connected to the database!") + connection.close() +except Exception as e: + logger.error(f"An error occurred: {e}") + + +# Session Session = sessionmaker(engine) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/src/data/models.py b/src/data/models.py index 2d5cbcd..5a51019 100644 --- a/src/data/models.py +++ b/src/data/models.py @@ -1,6 +1,6 @@ -import enum -from re import S +from datetime import datetime from sqlalchemy import ( + UUID, Column, Integer, String, @@ -8,28 +8,70 @@ DateTime, ForeignKey, Table, + func, ) from sqlalchemy.orm import relationship -from sqlalchemy.dialects.postgresql import ENUM from src.data.db import Base +class User(Base): + __tablename__ = "users" + + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) + # apple_id = Column(String, unique=True, nullable=False) + email = Column(String, unique=True, nullable=False) + name = Column(String, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" + + +class RefreshToken(Base): + __tablename__ = "refresh_tokens" + + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) + user_id = Column( + UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, unique=True + ) + token = Column(String, unique=True, nullable=False) + revoked = Column(DateTime, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" + + class Profile(Base): __tablename__ = "profiles" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) full_name = Column(String, nullable=True) provider = Column(String, nullable=False) - user_id = Column(String, nullable=False) - created_at = Column(DateTime) - updated_at = Column(DateTime) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + def __repr__(self): + return f"" class Queue(Base): __tablename__ = "queues" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) name = Column(String, nullable=False) - profile_id = Column(String, ForeignKey("profiles.id")) - created_at = Column(DateTime) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id"), nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Queue.profile = relationship("Profile", back_populates="queues") @@ -38,16 +80,21 @@ class Queue(Base): class Action(Base): __tablename__ = "actions" - id = Column(String, primary_key=True) - queue_id = Column(String, ForeignKey("queues.id")) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) + queue_id = Column(UUID(as_uuid=True), ForeignKey("queues.id")) type = Column(String, nullable=False) content = Column(String, nullable=False) status = Column(String, nullable=False) priority = Column(Integer) deadline = Column(DateTime) result = Column(String) - profile_id = Column(String, ForeignKey("profiles.id")) - created_at = Column(DateTime) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Action.queue = relationship("Queue", back_populates="actions") @@ -58,51 +105,71 @@ class Action(Base): class Memory(Base): __tablename__ = "memories" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) content = Column(String, nullable=False) importance = Column(Float) - profile_id = Column(String, ForeignKey("profiles.id")) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) last_accessed = Column(DateTime) - created_at = Column(DateTime) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Memory.profile = relationship("Profile", back_populates="memories") +Memory.entities = relationship("EntityMemory", back_populates="memory") Profile.memories = relationship("Memory", back_populates="profile") class Entity(Base): __tablename__ = "entities" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) name = Column(String, nullable=False) type = Column(String, nullable=False) attributes = Column(String) - profile_id = Column(String, ForeignKey("profiles.id")) - created_at = Column(DateTime) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Entity.profile = relationship("Profile", back_populates="entities") Profile.entities = relationship("Entity", back_populates="profile") +Entity.memories = relationship("EntityMemory", back_populates="entity") class EntityMemory(Base): __tablename__ = "entity_memories" - id = Column(String, primary_key=True) - entity_id = Column(String, ForeignKey("entities.id")) - memory_id = Column(String, ForeignKey("memories.id")) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) + entity_id = Column(UUID(as_uuid=True), ForeignKey("entities.id")) + memory_id = Column(UUID(as_uuid=True), ForeignKey("memories.id")) + + def __repr__(self): + return f"" EntityMemory.entity = relationship("Entity", back_populates="memories") EntityMemory.memory = relationship("Memory", back_populates="entities") -Entity.memories = relationship("Memory", secondary="entity_memories") -Memory.entities = relationship("Entity", secondary="entity_memories") class Tag(Base): __tablename__ = "tags" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) name = Column(String, nullable=False) - profile_id = Column(String, ForeignKey("profiles.id")) - created_at = Column(DateTime) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Tag.profile = relationship("Profile", back_populates="tags") @@ -111,21 +178,31 @@ class Tag(Base): class Relationship(Base): __tablename__ = "relationships" - id = Column(String, primary_key=True) - entity1_id = Column(String, ForeignKey("entities.id")) - entity2_id = Column(String, ForeignKey("entities.id")) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) + entity1_id = Column(UUID(as_uuid=True), ForeignKey("entities.id")) + entity2_id = Column(UUID(as_uuid=True), ForeignKey("entities.id")) relationship_type = Column(String, nullable=False) strength = Column(Float) + def __repr__(self): + return f"" + class Context(Base): __tablename__ = "context" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) name = Column(String, nullable=False) content = Column(String, nullable=False) - profile_id = Column(String, ForeignKey("profiles.id")) - created_at = Column(DateTime) - updated_at = Column(DateTime) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Context.profile = relationship("Profile", back_populates="contexts") @@ -134,12 +211,17 @@ class Context(Base): class SystemState(Base): __tablename__ = "system_state" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) current_focus = Column(String) mood = Column(String) parameters = Column(String) - profile_id = Column(String, ForeignKey("profiles.id")) - updated_at = Column(DateTime) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) + updated_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" SystemState.profile = relationship("Profile", back_populates="system_state") @@ -148,10 +230,15 @@ class SystemState(Base): class Chat(Base): __tablename__ = "chats" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) title = Column(String, nullable=False) - profile_id = Column(String, ForeignKey("profiles.id")) - created_at = Column(DateTime) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Chat.profile = relationship("Profile", back_populates="chats") @@ -160,12 +247,17 @@ class Chat(Base): class Message(Base): __tablename__ = "messages" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) message = Column(String) role = Column(String, nullable=False) - chat_id = Column(String, ForeignKey("chats.id")) - profile_id = Column(String, ForeignKey("profiles.id")) - created_at = Column(DateTime) + chat_id = Column(UUID(as_uuid=True), ForeignKey("chats.id")) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) + created_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Chat.messages = relationship("Message", back_populates="chat") @@ -174,11 +266,16 @@ class Message(Base): class Note(Base): __tablename__ = "notes" - id = Column(String, primary_key=True) + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) content = Column(String, nullable=False) - profile_id = Column(String, ForeignKey("profiles.id")) - created_at = Column(DateTime) - updated_at = Column(DateTime) + profile_id = Column(UUID(as_uuid=True), ForeignKey("profiles.id")) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" Note.profile = relationship("Profile", back_populates="notes") @@ -188,22 +285,22 @@ class Note(Base): memory_tags = Table( "memory_tags", Base.metadata, - Column("memory_id", String, ForeignKey("memories.id")), - Column("tag_id", String, ForeignKey("tags.id")), + Column("memory_id", UUID(as_uuid=True), ForeignKey("memories.id")), + Column("tag_id", UUID(as_uuid=True), ForeignKey("tags.id")), ) + # Add relationships Memory.tags = relationship("Tag", secondary=memory_tags, back_populates="memories") Tag.memories = relationship("Memory", secondary=memory_tags, back_populates="tags") -Entity.memories = relationship("Memory", secondary="entity_memories") -Memory.entities = relationship("Entity", secondary="entity_memories") + # Note and Tag many-to-many relationship note_tags = Table( "note_tags", Base.metadata, - Column("note_id", String, ForeignKey("notes.id")), - Column("tag_id", String, ForeignKey("tags.id")), + Column("note_id", UUID(as_uuid=True), ForeignKey("notes.id")), + Column("tag_id", UUID(as_uuid=True), ForeignKey("tags.id")), ) Note.tags = relationship("Tag", secondary=note_tags, back_populates="notes") diff --git a/src/data/notes.py b/src/data/notes.py new file mode 100644 index 0000000..afa6a6e --- /dev/null +++ b/src/data/notes.py @@ -0,0 +1,10 @@ +from typing import List +from sqlalchemy.orm import Session + +from src.data.models import Note + + +def get_user_notes(session: Session, profile_id: str) -> List[Note]: + notes = session.query(Note).filter(Note.profile_id == profile_id).all() + + return notes diff --git a/src/data/types.py b/src/data/types.py new file mode 100644 index 0000000..b6fe4c1 --- /dev/null +++ b/src/data/types.py @@ -0,0 +1,176 @@ +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + + +class User(BaseModel): + id: UUID = Field(..., alias="id") + email: str + name: Optional[str] = None + created_at: datetime + + +class RefreshToken(BaseModel): + id: UUID = Field(..., alias="id") + user_id: UUID + token: str + revoked: Optional[datetime] = None + created_at: datetime + + +class Profile(BaseModel): + id: UUID = Field(..., alias="id") + full_name: Optional[str] = None + provider: str + user_id: UUID + created_at: datetime + updated_at: datetime + queues: List["Queue"] = [] + actions: List["Action"] = [] + memories: List["Memory"] = [] + entities: List["Entity"] = [] + tags: List["Tag"] = [] + contexts: List["Context"] = [] + system_state: Optional["SystemState"] = None + chats: List["Chat"] = [] + notes: List["Note"] = [] + + +class Queue(BaseModel): + id: UUID = Field(..., alias="id") + name: str + profile_id: UUID + created_at: datetime + profile: "Profile" + actions: List["Action"] = [] + + +class Action(BaseModel): + id: UUID = Field(..., alias="id") + queue_id: Optional[UUID] = None + type: str + content: str + status: str + priority: Optional[int] = None + deadline: Optional[datetime] = None + result: Optional[str] = None + profile_id: UUID + created_at: datetime + queue: Optional["Queue"] = None + profile: "Profile" + + +class Memory(BaseModel): + id: UUID = Field(..., alias="id") + content: str + importance: Optional[float] = None + profile_id: UUID + last_accessed: Optional[datetime] = None + created_at: datetime + profile: "Profile" + entities: List["EntityMemory"] = [] + tags: List["Tag"] = [] + + +class Entity(BaseModel): + id: UUID = Field(..., alias="id") + name: str + type: str + attributes: Optional[str] = None + profile_id: UUID + created_at: datetime + profile: "Profile" + memories: List["EntityMemory"] = [] + + +class EntityMemory(BaseModel): + id: UUID = Field(..., alias="id") + entity_id: UUID + memory_id: UUID + entity: "Entity" + memory: "Memory" + + +class Tag(BaseModel): + id: UUID = Field(..., alias="id") + name: str + profile_id: UUID + created_at: datetime + profile: "Profile" + memories: List["Memory"] = [] + notes: List["Note"] = [] + + +class Relationship(BaseModel): + id: UUID = Field(..., alias="id") + entity1_id: UUID + entity2_id: UUID + relationship_type: str + strength: Optional[float] = None + + +class Context(BaseModel): + id: UUID = Field(..., alias="id") + name: str + content: str + profile_id: UUID + created_at: datetime + updated_at: datetime + profile: "Profile" + + +class SystemState(BaseModel): + id: UUID = Field(..., alias="id") + current_focus: Optional[str] = None + mood: Optional[str] = None + parameters: Optional[str] = None + profile_id: UUID + updated_at: datetime + profile: "Profile" + + +class Chat(BaseModel): + id: UUID = Field(..., alias="id") + title: str + profile_id: UUID + created_at: datetime + profile: "Profile" + messages: List["Message"] = [] + + +class Message(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: UUID = Field(..., alias="id") + message: str + role: str + chat_id: UUID + profile_id: UUID + created_at: datetime + chat: "Chat" + + +class Note(BaseModel): + id: UUID = Field(..., alias="id") + content: str + profile_id: UUID + created_at: datetime + updated_at: datetime + profile: "Profile" + tags: List["Tag"] = [] + + +# Set up forward references +Profile.model_rebuild() +Queue.model_rebuild() +Action.model_rebuild() +Memory.model_rebuild() +Entity.model_rebuild() +EntityMemory.model_rebuild() +Tag.model_rebuild() +Context.model_rebuild() +SystemState.model_rebuild() +Chat.model_rebuild() +Message.model_rebuild() +Note.model_rebuild() diff --git a/src/resolvers/chat_resolvers.py b/src/resolvers/chat_resolvers.py new file mode 100644 index 0000000..02719d2 --- /dev/null +++ b/src/resolvers/chat_resolvers.py @@ -0,0 +1,159 @@ +from typing import List +import strawberry +from enum import Enum +import json + +from src.data.data_access import get_sherpa_response, insert_message +from src.data.models import Chat as ChatModel, Message as MessageModel +from src.schemas.types import Chat, Message +from src.services.file_service import get_file_contents +from src.services.groq_service import groq_client +from src.utils.ai_models import open_source_models +from src.utils.logger import logger +from src.utils.generation_statistics import GenerationStatistics + + +def chat_to_gql(chat: ChatModel) -> Chat: + return Chat( + id=chat.id, + title=chat.title, + created_at=chat.created_at, + ) + + +def message_to_gql(message: MessageModel) -> Message: + return Message( + id=message.id, + chat_id=message.chat_id, + profile_id=message.profile_id, + role=message.role, + message=message.message, + created_at=message.created_at, + ) + + +class AvailablePrompts(Enum): + v1 = "user_input_formatter_v1.md" + v2 = "user_input_formatter_v2.md" + + +def get_prompt(prompt: AvailablePrompts): + return get_file_contents(f"src/prompts/{prompt.value}") + + +def analyze_user_input(transcript: str, model: str = "llama3-70b-8192"): + system_prompt = get_prompt(AvailablePrompts.v2) + + if model in open_source_models: + + try: + completion = groq_client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": transcript}, + ], + temperature=0.3, + max_tokens=8000, + top_p=1, + stream=False, + response_format={"type": "json_object"}, + stop=None, + ) + + usage = completion.usage + # print("------ USAGE ---", usage) + + except Exception as e: + logger.error(f" ********* API error ********: {e} ***** ") + return None, {"error": str(e)} + + try: + if usage: + statistics_to_return = GenerationStatistics( + input_time=int(usage.prompt_time) if usage.prompt_time else 0, + output_time=( + int(usage.completion_time) if usage.completion_time else 0 + ), + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + total_time=int(usage.total_time) if usage.total_time else 0, + model_name=model, + ) + logger.info("focus_stats", statistics_to_return.get_stats()) + + return ( + json.loads(completion.choices[0].message.content) + if completion.choices[0].message.content + else None + ) + except Exception as e: + logger.error(f" ********* STATISTICS GENERATION error ******* : {e} ") + return None, {"error": str(e)} + + +async def chats(info: strawberry.Info) -> List[Chat]: + if not info.context.get("user"): + raise Exception("Unauthorized") + + session = info.context.get("session") + profile_id = info.context.get("profile").id + chats = session.query(ChatModel).filter(ChatModel.profile_id == profile_id).all() + + if len(chats) == 0: + # Create a new chat if none exists + new_chat = ChatModel( + title="New Chat", + profile_id=profile_id, + ) + session.add(new_chat) + session.commit() + + return [chat_to_gql(new_chat)] + + return [chat_to_gql(chat) for chat in chats] + + +async def chat_messages(info: strawberry.Info, chat_id: str) -> List[Message]: + if not info.context.get("user"): + raise Exception("Unauthorized") + + session = info.context.get("session") + messages = session.query(MessageModel).filter(MessageModel.chat_id == chat_id).all() + + return [message_to_gql(message) for message in messages] + + +async def send_chat_message( + info: strawberry.Info, chat_id: str, message: str +) -> List[Message]: + + if not info.context.get("user"): + raise Exception("Unauthorized") + + profile_id = info.context.get("profile").id + session = info.context.get("session") + + # Insert new message into the database + user_message = insert_message( + session, chat_id=chat_id, profile_id=profile_id, message=message, role="user" + ) + + # Retrieve message from ChatGPT + sherpa_response = get_sherpa_response(session, message, chat_id, profile_id) + if sherpa_response is None: + raise Exception("No response from the model") + + # Save system response to the database + system_message = insert_message( + session, + chat_id=chat_id, + profile_id=profile_id, + role="assistant", + message=sherpa_response, + ) + + return [ + user_message, + system_message, + ] diff --git a/src/resolvers/user_resolvers.py b/src/resolvers/user_resolvers.py new file mode 100644 index 0000000..5e57e73 --- /dev/null +++ b/src/resolvers/user_resolvers.py @@ -0,0 +1,256 @@ +from datetime import datetime, timedelta +from fastapi import HTTPException +from jwt.exceptions import InvalidTokenError +from sqlalchemy.orm import Session +from strawberry.types import Info +import jwt +import os +import strawberry +import uuid + + +from src.data.models import User, Profile +from src.services.supabase import supabase_client +from src.schemas.types import CreateUserInput, CreateUserPayload, UpdateProfileInput + +JWT_SECRET = os.environ.get("JWT_SECRET") +if not JWT_SECRET: + raise ValueError("JWT_SECRET environment variable is not set") + + +JWT_ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +REFRESH_TOKEN_EXPIRE_DAYS = 30 + + +def get_current_user(session: Session, token: str) -> tuple[User, Profile]: + response = supabase_client.auth.get_user(token) + + if response is None: + raise HTTPException( + status_code=403, detail="Invalid authentication credentials" + ) + + user_id = response.user.id + user = session.query(User).filter(User.id == user_id).first() + if not user: + user = User( + id=user_id, + email=response.user.email, + ) + session.add(user) + session.commit() + + profile = session.query(Profile).filter(Profile.user_id == user_id).first() + if not profile: + profile = Profile(id=uuid.uuid4(), provider="apple", user_id=user.id) + session.add(profile) + session.commit() + + return User(id=response.user.id, email=response.user.email), profile + + +def create_access_token(data: dict): + to_encode = data.copy() + expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode.update({"exp": expire}) + return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) + + +def create_refresh_token(data: dict): + to_encode = data.copy() + expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + to_encode.update({"exp": expire}) + return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) + + +def verify_apple_token(id_token: str, nonce: str) -> dict: + # Implement Apple ID token verification here + # This should validate the token with Apple's public key and verify the nonce + # Return the decoded token payload if valid, raise an exception if not + # For brevity, we're assuming the token is valid in this example + return {"sub": "example_apple_id", "email": "user@example.com"} + + +@strawberry.type +class AuthPayload: + user_id: int + access_token: str + refresh_token: str + + +async def save_apple_user(info: Info, id_token: str, nonce: str) -> AuthPayload: + # Verify the Apple ID token + try: + apple_payload = verify_apple_token(id_token, nonce) + except InvalidTokenError: + raise ValueError("Invalid Apple ID token") + + apple_id = apple_payload["sub"] + email = apple_payload.get("email") + + # Get the database session from the context + session: Session = info.context["session"] + + # Check if the user already exists + user = session.query(User).filter(User.apple_id == apple_id).first() + + if user is None: + # Create a new user + user = User(apple_id=apple_id, email=email) + session.add(user) + session.commit() + elif email and user.email != email: + # Update the email if it has changed + user.email = email + session.commit() + + # Create access and refresh tokens + access_token = create_access_token({"sub": str(user.id)}) + refresh_token = create_refresh_token({"sub": str(user.id)}) + + return AuthPayload( + user_id=user.id, access_token=access_token, refresh_token=refresh_token + ) + + +async def refresh_token(info: Info, refresh_token: str) -> AuthPayload: + try: + payload = jwt.decode(refresh_token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) + user_id = int(payload["sub"]) + except (InvalidTokenError, ValueError): + raise ValueError("Invalid refresh token") + + session: Session = info.context["session"] + user = session.query(User).filter(User.id == user_id).first() + + if user is None: + raise ValueError("User not found") + + access_token = create_access_token({"sub": str(user.id)}) + new_refresh_token = create_refresh_token({"sub": str(user.id)}) + + return AuthPayload( + user_id=user.id, access_token=access_token, refresh_token=new_refresh_token + ) + + +async def create_user_and_profile( + info: Info, input: CreateUserInput +) -> CreateUserPayload: + session: Session = info.context["session"] + + user = User(email=input.email) + session.add(user) + session.commit() + + profile = Profile(user_id=user.id, provider="apple") + session.add(profile) + session.commit() + + # access_token = create_access_token({"sub": str(user.id)}) + # refresh_token = create_refresh_token({"sub": str(user.id)}) + + return CreateUserPayload(user=user, profile=profile) + + +@strawberry.type +class GetProfileOutput: + id: str + full_name: str + user_id: str + + +async def get_profile(info: Info) -> GetProfileOutput: + profile: Profile | None = info.context.get("profile") + + if not info.context.get("user") or not profile: + raise ValueError("Not authenticated") + + return GetProfileOutput( + id=str(profile.id), + full_name=str(profile.full_name), + user_id=str(profile.user_id), + ) + + +async def update_profile(info: Info, input: UpdateProfileInput) -> Profile: + session: Session = info.context["session"] + if not info.context.get("user"): + raise ValueError("Not authenticated") + + profile = ( + session.query(Profile) + .filter(Profile.user_id == info.context["user"].id) + .first() + ) + + if not profile: + raise ValueError("Profile does not exist") + + if input.full_name: + setattr(profile, "full_name", input.full_name) + + session.commit() + + return profile + + +# async def sign_out(info: Info) -> bool: +# # Get the current user from the context +# user_id = info.context.get("user_id") +# if not user_id: +# raise ValueError("Not authenticated") + +# session: Session = info.context["session"] + +# # Revoke all refresh tokens for the user +# session.query(RefreshToken).filter( +# RefreshToken.user_id == user_id, +# RefreshToken.revoked.is_(None) +# ).update({"revoked": datetime.utcnow()}) + +# session.commit() + +# # Notify subscribers about the account change +# await broadcast.publish(channel="user_changes", message=str(user_id)) + +# return True + +# @strawberry.mutation +# async def revoke_token(info: Info, refresh_token: str) -> bool: +# try: +# # Verify the token +# payload = verify_jwt_token(refresh_token) +# user_id = int(payload["sub"]) +# except (jwt.InvalidTokenError, ValueError): +# raise ValueError("Invalid refresh token") + +# session: Session = info.context["session"] + +# # Find and revoke the token +# token = session.query(RefreshToken).filter( +# RefreshToken.user_id == user_id, +# RefreshToken.token == refresh_token, +# RefreshToken.revoked.is_(None) +# ).first() + +# if token: +# token.revoked = datetime.utcnow() +# session.commit() + +# # Notify subscribers about the account change +# await broadcast.publish(channel="user_changes", message=str(user_id)) + +# return True +# else: +# return False + + +# # Setup function to initialize the broadcast client +# async def init_broadcast(): +# await broadcast.connect() + +# # Cleanup function to close the broadcast client +# async def cleanup_broadcast(): +# await broadcast.disconnect() diff --git a/src/routers/ai_router.py b/src/routers/ai_router.py index 51ec097..e862aa2 100644 --- a/src/routers/ai_router.py +++ b/src/routers/ai_router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Form from fastapi.responses import StreamingResponse -from src.services.sherpa_service import analyze_user_input +from src.resolvers.chat_resolvers import analyze_user_input from src.utils.logger import logger from src.services.file_service import get_file_contents diff --git a/src/routers/graphql.py b/src/routers/graphql.py index 2706470..fd08707 100644 --- a/src/routers/graphql.py +++ b/src/routers/graphql.py @@ -1,39 +1,44 @@ -# GraphQL endpoint -from functools import cached_property from fastapi import Request import strawberry -from strawberry.fastapi import GraphQLRouter, BaseContext +from strawberry.fastapi import GraphQLRouter from typing import Any, AsyncGenerator +from src.data.db import SessionLocal +from src.resolvers.user_resolvers import get_current_user from src.schemas.query import Query from src.schemas.mutation import Mutation -from src.schemas.types import User -from src.services.auth import get_current_user - - -class Context(BaseContext): - @cached_property - def user(self) -> User | None: - if not self.request: - return None - - authorization = self.request.headers.get("Authorization", None) - if not authorization: - return None - token = authorization.split(" ")[1] - return get_current_user(token) async def get_context(request: Request) -> AsyncGenerator[dict[str, Any], None]: + session = SessionLocal() auth_header = request.headers.get("Authorization") current_user = None + profile = None if auth_header: token = auth_header.split(" ")[1] - current_user = get_current_user(token) - - yield {"request": request, "current_user": current_user} + current_user, profile = get_current_user(session, token) + + try: + if current_user and profile: + print( + "user", + { + "user": current_user, + "profile": profile, + }, + ) + yield { + "request": request, + "user": current_user, + "profile": profile, + "session": session, + } + else: + yield {"request": request, "session": session} + finally: + session.close() schema = strawberry.Schema(query=Query, mutation=Mutation) diff --git a/src/routers/media_router.py b/src/routers/media_router.py index 7eec77d..9aeece4 100644 --- a/src/routers/media_router.py +++ b/src/routers/media_router.py @@ -1,14 +1,11 @@ import io -from fastapi import APIRouter, Depends, UploadFile +from fastapi import APIRouter, UploadFile +from src.resolvers.chat_resolvers import analyze_user_input from src.utils.ai_models import audio_models -from src.services.auth import get_current_user from src.services.groq_service import groq_client from src.services.media import transcribe_audio from src.services.openai_service import openai_client -from src.services.sherpa_service import analyze_user_input -from src.schemas.query import User - media_router = APIRouter() @@ -17,7 +14,6 @@ def transcription_route( audio_file: UploadFile, model: str = "openai", - current_user: User = Depends(get_current_user), ): """ Transcribes audio using either OpenAI's whisper or Groq's Whisper API. diff --git a/src/schemas/mutation.py b/src/schemas/mutation.py index fb89bd3..9fa9956 100644 --- a/src/schemas/mutation.py +++ b/src/schemas/mutation.py @@ -1,10 +1,20 @@ import io +import profile import strawberry from strawberry.file_uploads import Upload from typing import List +from src.data.models import Note -from src.schemas.types import Message -from src.services.sherpa_service import insert_message, get_sherpa_response +from src.resolvers.user_resolvers import ( + AuthPayload, + CreateUserPayload, + create_user_and_profile, + refresh_token, + save_apple_user, + update_profile, +) +from src.resolvers.chat_resolvers import send_chat_message +from src.schemas.types import Message, UpdateProfilePayload, CreateNote from src.services.media import transcribe_audio, TranscribeAudioResponse from src.utils.logger import logger @@ -14,13 +24,45 @@ class UploadVoiceNoteResponse(TranscribeAudioResponse): pass +@strawberry.input +class CreateNoteInput: + content: str + + @strawberry.type class Mutation: + create_user: CreateUserPayload = strawberry.field(resolver=create_user_and_profile) + save_apple_user: AuthPayload = strawberry.field(resolver=save_apple_user) + refresh_token: AuthPayload = strawberry.field(resolver=refresh_token) + update_profile: UpdateProfilePayload = strawberry.field(resolver=update_profile) + send_chat_message: List[Message] = strawberry.field(resolver=send_chat_message) + + @strawberry.field + async def create_note( + self, info: strawberry.Info, input: CreateNoteInput + ) -> CreateNote: + current_user = info.context.get("user") + if not current_user: + raise Exception("Unauthorized") + + session = info.context.get("session") + profile_id = info.context.get("profile").id + note = Note(content=input.content, profile_id=profile_id) + session.add(note) + session.commit() + + note_dict = note.__dict__ + return CreateNote( + id=note_dict["id"], + content=note_dict["content"], + created_at=note_dict["created_at"], + ) + @strawberry.field async def upload_voice_note( self, info: strawberry.Info, audio_file: Upload, chat_id: int ) -> UploadVoiceNoteResponse: - current_user = info.context.get("current_user") + current_user = info.context.get("user") if not current_user: raise Exception("Unauthorized") @@ -31,31 +73,17 @@ async def upload_voice_note( logger.info(f"Transcription: {transcription.text}") return UploadVoiceNoteResponse(text=transcription.text, error=None) - @strawberry.field - async def send_chat_message( - self, info: strawberry.Info, chat_id: int, message: str - ) -> List[Message]: - current_user = info.context.get("current_user") - if not current_user: - raise Exception("Unauthorized") - - # Insert new message into the database - user_message = insert_message(chat_id, message, current_user.id, "user") - # Retrieve message from ChatGPT - sherpa_response = get_sherpa_response(message, chat_id, current_user.id) - if sherpa_response is None: - raise Exception("No response from the model") - - # Save system response to the database - system_message = insert_message( - chat_id=chat_id, - user_id=current_user.id, - role="assistant", - message=sherpa_response, - ) +# @strawberry.type +# class Subscription: +# @strawberry.subscription +# async def user_account_changed(self, info: Info) -> str: +# # Get the current user from the context +# user_id = info.context.get("user_id") +# if not user_id: +# raise ValueError("Not authenticated") - return [ - user_message, - system_message, - ] +# async with broadcast.subscribe(channel="user_changes") as subscriber: +# async for event in subscriber: +# if event.message == str(user_id): +# yield event.message diff --git a/src/schemas/query.py b/src/schemas/query.py index a8d9e03..371025b 100644 --- a/src/schemas/query.py +++ b/src/schemas/query.py @@ -1,35 +1,42 @@ from typing import List import strawberry -from src.data.db import Session - -from src.data.models import Chat as ChatModel -from src.services.notebooks import get_notebooks, get_user_notes -from src.schemas.types import Chat, Note, Notebook, User - - -async def chats(self, info: strawberry.Info) -> List[Chat]: - session = Session() - chats = ( - session.query(ChatModel) - .filter(ChatModel.user_id == info.context.get("current_user").id) - .all() - ) - return [Chat(**chat.__dict__) for chat in chats] - - -async def current_user(self, info: strawberry.Info) -> User: - current_user = info.context.get("current_user") - - if not current_user: - raise Exception("Unauthorized") - - return User(id=current_user.id, email=current_user.email) +from src.resolvers.chat_resolvers import chats, chat_messages +from src.resolvers.user_resolvers import GetProfileOutput, get_profile +from src.data.notes import get_user_notes +from src.schemas.types import Chat, Message, NoteOutput, User @strawberry.type class Query: chats: List[Chat] = strawberry.field(resolver=chats) - current_user: User = strawberry.field(resolver=current_user) - notebooks: List[Notebook] = strawberry.field(resolver=get_notebooks) - notes: List[Note] = strawberry.field(resolver=get_user_notes) + chat_messages: List[Message] = strawberry.field(resolver=chat_messages) + profile: GetProfileOutput = strawberry.field(resolver=get_profile) + + @strawberry.field + async def notes(self, info: strawberry.Info) -> List[NoteOutput]: + current_user = info.context.get("user") + + if not current_user: + raise Exception("Unauthorized") + + profile_id = info.context.get("profile").id + notes = get_user_notes(info.context.get("session"), profile_id) + note_dicts = [note.__dict__ for note in notes] + return [ + NoteOutput( + id=note["id"], + content=note["content"], + created_at=note["created_at"], + ) + for note in note_dicts + ] + + @strawberry.field + async def current_user(self, info: strawberry.Info) -> User: + current_user = info.context.get("user") + + if not current_user: + raise Exception("Unauthorized") + + return User(id=current_user.id, email=current_user.email) diff --git a/src/schemas/types.py b/src/schemas/types.py index ce4220b..0caeafd 100644 --- a/src/schemas/types.py +++ b/src/schemas/types.py @@ -1,4 +1,5 @@ import enum +from typing import List import strawberry @@ -11,25 +12,23 @@ class MessageRole(enum.Enum): @strawberry.type class Profile: id: int - name: str | None - avatar_url: str | None + full_name: str | None user_id: str @strawberry.type class Chat: - id: int + id: str title: str - profile_id: str created_at: str @strawberry.type class Message: - id: int + id: str message: str role: MessageRole - chat_id: int + chat_id: str profile_id: str created_at: str @@ -41,28 +40,42 @@ class ChatMessageInput: @strawberry.type -class ChatMessageOutput: +class NoteOutput: + id: str content: str + created_at: str @strawberry.type -class Notebook: - title: str +class CreateNote: id: str + content: str created_at: str - updated_at: str - user_id: str @strawberry.type -class Note: - content: str +class User: id: str - created_at: str + email: str | None + + +@strawberry.input +class UpdateProfileInput: + full_name: str user_id: str @strawberry.type -class User: - id: str - email: str | None +class UpdateProfilePayload: + profile: Profile + + +@strawberry.input +class CreateUserInput: + email: str + + +@strawberry.type +class CreateUserPayload: + user: User + profile: Profile diff --git a/src/services/auth.py b/src/services/auth.py deleted file mode 100644 index 1e04416..0000000 --- a/src/services/auth.py +++ /dev/null @@ -1,14 +0,0 @@ -from fastapi import HTTPException -from src.services.supabase import supabase_client -from src.schemas.types import User - - -def get_current_user(token: str) -> User: - response = supabase_client.auth.get_user(token) - - if response is None: - raise HTTPException( - status_code=403, detail="Invalid authentication credentials" - ) - - return User(id=response.user.id, email=response.user.email) diff --git a/src/services/notebooks.py b/src/services/notebooks.py deleted file mode 100644 index be5e485..0000000 --- a/src/services/notebooks.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import List -from src.data.db import Session - -from src.data.models import Note as NoteModel -from src.schemas.types import Note -from src.utils.logger import logger - - -def get_notebooks(): - return [] - - -def get_user_notes(user_id: str) -> List[Note]: - logger.info("notes_search_start", {"user_id": user_id}) - session = Session() - notes = session.query(NoteModel).filter(NoteModel.user_id == user_id).all() - note_dicts = [note.__dict__ for note in notes] - logger.info("notes_search_end", {"user_id": user_id}) - return [Note(**note) for note in note_dicts] diff --git a/src/services/sherpa_service.py b/src/services/sherpa_service.py deleted file mode 100644 index 0494ac4..0000000 --- a/src/services/sherpa_service.py +++ /dev/null @@ -1,142 +0,0 @@ -from enum import Enum -import json -from typing import List -from src.data.db import Session - -from src.data.models import Message as MessageModel -from src.schemas.types import Message -from src.services.file_service import get_file_contents -from src.services.groq_service import groq_client -from src.services.notebooks import get_user_notes -from src.services.openai_service import openai_client -from src.services.supabase import supabase_client -from src.utils.ai_models import open_source_models -from src.utils.logger import logger -from src.utils.generation_statistics import GenerationStatistics - - -class AvailablePrompts(Enum): - v1 = "user_input_formatter_v1.md" - v2 = "user_input_formatter_v2.md" - - -def get_prompt(prompt: AvailablePrompts): - return get_file_contents(f"src/prompts/{prompt.value}") - - -def analyze_user_input(transcript: str, model: str = "llama3-70b-8192"): - system_prompt = get_prompt(AvailablePrompts.v2) - - if model in open_source_models: - - try: - completion = groq_client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": transcript}, - ], - temperature=0.3, - max_tokens=8000, - top_p=1, - stream=False, - response_format={"type": "json_object"}, - stop=None, - ) - - usage = completion.usage - # print("------ USAGE ---", usage) - - except Exception as e: - logger.error(f" ********* API error ********: {e} ***** ") - return None, {"error": str(e)} - - try: - if usage: - statistics_to_return = GenerationStatistics( - input_time=int(usage.prompt_time) if usage.prompt_time else 0, - output_time=( - int(usage.completion_time) if usage.completion_time else 0 - ), - input_tokens=usage.prompt_tokens, - output_tokens=usage.completion_tokens, - total_time=int(usage.total_time) if usage.total_time else 0, - model_name=model, - ) - logger.info("focus_stats", statistics_to_return.get_stats()) - - return ( - json.loads(completion.choices[0].message.content) - if completion.choices[0].message.content - else None - ) - except Exception as e: - logger.error(f" ********* STATISTICS GENERATION error ******* : {e} ") - return None, {"error": str(e)} - - -def insert_message(chat_id: int, message: str, profile_id: str, role: str) -> Message: - session = Session() - new_message = MessageModel( - message=message, chat_id=chat_id, profile_id=profile_id, role=role - ) - session.add(new_message) - session.commit() - - return Message(**new_message.__dict__) - - -def get_chat_history(chat_id: int) -> List[Message]: - session = Session() - chat = session.query(MessageModel).filter(MessageModel.chat_id == chat_id).all() - return [Message(**message) for message in chat] - - -def get_sherpa_response(message: str, chat_id, user_id) -> str | None: - system_prompt = """ - You are the user's expert-level personal assistant and best friend. - - You have full history of the user's chat with you and their Context, which is a list of notes they have taken. - These notes include their goals, tasks, and any other important information they have shared with you. - - The user is going to provide with their entire chat history with you, along with their latest message and \n - - You must respond to the user's message based on the chat history. - - ## Rules: - - Do not say who you are or that you are an AI. - - Do not speak in paragraphs. - - Respond with the least amount of words possible, but use full sentences. - - Include emojis in your responses where applicable. - - Your response should be in a friendly, upbeat and conversational tone. - - Your response should use all of the User Context and the entire Chat History to provide context to your response. - - Your response should use that knowledge about the user to answer the user's latest message. - - ## User Context - {user_context} - - ## Chat History - {chat_history} - """ - - chat_history = get_chat_history(chat_id) - user_context = get_user_notes(user_id) - chat_history_contents = [message.message for message in chat_history] - user_context_contents = [note.content for note in user_context] - - response = openai_client.chat.completions.create( - model="gpt-3.5-turbo", - messages=[ - { - "role": "system", - "content": system_prompt.format( - chat_history=chat_history_contents, - user_context=user_context_contents, - ), - }, - {"role": "user", "content": message}, - ], - stream=False, - ) - content = response.choices[0].message.content - return content