Skip to content

Commit

Permalink
Register tables in BallistaContext using TableProviders instead of Da…
Browse files Browse the repository at this point in the history
…taframe (#1028)
  • Loading branch information
rdettai authored Sep 20, 2021
1 parent 2258256 commit 843cd93
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 105 deletions.
42 changes: 23 additions & 19 deletions ballista/rust/client/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ use std::path::PathBuf;
use std::sync::{Arc, Mutex};

use ballista_core::config::BallistaConfig;
use ballista_core::{
datasource::DfTableAdapter, utils::create_df_ctx_with_ballista_query_planner,
};
use ballista_core::utils::create_df_ctx_with_ballista_query_planner;

use datafusion::catalog::TableReference;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::dataframe_impl::DataFrameImpl;
use datafusion::logical_plan::LogicalPlan;
Expand All @@ -44,7 +43,7 @@ struct BallistaContextState {
/// Scheduler port
scheduler_port: u16,
/// Tables that have been registered with this context
tables: HashMap<String, LogicalPlan>,
tables: HashMap<String, Arc<dyn TableProvider>>,
}

impl BallistaContextState {
Expand Down Expand Up @@ -197,11 +196,13 @@ impl BallistaContext {
}

/// Register a DataFrame as a table that can be referenced from a SQL query
pub fn register_table(&self, name: &str, table: &dyn DataFrame) -> Result<()> {
pub fn register_table(
&self,
name: &str,
table: Arc<dyn TableProvider>,
) -> Result<()> {
let mut state = self.state.lock().unwrap();
state
.tables
.insert(name.to_owned(), table.to_logical_plan());
state.tables.insert(name.to_owned(), table);
Ok(())
}

Expand All @@ -211,13 +212,17 @@ impl BallistaContext {
path: &str,
options: CsvReadOptions,
) -> Result<()> {
let df = self.read_csv(path, options)?;
self.register_table(name, df.as_ref())
match self.read_csv(path, options)?.to_logical_plan() {
LogicalPlan::TableScan { source, .. } => self.register_table(name, source),
_ => Err(DataFusionError::Internal("Expected tables scan".to_owned())),
}
}

pub fn register_parquet(&self, name: &str, path: &str) -> Result<()> {
let df = self.read_parquet(path)?;
self.register_table(name, df.as_ref())
match self.read_parquet(path)?.to_logical_plan() {
LogicalPlan::TableScan { source, .. } => self.register_table(name, source),
_ => Err(DataFusionError::Internal("Expected tables scan".to_owned())),
}
}

pub fn register_avro(
Expand All @@ -226,9 +231,10 @@ impl BallistaContext {
path: &str,
options: AvroReadOptions,
) -> Result<()> {
let df = self.read_avro(path, options)?;
self.register_table(name, df.as_ref())?;
Ok(())
match self.read_avro(path, options)?.to_logical_plan() {
LogicalPlan::TableScan { source, .. } => self.register_table(name, source),
_ => Err(DataFusionError::Internal("Expected tables scan".to_owned())),
}
}

/// Create a DataFrame from a SQL statement
Expand All @@ -245,12 +251,10 @@ impl BallistaContext {
// register tables with DataFusion context
{
let state = self.state.lock().unwrap();
for (name, plan) in &state.tables {
let plan = ctx.optimize(plan)?;
let execution_plan = ctx.create_physical_plan(&plan)?;
for (name, prov) in &state.tables {
ctx.register_table(
TableReference::Bare { table: name },
Arc::new(DfTableAdapter::new(plan, execution_plan)),
Arc::clone(prov),
)?;
}
}
Expand Down
64 changes: 0 additions & 64 deletions ballista/rust/core/src/datasource.rs

This file was deleted.

1 change: 0 additions & 1 deletion ballista/rust/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pub fn print_version() {

pub mod client;
pub mod config;
pub mod datasource;
pub mod error;
pub mod execution_plans;
pub mod memory_stream;
Expand Down
16 changes: 1 addition & 15 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
//! processes.
use super::super::proto_error;
use crate::datasource::DfTableAdapter;
use crate::serde::{protobuf, BallistaError};
use datafusion::arrow::datatypes::{
DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit,
Expand Down Expand Up @@ -728,20 +727,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
..
} => {
let schema = source.schema();

// unwrap the DFTableAdapter to get to the real TableProvider
let source = if let Some(adapter) =
source.as_any().downcast_ref::<DfTableAdapter>()
{
match &adapter.logical_plan {
LogicalPlan::TableScan { source, .. } => Ok(source.as_any()),
_ => Err(BallistaError::General(
"Invalid LogicalPlan::TableScan".to_owned(),
)),
}
} else {
Ok(source.as_any())
}?;
let source = source.as_any();

let projection = match projection {
None => None,
Expand Down
7 changes: 1 addition & 6 deletions ballista/rust/scheduler/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
use std::collections::HashMap;
use std::sync::Arc;

use ballista_core::datasource::DfTableAdapter;
use ballista_core::error::{BallistaError, Result};
use ballista_core::{
execution_plans::{ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec},
serde::scheduler::PartitionLocation,
};
use datafusion::execution::context::ExecutionContext;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::windows::WindowAggExec;
Expand Down Expand Up @@ -96,10 +94,7 @@ impl DistributedPlanner {
stages.append(&mut child_stages);
}

if let Some(adapter) = execution_plan.as_any().downcast_ref::<DfTableAdapter>() {
let ctx = ExecutionContext::new();
Ok((ctx.create_physical_plan(&adapter.logical_plan)?, stages))
} else if let Some(coalesce) = execution_plan
if let Some(coalesce) = execution_plan
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
{
Expand Down

0 comments on commit 843cd93

Please sign in to comment.