Skip to content

Commit

Permalink
Refactor Correlation and beta indicators (#8485)
Browse files Browse the repository at this point in the history
* WIP: Refactor Correlation Indicator

* Simplified comparison logic and improved abstraction

* Refactor Correlation and Beta indicators

- Created a base class to handle indicators with dual-symbol
  functionality.
- Refactored the Beta and Correlation indicators to inherit from the new
  base class.
- Updated unit tests.
- Added a new regression test to validate the latest computed value.

* Addressed review comments

* Update regression test

* Add new unit test to CommonIndicatorTests

* Addressed new review comments
  • Loading branch information
JosueNina authored Jan 2, 2025
1 parent fb49f5d commit f3ea223
Show file tree
Hide file tree
Showing 8 changed files with 488 additions and 263 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* QUANTCONNECT.COM - Democratizing Finance, Empowering Individuals.
* Lean Algorithmic Trading Engine v2.0. Copyright 2014 QuantConnect Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Collections.Generic;
using QuantConnect.Data;
using QuantConnect.Indicators;
using QuantConnect.Interfaces;

namespace QuantConnect.Algorithm.CSharp.RegressionTests
{
/// <summary>
/// Validates the <see cref="Correlation"/> indicator by ensuring no mismatch between the last computed value
/// and the expected value. Also verifies proper functionality across different time zones.
/// </summary>
public class CorrelationLastComputedValueRegressionAlgorithm : QCAlgorithm, IRegressionAlgorithmDefinition
{
private Correlation _correlationPearson;
private decimal _lastCorrelationValue;
private decimal _totalCount;
private decimal _matchingCount;

public override void Initialize()
{
SetStartDate(2015, 05, 08);
SetEndDate(2017, 06, 15);

EnableAutomaticIndicatorWarmUp = true;
AddCrypto("BTCUSD", Resolution.Daily);
AddEquity("SPY", Resolution.Daily);

_correlationPearson = C("BTCUSD", "SPY", 3, CorrelationType.Pearson, Resolution.Daily);
if (!_correlationPearson.IsReady)
{
throw new RegressionTestException("Correlation indicator was expected to be ready");
}
_lastCorrelationValue = _correlationPearson.Current.Value;
_totalCount = 0;
_matchingCount = 0;
}

public override void OnData(Slice slice)
{
if (_lastCorrelationValue == _correlationPearson[1].Value)
{
_matchingCount++;
}
Debug($"CorrelationPearson between BTCUSD and SPY - Current: {_correlationPearson[0].Value}, Previous: {_correlationPearson[1].Value}");
_lastCorrelationValue = _correlationPearson.Current.Value;
_totalCount++;
}

public override void OnEndOfAlgorithm()
{
if (_totalCount == 0)
{
throw new RegressionTestException("No data points were processed.");
}
if (_totalCount != _matchingCount)
{
throw new RegressionTestException("Mismatch in the last computed CorrelationPearson values.");
}
Debug($"{_totalCount} data points were processed, {_matchingCount} matched the last computed value.");
}

/// <summary>
/// Final status of the algorithm
/// </summary>
public AlgorithmStatus AlgorithmStatus => AlgorithmStatus.Completed;

/// <summary>
/// This is used by the regression test system to indicate if the open source Lean repository has the required data to run this algorithm.
/// </summary>
public bool CanRunLocally => true;

/// <summary>
/// This is used by the regression test system to indicate which languages this algorithm is written in.
/// </summary>
public List<Language> Languages { get; } = new() { Language.CSharp };

/// <summary>
/// Data Points count of all timeslices of algorithm
/// </summary>
public long DataPoints => 5798;

/// <summary>
/// Data Points count of the algorithm history
/// </summary>
public int AlgorithmHistoryDataPoints => 72;

/// <summary>
/// This is used by the regression test system to indicate what the expected statistics are from running the algorithm
/// </summary>
public Dictionary<string, string> ExpectedStatistics => new Dictionary<string, string>
{
{"Total Orders", "0"},
{"Average Win", "0%"},
{"Average Loss", "0%"},
{"Compounding Annual Return", "0%"},
{"Drawdown", "0%"},
{"Expectancy", "0"},
{"Start Equity", "100000.00"},
{"End Equity", "100000"},
{"Net Profit", "0%"},
{"Sharpe Ratio", "0"},
{"Sortino Ratio", "0"},
{"Probabilistic Sharpe Ratio", "0%"},
{"Loss Rate", "0%"},
{"Win Rate", "0%"},
{"Profit-Loss Ratio", "0"},
{"Alpha", "0"},
{"Beta", "0"},
{"Annual Standard Deviation", "0"},
{"Annual Variance", "0"},
{"Information Ratio", "-0.616"},
{"Tracking Error", "0.111"},
{"Treynor Ratio", "0"},
{"Total Fees", "$0.00"},
{"Estimated Strategy Capacity", "$0"},
{"Lowest Capacity Asset", ""},
{"Portfolio Turnover", "0%"},
{"OrderListHash", "d41d8cd98f00b204e9800998ecf8427e"}
};
}
}
171 changes: 15 additions & 156 deletions Indicators/Beta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
using System;
using QuantConnect.Data.Market;
using MathNet.Numerics.Statistics;
using QuantConnect.Securities;
using NodaTime;

namespace QuantConnect.Indicators
{
Expand All @@ -32,58 +30,8 @@ namespace QuantConnect.Indicators
/// The indicator only updates when both assets have a price for a time step. When a bar is missing for one of the assets,
/// the indicator value fills forward to improve the accuracy of the indicator.
/// </summary>
public class Beta : BarIndicator, IIndicatorWarmUpPeriodProvider
public class Beta : DualSymbolIndicator<decimal>
{
/// <summary>
/// RollingWindow to store the data points of the target symbol
/// </summary>
private readonly RollingWindow<decimal> _targetDataPoints;

/// <summary>
/// RollingWindow to store the data points of the reference symbol
/// </summary>
private readonly RollingWindow<decimal> _referenceDataPoints;

/// <summary>
/// Symbol of the reference used
/// </summary>
private readonly Symbol _referenceSymbol;

/// <summary>
/// Symbol of the target used
/// </summary>
private readonly Symbol _targetSymbol;

/// <summary>
/// Stores the previous input data point.
/// </summary>
private IBaseDataBar _previousInput;

/// <summary>
/// Indicates whether the previous symbol is the target symbol.
/// </summary>
private bool _previousSymbolIsTarget;

/// <summary>
/// Indicates if the time zone for the target and reference are different.
/// </summary>
private bool _isTimezoneDifferent;

/// <summary>
/// Time zone of the target symbol.
/// </summary>
private DateTimeZone _targetTimeZone;

/// <summary>
/// Time zone of the reference symbol.
/// </summary>
private DateTimeZone _referenceTimeZone;

/// <summary>
/// The resolution of the data (e.g., daily, hourly, etc.).
/// </summary>
private Resolution _resolution;

/// <summary>
/// RollingWindow of returns of the target symbol in the given period
/// </summary>
Expand All @@ -94,16 +42,6 @@ public class Beta : BarIndicator, IIndicatorWarmUpPeriodProvider
/// </summary>
private readonly RollingWindow<double> _referenceReturns;

/// <summary>
/// Beta of the target used in relation with the reference
/// </summary>
private decimal _beta;

/// <summary>
/// Required period, in data points, for the indicator to be ready and fully initialized.
/// </summary>
public int WarmUpPeriod { get; private set; }

/// <summary>
/// Gets a flag indicating when the indicator is ready and fully initialized
/// </summary>
Expand All @@ -118,27 +56,17 @@ public class Beta : BarIndicator, IIndicatorWarmUpPeriodProvider
/// <param name="period">The period of this indicator</param>
/// <param name="referenceSymbol">The reference symbol of this indicator</param>
public Beta(string name, Symbol targetSymbol, Symbol referenceSymbol, int period)
: base(name)
: base(name, targetSymbol, referenceSymbol, 2)
{
// Assert the period is greater than two, otherwise the beta can not be computed
if (period < 2)
{
throw new ArgumentException($"Period parameter for Beta indicator must be greater than 2 but was {period}.");
}
_referenceSymbol = referenceSymbol;
_targetSymbol = targetSymbol;

_targetDataPoints = new RollingWindow<decimal>(2);
_referenceDataPoints = new RollingWindow<decimal>(2);

_targetReturns = new RollingWindow<double>(period);
_referenceReturns = new RollingWindow<double>(period);
_beta = 0;
var dataFolder = MarketHoursDatabase.FromDataFolder();
_targetTimeZone = dataFolder.GetExchangeHours(_targetSymbol.ID.Market, _targetSymbol, _targetSymbol.ID.SecurityType).TimeZone;
_referenceTimeZone = dataFolder.GetExchangeHours(_referenceSymbol.ID.Market, _referenceSymbol, _referenceSymbol.ID.SecurityType).TimeZone;
_isTimezoneDifferent = _targetTimeZone != _referenceTimeZone;
WarmUpPeriod = period + 1 + (_isTimezoneDifferent ? 1 : 0);
WarmUpPeriod = period + 1 + (IsTimezoneDifferent ? 1 : 0);
}

/// <summary>
Expand Down Expand Up @@ -167,97 +95,32 @@ public Beta(string name, int period, Symbol targetSymbol, Symbol referenceSymbol
{
}

/// <summary>
/// Computes the next value for this indicator from the given state.
///
/// As this indicator is receiving data points from two different symbols,
/// it's going to compute the next value when the amount of data points
/// of each of them is the same. Otherwise, it will return the last beta
/// value computed
/// </summary>
/// <param name="input">The input value of this indicator on this time step.
/// It can be either from the target or the reference symbol</param>
/// <returns>The beta value of the target used in relation with the reference</returns>
protected override decimal ComputeNextValue(IBaseDataBar input)
{
if (_previousInput == null)
{
_previousInput = input;
_previousSymbolIsTarget = input.Symbol == _targetSymbol;
var timeDifference = input.EndTime - input.Time;
_resolution = timeDifference.TotalHours > 1 ? Resolution.Daily : timeDifference.ToHigherResolutionEquivalent(false);
return decimal.Zero;
}

var inputEndTime = input.EndTime;
var previousInputEndTime = _previousInput.EndTime;

if (_isTimezoneDifferent)
{
inputEndTime = inputEndTime.ConvertToUtc(_previousSymbolIsTarget ? _referenceTimeZone : _targetTimeZone);
previousInputEndTime = previousInputEndTime.ConvertToUtc(_previousSymbolIsTarget ? _targetTimeZone : _referenceTimeZone);
}

// Process data if symbol has changed and timestamps match
if (input.Symbol != _previousInput.Symbol && TruncateToResolution(inputEndTime) == TruncateToResolution(previousInputEndTime))
{
AddDataPoint(input);
AddDataPoint(_previousInput);
ComputeBeta();
}
_previousInput = input;
_previousSymbolIsTarget = input.Symbol == _targetSymbol;
return _beta;
}

/// <summary>
/// Truncates the given DateTime based on the specified resolution (Daily, Hourly, Minute, or Second).
/// </summary>
/// <param name="date">The DateTime to truncate.</param>
/// <returns>A DateTime truncated to the specified resolution.</returns>
private DateTime TruncateToResolution(DateTime date)
{
switch (_resolution)
{
case Resolution.Daily:
return date.Date;
case Resolution.Hour:
return date.Date.AddHours(date.Hour);
case Resolution.Minute:
return date.Date.AddHours(date.Hour).AddMinutes(date.Minute);
case Resolution.Second:
return date;
default:
return date;
}
}

/// <summary>
/// Adds the closing price to the corresponding symbol's data set (target or reference).
/// Computes returns when there are enough data points for each symbol.
/// </summary>
/// <param name="input">The input value for this symbol</param>
private void AddDataPoint(IBaseDataBar input)
protected override void AddDataPoint(IBaseDataBar input)
{
if (input.Symbol == _targetSymbol)
if (input.Symbol == TargetSymbol)
{
_targetDataPoints.Add(input.Close);
if (_targetDataPoints.Count > 1)
TargetDataPoints.Add(input.Close);
if (TargetDataPoints.Count > 1)
{
_targetReturns.Add(GetNewReturn(_targetDataPoints));
_targetReturns.Add(GetNewReturn(TargetDataPoints));
}
}
else if (input.Symbol == _referenceSymbol)
else if (input.Symbol == ReferenceSymbol)
{
_referenceDataPoints.Add(input.Close);
if (_referenceDataPoints.Count > 1)
ReferenceDataPoints.Add(input.Close);
if (ReferenceDataPoints.Count > 1)
{
_referenceReturns.Add(GetNewReturn(_referenceDataPoints));
_referenceReturns.Add(GetNewReturn(ReferenceDataPoints));
}
}
else
{
throw new ArgumentException($"The given symbol {input.Symbol} was not {_targetSymbol} or {_referenceSymbol} symbol");
throw new ArgumentException($"The given symbol {input.Symbol} was not {TargetSymbol} or {ReferenceSymbol} symbol");
}
}

Expand All @@ -276,28 +139,24 @@ private static double GetNewReturn(RollingWindow<decimal> rollingWindow)
/// Computes the beta value of the target in relation with the reference
/// using the target and reference returns
/// </summary>
private void ComputeBeta()
protected override void ComputeIndicator()
{
var varianceComputed = _referenceReturns.Variance();
var covarianceComputed = _targetReturns.Covariance(_referenceReturns);

// Avoid division with NaN or by zero
var variance = !varianceComputed.IsNaNOrZero() ? varianceComputed : 1;
var covariance = !covarianceComputed.IsNaNOrZero() ? covarianceComputed : 0;
_beta = (decimal)(covariance / variance);
IndicatorValue = (decimal)(covariance / variance);
}

/// <summary>
/// Resets this indicator to its initial state
/// </summary>
public override void Reset()
{
_previousInput = null;
_targetDataPoints.Reset();
_referenceDataPoints.Reset();
_targetReturns.Reset();
_referenceReturns.Reset();
_beta = 0;
base.Reset();
}
}
Expand Down
Loading

0 comments on commit f3ea223

Please sign in to comment.