diff --git a/Project.toml b/Project.toml index 39244ce..0e01e7a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.1.0" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] diff --git a/src/connector.jl b/src/connector.jl index de62907..657eebb 100644 --- a/src/connector.jl +++ b/src/connector.jl @@ -1,4 +1,4 @@ -using Tables +using Tables, StatsBase abstract type AbstractConnector{T} end @@ -21,6 +21,14 @@ Base.length(conn::TablesConnector) = length(conn.rows) hasnext(conn::TablesConnector) = conn.state < length(conn) -TablesConnector(df::DataFrames.DataFrame) = TablesConnector{DataFrame}(Tables.rows(df), 0) +function TablesConnector(df::DataFrames.DataFrame; + orderBy = nothing, + rev = false, + shuffle = false) + + shuffle == true ? df = df[StatsBase.shuffle(1:size(df,1)), :] : orderBy != nothing && sort!(df, DataFrames.order(orderBy, rev=rev)) + + return TablesConnector{DataFrame}(Tables.rows(df), 0) +end TablesConnector(filename::String) = TablesConnector{DataFrame}(CSV.read(filename; header = false)) \ No newline at end of file