Skip to content

Commit

Permalink
add function SessionContext::register_batch to register a single Reco…
Browse files Browse the repository at this point in the history
…rdBatch as a table (#3600)
  • Loading branch information
BaymaxHWY authored Sep 24, 2022
1 parent d7c0e42 commit 0b204c6
Show file tree
Hide file tree
Showing 22 changed files with 120 additions and 215 deletions.
4 changes: 1 addition & 3 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,9 +1102,7 @@ mod tests {
let schema = get_schema(table);
let batch = RecordBatch::new_empty(Arc::new(schema.to_owned()));

let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;

ctx.register_table(table, Arc::new(provider))?;
ctx.register_batch(table, batch)?;
}

let sql = &get_query_sql(n)?;
Expand Down
6 changes: 2 additions & 4 deletions datafusion-examples/examples/dataframe_in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use std::sync::Arc;
use datafusion::arrow::array::{Int32Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::from_slice::FromSlice;
use datafusion::prelude::*;
Expand All @@ -36,7 +35,7 @@ async fn main() -> Result<()> {

// define data.
let batch = RecordBatch::try_new(
schema.clone(),
schema,
vec![
Arc::new(StringArray::from_slice(&["a", "b", "c", "d"])),
Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])),
Expand All @@ -47,8 +46,7 @@ async fn main() -> Result<()> {
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
ctx.register_table("t", Arc::new(provider))?;
ctx.register_batch("t", batch)?;
let df = ctx.table("t")?;

// construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL
Expand Down
6 changes: 2 additions & 4 deletions datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use std::sync::Arc;
// create local execution context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Expand All @@ -41,7 +40,7 @@ fn create_context() -> Result<SessionContext> {

// define data.
let batch = RecordBatch::try_new(
schema.clone(),
schema,
vec![
Arc::new(Float32Array::from_slice(&[2.1, 3.1, 4.1, 5.1])),
Arc::new(Float64Array::from_slice(&[1.0, 2.0, 3.0, 4.0])),
Expand All @@ -52,8 +51,7 @@ fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
ctx.register_table("t", Arc::new(provider))?;
ctx.register_batch("t", batch)?;
Ok(ctx)
}

Expand Down
4 changes: 1 addition & 3 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1361,10 +1361,8 @@ mod tests {
],
)?;

let table = crate::datasource::MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;

let sql = r#"
SELECT
Expand Down
25 changes: 25 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ use std::{
};

use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;

use crate::catalog::{
catalog::{CatalogProvider, MemoryCatalogProvider},
Expand Down Expand Up @@ -230,6 +231,16 @@ impl SessionContext {
self.table_factories.insert(file_type.to_string(), factory);
}

/// Registers the RecordBatch as the specified table name
pub fn register_batch(
&self,
table_name: &str,
batch: RecordBatch,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
self.register_table(table_name, Arc::new(table))
}

/// Return the [RuntimeEnv] used to run queries with this [SessionContext]
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
self.state.read().runtime_env.clone()
Expand Down Expand Up @@ -713,6 +724,20 @@ impl SessionContext {
)))
}

/// Creates a DataFrame for reading a custom RecordBatch
pub fn read_batch(&self, batch: RecordBatch) -> Result<Arc<DataFrame>> {
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
Ok(Arc::new(DataFrame::new(
self.state.clone(),
&LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?,
)))
}

/// Registers a table that uses the listing feature of the object store to
/// find the files to be processed
/// This is async because it might need to resolve the schema.
Expand Down
24 changes: 8 additions & 16 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::logical_plan::{col, Expr};
use datafusion::prelude::CsvReadOptions;
use datafusion::{datasource::MemTable, prelude::JoinType};
use datafusion::prelude::JoinType;
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::{avg, count, lit, sum};

Expand Down Expand Up @@ -63,14 +63,11 @@ async fn join() -> Result<()> {

let ctx = SessionContext::new();

let table1 = MemTable::try_new(schema1, vec![vec![batch1]])?;
let table2 = MemTable::try_new(schema2, vec![vec![batch2]])?;

ctx.register_table("aa", Arc::new(table1))?;
ctx.register_batch("aa", batch1)?;

let df1 = ctx.table("aa")?;

ctx.register_table("aaa", Arc::new(table2))?;
ctx.register_batch("aaa", batch2)?;

let df2 = ctx.table("aaa")?;

Expand Down Expand Up @@ -100,8 +97,7 @@ async fn sort_on_unprojected_columns() -> Result<()> {
.unwrap();

let ctx = SessionContext::new();
let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap();
ctx.register_table("t", Arc::new(provider)).unwrap();
ctx.register_batch("t", batch).unwrap();

let df = ctx
.table("t")
Expand Down Expand Up @@ -143,8 +139,7 @@ async fn filter_with_alias_overwrite() -> Result<()> {
.unwrap();

let ctx = SessionContext::new();
let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap();
ctx.register_table("t", Arc::new(provider)).unwrap();
ctx.register_batch("t", batch).unwrap();

let df = ctx
.table("t")
Expand Down Expand Up @@ -180,8 +175,7 @@ async fn select_with_alias_overwrite() -> Result<()> {
.unwrap();

let ctx = SessionContext::new();
let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap();
ctx.register_table("t", Arc::new(provider)).unwrap();
ctx.register_batch("t", batch).unwrap();

let df = ctx
.table("t")
Expand Down Expand Up @@ -394,7 +388,7 @@ fn create_test_table() -> Result<Arc<DataFrame>> {

// define data.
let batch = RecordBatch::try_new(
schema.clone(),
schema,
vec![
Arc::new(StringArray::from_slice(&[
"abcDEF",
Expand All @@ -408,9 +402,7 @@ fn create_test_table() -> Result<Arc<DataFrame>> {

let ctx = SessionContext::new();

let table = MemTable::try_new(schema, vec![vec![batch]])?;

ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", batch)?;

ctx.table("test")
}
Expand Down
7 changes: 2 additions & 5 deletions datafusion/core/tests/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use datafusion::from_slice::FromSlice;
use std::sync::Arc;

use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;

use datafusion::error::Result;

Expand All @@ -43,7 +42,7 @@ fn create_test_table() -> Result<Arc<DataFrame>> {

// define data.
let batch = RecordBatch::try_new(
schema.clone(),
schema,
vec![
Arc::new(StringArray::from_slice(&[
"abcDEF",
Expand All @@ -57,9 +56,7 @@ fn create_test_table() -> Result<Arc<DataFrame>> {

let ctx = SessionContext::new();

let table = MemTable::try_new(schema, vec![vec![batch]])?;

ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", batch)?;

ctx.table("test")
}
Expand Down
10 changes: 3 additions & 7 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,7 @@ async fn median_test(
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new("a", data_type, false)]));
let batch = RecordBatch::try_new(schema.clone(), vec![values])?;
let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
ctx.register_table("t", table)?;
ctx.register_batch("t", batch)?;
let sql = format!("SELECT {}(a) FROM t", func);
let actual = execute(&ctx, &sql).await;
let expected = vec![vec![expected.to_owned()]];
Expand Down Expand Up @@ -2108,9 +2107,8 @@ async fn query_sum_distinct() -> Result<()> {
],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;
let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;

// 2 different aggregate functions: avg and sum(distinct)
let sql = "SELECT AVG(c1), SUM(DISTINCT c2) FROM test";
Expand Down Expand Up @@ -2153,10 +2151,8 @@ async fn query_count_distinct() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT COUNT(DISTINCT c1) FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down
44 changes: 11 additions & 33 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,8 @@ async fn query_not() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT NOT c1 FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -266,10 +264,8 @@ async fn query_is_null() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT c1 IS NULL FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -298,10 +294,8 @@ async fn query_is_not_null() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT c1 IS NOT NULL FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -330,10 +324,8 @@ async fn query_is_true() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT c1 IS TRUE as t FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -362,10 +354,8 @@ async fn query_is_false() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT c1 IS FALSE as f FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -394,10 +384,8 @@ async fn query_is_not_true() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT c1 IS NOT TRUE as nt FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -426,10 +414,8 @@ async fn query_is_not_false() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT c1 IS NOT FALSE as nf FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -458,10 +444,8 @@ async fn query_is_unknown() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT c1 IS UNKNOWN as t FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -490,10 +474,8 @@ async fn query_is_not_unknown() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT c1 IS NOT UNKNOWN as t FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -554,10 +536,8 @@ async fn query_scalar_minus_array() -> Result<()> {
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;
let sql = "SELECT 4 - c1 FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
Expand Down Expand Up @@ -1708,11 +1688,9 @@ async fn query_binary_eq() -> Result<()> {
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();

ctx.register_table("test", Arc::new(table))?;
ctx.register_batch("test", data)?;

let sql = "
SELECT sha256(c1)=digest('one', 'sha256'), sha256(c2)=sha256('two'), digest(c1, 'blake2b')=digest(c3, 'blake2b'), c2=c4
Expand Down
Loading

0 comments on commit 0b204c6

Please sign in to comment.