From f6fb882fa97d3447a7019c18775f5b5aa0a2d098 Mon Sep 17 00:00:00 2001 From: MBounouar Date: Fri, 3 Sep 2021 14:18:24 +0200 Subject: [PATCH] removed deprecate api (#46) --- src/zipline/_protocol.pyx | 223 +----------- src/zipline/algorithm.py | 138 +------- src/zipline/assets/_assets.pyx | 46 +-- src/zipline/data/data_portal.py | 136 +------- src/zipline/data/minute_bars.py | 36 +- src/zipline/errors.py | 8 - src/zipline/gens/tradesimulation.py | 17 +- src/zipline/protocol.py | 160 +-------- tests/test_api_shim.py | 513 +--------------------------- tests/test_assets.py | 24 -- tests/test_data_portal.py | 43 --- tests/test_history.py | 209 ++++-------- tests/test_tradesimulation.py | 1 - 13 files changed, 111 insertions(+), 1443 deletions(-) diff --git a/src/zipline/_protocol.pyx b/src/zipline/_protocol.pyx index e45a2a7e2f..d1bc12da8f 100644 --- a/src/zipline/_protocol.pyx +++ b/src/zipline/_protocol.pyx @@ -139,18 +139,12 @@ cdef class BarData: restrictions : zipline.finance.asset_restrictions.Restrictions Object that combines and returns restricted list information from multiple sources - universe_func : callable, optional - Function which returns the current 'universe'. This is for - backwards compatibility with older API concepts. """ cdef object data_portal cdef object simulation_dt_func cdef object data_frequency cdef object restrictions cdef dict _views - cdef object _universe_func - cdef object _last_calculated_universe - cdef object _universe_last_updated_at cdef bool _daily_mode cdef object _trading_calendar cdef object _is_restricted @@ -158,7 +152,7 @@ cdef class BarData: cdef bool _adjust_minutes def __init__(self, data_portal, simulation_dt_func, data_frequency, - trading_calendar, restrictions, universe_func=None): + trading_calendar, restrictions): self.data_portal = data_portal self.simulation_dt_func = simulation_dt_func self.data_frequency = data_frequency @@ -166,51 +160,11 @@ cdef class BarData: self._daily_mode = (self.data_frequency == "daily") - self._universe_func = universe_func - self._last_calculated_universe = None - self._universe_last_updated_at = None - self._adjust_minutes = False self._trading_calendar = trading_calendar self._is_restricted = restrictions.is_restricted - cdef _get_equity_price_view(self, asset): - """ - Returns a DataPortalSidView for the given asset. Used to support the - data[sid(N)] public API. Not needed if DataPortal is used standalone. - - Parameters - ---------- - asset : Asset - Asset that is being queried. - - Returns - ------- - SidView : Accessor into the given asset's data. - """ - try: - self._warn_deprecated("`data[sid(N)]` is deprecated. Use " - "`data.current`.") - view = self._views[asset] - except KeyError: - try: - asset = self.data_portal.asset_finder.retrieve_asset(asset) - except ValueError: - # assume fetcher - pass - view = self._views[asset] = self._create_sid_view(asset) - - return view - - cdef _create_sid_view(self, asset): - return SidView( - asset, - self.data_portal, - self.simulation_dt_func, - self.data_frequency - ) - cdef _get_current_minute(self): """ Internal utility method to get the current simulation time. @@ -742,181 +696,6 @@ cdef class BarData: self.current_session ) - ################# - # OLD API SUPPORT - ################# - cdef _calculate_universe(self): - if self._universe_func is None: - return [] - - simulation_dt = self.simulation_dt_func() - if self._last_calculated_universe is None or \ - self._universe_last_updated_at != simulation_dt: - self._last_calculated_universe = self._universe_func() - self._universe_last_updated_at = simulation_dt - - return self._last_calculated_universe - - def __iter__(self): - self._warn_deprecated("Iterating over the assets in `data` is " - "deprecated.") - for asset in self._calculate_universe(): - yield asset - - def __contains__(self, asset): - self._warn_deprecated("Checking whether an asset is in data is " - "deprecated.") - universe = self._calculate_universe() - return asset in universe - - def items(self): - self._warn_deprecated("Iterating over the assets in `data` is " - "deprecated.") - return [(asset, self[asset]) for asset in self._calculate_universe()] - - def iteritems(self): - self._warn_deprecated("Iterating over the assets in `data` is " - "deprecated.") - for asset in self._calculate_universe(): - yield asset, self[asset] - - def __len__(self): - self._warn_deprecated("Iterating over the assets in `data` is " - "deprecated.") - - return len(self._calculate_universe()) - - def keys(self): - self._warn_deprecated("Iterating over the assets in `data` is " - "deprecated.") - - return list(self._calculate_universe()) - - def iterkeys(self): - return iter(self.keys()) - - def __getitem__(self, name): - return self._get_equity_price_view(name) - - cdef _warn_deprecated(self, msg): - warnings.warn( - msg, - category=ZiplineDeprecationWarning, - stacklevel=1 - ) - -cdef class SidView: - cdef object asset - cdef object data_portal - cdef object simulation_dt_func - cdef object data_frequency - - """ - This class exists to temporarily support the deprecated data[sid(N)] API. - """ - def __init__(self, asset, data_portal, simulation_dt_func, data_frequency): - """ - Parameters - --------- - asset : Asset - The asset for which the instance retrieves data. - - data_portal : DataPortal - Provider for bar pricing data. - - simulation_dt_func: function - Function which returns the current simulation time. - This is usually bound to a method of TradingSimulation. - - data_frequency: string - The frequency of the bar data; i.e. whether the data is - 'daily' or 'minute' bars - """ - self.asset = asset - self.data_portal = data_portal - self.simulation_dt_func = simulation_dt_func - self.data_frequency = data_frequency - - def __getattr__(self, column): - # backwards compatibility code for Q1 API - if column == "close_price": - column = "close" - elif column == "open_price": - column = "open" - elif column == "dt": - return self.dt - elif column == "datetime": - return self.datetime - elif column == "sid": - return self.sid - - return self.data_portal.get_spot_value( - self.asset, - column, - self.simulation_dt_func(), - self.data_frequency - ) - - def __contains__(self, column): - return self.data_portal.contains(self.asset, column) - - def __getitem__(self, column): - return self.__getattr__(column) - - property sid: - def __get__(self): - return self.asset - - property dt: - def __get__(self): - return self.datetime - - property datetime: - def __get__(self): - return self.data_portal.get_last_traded_dt( - self.asset, - self.simulation_dt_func(), - self.data_frequency) - - property current_dt: - def __get__(self): - return self.simulation_dt_func() - - def mavg(self, num_minutes): - self._warn_deprecated("The `mavg` method is deprecated.") - return self.data_portal.get_simple_transform( - self.asset, "mavg", self.simulation_dt_func(), - self.data_frequency, bars=num_minutes - ) - - def stddev(self, num_minutes): - self._warn_deprecated("The `stddev` method is deprecated.") - return self.data_portal.get_simple_transform( - self.asset, "stddev", self.simulation_dt_func(), - self.data_frequency, bars=num_minutes - ) - - def vwap(self, num_minutes): - self._warn_deprecated("The `vwap` method is deprecated.") - return self.data_portal.get_simple_transform( - self.asset, "vwap", self.simulation_dt_func(), - self.data_frequency, bars=num_minutes - ) - - def returns(self): - self._warn_deprecated("The `returns` method is deprecated.") - return self.data_portal.get_simple_transform( - self.asset, "returns", self.simulation_dt_func(), - self.data_frequency - ) - - cdef _warn_deprecated(self, msg): - warnings.warn( - msg, - category=ZiplineDeprecationWarning, - stacklevel=1 - ) - cdef class InnerPosition: """The real values of a position. diff --git a/src/zipline/algorithm.py b/src/zipline/algorithm.py index 09071f82e6..7dba12d022 100644 --- a/src/zipline/algorithm.py +++ b/src/zipline/algorithm.py @@ -32,7 +32,6 @@ AttachPipelineAfterInitialize, CannotOrderDelistedAsset, DuplicatePipelineName, - HistoryInInitialize, IncompatibleCommissionModel, IncompatibleSlippageModel, NoSuchPipeline, @@ -269,9 +268,7 @@ def __init__( asset_finder is not None and asset_finder is not data_portal.asset_finder ): - raise ValueError( - "Inconsistent asset_finders in TradingAlgorithm()" - ) + raise ValueError("Inconsistent asset_finders in TradingAlgorithm()") self.asset_finder = data_portal.asset_finder self.benchmark_returns = benchmark_returns @@ -407,8 +404,6 @@ def noop(*args, **kwargs): self.restrictions = NoRestrictions() - self._backwards_compat_universe = None - def init_engine(self, get_loader): """ Construct and store a PipelineEngine from loader. @@ -490,9 +485,7 @@ def _create_clock(self): """ If the clock property is not set, then create one based on frequency. """ - trading_o_and_c = self.trading_calendar.schedule.loc[ - self.sim_params.sessions - ] + trading_o_and_c = self.trading_calendar.schedule.loc[self.sim_params.sessions] market_closes = trading_o_and_c["market_close"] minutely_emission = False @@ -536,9 +529,7 @@ def _create_clock(self): def _create_benchmark_source(self): if self.benchmark_sid is not None: - benchmark_asset = self.asset_finder.retrieve_asset( - self.benchmark_sid - ) + benchmark_asset = self.asset_finder.retrieve_asset(self.benchmark_sid) benchmark_returns = None else: @@ -587,22 +578,11 @@ def _create_generator(self, sim_params): self._create_clock(), benchmark_source, self.restrictions, - universe_func=self._calculate_universe, ) metrics_tracker.handle_start_of_simulation(benchmark_source) return self.trading_client.transform() - def _calculate_universe(self): - # this exists to provide backwards compatibility for older, - # deprecated APIs, particularly around the iterability of - # BarData (ie, 'for sid in data`). - if self._backwards_compat_universe is None: - self._backwards_compat_universe = self.asset_finder.retrieve_all( - self.asset_finder.sids - ) - return self._backwards_compat_universe - def compute_eager_pipelines(self): """ Compute any pipelines attached with eager=True. @@ -663,9 +643,7 @@ def _create_daily_stats(self, perfs): for perf in perfs: if "daily_perf" in perf: - perf["daily_perf"].update( - perf["daily_perf"].pop("recorded_vars") - ) + perf["daily_perf"].update(perf["daily_perf"].pop("recorded_vars")) perf["daily_perf"].update(perf["cumulative_risk_metrics"]) daily_perfs.append(perf["daily_perf"]) else: @@ -875,9 +853,7 @@ def fetch_csv( ) # ingest this into dataportal - self.data_portal.handle_extra_source( - csv_data_source.df, self.sim_params - ) + self.data_portal.handle_extra_source(csv_data_source.df, self.sim_params) return csv_data_source @@ -963,9 +939,7 @@ def schedule_function( else: raise ScheduleFunctionInvalidCalendar( given_calendar=calendar, - allowed_calendars=( - "[calendars.US_EQUITIES, calendars.US_FUTURES]" - ), + allowed_calendars=("[calendars.US_EQUITIES, calendars.US_FUTURES]"), ) self.add_event( @@ -1189,22 +1163,16 @@ def _calculate_order_value_amount(self, asset, value): " {1}.".format(asset.symbol, asset.end_date) ) else: - last_price = self.trading_client.current_data.current( - asset, "price" - ) + last_price = self.trading_client.current_data.current(asset, "price") if np.isnan(last_price): raise CannotOrderDelistedAsset( msg="Cannot order {0} on {1} as there is no last " - "price for the security.".format( - asset.symbol, self.datetime - ) + "price for the security.".format(asset.symbol, self.datetime) ) if tolerant_equals(last_price, 0): - zero_message = "Price of 0 for {psid}; can't infer value".format( - psid=asset - ) + zero_message = "Price of 0 for {psid}; can't infer value".format(psid=asset) if self.logger: self.logger.debug(zero_message) # Don't place any order @@ -1241,9 +1209,7 @@ def _can_order_asset(self, asset): @api_method @disallowed_in_before_trading_start(OrderInBeforeTradingStart()) - def order( - self, asset, amount, limit_price=None, stop_price=None, style=None - ): + def order(self, asset, amount, limit_price=None, stop_price=None, style=None): """Place an order for a fixed number of shares. Parameters @@ -1297,9 +1263,7 @@ def _calculate_order( amount = self.round_order(amount) # Raises a ZiplineError if invalid parameters are detected. - self.validate_order_params( - asset, amount, limit_price, stop_price, style - ) + self.validate_order_params(asset, amount, limit_price, stop_price, style) # Convert deprecated limit_price and stop_price parameters to use # ExecutionStyle objects. @@ -1320,9 +1284,7 @@ def round_order(amount): """ return int(round_if_near_integer(amount)) - def validate_order_params( - self, asset, amount, limit_price, stop_price, style - ): + def validate_order_params(self, asset, amount, limit_price, stop_price, style): """ Helper method for validating parameters to the order API function. @@ -1355,9 +1317,7 @@ def validate_order_params( ) @staticmethod - def __convert_order_params_for_blotter( - asset, limit_price, stop_price, style - ): + def __convert_order_params_for_blotter(asset, limit_price, stop_price, style): """ Helper method for converting deprecated limit_price and stop_price arguments into ExecutionStyle instances. @@ -1379,9 +1339,7 @@ def __convert_order_params_for_blotter( @api_method @disallowed_in_before_trading_start(OrderInBeforeTradingStart()) - def order_value( - self, asset, value, limit_price=None, stop_price=None, style=None - ): + def order_value(self, asset, value, limit_price=None, stop_price=None, style=None): """ Place an order for a fixed amount of money. @@ -1627,11 +1585,8 @@ def set_symbol_lookup_date(self, dt): except TypeError: self._symbol_lookup_date = pd.Timestamp(dt).tz_convert("UTC") except ValueError: - raise UnsupportedDatetimeFormat( - input=dt, method="set_symbol_lookup_date" - ) + raise UnsupportedDatetimeFormat(input=dt, method="set_symbol_lookup_date") - # Remain backwards compatibility @property def data_frequency(self): return self.sim_params.data_frequency @@ -1925,9 +1880,7 @@ def batch_market_order(self, share_counts): """ style = MarketOrder() order_args = [ - (asset, amount, style) - for (asset, amount) in share_counts.items() - if amount + (asset, amount, style) for (asset, amount) in share_counts.items() if amount ] return self.blotter.batch_order(order_args) @@ -1997,55 +1950,6 @@ def cancel_order(self, order_param): self.blotter.cancel(order_id) - @api_method - @require_initialized(HistoryInInitialize()) - def history(self, bar_count, frequency, field, ffill=True): - """DEPRECATED: use ``data.history`` instead.""" - warnings.warn( - "The `history` method is deprecated. Use `data.history` instead.", - category=ZiplineDeprecationWarning, - stacklevel=4, - ) - - return self.get_history_window( - bar_count, frequency, self._calculate_universe(), field, ffill - ) - - def get_history_window(self, bar_count, frequency, assets, field, ffill): - if not self._in_before_trading_start: - return self.data_portal.get_history_window( - assets, - self.datetime, - bar_count, - frequency, - field, - self.data_frequency, - ffill, - ) - else: - # If we are in before_trading_start, we need to get the window - # as of the previous market minute - adjusted_dt = self.trading_calendar.previous_minute(self.datetime) - - window = self.data_portal.get_history_window( - assets, - adjusted_dt, - bar_count, - frequency, - field, - self.data_frequency, - ffill, - ) - - # Get the adjustments between the last market minute and the - # current before_trading_start dt and apply to the window - adjs = self.data_portal.get_adjustments( - assets, field, adjusted_dt, self.datetime - ) - window = window * adjs - - return window - #################### # Account Controls # #################### @@ -2383,9 +2287,7 @@ def run_pipeline(self, pipeline, start_session, chunksize): # until chunksize days of data have been loaded. sim_end_session = self.sim_params.end_session - end_loc = min( - start_date_loc + chunksize, sessions.get_loc(sim_end_session) - ) + end_loc = min(start_date_loc + chunksize, sessions.get_loc(sim_end_session)) end_session = sessions[end_loc] @@ -2424,11 +2326,7 @@ def all_api_methods(cls): """ Return a list of all the TradingAlgorithm API methods. """ - return [ - fn - for fn in vars(cls).values() - if getattr(fn, "is_api_method", False) - ] + return [fn for fn in vars(cls).values() if getattr(fn, "is_api_method", False)] # Map from calendar name to default domain for that calendar. diff --git a/src/zipline/assets/_assets.pyx b/src/zipline/assets/_assets.pyx index b79786cf6b..3034c2c6c7 100644 --- a/src/zipline/assets/_assets.pyx +++ b/src/zipline/assets/_assets.pyx @@ -267,40 +267,7 @@ cdef class Equity(Asset): Asset subclass representing partial ownership of a company, trust, or partnership. """ - - property security_start_date: - """ - DEPRECATION: This property should be deprecated and is only present for - backwards compatibility - """ - def __get__(self): - warnings.warn("The security_start_date property will soon be " - "retired. Please use the start_date property instead.", - DeprecationWarning) - return self.start_date - - property security_end_date: - """ - DEPRECATION: This property should be deprecated and is only present for - backwards compatibility - """ - def __get__(self): - warnings.warn("The security_end_date property will soon be " - "retired. Please use the end_date property instead.", - DeprecationWarning) - return self.end_date - - property security_name: - """ - DEPRECATION: This property should be deprecated and is only present for - backwards compatibility - """ - def __get__(self): - warnings.warn("The security_name property will soon be " - "retired. Please use the asset_name property instead.", - DeprecationWarning) - return self.asset_name - + pass @cython.embedsignature(False) cdef class Future(Asset): @@ -361,17 +328,6 @@ cdef class Future(Asset): else: self.auto_close_date = min(notice_date, expiration_date) - property multiplier: - """ - DEPRECATION: This property should be deprecated and is only present for - backwards compatibility - """ - def __get__(self): - warnings.warn("The multiplier property will soon be " - "retired. Please use the price_multiplier property instead.", - DeprecationWarning) - return self.price_multiplier - cpdef __reduce__(self): """ Function used by pickle to determine how to serialize/deserialize this diff --git a/src/zipline/data/data_portal.py b/src/zipline/data/data_portal.py index 62e3f35c74..ab4eb5f330 100644 --- a/src/zipline/data/data_portal.py +++ b/src/zipline/data/data_portal.py @@ -34,7 +34,10 @@ ContinuousFutureSessionBarReader, ContinuousFutureMinuteBarReader, ) -from zipline.assets.roll_finder import CalendarRollFinder, VolumeRollFinder +from zipline.assets.roll_finder import ( + CalendarRollFinder, + VolumeRollFinder, +) from zipline.data.dispatch_bar_reader import ( AssetDispatchMinuteBarReader, AssetDispatchSessionBarReader, @@ -49,12 +52,9 @@ MinuteHistoryLoader, ) from zipline.data.bar_reader import NoDataOnDate -from zipline.utils.math_utils import nansum, nanmean, nanstd -from zipline.utils.memoize import remember_last, weak_lru_cache -from zipline.utils.pandas_utils import ( - normalize_date, - timedelta_to_integral_minutes, -) + +from zipline.utils.memoize import remember_last +from zipline.utils.pandas_utils import normalize_date from zipline.errors import HistoryWindowStartsBeforeData @@ -1212,128 +1212,6 @@ def get_fetcher_assets(self, dt): else: return [assets] if isinstance(assets, Asset) else [] - # cache size picked somewhat loosely. this code exists purely to - # handle deprecated API. - @weak_lru_cache(20) - def _get_minute_count_for_transform(self, ending_minute, days_count): - # This function works in three steps. - # Step 1. Count the minutes from ``ending_minute`` to the start of its - # session. - # Step 2. Count the minutes from the prior ``days_count - 1`` sessions. - # Step 3. Return the sum of the results from steps (1) and (2). - - # Example (NYSE Calendar) - # ending_minute = 2016-12-28 9:40 AM US/Eastern - # days_count = 3 - # Step 1. Calculate that there are 10 minutes in the ending session. - # Step 2. Calculate that there are 390 + 210 = 600 minutes in the prior - # two sessions. (Prior sessions are 2015-12-23 and 2015-12-24.) - # 2015-12-24 is a half day. - # Step 3. Return 600 + 10 = 610. - - cal = self.trading_calendar - - ending_session = cal.minute_to_session_label( - ending_minute, - direction="none", # It's an error to pass a non-trading minute. - ) - - # Assume that calendar days are always full of contiguous minutes, - # which means we can just take 1 + (number of minutes between the last - # minute and the start of the session). We add one so that we include - # the ending minute in the total. - ending_session_minute_count = ( - timedelta_to_integral_minutes( - ending_minute - cal.open_and_close_for_session(ending_session)[0] - ) - + 1 - ) - - if days_count == 1: - # We just need sessions for the active day. - return ending_session_minute_count - - # XXX: We're subtracting 2 here to account for two offsets: - # 1. We only want ``days_count - 1`` sessions, since we've already - # accounted for the ending session above. - # 2. The API of ``sessions_window`` is to return one more session than - # the requested number. I don't think any consumers actually want - # that behavior, but it's the tested and documented behavior right - # now, so we have to request one less session than we actually want. - completed_sessions = cal.sessions_window( - cal.previous_session_label(ending_session), - 2 - days_count, - ) - - completed_sessions_minute_count = ( - self.trading_calendar.minutes_count_for_sessions_in_range( - completed_sessions[0], completed_sessions[-1] - ) - ) - return ending_session_minute_count + completed_sessions_minute_count - - def get_simple_transform( - self, asset, transform_name, dt, data_frequency, bars=None - ): - if transform_name == "returns": - # returns is always calculated over the last 2 days, regardless - # of the simulation's data frequency. - hst = self.get_history_window( - [asset], - dt, - 2, - "1d", - "price", - data_frequency, - ffill=True, - )[asset] - - return (hst.iloc[-1] - hst.iloc[0]) / hst.iloc[0] - - if bars is None: - raise ValueError("bars cannot be None!") - - if data_frequency == "minute": - freq_str = "1m" - calculated_bar_count = int(self._get_minute_count_for_transform(dt, bars)) - else: - freq_str = "1d" - calculated_bar_count = bars - - price_arr = self.get_history_window( - [asset], - dt, - calculated_bar_count, - freq_str, - "price", - data_frequency, - ffill=True, - )[asset] - - if transform_name == "mavg": - return nanmean(price_arr) - elif transform_name == "stddev": - return nanstd(price_arr, ddof=1) - elif transform_name == "vwap": - volume_arr = self.get_history_window( - [asset], - dt, - calculated_bar_count, - freq_str, - "volume", - data_frequency, - ffill=True, - )[asset] - - vol_sum = nansum(volume_arr) - - try: - ret = nansum(price_arr * volume_arr) / vol_sum - except ZeroDivisionError: - ret = np.nan - - return ret - def get_current_future_chain(self, continuous_future, dt): """ Retrieves the future chain for the contract at the given `dt` according diff --git a/src/zipline/data/minute_bars.py b/src/zipline/data/minute_bars.py index 69ae703dad..0ca42a9161 100644 --- a/src/zipline/data/minute_bars.py +++ b/src/zipline/data/minute_bars.py @@ -301,28 +301,16 @@ def write(self, rootdir): end_session : datetime 'YYYY-MM-DD' formatted representation of the last trading session in the data set. - - Deprecated, but included for backwards compatibility: - - first_trading_day : string - 'YYYY-MM-DD' formatted representation of the first trading day - available in the dataset. - market_opens : list - List of int64 values representing UTC market opens as - minutes since epoch. - market_closes : list - List of int64 values representing UTC market closes as - minutes since epoch. """ - calendar = self.calendar - slicer = calendar.schedule.index.slice_indexer( - self.start_session, - self.end_session, - ) - schedule = calendar.schedule[slicer] - market_opens = schedule.market_open - market_closes = schedule.market_close + # calendar = self.calendar + # slicer = calendar.schedule.index.slice_indexer( + # self.start_session, + # self.end_session, + # ) + # schedule = calendar.schedule[slicer] + # market_opens = schedule.market_open + # market_closes = schedule.market_close metadata = { "version": self.version, @@ -332,14 +320,6 @@ def write(self, rootdir): "calendar_name": self.calendar.name, "start_session": str(self.start_session.date()), "end_session": str(self.end_session.date()), - # Write these values for backwards compatibility - "first_trading_day": str(self.start_session.date()), - "market_opens": ( - market_opens.values.astype("datetime64[m]").astype(np.int64).tolist() - ), - "market_closes": ( - market_closes.values.astype("datetime64[m]").astype(np.int64).tolist() - ), } with open(self.metadata_path(rootdir), "w+") as fp: json.dump(metadata, fp) diff --git a/src/zipline/errors.py b/src/zipline/errors.py index 0614240e21..5ddbe91a55 100644 --- a/src/zipline/errors.py +++ b/src/zipline/errors.py @@ -306,14 +306,6 @@ class IncompatibleHistoryFrequency(ZiplineError): """.strip() -class HistoryInInitialize(ZiplineError): - """ - Raised when an algorithm calls history() in initialize. - """ - - msg = "history() should only be called in handle_data()" - - class OrderInBeforeTradingStart(ZiplineError): """ Raised when an algorithm calls an order method in before_trading_start. diff --git a/src/zipline/gens/tradesimulation.py b/src/zipline/gens/tradesimulation.py index 79ae5a2ee0..ad072ec3cf 100644 --- a/src/zipline/gens/tradesimulation.py +++ b/src/zipline/gens/tradesimulation.py @@ -41,7 +41,6 @@ def __init__( clock, benchmark_source, restrictions, - universe_func, ): # ============== @@ -63,7 +62,7 @@ def __init__( # This object is the way that user algorithms interact with OHLCV data, # fetcher data, and some API methods like `data.can_trade`. - self.current_data = self._create_bar_data(universe_func) + self.current_data = self._create_bar_data() # We don't have a datetime for the current snapshot until we # receive a message. @@ -88,14 +87,13 @@ def inject_algo_dt(record): def get_simulation_dt(self): return self.simulation_dt - def _create_bar_data(self, universe_func): + def _create_bar_data(self): return BarData( data_portal=self.data_portal, simulation_dt_func=self.get_simulation_dt, data_frequency=self.sim_params.data_frequency, trading_calendar=self.algo.trading_calendar, restrictions=self.restrictions, - universe_func=universe_func, ) # TODO: simplify @@ -176,14 +174,11 @@ def once_a_day( # handle any splits that impact any positions or any open orders. assets_we_care_about = ( - metrics_tracker.positions.keys() - | algo.blotter.open_orders.keys() + metrics_tracker.positions.keys() | algo.blotter.open_orders.keys() ) if assets_we_care_about: - splits = data_portal.get_splits( - assets_we_care_about, midnight_dt - ) + splits = data_portal.get_splits(assets_we_care_about, midnight_dt) if splits: algo.blotter.process_splits(splits) metrics_tracker.handle_splits(splits) @@ -296,9 +291,7 @@ def past_auto_close_date(asset): # would not be processed until the first bar of the next day. blotter = algo.blotter assets_to_cancel = [ - asset - for asset in blotter.open_orders - if past_auto_close_date(asset) + asset for asset in blotter.open_orders if past_auto_close_date(asset) ] for asset in assets_to_cancel: blotter.cancel_all_orders_for_asset(asset) diff --git a/src/zipline/protocol.py b/src/zipline/protocol.py index c1a9c51a11..72a9b7b6ac 100644 --- a/src/zipline/protocol.py +++ b/src/zipline/protocol.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from warnings import warn - import pandas as pd from .assets import Asset @@ -110,58 +108,8 @@ def to_series(self, index=None): return pd.Series(self.__dict__, index=index) -def _deprecated_getitem_method(name, attrs): - """Create a deprecated ``__getitem__`` method that tells users to use - getattr instead. - - Parameters - ---------- - name : str - The name of the object in the warning message. - attrs : iterable[str] - The set of allowed attributes. - - Returns - ------- - __getitem__ : callable[any, str] - The ``__getitem__`` method to put in the class dict. - """ - attrs = frozenset(attrs) - msg = ( - "'{name}[{attr!r}]' is deprecated, please use" - " '{name}.{attr}' instead" - ) - - def __getitem__(self, key): - """``__getitem__`` is deprecated, please use attribute access instead.""" - warn(msg.format(name=name, attr=key), DeprecationWarning, stacklevel=2) - if key in attrs: - return getattr(self, key) - raise KeyError(key) - - return __getitem__ - - class Order(Event): - # If you are adding new attributes, don't update this set. This method - # is deprecated to normal attribute access so we don't want to encourage - # new usages. - __getitem__ = _deprecated_getitem_method( - "order", - { - "dt", - "sid", - "amount", - "stop", - "limit", - "id", - "filled", - "commission", - "stop_reached", - "limit_reached", - "created", - }, - ) + pass class Portfolio(object): @@ -211,24 +159,6 @@ def __setattr__(self, attr, value): def __repr__(self): return "Portfolio({0})".format(self.__dict__) - # If you are adding new attributes, don't update this set. This method - # is deprecated to normal attribute access so we don't want to encourage - # new usages. - __getitem__ = _deprecated_getitem_method( - "portfolio", - { - "capital_used", - "starting_cash", - "portfolio_value", - "pnl", - "returns", - "cash", - "positions", - "start_date", - "positions_value", - }, - ) - @property def current_portfolio_weights(self): """ @@ -242,9 +172,7 @@ def current_portfolio_weights(self): position_values = pd.Series( { asset: ( - position.last_sale_price - * position.amount - * asset.price_multiplier + position.last_sale_price * position.amount * asset.price_multiplier ) for asset, position in self.positions.items() }, @@ -287,32 +215,6 @@ def __setattr__(self, attr, value): def __repr__(self): return "Account({0})".format(self.__dict__) - # If you are adding new attributes, don't update this set. This method - # is deprecated to normal attribute access so we don't want to encourage - # new usages. - __getitem__ = _deprecated_getitem_method( - "account", - { - "settled_cash", - "accrued_interest", - "buying_power", - "equity_with_loan", - "total_positions_value", - "total_positions_exposure", - "regt_equity", - "regt_margin", - "initial_margin_requirement", - "maintenance_margin_requirement", - "available_funds", - "excess_liquidity", - "cushion", - "day_trades_remaining", - "leverage", - "net_leverage", - "net_liquidation", - }, - ) - class Position(object): """ @@ -361,49 +263,6 @@ def __repr__(self): ) } - # If you are adding new attributes, don't update this set. This method - # is deprecated to normal attribute access so we don't want to encourage - # new usages. - __getitem__ = _deprecated_getitem_method( - "position", - { - "sid", - "amount", - "cost_basis", - "last_sale_price", - "last_sale_date", - }, - ) - - -# Copied from Position and renamed. This is used to handle cases where a user -# does something like `context.portfolio.positions[100]` instead of -# `context.portfolio.positions[sid(100)]`. -class _DeprecatedSidLookupPosition(object): - def __init__(self, sid): - self.sid = sid - self.amount = 0 - self.cost_basis = 0.0 # per share - self.last_sale_price = 0.0 - self.last_sale_date = None - - def __repr__(self): - return "_DeprecatedSidLookupPosition({0})".format(self.__dict__) - - # If you are adding new attributes, don't update this set. This method - # is deprecated to normal attribute access so we don't want to encourage - # new usages. - __getitem__ = _deprecated_getitem_method( - "position", - { - "sid", - "amount", - "cost_basis", - "last_sale_price", - "last_sale_date", - }, - ) - class Positions(dict): """A dict-like object containing the algorithm's current positions.""" @@ -411,15 +270,8 @@ class Positions(dict): def __missing__(self, key): if isinstance(key, Asset): return Position(InnerPosition(key)) - elif isinstance(key, int): - warn( - "Referencing positions by integer is deprecated." - " Use an asset instead." - ) - else: - warn( - "Position lookup expected a value of type Asset but got {0}" - " instead.".format(type(key).__name__) - ) - return _DeprecatedSidLookupPosition(key) + raise ValueError( + "Position lookup expected a value of type Asset" + f" but got {type(key).__name__} instead" + ) diff --git a/tests/test_api_shim.py b/tests/test_api_shim.py index 81f3b9c87b..da1f44757a 100644 --- a/tests/test_api_shim.py +++ b/tests/test_api_shim.py @@ -1,123 +1,11 @@ -import warnings - -from mock import patch -import numpy as np import pandas as pd -from zipline.finance.trading import SimulationParameters -from zipline.testing import ( - MockDailyBarReader, - create_daily_df_for_asset, - create_minute_df_for_asset, - str_to_seconds, -) from zipline.testing.fixtures import ( WithCreateBarData, WithMakeAlgo, ZiplineTestCase, ) -from zipline.utils.pandas_utils import PerformanceWarning -from zipline.zipline_warnings import ZiplineDeprecationWarning - -simple_algo = """ -from zipline.api import sid, order -def initialize(context): - pass - -def handle_data(context, data): - assert sid(1) in data - assert sid(2) in data - assert len(data) == 3 - for asset in data: - pass -""" - -history_algo = """ -from zipline.api import sid, history - -def initialize(context): - context.sid1 = sid(1) - -def handle_data(context, data): - context.history_window = history(5, "1m", "volume") -""" - -history_bts_algo = """ -from zipline.api import sid, history, record - -def initialize(context): - context.sid3 = sid(3) - context.num_bts = 0 - -def before_trading_start(context, data): - context.num_bts += 1 - - # Get history at the second BTS (beginning of second day) - if context.num_bts == 2: - record(history=history(5, "1m", "volume")) - -def handle_data(context, data): - pass -""" - -simple_transforms_algo = """ -from zipline.api import sid -def initialize(context): - context.count = 0 - -def handle_data(context, data): - if context.count == 2: - context.mavg = data[sid(1)].mavg(5) - context.vwap = data[sid(1)].vwap(5) - context.stddev = data[sid(1)].stddev(5) - context.returns = data[sid(1)].returns() - - context.count += 1 -""" - -manipulation_algo = """ -def initialize(context): - context.asset1 = sid(1) - context.asset2 = sid(2) - -def handle_data(context, data): - assert len(data) == 2 - assert len(data.keys()) == 2 - assert context.asset1 in data.keys() - assert context.asset2 in data.keys() -""" - -sid_accessor_algo = """ -from zipline.api import sid - -def initialize(context): - context.asset1 = sid(1) - -def handle_data(context,data): - assert data[sid(1)].sid == context.asset1 - assert data[sid(1)]["sid"] == context.asset1 -""" - -data_items_algo = """ -from zipline.api import sid - -def initialize(context): - context.asset1 = sid(1) - context.asset2 = sid(2) - -def handle_data(context, data): - iter_list = list(data.iteritems()) - items_list = data.items() - assert iter_list == items_list -""" - -reference_missing_position_by_int_algo = """ -def initialize(context): - pass - -def handle_data(context, data): - context.portfolio.positions[24] -""" +import pytest reference_missing_position_by_unexpected_type_algo = """ def initialize(context): @@ -136,45 +24,6 @@ class TestAPIShim(WithCreateBarData, WithMakeAlgo, ZiplineTestCase): sids = ASSET_FINDER_EQUITY_SIDS = 1, 2, 3 - @classmethod - def make_equity_minute_bar_data(cls): - for sid in cls.sids: - yield sid, create_minute_df_for_asset( - cls.trading_calendar, - cls.SIM_PARAMS_START, - cls.SIM_PARAMS_END, - ) - - @classmethod - def make_equity_daily_bar_data(cls, country_code, sids): - for sid in sids: - yield sid, create_daily_df_for_asset( - cls.trading_calendar, - cls.SIM_PARAMS_START, - cls.SIM_PARAMS_END, - ) - - @classmethod - def make_splits_data(cls): - return pd.DataFrame( - [ - { - "effective_date": str_to_seconds("2016-01-06"), - "ratio": 0.5, - "sid": 3, - } - ] - ) - - @classmethod - def make_adjustment_writer_equity_daily_bar_reader(cls): - return MockDailyBarReader( - dates=cls.nyse_calendar.sessions_in_range( - cls.START_DATE, - cls.END_DATE, - ), - ) - @classmethod def init_class_fixtures(cls): super(TestAPIShim, cls).init_class_fixtures() @@ -191,360 +40,10 @@ def create_algo(self, code, filename=None, sim_params=None): script=code, sim_params=sim_params, algo_filename=filename ) - def test_old_new_data_api_paths(self): - """ - Test that the new and old data APIs hit the same code paths. - - We want to ensure that the old data API(data[sid(N)].field and - similar) and the new data API(data.current(sid(N), field) and - similar) hit the same code paths on the DataPortal. - """ - test_start_minute = self.trading_calendar.minutes_for_session( - self.sim_params.sessions[0] - )[1] - test_end_minute = self.trading_calendar.minutes_for_session( - self.sim_params.sessions[0] - )[-1] - bar_data = self.create_bardata( - lambda: test_end_minute, - ) - ohlcvp_fields = [ - "open", - "high", - "low" "close", - "volume", - "price", - ] - spot_value_meth = "zipline.data.data_portal.DataPortal.get_spot_value" - - def assert_get_spot_value_called(fun, field): - """ - Assert that get_spot_value was called during the execution of fun. - - Takes in a function fun and a string field. - """ - with patch(spot_value_meth) as gsv: - fun() - gsv.assert_called_with(self.asset1, field, test_end_minute, "minute") - - # Ensure that data.current(sid(n), field) has the same behaviour as - # data[sid(n)].field. - for field in ohlcvp_fields: - assert_get_spot_value_called( - lambda: getattr(bar_data[self.asset1], field), - field, - ) - assert_get_spot_value_called( - lambda: bar_data.current(self.asset1, field), - field, - ) - - history_meth = "zipline.data.data_portal.DataPortal.get_history_window" - - def assert_get_history_window_called(fun, is_legacy): - """ - Assert that get_history_window was called during fun(). - - Takes in a function fun and a boolean is_legacy. - """ - with patch(history_meth) as ghw: - fun() - # Slightly hacky, but done to get around the fact that - # history( explicitly passes an ffill param as the last arg, - # while data.history doesn't. - if is_legacy: - ghw.assert_called_with( - [self.asset1, self.asset2, self.asset3], - test_end_minute, - 5, - "1m", - "volume", - "minute", - True, - ) - else: - ghw.assert_called_with( - [self.asset1, self.asset2, self.asset3], - test_end_minute, - 5, - "1m", - "volume", - "minute", - ) - - test_sim_params = SimulationParameters( - start_session=test_start_minute, - end_session=test_end_minute, - data_frequency="minute", - trading_calendar=self.trading_calendar, - ) - - history_algorithm = self.create_algo(history_algo, sim_params=test_sim_params) - assert_get_history_window_called( - lambda: history_algorithm.run(), is_legacy=True - ) - assert_get_history_window_called( - lambda: bar_data.history( - [self.asset1, self.asset2, self.asset3], "volume", 5, "1m" - ), - is_legacy=False, - ) - - def test_sid_accessor(self): - """ - Test that we maintain backwards compat for sid access on a data object. - - We want to support both data[sid(24)].sid, as well as - data[sid(24)]["sid"]. Since these are deprecated and will eventually - cease to be supported, we also want to assert that we're seeing a - deprecation warning. - """ - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("ignore", PerformanceWarning) - warnings.simplefilter("ignore", RuntimeWarning, append=True) - warnings.simplefilter("default", ZiplineDeprecationWarning) - algo = self.create_algo(sid_accessor_algo) - algo.run() - - # Since we're already raising a warning on doing data[sid(x)], - # we don't want to raise an extra warning on data[sid(x)].sid. - assert 2 == len(w) - - # Check that both the warnings raised were in fact - # ZiplineDeprecationWarnings - for warning in w: - assert ZiplineDeprecationWarning == warning.category - assert "`data[sid(N)]` is deprecated. Use `data.current`." == str( - warning.message - ) - - def test_data_items(self): - """ - Test that we maintain backwards compat for data.[items | iteritems]. - - We also want to assert that we warn that iterating over the assets - in `data` is deprecated. - """ - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("ignore", PerformanceWarning) - # Divide by null warning from empyrical - warnings.simplefilter("ignore", RuntimeWarning, append=True) - warnings.simplefilter("default", ZiplineDeprecationWarning) - algo = self.create_algo(data_items_algo) - algo.run() - - assert 4 == len(w) - - for idx, warning in enumerate(w): - assert ZiplineDeprecationWarning == warning.category - if idx % 2 == 0: - assert "Iterating over the assets in `data` is deprecated." == str( - warning.message - ) - else: - assert "`data[sid(N)]` is deprecated. Use `data.current`." == str( - warning.message - ) - - def test_iterate_data(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("ignore", PerformanceWarning) - warnings.simplefilter("ignore", RuntimeWarning, append=True) - warnings.simplefilter("default", ZiplineDeprecationWarning) - - algo = self.create_algo(simple_algo) - algo.run() - - assert 4 == len(w) - - line_nos = [warning.lineno for warning in w] - assert 4 == len(set(line_nos)) - - for idx, warning in enumerate(w): - assert ZiplineDeprecationWarning == warning.category - - assert "" == warning.filename - assert line_nos[idx] == warning.lineno - - if idx < 2: - assert "Checking whether an asset is in data is deprecated." == str( - warning.message - ) - else: - assert "Iterating over the assets in `data` is deprecated." == str( - warning.message - ) - - def test_history(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("ignore", PerformanceWarning) - # Divide by null warning from empyrical - warnings.simplefilter("ignore", RuntimeWarning, append=True) - warnings.simplefilter("default", ZiplineDeprecationWarning) - - sim_params = self.sim_params.create_new( - self.sim_params.sessions[1], self.sim_params.end_session - ) - - algo = self.create_algo(history_algo, sim_params=sim_params) - algo.run() - - assert 1 == len(w) - assert ZiplineDeprecationWarning == w[0].category - assert "" == w[0].filename - assert 8 == w[0].lineno - assert ( - "The `history` method is deprecated. Use " - "`data.history` instead." == str(w[0].message) - ) - - def test_old_new_history_bts_paths(self): - """ - Tests that calling history in before_trading_start gets us the correct - values, which involves 1) calling data_portal.get_history_window as of - the previous market minute, 2) getting adjustments between the previous - market minute and the current time, and 3) applying those adjustments - """ - algo = self.create_algo(history_bts_algo) - algo.run() - - expected_vol_without_split = np.arange(386, 391) * 100 - expected_vol_with_split = np.arange(386, 391) * 200 - - window = algo.recorded_vars["history"] - np.testing.assert_array_equal( - window[self.asset1].values, expected_vol_without_split - ) - np.testing.assert_array_equal( - window[self.asset2].values, expected_vol_without_split - ) - np.testing.assert_array_equal( - window[self.asset3].values, expected_vol_with_split - ) - - def test_simple_transforms(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("ignore", PerformanceWarning) - warnings.simplefilter("ignore", RuntimeWarning, append=True) - warnings.simplefilter("default", ZiplineDeprecationWarning) - - sim_params = SimulationParameters( - start_session=self.sim_params.sessions[8], - end_session=self.sim_params.sessions[-1], - data_frequency="minute", - trading_calendar=self.trading_calendar, - ) - - algo = self.create_algo(simple_transforms_algo, sim_params=sim_params) - algo.run() - - assert 8 == len(w) - transforms = ["mavg", "vwap", "stddev", "returns"] - - for idx, line_no in enumerate(range(8, 12)): - warning1 = w[idx * 2] - warning2 = w[(idx * 2) + 1] - - assert "" == warning1.filename - assert "" == warning2.filename - - assert line_no == warning1.lineno - assert line_no == warning2.lineno - - assert "`data[sid(N)]` is deprecated. Use " "`data.current`." == str( - warning1.message - ) - assert "The `{0}` method is " "deprecated.".format( - transforms[idx] - ) == str(warning2.message) - - # now verify the transform values - # minute price - # 2016-01-11 14:31:00+00:00 1561 - # ... - # 2016-01-14 20:59:00+00:00 3119 - # 2016-01-14 21:00:00+00:00 3120 - # 2016-01-15 14:31:00+00:00 3121 - # 2016-01-15 14:32:00+00:00 3122 - # 2016-01-15 14:33:00+00:00 3123 - - # volume - # 2016-01-11 14:31:00+00:00 156100 - # ... - # 2016-01-14 20:59:00+00:00 311900 - # 2016-01-14 21:00:00+00:00 312000 - # 2016-01-15 14:31:00+00:00 312100 - # 2016-01-15 14:32:00+00:00 312200 - # 2016-01-15 14:33:00+00:00 312300 - - # daily price (last day built with minute data) - # 2016-01-14 00:00:00+00:00 9 - # 2016-01-15 00:00:00+00:00 3123 - - # mavg = average of all the prices = (1561 + 3123) / 2 = 2342 - # vwap = sum(price * volume) / sum(volumes) - # = 889119531400.0 / 366054600.0 - # = 2428.9259891830343 - # stddev = stddev(price, ddof=1) = 451.3435498597493 - # returns = (todayprice - yesterdayprice) / yesterdayprice - # = (3123 - 9) / 9 = 346 - assert 2342 == algo.mavg - self.assertAlmostEqual(2428.92599, algo.vwap, places=5) - self.assertAlmostEqual(451.34355, algo.stddev, places=5) - self.assertAlmostEqual(346, algo.returns) - - def test_manipulation(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("ignore", PerformanceWarning) - warnings.simplefilter("ignore", RuntimeWarning, append=True) - warnings.simplefilter("default", ZiplineDeprecationWarning) - - algo = self.create_algo(simple_algo) - algo.run() - - assert 4 == len(w) - - for idx, warning in enumerate(w): - assert "" == warning.filename - assert 7 + idx == warning.lineno - - if idx < 2: - assert ( - "Checking whether an asset is in data is " - "deprecated." == str(warning.message) - ) - else: - assert ( - "Iterating over the assets in `data` is " - "deprecated." == str(warning.message) - ) - - def test_reference_empty_position_by_int(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("ignore", RuntimeWarning, append=True) - warnings.simplefilter("default", ZiplineDeprecationWarning) - - algo = self.create_algo(reference_missing_position_by_int_algo) - algo.run() - - assert 1 == len(w) - assert ( - str(w[0].message) - == "Referencing positions by integer is deprecated. Use an asset " - "instead." - ) - def test_reference_empty_position_by_unexpected_type(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("ignore", RuntimeWarning, append=True) - warnings.simplefilter("default", ZiplineDeprecationWarning) - - algo = self.create_algo(reference_missing_position_by_unexpected_type_algo) + algo = self.create_algo(reference_missing_position_by_unexpected_type_algo) + with pytest.raises( + ValueError, + match="Position lookup expected a value of type Asset but got str instead", + ): algo.run() - - assert 1 == len(w) - assert ( - str(w[0].message) - == "Position lookup expected a value of type Asset but got str" - " instead." - ) diff --git a/tests/test_assets.py b/tests/test_assets.py index 537660b324..f1ff1d54e7 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -25,7 +25,6 @@ import sys from types import GetSetDescriptorType import uuid -import warnings from parameterized import parameterized import numpy as np @@ -1007,29 +1006,6 @@ def test_lookup_generic_multiple_symbols_across_countries(self): assert [matches] == [self.asset_finder.retrieve_asset(1)] assert missing == [] - def test_security_dates_warning(self): - - # Build an asset with an end_date - eq_end = pd.Timestamp("2012-01-01", tz="UTC") - equity_asset = Equity( - 1, - symbol="TESTEQ", - end_date=eq_end, - exchange_info=ExchangeInfo("TEST", "TEST", "??"), - ) - - # Catch all warnings - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered - warnings.simplefilter("always") - equity_asset.security_start_date - equity_asset.security_end_date - equity_asset.security_name - # Verify the warning - assert 3 == len(w) - for warning in w: - assert issubclass(warning.category, DeprecationWarning) - def test_compute_lifetimes(self): assets_per_exchange = 4 trading_day = self.trading_calendar.day diff --git a/tests/test_data_portal.py b/tests/test_data_portal.py index 98aead7176..951bebe7f7 100644 --- a/tests/test_data_portal.py +++ b/tests/test_data_portal.py @@ -17,7 +17,6 @@ from numpy import array, append, nan, full from numpy.testing import assert_almost_equal import pandas as pd -from pandas import Timedelta from zipline.assets import Equity, Future from zipline.data.data_portal import HISTORY_FREQUENCIES, OHLCV_FIELDS @@ -453,48 +452,6 @@ def test_get_adjustments(self, data_frequency, field): err_msg="at dt={} perspective={}".format(dt, perspective_dt), ) - def test_bar_count_for_simple_transforms(self): - # July 2015 - # Su Mo Tu We Th Fr Sa - # 1 2 3 4 - # 5 6 7 8 9 10 11 - # 12 13 14 15 16 17 18 - # 19 20 21 22 23 24 25 - # 26 27 28 29 30 31 - - # half an hour into july 9, getting a 4-"day" window should get us - # all the minutes of 7/6, 7/7, 7/8, and 31 minutes of 7/9 - - july_9_dt = self.trading_calendar.open_and_close_for_session( - pd.Timestamp("2015-07-09", tz="UTC") - )[0] + Timedelta("30 minutes") - - assert (3 * 390) + 31 == self.data_portal._get_minute_count_for_transform( - july_9_dt, 4 - ) - - # November 2015 - # Su Mo Tu We Th Fr Sa - # 1 2 3 4 5 6 7 - # 8 9 10 11 12 13 14 - # 15 16 17 18 19 20 21 - # 22 23 24 25 26 27 28 - # 29 30 - - # nov 26th closed - # nov 27th was an early close - - # half an hour into nov 30, getting a 4-"day" window should get us - # all the minutes of 11/24, 11/25, 11/27 (half day!), and 31 minutes - # of 11/30 - nov_30_dt = self.trading_calendar.open_and_close_for_session( - pd.Timestamp("2015-11-30", tz="UTC") - )[0] + Timedelta("30 minutes") - - assert 390 + 390 + 210 + 31 == self.data_portal._get_minute_count_for_transform( - nov_30_dt, 4 - ) - def test_get_last_traded_dt_minute(self): minutes = self.nyse_calendar.minutes_for_session(self.trading_days[2]) equity = self.asset_finder.retrieve_asset(1) diff --git a/tests/test_history.py b/tests/test_history.py index 1b25327b01..7b3ef3a039 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -22,7 +22,6 @@ from zipline._protocol import handle_non_market_minutes, BarData from zipline.assets import Asset, Equity from zipline.errors import ( - HistoryInInitialize, HistoryWindowStartsBeforeData, ) from zipline.finance.asset_restrictions import NoRestrictions @@ -203,34 +202,22 @@ def make_dividends_data(cls): [ { # only care about ex date, the other dates don't matter here - "ex_date": pd.Timestamp( - "2015-01-06", tz="UTC" - ).to_datetime64(), - "record_date": pd.Timestamp( - "2015-01-06", tz="UTC" - ).to_datetime64(), + "ex_date": pd.Timestamp("2015-01-06", tz="UTC").to_datetime64(), + "record_date": pd.Timestamp("2015-01-06", tz="UTC").to_datetime64(), "declared_date": pd.Timestamp( "2015-01-06", tz="UTC" ).to_datetime64(), - "pay_date": pd.Timestamp( - "2015-01-06", tz="UTC" - ).to_datetime64(), + "pay_date": pd.Timestamp("2015-01-06", tz="UTC").to_datetime64(), "amount": 2.0, "sid": cls.DIVIDEND_ASSET_SID, }, { - "ex_date": pd.Timestamp( - "2015-01-07", tz="UTC" - ).to_datetime64(), - "record_date": pd.Timestamp( - "2015-01-07", tz="UTC" - ).to_datetime64(), + "ex_date": pd.Timestamp("2015-01-07", tz="UTC").to_datetime64(), + "record_date": pd.Timestamp("2015-01-07", tz="UTC").to_datetime64(), "declared_date": pd.Timestamp( "2015-01-07", tz="UTC" ).to_datetime64(), - "pay_date": pd.Timestamp( - "2015-01-07", tz="UTC" - ).to_datetime64(), + "pay_date": pd.Timestamp("2015-01-07", tz="UTC").to_datetime64(), "amount": 4.0, "sid": cls.DIVIDEND_ASSET_SID, }, @@ -255,9 +242,7 @@ def make_adjustment_writer_equity_daily_bar_reader(cls): ) # TODO: simplify (flake8) - def verify_regular_dt( - self, idx, dt, mode, fields=None, assets=None - ): # noqa: C901 + def verify_regular_dt(self, idx, dt, mode, fields=None, assets=None): # noqa: C901 if mode == "daily": freq = "1d" else: @@ -344,16 +329,13 @@ def reindex_to_primary_calendar(a, field): # and some real values np.testing.assert_array_equal( - np.array(range(base, base + present_count + 1)) - * 100, + np.array(range(base, base + present_count + 1)) * 100, asset_series[(9 - present_count) :], ) if asset == self.ASSET3: # asset3 is all zeros, no volume yet - np.testing.assert_array_equal( - np.zeros(10), asset_series - ) + np.testing.assert_array_equal(np.zeros(10), asset_series) else: # asset3 should have data every 10 minutes # construct an array full of nans, put something in the @@ -380,9 +362,7 @@ def reindex_to_primary_calendar(a, field): if asset == self.ASSET2: np.testing.assert_array_equal( reindex_to_primary_calendar( - np.array( - range(base + idx - 9, base + idx + 1) - ), + np.array(range(base + idx - 9, base + idx + 1)), field, ), asset_series, @@ -394,9 +374,7 @@ def reindex_to_primary_calendar(a, field): ) elif field == "volume": asset3_answer_key = np.zeros(10) - asset3_answer_key[-position_from_end] = ( - value_for_asset3 * 100 - ) + asset3_answer_key[-position_from_end] = value_for_asset3 * 100 asset3_answer_key = reindex_to_primary_calendar( asset3_answer_key, field, @@ -405,9 +383,7 @@ def reindex_to_primary_calendar(a, field): if asset == self.ASSET2: np.testing.assert_array_equal( reindex_to_primary_calendar( - np.array( - range(base + idx - 9, base + idx + 1) - ) + np.array(range(base + idx - 9, base + idx + 1)) * 100, field, ), @@ -436,9 +412,7 @@ def reindex_to_primary_calendar(a, field): if asset == self.ASSET3: # Second part begins on the session after # `position_from_end` on the NYSE calendar. - second_begin = dt - equity_cal.day * ( - position_from_end - 1 - ) + second_begin = dt - equity_cal.day * (position_from_end - 1) # First part goes up until the start of the # second part, because we forward-fill. @@ -467,17 +441,13 @@ def reindex_to_primary_calendar(a, field): ) else: np.testing.assert_array_equal( - np.array( - [decile_count * 10 - 9] - * len(first_part) - ), + np.array([decile_count * 10 - 9] * len(first_part)), first_part, ) np.testing.assert_array_equal( np.array( - [decile_count * 10 + 1] - * len(second_part) + [decile_count * 10 + 1] * len(second_part) ), second_part, ) @@ -500,8 +470,7 @@ def check_internal_consistency(bar_data, assets, fields, bar_count, freq): } multi_asset_dict = { - field: bar_data.history(asset_list, field, bar_count, freq) - for field in fields + field: bar_data.history(asset_list, field, bar_count, freq) for field in fields } df = bar_data.history(asset_list, field_list, bar_count, freq) @@ -512,13 +481,9 @@ def check_internal_consistency(bar_data, assets, fields, bar_count, freq): for asset in asset_list: series = bar_data.history(asset, field, bar_count, freq) - np.testing.assert_array_equal( - series, multi_asset_dict[field][asset] - ) + np.testing.assert_array_equal(series, multi_asset_dict[field][asset]) - np.testing.assert_array_equal( - series, multi_field_dict[asset][field] - ) + np.testing.assert_array_equal(series, multi_field_dict[asset][field]) np.testing.assert_array_equal( series, df.loc[pd.IndexSlice[:, asset], field] @@ -538,9 +503,7 @@ def check_internal_consistency(bar_data, assets, fields, bar_count, freq): } -class MinuteEquityHistoryTestCase( - WithHistory, zf.WithMakeAlgo, zf.ZiplineTestCase -): +class MinuteEquityHistoryTestCase(WithHistory, zf.WithMakeAlgo, zf.ZiplineTestCase): EQUITY_DAILY_BAR_SOURCE_FROM_MINUTE = True DATA_PORTAL_FIRST_TRADING_DAY = zf.alias("TRADING_START_DT") @@ -621,29 +584,27 @@ def make_equity_minute_bar_data(cls): ) return data.items() - def test_history_in_initialize(self): - algo_text = dedent( - """\ - from zipline.api import history + # def test_history_in_initialize(self): + # algo_text = dedent( + # """\ + # from zipline.api import history - def initialize(context): - history([1], 10, '1d', 'price') + # def initialize(context): + # history([1], 10, '1d', 'price') - def handle_data(context, data): - pass - """ - ) - algo = self.make_algo(script=algo_text) - with pytest.raises(HistoryInInitialize): - algo.run() + # def handle_data(context, data): + # pass + # """ + # ) + # algo = self.make_algo(script=algo_text) + # with pytest.raises(HistoryInInitialize): + # algo.run() def test_negative_bar_count(self): """ Negative bar counts leak future information. """ - with pytest.raises( - ValueError, match="bar_count must be >= 1, but got -1" - ): + with pytest.raises(ValueError, match="bar_count must be >= 1, but got -1"): self.data_portal.get_history_window( [self.ASSET1], pd.Timestamp("2015-01-07 14:35", tz="UTC"), @@ -814,13 +775,9 @@ def test_minute_before_assets_trading(self): np.testing.assert_array_equal(np.zeros(10), asset2_series) np.testing.assert_array_equal(np.zeros(10), asset3_series) else: - np.testing.assert_array_equal( - np.full(10, np.nan), asset2_series - ) + np.testing.assert_array_equal(np.full(10, np.nan), asset2_series) - np.testing.assert_array_equal( - np.full(10, np.nan), asset3_series - ) + np.testing.assert_array_equal(np.full(10, np.nan), asset3_series) @parameterized.expand( [ @@ -887,9 +844,7 @@ def test_minute_after_asset_stopped(self): for idx, minute in enumerate(minutes): bar_data = self.create_bardata(lambda: minute) - check_internal_consistency( - bar_data, self.SHORT_ASSET, ALL_FIELDS, 30, "1m" - ) + check_internal_consistency(bar_data, self.SHORT_ASSET, ALL_FIELDS, 30, "1m") # Reset data portal because it has advanced past next test date. data_portal = self.make_data_portal() @@ -954,17 +909,13 @@ def test_minute_after_asset_stopped(self): np.testing.assert_array_equal( range(76800, 78101, 100), window["volume"][0:14] ) - np.testing.assert_array_equal( - np.zeros(16), window["volume"][-16:] - ) + np.testing.assert_array_equal(np.zeros(16), window["volume"][-16:]) else: np.testing.assert_array_equal( np.array(range(768, 782)) + MINUTE_FIELD_INFO[field], window[field][0:14], ) - np.testing.assert_array_equal( - np.full(16, np.nan), window[field][-16:] - ) + np.testing.assert_array_equal(np.full(16, np.nan), window[field][-16:]) # now do a smaller window that is entirely contained after the asset # ends @@ -1179,9 +1130,7 @@ def test_passing_iterable_to_history_bts(self): ) with handle_non_market_minutes(bar_data): - bar_data.history( - pd.Index([self.ASSET1, self.ASSET2]), "high", 5, "1m" - ) + bar_data.history(pd.Index([self.ASSET1, self.ASSET2]), "high", 5, "1m") # for some obscure reason at best 2 of 3 cases of can pass depending on */ order # in last two assert_array_equal @@ -1240,9 +1189,7 @@ def test_overnight_adjustments(self): values = bar_data.history( self.SPLIT_ASSET, ["open", "volume"], window_length, "1m" ) - np.testing.assert_array_equal( - values.open.values[:10], adj_expected["open"] - ) + np.testing.assert_array_equal(values.open.values[:10], adj_expected["open"]) np.testing.assert_array_equal( values.volume.values[:10], adj_expected["volume"] ) @@ -1266,15 +1213,11 @@ def test_overnight_adjustments(self): "1m", ) np.testing.assert_array_equal( - values.loc[pd.IndexSlice[:, self.SPLIT_ASSET], "open"].values[ - :10 - ], + values.loc[pd.IndexSlice[:, self.SPLIT_ASSET], "open"].values[:10], adj_expected["open"], ) np.testing.assert_array_equal( - values.loc[pd.IndexSlice[:, self.SPLIT_ASSET], "volume"].values[ - :10 - ], + values.loc[pd.IndexSlice[:, self.SPLIT_ASSET], "volume"].values[:10], adj_expected["volume"], ) np.testing.assert_array_equal( @@ -1722,9 +1665,7 @@ def make_equity_daily_bar_data(cls, country_code, sids): ) @classmethod - def create_df_for_asset( - cls, start_day, end_day, interval=1, force_zeroes=False - ): + def create_df_for_asset(cls, start_day, end_day, interval=1, force_zeroes=False): sessions = cls.trading_calendars[Equity].sessions_in_range( start_day, end_day, @@ -1780,13 +1721,9 @@ def test_daily_before_assets_trading(self): np.testing.assert_array_equal(np.zeros(10), asset2_series) np.testing.assert_array_equal(np.zeros(10), asset3_series) else: - np.testing.assert_array_equal( - np.full(10, np.nan), asset2_series - ) + np.testing.assert_array_equal(np.full(10, np.nan), asset2_series) - np.testing.assert_array_equal( - np.full(10, np.nan), asset3_series - ) + np.testing.assert_array_equal(np.full(10, np.nan), asset3_series) def test_daily_regular(self): # asset2 and asset3 both started on 1/5/2015, but asset3 trades every @@ -1811,26 +1748,18 @@ def test_daily_some_assets_stopped(self): ) for field in OHLCP: - window = bar_data.history( - [self.ASSET1, self.ASSET2], field, 15, "1d" - ) + window = bar_data.history([self.ASSET1, self.ASSET2], field, 15, "1d") # last 2 values for asset2 should be NaN (# of days since asset2 # delisted) - np.testing.assert_array_equal( - np.full(2, np.nan), window[self.ASSET2][-2:] - ) + np.testing.assert_array_equal(np.full(2, np.nan), window[self.ASSET2][-2:]) # third from last value should not be NaN assert not np.isnan(window[self.ASSET2][-3]) - volume_window = bar_data.history( - [self.ASSET1, self.ASSET2], "volume", 15, "1d" - ) + volume_window = bar_data.history([self.ASSET1, self.ASSET2], "volume", 15, "1d") - np.testing.assert_array_equal( - np.zeros(2), volume_window[self.ASSET2][-2:] - ) + np.testing.assert_array_equal(np.zeros(2), volume_window[self.ASSET2][-2:]) assert 0 != volume_window[self.ASSET2][-3] @@ -1845,21 +1774,15 @@ def test_daily_after_asset_stopped(self): # days has 1/7, 1/8 for idx, day in enumerate(days): bar_data = self.create_bardata(simulation_dt_func=lambda: day) - check_internal_consistency( - bar_data, self.SHORT_ASSET, ALL_FIELDS, 2, "1d" - ) + check_internal_consistency(bar_data, self.SHORT_ASSET, ALL_FIELDS, 2, "1d") for field in ALL_FIELDS: - asset_series = bar_data.history( - self.SHORT_ASSET, field, 2, "1d" - ) + asset_series = bar_data.history(self.SHORT_ASSET, field, 2, "1d") if idx == 0: # one value, then one NaN. base value for 1/6 is 3. if field in OHLCP: - assert ( - 3 + MINUTE_FIELD_INFO[field] == asset_series.iloc[0] - ) + assert 3 + MINUTE_FIELD_INFO[field] == asset_series.iloc[0] assert np.isnan(asset_series.iloc[1]) elif field == "volume": @@ -2005,40 +1928,28 @@ def test_daily_blended_some_assets_stopped(self): # asset2 ends on 2016-01-04 bar_data = self.create_bardata( - simulation_dt_func=lambda: pd.Timestamp( - "2016-01-06 16:00", tz="UTC" - ), + simulation_dt_func=lambda: pd.Timestamp("2016-01-06 16:00", tz="UTC"), ) for field in OHLCP: - window = bar_data.history( - [self.ASSET1, self.ASSET2], field, 15, "1d" - ) + window = bar_data.history([self.ASSET1, self.ASSET2], field, 15, "1d") # last 2 values for asset2 should be NaN - np.testing.assert_array_equal( - np.full(2, np.nan), window[self.ASSET2][-2:] - ) + np.testing.assert_array_equal(np.full(2, np.nan), window[self.ASSET2][-2:]) # third from last value should not be NaN assert not np.isnan(window[self.ASSET2][-3]) - volume_window = bar_data.history( - [self.ASSET1, self.ASSET2], "volume", 15, "1d" - ) + volume_window = bar_data.history([self.ASSET1, self.ASSET2], "volume", 15, "1d") - np.testing.assert_array_equal( - np.zeros(2), volume_window[self.ASSET2][-2:] - ) + np.testing.assert_array_equal(np.zeros(2), volume_window[self.ASSET2][-2:]) assert 0 != volume_window[self.ASSET2][-3] def test_history_window_before_first_trading_day(self): # trading_start is 2/3/2014 # get a history window that starts before that, and ends after that - second_day = self.trading_calendar.next_session_label( - self.TRADING_START_DT - ) + second_day = self.trading_calendar.next_session_label(self.TRADING_START_DT) exp_msg = ( "History window extends before 2014-01-03. To use this history " @@ -2066,9 +1977,7 @@ def test_history_window_before_first_trading_day(self): )[self.ASSET1] # Use a minute to force minute mode. - first_minute = self.trading_calendar.schedule.market_open[ - self.TRADING_START_DT - ] + first_minute = self.trading_calendar.schedule.market_open[self.TRADING_START_DT] with pytest.raises(HistoryWindowStartsBeforeData, match=exp_msg): self.data_portal.get_history_window( diff --git a/tests/test_tradesimulation.py b/tests/test_tradesimulation.py index 87a5c9be07..66fecf2ed4 100644 --- a/tests/test_tradesimulation.py +++ b/tests/test_tradesimulation.py @@ -114,7 +114,6 @@ def initialize(context): BeforeTradingStartsOnlyClock(dt), benchmark_source, NoRestrictions(), - None, ) # run through the algo's simulation