diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs index 6e93c53649b..1b2898b0dd0 100644 --- a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs +++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs @@ -551,9 +551,11 @@ public override Stream GetStream(int ordinal) /// The zero-based column ordinal. /// The returned object. public override TextReader GetTextReader(int ordinal) - => IsDBNull(ordinal) - ? (TextReader)new StringReader(string.Empty) - : new StreamReader(GetStream(ordinal), Encoding.UTF8); + => _closed + ? throw new InvalidOperationException(Resources.DataReaderClosed(nameof(GetTextReader))) + : _record == null + ? throw new InvalidOperationException(Resources.NoData) + : _record.GetTextReader(ordinal); /// /// Gets the value of the specified column. diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs index e74006d4d39..38d4a8cd2ce 100644 --- a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs +++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs @@ -60,7 +60,19 @@ protected override string GetStringCore(int ordinal) => sqlite3_column_text(Handle, ordinal).utf8_to_string(); public override T GetFieldValue(int ordinal) - => base.GetFieldValue(ordinal)!; + { + if (typeof(T) == typeof(Stream)) + { + return (T)(object)GetStream(ordinal); + } + + if (typeof(T) == typeof(TextReader)) + { + return (T)(object)GetTextReader(ordinal); + } + + return base.GetFieldValue(ordinal)!; + } protected override byte[] GetBlob(int ordinal) => base.GetBlob(ordinal)!; @@ -317,6 +329,11 @@ public virtual Stream GetStream(int ordinal) return new SqliteBlob(_connection, blobDatabaseName, blobTableName, blobColumnName, rowid, readOnly: true); } + public virtual TextReader GetTextReader(int ordinal) + => IsDBNull(ordinal) + ? new StringReader(string.Empty) + : new StreamReader(GetStream(ordinal), Encoding.UTF8); + public bool Read() { if (!_stepped) diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs index 31420413799..7a548763880 100644 --- a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs +++ b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs @@ -852,6 +852,45 @@ public void GetFieldValue_of_Nullable_works() "SELECT 1;", (int?)1); + [Fact] + public void GetFieldValue_of_Stream_works() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + + using (var reader = connection.ExecuteReader("SELECT x'7E57';")) + { + var hasData = reader.Read(); + Assert.True(hasData); + + var stream = reader.GetFieldValue(0); + Assert.Equal(0x7E, stream.ReadByte()); + Assert.Equal(0x57, stream.ReadByte()); + } + } + } + + [Fact] + public void GetFieldValue_of_TextReader_works() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + + using (var reader = connection.ExecuteReader("SELECT 'test';")) + { + var hasData = reader.Read(); + Assert.True(hasData); + + using (var textReader = reader.GetFieldValue(0)) + { + Assert.Equal("test", textReader.ReadToEnd()); + } + } + } + } + [Fact] public void GetFieldValue_of_TimeSpan_works() => GetFieldValue_works(