diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go index 9dd08b59..518055b7 100644 --- a/alwaysencrypted_test.go +++ b/alwaysencrypted_test.go @@ -41,6 +41,13 @@ type aeColumnInfo struct { sampleValue interface{} } +type customValuer struct { +} + +func (n customValuer) Value() (driver.Value, error) { + return nil, nil +} + func TestAlwaysEncryptedE2E(t *testing.T) { params := testConnParams(t) if !params.ColumnEncryption { @@ -53,7 +60,11 @@ func TestAlwaysEncryptedE2E(t *testing.T) { {"int", "INT", ColumnEncryptionDeterministic, int32(1)}, {"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")}, {"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)}, + {"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: false}}, + {"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: true, Byte: 1}}, {"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)}, + {"smallint", "SMALLINT", ColumnEncryptionRandomized, sql.NullInt16{Valid: false}}, + {"smallint", "SMALLINT", ColumnEncryptionDeterministic, sql.NullInt16{Valid: true, Int16: 32000}}, {"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)}, // We can't use fractional float/real values due to rounding errors in the round trip {"real", "REAL", ColumnEncryptionDeterministic, float32(5)}, @@ -67,9 +78,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) { {"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)}, {"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")}, {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, + {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: true, Int32: -75000}}, {"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}}, + {"bigint", "BIGINT", ColumnEncryptionRandomized, sql.NullInt64{Valid: false}}, {"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}}, {"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}}, + {"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: false}}, + {"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: true, Time: time.Now()}}, } for _, test := range providerTests { // turn off key caching @@ -108,7 +123,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) { _, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName)) _, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName)) _, _ = sel.WriteString("select top(1) ") - insertArgs := make([]interface{}, len(encryptableColumns)+1) + insertArgs := make([]interface{}, len(encryptableColumns)+2) for i, ec := range encryptableColumns { encType := "RANDOMIZED" null := "" @@ -128,11 +143,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) { insert.WriteString(fmt.Sprintf("@p%d,", i+1)) sel.WriteString(fmt.Sprintf("col%d,", i)) } - _, _ = query.WriteString("unencryptedcolumn nvarchar(100)") + _, _ = query.WriteString("unencryptedcolumn nvarchar(100),") + _, _ = query.WriteString("nullableCustomValuer int NULL") _, _ = query.WriteString(")") insertArgs[len(encryptableColumns)] = "unencryptedvalue" - insert.WriteString(fmt.Sprintf("@p%d)", len(encryptableColumns)+1)) - sel.WriteString(fmt.Sprintf("unencryptedcolumn from [%s]", tableName)) + insertArgs[len(encryptableColumns)+1] = customValuer{} + insert.WriteString(fmt.Sprintf("@p%d,@p%d)", len(encryptableColumns)+1, len(encryptableColumns)+2)) + sel.WriteString(fmt.Sprintf("unencryptedcolumn, nullableCustomValuer from [%s]", tableName)) _, err = conn.Exec(query.String()) assert.NoError(t, err, "Failed to create encrypted table") defer func() { _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) }() @@ -152,13 +169,15 @@ func TestAlwaysEncryptedE2E(t *testing.T) { } var unencryptedColumnValue string - scanValues := make([]interface{}, len(encryptableColumns)+1) + var nullint sql.NullInt32 + scanValues := make([]interface{}, len(encryptableColumns)+2) for v := range scanValues { if v < len(encryptableColumns) { scanValues[v] = new(interface{}) } } scanValues[len(encryptableColumns)] = &unencryptedColumnValue + scanValues[len(encryptableColumns)+1] = &nullint err = rows.Scan(scanValues...) defer rows.Close() if err != nil { @@ -182,6 +201,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) { assert.Equalf(t, expectedStrVal, strVal, "Incorrect value for col%d. ", i) } assert.Equalf(t, "unencryptedvalue", unencryptedColumnValue, "Got wrong value for unencrypted column") + assert.False(t, nullint.Valid, "custom valuer should have null value") _ = rows.Next() err = rows.Err() assert.NoError(t, err, "rows.Err() has non-nil values") diff --git a/mssql.go b/mssql.go index f86d5361..2d940ddf 100644 --- a/mssql.go +++ b/mssql.go @@ -983,6 +983,19 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { return } switch valuer := val.(type) { + // sql.Nullxxx integer types return an int64. We want the original type, to match the SQL type size. + case sql.NullByte: + if valuer.Valid { + return s.makeParam(valuer.Byte) + } + case sql.NullInt16: + if valuer.Valid { + return s.makeParam(valuer.Int16) + } + case sql.NullInt32: + if valuer.Valid { + return s.makeParam(valuer.Int32) + } case UniqueIdentifier: case NullUniqueIdentifier: default: @@ -1052,9 +1065,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { res.ti.Size = 8 res.buffer = []byte{} case sql.NullInt32: + // only null values should be getting here res.ti.TypeId = typeIntN res.ti.Size = 4 res.buffer = []byte{} + case sql.NullInt16: + // only null values should be getting here + res.buffer = []byte{} + res.ti.Size = 2 + res.ti.TypeId = typeIntN + case sql.NullByte: + // only null values should be getting here + res.buffer = []byte{} + res.ti.Size = 1 + res.ti.TypeId = typeIntN case byte: res.ti.TypeId = typeIntN res.buffer = []byte{val} @@ -1110,6 +1134,18 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { res.buffer = encodeDateTime(val) res.ti.Size = len(res.buffer) } + case sql.NullTime: // only null values reach here + res.buffer = []byte{} + res.ti.Size = 8 + if s.c.sess.loginAck.TDSVersion >= verTDS73 { + res.ti.TypeId = typeDateTimeOffsetN + res.ti.Scale = 7 + } else { + res.ti.TypeId = typeDateTimeN + } + case driver.Valuer: + // We have a custom Valuer implementation with a nil value + return s.makeParam(nil) default: return s.makeParamExtra(val) }