Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

External functions #567

Merged
merged 4 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/ext/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ impl ExtFunc {
}
}
}

pub fn init(db: &mut crate::Database) {
jussisaurio marked this conversation as resolved.
Show resolved Hide resolved
#[cfg(feature = "uuid")]
uuid::init(db);
}
28 changes: 15 additions & 13 deletions core/ext/uuid.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use super::ExtFunc;
use crate::{
types::{LimboText, OwnedValue},
LimboError,
Database, LimboError,
};
use std::rc::Rc;
use uuid::{ContextV7, Timestamp, Uuid};

#[derive(Debug, Clone, Copy, PartialEq)]
pub enum UuidFunc {
Uuid4Str,
Uuid4,
Uuid7,
Uuid7TS,
UuidStr,
Expand All @@ -20,7 +19,6 @@ impl UuidFunc {
pub fn resolve_function(name: &str, num_args: usize) -> Option<ExtFunc> {
match name {
"uuid4_str" => Some(ExtFunc::Uuid(Self::Uuid4Str)),
"uuid4" => Some(ExtFunc::Uuid(Self::Uuid4)),
"uuid7" if num_args < 2 => Some(ExtFunc::Uuid(Self::Uuid7)),
"uuid_str" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidStr)),
"uuid_blob" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidBlob)),
Expand All @@ -36,7 +34,6 @@ impl std::fmt::Display for UuidFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Uuid4Str => write!(f, "uuid4_str"),
Self::Uuid4 => write!(f, "uuid4"),
Self::Uuid7 => write!(f, "uuid7"),
Self::Uuid7TS => write!(f, "uuid7_timestamp_ms"),
Self::UuidStr => write!(f, "uuid_str"),
Expand All @@ -47,9 +44,6 @@ impl std::fmt::Display for UuidFunc {

pub fn exec_uuid(var: &UuidFunc, sec: Option<&OwnedValue>) -> crate::Result<OwnedValue> {
match var {
UuidFunc::Uuid4 => Ok(OwnedValue::Blob(Rc::new(
Uuid::new_v4().into_bytes().to_vec(),
))),
UuidFunc::Uuid4Str => Ok(OwnedValue::Text(LimboText::new(Rc::new(
Uuid::new_v4().to_string(),
)))),
Expand All @@ -71,6 +65,12 @@ pub fn exec_uuid(var: &UuidFunc, sec: Option<&OwnedValue>) -> crate::Result<Owne
}
}

pub fn exec_uuid4() -> crate::Result<OwnedValue> {
Ok(OwnedValue::Blob(Rc::new(
Uuid::new_v4().into_bytes().to_vec(),
)))
}

pub fn exec_uuidstr(reg: &OwnedValue) -> crate::Result<OwnedValue> {
match reg {
OwnedValue::Blob(blob) => {
Expand Down Expand Up @@ -136,17 +136,20 @@ fn uuid_to_unix(uuid: &[u8; 16]) -> u64 {
| (uuid[5] as u64)
}

pub fn init(db: &mut Database) {
db.define_scalar_function("uuid4", |_args| exec_uuid4());
}

#[cfg(test)]
#[cfg(feature = "uuid")]
pub mod test {
use super::UuidFunc;
use crate::types::OwnedValue;
#[test]
fn test_exec_uuid_v4blob() {
use super::exec_uuid;
use super::exec_uuid4;
use uuid::Uuid;
let func = UuidFunc::Uuid4;
let owned_val = exec_uuid(&func, None);
let owned_val = exec_uuid4();
match owned_val {
Ok(OwnedValue::Blob(blob)) => {
assert_eq!(blob.len(), 16);
Expand Down Expand Up @@ -303,11 +306,10 @@ pub mod test {

#[test]
fn test_exec_uuid_v4_blob_to_str() {
use super::{exec_uuid, exec_uuidstr, UuidFunc};
use super::{exec_uuid4, exec_uuidstr};
use uuid::Uuid;
// convert a v4 blob to a string
let owned_val =
exec_uuidstr(&exec_uuid(&UuidFunc::Uuid4, None).expect("uuid v7 blob to generate"));
let owned_val = exec_uuidstr(&exec_uuid4().expect("uuid v7 blob to generate"));
match owned_val {
Ok(OwnedValue::Text(v4str)) => {
assert_eq!(v4str.value.len(), 36);
Expand Down
25 changes: 23 additions & 2 deletions core/function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
use crate::ext::ExtFunc;
use std::fmt;
use std::fmt::Display;
use std::fmt::{Debug, Display};
use std::rc::Rc;

pub struct ExternalFunc {
pub name: String,
pub func: Box<dyn Fn(&[crate::types::Value]) -> crate::Result<crate::types::OwnedValue>>,
}

impl Debug for ExternalFunc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)
}
}

impl Display for ExternalFunc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)
}
}

#[cfg(feature = "json")]
#[derive(Debug, Clone, PartialEq)]
pub enum JsonFunc {
Expand Down Expand Up @@ -255,14 +274,15 @@ impl Display for MathFunc {
}
}

#[derive(Debug, Clone, PartialEq)]
#[derive(Debug)]
pub enum Func {
Agg(AggFunc),
Scalar(ScalarFunc),
Math(MathFunc),
#[cfg(feature = "json")]
Json(JsonFunc),
Extension(ExtFunc),
External(Rc<ExternalFunc>),
}

impl Display for Func {
Expand All @@ -274,6 +294,7 @@ impl Display for Func {
#[cfg(feature = "json")]
Self::Json(json_func) => write!(f, "{}", json_func),
Self::Extension(ext_func) => write!(f, "{}", ext_func),
Self::External(generic_func) => write!(f, "{}", generic_func),
}
}
}
Expand Down
88 changes: 70 additions & 18 deletions core/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use schema::Schema;
use sqlite3_parser::ast;
use sqlite3_parser::{ast::Cmd, lexer::sql::Parser};
use std::cell::Cell;
use std::sync::Weak;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use std::{cell::RefCell, rc::Rc};
use storage::btree::btree_init_page;
Expand All @@ -37,6 +37,7 @@ pub use storage::wal::WalFileShared;
use util::parse_schema_rows;

use translate::select::prepare_select_plan;
use types::OwnedValue;

pub use error::LimboError;
pub type Result<T> = std::result::Result<T, error::LimboError>;
Expand Down Expand Up @@ -67,6 +68,7 @@ pub struct Database {
pager: Rc<Pager>,
schema: Rc<RefCell<Schema>>,
header: Rc<RefCell<DatabaseHeader>>,
syms: Rc<RefCell<SymbolTable>>,
// Shared structures of a Database are the parts that are common to multiple threads that might
// create DB connections.
_shared_page_cache: Arc<RwLock<DumbLruPageCache>>,
Expand Down Expand Up @@ -119,39 +121,58 @@ impl Database {
_shared_page_cache.clone(),
buffer_pool,
)?);
let bootstrap_schema = Rc::new(RefCell::new(Schema::new()));
let conn = Rc::new(Connection {
let header = db_header;
let schema = Rc::new(RefCell::new(Schema::new()));
let syms = Rc::new(RefCell::new(SymbolTable::new()));
let mut db = Database {
pager: pager.clone(),
schema: bootstrap_schema.clone(),
header: db_header.clone(),
schema: schema.clone(),
header: header.clone(),
_shared_page_cache: _shared_page_cache.clone(),
_shared_wal: shared_wal.clone(),
syms,
};
ext::init(&mut db);
let db = Arc::new(db);
let conn = Rc::new(Connection {
db: db.clone(),
pager: pager,
schema: schema.clone(),
header,
transaction_state: RefCell::new(TransactionState::None),
_db: Weak::new(),
last_insert_rowid: Cell::new(0),
});
let mut schema = Schema::new();
let rows = conn.query("SELECT * FROM sqlite_schema")?;
let mut schema = schema.borrow_mut();
parse_schema_rows(rows, &mut schema, io)?;
let schema = Rc::new(RefCell::new(schema));
let header = db_header;
Ok(Arc::new(Database {
pager,
schema,
header,
_shared_page_cache,
_shared_wal: shared_wal,
}))
Ok(db)
}

pub fn connect(self: &Arc<Database>) -> Rc<Connection> {
Rc::new(Connection {
db: self.clone(),
pager: self.pager.clone(),
schema: self.schema.clone(),
header: self.header.clone(),
last_insert_rowid: Cell::new(0),
_db: Arc::downgrade(self),
transaction_state: RefCell::new(TransactionState::None),
})
}

pub fn define_scalar_function<S: AsRef<str>>(
&self,
name: S,
func: impl Fn(&[Value]) -> Result<OwnedValue> + 'static,
) {
let func = function::ExternalFunc {
name: name.as_ref().to_string(),
func: Box::new(func),
};
self.syms
.borrow_mut()
.functions
.insert(name.as_ref().to_string(), Rc::new(func));
}
}

pub fn maybe_init_database_file(file: &Rc<dyn File>, io: &Arc<dyn IO>) -> Result<()> {
Expand Down Expand Up @@ -204,10 +225,10 @@ pub fn maybe_init_database_file(file: &Rc<dyn File>, io: &Arc<dyn IO>) -> Result
}

pub struct Connection {
db: Arc<Database>,
pager: Rc<Pager>,
schema: Rc<RefCell<Schema>>,
header: Rc<RefCell<DatabaseHeader>>,
_db: Weak<Database>, // backpointer to the database holding this connection
transaction_state: RefCell<TransactionState>,
last_insert_rowid: Cell<u64>,
}
Expand All @@ -216,6 +237,8 @@ impl Connection {
pub fn prepare(self: &Rc<Connection>, sql: impl Into<String>) -> Result<Statement> {
let sql = sql.into();
trace!("Preparing: {}", sql);
let db = self.db.clone();
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
if let Some(cmd) = cmd {
Expand All @@ -227,6 +250,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?);
Ok(Statement::new(program, self.pager.clone()))
}
Expand All @@ -241,6 +265,8 @@ impl Connection {
pub fn query(self: &Rc<Connection>, sql: impl Into<String>) -> Result<Option<Rows>> {
let sql = sql.into();
trace!("Querying: {}", sql);
let db = self.db.clone();
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
if let Some(cmd) = cmd {
Expand All @@ -252,6 +278,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?);
let stmt = Statement::new(program, self.pager.clone());
Ok(Some(Rows { stmt }))
Expand All @@ -263,6 +290,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?;
program.explain();
Ok(None)
Expand All @@ -286,6 +314,8 @@ impl Connection {

pub fn execute(self: &Rc<Connection>, sql: impl Into<String>) -> Result<()> {
let sql = sql.into();
let db = self.db.clone();
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
if let Some(cmd) = cmd {
Expand All @@ -297,6 +327,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?;
program.explain();
}
Expand All @@ -308,6 +339,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?;
let mut state = vdbe::ProgramState::new(program.max_registers);
program.step(&mut state, self.pager.clone())?;
Expand Down Expand Up @@ -426,3 +458,23 @@ impl Rows {
self.stmt.step()
}
}

pub(crate) struct SymbolTable {
pub functions: HashMap<String, Rc<crate::function::ExternalFunc>>,
}

impl SymbolTable {
pub fn new() -> Self {
Self {
functions: HashMap::new(),
}
}

pub fn resolve_function(
&self,
name: &str,
_arg_count: usize,
) -> Option<Rc<crate::function::ExternalFunc>> {
self.functions.get(name).cloned()
}
}
5 changes: 3 additions & 2 deletions core/translate/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::translate::optimizer::optimize_plan;
use crate::translate::plan::{BTreeTableReference, DeletePlan, Plan, SourceOperator};
use crate::translate::planner::{parse_limit, parse_where};
use crate::{schema::Schema, storage::sqlite3_ondisk::DatabaseHeader, vdbe::Program};
use crate::{Connection, Result};
use crate::{Connection, Result, SymbolTable};
use sqlite3_parser::ast::{Expr, Limit, QualifiedName};
use std::rc::Weak;
use std::{cell::RefCell, rc::Rc};
Expand All @@ -15,10 +15,11 @@ pub fn translate_delete(
limit: Option<Limit>,
database_header: Rc<RefCell<DatabaseHeader>>,
connection: Weak<Connection>,
syms: &SymbolTable,
) -> Result<Program> {
let delete_plan = prepare_delete_plan(schema, tbl_name, where_clause, limit)?;
let optimized_plan = optimize_plan(delete_plan)?;
emit_program(database_header, optimized_plan, connection)
emit_program(database_header, optimized_plan, connection, syms)
}

pub fn prepare_delete_plan(
Expand Down
Loading
Loading