Skip to content

Commit

Permalink
Add pgvector provider implementation (apache#35399)
Browse files Browse the repository at this point in the history
This PR is part of our larger effort to add first-class integrations to support LLMOps that was [presented at Airflow Summit](https://www.youtube.com/watch?v=mgA6m3ggKhs&t=4s).

This PR adds explicitly the pgvector Provider. 
https://github.com/pgvector/pgvector

Email Discussion related to the effort can be found here - https://lists.apache.org/thread/0d669fmy4hn29h5c0wj0ottdskd77ktp
  • Loading branch information
pankajkoti authored and romsharon98 committed Nov 10, 2023
1 parent b319fa1 commit 202b63d
Show file tree
Hide file tree
Showing 36 changed files with 1,036 additions and 123 deletions.
1 change: 1 addition & 0 deletions .github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ body:
- oracle
- pagerduty
- papermill
- pgvector
- pinecone
- plexus
- postgres
Expand Down
8 changes: 4 additions & 4 deletions CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -677,10 +677,10 @@ doc, doc_gen, docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api,
github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc,
jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp,
microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, openlineage, opensearch, opsgenie,
oracle, otel, pagerduty, pandas, papermill, password, pinecone, pinot, plexus, postgres, presto,
rabbitmq, redis, s3, s3fs, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack,
smtp, snowflake, spark, sqlite, ssh, statsd, tableau, tabular, telegram, trino, vertica, virtualenv,
weaviate, webhdfs, winrm, yandex, zendesk
oracle, otel, pagerduty, pandas, papermill, password, pgvector, pinecone, pinot, plexus, postgres,
presto, rabbitmq, redis, s3, s3fs, salesforce, samba, segment, sendgrid, sentry, sftp, singularity,
slack, smtp, snowflake, spark, sqlite, ssh, statsd, tableau, tabular, telegram, trino, vertica,
virtualenv, weaviate, webhdfs, winrm, yandex, zendesk
.. END EXTRAS HERE
Provider packages
Expand Down
8 changes: 4 additions & 4 deletions INSTALL
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ doc, doc_gen, docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api,
github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc,
jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp,
microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, openlineage, opensearch, opsgenie,
oracle, otel, pagerduty, pandas, papermill, password, pinecone, pinot, plexus, postgres, presto,
rabbitmq, redis, s3, s3fs, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack,
smtp, snowflake, spark, sqlite, ssh, statsd, tableau, tabular, telegram, trino, vertica, virtualenv,
weaviate, webhdfs, winrm, yandex, zendesk
oracle, otel, pagerduty, pandas, papermill, password, pgvector, pinecone, pinot, plexus, postgres,
presto, rabbitmq, redis, s3, s3fs, salesforce, samba, segment, sendgrid, sentry, sftp, singularity,
slack, smtp, snowflake, spark, sqlite, ssh, statsd, tableau, tabular, telegram, trino, vertica,
virtualenv, weaviate, webhdfs, winrm, yandex, zendesk
# END EXTRAS HERE

# For installing Airflow in development environments - see CONTRIBUTING.rst
Expand Down
26 changes: 26 additions & 0 deletions airflow/providers/pgvector/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
``apache-airflow-providers-pgvector``

Changelog
---------

1.0.0
.....

Initial version of the provider.
43 changes: 43 additions & 0 deletions airflow/providers/pgvector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE
# OVERWRITTEN WHEN PREPARING DOCUMENTATION FOR THE PACKAGES.
#
# IF YOU WANT TO MODIFY IT, YOU SHOULD MODIFY THE TEMPLATE
# `PROVIDER__INIT__PY_TEMPLATE.py.jinja2` IN the `dev/provider_packages` DIRECTORY
#
from __future__ import annotations

import packaging.version

__all__ = ["__version__"]

__version__ = "1.0.0"

try:
from airflow import __version__ as airflow_version
except ImportError:
from airflow.version import version as airflow_version

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.5.0"
):
raise RuntimeError(
f"The package `apache-airflow-providers-pgvector:{__version__}` requires Apache Airflow 2.5.0+"
)
24 changes: 24 additions & 0 deletions airflow/providers/pgvector/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE
# OVERWRITTEN WHEN PREPARING DOCUMENTATION FOR THE PACKAGES.
#
# IF YOU WANT TO MODIFY IT, YOU SHOULD MODIFY THE TEMPLATE
# `PROVIDER__INIT__PY_TEMPLATE.py.jinja2` IN the `dev/provider_packages` DIRECTORY
#
80 changes: 80 additions & 0 deletions airflow/providers/pgvector/hooks/pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from airflow.providers.postgres.hooks.postgres import PostgresHook


class PgVectorHook(PostgresHook):
"""Extend PostgresHook for working with PostgreSQL and pgvector extension for vector data types."""

def __init__(self, *args, **kwargs) -> None:
"""Initialize a PgVectorHook."""
super().__init__(*args, **kwargs)

def create_table(self, table_name: str, columns: list[str], if_not_exists: bool = True) -> None:
"""
Create a table in the Postgres database.
:param table_name: The name of the table to create.
:param columns: A list of column definitions for the table.
:param if_not_exists: If True, only create the table if it does not already exist.
"""
create_table_sql = "CREATE TABLE"
if if_not_exists:
create_table_sql = f"{create_table_sql} IF NOT EXISTS"
create_table_sql = f"{create_table_sql} {table_name} ({', '.join(columns)})"
self.run(create_table_sql)

def create_extension(self, extension_name: str, if_not_exists: bool = True) -> None:
"""
Create a PostgreSQL extension.
:param extension_name: The name of the extension to create.
:param if_not_exists: If True, only create the extension if it does not already exist.
"""
create_extension_sql = "CREATE EXTENSION"
if if_not_exists:
create_extension_sql = f"{create_extension_sql} IF NOT EXISTS"
create_extension_sql = f"{create_extension_sql} {extension_name}"
self.run(create_extension_sql)

def drop_table(self, table_name: str, if_exists: bool = True) -> None:
"""
Drop a table from the Postgres database.
:param table_name: The name of the table to drop.
:param if_exists: If True, only drop the table if it exists.
"""
drop_table_sql = "DROP TABLE"
if if_exists:
drop_table_sql = f"{drop_table_sql} IF EXISTS"
drop_table_sql = f"{drop_table_sql} {table_name}"
self.run(drop_table_sql)

def truncate_table(self, table_name: str, restart_identity: bool = True) -> None:
"""
Truncate a table, removing all rows.
:param table_name: The name of the table to truncate.
:param restart_identity: If True, restart the serial sequence if the table has one.
"""
truncate_sql = f"TRUNCATE TABLE {table_name}"
if restart_identity:
truncate_sql = f"{truncate_sql} RESTART IDENTITY"
self.run(truncate_sql)
24 changes: 24 additions & 0 deletions airflow/providers/pgvector/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE
# OVERWRITTEN WHEN PREPARING DOCUMENTATION FOR THE PACKAGES.
#
# IF YOU WANT TO MODIFY IT, YOU SHOULD MODIFY THE TEMPLATE
# `PROVIDER__INIT__PY_TEMPLATE.py.jinja2` IN the `dev/provider_packages` DIRECTORY
#
49 changes: 49 additions & 0 deletions airflow/providers/pgvector/operators/pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from pgvector.psycopg2 import register_vector

from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator


class PgVectorIngestOperator(SQLExecuteQueryOperator):
"""
This operator is designed for ingesting data into a PostgreSQL database with pgvector support.
It inherits from the SQLExecuteQueryOperator and extends its functionality by registering
the pgvector data type with the database connection before executing queries.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:PgVectorIngestOperator`
"""

def __init__(self, *args, **kwargs) -> None:
"""Initialize a new PgVectorIngestOperator."""
super().__init__(*args, **kwargs)

def _register_vector(self) -> None:
"""Register the vector type with your connection."""
conn = self.get_db_hook().get_conn()
register_vector(conn)

def execute(self, context):
self._register_vector()
super().execute(context)
51 changes: 51 additions & 0 deletions airflow/providers/pgvector/provider.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

---
package-name: apache-airflow-providers-pgvector

name: pgvector

description: |
`pgvector <https://github.com/pgvector/pgvector>`__
suspended: false

versions:
- 1.0.0

integrations:
- integration-name: pgvector
external-doc-url: https://github.com/pgvector/pgvector
how-to-guide:
- /docs/apache-airflow-providers-pgvector/operators/pgvector.rst
tags: [software]

dependencies:
- apache-airflow>=2.5.0
- apache-airflow-providers-postgres>=5.7.1
- pgvector>=0.2.3

hooks:
- integration-name: pgvector
python-modules:
- airflow.providers.pgvector.hooks.pgvector

operators:
- integration-name: pgvector
python-modules:
- airflow.providers.pgvector.operators.pgvector
12 changes: 7 additions & 5 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"tests/providers/postgres/file.py",
),
{
"affected-providers-list-as-string": "amazon common.sql google openlineage postgres",
"affected-providers-list-as-string": "amazon common.sql google openlineage "
"pgvector postgres",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"python-versions": "['3.8']",
Expand All @@ -169,7 +170,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"docs-build": "true",
"upgrade-to-newer-dependencies": "false",
"parallel-test-types-list-as-string": "API Always Providers[amazon] "
"Providers[common.sql,openlineage,postgres] Providers[google]",
"Providers[common.sql,openlineage,pgvector,postgres] Providers[google]",
},
id="API and providers tests and docs should run",
)
Expand Down Expand Up @@ -225,7 +226,8 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"tests/providers/postgres/file.py",
),
{
"affected-providers-list-as-string": "amazon common.sql google openlineage postgres",
"affected-providers-list-as-string": "amazon common.sql google openlineage "
"pgvector postgres",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"python-versions": "['3.8']",
Expand All @@ -239,7 +241,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"run-kubernetes-tests": "true",
"upgrade-to-newer-dependencies": "false",
"parallel-test-types-list-as-string": "Always Providers[amazon] "
"Providers[common.sql,openlineage,postgres] Providers[google]",
"Providers[common.sql,openlineage,pgvector,postgres] Providers[google]",
},
id="Helm tests, providers (both upstream and downstream),"
"kubernetes tests and docs should run",
Expand Down Expand Up @@ -1064,7 +1066,7 @@ def test_upgrade_to_newer_dependencies(
"docs-list-as-string": "apache-airflow amazon apache.drill apache.druid apache.hive "
"apache.impala apache.pinot common.sql databricks elasticsearch "
"exasol google jdbc microsoft.mssql mysql odbc openlineage "
"oracle postgres presto slack snowflake sqlite trino vertica",
"oracle pgvector postgres presto slack snowflake sqlite trino vertica",
},
id="Common SQL provider package python files changed",
),
Expand Down
Loading

0 comments on commit 202b63d

Please sign in to comment.