diff --git a/floor_generator/lib/model/change_method.dart b/floor_generator/lib/model/change_method.dart index cee5fea0..9de982f2 100644 --- a/floor_generator/lib/model/change_method.dart +++ b/floor_generator/lib/model/change_method.dart @@ -66,10 +66,10 @@ class ChangeMethod { DartType get _flattenedReturnType { final type = method.returnType.flattenFutures(method.context.typeSystem); - return _returnsList ? flattenList(type) : type; + return returnsList ? flattenList(type) : type; } - bool get _returnsList { + bool get returnsList { final type = method.returnType.flattenFutures(method.context.typeSystem); return isList(type); } diff --git a/floor_generator/lib/writer/insert_method_body_writer.dart b/floor_generator/lib/writer/insert_method_body_writer.dart index 8f9595de..44b855ae 100644 --- a/floor_generator/lib/writer/insert_method_body_writer.dart +++ b/floor_generator/lib/writer/insert_method_body_writer.dart @@ -17,8 +17,9 @@ class InsertMethodBodyWriter implements Writer { } String _generateMethodBody() { - final columnNames = - method.getEntity(library).columns.map((column) => column.name).toList(); + final entity = method.getEntity(library); + + final columnNames = entity.columns.map((column) => column.name).toList(); final constructorParameters = method.flattenedParameterClass.constructors.first.parameters; @@ -29,7 +30,7 @@ class InsertMethodBodyWriter implements Writer { keyValueList.add("'${columnNames[i]}': $valueMapping"); } - final entityName = method.getEntity(library).name; + final entityName = entity.name; final methodSignatureParameterName = method.parameter.displayName; if (method.returnsInt) { @@ -46,7 +47,7 @@ class InsertMethodBodyWriter implements Writer { ); } else { throw InvalidGenerationSourceError( - 'Insert methods have to return a Future of either void, int or List', + 'Insert methods have to return a Future of either void, int or List.', element: method.method, ); } diff --git a/floor_generator/lib/writer/update_method_body_writer.dart b/floor_generator/lib/writer/update_method_body_writer.dart index 18e60ebc..750e595e 100644 --- a/floor_generator/lib/writer/update_method_body_writer.dart +++ b/floor_generator/lib/writer/update_method_body_writer.dart @@ -1,6 +1,7 @@ import 'package:analyzer/dart/element/element.dart'; import 'package:code_builder/code_builder.dart'; import 'package:floor_generator/misc/type_utils.dart'; +import 'package:floor_generator/model/column.dart'; import 'package:floor_generator/model/update_method.dart'; import 'package:floor_generator/writer/writer.dart'; import 'package:source_gen/source_gen.dart'; @@ -17,7 +18,8 @@ class UpdateMethodBodyWriter implements Writer { } String _generateMethodBody() { - final methodHeadParameterName = method.parameter.displayName; + _assertMethodReturnsNoList(); + final entity = method.getEntity(library); final columnNames = entity.columns.map((column) => column.name).toList(); @@ -32,12 +34,71 @@ class UpdateMethodBodyWriter implements Writer { } final entityName = entity.name; + final methodSignatureParameterName = method.parameter.displayName; + final primaryKeyColumn = entity.primaryKeyColumn; + + if (method.returnsInt) { + return _generateIntReturnMethodBody( + methodSignatureParameterName, + keyValueList, + entityName, + primaryKeyColumn, + ); + } else if (method.returnsVoid) { + return _generateVoidReturnMethodBody( + methodSignatureParameterName, + keyValueList, + entityName, + primaryKeyColumn, + ); + } else { + throw InvalidGenerationSourceError( + 'Update methods have to return a Future of either void or int.', + element: method.method, + ); + } + } + + String _generateIntReturnMethodBody( + final String methodSignatureParameterName, + final List keyValueList, + final String entityName, + final Column primaryKeyColumn, + ) { + if (method.changesMultipleItems) { + return ''' + final batch = database.batch(); + for (final item in $methodSignatureParameterName) { + final values = { + ${keyValueList.join(', ')} + }; + batch.update('$entityName', values, where: '${primaryKeyColumn.name} = ?', whereArgs: [item.${primaryKeyColumn.field.displayName}]); + } + return (await batch.commit(noResult: false)) + .cast() + .reduce((first, second) => first + second); + '''; + } else { + return ''' + final item = $methodSignatureParameterName; + final values = { + ${keyValueList.join(', ')} + }; + return database.update('$entityName', values, where: '${primaryKeyColumn.name} = ?', whereArgs: [item.${primaryKeyColumn.field.displayName}]); + '''; + } + } + String _generateVoidReturnMethodBody( + final String methodSignatureParameterName, + final List keyValueList, + final String entityName, + final Column primaryKeyColumn, + ) { if (method.changesMultipleItems) { - final primaryKeyColumn = entity.primaryKeyColumn; return ''' final batch = database.batch(); - for (final item in $methodHeadParameterName) { + for (final item in $methodSignatureParameterName) { final values = { ${keyValueList.join(', ')} }; @@ -47,11 +108,11 @@ class UpdateMethodBodyWriter implements Writer { '''; } else { return ''' - final item = $methodHeadParameterName; + final item = $methodSignatureParameterName; final values = { ${keyValueList.join(', ')} }; - await database.update('$entityName', values); + await database.update('$entityName', values, where: '${primaryKeyColumn.name} = ?', whereArgs: [item.${primaryKeyColumn.field.displayName}]); '''; } } @@ -63,4 +124,13 @@ class UpdateMethodBodyWriter implements Writer { ? 'item.$parameterName ? 1 : 0' : 'item.$parameterName'; } + + void _assertMethodReturnsNoList() { + if (method.returnsList) { + throw InvalidGenerationSourceError( + 'Update methods have to return a Future of either void or int but not a list.', + element: method.method, + ); + } + } } diff --git a/floor_test/test/database.dart b/floor_test/test/database.dart index 18cb199c..4226ae04 100644 --- a/floor_test/test/database.dart +++ b/floor_test/test/database.dart @@ -32,6 +32,12 @@ abstract class TestDatabase extends FloorDatabase { @update Future updatePerson(Person person); + @update + Future updatePersonWithReturn(Person person); + + @update + Future updatePersonsWithReturn(List persons); + @update Future updatePersons(List persons); diff --git a/floor_test/test/database_test.dart b/floor_test/test/database_test.dart index 43b5cf28..369011b6 100644 --- a/floor_test/test/database_test.dart +++ b/floor_test/test/database_test.dart @@ -43,8 +43,8 @@ void main() { test('update person', () async { final person = Person(1, 'Simon'); await database.insertPerson(person); + final updatedPerson = Person(person.id, _reverse(person.name)); - final updatedPerson = Person(person.id, 'Frank'); await database.updatePerson(updatedPerson); final actual = await database.findPersonById(person.id); @@ -71,15 +71,12 @@ void main() { }); test('update persons', () async { - final person1 = Person(1, 'Simon'); - final person2 = Person(2, 'Frank'); - final persons = [person1, person2]; + final persons = [Person(1, 'Simon'), Person(2, 'Frank')]; await database.insertPersons(persons); + final updatedPersons = persons + .map((person) => Person(person.id, _reverse(person.name))) + .toList(); - final updatedPersons = [ - Person(person1.id, _reverse(person1.name)), - Person(person2.id, _reverse(person2.name)) - ]; await database.updatePersons(updatedPersons); final actual = await database.findAllPersons(); @@ -89,8 +86,8 @@ void main() { test('replace persons in transaction', () async { final persons = [Person(1, 'Simon'), Person(2, 'Frank')]; await database.insertPersons(persons); - final newPersons = [Person(3, 'Paul'), Person(4, 'Karl')]; + await database.replacePersons(newPersons); final actual = await database.findAllPersons(); @@ -106,13 +103,38 @@ void main() { }); test('insert persons and return ids of inserted items', () async { - final person1 = Person(1, 'Simon'); - final person2 = Person(2, 'Frank'); - final persons = [person1, person2]; + final persons = [Person(1, 'Simon'), Person(2, 'Frank')]; final actual = await database.insertPersonsWithReturn(persons); - expect(actual, equals([person1.id, person2.id])); + final expected = persons.map((person) => person.id).toList(); + expect(actual, equals(expected)); + }); + + test('update person and return 1 (affected row count)', () async { + final person = Person(1, 'Simon'); + await database.insertPerson(person); + final updatedPerson = Person(person.id, _reverse(person.name)); + + final actual = await database.updatePersonWithReturn(updatedPerson); + + final persistentPerson = await database.findPersonById(person.id); + expect(persistentPerson, equals(updatedPerson)); + expect(actual, equals(1)); + }); + + test('update persons and return affected rows count', () async { + final persons = [Person(1, 'Simon'), Person(2, 'Frank')]; + await database.insertPersons(persons); + final updatedPersons = persons + .map((person) => Person(person.id, _reverse(person.name))) + .toList(); + + final actual = await database.updatePersonsWithReturn(updatedPersons); + + final persistentPersons = await database.findAllPersons(); + expect(persistentPersons, equals(updatedPersons)); + expect(actual, equals(2)); }); }); }