diff --git a/src/Microsoft.Data.Analysis/DataFrame.IDataView.cs b/src/Microsoft.Data.Analysis/DataFrame.IDataView.cs index 79d5c693fc..27de92da69 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.IDataView.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.IDataView.cs @@ -16,7 +16,7 @@ public partial class DataFrame : IDataView bool IDataView.CanShuffle => false; private DataViewSchema _schema; - internal DataViewSchema DataViewSchema + private DataViewSchema DataViewSchema { get { @@ -70,29 +70,22 @@ private sealed class RowCursor : DataViewRowCursor private bool _disposed; private long _position; private readonly DataFrame _dataFrame; - private readonly List _getters; - private Dictionary _columnIndexToGetterIndex; + private readonly Delegate[] _getters; public RowCursor(DataFrame dataFrame, bool[] activeColumns) { Debug.Assert(dataFrame != null); Debug.Assert(activeColumns != null); - _columnIndexToGetterIndex = new Dictionary(); _position = -1; _dataFrame = dataFrame; - _getters = new List(); - for (int i = 0; i < Schema.Count; i++) + _getters = new Delegate[Schema.Count]; + for (int i = 0; i < _getters.Length; i++) { if (!activeColumns[i]) - { continue; - } - - Delegate getter = CreateGetterDelegate(i); - _getters.Add(getter); - Debug.Assert(getter != null); - _columnIndexToGetterIndex[i] = _getters.Count - 1; + _getters[i] = CreateGetterDelegate(i); + Debug.Assert(_getters[i] != null); } } @@ -103,15 +96,11 @@ public RowCursor(DataFrame dataFrame, bool[] activeColumns) protected override void Dispose(bool disposing) { if (_disposed) - { return; - } - if (disposing) { _position = -1; } - _disposed = true; base.Dispose(disposing); } @@ -127,7 +116,7 @@ public override ValueGetter GetGetter(DataViewSchema.Column colu if (!IsColumnActive(column)) throw new ArgumentOutOfRangeException(nameof(column)); - return (ValueGetter)_getters[_columnIndexToGetterIndex[column.Index]]; + return (ValueGetter)_getters[column.Index]; } public override ValueGetter GetIdGetter() @@ -137,15 +126,13 @@ public override ValueGetter GetIdGetter() public override bool IsColumnActive(DataViewSchema.Column column) { - return _getters[_columnIndexToGetterIndex[column.Index]] != null; + return _getters[column.Index] != null; } public override bool MoveNext() { if (_disposed) - { return false; - } _position++; return _position < _dataFrame.Rows.Count; } diff --git a/src/Microsoft.Data.Analysis/DataFrameColumn.cs b/src/Microsoft.Data.Analysis/DataFrameColumn.cs index 346dd4f242..bd21d6fe96 100644 --- a/src/Microsoft.Data.Analysis/DataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/DataFrameColumn.cs @@ -251,15 +251,14 @@ public virtual DataFrameColumn Sort(bool ascending = true) /// Appends a value to this using /// /// The row cursor which has the current position - /// The in /// The cached ValueGetter for this column. - protected internal virtual void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn, Delegate ValueGetter) => throw new NotImplementedException(); + protected internal virtual void AddValueUsingCursor(DataViewRowCursor cursor, Delegate ValueGetter) => throw new NotImplementedException(); /// /// Returns the ValueGetter for each active column in as a delegate to be cached. /// /// The row cursor which has the current position - /// The in + /// The to return the ValueGetter for. protected internal virtual Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) => throw new NotImplementedException(); /// diff --git a/src/Microsoft.Data.Analysis/IDataView.Extension.cs b/src/Microsoft.Data.Analysis/IDataView.Extension.cs index 5b9c9034bc..32b97d365a 100644 --- a/src/Microsoft.Data.Analysis/IDataView.Extension.cs +++ b/src/Microsoft.Data.Analysis/IDataView.Extension.cs @@ -13,20 +13,40 @@ public static class IDataViewExtensions { private const int defaultMaxRows = 100; + /// + /// Returns a from this . + /// + /// The current . + /// The max number or rows in the . Defaults to 100. Use -1 to construct a DataFrame using all the rows in . + /// A with . public static DataFrame ToDataFrame(this IDataView dataView, long maxRows = defaultMaxRows) { return ToDataFrame(dataView, maxRows, null); } + /// + /// Returns a with the first 100 rows of this . + /// + /// The current . + /// The columns selected for the resultant DataFrame + /// A with the selected columns and 100 rows. public static DataFrame ToDataFrame(this IDataView dataView, params string[] selectColumns) { return ToDataFrame(dataView, defaultMaxRows, selectColumns); } + /// + /// Returns a with the first of this . + /// + /// The current . + /// The max number or rows in the . Use -1 to construct a DataFrame using all the rows in . + /// The columns selected for the resultant DataFrame + /// A with the selected columns and rows. public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, params string[] selectColumns) { DataViewSchema schema = dataView.Schema; - List columns = new List(schema.Count); + List dataFrameColumns = new List(schema.Count); + maxRows = maxRows == -1 ? long.MaxValue : maxRows; HashSet selectColumnsSet = null; if (selectColumns != null && selectColumns.Length > 0) @@ -34,63 +54,63 @@ public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, param selectColumnsSet = new HashSet(selectColumns); } - List activeColumns = new List(); - foreach (DataViewSchema.Column column in schema) + List activeDataViewColumns = new List(); + foreach (DataViewSchema.Column dataViewColumn in schema) { - if (column.IsHidden || (selectColumnsSet != null && !selectColumnsSet.Contains(column.Name))) + if (dataViewColumn.IsHidden || (selectColumnsSet != null && !selectColumnsSet.Contains(dataViewColumn.Name))) { continue; } - activeColumns.Add(column); - DataViewType type = column.Type; + activeDataViewColumns.Add(dataViewColumn); + DataViewType type = dataViewColumn.Type; if (type == BooleanDataViewType.Instance) { - columns.Add(new BooleanDataFrameColumn(column.Name)); + dataFrameColumns.Add(new BooleanDataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.Byte) { - columns.Add(new ByteDataFrameColumn(column.Name)); + dataFrameColumns.Add(new ByteDataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.Double) { - columns.Add(new DoubleDataFrameColumn(column.Name)); + dataFrameColumns.Add(new DoubleDataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.Single) { - columns.Add(new SingleDataFrameColumn(column.Name)); + dataFrameColumns.Add(new SingleDataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.Int32) { - columns.Add(new Int32DataFrameColumn(column.Name)); + dataFrameColumns.Add(new Int32DataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.Int64) { - columns.Add(new Int64DataFrameColumn(column.Name)); + dataFrameColumns.Add(new Int64DataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.SByte) { - columns.Add(new SByteDataFrameColumn(column.Name)); + dataFrameColumns.Add(new SByteDataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.Int16) { - columns.Add(new Int16DataFrameColumn(column.Name)); + dataFrameColumns.Add(new Int16DataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.UInt32) { - columns.Add(new UInt32DataFrameColumn(column.Name)); + dataFrameColumns.Add(new UInt32DataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.UInt64) { - columns.Add(new UInt64DataFrameColumn(column.Name)); + dataFrameColumns.Add(new UInt64DataFrameColumn(dataViewColumn.Name)); } else if (type == NumberDataViewType.UInt16) { - columns.Add(new UInt16DataFrameColumn(column.Name)); + dataFrameColumns.Add(new UInt16DataFrameColumn(dataViewColumn.Name)); } else if (type == TextDataViewType.Instance) { - columns.Add(new StringDataFrameColumn(column.Name)); + dataFrameColumns.Add(new StringDataFrameColumn(dataViewColumn.Name)); } else { @@ -98,28 +118,26 @@ public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, param } } - using (DataViewRowCursor cursor = dataView.GetRowCursor(activeColumns)) + using (DataViewRowCursor cursor = dataView.GetRowCursor(activeDataViewColumns)) { - Delegate[] activeColumnDelegates = new Delegate[activeColumns.Count]; + Delegate[] activeColumnDelegates = new Delegate[activeDataViewColumns.Count]; int columnIndex = 0; - foreach (DataViewSchema.Column column in activeColumns) + foreach (DataViewSchema.Column activeDataViewColumn in activeDataViewColumns) { - Delegate valueGetter = columns[columnIndex].GetValueGetterUsingCursor(cursor, column); + Delegate valueGetter = dataFrameColumns[columnIndex].GetValueGetterUsingCursor(cursor, activeDataViewColumn); activeColumnDelegates[columnIndex] = valueGetter; columnIndex++; } while (cursor.MoveNext() && cursor.Position < maxRows) { - columnIndex = 0; - foreach (DataViewSchema.Column column in activeColumns) + for (int i = 0; i < activeColumnDelegates.Length; i++) { - columns[columnIndex].AddValueUsingCursor(cursor, column, activeColumnDelegates[columnIndex]); - columnIndex++; + dataFrameColumns[i].AddValueUsingCursor(cursor, activeColumnDelegates[i]); } } } - return new DataFrame(columns); + return new DataFrame(dataFrameColumns); } } diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index f91c72802c..a7e7d20cb9 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -776,7 +776,7 @@ private static ValueGetter CreateCharValueGetterDelegate(DataViewRowCurs private static ValueGetter CreateDecimalValueGetterDelegate(DataViewRowCursor cursor, PrimitiveDataFrameColumn column) => (ref double value) => value = (double?)column[cursor.Position] ?? double.NaN; - protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column column, Delegate getter) + protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, Delegate getter) { long row = cursor.Position; T value = default; diff --git a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs index 197cce721d..7ada30e10c 100644 --- a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs @@ -468,7 +468,7 @@ protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor) private ValueGetter> CreateValueGetterDelegate(DataViewRowCursor cursor) => (ref ReadOnlyMemory value) => value = this[cursor.Position].AsMemory(); - protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn, Delegate getter) + protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, Delegate getter) { long row = cursor.Position; ReadOnlyMemory value = default; @@ -489,6 +489,7 @@ protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, D throw new IndexOutOfRangeException(nameof(row)); } } + protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) { return cursor.GetGetter>(schemaColumn); diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs index c090817cf5..dea8099876 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs @@ -252,25 +252,41 @@ public void TestDataFrameFromIDataView_SelectColumns() Assert.True(df.Columns["Double"].ElementwiseEquals(newDf.Columns["Double"]).All()); } - [Fact] - public void TestDataFrameFromIDataView_SelectRows() + [Theory] + [InlineData(10, 5)] + [InlineData(110, 100)] + [InlineData(110, -1)] + public void TestDataFrameFromIDataView_SelectRows(int dataFrameSize, int rowSize) { - DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false); + DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(dataFrameSize, withNulls: false); df.Columns.Remove("Char"); // Because chars are returned as uint16 by DataViewSchema, so end up comparing CharDataFrameColumn to UInt16DataFrameColumn and fail asserts df.Columns.Remove("Decimal"); // Because decimal is returned as double by DataViewSchema, so end up comparing DecimalDataFrameColumn to DoubleDataFrameColumn and fail asserts IDataView dfAsIDataView = df; - DataFrame newDf = dfAsIDataView.ToDataFrame(5); - Assert.Equal(5, newDf.Rows.Count); + DataFrame newDf; + if (rowSize == 100) + { + // Test default + newDf = dfAsIDataView.ToDataFrame(); + } + else + { + newDf = dfAsIDataView.ToDataFrame(rowSize); + } + if (rowSize == -1) + { + rowSize = dataFrameSize; + } + Assert.Equal(rowSize, newDf.Rows.Count); Assert.Equal(df.Columns.Count, newDf.Columns.Count); for (int i = 0; i < newDf.Columns.Count; i++) { - Assert.Equal(5, newDf.Columns[i].Length); + Assert.Equal(rowSize, newDf.Columns[i].Length); Assert.Equal(df.Columns[i].Name, newDf.Columns[i].Name); } Assert.Equal(dfAsIDataView.Schema.Count, newDf.Columns.Count); for (int c = 0; c < df.Columns.Count; c++) { - for (int r = 0; r < 5; r++) + for (int r = 0; r < rowSize; r++) { Assert.Equal(df.Columns[c][r], newDf.Columns[c][r]); } diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index c277aae36e..300babbffb 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -62,10 +62,12 @@ public static ArrowStringDataFrameColumn CreateArrowStringColumn(int length, boo // write the current length to (index + 1) int offsetIndex = (i + 1) * 4; - offsetMemory[offsetIndex++] = (byte)(3 * validStringsIndex); - offsetMemory[offsetIndex++] = 0; - offsetMemory[offsetIndex++] = 0; - offsetMemory[offsetIndex++] = 0; + int offsetValue = 3 * validStringsIndex; + byte[] offsetValueBytes = BitConverter.GetBytes(offsetValue); + offsetMemory[offsetIndex++] = offsetValueBytes[0]; + offsetMemory[offsetIndex++] = offsetValueBytes[1]; + offsetMemory[offsetIndex++] = offsetValueBytes[2]; + offsetMemory[offsetIndex++] = offsetValueBytes[3]; } int nullCount = withNulls ? 1 : 0;