From dca47f62ea77761225981d15469368eb73155b04 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sat, 28 Dec 2024 09:51:46 +0200 Subject: [PATCH 1/4] core: Don't use Weak reference for connection database The database object is a way to represent state that's shared across multiple connections. We don't want to release that object until all connections are closed. --- core/lib.rs | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index d7fac66a4..7002c6803 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -23,7 +23,6 @@ use schema::Schema; use sqlite3_parser::ast; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; -use std::sync::Weak; use std::sync::{Arc, OnceLock, RwLock}; use std::{cell::RefCell, rc::Rc}; use storage::btree::btree_init_page; @@ -119,27 +118,27 @@ impl Database { _shared_page_cache.clone(), buffer_pool, )?); - let bootstrap_schema = Rc::new(RefCell::new(Schema::new())); - let conn = Rc::new(Connection { + let header = db_header; + let schema = Rc::new(RefCell::new(Schema::new())); + let db = Arc::new(Database { pager: pager.clone(), - schema: bootstrap_schema.clone(), - header: db_header.clone(), + schema: schema.clone(), + header: header.clone(), + shared_page_cache, + shared_wal, + }); + let conn = Rc::new(Connection { + pager: pager, + schema: schema.clone(), + header, transaction_state: RefCell::new(TransactionState::None), - _db: Weak::new(), + db: db.clone(), last_insert_rowid: Cell::new(0), }); - let mut schema = Schema::new(); let rows = conn.query("SELECT * FROM sqlite_schema")?; + let mut schema = schema.borrow_mut(); parse_schema_rows(rows, &mut schema, io)?; - let schema = Rc::new(RefCell::new(schema)); - let header = db_header; - Ok(Arc::new(Database { - pager, - schema, - header, - _shared_page_cache, - _shared_wal: shared_wal, - })) + Ok(db) } pub fn connect(self: &Arc) -> Rc { @@ -148,7 +147,11 @@ impl Database { schema: self.schema.clone(), header: self.header.clone(), last_insert_rowid: Cell::new(0), +<<<<<<< HEAD _db: Arc::downgrade(self), +======= + db: self.clone(), +>>>>>>> 680b321 (core: Don't use Weak reference for connection database) transaction_state: RefCell::new(TransactionState::None), }) } @@ -204,10 +207,14 @@ pub fn maybe_init_database_file(file: &Rc, io: &Arc) -> Result } pub struct Connection { + db: Arc, pager: Rc, schema: Rc>, header: Rc>, +<<<<<<< HEAD _db: Weak, // backpointer to the database holding this connection +======= +>>>>>>> 680b321 (core: Don't use Weak reference for connection database) transaction_state: RefCell, last_insert_rowid: Cell, } From 858aecfea23aa28b497e3e35d257ad46782b9f49 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sat, 28 Dec 2024 10:23:21 +0200 Subject: [PATCH 2/4] core: Drop Clone and PartialEq from Func enum We don't need them anywhere and they make it hard to introduce GenericFunction. --- core/function.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/function.rs b/core/function.rs index a97c6e1b6..44e32f5ab 100644 --- a/core/function.rs +++ b/core/function.rs @@ -255,7 +255,7 @@ impl Display for MathFunc { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug)] pub enum Func { Agg(AggFunc), Scalar(ScalarFunc), From 33dbd6c892b8c8a9c0b4cac3b475988790f5d052 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Fri, 27 Dec 2024 13:26:41 +0200 Subject: [PATCH 3/4] core: External functions --- core/function.rs | 23 ++++++++- core/lib.rs | 67 +++++++++++++++++++----- core/translate/delete.rs | 5 +- core/translate/emitter.rs | 78 +++++++++++++++++++++++----- core/translate/expr.rs | 106 ++++++++++++++++++++++++++++++++++---- core/translate/insert.rs | 6 +++ core/translate/mod.rs | 9 +++- core/translate/select.rs | 5 +- core/vdbe/mod.rs | 4 ++ 9 files changed, 260 insertions(+), 43 deletions(-) diff --git a/core/function.rs b/core/function.rs index 44e32f5ab..4bb6a3f9b 100644 --- a/core/function.rs +++ b/core/function.rs @@ -1,6 +1,25 @@ use crate::ext::ExtFunc; use std::fmt; -use std::fmt::Display; +use std::fmt::{Debug, Display}; +use std::rc::Rc; + +pub struct ExternalFunc { + pub name: String, + pub func: Box crate::Result>, +} + +impl Debug for ExternalFunc { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name) + } +} + +impl Display for ExternalFunc { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name) + } +} + #[cfg(feature = "json")] #[derive(Debug, Clone, PartialEq)] pub enum JsonFunc { @@ -263,6 +282,7 @@ pub enum Func { #[cfg(feature = "json")] Json(JsonFunc), Extension(ExtFunc), + External(Rc), } impl Display for Func { @@ -274,6 +294,7 @@ impl Display for Func { #[cfg(feature = "json")] Self::Json(json_func) => write!(f, "{}", json_func), Self::Extension(ext_func) => write!(f, "{}", ext_func), + Self::External(generic_func) => write!(f, "{}", generic_func), } } } diff --git a/core/lib.rs b/core/lib.rs index 7002c6803..321954e83 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -23,6 +23,7 @@ use schema::Schema; use sqlite3_parser::ast; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; +use std::collections::HashMap; use std::sync::{Arc, OnceLock, RwLock}; use std::{cell::RefCell, rc::Rc}; use storage::btree::btree_init_page; @@ -36,6 +37,7 @@ pub use storage::wal::WalFileShared; use util::parse_schema_rows; use translate::select::prepare_select_plan; +use types::OwnedValue; pub use error::LimboError; pub type Result = std::result::Result; @@ -66,6 +68,7 @@ pub struct Database { pager: Rc, schema: Rc>, header: Rc>, + syms: Rc>, // Shared structures of a Database are the parts that are common to multiple threads that might // create DB connections. _shared_page_cache: Arc>, @@ -120,19 +123,21 @@ impl Database { )?); let header = db_header; let schema = Rc::new(RefCell::new(Schema::new())); + let syms = Rc::new(RefCell::new(SymbolTable::new())); let db = Arc::new(Database { pager: pager.clone(), schema: schema.clone(), header: header.clone(), - shared_page_cache, - shared_wal, + _shared_page_cache: _shared_page_cache.clone(), + _shared_wal: shared_wal.clone(), + syms, }); let conn = Rc::new(Connection { + db: db.clone(), pager: pager, schema: schema.clone(), header, transaction_state: RefCell::new(TransactionState::None), - db: db.clone(), last_insert_rowid: Cell::new(0), }); let rows = conn.query("SELECT * FROM sqlite_schema")?; @@ -143,18 +148,29 @@ impl Database { pub fn connect(self: &Arc) -> Rc { Rc::new(Connection { + db: self.clone(), pager: self.pager.clone(), schema: self.schema.clone(), header: self.header.clone(), last_insert_rowid: Cell::new(0), -<<<<<<< HEAD - _db: Arc::downgrade(self), -======= - db: self.clone(), ->>>>>>> 680b321 (core: Don't use Weak reference for connection database) transaction_state: RefCell::new(TransactionState::None), }) } + + pub fn define_scalar_function>( + &self, + name: S, + func: impl Fn(&[Value]) -> Result + 'static, + ) { + let func = function::ExternalFunc { + name: name.as_ref().to_string(), + func: Box::new(func), + }; + self.syms + .borrow_mut() + .functions + .insert(name.as_ref().to_string(), Rc::new(func)); + } } pub fn maybe_init_database_file(file: &Rc, io: &Arc) -> Result<()> { @@ -211,10 +227,6 @@ pub struct Connection { pager: Rc, schema: Rc>, header: Rc>, -<<<<<<< HEAD - _db: Weak, // backpointer to the database holding this connection -======= ->>>>>>> 680b321 (core: Don't use Weak reference for connection database) transaction_state: RefCell, last_insert_rowid: Cell, } @@ -223,6 +235,8 @@ impl Connection { pub fn prepare(self: &Rc, sql: impl Into) -> Result { let sql = sql.into(); trace!("Preparing: {}", sql); + let db = self.db.clone(); + let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; if let Some(cmd) = cmd { @@ -234,6 +248,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), + &syms, )?); Ok(Statement::new(program, self.pager.clone())) } @@ -248,6 +263,8 @@ impl Connection { pub fn query(self: &Rc, sql: impl Into) -> Result> { let sql = sql.into(); trace!("Querying: {}", sql); + let db = self.db.clone(); + let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; if let Some(cmd) = cmd { @@ -259,6 +276,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), + &syms, )?); let stmt = Statement::new(program, self.pager.clone()); Ok(Some(Rows { stmt })) @@ -270,6 +288,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), + &syms, )?; program.explain(); Ok(None) @@ -293,6 +312,8 @@ impl Connection { pub fn execute(self: &Rc, sql: impl Into) -> Result<()> { let sql = sql.into(); + let db = self.db.clone(); + let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; if let Some(cmd) = cmd { @@ -304,6 +325,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), + &syms, )?; program.explain(); } @@ -315,6 +337,7 @@ impl Connection { self.header.clone(), self.pager.clone(), Rc::downgrade(self), + &syms, )?; let mut state = vdbe::ProgramState::new(program.max_registers); program.step(&mut state, self.pager.clone())?; @@ -433,3 +456,23 @@ impl Rows { self.stmt.step() } } + +pub(crate) struct SymbolTable { + pub functions: HashMap>, +} + +impl SymbolTable { + pub fn new() -> Self { + Self { + functions: HashMap::new(), + } + } + + pub fn resolve_function( + &self, + name: &str, + _arg_count: usize, + ) -> Option> { + self.functions.get(name).cloned() + } +} diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 38256d4c1..f655e8c3b 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -3,7 +3,7 @@ use crate::translate::optimizer::optimize_plan; use crate::translate::plan::{BTreeTableReference, DeletePlan, Plan, SourceOperator}; use crate::translate::planner::{parse_limit, parse_where}; use crate::{schema::Schema, storage::sqlite3_ondisk::DatabaseHeader, vdbe::Program}; -use crate::{Connection, Result}; +use crate::{Connection, Result, SymbolTable}; use sqlite3_parser::ast::{Expr, Limit, QualifiedName}; use std::rc::Weak; use std::{cell::RefCell, rc::Rc}; @@ -15,10 +15,11 @@ pub fn translate_delete( limit: Option, database_header: Rc>, connection: Weak, + syms: &SymbolTable, ) -> Result { let delete_plan = prepare_delete_plan(schema, tbl_name, where_clause, limit)?; let optimized_plan = optimize_plan(delete_plan)?; - emit_program(database_header, optimized_plan, connection) + emit_program(database_header, optimized_plan, connection, syms) } pub fn prepare_delete_plan( diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 9d1c124da..2cf9dafe4 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -14,7 +14,7 @@ use crate::types::{OwnedRecord, OwnedValue}; use crate::util::exprs_are_equivalent; use crate::vdbe::builder::ProgramBuilder; use crate::vdbe::{insn::Insn, BranchOffset, Program}; -use crate::{Connection, Result}; +use crate::{Connection, Result, SymbolTable}; use super::expr::{ translate_aggregation, translate_aggregation_groupby, translate_condition_expr, translate_expr, @@ -176,10 +176,11 @@ pub fn emit_program( database_header: Rc>, plan: Plan, connection: Weak, + syms: &SymbolTable, ) -> Result { match plan { - Plan::Select(plan) => emit_program_for_select(database_header, plan, connection), - Plan::Delete(plan) => emit_program_for_delete(database_header, plan, connection), + Plan::Select(plan) => emit_program_for_select(database_header, plan, connection, syms), + Plan::Delete(plan) => emit_program_for_delete(database_header, plan, connection, syms), } } @@ -187,6 +188,7 @@ fn emit_program_for_select( database_header: Rc>, mut plan: SelectPlan, connection: Weak, + syms: &SymbolTable, ) -> Result { let (mut program, mut metadata, init_label, start_offset) = prologue()?; @@ -235,10 +237,11 @@ fn emit_program_for_select( &mut plan.source, &plan.referenced_tables, &mut metadata, + syms, )?; // Process result columns and expressions in the inner loop - inner_loop_emit(&mut program, &mut plan, &mut metadata)?; + inner_loop_emit(&mut program, &mut plan, &mut metadata, syms)?; // Clean up and close the main execution loop close_loop(&mut program, &plan.source, &mut metadata)?; @@ -260,6 +263,7 @@ fn emit_program_for_select( plan.limit, &plan.referenced_tables, &mut metadata, + syms, )?; } else if !plan.aggregates.is_empty() { // Handle aggregation without GROUP BY @@ -269,6 +273,7 @@ fn emit_program_for_select( &plan.result_columns, &plan.aggregates, &mut metadata, + syms, )?; // Single row result for aggregates without GROUP BY, so ORDER BY not needed order_by_necessary = false; @@ -297,6 +302,7 @@ fn emit_program_for_delete( database_header: Rc>, mut plan: DeletePlan, connection: Weak, + syms: &SymbolTable, ) -> Result { let (mut program, mut metadata, init_label, start_offset) = prologue()?; @@ -328,6 +334,7 @@ fn emit_program_for_delete( &mut plan.source, &plan.referenced_tables, &mut metadata, + syms, )?; emit_delete_insns(&mut program, &plan.source, &plan.limit, &metadata)?; @@ -590,6 +597,7 @@ fn open_loop( source: &mut SourceOperator, referenced_tables: &[BTreeTableReference], metadata: &mut Metadata, + syms: &SymbolTable, ) -> Result<()> { match source { SourceOperator::Join { @@ -600,7 +608,7 @@ fn open_loop( outer, .. } => { - open_loop(program, left, referenced_tables, metadata)?; + open_loop(program, left, referenced_tables, metadata, syms)?; let mut jump_target_when_false = *metadata .next_row_labels @@ -620,7 +628,7 @@ fn open_loop( .next_row_labels .insert(right.id(), jump_target_when_false); - open_loop(program, right, referenced_tables, metadata)?; + open_loop(program, right, referenced_tables, metadata, syms)?; if let Some(predicates) = predicates { let jump_target_when_true = program.allocate_label(); @@ -636,6 +644,7 @@ fn open_loop( predicate, condition_metadata, None, + syms, )?; } program.resolve_label(jump_target_when_true, program.offset()); @@ -707,6 +716,7 @@ fn open_loop( expr, condition_metadata, None, + syms, )?; program.resolve_label(jump_target_when_true, program.offset()); } @@ -745,7 +755,14 @@ fn open_loop( ast::Operator::Equals | ast::Operator::Greater | ast::Operator::GreaterEquals => { - translate_expr(program, Some(referenced_tables), cmp_expr, cmp_reg, None)?; + translate_expr( + program, + Some(referenced_tables), + cmp_expr, + cmp_reg, + None, + syms, + )?; } ast::Operator::Less | ast::Operator::LessEquals => { program.emit_insn(Insn::Null { @@ -778,7 +795,14 @@ fn open_loop( *metadata.termination_label_stack.last().unwrap(), ); if *cmp_op == ast::Operator::Less || *cmp_op == ast::Operator::LessEquals { - translate_expr(program, Some(referenced_tables), cmp_expr, cmp_reg, None)?; + translate_expr( + program, + Some(referenced_tables), + cmp_expr, + cmp_reg, + None, + syms, + )?; } program.defer_label_resolution(scan_loop_body_label, program.offset() as usize); @@ -866,7 +890,14 @@ fn open_loop( if let Search::RowidEq { cmp_expr } = search { let src_reg = program.alloc_register(); - translate_expr(program, Some(referenced_tables), cmp_expr, src_reg, None)?; + translate_expr( + program, + Some(referenced_tables), + cmp_expr, + src_reg, + None, + syms, + )?; program.emit_insn_with_label_dependency( Insn::SeekRowid { cursor_id: table_cursor_id, @@ -890,6 +921,7 @@ fn open_loop( predicate, condition_metadata, None, + syms, )?; program.resolve_label(jump_target_when_true, program.offset()); } @@ -927,6 +959,7 @@ fn inner_loop_emit( program: &mut ProgramBuilder, plan: &mut SelectPlan, metadata: &mut Metadata, + syms: &SymbolTable, ) -> Result<()> { // if we have a group by, we emit a record into the group by sorter. if let Some(group_by) = &plan.group_by { @@ -940,6 +973,7 @@ fn inner_loop_emit( aggregates: &plan.aggregates, }, &plan.referenced_tables, + syms, ); } // if we DONT have a group by, but we have aggregates, we emit without ResultRow. @@ -952,6 +986,7 @@ fn inner_loop_emit( metadata, InnerLoopEmitTarget::AggStep, &plan.referenced_tables, + syms, ); } // if we DONT have a group by, but we have an order by, we emit a record into the order by sorter. @@ -963,6 +998,7 @@ fn inner_loop_emit( metadata, InnerLoopEmitTarget::OrderBySorter { order_by }, &plan.referenced_tables, + syms, ); } // if we have neither, we emit a ResultRow. In that case, if we have a Limit, we handle that with DecrJumpZero. @@ -973,6 +1009,7 @@ fn inner_loop_emit( metadata, InnerLoopEmitTarget::ResultRow { limit: plan.limit }, &plan.referenced_tables, + syms, ) } @@ -986,6 +1023,7 @@ fn inner_loop_source_emit( metadata: &mut Metadata, emit_target: InnerLoopEmitTarget, referenced_tables: &[BTreeTableReference], + syms: &SymbolTable, ) -> Result<()> { match emit_target { InnerLoopEmitTarget::GroupBySorter { @@ -1003,7 +1041,7 @@ fn inner_loop_source_emit( for expr in group_by.exprs.iter() { let key_reg = cur_reg; cur_reg += 1; - translate_expr(program, Some(referenced_tables), expr, key_reg, None)?; + translate_expr(program, Some(referenced_tables), expr, key_reg, None, syms)?; } // Then we have the aggregate arguments. for agg in aggregates.iter() { @@ -1016,7 +1054,7 @@ fn inner_loop_source_emit( for expr in agg.args.iter() { let agg_reg = cur_reg; cur_reg += 1; - translate_expr(program, Some(referenced_tables), expr, agg_reg, None)?; + translate_expr(program, Some(referenced_tables), expr, agg_reg, None, syms)?; } } @@ -1044,6 +1082,7 @@ fn inner_loop_source_emit( &mut metadata.result_column_indexes_in_orderby_sorter, metadata.sort_metadata.as_ref().unwrap(), None, + syms, )?; Ok(()) } @@ -1060,7 +1099,7 @@ fn inner_loop_source_emit( // Instead, we translate the aggregates + any expressions that do not contain aggregates. for (i, agg) in aggregates.iter().enumerate() { let reg = start_reg + i; - translate_aggregation(program, referenced_tables, agg, reg)?; + translate_aggregation(program, referenced_tables, agg, reg, syms)?; } for (i, rc) in result_columns.iter().enumerate() { if rc.contains_aggregates { @@ -1070,7 +1109,7 @@ fn inner_loop_source_emit( continue; } let reg = start_reg + num_aggs + i; - translate_expr(program, Some(referenced_tables), &rc.expr, reg, None)?; + translate_expr(program, Some(referenced_tables), &rc.expr, reg, None, syms)?; } Ok(()) } @@ -1085,6 +1124,7 @@ fn inner_loop_source_emit( result_columns, None, limit.map(|l| (l, *metadata.termination_label_stack.last().unwrap())), + syms, )?; Ok(()) @@ -1301,6 +1341,7 @@ fn group_by_emit( limit: Option, referenced_tables: &[BTreeTableReference], metadata: &mut Metadata, + syms: &SymbolTable, ) -> Result<()> { let sort_loop_start_label = program.allocate_label(); let grouping_done_label = program.allocate_label(); @@ -1451,6 +1492,7 @@ fn group_by_emit( cursor_index, agg, agg_result_reg, + syms, )?; cursor_index += agg.args.len(); } @@ -1585,6 +1627,7 @@ fn group_by_emit( jump_target_when_true: i64::MAX, // unused }, Some(&precomputed_exprs_to_register), + syms, )?; } } @@ -1597,6 +1640,7 @@ fn group_by_emit( result_columns, Some(&precomputed_exprs_to_register), limit.map(|l| (l, *metadata.termination_label_stack.last().unwrap())), + syms, )?; } Some(order_by) => { @@ -1608,6 +1652,7 @@ fn group_by_emit( &mut metadata.result_column_indexes_in_orderby_sorter, metadata.sort_metadata.as_ref().unwrap(), Some(&precomputed_exprs_to_register), + syms, )?; } } @@ -1647,6 +1692,7 @@ fn agg_without_group_by_emit( result_columns: &[ResultSetColumn], aggregates: &[Aggregate], metadata: &mut Metadata, + syms: &SymbolTable, ) -> Result<()> { let agg_start_reg = metadata.aggregation_start_register.unwrap(); for (i, agg) in aggregates.iter().enumerate() { @@ -1672,6 +1718,7 @@ fn agg_without_group_by_emit( result_columns, Some(&precomputed_exprs_to_register), None, + syms, )?; Ok(()) @@ -1822,6 +1869,7 @@ fn emit_select_result( result_columns: &[ResultSetColumn], precomputed_exprs_to_register: Option<&Vec<(&ast::Expr, usize)>>, limit: Option<(usize, BranchOffset)>, + syms: &SymbolTable, ) -> Result<()> { let start_reg = program.alloc_registers(result_columns.len()); for (i, rc) in result_columns.iter().enumerate() { @@ -1832,6 +1880,7 @@ fn emit_select_result( &rc.expr, reg, precomputed_exprs_to_register, + syms, )?; } emit_result_row_and_limit(program, start_reg, result_columns.len(), limit)?; @@ -1867,6 +1916,7 @@ fn order_by_sorter_insert( result_column_indexes_in_orderby_sorter: &mut HashMap, sort_metadata: &SortMetadata, precomputed_exprs_to_register: Option<&Vec<(&ast::Expr, usize)>>, + syms: &SymbolTable, ) -> Result<()> { let order_by_len = order_by.len(); // If any result columns can be skipped due to being an exact duplicate of a sort key, we need to know which ones and their new index in the ORDER BY sorter. @@ -1888,6 +1938,7 @@ fn order_by_sorter_insert( expr, key_reg, precomputed_exprs_to_register, + syms, )?; } let mut cur_reg = start_reg + order_by_len; @@ -1907,6 +1958,7 @@ fn order_by_sorter_insert( &rc.expr, cur_reg, precomputed_exprs_to_register, + syms, )?; result_column_indexes_in_orderby_sorter.insert(i, cur_idx_in_orderby_sorter); cur_idx_in_orderby_sorter += 1; diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 8b91a2969..5032e8274 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -8,7 +8,7 @@ use crate::function::{AggFunc, Func, FuncCtx, MathFuncArity, ScalarFunc}; use crate::schema::Type; use crate::util::{exprs_are_equivalent, normalize_ident}; use crate::vdbe::{builder::ProgramBuilder, insn::Insn, BranchOffset}; -use crate::Result; +use crate::{Result, SymbolTable}; use super::plan::{Aggregate, BTreeTableReference}; @@ -25,6 +25,7 @@ pub fn translate_condition_expr( expr: &ast::Expr, condition_metadata: ConditionMetadata, precomputed_exprs_to_registers: Option<&Vec<(&ast::Expr, usize)>>, + syms: &SymbolTable, ) -> Result<()> { match expr { ast::Expr::Between { .. } => todo!(), @@ -40,6 +41,7 @@ pub fn translate_condition_expr( ..condition_metadata }, precomputed_exprs_to_registers, + syms, ); let _ = translate_condition_expr( program, @@ -47,6 +49,7 @@ pub fn translate_condition_expr( rhs, condition_metadata, precomputed_exprs_to_registers, + syms, ); } ast::Expr::Binary(lhs, ast::Operator::Or, rhs) => { @@ -62,6 +65,7 @@ pub fn translate_condition_expr( ..condition_metadata }, precomputed_exprs_to_registers, + syms, ); program.resolve_label(jump_target_when_false, program.offset()); let _ = translate_condition_expr( @@ -70,6 +74,7 @@ pub fn translate_condition_expr( rhs, condition_metadata, precomputed_exprs_to_registers, + syms, ); } ast::Expr::Binary(lhs, op, rhs) => { @@ -80,6 +85,7 @@ pub fn translate_condition_expr( lhs, lhs_reg, precomputed_exprs_to_registers, + syms, ); if let ast::Expr::Literal(_) = lhs.as_ref() { program.mark_last_insn_constant() @@ -91,6 +97,7 @@ pub fn translate_condition_expr( rhs, rhs_reg, precomputed_exprs_to_registers, + syms, ); if let ast::Expr::Literal(_) = rhs.as_ref() { program.mark_last_insn_constant() @@ -339,6 +346,7 @@ pub fn translate_condition_expr( lhs, lhs_reg, precomputed_exprs_to_registers, + syms, )?; let rhs = rhs.as_ref().unwrap(); @@ -368,6 +376,7 @@ pub fn translate_condition_expr( expr, rhs_reg, precomputed_exprs_to_registers, + syms, )?; // If this is not the last condition, we need to jump to the 'jump_target_when_true' label if the condition is true. if !last_condition { @@ -411,6 +420,7 @@ pub fn translate_condition_expr( expr, rhs_reg, precomputed_exprs_to_registers, + syms, )?; program.emit_insn_with_label_dependency( Insn::Eq { @@ -456,6 +466,7 @@ pub fn translate_condition_expr( lhs, column_reg, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = lhs.as_ref() { program.mark_last_insn_constant(); @@ -466,6 +477,7 @@ pub fn translate_condition_expr( rhs, pattern_reg, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = rhs.as_ref() { program.mark_last_insn_constant(); @@ -539,6 +551,7 @@ pub fn translate_condition_expr( expr, condition_metadata, precomputed_exprs_to_registers, + syms, ); } } @@ -553,6 +566,7 @@ pub fn translate_expr( expr: &ast::Expr, target_register: usize, precomputed_exprs_to_registers: Option<&Vec<(&ast::Expr, usize)>>, + syms: &SymbolTable, ) -> Result { if let Some(precomputed_exprs_to_registers) = precomputed_exprs_to_registers { for (precomputed_expr, reg) in precomputed_exprs_to_registers.iter() { @@ -576,6 +590,7 @@ pub fn translate_expr( e1, e1_reg, precomputed_exprs_to_registers, + syms, )?; let e2_reg = program.alloc_register(); translate_expr( @@ -584,6 +599,7 @@ pub fn translate_expr( e2, e2_reg, precomputed_exprs_to_registers, + syms, )?; match op { @@ -744,6 +760,7 @@ pub fn translate_expr( base_expr, base_reg.unwrap(), precomputed_exprs_to_registers, + syms, )?; }; for (when_expr, then_expr) in when_then_pairs { @@ -753,6 +770,7 @@ pub fn translate_expr( when_expr, expr_reg, precomputed_exprs_to_registers, + syms, )?; match base_reg { // CASE 1 WHEN 0 THEN 0 ELSE 1 becomes 1==0, Ne branch to next clause @@ -781,6 +799,7 @@ pub fn translate_expr( then_expr, target_register, precomputed_exprs_to_registers, + syms, )?; program.emit_insn_with_label_dependency( Insn::Goto { @@ -801,6 +820,7 @@ pub fn translate_expr( expr, target_register, precomputed_exprs_to_registers, + syms, )?; } // If ELSE isn't specified, it means ELSE null. @@ -823,6 +843,7 @@ pub fn translate_expr( expr, reg_expr, precomputed_exprs_to_registers, + syms, )?; let reg_type = program.alloc_register(); program.emit_insn(Insn::String8 { @@ -855,8 +876,13 @@ pub fn translate_expr( order_by: _, } => { let args_count = if let Some(args) = args { args.len() } else { 0 }; - let func_type: Option = - Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count).ok(); + let func_name = normalize_ident(name.0.as_str()); + let func_type = match Func::resolve_function(&func_name, args_count).ok() { + Some(func) => Some(func), + None => syms + .resolve_function(&func_name, args_count) + .map(|func| Func::External(func)), + }; if func_type.is_none() { crate::bail_parse_error!("unknown function {}", name.0); @@ -871,6 +897,16 @@ pub fn translate_expr( Func::Agg(_) => { crate::bail_parse_error!("aggregation function in non-aggregation context") } + Func::External(_) => { + let regs = program.alloc_register(); + program.emit_insn(Insn::Function { + constant_mask: 0, + start_reg: regs, + dest: target_register, + func: func_ctx, + }); + Ok(target_register) + } #[cfg(feature = "json")] Func::Json(j) => match j { JsonFunc::Json => { @@ -895,6 +931,7 @@ pub fn translate_expr( &args[0], regs, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { constant_mask: 0, @@ -910,6 +947,7 @@ pub fn translate_expr( args, referenced_tables, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { @@ -945,6 +983,7 @@ pub fn translate_expr( &args[0], json_reg, precomputed_exprs_to_registers, + syms, )?; if args.len() == 2 { @@ -954,6 +993,7 @@ pub fn translate_expr( &args[1], path_reg, precomputed_exprs_to_registers, + syms, )?; } @@ -977,6 +1017,7 @@ pub fn translate_expr( args, referenced_tables, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { @@ -1013,6 +1054,7 @@ pub fn translate_expr( arg, target_register, precomputed_exprs_to_registers, + syms, )?; if index < args.len() - 1 { program.emit_insn_with_label_dependency( @@ -1057,6 +1099,7 @@ pub fn translate_expr( arg, reg, precomputed_exprs_to_registers, + syms, )?; } program.emit_insn(Insn::Function { @@ -1089,6 +1132,7 @@ pub fn translate_expr( arg, reg, precomputed_exprs_to_registers, + syms, )?; } program.emit_insn(Insn::Function { @@ -1125,6 +1169,7 @@ pub fn translate_expr( &args[0], temp_reg, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::NotNull { reg: temp_reg, @@ -1137,6 +1182,7 @@ pub fn translate_expr( &args[1], temp_reg, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Copy { src_reg: temp_reg, @@ -1161,6 +1207,7 @@ pub fn translate_expr( &args[0], temp_reg, precomputed_exprs_to_registers, + syms, )?; let jump_target_when_false = program.allocate_label(); program.emit_insn_with_label_dependency( @@ -1177,6 +1224,7 @@ pub fn translate_expr( &args[1], target_register, precomputed_exprs_to_registers, + syms, )?; let jump_target_result = program.allocate_label(); program.emit_insn_with_label_dependency( @@ -1192,6 +1240,7 @@ pub fn translate_expr( &args[2], target_register, precomputed_exprs_to_registers, + syms, )?; program.resolve_label(jump_target_result, program.offset()); Ok(target_register) @@ -1219,6 +1268,7 @@ pub fn translate_expr( arg, reg, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = arg { program.mark_last_insn_constant() @@ -1268,6 +1318,7 @@ pub fn translate_expr( &args[0], regs, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { constant_mask: 0, @@ -1304,6 +1355,7 @@ pub fn translate_expr( arg, target_reg, precomputed_exprs_to_registers, + syms, )?; } } @@ -1341,6 +1393,7 @@ pub fn translate_expr( &args[0], str_reg, precomputed_exprs_to_registers, + syms, )?; translate_expr( program, @@ -1348,6 +1401,7 @@ pub fn translate_expr( &args[1], start_reg, precomputed_exprs_to_registers, + syms, )?; if args.len() == 3 { translate_expr( @@ -1356,6 +1410,7 @@ pub fn translate_expr( &args[2], length_reg, precomputed_exprs_to_registers, + syms, )?; } @@ -1385,6 +1440,7 @@ pub fn translate_expr( &args[0], regs, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { constant_mask: 0, @@ -1408,6 +1464,7 @@ pub fn translate_expr( &args[0], arg_reg, precomputed_exprs_to_registers, + syms, )?; start_reg = arg_reg; } @@ -1432,6 +1489,7 @@ pub fn translate_expr( arg, target_reg, precomputed_exprs_to_registers, + syms, )?; } } @@ -1471,6 +1529,7 @@ pub fn translate_expr( arg, reg, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = arg { program.mark_last_insn_constant(); @@ -1503,6 +1562,7 @@ pub fn translate_expr( arg, reg, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = arg { program.mark_last_insn_constant() @@ -1536,6 +1596,7 @@ pub fn translate_expr( arg, reg, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = arg { program.mark_last_insn_constant() @@ -1573,6 +1634,7 @@ pub fn translate_expr( &args[0], first_reg, precomputed_exprs_to_registers, + syms, )?; let second_reg = program.alloc_register(); translate_expr( @@ -1581,6 +1643,7 @@ pub fn translate_expr( &args[1], second_reg, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { constant_mask: 0, @@ -1637,6 +1700,7 @@ pub fn translate_expr( &args[0], str_reg, precomputed_exprs_to_registers, + syms, )?; translate_expr( program, @@ -1644,6 +1708,7 @@ pub fn translate_expr( &args[1], pattern_reg, precomputed_exprs_to_registers, + syms, )?; translate_expr( @@ -1652,6 +1717,7 @@ pub fn translate_expr( &args[2], replacement_reg, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { @@ -1690,6 +1756,7 @@ pub fn translate_expr( &args[0], regs, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { constant_mask: 0, @@ -1734,6 +1801,7 @@ pub fn translate_expr( arg, reg, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = arg { program.mark_last_insn_constant() @@ -1787,6 +1855,7 @@ pub fn translate_expr( &args[0], reg, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::Function { @@ -1820,6 +1889,7 @@ pub fn translate_expr( &args[0], reg1, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = &args[0] { program.mark_last_insn_constant(); @@ -1831,6 +1901,7 @@ pub fn translate_expr( &args[1], reg2, precomputed_exprs_to_registers, + syms, )?; if let ast::Expr::Literal(_) = &args[1] { program.mark_last_insn_constant(); @@ -1867,6 +1938,7 @@ pub fn translate_expr( arg, regs + i, precomputed_exprs_to_registers, + syms, )?; } @@ -1978,6 +2050,7 @@ pub fn translate_expr( &exprs[0], target_register, precomputed_exprs_to_registers, + syms, )?; } else { // Parenthesized expressions with multiple arguments are reserved for special cases @@ -2030,6 +2103,7 @@ pub fn translate_expr( expr, reg, precomputed_exprs_to_registers, + syms, )?; let zero_reg = program.alloc_register(); program.emit_insn(Insn::Integer { @@ -2077,6 +2151,7 @@ pub fn translate_expr( expr, reg, precomputed_exprs_to_registers, + syms, )?; program.emit_insn(Insn::BitNot { reg, @@ -2097,6 +2172,7 @@ fn translate_variable_sized_function_parameter_list( args: &Option>, referenced_tables: Option<&[BTreeTableReference]>, precomputed_exprs_to_registers: Option<&Vec<(&ast::Expr, usize)>>, + syms: &SymbolTable, ) -> Result { let args = args.as_deref().unwrap_or_default(); @@ -2110,6 +2186,7 @@ fn translate_variable_sized_function_parameter_list( arg, current_reg, precomputed_exprs_to_registers, + syms, )?; current_reg += 1; @@ -2149,6 +2226,7 @@ pub fn translate_aggregation( referenced_tables: &[BTreeTableReference], agg: &Aggregate, target_register: usize, + syms: &SymbolTable, ) -> Result { let dest = match agg.func { AggFunc::Avg => { @@ -2157,7 +2235,7 @@ pub fn translate_aggregation( } let expr = &agg.args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None)?; + let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None, syms)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -2172,7 +2250,8 @@ pub fn translate_aggregation( } else { let expr = &agg.args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None); + let _ = + translate_expr(program, Some(referenced_tables), expr, expr_reg, None, syms)?; expr_reg }; program.emit_insn(Insn::AggStep { @@ -2208,13 +2287,14 @@ pub fn translate_aggregation( delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\""))); } - translate_expr(program, Some(referenced_tables), expr, expr_reg, None)?; + translate_expr(program, Some(referenced_tables), expr, expr_reg, None, syms)?; translate_expr( program, Some(referenced_tables), &delimiter_expr, delimiter_reg, None, + syms, )?; program.emit_insn(Insn::AggStep { @@ -2232,7 +2312,7 @@ pub fn translate_aggregation( } let expr = &agg.args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None)?; + let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None, syms)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -2247,7 +2327,7 @@ pub fn translate_aggregation( } let expr = &agg.args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None)?; + let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None, syms)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -2273,13 +2353,14 @@ pub fn translate_aggregation( _ => crate::bail_parse_error!("Incorrect delimiter parameter"), }; - translate_expr(program, Some(referenced_tables), expr, expr_reg, None)?; + translate_expr(program, Some(referenced_tables), expr, expr_reg, None, syms)?; translate_expr( program, Some(referenced_tables), &delimiter_expr, delimiter_reg, None, + syms, )?; program.emit_insn(Insn::AggStep { @@ -2297,7 +2378,7 @@ pub fn translate_aggregation( } let expr = &agg.args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None)?; + let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None, syms)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -2312,7 +2393,7 @@ pub fn translate_aggregation( } let expr = &agg.args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None)?; + let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, None, syms)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -2332,6 +2413,7 @@ pub fn translate_aggregation_groupby( cursor_index: usize, agg: &Aggregate, target_register: usize, + syms: &SymbolTable, ) -> Result { let emit_column = |program: &mut ProgramBuilder, expr_reg: usize| { program.emit_insn(Insn::Column { @@ -2397,6 +2479,7 @@ pub fn translate_aggregation_groupby( &delimiter_expr, delimiter_reg, None, + syms, )?; program.emit_insn(Insn::AggStep { @@ -2459,6 +2542,7 @@ pub fn translate_aggregation_groupby( &delimiter_expr, delimiter_reg, None, + syms, )?; program.emit_insn(Insn::AggStep { diff --git a/core/translate/insert.rs b/core/translate/insert.rs index b057ade33..e97c8372f 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -12,6 +12,7 @@ use crate::{ storage::sqlite3_ondisk::DatabaseHeader, translate::expr::translate_expr, vdbe::{builder::ProgramBuilder, insn::Insn, Program}, + SymbolTable, }; use crate::{Connection, Result}; @@ -26,6 +27,7 @@ pub fn translate_insert( _returning: &Option>, database_header: Rc>, connection: Weak, + syms: &SymbolTable, ) -> Result { if with.is_some() { crate::bail_parse_error!("WITH clause is not supported"); @@ -124,6 +126,7 @@ pub fn translate_insert( column_registers_start, true, rowid_reg, + syms, )?; program.emit_insn(Insn::Yield { yield_reg, @@ -165,6 +168,7 @@ pub fn translate_insert( column_registers_start, false, rowid_reg, + syms, )?; } @@ -379,6 +383,7 @@ fn populate_column_registers( column_registers_start: usize, inserting_multiple_rows: bool, rowid_reg: usize, + syms: &SymbolTable, ) -> Result<()> { for (i, mapping) in column_mappings.iter().enumerate() { let target_reg = column_registers_start + i; @@ -401,6 +406,7 @@ fn populate_column_registers( value.get(value_index).expect("value index out of bounds"), reg, None, + syms, )?; if write_directly_to_rowid_reg { program.emit_insn(Insn::SoftNull { reg: target_reg }); diff --git a/core/translate/mod.rs b/core/translate/mod.rs index ca91eeebd..7cd2bf840 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -21,7 +21,7 @@ use crate::storage::pager::Pager; use crate::storage::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; use crate::translate::delete::translate_delete; use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program}; -use crate::{bail_parse_error, Connection, Result}; +use crate::{bail_parse_error, Connection, Result, SymbolTable}; use insert::translate_insert; use select::translate_select; use sqlite3_parser::ast::fmt::ToTokens; @@ -38,6 +38,7 @@ pub fn translate( database_header: Rc>, pager: Rc, connection: Weak, + syms: &SymbolTable, ) -> Result { match stmt { ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"), @@ -81,6 +82,7 @@ pub fn translate( limit, database_header, connection, + syms, ), ast::Stmt::Detach(_) => bail_parse_error!("DETACH not supported yet"), ast::Stmt::DropIndex { .. } => bail_parse_error!("DROP INDEX not supported yet"), @@ -94,7 +96,9 @@ pub fn translate( ast::Stmt::Release(_) => bail_parse_error!("RELEASE not supported yet"), ast::Stmt::Rollback { .. } => bail_parse_error!("ROLLBACK not supported yet"), ast::Stmt::Savepoint(_) => bail_parse_error!("SAVEPOINT not supported yet"), - ast::Stmt::Select(select) => translate_select(schema, select, database_header, connection), + ast::Stmt::Select(select) => { + translate_select(schema, select, database_header, connection, syms) + } ast::Stmt::Update { .. } => bail_parse_error!("UPDATE not supported yet"), ast::Stmt::Vacuum(_, _) => bail_parse_error!("VACUUM not supported yet"), ast::Stmt::Insert { @@ -114,6 +118,7 @@ pub fn translate( &returning, database_header, connection, + syms, ), } } diff --git a/core/translate/select.rs b/core/translate/select.rs index ea7c9b119..5ad2d41d5 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -11,8 +11,8 @@ use crate::translate::planner::{ parse_where, resolve_aggregates, OperatorIdCounter, }; use crate::util::normalize_ident; -use crate::Connection; use crate::{schema::Schema, vdbe::Program, Result}; +use crate::{Connection, SymbolTable}; use sqlite3_parser::ast; use sqlite3_parser::ast::ResultColumn; @@ -21,10 +21,11 @@ pub fn translate_select( select: ast::Select, database_header: Rc>, connection: Weak, + syms: &SymbolTable, ) -> Result { let select_plan = prepare_select_plan(schema, select)?; let optimized_plan = optimize_plan(select_plan)?; - emit_program(database_header, optimized_plan, connection) + emit_program(database_header, optimized_plan, connection, syms) } pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 52b74ccc2..470b48588 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -2212,6 +2212,10 @@ impl Program { }, _ => unreachable!(), // when more extension types are added }, + crate::function::Func::External(f) => { + let result = (f.func)(&[])?; + state.registers[*dest] = result; + } crate::function::Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => match math_func { MathFunc::Pi => { From 0aabcddf18fed76be57bb12ed0aa44585f205808 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sat, 28 Dec 2024 10:35:54 +0200 Subject: [PATCH 4/4] ext/uuid: Convert uuid4() to external function --- core/ext/mod.rs | 5 +++++ core/ext/uuid.rs | 28 +++++++++++++++------------- core/lib.rs | 6 ++++-- core/translate/expr.rs | 2 +- core/vdbe/mod.rs | 2 +- 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 2140957d4..cea65a98d 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -30,3 +30,8 @@ impl ExtFunc { } } } + +pub fn init(db: &mut crate::Database) { + #[cfg(feature = "uuid")] + uuid::init(db); +} diff --git a/core/ext/uuid.rs b/core/ext/uuid.rs index d3225e8e0..92fdd831a 100644 --- a/core/ext/uuid.rs +++ b/core/ext/uuid.rs @@ -1,7 +1,7 @@ use super::ExtFunc; use crate::{ types::{LimboText, OwnedValue}, - LimboError, + Database, LimboError, }; use std::rc::Rc; use uuid::{ContextV7, Timestamp, Uuid}; @@ -9,7 +9,6 @@ use uuid::{ContextV7, Timestamp, Uuid}; #[derive(Debug, Clone, Copy, PartialEq)] pub enum UuidFunc { Uuid4Str, - Uuid4, Uuid7, Uuid7TS, UuidStr, @@ -20,7 +19,6 @@ impl UuidFunc { pub fn resolve_function(name: &str, num_args: usize) -> Option { match name { "uuid4_str" => Some(ExtFunc::Uuid(Self::Uuid4Str)), - "uuid4" => Some(ExtFunc::Uuid(Self::Uuid4)), "uuid7" if num_args < 2 => Some(ExtFunc::Uuid(Self::Uuid7)), "uuid_str" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidStr)), "uuid_blob" if num_args == 1 => Some(ExtFunc::Uuid(Self::UuidBlob)), @@ -36,7 +34,6 @@ impl std::fmt::Display for UuidFunc { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Uuid4Str => write!(f, "uuid4_str"), - Self::Uuid4 => write!(f, "uuid4"), Self::Uuid7 => write!(f, "uuid7"), Self::Uuid7TS => write!(f, "uuid7_timestamp_ms"), Self::UuidStr => write!(f, "uuid_str"), @@ -47,9 +44,6 @@ impl std::fmt::Display for UuidFunc { pub fn exec_uuid(var: &UuidFunc, sec: Option<&OwnedValue>) -> crate::Result { match var { - UuidFunc::Uuid4 => Ok(OwnedValue::Blob(Rc::new( - Uuid::new_v4().into_bytes().to_vec(), - ))), UuidFunc::Uuid4Str => Ok(OwnedValue::Text(LimboText::new(Rc::new( Uuid::new_v4().to_string(), )))), @@ -71,6 +65,12 @@ pub fn exec_uuid(var: &UuidFunc, sec: Option<&OwnedValue>) -> crate::Result crate::Result { + Ok(OwnedValue::Blob(Rc::new( + Uuid::new_v4().into_bytes().to_vec(), + ))) +} + pub fn exec_uuidstr(reg: &OwnedValue) -> crate::Result { match reg { OwnedValue::Blob(blob) => { @@ -136,6 +136,10 @@ fn uuid_to_unix(uuid: &[u8; 16]) -> u64 { | (uuid[5] as u64) } +pub fn init(db: &mut Database) { + db.define_scalar_function("uuid4", |_args| exec_uuid4()); +} + #[cfg(test)] #[cfg(feature = "uuid")] pub mod test { @@ -143,10 +147,9 @@ pub mod test { use crate::types::OwnedValue; #[test] fn test_exec_uuid_v4blob() { - use super::exec_uuid; + use super::exec_uuid4; use uuid::Uuid; - let func = UuidFunc::Uuid4; - let owned_val = exec_uuid(&func, None); + let owned_val = exec_uuid4(); match owned_val { Ok(OwnedValue::Blob(blob)) => { assert_eq!(blob.len(), 16); @@ -303,11 +306,10 @@ pub mod test { #[test] fn test_exec_uuid_v4_blob_to_str() { - use super::{exec_uuid, exec_uuidstr, UuidFunc}; + use super::{exec_uuid4, exec_uuidstr}; use uuid::Uuid; // convert a v4 blob to a string - let owned_val = - exec_uuidstr(&exec_uuid(&UuidFunc::Uuid4, None).expect("uuid v7 blob to generate")); + let owned_val = exec_uuidstr(&exec_uuid4().expect("uuid v7 blob to generate")); match owned_val { Ok(OwnedValue::Text(v4str)) => { assert_eq!(v4str.value.len(), 36); diff --git a/core/lib.rs b/core/lib.rs index 321954e83..5ef48e742 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -124,14 +124,16 @@ impl Database { let header = db_header; let schema = Rc::new(RefCell::new(Schema::new())); let syms = Rc::new(RefCell::new(SymbolTable::new())); - let db = Arc::new(Database { + let mut db = Database { pager: pager.clone(), schema: schema.clone(), header: header.clone(), _shared_page_cache: _shared_page_cache.clone(), _shared_wal: shared_wal.clone(), syms, - }); + }; + ext::init(&mut db); + let db = Arc::new(db); let conn = Rc::new(Connection { db: db.clone(), pager: pager, diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 5032e8274..6b12c24e6 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1766,7 +1766,7 @@ pub fn translate_expr( }); Ok(target_register) } - UuidFunc::Uuid4 | UuidFunc::Uuid4Str => { + UuidFunc::Uuid4Str => { if args.is_some() { crate::bail_parse_error!( "{} function with arguments", diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 470b48588..20b5e39cb 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -2183,7 +2183,7 @@ impl Program { crate::function::Func::Extension(extfn) => match extfn { #[cfg(feature = "uuid")] ExtFunc::Uuid(uuidfn) => match uuidfn { - UuidFunc::Uuid4 | UuidFunc::Uuid4Str => { + UuidFunc::Uuid4Str => { state.registers[*dest] = exec_uuid(uuidfn, None)? } UuidFunc::Uuid7 => match arg_count {