Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[rust, python]: add float_precision parameter to DataFrame.write_csv #4504

Merged
merged 2 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion polars/polars-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2",
csv-core = { version = "0.1.10", optional = true }
dirs = "4.0"
flate2 = { version = "1", optional = true, default-features = false }
lexical = { version = "6", optional = true, default-features = false, features = ["std", "parse-floats", "parse-integers"] }
lexical = { version = "6", optional = true, default-features = false, features = ["std", "parse-floats", "parse-integers", "write-floats"] }
matteosantama marked this conversation as resolved.
Show resolved Hide resolved
lexical-core = { version = "0.8", optional = true }
memchr = "2.4"
memmap = { package = "memmap2", version = "0.5.2", optional = true }
Expand All @@ -55,6 +55,7 @@ polars-arrow = { version = "0.23.0", path = "../polars-arrow" }
polars-core = { version = "0.23.0", path = "../polars-core", features = ["private"], default-features = false }
polars-time = { version = "0.23.0", path = "../polars-time", features = ["private"], default-features = false, optional = true }
polars-utils = { version = "0.23.0", path = "../polars-utils" }
rand = "0.8.5"
rayon = "1.5"
regex = "1.5"
serde = { version = "1", features = ["derive"], optional = true }
Expand Down
31 changes: 31 additions & 0 deletions polars/polars-io/src/csv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,34 @@ use crate::mmap::MmapBytesReader;
use crate::predicates::PhysicalIoExpr;
use crate::utils::resolve_homedir;
use crate::{RowCount, SerReader, SerWriter};
matteosantama marked this conversation as resolved.
Show resolved Hide resolved

#[cfg(test)]
mod tests {
extern crate test;

use polars_core::prelude::*;
use rand::prelude::*;
use test::Bencher;

use crate::csv::CsvWriter;
use crate::SerWriter;

#[bench]
fn benchmark_write_csv_f32(b: &mut Bencher) -> Result<()> {
// NOTE: This benchmark can be run by executing ``$ cargo bench -p polars-io``
// from within /polars/polars/
const N: usize = 10_000_000;

let vec: Vec<f32> = (0..N).map(|_| thread_rng().next_u32() as f32).collect();

let mut df = df![
"random" => vec.as_slice(),
]?;

let mut buffer: Vec<u8> = Vec::new();

b.iter(|| CsvWriter::new(&mut buffer).finish(&mut df));

Ok(())
}
}
10 changes: 9 additions & 1 deletion polars/polars-io/src/csv/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,22 @@ where
self
}

/// Set the CSV file's timestamp format array
/// Set the CSV file's datetime format
pub fn with_datetime_format(mut self, format: Option<String>) -> Self {
if format.is_some() {
self.options.datetime_format = format;
}
self
}

/// Set the CSV file's float precision
pub fn with_float_precision(mut self, precision: Option<usize>) -> Self {
if precision.is_some() {
self.options.float_precision = precision;
}
self
}

/// Set the single byte character used for quoting
pub fn with_quoting_char(mut self, char: u8) -> Self {
self.options.quote = char;
Expand Down
50 changes: 24 additions & 26 deletions polars/polars-io/src/csv/write_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ fn fmt_and_escape_str(f: &mut Vec<u8>, v: &str, options: &SerializeOptions) -> s
}
}

fn fast_float_write<N: ToLexical>(f: &mut Vec<u8>, n: N, write_size: usize) -> std::io::Result<()> {
let len = f.len();
f.reserve(write_size);
unsafe {
let buffer = std::slice::from_raw_parts_mut(f.as_mut_ptr().add(len), write_size);
let written_n = n.to_lexical(buffer).len();
f.set_len(len + written_n);
}
Ok(())
}

fn write_anyvalue(f: &mut Vec<u8>, value: AnyValue, options: &SerializeOptions) {
match value {
AnyValue::Null => write!(f, ""),
Expand All @@ -45,30 +56,14 @@ fn write_anyvalue(f: &mut Vec<u8>, value: AnyValue, options: &SerializeOptions)
AnyValue::UInt16(v) => write!(f, "{}", v),
AnyValue::UInt32(v) => write!(f, "{}", v),
AnyValue::UInt64(v) => write!(f, "{}", v),
AnyValue::Float32(v) => {
let len = f.len();
let write_size = f32::FORMATTED_SIZE_DECIMAL;
f.reserve(write_size);
unsafe {
let buf = std::slice::from_raw_parts_mut(f.as_mut_ptr().add(len), write_size);

let written_n = v.to_lexical(buf).len();
f.set_len(len + written_n);
}
Ok(())
}
AnyValue::Float64(v) => {
let len = f.len();
let write_size = f64::FORMATTED_SIZE_DECIMAL;
f.reserve(write_size);
unsafe {
let buf = std::slice::from_raw_parts_mut(f.as_mut_ptr().add(len), write_size);

let written_n = v.to_lexical(buf).len();
f.set_len(len + written_n);
}
Ok(())
}
AnyValue::Float32(v) => match &options.float_precision {
None => fast_float_write(f, v, f32::FORMATTED_SIZE_DECIMAL),
Some(precision) => write!(f, "{v:.precision$}", v = v, precision = precision),
},
AnyValue::Float64(v) => match &options.float_precision {
None => fast_float_write(f, v, f64::FORMATTED_SIZE_DECIMAL),
Some(precision) => write!(f, "{v:.precision$}", v = v, precision = precision),
},
AnyValue::Boolean(v) => write!(f, "{}", v),
AnyValue::Utf8(v) => fmt_and_escape_str(f, v, options),
#[cfg(feature = "dtype-categorical")]
Expand Down Expand Up @@ -126,10 +121,12 @@ fn write_anyvalue(f: &mut Vec<u8>, value: AnyValue, options: &SerializeOptions)
pub struct SerializeOptions {
/// used for [`DataType::Date`]
pub date_format: Option<String>,
/// used for [`DataType::Time64`]
/// used for [`DataType::Time`]
pub time_format: Option<String>,
/// used for [`DataType::Timestamp`]
/// used for [`DataType::Datetime]
pub datetime_format: Option<String>,
/// used for [`DataType::Float64`] and [`DataType::Float32`]
pub float_precision: Option<usize>,
/// used as separator/delimiter
pub delimiter: u8,
/// quoting character
Expand All @@ -142,6 +139,7 @@ impl Default for SerializeOptions {
date_format: None,
time_format: None,
datetime_format: None,
float_precision: None,
delimiter: b',',
quote: b'"',
}
Expand Down
1 change: 1 addition & 0 deletions polars/polars-io/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![feature(test)]
matteosantama marked this conversation as resolved.
Show resolved Hide resolved
#![cfg_attr(docsrs, feature(doc_cfg))]

#[cfg(feature = "private")]
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ def write_csv(
datetime_format: str | None = None,
date_format: str | None = None,
time_format: str | None = None,
float_precision: int | None = None,
) -> str | None:
"""
Write to comma-separated values (CSV) file.
Expand Down Expand Up @@ -1117,6 +1118,9 @@ def write_csv(
A format string, with the specifiers defined by the
`chrono <https://docs.rs/chrono/latest/chrono/format/strftime/index.html>`_
Rust crate.
float_precision
Number of decimal places to write, applied to both ``Float32`` and
``Float64`` datatypes.

Examples
--------
Expand Down Expand Up @@ -1148,6 +1152,7 @@ def write_csv(
datetime_format,
date_format,
time_format,
float_precision,
)
return str(buffer.getvalue(), encoding="utf-8")

Expand All @@ -1163,6 +1168,7 @@ def write_csv(
datetime_format,
date_format,
time_format,
float_precision,
)
return None

Expand Down
3 changes: 3 additions & 0 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ impl PyDataFrame {
datetime_format: Option<String>,
date_format: Option<String>,
time_format: Option<String>,
float_precision: Option<usize>,
) -> PyResult<()> {
if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s).unwrap();
Expand All @@ -442,6 +443,7 @@ impl PyDataFrame {
.with_datetime_format(datetime_format)
.with_date_format(date_format)
.with_time_format(time_format)
.with_float_precision(float_precision)
.finish(&mut self.df)
.map_err(PyPolarsErr::from)?;
} else {
Expand All @@ -454,6 +456,7 @@ impl PyDataFrame {
.with_datetime_format(datetime_format)
.with_date_format(date_format)
.with_time_format(time_format)
.with_float_precision(float_precision)
.finish(&mut self.df)
.map_err(PyPolarsErr::from)?;
}
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,3 +730,14 @@ def test_time_format(fmt: str, expected: str) -> None:
df = pl.DataFrame({"dt": [time(16, 15, 30)]})
csv = df.write_csv(time_format=fmt)
assert csv == expected


@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])
def test_float_precision(dtype: pl.Float32 | pl.Float64) -> None:
df = pl.Series("col", [1.0, 2.2, 3.33], dtype=dtype).to_frame()

assert df.write_csv(float_precision=None) == "col\n1.0\n2.2\n3.33\n"
assert df.write_csv(float_precision=0) == "col\n1\n2\n3\n"
assert df.write_csv(float_precision=1) == "col\n1.0\n2.2\n3.3\n"
assert df.write_csv(float_precision=2) == "col\n1.00\n2.20\n3.33\n"
assert df.write_csv(float_precision=3) == "col\n1.000\n2.200\n3.330\n"