Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

workload/schemachange: improve error screening #56379

Merged
18 changes: 17 additions & 1 deletion pkg/sql/rowenc/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,13 @@ func RandDatumWithNullChance(rng *rand.Rand, typ *types.T, nullChance int) tree.
return &tree.DBitArray{BitArray: r}
case types.StringFamily:
// Generate a random ASCII string.
p := make([]byte, rng.Intn(10))
var length int
if typ.Oid() == oid.T_char || typ.Oid() == oid.T_bpchar {
length = 1
} else {
length = rng.Intn(10)
}
p := make([]byte, length)
for i := range p {
p[i] = byte(1 + rng.Intn(127))
}
Expand Down Expand Up @@ -750,6 +756,16 @@ func randInterestingDatum(rng *rand.Rand, typ *types.T) tree.Datum {
default:
panic(errors.AssertionFailedf("float with an unexpected width %d", typ.Width()))
}
case types.BitFamily:
// A width of 64 is used by all special BitFamily datums in randInterestingDatums.
// If the provided bit type, typ, has a width of 0 (representing an arbitrary width) or 64 exactly,
// then the special datum will be valid for the provided type. Otherwise, the special type
// must be resized to match the width of the provided type.
if typ.Width() == 0 || typ.Width() == 64 {
return special
}
return &tree.DBitArray{BitArray: special.(*tree.DBitArray).ToWidth(uint(typ.Width()))}

default:
return special
}
Expand Down
165 changes: 165 additions & 0 deletions pkg/workload/schemachange/error_screening.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ func viewExists(tx *pgx.Tx, tableName *tree.TableName) (bool, error) {
)`, tableName.Schema(), tableName.Object())
}

func sequenceExists(tx *pgx.Tx, seqName *tree.TableName) (bool, error) {
return scanBool(tx, `SELECT EXISTS (
SELECT sequence_name
FROM information_schema.sequences
WHERE sequence_schema = $1
AND sequence_name = $2
)`, seqName.Schema(), seqName.Object())
}

func columnExistsOnTable(tx *pgx.Tx, tableName *tree.TableName, columnName string) (bool, error) {
return scanBool(tx, `SELECT EXISTS (
SELECT column_name
Expand Down Expand Up @@ -156,3 +165,159 @@ func colIsPrimaryKey(tx *pgx.Tx, tableName *tree.TableName, columnName string) (
);
`, tableName.Schema(), tableName.Object(), columnName)
}

// valuesViolateUniqueConstraints determines if any unique constraints (including primary constraints)
// will be violated upon inserting the specified rows into the specified table.
func violatesUniqueConstraints(
tx *pgx.Tx, tableName *tree.TableName, columns []string, rows [][]string,
) (bool, error) {

if len(rows) == 0 {
return false, fmt.Errorf("violatesUniqueConstraints: no rows provided")
}

// Fetch unique constraints from the database. The format returned is an array of string arrays.
// Each string array is a group of column names for which a unique constraint exists.
constraints, err := scanStringArrayRows(tx, `
SELECT DISTINCT array_agg(cols.column_name ORDER BY cols.column_name)
FROM (
SELECT d.oid,
d.table_name,
d.schema_name,
conname,
contype,
unnest(conkey) AS position
FROM (
SELECT c.oid AS oid,
c.relname AS table_name,
ns.nspname AS schema_name
FROM pg_catalog.pg_class AS c
JOIN pg_catalog.pg_namespace AS ns ON
ns.oid = c.relnamespace
WHERE ns.nspname = $1
AND c.relname = $2
) AS d
JOIN (
SELECT conname, conkey, conrelid, contype
FROM pg_catalog.pg_constraint
WHERE contype = 'p' OR contype = 'u'
) ON conrelid = d.oid
) AS cons
JOIN (
SELECT table_name,
table_schema,
column_name,
ordinal_position
FROM information_schema.columns
) AS cols ON cons.schema_name = cols.table_schema
AND cols.table_name = cons.table_name
AND cols.ordinal_position = cons.position
GROUP BY cons.conname;
`, tableName.Schema(), tableName.Object())
if err != nil {
return false, err
}

for _, constraint := range constraints {
// previousRows is used to check unique constraints among the values which
// will be inserted into the database.
previousRows := map[string]bool{}
for _, row := range rows {
violation, err := violatesUniqueConstraintsHelper(tx, tableName, columns, constraint, row, previousRows)
if err != nil {
return false, err
}
if violation {
return true, nil
}
}
}

return false, nil
}

func violatesUniqueConstraintsHelper(
tx *pgx.Tx,
tableName *tree.TableName,
columns []string,
constraint []string,
row []string,
previousRows map[string]bool,
) (bool, error) {

// Put values to be inserted into a column name to value map to simplify lookups.
columnsToValues := map[string]string{}
for i := 0; i < len(columns); i++ {
columnsToValues[columns[i]] = row[i]
}

query := strings.Builder{}
query.WriteString(fmt.Sprintf(`SELECT EXISTS (
SELECT *
FROM %s
WHERE
`, tableName.String()))

atLeastOneNonNullValue := false
for _, column := range constraint {

// Null values are not checked because unique constraints do not apply to null values.
if columnsToValues[column] != "NULL" {
if atLeastOneNonNullValue {
query.WriteString(fmt.Sprintf(` AND %s = %s`, column, columnsToValues[column]))
} else {
query.WriteString(fmt.Sprintf(`%s = %s`, column, columnsToValues[column]))
}

atLeastOneNonNullValue = true
}
}
query.WriteString(")")

// If there are only null values being inserted for each of the constrained columns,
// then checking for uniqueness against other rows is not necessary.
if !atLeastOneNonNullValue {
return false, nil
}

queryString := query.String()

// Check for uniqueness against other rows to be inserted. For simplicity, the `SELECT EXISTS`
// query used to check for uniqueness against rows in the database can also
// be used as a unique key to check for uniqueness among rows to be inserted.
if _, duplicateEntry := previousRows[queryString]; duplicateEntry {
return true, nil
}
previousRows[queryString] = true

// Check for uniqueness against rows in the database.
exists, err := scanBool(tx, queryString)
if err != nil {
return false, err
}
if exists {
return true, nil
}

return false, nil
}

func scanStringArrayRows(tx *pgx.Tx, query string, args ...interface{}) ([][]string, error) {
rows, err := tx.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()

results := [][]string{}
for rows.Next() {
var columnNames []string
err := rows.Scan(&columnNames)
if err != nil {
return nil, err
}
results = append(results, columnNames)
}

return results, err
}
Loading