Skip to content

Commit

Permalink
Merge pull request #35 from SaguaroCapital/tyler-expand-test-coverage
Browse files Browse the repository at this point in the history
Expand Test Coverage
  • Loading branch information
tylerjthomas9 authored Nov 30, 2023
2 parents a362162 + 24cc890 commit 3aef664
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 29 deletions.
2 changes: 1 addition & 1 deletion ext/MarketDataExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function SaguaroTrader.download_market_data(
end

function SaguaroTrader.download_market_data(
securities::Vector{Symbol},
securities::AbstractVector{Symbol},
data_dir::String="./temp/";
start_dt::DateTime=DateTime(1990, 1, 1),
end_dt::DateTime=DateTime(2040, 1, 1),
Expand Down
4 changes: 2 additions & 2 deletions src/asset/universe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Dynamic asset universe. Each asset has a start date
Parameters
----------
- `asset_dates::Vector{Dict{Symbol,DateTime}}`
- `asset_dates::AbstractVector{Dict{Symbol,DateTime}}`
"""
struct DynamicUniverse <: Universe
asset_dates::Dict{Asset,DateTime}
Expand All @@ -46,7 +46,7 @@ Returns
-------
- `Vector{Asset}`
"""
function _get_assets(uni::DynamicUniverse, dt)::Vector{Asset}
function _get_assets(uni::DynamicUniverse, dt)::AbstractVector{Asset}
return Vector{Asset}([
asset for (asset, asset_date) in uni.asset_dates if asset_date < dt
])
Expand Down
28 changes: 16 additions & 12 deletions src/data/daily_bar_csv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function _load_csv_into_df(
end

function _load_csvs_into_dfs(
csv_files::Vector{String},
csv_files::AbstractVector{String},
adjust_prices::Bool=false,
market_open::Dates.CompoundPeriod=Hour(14) + Minute(30),
market_close::Dates.CompoundPeriod=Hour(21) + Minute(59),
Expand Down Expand Up @@ -60,13 +60,13 @@ function _load_csvs_into_dfs(
return dict_asset_dfs
end

function _detect_adj_column(columns::Vector{String}, identifier::String)
function _detect_adj_column(columns::AbstractVector{String}, identifier::String)
identifier = lowercase(identifier)
for col in columns
col_lowered = lowercase(col)
if occursin("adj", col_lowered)
if occursin(identifier, col_lowered)
return col
return Symbol(col)
end
end
end
Expand All @@ -75,12 +75,16 @@ function _detect_adj_column(columns::Vector{String}, identifier::String)
for col in columns
col_lowered = lowercase(col)
if identifier == col_lowered
return col
return Symbol(col)
end
end
return error("Unable to detect '$identifier' column in columns: $columns")
end

_detect_adj_column(columns::AbstractVector{Symbol}, identifier::String) = _detect_adj_column(string.(columns), identifier)
_detect_adj_column(columns::AbstractVector{String}, identifier::Symbol) = _detect_adj_column(columns, string(identifier))
_detect_adj_column(columns::AbstractVector{Symbol}, identifier::Symbol) = _detect_adj_column(string.(columns), string(identifier))

"""
Estimate Bid-Ask spreads from OHLCV data
Expand Down Expand Up @@ -188,7 +192,7 @@ struct CSVDailyBarSource <: DataSource
csv_dir::String
asset_type::DataType
adjust_prices::Bool
assets::Vector{Asset}
assets::AbstractVector{Asset}
dict_asset_dfs::Dict{Symbol,DataFrame}
csv_symbols::Union{Nothing,Vector{Symbol}}
market_open::Dates.CompoundPeriod
Expand Down Expand Up @@ -244,7 +248,7 @@ end

function get_bid(ds::CSVDailyBarSource, dt::DateTime, asset::Symbol)::Float64
df_bid_ask = ds.dict_asset_dfs[asset]
ix = searchsortedfirst(df_bid_ask.timestamp::Vector{DateTime}, dt)::Int64
ix = searchsortedfirst(df_bid_ask.timestamp::AbstractVector{DateTime}, dt)::Int64
if (ix == 1) || (ix > size(df_bid_ask, 1))
return NaN
else
Expand All @@ -254,7 +258,7 @@ end

function get_ask(ds::CSVDailyBarSource, dt::DateTime, asset::Symbol)::Float64
df_bid_ask = ds.dict_asset_dfs[asset]
ix = searchsortedfirst(df_bid_ask.timestamp::Vector{DateTime}, dt)::Int64
ix = searchsortedfirst(df_bid_ask.timestamp::AbstractVector{DateTime}, dt)::Int64
if (ix == 1) || (ix > size(df_bid_ask, 1))
return NaN
else
Expand All @@ -264,15 +268,15 @@ end

function get_volume(ds::CSVDailyBarSource, dt::DateTime, asset::Symbol)::Float64
df_bid_ask = ds.dict_asset_dfs[asset]
ix = searchsortedfirst(df_bid_ask.timestamp::Vector{DateTime}, dt)::Int64
ix = searchsortedfirst(df_bid_ask.timestamp::AbstractVector{DateTime}, dt)::Int64
if (ix == 1) || (ix > size(df_bid_ask, 1))
return NaN
else
return @inbounds df_bid_ask[ix, :Volume]::Float64
end
end

function _get_timestamp_start(start_dt::DateTime, v::Vector{DateTime})::keytype(v)
function _get_timestamp_start(start_dt::DateTime, v::AbstractVector{DateTime})::keytype(v)
start_ix = findfirst(>=(start_dt), v)
if start_ix === nothing
return firstindex(v)
Expand All @@ -281,7 +285,7 @@ function _get_timestamp_start(start_dt::DateTime, v::Vector{DateTime})::keytype(
end
end

function _get_timestamp_end(end_dt::DateTime, v::Vector{DateTime})::keytype(v)
function _get_timestamp_end(end_dt::DateTime, v::AbstractVector{DateTime})::keytype(v)
end_ix = findlast(<=(end_dt), v)
if end_ix === nothing
return firstindex(v)
Expand All @@ -291,7 +295,7 @@ function _get_timestamp_end(end_dt::DateTime, v::Vector{DateTime})::keytype(v)
end

function get_assets_historical_bids(
ds::CSVDailyBarSource, start_dt::DateTime, end_dt::DateTime, assets::Vector{Symbol}
ds::CSVDailyBarSource, start_dt::DateTime, end_dt::DateTime, assets::AbstractVector{Symbol}
)
# TODO: do we want historical close like qstrader uses?
prices_df = DataFrame(; timestamp=Vector{DateTime}())
Expand All @@ -301,7 +305,7 @@ function get_assets_historical_bids(
continue
end
df_bid_ask = ds.dict_asset_dfs[asset]
timestamp = df_bid_ask.timestamp::Vector{DateTime}
timestamp = df_bid_ask.timestamp::AbstractVector{DateTime}
start_ix = _get_timestamp_start(start_dt, timestamp)
end_ix = _get_timestamp_end(end_dt, timestamp)
df_bid_ask_subset = df_bid_ask[start_ix:end_ix, [:timestamp, :Bid]]::DataFrame
Expand Down
6 changes: 3 additions & 3 deletions src/data/data_handler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function get_asset_latest_mid_price(dh::DataHandler, dt::DateTime, asset::Symbol
end

function get_assets_historical_range_close_price(
dh::DataHandler, start_dt::DateTime, end_dt::DateTime, assets::Vector{Symbol}
dh::DataHandler, start_dt::DateTime, end_dt::DateTime, assets::AbstractVector{Symbol}
)
df_prices = DataFrame(; timestamp=Vector{DateTime})
for ds in dh.data_sources
Expand Down Expand Up @@ -212,7 +212,7 @@ get_assets_historical_range_close_price(
dh::DataHandler,
start_dt::DateTime,
end_dt::DateTime,
assets::Vector{Symbol},
assets::AbstractVector{Symbol},
)
```
Expand All @@ -223,7 +223,7 @@ Parameters
- `dh::DataHandler`
- `start_dt::DateTime`
- `end_dt::DateTime`
- `asset::Vector{Symbol}`
- `asset::AbstractVector{Symbol}`
Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions src/portfolio/portfolio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Fields
- `pos_handler::PositionHandler`
- `portfolio_id::String`
- `name::String`
- `history::Vector{PortfolioEvent}`
- `history::AbstractVector{PortfolioEvent}`
"""
mutable struct Portfolio
current_dt::DateTime
Expand All @@ -24,7 +24,7 @@ mutable struct Portfolio
pos_handler::PositionHandler
portfolio_id::String
name::String
history::Vector{PortfolioEvent}
history::AbstractVector{PortfolioEvent}
function Portfolio(
start_dt::DateTime,
starting_cash::Float64,
Expand Down
4 changes: 2 additions & 2 deletions src/rebalance/buy_and_hold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ Only rebalance at the start date
Fields
------
- `start_dt::DateTime`
- `rebalances::Vector{DateTime}`
- `rebalances::AbstractVector{DateTime}`
"""
struct BuyAndHoldRebalance <: Rebalance
start_date::DateTime
end_date::DateTime
rebalances::Vector{DateTime}
rebalances::AbstractVector{DateTime}
function BuyAndHoldRebalance(start_dt::DateTime)
return new(Date(start_dt), Date(start_dt), [start_dt])
end
Expand Down
4 changes: 2 additions & 2 deletions src/rebalance/daily.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ Fields
- `start_date::Date`
- `end_date::Date`
- `market_time::DateTime`
- `rebalances::Vector{DateTime}`
- `rebalances::AbstractVector{DateTime}`
"""
struct DailyRebalance <: Rebalance
start_date::Date
end_date::Date
market_time::Union{Hour,Minute,Dates.CompoundPeriod}
rebalances::Vector{DateTime}
rebalances::AbstractVector{DateTime}
function DailyRebalance(
start_date::Date,
end_date::Date,
Expand Down
4 changes: 2 additions & 2 deletions src/rebalance/monthly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ Fields
- `start_date::Date`
- `end_date::Date`
- `market_time::DateTime`
- `rebalances::Vector{DateTime}`
- `rebalances::AbstractVector{DateTime}`
"""
struct MonthlyRebalance <: Rebalance
start_date::Date
end_date::Date
market_time::Union{Hour,Minute,Dates.CompoundPeriod}
rebalances::Vector{DateTime}
rebalances::AbstractVector{DateTime}
function MonthlyRebalance(
start_date::Date,
end_date::Date,
Expand Down
4 changes: 2 additions & 2 deletions src/rebalance/weekly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ Fields
- `start_date::Date`
- `end_date::Date`
- `market_time::DateTime`
- `rebalances::Vector{DateTime}`
- `rebalances::AbstractVector{DateTime}`
"""
struct WeeklyRebalance <: Rebalance
start_date::Date
end_date::Date
market_time::Union{Hour,Minute,Dates.CompoundPeriod}
rebalances::Vector{DateTime}
rebalances::AbstractVector{DateTime}
function WeeklyRebalance(
start_date::Date,
end_date::Date,
Expand Down
17 changes: 17 additions & 0 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@
@test size(unique_events) == size(unique(unique_events))
end

@testset "_detect_adj_column" begin
columns = ["open", "high", "low", "close", "volume"]
println(SaguaroTrader._detect_adj_column(columns, "close"))
@test SaguaroTrader._detect_adj_column(columns, "close") == :close
@test SaguaroTrader._detect_adj_column(columns, "open") == :open

columns = ["open", "high", "low", "close", "volume", "adj_open", "adj_close"]
@test SaguaroTrader._detect_adj_column(columns, "close") == :adj_close
@test SaguaroTrader._detect_adj_column(columns, "open") == :adj_open


@test SaguaroTrader._detect_adj_column(columns, :close) == :adj_close
@test SaguaroTrader._detect_adj_column(Symbol.(columns), "close") == :adj_close
@test SaguaroTrader._detect_adj_column(Symbol.(columns), :close) == :adj_close

end

########################################################################
# DataHandler
########################################################################
Expand Down
6 changes: 5 additions & 1 deletion test/order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
ord1 = Order(DateTime(2022, 7, 1), 100.0, Equity(:AMD))
ord2 = Order(DateTime(2022, 7, 1), 100.0, Equity(:AMD))
ord3 = Order(DateTime(2022, 7, 1), -100.0, Equity(:AMD))
ord4 = Order(DateTime(2023, 7, 1), 100.0, Equity(:AMD))
ord5 = Order(DateTime(2023, 7, 1), 100.0, Equity(:NVDA))

@test SaguaroTrader.equal_orders(ord1, ord2)
@test !SaguaroTrader.equal_orders(ord1, ord3)
@test !SaguaroTrader.equal_orders(ord1, ord3) # different quantity
@test !SaguaroTrader.equal_orders(ord1, ord4) # different date
@test !SaguaroTrader.equal_orders(ord1, ord5) # different asset
@test ord1.direction == 1
@test ord3.direction == -1
end

0 comments on commit 3aef664

Please sign in to comment.