Skip to content

Commit

Permalink
Fix column references in versioned schema code
Browse files Browse the repository at this point in the history
  • Loading branch information
simolus3 committed Oct 2, 2024
1 parent 7665813 commit be1a5ce
Show file tree
Hide file tree
Showing 29 changed files with 1,368 additions and 52 deletions.
9 changes: 9 additions & 0 deletions drift/lib/internal/versioned_schema.dart
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ class VersionedTable extends Table with TableInfo<Table, QueryRow> {
VersionedTable createAlias(String alias) {
return VersionedTable.aliased(source: this, alias: alias);
}

/// Generates an expression referencing a column in the same table with the
/// given [name].
///
/// Intended for generated code.
static Expression<T> col<T extends Object>(String name) {
return CustomExpression(SqlDialect.sqlite.escape(name),
precedence: Precedence.primary);
}
}

/// The version of [VersionedTable] for virtual tables.
Expand Down
25 changes: 22 additions & 3 deletions drift_dev/lib/src/analysis/resolver/dart/column.dart
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,25 @@ const String _errorMessage = 'This getter does not create a valid column that '
class ColumnParser {
final DartTableResolver _resolver;

ColumnParser(this._resolver);
/// A map of elements to their name for elements defining columns.
///
/// This is used to recognize column references in arbitrary Dart code, e.g.
/// in this definition:
///
/// ```
/// DateTimeColumn get creationTime => dateTime()
/// .check(creationTime.isBiggerThan(Constant(DateTime(2020))))();
/// ```
///
/// Here, the check constraint references the column itself. In some code
/// generation modes where we generate code for individual columns (instead
/// of for entire table structures, this mainly includes step-by-step
/// migrations), there might not be a `creationTime` in scope for the check
/// constraint. So, we annotate these references in [AnnotatedDartCode] and
/// use that information when generating code to transform the code.
final Map<Element, String> _columnsInSameTable;

ColumnParser(this._resolver, this._columnsInSameTable);

Future<PendingColumnInformation?> parse(
ColumnDeclaration columnDeclaration, Element element) async {
Expand Down Expand Up @@ -343,8 +361,9 @@ class ColumnParser {
break;
case _methodCheck:
final expr = remainingExpr.argumentList.arguments.first;
foundConstraints
.add(DartCheckExpression(AnnotatedDartCode.ast(expr)));

foundConstraints.add(DartCheckExpression(AnnotatedDartCode.build(
(b) => b.addAstNode(expr, taggedElements: _columnsInSameTable))));
}

// We're not at a starting method yet, so we need to go deeper!
Expand Down
14 changes: 8 additions & 6 deletions drift_dev/lib/src/analysis/resolver/dart/table.dart
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ class DartTableResolver extends LocalElementResolver<DiscoveredDartTable> {
element.lookUpInheritedConcreteGetter(name, element.library);
// ignore: deprecated_member_use
return getter!.variable;
});
}).toList();
final all = {for (final entry in fields) entry.getter ?? entry: entry.name};

final results = <PendingColumnInformation>[];
for (final field in fields) {
final ColumnDeclaration node;
Expand All @@ -317,14 +319,14 @@ class DartTableResolver extends LocalElementResolver<DiscoveredDartTable> {
.loadElementDeclaration(field.declaration)
as VariableDeclaration,
null);
column = await _parseColumn(node, field.declaration);
column = await _parseColumn(node, field.declaration, all);
} else {
node = ColumnDeclaration(
null,
await resolver.driver.backend.loadElementDeclaration(field.getter!)
as MethodDeclaration);

column = await _parseColumn(node, field.getter!);
column = await _parseColumn(node, field.getter!, all);
}

if (column != null) {
Expand All @@ -335,9 +337,9 @@ class DartTableResolver extends LocalElementResolver<DiscoveredDartTable> {
return results.whereType();
}

Future<PendingColumnInformation?> _parseColumn(
ColumnDeclaration declaration, Element element) async {
return ColumnParser(this).parse(declaration, element);
Future<PendingColumnInformation?> _parseColumn(ColumnDeclaration declaration,
Element element, Map<Element, String> allColumns) async {
return ColumnParser(this, allColumns).parse(declaration, element);
}

Future<List<String>> _readCustomConstraints(Set<DriftElement> references,
Expand Down
4 changes: 2 additions & 2 deletions drift_dev/lib/src/analysis/resolver/shared/data_class.dart
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ CustomParentClass? parseCustomParentClass(
if (genericType.isDartCoreObject || genericType is DynamicType) {
code = AnnotatedDartCode([
DartTopLevelSymbol.topLevelElement(extendingType.element),
'<',
const DartLexeme('<'),
DartTopLevelSymbol(
dartTypeName ?? dataClassNameForClassName(element.name),
null,
),
'>',
const DartLexeme('>'),
]);
} else {
resolver.reportError(
Expand Down
96 changes: 80 additions & 16 deletions drift_dev/lib/src/analysis/results/dart.dart
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ class AnnotatedDartCode {
static final Uri dartCore = Uri.parse('dart:core');
static final Uri drift = Uri.parse('package:drift/drift.dart');

final List<dynamic /* String|DartTopLevelSymbol */ > elements;
final List<DartCodeElement> elements;

AnnotatedDartCode(this.elements)
: assert(elements.every((e) => e is String || e is DartTopLevelSymbol));
AnnotatedDartCode(this.elements);

AnnotatedDartCode.text(String e) : elements = [e];
AnnotatedDartCode.text(String e) : elements = [DartLexeme(e)];

factory AnnotatedDartCode.ast(AstNode node) {
return AnnotatedDartCode.build(((builder) => builder.addAstNode(node)));
Expand All @@ -55,8 +54,7 @@ class AnnotatedDartCode {
final serializedElements = json['elements'] as List;

return AnnotatedDartCode([
for (final part in serializedElements)
if (part is Map) DartTopLevelSymbol.fromJson(part) else part as String
for (final part in serializedElements) DartCodeElement.fromJson(part)
]);
}

Expand All @@ -66,10 +64,7 @@ class AnnotatedDartCode {

Map<String, Object?> toJson() {
return {
'elements': [
for (final element in elements)
if (element is DartTopLevelSymbol) element.toJson() else element
],
'elements': [for (final element in elements) element.toJson()],
};
}

Expand All @@ -90,12 +85,12 @@ class AnnotatedDartCode {
}

class AnnotatedDartCodeBuilder {
final List<dynamic> _elements = [];
final List<DartCodeElement> _elements = [];
final StringBuffer _pendingText = StringBuffer();

void _addPendingText() {
if (_pendingText.isNotEmpty) {
_elements.add(_pendingText.toString());
_elements.add(DartLexeme(_pendingText.toString()));
_pendingText.clear();
}
}
Expand All @@ -122,12 +117,21 @@ class AnnotatedDartCodeBuilder {
_elements.add(DartTopLevelSymbol.topLevelElement(element));
}

void addTagged(String lexeme, String tag) {
_addPendingText();
_elements.add(TaggedDartLexeme(lexeme, tag));
}

void addDartType(DartType type) {
type.accept(_AddFromDartType(this));
}

void addAstNode(AstNode node, {Set<AstNode> exclude = const {}}) {
final visitor = _AddFromAst(this, exclude);
void addAstNode(
AstNode node, {
Set<AstNode> exclude = const {},
Map<Element, String> taggedElements = const {},
}) {
final visitor = _AddFromAst(this, exclude, taggedElements);
node.accept(visitor);
}

Expand Down Expand Up @@ -224,8 +228,64 @@ class AnnotatedDartCodeBuilder {
}
}

sealed class DartCodeElement {
Object? toJson();

factory DartCodeElement.fromJson(Object? json) {
return switch (json) {
String s => DartLexeme(s),
{'import_uri': _} => DartTopLevelSymbol.fromJson(json),
{'tag': _} => TaggedDartLexeme.fromJson(json),
_ => throw ArgumentError.value(json, 'json', 'Unknown code element'),
};
}
}

final class DartLexeme implements DartCodeElement {
final String lexeme;

const DartLexeme(this.lexeme);

@override
Object? toJson() {
return lexeme;
}

@override
String toString() {
return lexeme;
}
}

/// A variant of [DartLexeme] with a custom associated [tag].
///
/// For a motivation, see `ColumnParser._columnsInSameTable` - essentially, some
/// drift tools need to resolve column references in Dart code to rewrite them
/// depending on the generation mode.
@JsonSerializable()
final class TaggedDartLexeme implements DartCodeElement {
final String lexeme;
final String tag;

TaggedDartLexeme(this.lexeme, this.tag);

factory TaggedDartLexeme.fromJson(Map json) =>
_$TaggedDartLexemeFromJson(json);

@override
Map<String, Object?> toJson() => _$TaggedDartLexemeToJson(this);

@override
String toString() {
return lexeme;
}
}

/// A variant of [DartLexeme] that is used for top-level elements to also store
/// the import URI. This allows drift's code generator, when encountering such
/// element, to automatically add the relevant import to generated Dart files.
@JsonSerializable()
class DartTopLevelSymbol {
final class DartTopLevelSymbol implements DartCodeElement {
static final _driftUri = Uri.parse('package:drift/drift.dart');

static final list = DartTopLevelSymbol('List', AnnotatedDartCode.dartCore);
Expand Down Expand Up @@ -259,6 +319,7 @@ class DartTopLevelSymbol {
factory DartTopLevelSymbol.fromJson(Map json) =>
_$DartTopLevelSymbolFromJson(json);

@override
Map<String, Object?> toJson() => _$DartTopLevelSymbolToJson(this);
}

Expand Down Expand Up @@ -453,8 +514,9 @@ class _AddFromDartType extends UnifyingTypeVisitor<void> {
class _AddFromAst extends GeneralizingAstVisitor<void> {
final AnnotatedDartCodeBuilder _builder;
final Set<AstNode> _excluding;
final Map<Element, String> _taggedElements;

_AddFromAst(this._builder, this._excluding);
_AddFromAst(this._builder, this._excluding, this._taggedElements);

void _addTopLevelReference(Element? element, Token name2) {
if (element == null || (element.isSynthetic && element.library == null)) {
Expand Down Expand Up @@ -575,6 +637,8 @@ class _AddFromAst extends GeneralizingAstVisitor<void> {

if (isTopLevel) {
_builder.addTopLevelElement(target!);
} else if (_taggedElements[target] case final tag?) {
_builder.addTagged(node.token.lexeme, tag);
} else {
_builder.addText(node.name);
}
Expand Down
11 changes: 11 additions & 0 deletions drift_dev/lib/src/generated/analysis/results/dart.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions drift_dev/lib/src/services/schema/schema_files.dart
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,9 @@ class SchemaReader {
nullable: nullable,
nameInSql: name,
nameInDart: getterName ?? ReCase(name).camelCase,
defaultArgument:
defaultDart != null ? AnnotatedDartCode([defaultDart]) : null,
defaultArgument: defaultDart != null
? AnnotatedDartCode([DartLexeme(defaultDart)])
: null,
declaration: _declaration,
customConstraints: customConstraints,
constraints: dslFeatures,
Expand Down
23 changes: 22 additions & 1 deletion drift_dev/lib/src/writer/schema_version_writer.dart
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,28 @@ class SchemaVersionWriter {
/// called in different places. This method looks up or creates a method for
/// the given [column], returning it if doesn't exist.
String _referenceColumn(DriftColumn column) {
final text = libraryScope.leaf();
final text = libraryScope.leaf(writeTaggedDartCode: (tag, buffer) {
final dartName = tag.tag;
final referencedColumn = column.owner.columns
.singleWhereOrNull((e) => e.nameInDart == dartName);

if (referencedColumn != null) {
// This references a column in the same table. Since we're not
// generating columns in a table structure where they would be in scope
// for Dart, we have to replace this with a custom expression evaluating
// to the column.
final sqlType = libraryScope.innerColumnType(referencedColumn.sqlType);
final result = libraryScope.dartCode(AnnotatedDartCode.build((b) => b
..addText('(')
..addSymbol('VersionedTable', _schemaLibrary)
..addText('.col<')
..addCode(sqlType)
..addText('>(${asDartLiteral(referencedColumn.nameInSql)}))')));
buffer.write(result);
} else {
buffer.write(tag.lexeme);
}
});
final (type, code) = TableOrViewWriter.instantiateColumn(column, text);

return _columnCodeToFactory.putIfAbsent(code, () {
Expand Down
2 changes: 1 addition & 1 deletion drift_dev/lib/src/writer/tables/data_class_writer.dart
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DataClassWriter {
final nullable = converter.canBeSkippedForNulls && column.nullable;
final code = AnnotatedDartCode([
...AnnotatedDartCode.type(converter.jsonType!).elements,
if (nullable) '?',
if (nullable) const DartLexeme('?'),
]);

return _emitter.dartCode(code);
Expand Down
5 changes: 4 additions & 1 deletion drift_dev/lib/src/writer/utils/column_constraints.dart
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ Map<SqlDialect, String> defaultConstraints(DriftColumn column) {
result.write(defaults);
}
if (feature.dialectSpecific[dialect] case final specific?) {
result.write(' $specific');
if (result.isNotEmpty) {
result.write(' ');
}
result.write(specific);
}
return result.toString();
}
Expand Down
Loading

0 comments on commit be1a5ce

Please sign in to comment.