Skip to content

Commit

Permalink
Enums & Literals as map keys (#1178)
Browse files Browse the repository at this point in the history
Issue #1050

Added support for enums and literal strings in map keys:

```baml
enum MapKey {
  A
  B
  C
}

class Fields {
  // Enum as key
  e map<MapKey, string>
  // Single literal as key
  l1 map<"literal", string>
  // Union of literals as keys
  l2 map<"one" | "two" | ("three" | "four"), string>
}
``` 

Literal integers are more complicated since they require maps to support
int keys. See #1180.
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Add support for enums and literal strings as map keys in BAML, with
updated validation, coercion logic, and tests.
> 
>   - **Behavior**:
> - Support for enums and literal strings as map keys added in `mod.rs`
and `types.rs`.
> - Validation logic updated to allow enums and literal strings as map
keys.
> - Coercion logic updated to handle enums and literal strings as map
keys.
>   - **Tests**:
> - Added tests in `map_enums_and_literals.baml` and `map_types.baml` to
verify new map key functionality.
> - Updated `test_functions.py` and `integ-tests.test.ts` to include
cases for enum and literal string map keys.
>   - **Misc**:
> - Updated error messages in `error.rs` to reflect new map key types.
> - Minor updates in `async_client.py`, `sync_client.py`, and
`client.rb` to support new map key types.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for c7742fd. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Nov 19, 2024
1 parent fcdbdfb commit 39e0271
Show file tree
Hide file tree
Showing 35 changed files with 1,210 additions and 113 deletions.
8 changes: 7 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,13 @@ impl IRHelper for IntermediateRepr {
match maybe_item_type {
Some(item_type) => {
let map_type = FieldType::Map(
Box::new(FieldType::Primitive(TypeValue::String)),
Box::new(match &field_type {
FieldType::Map(key, _) => match key.as_ref() {
FieldType::Enum(name) => FieldType::Enum(name.clone()),
_ => FieldType::string(),
},
_ => FieldType::string(),
}),
Box::new(item_type.clone()),
);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use baml_types::TypeValue;
use std::collections::VecDeque;

use baml_types::{LiteralValue, TypeValue};
use either::Either;
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span};
use internal_baml_schema_ast::ast::{
Argument, Attribute, Expression, FieldArity, FieldType, Identifier, WithName, WithSpan,
Expand Down Expand Up @@ -56,12 +59,53 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) {
field_type.span().clone(),
));
}

match &kv_types.0 {
// String key.
FieldType::Primitive(FieldArity::Required, TypeValue::String, ..) => {}
key_type => {
ctx.push_error(DatamodelError::new_validation_error(
"Maps may only have strings as keys",
key_type.span().clone(),

// Enum key.
FieldType::Symbol(FieldArity::Required, identifier, _)
if ctx
.db
.find_type(identifier)
.is_some_and(|t| matches!(t, Either::Right(_))) => {}

// Literal string key.
FieldType::Literal(FieldArity::Required, LiteralValue::String(_), ..) => {}

// Literal string union.
FieldType::Union(FieldArity::Required, items, ..) => {
let mut queue = VecDeque::from_iter(items.iter());

while let Some(item) = queue.pop_front() {
match item {
// Ok, literal string.
FieldType::Literal(
FieldArity::Required,
LiteralValue::String(_),
..,
) => {}

// Nested union, "recurse" but it's iterative.
FieldType::Union(FieldArity::Required, nested, ..) => {
queue.extend(nested.iter());
}

other => {
ctx.push_error(
DatamodelError::new_type_not_allowed_as_map_key_error(
other.span().clone(),
),
);
}
}
}
}

other => {
ctx.push_error(DatamodelError::new_type_not_allowed_as_map_key_error(
other.span().clone(),
));
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
enum MapKey {
A
B
C
}

class Fields {
e map<MapKey, string>
l1 map<"literal", string>
l2 map<"one" | "two" | ("three" | "four"), string>
}

function InOutEnumKey(i1: map<MapKey, string>, i2: map<MapKey, string>) -> map<MapKey, string> {
client "openai/gpt-4o"
prompt #"
Merge these: {{i1}} {{i2}}

{{ ctx.output_format }}
"#
}

function InOutLiteralStringUnionMapKey(
i1: map<"one" | "two" | ("three" | "four"), string>,
i2: map<"one" | "two" | ("three" | "four"), string>
) -> map<"one" | "two" | ("three" | "four"), string> {
client "openai/gpt-4o"
prompt #"
Merge these:

{{i1}}

{{i2}}

{{ ctx.output_format }}
"#
}
22 changes: 14 additions & 8 deletions engine/baml-lib/baml/tests/validation_files/class/map_types.baml
Original file line number Diff line number Diff line change
Expand Up @@ -31,49 +31,55 @@ function InputAndOutput(i1: map<string, string>, i2: map<MapDummy, string>) -> m
"#
}

// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:16
// |
// 15 |
// 16 | b1 map<int, string>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:17
// |
// 16 | b1 map<int, string>
// 17 | b2 map<float, string>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:18
// |
// 17 | b2 map<float, string>
// 18 | b3 map<MapDummy, string>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:19
// |
// 18 | b3 map<MapDummy, string>
// 19 | b4 map<string?, string>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:20
// |
// 19 | b4 map<string?, string>
// 20 | b5 map<string | int, string>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:20
// |
// 19 | b4 map<string?, string>
// 20 | b5 map<string | int, string>
// |
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:23
// |
// 22 | c1 string | map<string, string>
// 23 | c2 string | map<int, string>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:24
// |
// 23 | c2 string | map<int, string>
// 24 | c3 string | map<string?, string>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/map_types.baml:27
// |
// 26 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ComplexTypes {
// 2 | class ComplexTypes {
// 3 | a map<string[], (int | bool[]) | apple_pie[][]>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/secure_types.baml:3
// |
// 2 | class ComplexTypes {
Expand All @@ -43,7 +43,7 @@ class ComplexTypes {
// 3 | a map<string[], (int | bool[]) | apple_pie[][]>
// 4 | b (int, map<bool, string?>, (char | float)[][] | long_word_123.foobar[])
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/secure_types.baml:4
// |
// 3 | a map<string[], (int | bool[]) | apple_pie[][]>
Expand Down Expand Up @@ -73,7 +73,7 @@ class ComplexTypes {
// 5 | c apple123_456_pie | (stringer, bool[], (int | char))[]
// 6 | d map<int[][], ((int | float) | char[])>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/secure_types.baml:6
// |
// 5 | c apple123_456_pie | (stringer, bool[], (int | char))[]
Expand Down Expand Up @@ -121,7 +121,7 @@ class ComplexTypes {
// 9 | g (int, (float, char, bool), string[]) | tuple_inside_tuple[]
// 10 | h (((int | string)[]) | map<bool[][], char[]>)
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/secure_types.baml:10
// |
// 9 | g (int, (float, char, bool), string[]) | tuple_inside_tuple[]
Expand Down Expand Up @@ -181,13 +181,13 @@ class ComplexTypes {
// 12 | j ((char, int[][], (bool | string[][])) | double[][][][], (float, int)[])
// 13 | k map<string[], (int | long[])> | map<float[][], double[][]>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/secure_types.baml:13
// |
// 12 | j ((char, int[][], (bool | string[][])) | double[][][][], (float, int)[])
// 13 | k map<string[], (int | long[])> | map<float[][], double[][]>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/secure_types.baml:13
// |
// 12 | j ((char, int[][], (bool | string[][])) | double[][][][], (float, int)[])
Expand Down Expand Up @@ -247,13 +247,13 @@ class ComplexTypes {
// 15 | m (tuple_1, tuple_2 | tuple_3, (tuple_4, tuple_5))[]
// 16 | n map<complex_key_type[], map<another_key, (int | string[])>>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/secure_types.baml:16
// |
// 15 | m (tuple_1, tuple_2 | tuple_3, (tuple_4, tuple_5))[]
// 16 | n map<complex_key_type[], map<another_key, (int | string[])>>
// |
// error: Error validating: Maps may only have strings as keys
// error: Error validating: Maps may only have strings, enums or literal strings as keys
// --> class/secure_types.baml:16
// |
// 15 | m (tuple_1, tuple_2 | tuple_3, (tuple_4, tuple_5))[]
Expand Down
7 changes: 7 additions & 0 deletions engine/baml-lib/diagnostics/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,13 @@ impl DatamodelError {
Self::new(msg, span)
}

pub fn new_type_not_allowed_as_map_key_error(span: Span) -> DatamodelError {
Self::new_validation_error(
"Maps may only have strings, enums or literal strings as keys",
span,
)
}

pub fn span(&self) -> &Span {
&self.span
}
Expand Down
83 changes: 70 additions & 13 deletions engine/baml-lib/jsonish/src/deserializer/coercer/coerce_map.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use std::collections::VecDeque;

use anyhow::Result;

use crate::deserializer::{
deserialize_flags::{DeserializerConditions, Flag},
types::BamlValueWithFlags,
use crate::{
deserializer::{
deserialize_flags::{DeserializerConditions, Flag},
types::BamlValueWithFlags,
},
jsonish,
};
use baml_types::{BamlMap, FieldType, TypeValue};
use baml_types::{BamlMap, FieldType, LiteralValue, TypeValue};

use super::{ParsingContext, ParsingError, TypeCoercer};

pub(super) fn coerce_map(
ctx: &ParsingContext,
map_target: &FieldType,
value: Option<&crate::jsonish::Value>,
value: Option<&jsonish::Value>,
) -> Result<BamlValueWithFlags, ParsingError> {
log::debug!(
"scope: {scope} :: coercing to: {name} (current: {current})",
Expand All @@ -28,22 +33,74 @@ pub(super) fn coerce_map(
return Err(ctx.error_unexpected_type(map_target, value));
};

if !matches!(**key_type, FieldType::Primitive(TypeValue::String)) {
return Err(ctx.error_map_must_have_string_key(key_type));
// TODO: Do we actually need to check the key type here in the coercion
// logic? Can the user pass a "type" here at runtime? Can we pass the wrong
// type from our own code or is this guaranteed to be a valid map key type?
// If we can determine that the type is always valid then we can get rid of
// this logic and skip the loops & allocs in the the union branch.
match key_type.as_ref() {
// String, enum or just one literal string, OK.
FieldType::Primitive(TypeValue::String)
| FieldType::Enum(_)
| FieldType::Literal(LiteralValue::String(_)) => {}

// For unions we need to check if all the items are literal strings.
FieldType::Union(items) => {
let mut queue = VecDeque::from_iter(items.iter());
while let Some(item) = queue.pop_front() {
match item {
FieldType::Literal(LiteralValue::String(_)) => continue,
FieldType::Union(nested) => queue.extend(nested.iter()),
other => return Err(ctx.error_map_must_have_supported_key(other)),
}
}
}

// Key type not allowed.
other => return Err(ctx.error_map_must_have_supported_key(other)),
}

let mut flags = DeserializerConditions::new();
flags.add_flag(Flag::ObjectToMap(value.clone()));

match &value {
crate::jsonish::Value::Object(obj) => {
jsonish::Value::Object(obj) => {
let mut items = BamlMap::new();
for (key, value) in obj.iter() {
match value_type.coerce(&ctx.enter_scope(key), value_type, Some(value)) {
Ok(v) => {
items.insert(key.clone(), (DeserializerConditions::new(), v));
for (idx, (key, value)) in obj.iter().enumerate() {
let coerced_value =
match value_type.coerce(&ctx.enter_scope(key), value_type, Some(value)) {
Ok(v) => v,
Err(e) => {
flags.add_flag(Flag::MapValueParseError(key.clone(), e));
// Could not coerce value, nothing else to do here.
continue;
}
};

// Keys are just strings but since we suport enums and literals
// we have to check that the key we are reading is actually a
// valid enum member or expected literal value. The coercion
// logic already does that so we'll just coerce the key.
//
// TODO: Is it necessary to check that values match here? This
// is also checked at `coerce_arg` in
// baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
let key_as_jsonish = jsonish::Value::String(key.to_owned());
match key_type.coerce(ctx, &key_type, Some(&key_as_jsonish)) {
Ok(_) => {
// Hack to avoid cloning the key twice.
let jsonish::Value::String(owned_key) = key_as_jsonish else {
unreachable!("key_as_jsonish is defined as jsonish::Value::String");
};

// Both the value and the key were successfully
// coerced, add the key to the map.
items.insert(owned_key, (DeserializerConditions::new(), coerced_value));
}
Err(e) => flags.add_flag(Flag::MapValueParseError(key.clone(), e)),
// Couldn't coerce key, this is either not a valid enum
// variant or it doesn't match any of the literal values
// expected.
Err(e) => flags.add_flag(Flag::MapKeyParseError(idx, e)),
}
}
Ok(BamlValueWithFlags::Map(flags, items))
Expand Down
6 changes: 4 additions & 2 deletions engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ impl ParsingContext<'_> {
}
}

pub(crate) fn error_map_must_have_string_key(&self, key_type: &FieldType) -> ParsingError {
pub(crate) fn error_map_must_have_supported_key(&self, key_type: &FieldType) -> ParsingError {
ParsingError {
reason: format!("Maps may only have strings for keys, but got {}", key_type),
reason: format!(
"Maps may only have strings, enums or literal strings for keys, but got {key_type}"
),
scope: self.scope.clone(),
causes: vec![],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(super) fn visit_constraint_attributes(
ctx.push_error(DatamodelError::new_attribute_validation_error(
"Internal error - the parser should have ruled out other attribute names.",
other_name,
span
span,
));
return ();
}
Expand Down
Loading

0 comments on commit 39e0271

Please sign in to comment.