diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs
index 6e93c53649b..1b2898b0dd0 100644
--- a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs
+++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs
@@ -551,9 +551,11 @@ public override Stream GetStream(int ordinal)
/// The zero-based column ordinal.
/// The returned object.
public override TextReader GetTextReader(int ordinal)
- => IsDBNull(ordinal)
- ? (TextReader)new StringReader(string.Empty)
- : new StreamReader(GetStream(ordinal), Encoding.UTF8);
+ => _closed
+ ? throw new InvalidOperationException(Resources.DataReaderClosed(nameof(GetTextReader)))
+ : _record == null
+ ? throw new InvalidOperationException(Resources.NoData)
+ : _record.GetTextReader(ordinal);
///
/// Gets the value of the specified column.
diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs
index e74006d4d39..38d4a8cd2ce 100644
--- a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs
+++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs
@@ -60,7 +60,19 @@ protected override string GetStringCore(int ordinal)
=> sqlite3_column_text(Handle, ordinal).utf8_to_string();
public override T GetFieldValue(int ordinal)
- => base.GetFieldValue(ordinal)!;
+ {
+ if (typeof(T) == typeof(Stream))
+ {
+ return (T)(object)GetStream(ordinal);
+ }
+
+ if (typeof(T) == typeof(TextReader))
+ {
+ return (T)(object)GetTextReader(ordinal);
+ }
+
+ return base.GetFieldValue(ordinal)!;
+ }
protected override byte[] GetBlob(int ordinal)
=> base.GetBlob(ordinal)!;
@@ -317,6 +329,11 @@ public virtual Stream GetStream(int ordinal)
return new SqliteBlob(_connection, blobDatabaseName, blobTableName, blobColumnName, rowid, readOnly: true);
}
+ public virtual TextReader GetTextReader(int ordinal)
+ => IsDBNull(ordinal)
+ ? new StringReader(string.Empty)
+ : new StreamReader(GetStream(ordinal), Encoding.UTF8);
+
public bool Read()
{
if (!_stepped)
diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs
index 31420413799..7a548763880 100644
--- a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs
+++ b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs
@@ -852,6 +852,45 @@ public void GetFieldValue_of_Nullable_works()
"SELECT 1;",
(int?)1);
+ [Fact]
+ public void GetFieldValue_of_Stream_works()
+ {
+ using (var connection = new SqliteConnection("Data Source=:memory:"))
+ {
+ connection.Open();
+
+ using (var reader = connection.ExecuteReader("SELECT x'7E57';"))
+ {
+ var hasData = reader.Read();
+ Assert.True(hasData);
+
+ var stream = reader.GetFieldValue(0);
+ Assert.Equal(0x7E, stream.ReadByte());
+ Assert.Equal(0x57, stream.ReadByte());
+ }
+ }
+ }
+
+ [Fact]
+ public void GetFieldValue_of_TextReader_works()
+ {
+ using (var connection = new SqliteConnection("Data Source=:memory:"))
+ {
+ connection.Open();
+
+ using (var reader = connection.ExecuteReader("SELECT 'test';"))
+ {
+ var hasData = reader.Read();
+ Assert.True(hasData);
+
+ using (var textReader = reader.GetFieldValue(0))
+ {
+ Assert.Equal("test", textReader.ReadToEnd());
+ }
+ }
+ }
+ }
+
[Fact]
public void GetFieldValue_of_TimeSpan_works()
=> GetFieldValue_works(