Skip to content

Commit

Permalink
Merge 'External functions' from Pekka Enberg
Browse files Browse the repository at this point in the history
This pull request adds support for external functions, which are
functions provided by extensions. The main difference to how things are
today is that extensions can register functions to a symbol table at
runtime instead of specifying an enum variant.

Reviewed-by: Pere Diaz Bou <[email protected]>

Closes #567
  • Loading branch information
penberg committed Dec 31, 2024
2 parents 9b7b2f6 + 0aabcdd commit b1ca2b0
Show file tree
Hide file tree
Showing 11 changed files with 298 additions and 65 deletions.
5 changes: 5 additions & 0 deletions core/ext/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ impl ExtFunc {
}
}
}

pub fn init(db: &mut crate::Database) {
#[cfg(feature = "uuid")]
uuid::init(db);
}
28 changes: 15 additions & 13 deletions core/ext/uuid.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use super::ExtFunc;
use crate::{
types::{LimboText, OwnedValue},
LimboError,
Database, LimboError,
};
use std::rc::Rc;
use uuid::{ContextV7, Timestamp, Uuid};

#[derive(Debug, Clone, Copy, PartialEq)]
pub enum UuidFunc {
Uuid4Str,
Uuid4,
Uuid7,
Uuid7TS,
UuidStr,
Expand All @@ -20,7 +19,6 @@ impl UuidFunc {
pub fn resolve_function(name: &str, num_args: usize) -> Option<ExtFunc> {
match name {
"uuid4_str" => Some(ExtFunc::Uuid(Self::Uuid4Str)),
"uuid4" => Some(ExtFunc::Uuid(Self::Uuid4)),
"uuid7" if num_args < 2 => Some(ExtFunc::Uuid(Self::Uuid7)),
"uuid_str" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidStr)),
"uuid_blob" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidBlob)),
Expand All @@ -36,7 +34,6 @@ impl std::fmt::Display for UuidFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Uuid4Str => write!(f, "uuid4_str"),
Self::Uuid4 => write!(f, "uuid4"),
Self::Uuid7 => write!(f, "uuid7"),
Self::Uuid7TS => write!(f, "uuid7_timestamp_ms"),
Self::UuidStr => write!(f, "uuid_str"),
Expand All @@ -47,9 +44,6 @@ impl std::fmt::Display for UuidFunc {

pub fn exec_uuid(var: &UuidFunc, sec: Option<&OwnedValue>) -> crate::Result<OwnedValue> {
match var {
UuidFunc::Uuid4 => Ok(OwnedValue::Blob(Rc::new(
Uuid::new_v4().into_bytes().to_vec(),
))),
UuidFunc::Uuid4Str => Ok(OwnedValue::Text(LimboText::new(Rc::new(
Uuid::new_v4().to_string(),
)))),
Expand All @@ -71,6 +65,12 @@ pub fn exec_uuid(var: &UuidFunc, sec: Option<&OwnedValue>) -> crate::Result<Owne
}
}

pub fn exec_uuid4() -> crate::Result<OwnedValue> {
Ok(OwnedValue::Blob(Rc::new(
Uuid::new_v4().into_bytes().to_vec(),
)))
}

pub fn exec_uuidstr(reg: &OwnedValue) -> crate::Result<OwnedValue> {
match reg {
OwnedValue::Blob(blob) => {
Expand Down Expand Up @@ -136,17 +136,20 @@ fn uuid_to_unix(uuid: &[u8; 16]) -> u64 {
| (uuid[5] as u64)
}

pub fn init(db: &mut Database) {
db.define_scalar_function("uuid4", |_args| exec_uuid4());
}

#[cfg(test)]
#[cfg(feature = "uuid")]
pub mod test {
use super::UuidFunc;
use crate::types::OwnedValue;
#[test]
fn test_exec_uuid_v4blob() {
use super::exec_uuid;
use super::exec_uuid4;
use uuid::Uuid;
let func = UuidFunc::Uuid4;
let owned_val = exec_uuid(&func, None);
let owned_val = exec_uuid4();
match owned_val {
Ok(OwnedValue::Blob(blob)) => {
assert_eq!(blob.len(), 16);
Expand Down Expand Up @@ -303,11 +306,10 @@ pub mod test {

#[test]
fn test_exec_uuid_v4_blob_to_str() {
use super::{exec_uuid, exec_uuidstr, UuidFunc};
use super::{exec_uuid4, exec_uuidstr};
use uuid::Uuid;
// convert a v4 blob to a string
let owned_val =
exec_uuidstr(&exec_uuid(&UuidFunc::Uuid4, None).expect("uuid v7 blob to generate"));
let owned_val = exec_uuidstr(&exec_uuid4().expect("uuid v7 blob to generate"));
match owned_val {
Ok(OwnedValue::Text(v4str)) => {
assert_eq!(v4str.value.len(), 36);
Expand Down
25 changes: 23 additions & 2 deletions core/function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
use crate::ext::ExtFunc;
use std::fmt;
use std::fmt::Display;
use std::fmt::{Debug, Display};
use std::rc::Rc;

pub struct ExternalFunc {
pub name: String,
pub func: Box<dyn Fn(&[crate::types::Value]) -> crate::Result<crate::types::OwnedValue>>,
}

impl Debug for ExternalFunc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)
}
}

impl Display for ExternalFunc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)
}
}

#[cfg(feature = "json")]
#[derive(Debug, Clone, PartialEq)]
pub enum JsonFunc {
Expand Down Expand Up @@ -255,14 +274,15 @@ impl Display for MathFunc {
}
}

#[derive(Debug, Clone, PartialEq)]
#[derive(Debug)]
pub enum Func {
Agg(AggFunc),
Scalar(ScalarFunc),
Math(MathFunc),
#[cfg(feature = "json")]
Json(JsonFunc),
Extension(ExtFunc),
External(Rc<ExternalFunc>),
}

impl Display for Func {
Expand All @@ -274,6 +294,7 @@ impl Display for Func {
#[cfg(feature = "json")]
Self::Json(json_func) => write!(f, "{}", json_func),
Self::Extension(ext_func) => write!(f, "{}", ext_func),
Self::External(generic_func) => write!(f, "{}", generic_func),
}
}
}
Expand Down
88 changes: 70 additions & 18 deletions core/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use schema::Schema;
use sqlite3_parser::ast;
use sqlite3_parser::{ast::Cmd, lexer::sql::Parser};
use std::cell::Cell;
use std::sync::Weak;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use std::{cell::RefCell, rc::Rc};
use storage::btree::btree_init_page;
Expand All @@ -37,6 +37,7 @@ pub use storage::wal::WalFileShared;
use util::parse_schema_rows;

use translate::select::prepare_select_plan;
use types::OwnedValue;

pub use error::LimboError;
pub type Result<T> = std::result::Result<T, error::LimboError>;
Expand Down Expand Up @@ -67,6 +68,7 @@ pub struct Database {
pager: Rc<Pager>,
schema: Rc<RefCell<Schema>>,
header: Rc<RefCell<DatabaseHeader>>,
syms: Rc<RefCell<SymbolTable>>,
// Shared structures of a Database are the parts that are common to multiple threads that might
// create DB connections.
_shared_page_cache: Arc<RwLock<DumbLruPageCache>>,
Expand Down Expand Up @@ -119,39 +121,58 @@ impl Database {
_shared_page_cache.clone(),
buffer_pool,
)?);
let bootstrap_schema = Rc::new(RefCell::new(Schema::new()));
let conn = Rc::new(Connection {
let header = db_header;
let schema = Rc::new(RefCell::new(Schema::new()));
let syms = Rc::new(RefCell::new(SymbolTable::new()));
let mut db = Database {
pager: pager.clone(),
schema: bootstrap_schema.clone(),
header: db_header.clone(),
schema: schema.clone(),
header: header.clone(),
_shared_page_cache: _shared_page_cache.clone(),
_shared_wal: shared_wal.clone(),
syms,
};
ext::init(&mut db);
let db = Arc::new(db);
let conn = Rc::new(Connection {
db: db.clone(),
pager: pager,
schema: schema.clone(),
header,
transaction_state: RefCell::new(TransactionState::None),
_db: Weak::new(),
last_insert_rowid: Cell::new(0),
});
let mut schema = Schema::new();
let rows = conn.query("SELECT * FROM sqlite_schema")?;
let mut schema = schema.borrow_mut();
parse_schema_rows(rows, &mut schema, io)?;
let schema = Rc::new(RefCell::new(schema));
let header = db_header;
Ok(Arc::new(Database {
pager,
schema,
header,
_shared_page_cache,
_shared_wal: shared_wal,
}))
Ok(db)
}

pub fn connect(self: &Arc<Database>) -> Rc<Connection> {
Rc::new(Connection {
db: self.clone(),
pager: self.pager.clone(),
schema: self.schema.clone(),
header: self.header.clone(),
last_insert_rowid: Cell::new(0),
_db: Arc::downgrade(self),
transaction_state: RefCell::new(TransactionState::None),
})
}

pub fn define_scalar_function<S: AsRef<str>>(
&self,
name: S,
func: impl Fn(&[Value]) -> Result<OwnedValue> + 'static,
) {
let func = function::ExternalFunc {
name: name.as_ref().to_string(),
func: Box::new(func),
};
self.syms
.borrow_mut()
.functions
.insert(name.as_ref().to_string(), Rc::new(func));
}
}

pub fn maybe_init_database_file(file: &Rc<dyn File>, io: &Arc<dyn IO>) -> Result<()> {
Expand Down Expand Up @@ -204,10 +225,10 @@ pub fn maybe_init_database_file(file: &Rc<dyn File>, io: &Arc<dyn IO>) -> Result
}

pub struct Connection {
db: Arc<Database>,
pager: Rc<Pager>,
schema: Rc<RefCell<Schema>>,
header: Rc<RefCell<DatabaseHeader>>,
_db: Weak<Database>, // backpointer to the database holding this connection
transaction_state: RefCell<TransactionState>,
last_insert_rowid: Cell<u64>,
}
Expand All @@ -216,6 +237,8 @@ impl Connection {
pub fn prepare(self: &Rc<Connection>, sql: impl Into<String>) -> Result<Statement> {
let sql = sql.into();
trace!("Preparing: {}", sql);
let db = self.db.clone();
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
if let Some(cmd) = cmd {
Expand All @@ -227,6 +250,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?);
Ok(Statement::new(program, self.pager.clone()))
}
Expand All @@ -241,6 +265,8 @@ impl Connection {
pub fn query(self: &Rc<Connection>, sql: impl Into<String>) -> Result<Option<Rows>> {
let sql = sql.into();
trace!("Querying: {}", sql);
let db = self.db.clone();
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
if let Some(cmd) = cmd {
Expand All @@ -252,6 +278,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?);
let stmt = Statement::new(program, self.pager.clone());
Ok(Some(Rows { stmt }))
Expand All @@ -263,6 +290,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?;
program.explain();
Ok(None)
Expand All @@ -286,6 +314,8 @@ impl Connection {

pub fn execute(self: &Rc<Connection>, sql: impl Into<String>) -> Result<()> {
let sql = sql.into();
let db = self.db.clone();
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
if let Some(cmd) = cmd {
Expand All @@ -297,6 +327,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?;
program.explain();
}
Expand All @@ -308,6 +339,7 @@ impl Connection {
self.header.clone(),
self.pager.clone(),
Rc::downgrade(self),
&syms,
)?;
let mut state = vdbe::ProgramState::new(program.max_registers);
program.step(&mut state, self.pager.clone())?;
Expand Down Expand Up @@ -426,3 +458,23 @@ impl Rows {
self.stmt.step()
}
}

pub(crate) struct SymbolTable {
pub functions: HashMap<String, Rc<crate::function::ExternalFunc>>,
}

impl SymbolTable {
pub fn new() -> Self {
Self {
functions: HashMap::new(),
}
}

pub fn resolve_function(
&self,
name: &str,
_arg_count: usize,
) -> Option<Rc<crate::function::ExternalFunc>> {
self.functions.get(name).cloned()
}
}
5 changes: 3 additions & 2 deletions core/translate/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::translate::optimizer::optimize_plan;
use crate::translate::plan::{BTreeTableReference, DeletePlan, Plan, SourceOperator};
use crate::translate::planner::{parse_limit, parse_where};
use crate::{schema::Schema, storage::sqlite3_ondisk::DatabaseHeader, vdbe::Program};
use crate::{Connection, Result};
use crate::{Connection, Result, SymbolTable};
use sqlite3_parser::ast::{Expr, Limit, QualifiedName};
use std::rc::Weak;
use std::{cell::RefCell, rc::Rc};
Expand All @@ -15,10 +15,11 @@ pub fn translate_delete(
limit: Option<Limit>,
database_header: Rc<RefCell<DatabaseHeader>>,
connection: Weak<Connection>,
syms: &SymbolTable,
) -> Result<Program> {
let delete_plan = prepare_delete_plan(schema, tbl_name, where_clause, limit)?;
let optimized_plan = optimize_plan(delete_plan)?;
emit_program(database_header, optimized_plan, connection)
emit_program(database_header, optimized_plan, connection, syms)
}

pub fn prepare_delete_plan(
Expand Down
Loading

0 comments on commit b1ca2b0

Please sign in to comment.