diff --git a/merlin/systems/triton/conversions.py b/merlin/systems/triton/conversions.py index cb61d4983..4a6b57f87 100644 --- a/merlin/systems/triton/conversions.py +++ b/merlin/systems/triton/conversions.py @@ -83,17 +83,23 @@ def match_representations(schema: Schema, dict_array: Dict[str, Any]) -> Dict[st Dict[str, Any] A dictionary of NumPy or CuPy ndarrays with representations adjusted """ - schema_names = tensor_names(schema) - aligned = {} - for tensor_name in dict_array.keys(): - if tensor_name in schema_names: - aligned[tensor_name] = dict_array[tensor_name] + for col_name, col_schema in schema.column_schemas.items(): + if col_schema.is_ragged: + vals_name = f"{col_name}__values" + offs_name = f"{col_name}__offsets" + + try: + # Look for values and offsets that already exist + aligned[vals_name] = dict_array[vals_name] + aligned[offs_name] = dict_array[offs_name] + except KeyError: + # If you don't find them, create the offsets + values, offsets = _to_values_offsets(dict_array[col_name]) + aligned[vals_name] = values + aligned[offs_name] = offsets else: - # Ragged columns with fixed shape values - values, offsets = _to_values_offsets(dict_array[tensor_name]) - aligned[f"{tensor_name}__values"] = values - aligned[f"{tensor_name}__offsets"] = offsets + aligned[col_name] = dict_array[col_name] return aligned