Skip to content

Commit

Permalink
Fix infering result type of fts5 functions
Browse files Browse the repository at this point in the history
  • Loading branch information
simolus3 committed Oct 18, 2024
1 parent f692bb1 commit 2cd568a
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 9 deletions.
56 changes: 56 additions & 0 deletions drift_dev/test/analysis/resolver/drift/regression_3292_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import 'package:drift/drift.dart';
import 'package:drift_dev/src/analysis/options.dart';
import 'package:drift_dev/src/analysis/results/results.dart';
import 'package:test/test.dart';

import '../../test_utils.dart';

void main() {
test('infers types for bm25 and snippet functions', () async {
final backend = await TestBackend.inTest({
'a|lib/a.drift': '''
CREATE VIRTUAL TABLE songs_fts USING fts5(uuid, source_bank_id, title, lyrics, composer, poet, translator, pitch_field);
song_fulltext_search(:match_string AS TEXT):
SELECT
BM25(songs_fts, 0.0, 0.0, 10.0, 0.5, 5.0, 5.0, 2.0, 0.0) AS rank
,uuid
,source_bank_id
,pitch_field
,SNIPPET(songs_fts, 2, '<?', '?>', '...', 30) AS match_title
,SNIPPET(songs_fts, 3, '<?', '?>', '...', 30) AS match_lyrics
,SNIPPET(songs_fts, 4, '<?', '?>', '...', 30) AS match_composer
,SNIPPET(songs_fts, 5, '<?', '?>', '...', 30) AS match_poet
,SNIPPET(songs_fts, 6, '<?', '?>', '...', 30) AS match_translator
FROM songs_fts
WHERE songs_fts MATCH :match_string
ORDER BY rank;
''',
}, options: DriftOptions.defaults(modules: [SqlModule.fts5]));

final file = await backend.analyze('package:a/a.drift');
backend.expectNoErrors();

final query =
file.fileAnalysis!.resolvedQueries.values.single as SqlSelectQuery;
expect(
query.resultSet.columns.map(
(e) => (
e.dartGetterName(const []),
(e as ScalarResultColumn).sqlType.builtin
),
),
[
('rank', DriftSqlType.double),
('uuid', DriftSqlType.string),
('sourceBankId', DriftSqlType.string),
('pitchField', DriftSqlType.string),
('matchTitle', DriftSqlType.string),
('matchLyrics', DriftSqlType.string),
('matchComposer', DriftSqlType.string),
('matchPoet', DriftSqlType.string),
('matchTranslator', DriftSqlType.string),
],
);
});
}
4 changes: 2 additions & 2 deletions sqlparser/lib/src/engine/module/fts5.dart
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class _Fts5Functions with ArgumentCountLinter implements FunctionHandler {
if (argumentIndex == 0) {
return const ResolveResult.unknown();
} else {
return const ResolveResult(ResolvedType(type: BasicType.int));
return const ResolveResult(ResolvedType(type: BasicType.real));
}

case 'highlight':
Expand All @@ -184,7 +184,7 @@ class _Fts5Functions with ArgumentCountLinter implements FunctionHandler {
@override
ResolveResult inferReturnType(AnalysisContext context, SqlInvocation call,
List<Typeable> expandedArgs) {
switch (call.name) {
switch (call.name.toLowerCase()) {
case 'bm25':
return const ResolveResult(ResolvedType(type: BasicType.real));
case 'highlight':
Expand Down
6 changes: 1 addition & 5 deletions sqlparser/lib/src/engine/sql_engine.dart
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,7 @@ class SqlEngine {
AnalysisContext analyzeParsed(ParseResult result,
{AnalyzeStatementOptions? stmtOptions}) {
final node = result.rootNode;

final context = _createContext(node, result.sql, stmtOptions);
_analyzeContext(context);

return context;
return analyzeNode(node, result.sql, stmtOptions: stmtOptions);
}

/// Analyzes the given [node], which should be a [CrudStatement].
Expand Down
6 changes: 4 additions & 2 deletions sqlparser/test/engine/module/fts5_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ void main() {
final column = select.resolvedColumns!.singleWhere((c) => c.name == 'b');
expect(result.typeOf(column),
const ResolveResult(ResolvedType(type: BasicType.real)));
expect(result.typeOf((column as ExpressionColumn).expression),
const ResolveResult(ResolvedType(type: BasicType.real)));
});

test('return type of highlight()', () {
Expand Down Expand Up @@ -245,8 +247,8 @@ void main() {
checkVarTypes(
'SELECT bm25(fts, ?, ?) FROM fts;',
[
BasicType.int,
BasicType.int,
BasicType.real,
BasicType.real,
],
);
});
Expand Down

0 comments on commit 2cd568a

Please sign in to comment.