Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: handle sql.NullTime parameters #195

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions alwaysencrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ type aeColumnInfo struct {
sampleValue interface{}
}

type customValuer struct {
}

func (n customValuer) Value() (driver.Value, error) {
return nil, nil
}

func TestAlwaysEncryptedE2E(t *testing.T) {
params := testConnParams(t)
if !params.ColumnEncryption {
Expand All @@ -53,7 +60,11 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
{"int", "INT", ColumnEncryptionDeterministic, int32(1)},
{"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")},
{"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)},
{"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: false}},
{"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: true, Byte: 1}},
{"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)},
{"smallint", "SMALLINT", ColumnEncryptionRandomized, sql.NullInt16{Valid: false}},
{"smallint", "SMALLINT", ColumnEncryptionDeterministic, sql.NullInt16{Valid: true, Int16: 32000}},
{"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)},
// We can't use fractional float/real values due to rounding errors in the round trip
{"real", "REAL", ColumnEncryptionDeterministic, float32(5)},
Expand All @@ -67,9 +78,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: true, Int32: -75000}},
{"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
{"bigint", "BIGINT", ColumnEncryptionRandomized, sql.NullInt64{Valid: false}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}},
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: false}},
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: true, Time: time.Now()}},
}
for _, test := range providerTests {
// turn off key caching
Expand Down Expand Up @@ -108,7 +123,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
_, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName))
_, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName))
_, _ = sel.WriteString("select top(1) ")
insertArgs := make([]interface{}, len(encryptableColumns)+1)
insertArgs := make([]interface{}, len(encryptableColumns)+2)
for i, ec := range encryptableColumns {
encType := "RANDOMIZED"
null := ""
Expand All @@ -128,11 +143,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
insert.WriteString(fmt.Sprintf("@p%d,", i+1))
sel.WriteString(fmt.Sprintf("col%d,", i))
}
_, _ = query.WriteString("unencryptedcolumn nvarchar(100)")
_, _ = query.WriteString("unencryptedcolumn nvarchar(100),")
_, _ = query.WriteString("nullableCustomValuer int NULL")
_, _ = query.WriteString(")")
insertArgs[len(encryptableColumns)] = "unencryptedvalue"
insert.WriteString(fmt.Sprintf("@p%d)", len(encryptableColumns)+1))
sel.WriteString(fmt.Sprintf("unencryptedcolumn from [%s]", tableName))
insertArgs[len(encryptableColumns)+1] = customValuer{}
insert.WriteString(fmt.Sprintf("@p%d,@p%d)", len(encryptableColumns)+1, len(encryptableColumns)+2))
sel.WriteString(fmt.Sprintf("unencryptedcolumn, nullableCustomValuer from [%s]", tableName))
_, err = conn.Exec(query.String())
assert.NoError(t, err, "Failed to create encrypted table")
defer func() { _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) }()
Expand All @@ -152,13 +169,15 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
}

var unencryptedColumnValue string
scanValues := make([]interface{}, len(encryptableColumns)+1)
var nullint sql.NullInt32
scanValues := make([]interface{}, len(encryptableColumns)+2)
for v := range scanValues {
if v < len(encryptableColumns) {
scanValues[v] = new(interface{})
}
}
scanValues[len(encryptableColumns)] = &unencryptedColumnValue
scanValues[len(encryptableColumns)+1] = &nullint
err = rows.Scan(scanValues...)
defer rows.Close()
if err != nil {
Expand All @@ -182,6 +201,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
assert.Equalf(t, expectedStrVal, strVal, "Incorrect value for col%d. ", i)
}
assert.Equalf(t, "unencryptedvalue", unencryptedColumnValue, "Got wrong value for unencrypted column")
assert.False(t, nullint.Valid, "custom valuer should have null value")
_ = rows.Next()
err = rows.Err()
assert.NoError(t, err, "rows.Err() has non-nil values")
Expand Down
36 changes: 36 additions & 0 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,19 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
return
}
switch valuer := val.(type) {
// sql.Nullxxx integer types return an int64. We want the original type, to match the SQL type size.
case sql.NullByte:
if valuer.Valid {
return s.makeParam(valuer.Byte)
}
case sql.NullInt16:
if valuer.Valid {
return s.makeParam(valuer.Int16)
}
case sql.NullInt32:
if valuer.Valid {
return s.makeParam(valuer.Int32)
}
case UniqueIdentifier:
case NullUniqueIdentifier:
default:
Expand Down Expand Up @@ -1052,9 +1065,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.Size = 8
res.buffer = []byte{}
case sql.NullInt32:
// only null values should be getting here
res.ti.TypeId = typeIntN
res.ti.Size = 4
res.buffer = []byte{}
case sql.NullInt16:
// only null values should be getting here
res.buffer = []byte{}
res.ti.Size = 2
res.ti.TypeId = typeIntN
case sql.NullByte:
// only null values should be getting here
res.buffer = []byte{}
res.ti.Size = 1
res.ti.TypeId = typeIntN
case byte:
res.ti.TypeId = typeIntN
res.buffer = []byte{val}
Expand Down Expand Up @@ -1110,6 +1134,18 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.buffer = encodeDateTime(val)
res.ti.Size = len(res.buffer)
}
case sql.NullTime: // only null values reach here
res.buffer = []byte{}
res.ti.Size = 8
if s.c.sess.loginAck.TDSVersion >= verTDS73 {
res.ti.TypeId = typeDateTimeOffsetN
res.ti.Scale = 7
} else {
res.ti.TypeId = typeDateTimeN
}
case driver.Valuer:
// We have a custom Valuer implementation with a nil value
return s.makeParam(nil)
default:
return s.makeParamExtra(val)
}
Expand Down
Loading