Skip to content

Commit

Permalink
feat: Implement list_columns for datasources. (#1720)
Browse files Browse the repository at this point in the history
Lists the columns information for datasources:

```
> CREATE EXTERNAL DATABASE "PgDb"
::: FROM postgres OPTIONS (
:::   host = 'pg.demo.glaredb.com',
:::   port = '5432',
:::   user = 'demo',
:::   password = 'demo',
:::   database = 'postgres',
::: );
Database created
>
> select * from list_columns("PgDb", public, nation);
┌─────────────┬───────────┬──────────┐
│ column_name │ data_type │ nullable │
│ ──          │ ──        │ ──       │
│ Utf8        │ Utf8      │ Boolean  │
╞═════════════╪═══════════╪══════════╡
│ n_nationkey │ Int32     │ true     │
│ n_name      │ Utf8      │ true     │
│ n_regionkey │ Int64     │ true     │
│ n_comment   │ Utf8      │ true     │
└─────────────┴───────────┴──────────┘
```

Syntax:

```sql
-- All the three arguments are identifiers so they can either be double
-- quoted or quoteless.
--
-- Prefer adding `"` (double quotation marks) on cloud.
--
select * from list_columns("<database>", "<schema>", "<table>");
```

Fixes #1540

Signed-off-by: Vaibhav <[email protected]>
  • Loading branch information
vrongmeal authored Sep 7, 2023
1 parent c6f7866 commit bb02791
Show file tree
Hide file tree
Showing 15 changed files with 309 additions and 28 deletions.
28 changes: 26 additions & 2 deletions crates/datasources/src/bigquery/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,17 @@ impl VirtualLister for BigQueryAccessor {
Ok(schemas)
}

async fn list_tables(&self, schema: &str) -> Result<Vec<String>, DatasourceCommonError> {
async fn list_tables(&self, dataset_id: &str) -> Result<Vec<String>, DatasourceCommonError> {
use DatasourceCommonError::ListingErrBoxed;

let tables = self
.metadata
.table()
.list(&self.gcp_project_id, schema, table::ListOptions::default())
.list(
&self.gcp_project_id,
dataset_id,
table::ListOptions::default(),
)
.await
.map_err(|e| ListingErrBoxed(Box::new(BigQueryError::from(e))))?;

Expand All @@ -206,6 +210,26 @@ impl VirtualLister for BigQueryAccessor {

Ok(tables)
}

async fn list_columns(
&self,
dataset_id: &str,
table_id: &str,
) -> Result<Fields, DatasourceCommonError> {
use DatasourceCommonError::ListingErrBoxed;

let table_meta = self
.metadata
.table()
.get(&self.gcp_project_id, dataset_id, table_id, None)
.await
.map_err(|e| ListingErrBoxed(Box::new(e)))?;

let schema = bigquery_table_to_arrow_schema(&table_meta)
.map_err(|e| ListingErrBoxed(Box::new(e)))?;

Ok(schema.fields)
}
}

pub struct BigQueryTableProvider {
Expand Down
4 changes: 4 additions & 0 deletions crates/datasources/src/common/listing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! data source. These essentially provide a trimmed down information schema.

use async_trait::async_trait;
use datafusion::arrow::datatypes::Fields;

use super::errors::Result;

Expand All @@ -20,4 +21,7 @@ pub trait VirtualLister: Sync + Send {

/// List tables for a data source.
async fn list_tables(&self, schema: &str) -> Result<Vec<String>>;

/// List columns for a specific table in the datasource.
async fn list_columns(&self, schema: &str, table: &str) -> Result<Fields>;
}
21 changes: 20 additions & 1 deletion crates/datasources/src/debug/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::common::listing::VirtualLister;
use async_trait::async_trait;
use datafusion::arrow::array::Int32Array;
use datafusion::arrow::datatypes::{
DataType, Field, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
DataType, Field, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
};
use datafusion::arrow::error::Result as ArrowResult;
use datafusion::arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -158,6 +158,25 @@ impl VirtualLister for DebugVirtualLister {
let tables = (0..2).map(|i| format!("{schema}_table_{i}")).collect();
Ok(tables)
}

async fn list_columns(
&self,
schema: &str,
table: &str,
) -> Result<Fields, DatasourceCommonError> {
Ok((0..2)
.map(|i| {
let name = format!("{schema}_{table}_col_{i}");
let datatype = if i % 2 == 0 {
DataType::Utf8
} else {
DataType::Int64
};
let nullable = i % 2 == 0;
Field::new(name, datatype, nullable)
})
.collect())
}
}

pub struct DebugTableProvider {
Expand Down
24 changes: 21 additions & 3 deletions crates/datasources/src/mongodb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use infer::TableSampler;
use crate::common::errors::DatasourceCommonError;
use crate::common::listing::VirtualLister;
use async_trait::async_trait;
use datafusion::arrow::datatypes::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef};
use datafusion::arrow::datatypes::{Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef};
use datafusion::datasource::TableProvider;
use datafusion::error::Result as DatafusionResult;
use datafusion::execution::context::SessionState;
Expand Down Expand Up @@ -158,17 +158,35 @@ impl VirtualLister for MongoAccessor {
Ok(databases)
}

async fn list_tables(&self, schema: &str) -> Result<Vec<String>, DatasourceCommonError> {
async fn list_tables(&self, database: &str) -> Result<Vec<String>, DatasourceCommonError> {
use DatasourceCommonError::ListingErrBoxed;

let database = self.client.database(schema);
let database = self.client.database(database);
let collections = database
.list_collection_names(/* filter: */ None)
.await
.map_err(|e| ListingErrBoxed(Box::new(e)))?;

Ok(collections)
}

async fn list_columns(
&self,
database: &str,
collection: &str,
) -> Result<Fields, DatasourceCommonError> {
use DatasourceCommonError::ListingErrBoxed;

let collection = self.client.database(database).collection(collection);
let sampler = TableSampler::new(collection);

let schema = sampler
.infer_schema_from_sample()
.await
.map_err(|e| ListingErrBoxed(Box::new(e)))?;

Ok(schema.fields)
}
}

#[derive(Debug, Clone)]
Expand Down
43 changes: 31 additions & 12 deletions crates/datasources/src/mysql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use async_stream::stream;
use async_trait::async_trait;
use chrono::{NaiveDate, NaiveDateTime, NaiveTime, Timelike};
use datafusion::arrow::datatypes::{
DataType, Field, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit,
DataType, Field, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit,
};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::TableProvider;
Expand Down Expand Up @@ -187,27 +187,31 @@ impl MysqlAccessor {
Ok(())
}

pub async fn into_table_provider(
mut self,
table_access: MysqlTableAccess,
predicate_pushdown: bool,
) -> Result<MysqlTableProvider> {
let conn = self.conn.get_mut();
/// Get the arrow schema for the MySQL table.
async fn get_table_schema(&self, schema: &str, table: &str) -> Result<ArrowSchema> {
let mut conn = self.conn.write().await;

let cols = conn
.exec_iter(
format!(
"SELECT * FROM {}.{} where false",
table_access.schema, table_access.name
),
format!("SELECT * FROM {}.{} where false", schema, table),
(),
)
.await?;
let cols = cols.columns_ref();

// Genrate arrow schema from table schema
let arrow_schema = try_create_arrow_schema(cols)?;
trace!(?arrow_schema);
Ok(arrow_schema)
}

pub async fn into_table_provider(
self,
table_access: MysqlTableAccess,
predicate_pushdown: bool,
) -> Result<MysqlTableProvider> {
let arrow_schema = self
.get_table_schema(&table_access.schema, &table_access.name)
.await?;

Ok(MysqlTableProvider {
predicate_pushdown,
Expand Down Expand Up @@ -261,6 +265,21 @@ impl VirtualLister for MysqlAccessor {

Ok(cols)
}

async fn list_columns(
&self,
schema: &str,
table: &str,
) -> Result<Fields, DatasourceCommonError> {
use DatasourceCommonError::ListingErrBoxed;

let schema = self
.get_table_schema(schema, table)
.await
.map_err(|e| ListingErrBoxed(Box::new(e)))?;

Ok(schema.fields)
}
}

pub struct MysqlTableProvider {
Expand Down
17 changes: 16 additions & 1 deletion crates/datasources/src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use chrono::naive::{NaiveDateTime, NaiveTime};
use chrono::{DateTime, NaiveDate, Timelike, Utc};
use datafusion::arrow::array::Decimal128Builder;
use datafusion::arrow::datatypes::{
DataType, Field, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit,
DataType, Field, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit,
};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::TableProvider;
Expand Down Expand Up @@ -462,6 +462,21 @@ WHERE
}
Ok(virtual_tables)
}

async fn list_columns(
&self,
schema: &str,
table: &str,
) -> Result<Fields, DatasourceCommonError> {
use DatasourceCommonError::ListingErrBoxed;

let (schema, _) = self
.get_table_schema(schema, table)
.await
.map_err(|e| ListingErrBoxed(Box::new(e)))?;

Ok(schema.fields)
}
}

pub struct PostgresTableProviderConfig {
Expand Down
30 changes: 25 additions & 5 deletions crates/datasources/src/snowflake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::common::errors::DatasourceCommonError;
use crate::common::listing::VirtualLister;
use crate::common::util;
use async_trait::async_trait;
use datafusion::arrow::datatypes::Fields;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::execution::context::TaskContext;
use datafusion::physical_expr::PhysicalSortExpr;
Expand Down Expand Up @@ -99,15 +100,17 @@ impl SnowflakeAccessor {
let _res = accessor.conn.query_sync(query, vec![]).await?;

// Get table schema
accessor.get_table_schema(table_access).await
accessor
.get_table_schema(&table_access.schema_name, &table_access.table_name)
.await
}

async fn get_table_schema(&self, table_access: &SnowflakeTableAccess) -> Result<ArrowSchema> {
async fn get_table_schema(&self, schema_name: &str, table_name: &str) -> Result<ArrowSchema> {
// Snowflake stores data as upper-case. Maybe this won't be an issue
// when we use bindings but for now, manually transform everything to
// uppercase values.
let table_schema = table_access.schema_name.to_uppercase();
let table_name = table_access.table_name.to_uppercase();
let table_schema = schema_name.to_uppercase();
let table_name = table_name.to_uppercase();

let res = self
.conn
Expand Down Expand Up @@ -174,7 +177,9 @@ WHERE
table_access: SnowflakeTableAccess,
predicate_pushdown: bool,
) -> Result<SnowflakeTableProvider> {
let arrow_schema = self.get_table_schema(&table_access).await?;
let arrow_schema = self
.get_table_schema(&table_access.schema_name, &table_access.table_name)
.await?;

Ok(SnowflakeTableProvider {
predicate_pushdown,
Expand Down Expand Up @@ -265,6 +270,21 @@ impl VirtualLister for SnowflakeAccessor {

Ok(tables_list)
}

async fn list_columns(
&self,
schema_name: &str,
table_name: &str,
) -> Result<Fields, DatasourceCommonError> {
use DatasourceCommonError::ListingErrBoxed;

let schema = self
.get_table_schema(schema_name, table_name)
.await
.map_err(|e| ListingErrBoxed(Box::new(e)))?;

Ok(schema.fields)
}
}

pub struct SnowflakeTableProvider {
Expand Down
3 changes: 2 additions & 1 deletion crates/sqlbuiltins/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use self::mysql::ReadMysql;
use self::object_store::{CSV_SCAN, JSON_SCAN, PARQUET_SCAN};
use self::postgres::ReadPostgres;
use self::snowflake::ReadSnowflake;
use self::virtual_listing::{ListSchemas, ListTables};
use self::virtual_listing::{ListColumns, ListSchemas, ListTables};

/// Builtin table returning functions available for all sessions.
pub static BUILTIN_TABLE_FUNCS: Lazy<BuiltinTableFuncs> = Lazy::new(BuiltinTableFuncs::new);
Expand Down Expand Up @@ -58,6 +58,7 @@ impl BuiltinTableFuncs {
// Listing
Arc::new(ListSchemas),
Arc::new(ListTables),
Arc::new(ListColumns),
// Series generating
Arc::new(GenerateSeries),
];
Expand Down
Loading

0 comments on commit bb02791

Please sign in to comment.