Skip to content

Commit

Permalink
Extract more logic into submodule from execution::context
Browse files Browse the repository at this point in the history
  • Loading branch information
ongchi committed Oct 17, 2023
1 parent 60c0190 commit 0232174
Show file tree
Hide file tree
Showing 5 changed files with 394 additions and 311 deletions.
83 changes: 83 additions & 0 deletions datafusion/core/src/execution/context/avro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use super::super::options::{AvroReadOptions, ReadOptions};
use super::{DataFilePaths, DataFrame, Result, SessionContext};

impl SessionContext {
/// Creates a [`DataFrame`] for reading an Avro data source.
///
/// For more control such as reading multiple files, you can use
/// [`read_table`](Self::read_table) with a [`super::ListingTable`].
///
/// For an example, see [`read_csv`](Self::read_csv)
pub async fn read_avro<P: DataFilePaths>(
&self,
table_paths: P,
options: AvroReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}

/// Registers an Avro file as a table that can be referenced from
/// SQL statements executed against this context.
pub async fn register_avro(
&self,
name: &str,
table_path: &str,
options: AvroReadOptions<'_>,
) -> Result<()> {
let listing_options = options.to_listing_options(&self.copied_config());

self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;

// Test for compilation error when calling read_* functions from an #[async_trait] function.
// See https://github.com/apache/arrow-datafusion/issues/1154
#[async_trait]
trait CallReadTrait {
async fn call_read_avro(&self) -> DataFrame;
}

struct CallRead {}

#[async_trait]
impl CallReadTrait for CallRead {
async fn call_read_avro(&self) -> DataFrame {
let ctx = SessionContext::new();
ctx.read_avro("dummy", AvroReadOptions::default())
.await
.unwrap()
}
}
}
143 changes: 143 additions & 0 deletions datafusion/core/src/execution/context/csv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::datasource::physical_plan::plan_to_csv;

use super::super::options::{CsvReadOptions, ReadOptions};
use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext};

impl SessionContext {
/// Creates a [`DataFrame`] for reading a CSV data source.
///
/// For more control such as reading multiple files, you can use
/// [`read_table`](Self::read_table) with a [`super::ListingTable`].
///
/// Example usage is given below:
///
/// ```
/// use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// // You can read a single file using `read_csv`
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
/// // you can also read multiple files:
/// let df = ctx.read_csv(vec!["tests/data/example.csv", "tests/data/example.csv"], CsvReadOptions::new()).await?;
/// # Ok(())
/// # }
/// ```
pub async fn read_csv<P: DataFilePaths>(
&self,
table_paths: P,
options: CsvReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}

/// Registers a CSV file as a table which can referenced from SQL
/// statements executed against this context.
pub async fn register_csv(
&self,
name: &str,
table_path: &str,
options: CsvReadOptions<'_>,
) -> Result<()> {
let listing_options = options.to_listing_options(&self.copied_config());

self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;

Ok(())
}

/// Executes a query and writes the results to a partitioned CSV file.
pub async fn write_csv(
&self,
plan: Arc<dyn ExecutionPlan>,
path: impl AsRef<str>,
) -> Result<()> {
plan_to_csv(self.task_ctx(), plan, path).await
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::assert_batches_eq;
use crate::test_util::{plan_and_collect, populate_csv_partitions};
use async_trait::async_trait;
use tempfile::TempDir;

#[tokio::test]
async fn query_csv_with_custom_partition_extension() -> Result<()> {
let tmp_dir = TempDir::new()?;

// The main stipulation of this test: use a file extension that isn't .csv.
let file_extension = ".tst";

let ctx = SessionContext::new();
let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?;
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new()
.schema(&schema)
.file_extension(file_extension),
)
.await?;
let results =
plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?;

assert_eq!(results.len(), 1);
let expected = [
"+--------------+--------------+----------+",
"| SUM(test.c1) | SUM(test.c2) | COUNT(*) |",
"+--------------+--------------+----------+",
"| 10 | 110 | 20 |",
"+--------------+--------------+----------+",
];
assert_batches_eq!(expected, &results);

Ok(())
}

// Test for compilation error when calling read_* functions from an #[async_trait] function.
// See https://github.com/apache/arrow-datafusion/issues/1154
#[async_trait]
trait CallReadTrait {
async fn call_read_csv(&self) -> DataFrame;
}

struct CallRead {}

#[async_trait]
impl CallReadTrait for CallRead {
async fn call_read_csv(&self) -> DataFrame {
let ctx = SessionContext::new();
ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap()
}
}
}
69 changes: 69 additions & 0 deletions datafusion/core/src/execution/context/json.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::datasource::physical_plan::plan_to_json;

use super::super::options::{NdJsonReadOptions, ReadOptions};
use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext};

impl SessionContext {
/// Creates a [`DataFrame`] for reading an JSON data source.
///
/// For more control such as reading multiple files, you can use
/// [`read_table`](Self::read_table) with a [`super::ListingTable`].
///
/// For an example, see [`read_csv`](Self::read_csv)
pub async fn read_json<P: DataFilePaths>(
&self,
table_paths: P,
options: NdJsonReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}

/// Registers a JSON file as a table that it can be referenced
/// from SQL statements executed against this context.
pub async fn register_json(
&self,
name: &str,
table_path: &str,
options: NdJsonReadOptions<'_>,
) -> Result<()> {
let listing_options = options.to_listing_options(&self.copied_config());

self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}

/// Executes a query and writes the results to a partitioned JSON file.
pub async fn write_json(
&self,
plan: Arc<dyn ExecutionPlan>,
path: impl AsRef<str>,
) -> Result<()> {
plan_to_json(self.task_ctx(), plan, path).await
}
}
Loading

0 comments on commit 0232174

Please sign in to comment.