Skip to content

Commit

Permalink
workload/schemachange: add error screening to insertRow op
Browse files Browse the repository at this point in the history
Release note: None
  • Loading branch information
jayshrivastava committed Nov 25, 2020
1 parent b3c4604 commit 0bce752
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 8 deletions.
18 changes: 17 additions & 1 deletion pkg/sql/rowenc/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,13 @@ func RandDatumWithNullChance(rng *rand.Rand, typ *types.T, nullChance int) tree.
return &tree.DBitArray{BitArray: r}
case types.StringFamily:
// Generate a random ASCII string.
p := make([]byte, rng.Intn(10))
var length int
if typ.Oid() == oid.T_char || typ.Oid() == oid.T_bpchar {
length = 1
} else {
length = rng.Intn(10)
}
p := make([]byte, length)
for i := range p {
p[i] = byte(1 + rng.Intn(127))
}
Expand Down Expand Up @@ -750,6 +756,16 @@ func randInterestingDatum(rng *rand.Rand, typ *types.T) tree.Datum {
default:
panic(errors.AssertionFailedf("float with an unexpected width %d", typ.Width()))
}
case types.BitFamily:
// A width of 64 is used by all special BitFamily datums in randInterestingDatums.
// If the provided bit type, typ, has a width of 0 (representing an arbitrary width) or 64 exactly,
// then the special datum will be valid for the provided type. Otherwise, the special type
// must be resized to match the width of the provided type.
if typ.Width() == 0 || typ.Width() == 64 {
return special
}
return &tree.DBitArray{BitArray: special.(*tree.DBitArray).ToWidth(uint(typ.Width()))}

default:
return special
}
Expand Down
156 changes: 156 additions & 0 deletions pkg/workload/schemachange/error_screening.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,159 @@ func colIsPrimaryKey(tx *pgx.Tx, tableName *tree.TableName, columnName string) (
);
`, tableName.Schema(), tableName.Object(), columnName)
}

// valuesViolateUniqueConstraints determines if any unique constraints (including primary constraints)
// will be violated upon inserting the specified rows into the specified table.
func violatesUniqueConstraints(
tx *pgx.Tx, tableName *tree.TableName, columns []string, rows [][]string,
) (bool, error) {

if len(rows) == 0 {
return false, fmt.Errorf("violatesUniqueConstraints: no rows provided")
}

// Fetch unique constraints from the database. The format returned is an array of string arrays.
// Each string array is a group of column names for which a unique constraint exists.
constraints, err := scanStringArrayRows(tx, `
SELECT DISTINCT array_agg(cols.column_name ORDER BY cols.column_name)
FROM (
SELECT d.oid,
d.table_name,
d.schema_name,
conname,
contype,
unnest(conkey) AS position
FROM (
SELECT c.oid AS oid,
c.relname AS table_name,
ns.nspname AS schema_name
FROM pg_catalog.pg_class AS c
JOIN pg_catalog.pg_namespace AS ns ON
ns.oid = c.relnamespace
WHERE ns.nspname = $1
AND c.relname = $2
) AS d
JOIN (
SELECT conname, conkey, conrelid, contype
FROM pg_catalog.pg_constraint
WHERE contype = 'p' OR contype = 'u'
) ON conrelid = d.oid
) AS cons
JOIN (
SELECT table_name,
table_schema,
column_name,
ordinal_position
FROM information_schema.columns
) AS cols ON cons.schema_name = cols.table_schema
AND cols.table_name = cons.table_name
AND cols.ordinal_position = cons.position
GROUP BY cons.conname;
`, tableName.Schema(), tableName.Object())
if err != nil {
return false, err
}

for _, constraint := range constraints {
// previousRows is used to check unique constraints among the values which
// will be inserted into the database.
previousRows := map[string]bool{}
for _, row := range rows {
violation, err := violatesUniqueConstraintsHelper(tx, tableName, columns, constraint, row, previousRows)
if err != nil {
return false, err
}
if violation {
return true, nil
}
}
}

return false, nil
}

func violatesUniqueConstraintsHelper(
tx *pgx.Tx,
tableName *tree.TableName,
columns []string,
constraint []string,
row []string,
previousRows map[string]bool,
) (bool, error) {

// Put values to be inserted into a column name to value map to simplify lookups.
columnsToValues := map[string]string{}
for i := 0; i < len(columns); i++ {
columnsToValues[columns[i]] = row[i]
}

query := strings.Builder{}
query.WriteString(fmt.Sprintf(`SELECT EXISTS (
SELECT *
FROM %s
WHERE
`, tableName.String()))

atLeastOneNonNullValue := false
for _, column := range constraint {

// Null values are not checked because unique constraints do not apply to null values.
if columnsToValues[column] != "NULL" {
if atLeastOneNonNullValue {
query.WriteString(fmt.Sprintf(` AND %s = %s`, column, columnsToValues[column]))
} else {
query.WriteString(fmt.Sprintf(`%s = %s`, column, columnsToValues[column]))
}

atLeastOneNonNullValue = true
}
}
query.WriteString(")")

// If there are only null values being inserted for each of the constrained columns,
// then checking for uniqueness against other rows is not necessary.
if !atLeastOneNonNullValue {
return false, nil
}

queryString := query.String()

// Check for uniqueness against other rows to be inserted. For simplicity, the `SELECT EXISTS`
// query used to check for uniqueness against rows in the database can also
// be used as a unique key to check for uniqueness among rows to be inserted.
if _, duplicateEntry := previousRows[queryString]; duplicateEntry {
return true, nil
}
previousRows[queryString] = true

// Check for uniqueness against rows in the database.
exists, err := scanBool(tx, queryString)
if err != nil {
return false, err
}
if exists {
return true, nil
}

return false, nil
}

func scanStringArrayRows(tx *pgx.Tx, query string, args ...interface{}) ([][]string, error) {
rows, err := tx.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()

results := [][]string{}
for rows.Next() {
var columnNames []string
err := rows.Scan(&columnNames)
if err != nil {
return nil, err
}
results = append(results, columnNames)
}

return results, err
}
46 changes: 39 additions & 7 deletions pkg/workload/schemachange/operation_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ var opsWithExecErrorScreening = map[opType]bool{
renameSequence: true,
renameTable: true,
renameView: true,

insertRow: true,
}

func opScreensForExecErrors(op opType) bool {
Expand Down Expand Up @@ -1205,29 +1207,57 @@ func (og *operationGenerator) insertRow(tx *pgx.Tx) (string, error) {
if err != nil {
return "", errors.Wrapf(err, "error getting random table name")
}
tableExists, err := tableExists(tx, tableName)
if err != nil {
return "", err
}
if !tableExists {
og.expectedExecErrors.add(pgcode.UndefinedTable)
return fmt.Sprintf(
`INSERT INTO %s (IrrelevantColumnName) VALUES ("IrrelevantValue")`,
tableName,
), nil
}
cols, err := og.getTableColumns(tx, tableName.String())
if err != nil {
return "", errors.Wrapf(err, "error getting table columns for insert row")
}
colNames := []string{}
rows := []string{}
rows := [][]string{}
for _, col := range cols {
colNames = append(colNames, fmt.Sprintf(`"%s"`, col.name))
colNames = append(colNames, col.name)
}
numRows := og.randIntn(10) + 1
numRows := og.randIntn(3) + 1
for i := 0; i < numRows; i++ {
var row []string
for _, col := range cols {
d := rowenc.RandDatum(og.params.rng, col.typ, col.nullable)
row = append(row, tree.AsStringWithFlags(d, tree.FmtParsable))
}
rows = append(rows, fmt.Sprintf("(%s)", strings.Join(row, ",")))

rows = append(rows, row)
}

// Verify if the new row will violate unique constraints by checking the constraints and
// existing rows in the database.
uniqueConstraintViolation, err := violatesUniqueConstraints(tx, tableName, colNames, rows)
if err != nil {
return "", err
}
if uniqueConstraintViolation {
og.expectedExecErrors.add(pgcode.UniqueViolation)
}

formattedRows := []string{}
for _, row := range rows {
formattedRows = append(formattedRows, fmt.Sprintf("(%s)", strings.Join(row, ",")))
}

return fmt.Sprintf(
`INSERT INTO %s (%s) VALUES %s`,
tableName,
strings.Join(colNames, ","),
strings.Join(rows, ","),
strings.Join(formattedRows, ","),
), nil
}

Expand Down Expand Up @@ -1287,8 +1317,10 @@ func (og *operationGenerator) getTableColumns(tx *pgx.Tx, tableName string) ([]c
if err != nil {
return nil, err
}
typNames = append(typNames, typName)
ret = append(ret, c)
if c.name != "rowid" {
typNames = append(typNames, typName)
ret = append(ret, c)
}
}
if err := rows.Err(); err != nil {
return nil, err
Expand Down

0 comments on commit 0bce752

Please sign in to comment.