Skip to content

Commit

Permalink
Fix Spark parameter creation when passing a nil-value named paramet…
Browse files Browse the repository at this point in the history
…er to a query (#199)

- When a `sql.NamedValue` has the field `Value` set to `nil`, the
resulting `cli_service.TSparkParameter` will also have the value `nil`
instead of `*cli_service.TSparkParameterValue{StringValue:
*"%!s(<nil>"}`.
- Add the type `SqlVoid`, following the conventions used in the [NodeJS
connector](https://github.com/databricks/databricks-sql-nodejs/blob/main/lib/DBSQLParameter.ts#L43-L51)
and the [Python
driver](https://github.com/databricks/databricks-sql-python/blob/f6fd7a7956a4dbc78ad36b5e079fe8d74176a0f1/src/databricks/sql/parameters/native.py#L319-L323).

Fix #193.

---------

Signed-off-by: Esdras Beleza <[email protected]>
Signed-off-by: Levko Kravets <[email protected]>
Signed-off-by: candiduslynx <[email protected]>
Co-authored-by: Levko Kravets <[email protected]>
Co-authored-by: Mahdi Dibaiee <[email protected]>
Co-authored-by: Alex Shcherbakov <[email protected]>
  • Loading branch information
4 people authored Apr 16, 2024
1 parent 4f9a1a1 commit aeb5e5d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Release History

- Fix formatting of *float64 parameters
- Bug fix for issue 193: convertNamedValuesToSparkParams was incorrectly creating a Spark parameter value as "%!s(<nil>)" when a named param was nil (databricks/databricks-sql-go#199 by @esdrasbeleza)
- Fix formatting of *float64 parameters (databricks/databricks-sql-go#215 by @esdrasbeleza)

## v1.5.4 (2024-04-10)

Expand Down
7 changes: 5 additions & 2 deletions parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import (

func TestParameter_Inference(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [6]driver.NamedValue{
values := [7]driver.NamedValue{
{Name: "", Value: float32(5.1)},
{Name: "", Value: time.Now()},
{Name: "", Value: int64(5)},
{Name: "", Value: true},
{Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}},
{Name: "", Value: nil},
{Name: "", Value: Parameter{Value: float64Ptr(6.2), Type: SqlUnkown}},
}
parameters := convertNamedValuesToSparkParams(values[:])
Expand All @@ -28,7 +29,9 @@ func TestParameter_Inference(t *testing.T) {
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[4].Type)
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, parameters[5].Value)
assert.Equal(t, string("VOID"), *parameters[5].Type)
assert.Nil(t, parameters[5].Value)
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, parameters[6].Value)
})
}
func TestParameters_Names(t *testing.T) {
Expand Down
19 changes: 16 additions & 3 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const (
SqlBoolean
SqlIntervalMonth
SqlIntervalDay
SqlVoid
)

func (s SqlType) String() string {
Expand Down Expand Up @@ -64,6 +65,8 @@ func (s SqlType) String() string {
return "INTERVAL MONTH"
case SqlIntervalDay:
return "INTERVAL DAY"
case SqlVoid:
return "VOID"
}
return "unknown"
}
Expand Down Expand Up @@ -149,6 +152,9 @@ func inferType(param *Parameter) {
case time.Time:
param.Value = value.Format(time.RFC3339Nano)
param.Type = SqlTimestamp
case nil:
param.Value = nil
param.Type = SqlVoid
default:
s := fmt.Sprintf("%s", param.Value)
param.Value = s
Expand All @@ -163,14 +169,21 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.
inferTypes(sqlParams)
for i := range sqlParams {
sqlParam := sqlParams[i]
sparkParamValue := sqlParam.Value.(string)
sparkValue := new(cli_service.TSparkParameterValue)
if sqlParam.Type == SqlVoid {
sparkValue = nil
} else {
stringValue := sqlParam.Value.(string)
sparkValue = &cli_service.TSparkParameterValue{StringValue: &stringValue}
}

var sparkParamType string
if sqlParam.Type == SqlDecimal {
sparkParamType = inferDecimalType(sparkParamValue)
sparkParamType = inferDecimalType(sparkValue.GetStringValue())
} else {
sparkParamType = sqlParam.Type.String()
}
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: &cli_service.TSparkParameterValue{StringValue: &sparkParamValue}}
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: sparkValue}
sparkParams = append(sparkParams, &sparkParam)
}
return sparkParams
Expand Down

0 comments on commit aeb5e5d

Please sign in to comment.