diff --git a/README.md b/README.md
index 0f1b8cc..b72c34f 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
-
+
diff --git a/README.tex.md b/README.tex.md
index 6dd741f..4872162 100644
--- a/README.tex.md
+++ b/README.tex.md
@@ -7,7 +7,7 @@
-
+
diff --git a/example/Example-Analysis.py b/example/Example-Analysis.py
index 40507ef..79ca9f0 100644
--- a/example/Example-Analysis.py
+++ b/example/Example-Analysis.py
@@ -316,3 +316,45 @@
#
print(pf.data.loc[pf.data.index.year == 2017].head(3))
+
+#
+
+# ## Momentum Indicators
+# `FinQuant` provides a module `finquant.momentum_indicators` to compute and
+# visualize a number of momentum indicators. Currently RSI (Relative Strength Index)
+# and MACD (Moving Average Convergence Divergence) indicators are available.
+# See below.
+
+#
+# plot the RSI (Relative Strength Index) for disney stock proces
+from finquant.momentum_indicators import plot_relative_strength_index as rsi
+
+# get stock data for disney
+dis = pf.get_stock("WIKI/DIS").data.copy(deep=True)
+
+# plot RSI - by default this plots RSI against the price in two graphs
+rsi(dis)
+plt.show()
+
+# plot RSI with custom arguments
+rsi(dis, oversold=20, overbought=80)
+plt.show()
+
+# plot RSI standalone graph
+rsi(dis, oversold=20, overbought=80, standalone=True)
+plt.show()
+
+#
+# plot MACD for disney stock proces
+from finquant.momentum_indicators import plot_macd
+
+# using short time frame of data due to plot warnings from matplotlib/mplfinance
+dis = dis[0:300]
+
+# plot MACD - by default this plots RSI against the price in two graphs
+plot_macd(dis)
+plt.show()
+
+# plot MACD using custom arguments
+plot_macd(dis, longer_ema_window=30, shorter_ema_window=15, signal_ema_window=10)
+plt.show()
diff --git a/finquant/momentum_indicators.py b/finquant/momentum_indicators.py
new file mode 100644
index 0000000..3474e1c
--- /dev/null
+++ b/finquant/momentum_indicators.py
@@ -0,0 +1,427 @@
+""" This module provides function(s) to compute momentum indicators
+used in technical analysis such as RSI, MACD etc. """
+from typing import List, Optional
+
+import matplotlib.pyplot as plt
+import mplfinance as mpf
+import pandas as pd
+
+from finquant.data_types import FLOAT, INT, SERIES_OR_DATAFRAME
+from finquant.type_utilities import type_validation
+from finquant.utils import all_list_ele_in_other, re_download_stock_data
+
+
+def calculate_wilder_smoothing_averages(
+ avg_gain_loss: FLOAT, gain_loss: FLOAT, window_length: INT
+) -> FLOAT:
+ """
+ Calculate Wilder's Smoothing Averages.
+
+ Wilder's Smoothing Averages are used in technical analysis, particularly for
+ calculating indicators like the Relative Strength Index (RSI). This function
+ takes the average gain/loss, the current gain/loss, and the window length as
+ input and returns the smoothed average.
+
+ :param avg_gain_loss: The previous average gain/loss.
+ :type avg_gain_loss: :py:data:`~.finquant.data_types.FLOAT`
+ :param gain_loss: The current gain or loss.
+ :type gain_loss: :py:data:`~.finquant.data_types.FLOAT`
+ :param window_length: The length of the smoothing window.
+ :type window_length: :py:data:`~.finquant.data_types.FLOAT`
+
+ :return: The Wilder's smoothed average value.
+ :rtype: :py:data:`~.finquant.data_types.FLOAT`
+
+ Example:
+
+ .. code-block:: python
+
+ calculate_wilder_smoothing_averages(10.0, 5.0, 14)
+
+ """
+ # Type validations:
+ type_validation(
+ avg_gain_loss=avg_gain_loss,
+ gain_loss=gain_loss,
+ window_length=window_length,
+ )
+ # validating window_length value range
+ if window_length <= 0:
+ raise ValueError("Error: window_length must be > 0.")
+ return (avg_gain_loss * (window_length - 1) + gain_loss) / float(window_length)
+
+
+def calculate_relative_strength_index(
+ data: SERIES_OR_DATAFRAME,
+ window_length: INT = 14,
+ oversold: INT = 30,
+ overbought: INT = 70,
+) -> pd.Series:
+ """Computes the relative strength index of given stock price data.
+
+ Ref: https://www.investopedia.com/terms/r/rsi.asp
+
+ :param data: A series/dataframe of daily stock prices
+ :type data: :py:data:`~.finquant.data_types.SERIES_OR_DATAFRAME`
+ :param window_length: Window length to compute RSI, default being 14 days
+ :type window_length: :py:data:`~.finquant.data_types.INT`
+ :param oversold: Standard level for oversold RSI, default being 30
+ :type oversold: :py:data:`~.finquant.data_types.INT`
+ :param overbought: Standard level for overbought RSI, default being 70
+ :type overbought: :py:data:`~.finquant.data_types.INT`
+
+ :return: A Series of RSI values.
+ """
+ # Type validations:
+ type_validation(
+ data=data,
+ window_length=window_length,
+ oversold=oversold,
+ overbought=overbought,
+ )
+ # validating levels
+ if oversold >= overbought:
+ raise ValueError("oversold level should be < overbought level")
+ if not 0 < oversold < 100 or not 0 < overbought < 100:
+ raise ValueError("levels should be > 0 and < 100")
+
+ if window_length > len(data):
+ raise ValueError("Error: window_length must be <= len(data).")
+
+ # converting data to pd.DataFrame if it is a pd.Series (for subsequent function calls):
+ if isinstance(data, pd.Series):
+ data = data.to_frame()
+
+ # calculate price differences
+ data["diff"] = data.diff(periods=1)
+ # calculate gains and losses
+ data["gain"] = data["diff"].clip(lower=0)
+ data["loss"] = data["diff"].clip(upper=0).abs()
+ # calculate rolling window mean gains and losses
+ data["avg_gain"] = (
+ data["gain"].rolling(window=window_length, min_periods=window_length).mean()
+ )
+ data["avg_loss"] = (
+ data["loss"].rolling(window=window_length, min_periods=window_length).mean()
+ )
+ # ignore SettingWithCopyWarning for the below operation
+ with pd.option_context("mode.chained_assignment", None):
+ for gain_or_loss in ["gain", "loss"]:
+ for idx, _ in enumerate(
+ data[f"avg_{gain_or_loss}"].iloc[window_length + 1 :]
+ ):
+ data[f"avg_{gain_or_loss}"].iloc[
+ idx + window_length + 1
+ ] = calculate_wilder_smoothing_averages(
+ data[f"avg_{gain_or_loss}"].iloc[idx + window_length],
+ data[gain_or_loss].iloc[idx + window_length + 1],
+ window_length,
+ )
+ # calculate RS values
+ data["rs"] = data["avg_gain"] / data["avg_loss"]
+ # calculate RSI
+ data["rsi"] = 100 - (100 / (1.0 + data["rs"]))
+ return data["rsi"]
+
+
+def plot_relative_strength_index(
+ data: SERIES_OR_DATAFRAME,
+ window_length: INT = 14,
+ oversold: INT = 30,
+ overbought: INT = 70,
+ standalone: bool = False,
+) -> None:
+ """Computes and visualizes a RSI graph,
+ plotted along with the prices in another sub-graph
+ for comparison.
+
+ Ref: https://www.investopedia.com/terms/r/rsi.asp
+
+ :param data: A series/dataframe of daily stock prices
+ :type data: :py:data:`~.finquant.data_types.SERIES_OR_DATAFRAME`
+ :param window_length: Window length to compute RSI, default being 14 days
+ :type window_length: :py:data:`~.finquant.data_types.INT`
+ :param oversold: Standard level for oversold RSI, default being 30
+ :type oversold: :py:data:`~.finquant.data_types.INT`
+ :param overbought: Standard level for overbought RSI, default being 70
+ :type overbought: :py:data:`~.finquant.data_types.INT`
+ :param standalone: Plot only the RSI graph
+ """
+
+ # converting data to pd.DataFrame if it is a pd.Series (for subsequent function calls):
+ if isinstance(data, pd.Series):
+ data = data.to_frame()
+ # Get stock name:
+ stock_name = data.keys()[0]
+
+ # compute RSI:
+ data["rsi"] = calculate_relative_strength_index(
+ data, window_length=window_length, oversold=oversold, overbought=overbought
+ )
+
+ # Plot it
+ if standalone:
+ # Single plot
+ fig = plt.figure()
+ axis = fig.add_subplot(111)
+ axis.axhline(
+ y=float(overbought), color="r", linestyle="dashed", label="overbought"
+ )
+ axis.axhline(y=float(oversold), color="g", linestyle="dashed", label="oversold")
+ axis.set_ylim(0, 100)
+ data["rsi"].plot(ylabel="RSI", xlabel="Date", ax=axis, grid=True)
+ plt.title("RSI Plot")
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
+ else:
+ # RSI against price in 2 plots
+ fig, axis = plt.subplots(2, 1, sharex=True, sharey=False)
+ axis[0].axhline(
+ y=float(overbought), color="r", linestyle="dashed", label="overbought"
+ )
+ axis[0].axhline(
+ y=float(oversold), color="g", linestyle="dashed", label="oversold"
+ )
+ axis[0].set_title("RSI + Price Plot")
+ axis[0].set_ylim(0, 100)
+ # plot 2 graphs in 2 colors
+ colors = plt.rcParams["axes.prop_cycle"]()
+ data["rsi"].plot(
+ ylabel="RSI",
+ ax=axis[0],
+ grid=True,
+ color=next(colors)["color"],
+ legend=True,
+ ).legend(loc="center left", bbox_to_anchor=(1, 0.5))
+ data[stock_name].plot(
+ xlabel="Date",
+ ylabel="Price",
+ ax=axis[1],
+ grid=True,
+ color=next(colors)["color"],
+ legend=True,
+ ).legend(loc="center left", bbox_to_anchor=(1, 0.5))
+
+
+# Generating colors for MACD histogram
+def gen_macd_color(df: pd.DataFrame) -> List[str]:
+ """
+ Generate a list of color codes based on MACD histogram values in a DataFrame.
+
+ This function takes a DataFrame containing MACD histogram values ('MACDh') and
+ assigns colors to each data point based on the direction of change in MACD values.
+
+ :param df: A series/dataframe of MACD histogram values
+
+ :return: A list of color codes corresponding to each data point in the DataFrame.
+
+ Note:
+ - This function assumes that the DataFrame contains a column named 'MACDh'.
+ - The color assignments are based on the comparison of each data point with its
+ previous data point in the 'MACDh' column.
+
+ Example:
+
+ .. code-block:: python
+
+ import pandas as pd
+ from typing import List
+
+ # Create a DataFrame with MACD histogram values
+ df = pd.DataFrame({'MACDh': [0.5, -0.2, 0.8, -0.6, 0.2]})
+
+ # Generate MACD color codes
+ colors = gen_macd_color(df)
+ print(colors) # Output: ['#26A69A', '#FFCDD2', '#26A69A', '#FFCDD2', '#26A69A']
+
+ """
+ # Type validations:
+ type_validation(df=df)
+ macd_color = []
+ macd_color.clear()
+ for idx in range(0, len(df["MACDh"])):
+ if (
+ df["MACDh"].iloc[idx] >= 0
+ and df["MACDh"].iloc[idx - 1] < df["MACDh"].iloc[idx]
+ ):
+ macd_color.append("#26A69A") # green
+ elif (
+ df["MACDh"].iloc[idx] >= 0
+ and df["MACDh"].iloc[idx - 1] > df["MACDh"].iloc[idx]
+ ):
+ macd_color.append("#B2DFDB") # faint green
+ elif (
+ df["MACDh"].iloc[idx] < 0
+ and df["MACDh"].iloc[idx - 1] > df["MACDh"].iloc[idx]
+ ):
+ macd_color.append("#FF5252") # red
+ elif (
+ df["MACDh"].iloc[idx] < 0
+ and df["MACDh"].iloc[idx - 1] < df["MACDh"].iloc[idx]
+ ):
+ macd_color.append("#FFCDD2") # faint red
+ else:
+ macd_color.append("#000000")
+ return macd_color
+
+
+def calculate_macd(
+ data: SERIES_OR_DATAFRAME,
+ longer_ema_window: Optional[INT] = 26,
+ shorter_ema_window: Optional[INT] = 12,
+ signal_ema_window: Optional[INT] = 9,
+ stock_name: Optional[str] = None,
+ num_days_predate_stock_price: Optional[INT] = 31,
+) -> pd.DataFrame:
+ # Type validations:
+ type_validation(
+ data=data,
+ longer_ema_window=longer_ema_window,
+ shorter_ema_window=shorter_ema_window,
+ signal_ema_window=signal_ema_window,
+ name=stock_name,
+ num_days_predate_stock_price=num_days_predate_stock_price,
+ )
+
+ # validating windows
+ if longer_ema_window < shorter_ema_window:
+ raise ValueError("longer ema window should be > shorter ema window")
+ if longer_ema_window < signal_ema_window:
+ raise ValueError("longer ema window should be > signal ema window")
+
+ # Taking care of potential column header clash, removing "WIKI/" (which comes from legacy quandl)
+ if stock_name is None:
+ stock_name = data.name
+ if "WIKI/" in stock_name:
+ stock_name = stock_name.replace("WIKI/", "")
+ if isinstance(data, pd.Series):
+ data = data.to_frame()
+ # Remove prefix substring from column headers
+ data.columns = data.columns.str.replace("WIKI/", "")
+
+ # Check if required columns are present, if data is a pd.DataFrame, else re-download stock price data:
+ download_stock_data_again = True
+ if isinstance(data, pd.DataFrame) and all_list_ele_in_other(
+ ["Open", "Close", "High", "Low", "Volume"], data.columns
+ ):
+ download_stock_data_again = False
+ if download_stock_data_again:
+ df = re_download_stock_data(
+ data,
+ stock_name=stock_name,
+ num_days_predate_stock_price=num_days_predate_stock_price,
+ )
+ else:
+ df = data
+
+ # Get the shorter_ema_window-day EMA of the closing price
+ ema_short = (
+ df["Close"]
+ .ewm(span=shorter_ema_window, adjust=False, min_periods=shorter_ema_window)
+ .mean()
+ )
+ # Get the longer_ema_window-day EMA of the closing price
+ ema_long = (
+ df["Close"]
+ .ewm(span=longer_ema_window, adjust=False, min_periods=longer_ema_window)
+ .mean()
+ )
+
+ # Subtract the longer_ema_window-day EMA from the shorter_ema_window-Day EMA to get the MACD
+ macd = ema_short - ema_long
+ # Get the signal_ema_window-Day EMA of the MACD for the Trigger line
+ macd_s = macd.ewm(
+ span=signal_ema_window, adjust=False, min_periods=signal_ema_window
+ ).mean()
+ # Calculate the difference between the MACD - Trigger for the Convergence/Divergence value
+ macd_h = macd - macd_s
+
+ # Add all of our new values for the MACD to the dataframe
+ df["MACD"] = df.index.map(macd)
+ df["MACDh"] = df.index.map(macd_h)
+ df["MACDs"] = df.index.map(macd_s)
+ return df
+
+
+def plot_macd(
+ data: SERIES_OR_DATAFRAME,
+ longer_ema_window: Optional[INT] = 26,
+ shorter_ema_window: Optional[INT] = 12,
+ signal_ema_window: Optional[INT] = 9,
+ stock_name: Optional[str] = None,
+):
+ """
+ Generate a Matplotlib candlestick chart with MACD (Moving Average Convergence Divergence) indicators.
+
+ Ref: https://github.com/matplotlib/mplfinance/blob/master/examples/indicators/macd_histogram_gradient.ipynb
+
+ This function creates a candlestick chart using the given stock price data and overlays
+ MACD, MACD Signal Line, and MACD Histogram indicators. The MACD is calculated by taking
+ the difference between two Exponential Moving Averages (EMAs) of the closing price.
+
+ :param data: Time series data containing stock price information. If a
+ DataFrame is provided, it should have columns 'Open', 'Close', 'High', 'Low', and 'Volume'.
+ Else, stock price data for given time frame is downloaded again.
+ :type data: :py:data:`~.finquant.data_types.SERIES_OR_DATAFRAME`
+ :param longer_ema_window: Optional, window size for the longer-term EMA (default is 26).
+ :type longer_ema_window: :py:data:`~.finquant.data_types.INT`
+ :param shorter_ema_window: Optional, window size for the shorter-term EMA (default is 12).
+ :type shorter_ema_window: :py:data:`~.finquant.data_types.INT`
+ :param signal_ema_window: Optional, window size for the signal line EMA (default is 9).
+ :type signal_ema_window: :py:data:`~.finquant.data_types.INT`
+ :param stock_name: Optional, name of the stock for labeling purposes (default is None).
+
+ Note:
+ - If the input data is a DataFrame, it should contain columns 'Open', 'Close', 'High', 'Low', and 'Volume'.
+ - If the input data is a Series, it should have a valid name.
+ - The longer EMA window should be greater than or equal to the shorter EMA window and signal EMA window.
+
+ Example:
+
+ .. code-block:: python
+
+ import pandas as pd
+ from mplfinance.original_flavor import plot as mpf
+
+ # Create a DataFrame or Series with stock price data
+ data = pd.read_csv('stock_data.csv', index_col='Date', parse_dates=True)
+ plot_macd(data, longer_ema_window=26, shorter_ema_window=12, signal_ema_window=9, stock_name='DIS')
+
+ """
+ # calculate MACD:
+ df = calculate_macd(
+ data,
+ longer_ema_window,
+ shorter_ema_window,
+ signal_ema_window,
+ stock_name=stock_name,
+ )
+
+ # plot macd
+ macd_color = gen_macd_color(df)
+ apds = [
+ mpf.make_addplot(df["MACD"], color="#2962FF", panel=1),
+ mpf.make_addplot(df["MACDs"], color="#FF6D00", panel=1),
+ mpf.make_addplot(
+ df["MACDh"],
+ type="bar",
+ width=0.7,
+ panel=1,
+ color=macd_color,
+ alpha=1,
+ secondary_y=True,
+ ),
+ ]
+ fig, axes = mpf.plot(
+ df,
+ volume=True,
+ type="candle",
+ style="yahoo",
+ addplot=apds,
+ volume_panel=2,
+ figsize=(20, 10),
+ returnfig=True,
+ )
+ axes[2].legend(["MACD"], loc="upper left")
+ axes[3].legend(["Signal"], loc="lower left")
+
+ return fig, axes
diff --git a/finquant/type_utilities.py b/finquant/type_utilities.py
index 25d86fd..e46f441 100644
--- a/finquant/type_utilities.py
+++ b/finquant/type_utilities.py
@@ -64,7 +64,7 @@ def _check_type(
if element_type is not None:
if isinstance(arg_values, pd.DataFrame) and not all(
- arg_values.dtypes == element_type
+ np.issubdtype(value_type, element_type) for value_type in arg_values.dtypes
):
validation_failed = True
@@ -114,7 +114,7 @@ def _check_empty_data(arg_name: str, arg_values: Any) -> None:
],
] = {
# DataFrames, Series, Array:
- "data": ((pd.Series, pd.DataFrame), np.floating),
+ "data": ((pd.Series, pd.DataFrame), np.number),
"pf_allocation": (pd.DataFrame, None),
"returns_df": (pd.DataFrame, np.floating),
"returns_series": (pd.Series, np.floating),
@@ -124,6 +124,7 @@ def _check_empty_data(arg_name: str, arg_values: Any) -> None:
"initial_weights": (np.ndarray, np.floating),
"weights_array": (np.ndarray, np.floating),
"cov_matrix": ((np.ndarray, pd.DataFrame), np.floating),
+ "df": (pd.DataFrame, None),
# Lists:
"names": ((List, np.ndarray), str),
"cols": ((List, np.ndarray), str),
@@ -150,15 +151,25 @@ def _check_empty_data(arg_name: str, arg_values: Any) -> None:
"freq": ((int, np.integer), None),
"span": ((int, np.integer), None),
"num_trials": ((int, np.integer), None),
+ "longer_ema_window": ((int, np.integer), None),
+ "shorter_ema_window": ((int, np.integer), None),
+ "signal_ema_window": ((int, np.integer), None),
+ "window_length": ((int, np.integer), None),
+ "oversold": ((int, np.integer), None),
+ "overbought": ((int, np.integer), None),
+ "num_days_predate_stock_price": ((int, np.integer), None),
# NUMERICs:
"investment": ((int, np.integer, float, np.floating), None),
"dividend": ((int, np.integer, float, np.floating), None),
"target": ((int, np.integer, float, np.floating), None),
+ "avg_gain_loss": ((int, np.integer, float, np.floating), None),
+ "gain_loss": ((int, np.integer, float, np.floating), None),
# Booleans:
"plot": (bool, None),
"save_weights": (bool, None),
"verbose": (bool, None),
"defer_update": (bool, None),
+ "standalone": (bool, None),
}
type_callable_dict: Dict[
diff --git a/finquant/utils.py b/finquant/utils.py
new file mode 100644
index 0000000..e9c34d9
--- /dev/null
+++ b/finquant/utils.py
@@ -0,0 +1,40 @@
+import datetime
+from typing import Optional
+
+import pandas as pd
+
+from finquant.data_types import ELEMENT_TYPE, INT, LIST_DICT_KEYS, SERIES_OR_DATAFRAME
+from finquant.portfolio import _yfinance_request
+from finquant.type_utilities import type_validation
+
+
+def all_list_ele_in_other(
+ l_1: LIST_DICT_KEYS[ELEMENT_TYPE],
+ l_2: LIST_DICT_KEYS[ELEMENT_TYPE],
+) -> bool:
+ """Returns True if all elements of list l1 are found in list l2."""
+ return all(ele in l_2 for ele in l_1)
+
+
+def re_download_stock_data(
+ data: SERIES_OR_DATAFRAME,
+ stock_name: str,
+ num_days_predate_stock_price: Optional[INT] = 0,
+) -> pd.DataFrame:
+ # Type validations:
+ type_validation(
+ data=data,
+ name=stock_name,
+ num_days_predate_stock_price=num_days_predate_stock_price,
+ )
+ if num_days_predate_stock_price < 0:
+ raise ValueError("Error: num_days_predate_stock_price must be >= 0.")
+ # download additional price data 'Open' for given stock and timeframe:
+ start_date = data.index.min() - datetime.timedelta(
+ days=num_days_predate_stock_price
+ )
+ end_date = data.index.max() + datetime.timedelta(days=1)
+ df = _yfinance_request([stock_name], start_date=start_date, end_date=end_date)
+ # dropping second level of column header that yfinance returns
+ df.columns = df.columns.droplevel(1)
+ return df
diff --git a/requirements.txt b/requirements.txt
index 3216f4d..2fad55b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,6 +2,7 @@ numpy>=1.22.0
scipy>=1.2.0
pandas>=2.0
matplotlib>=3.0
+mplfinance>=0.12.10b0
quandl>=3.4.5
yfinance>=0.1.43
scikit-learn>=1.3.0
\ No newline at end of file
diff --git a/tests/test_momentum_indicators.py b/tests/test_momentum_indicators.py
new file mode 100644
index 0000000..e907400
--- /dev/null
+++ b/tests/test_momentum_indicators.py
@@ -0,0 +1,354 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import pytest
+
+from finquant.momentum_indicators import (
+ calculate_macd,
+ calculate_relative_strength_index,
+ calculate_wilder_smoothing_averages,
+ gen_macd_color,
+ plot_macd,
+ plot_relative_strength_index,
+)
+from finquant.utils import re_download_stock_data
+
+plt.close("all")
+plt.switch_backend("Agg")
+
+
+# Define a sample dataframe for testing
+price_data = np.array(
+ [100, 102, 105, 103, 108, 110, 107, 109, 112, 115, 120, 118, 121, 124, 125, 126]
+).astype(np.float64)
+data = pd.DataFrame({"Close": price_data})
+macd_data = pd.DataFrame(
+ {
+ "Date": pd.date_range(start="2022-01-01", periods=16, freq="D"),
+ "DIS": price_data,
+ }
+).set_index("Date", inplace=False)
+macd_data.name = "DIS"
+
+
+def test_calculate_wilder_smoothing_averages():
+ # Test with the example values
+ result = calculate_wilder_smoothing_averages(10.0, 5.0, 14)
+ assert (
+ abs(result - 9.642857142857142) <= 1e-15
+ ) # The expected result calculated manually
+
+ # Test with zero average gain/loss
+ result = calculate_wilder_smoothing_averages(0.0, 5.0, 14)
+ assert (
+ abs(result - 0.35714285714285715) <= 1e-15
+ ) # The expected result calculated manually
+
+ # Test with zero current gain/loss
+ result = calculate_wilder_smoothing_averages(10.0, 0.0, 14)
+ assert (
+ abs(result - 9.285714285714286) <= 1e-15
+ ) # The expected result calculated manually
+
+ # Test with window length of 1
+ result = calculate_wilder_smoothing_averages(10.0, 5.0, 1)
+ assert (
+ abs(result - 5.0) <= 1e-15
+ ) # Since window length is 1, the result should be the current gain/loss
+
+ # Test with negative values
+ result = calculate_wilder_smoothing_averages(-10.0, -5.0, 14)
+ assert (
+ abs(result - -9.642857142857142) <= 1e-15
+ ) # The expected result calculated manually
+
+ # Test with very large numbers
+ result = calculate_wilder_smoothing_averages(1e20, 1e20, int(1e20))
+ assert (
+ abs(result - 1e20) <= 1e-15
+ ) # The expected result is the same as input due to the large window length
+
+ # Test with non-float input (should raise an exception)
+ with pytest.raises(TypeError):
+ calculate_wilder_smoothing_averages("10.0", 5.0, 14)
+
+ # Test with window length of 0 (should raise an exception)
+ with pytest.raises(ValueError):
+ calculate_wilder_smoothing_averages(10.0, 5.0, 0)
+
+ # Test with negative window length (should raise an exception)
+ with pytest.raises(ValueError):
+ calculate_wilder_smoothing_averages(10.0, 5.0, -14)
+
+
+def test_calculate_relative_strength_index():
+ rsi = calculate_relative_strength_index(data["Close"])
+
+ # Check if the result is a Pandas Series
+ assert isinstance(rsi, pd.Series)
+
+ # Check the length of the result
+ assert len(rsi.dropna()) == len(data) - 14
+
+ # Check the first RSI value
+ assert np.isclose(rsi.dropna().iloc[0], 82.051282, rtol=1e-4)
+
+ # Check the last RSI value
+ assert np.isclose(rsi.iloc[-1], 82.53358925143954, rtol=1e-4)
+
+ # Check that the RSI values are within the range [0, 100]
+ assert (rsi.dropna() >= 0).all() and (rsi.dropna() <= 100).all()
+
+ # Check for window_length > data length, should raise a ValueError
+ with pytest.raises(ValueError):
+ calculate_relative_strength_index(data["Close"], window_length=17)
+
+ # Check for oversold >= overbought, should raise a ValueError
+ with pytest.raises(ValueError):
+ calculate_relative_strength_index(data["Close"], oversold=70, overbought=70)
+
+ # Check for invalid levels, should raise a ValueError
+ with pytest.raises(ValueError):
+ calculate_relative_strength_index(data["Close"], oversold=150, overbought=80)
+
+ with pytest.raises(ValueError):
+ calculate_relative_strength_index(data["Close"], oversold=20, overbought=120)
+
+ # Check for empty input data, should raise a ValueError
+ with pytest.raises(ValueError):
+ calculate_relative_strength_index(pd.Series([]))
+
+ # Check for non-Pandas Series input, should raise a TypeError
+ with pytest.raises(TypeError):
+ calculate_relative_strength_index(list(data["Close"]))
+
+
+def test_plot_relative_strength_index_standalone():
+ # Test standalone mode
+ xlabel_orig = "Date"
+ ylabel_orig = "RSI"
+ labels_orig = ["overbought", "oversold", "rsi"]
+ title_orig = "RSI Plot"
+ plot_relative_strength_index(data["Close"], standalone=True)
+ # get data from axis object
+ ax = plt.gca()
+ # ax.lines[2] is the RSI data
+ xlabel_plot = ax.get_xlabel()
+ ylabel_plot = ax.get_ylabel()
+ # tests
+ labels_plot = ax.get_legend_handles_labels()[1]
+ title_plot = ax.get_title()
+ assert labels_plot == labels_orig
+ assert xlabel_plot == xlabel_orig
+ assert ylabel_plot == ylabel_orig
+ assert title_plot == title_orig
+
+
+def test_plot_relative_strength_index_not_standalone():
+ # Test non-standalone mode
+ xlabel_orig = "Date"
+ ylabel_orig = "Price"
+ plot_relative_strength_index(data["Close"], standalone=False)
+ # get data from axis object
+ ax = plt.gca()
+ line1 = ax.lines[0]
+ stock_plot = line1.get_xydata()
+ xlabel_plot = ax.get_xlabel()
+ ylabel_plot = ax.get_ylabel()
+ # tests
+ assert (data["Close"].index.values == stock_plot[:, 0]).all()
+ assert (data["Close"].values == stock_plot[:, 1]).all()
+ assert xlabel_orig == xlabel_plot
+ assert ylabel_orig == ylabel_plot
+
+
+def test_gen_macd_color_valid_input():
+ # Test with valid input
+ macd_df = pd.DataFrame({"MACDh": [0.5, -0.2, 0.8, -0.6, 0.2]})
+ colors = gen_macd_color(macd_df)
+
+ # Check that the result is a list
+ assert isinstance(colors, list)
+
+ # Check the length of the result
+ assert len(colors) == len(macd_df)
+
+ # Check color assignments based on MACD values
+ assert colors == ["#26A69A", "#FF5252", "#26A69A", "#FF5252", "#26A69A"]
+
+
+def test_gen_macd_color_green():
+ # Test with a DataFrame where MACD values are consistently positive, should return
+ # all green colors
+ positive_df = pd.DataFrame({"MACDh": [0.5, 0.6, 0.7, 0.8, 0.9]})
+ colors = gen_macd_color(positive_df)
+
+ # Check that the result is a list of all green colors
+ assert colors == ["#B2DFDB", "#26A69A", "#26A69A", "#26A69A", "#26A69A"]
+
+
+def test_gen_macd_color_faint_green():
+ # Test with a DataFrame where MACD values are consistently positive but decreasing,
+ # should return all faint green colors
+ faint_green_df = pd.DataFrame({"MACDh": [0.5, 0.4, 0.3, 0.2, 0.1]})
+ colors = gen_macd_color(faint_green_df)
+
+ # Check that the result is a list of all faint green colors
+ assert colors == ["#26A69A", "#B2DFDB", "#B2DFDB", "#B2DFDB", "#B2DFDB"]
+
+
+def test_gen_macd_color_red():
+ # Test with a DataFrame where MACD values are consistently negative,
+ # should return all red colors
+ negative_df = pd.DataFrame({"MACDh": [-0.5, -0.6, -0.7, -0.8, -0.9]})
+ colors = gen_macd_color(negative_df)
+
+ # Check that the result is a list of all red colors
+ assert colors == ["#FFCDD2", "#FF5252", "#FF5252", "#FF5252", "#FF5252"]
+
+
+def test_gen_macd_color_faint_red():
+ # Test with a DataFrame where MACD values are consistently negative but decreasing,
+ # should return all faint red colors
+ faint_red_df = pd.DataFrame({"MACDh": [-0.5, -0.4, -0.3, -0.2, -0.1]})
+ colors = gen_macd_color(faint_red_df)
+
+ # Check that the result is a list of all faint red colors
+ assert colors == ["#FF5252", "#FFCDD2", "#FFCDD2", "#FFCDD2", "#FFCDD2"]
+
+
+def test_gen_macd_color_single_element():
+ # Test with a DataFrame containing a single element, should return a list with one color
+ single_element_df = pd.DataFrame({"MACDh": [0.5]})
+ colors = gen_macd_color(single_element_df)
+
+ # Check that the result is a list with one color
+ assert colors == ["#000000"]
+
+
+def test_gen_macd_color_empty_input():
+ # Test with an empty DataFrame, should return an empty list
+ empty_df = pd.DataFrame(columns=["MACDh"])
+ with pytest.raises(ValueError):
+ colors = gen_macd_color(empty_df)
+
+
+def test_gen_macd_color_missing_column():
+ # Test with a DataFrame missing 'MACDh' column, should raise a KeyError
+ df_missing_column = pd.DataFrame({"NotMACDh": [0.5, -0.2, 0.8, -0.6, 0.2]})
+
+ with pytest.raises(KeyError):
+ gen_macd_color(df_missing_column)
+
+
+def test_gen_macd_color_no_color_change():
+ # Test with a DataFrame where MACD values don't change, should return all black colors
+ no_change_df = pd.DataFrame({"MACDh": [0.5, 0.5, 0.5, 0.5, 0.5]})
+ colors = gen_macd_color(no_change_df)
+
+ # Check that the result is a list of all black colors
+ assert colors == ["#000000", "#000000", "#000000", "#000000", "#000000"]
+
+
+def test_calculate_macd_valid_input():
+ # Test with valid input
+ result = calculate_macd(macd_data, num_days_predate_stock_price=0)
+
+ # Check that the result is a DataFrame
+ assert isinstance(result, pd.DataFrame)
+
+ # Check the length of the result
+ assert len(result) == 10
+ # not == len(macd_data) here, as we currently re-download data, weekends are not considered
+
+ # Check that the required columns ('MACD', 'MACDh', 'MACDs') are present in the result
+ assert all(col in result.columns for col in ["MACD", "MACDh", "MACDs"])
+
+
+def test_calculate_macd_correct_values():
+ # Test for correct values in 'MACD', 'MACDh', and 'MACDs' columns
+ longer_ema_window = 10
+ shorter_ema_window = 7
+ signal_ema_window = 4
+ df = re_download_stock_data(
+ macd_data, stock_name="DIS", num_days_predate_stock_price=0
+ )
+ result = calculate_macd(
+ macd_data,
+ longer_ema_window=longer_ema_window,
+ shorter_ema_window=shorter_ema_window,
+ signal_ema_window=signal_ema_window,
+ num_days_predate_stock_price=0,
+ )
+
+ # Calculate expected values manually (using the provided df)
+ ema_short = (
+ df["Close"]
+ .ewm(span=shorter_ema_window, adjust=False, min_periods=shorter_ema_window)
+ .mean()
+ )
+ ema_long = (
+ df["Close"]
+ .ewm(span=longer_ema_window, adjust=False, min_periods=longer_ema_window)
+ .mean()
+ )
+ macd = ema_short - ema_long
+ macd.name = "MACD"
+ signal = macd.ewm(
+ span=signal_ema_window, adjust=False, min_periods=signal_ema_window
+ ).mean()
+ macd_h = macd - signal
+
+ # Check that the calculated values match the values in the DataFrame
+ assert all(result["MACD"].dropna() == macd.dropna())
+ assert all(result["MACDh"].dropna() == macd_h.dropna())
+ assert all(result["MACDs"].dropna() == signal.dropna())
+
+
+def test_calculate_macd_custom_windows():
+ # Test with custom EMA window values
+ result = calculate_macd(
+ macd_data, longer_ema_window=30, shorter_ema_window=15, signal_ema_window=10
+ )
+
+ # Check that the result is a DataFrame
+ assert isinstance(result, pd.DataFrame)
+
+ # Check that the required columns ('MACD', 'MACDh', 'MACDs') are present in the result
+ assert all(col in result.columns for col in ["MACD", "MACDh", "MACDs"])
+
+
+def test_calculate_macd_invalid_windows():
+ # Test with invalid window values, should raise ValueError
+ with pytest.raises(ValueError):
+ calculate_macd(
+ macd_data, longer_ema_window=10, shorter_ema_window=20, signal_ema_window=15
+ )
+ with pytest.raises(ValueError):
+ plot_macd(
+ macd_data, longer_ema_window=10, shorter_ema_window=5, signal_ema_window=30
+ )
+
+
+def test_plot_macd():
+ axes0_ylabel_orig = "Price"
+ axes4_ylabel_orig = "Volume $10^{6}$"
+ # Create sample data for testing
+ x = np.sin(np.linspace(1, 10, 100))
+ df = pd.DataFrame(
+ {"Close": x}, index=pd.date_range("2015-01-01", periods=100, freq="D")
+ )
+ df.name = "DIS"
+
+ # Call mpl_macd function
+ fig, axes = plot_macd(df)
+
+ axes0_ylabel_plot = axes[0].get_ylabel()
+ axes4_ylabel_plot = axes[4].get_ylabel()
+
+ # Check if the function returned valid figures and axes objects
+ assert isinstance(fig, plt.Figure)
+ assert isinstance(axes, list)
+ assert len(axes) == 6 # Assuming there are six subplots in the returned figure
+ assert axes0_ylabel_orig == axes0_ylabel_plot
+ assert axes4_ylabel_orig == axes4_ylabel_plot
diff --git a/version b/version
index 7ad36c6..37902e9 100644
--- a/version
+++ b/version
@@ -1,2 +1,2 @@
-version=0.7.0
-release=0.7.0
+version=0.8.0
+release=0.8.0