diff --git a/README.md b/README.md index 0f1b8cc..b72c34f 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ pypi - pypi + pypi GitHub Actions diff --git a/README.tex.md b/README.tex.md index 6dd741f..4872162 100644 --- a/README.tex.md +++ b/README.tex.md @@ -7,7 +7,7 @@ pypi - pypi + pypi GitHub Actions diff --git a/example/Example-Analysis.py b/example/Example-Analysis.py index 40507ef..79ca9f0 100644 --- a/example/Example-Analysis.py +++ b/example/Example-Analysis.py @@ -316,3 +316,45 @@ # print(pf.data.loc[pf.data.index.year == 2017].head(3)) + +# + +# ## Momentum Indicators +# `FinQuant` provides a module `finquant.momentum_indicators` to compute and +# visualize a number of momentum indicators. Currently RSI (Relative Strength Index) +# and MACD (Moving Average Convergence Divergence) indicators are available. +# See below. + +# +# plot the RSI (Relative Strength Index) for disney stock proces +from finquant.momentum_indicators import plot_relative_strength_index as rsi + +# get stock data for disney +dis = pf.get_stock("WIKI/DIS").data.copy(deep=True) + +# plot RSI - by default this plots RSI against the price in two graphs +rsi(dis) +plt.show() + +# plot RSI with custom arguments +rsi(dis, oversold=20, overbought=80) +plt.show() + +# plot RSI standalone graph +rsi(dis, oversold=20, overbought=80, standalone=True) +plt.show() + +# +# plot MACD for disney stock proces +from finquant.momentum_indicators import plot_macd + +# using short time frame of data due to plot warnings from matplotlib/mplfinance +dis = dis[0:300] + +# plot MACD - by default this plots RSI against the price in two graphs +plot_macd(dis) +plt.show() + +# plot MACD using custom arguments +plot_macd(dis, longer_ema_window=30, shorter_ema_window=15, signal_ema_window=10) +plt.show() diff --git a/finquant/momentum_indicators.py b/finquant/momentum_indicators.py new file mode 100644 index 0000000..3474e1c --- /dev/null +++ b/finquant/momentum_indicators.py @@ -0,0 +1,427 @@ +""" This module provides function(s) to compute momentum indicators +used in technical analysis such as RSI, MACD etc. """ +from typing import List, Optional + +import matplotlib.pyplot as plt +import mplfinance as mpf +import pandas as pd + +from finquant.data_types import FLOAT, INT, SERIES_OR_DATAFRAME +from finquant.type_utilities import type_validation +from finquant.utils import all_list_ele_in_other, re_download_stock_data + + +def calculate_wilder_smoothing_averages( + avg_gain_loss: FLOAT, gain_loss: FLOAT, window_length: INT +) -> FLOAT: + """ + Calculate Wilder's Smoothing Averages. + + Wilder's Smoothing Averages are used in technical analysis, particularly for + calculating indicators like the Relative Strength Index (RSI). This function + takes the average gain/loss, the current gain/loss, and the window length as + input and returns the smoothed average. + + :param avg_gain_loss: The previous average gain/loss. + :type avg_gain_loss: :py:data:`~.finquant.data_types.FLOAT` + :param gain_loss: The current gain or loss. + :type gain_loss: :py:data:`~.finquant.data_types.FLOAT` + :param window_length: The length of the smoothing window. + :type window_length: :py:data:`~.finquant.data_types.FLOAT` + + :return: The Wilder's smoothed average value. + :rtype: :py:data:`~.finquant.data_types.FLOAT` + + Example: + + .. code-block:: python + + calculate_wilder_smoothing_averages(10.0, 5.0, 14) + + """ + # Type validations: + type_validation( + avg_gain_loss=avg_gain_loss, + gain_loss=gain_loss, + window_length=window_length, + ) + # validating window_length value range + if window_length <= 0: + raise ValueError("Error: window_length must be > 0.") + return (avg_gain_loss * (window_length - 1) + gain_loss) / float(window_length) + + +def calculate_relative_strength_index( + data: SERIES_OR_DATAFRAME, + window_length: INT = 14, + oversold: INT = 30, + overbought: INT = 70, +) -> pd.Series: + """Computes the relative strength index of given stock price data. + + Ref: https://www.investopedia.com/terms/r/rsi.asp + + :param data: A series/dataframe of daily stock prices + :type data: :py:data:`~.finquant.data_types.SERIES_OR_DATAFRAME` + :param window_length: Window length to compute RSI, default being 14 days + :type window_length: :py:data:`~.finquant.data_types.INT` + :param oversold: Standard level for oversold RSI, default being 30 + :type oversold: :py:data:`~.finquant.data_types.INT` + :param overbought: Standard level for overbought RSI, default being 70 + :type overbought: :py:data:`~.finquant.data_types.INT` + + :return: A Series of RSI values. + """ + # Type validations: + type_validation( + data=data, + window_length=window_length, + oversold=oversold, + overbought=overbought, + ) + # validating levels + if oversold >= overbought: + raise ValueError("oversold level should be < overbought level") + if not 0 < oversold < 100 or not 0 < overbought < 100: + raise ValueError("levels should be > 0 and < 100") + + if window_length > len(data): + raise ValueError("Error: window_length must be <= len(data).") + + # converting data to pd.DataFrame if it is a pd.Series (for subsequent function calls): + if isinstance(data, pd.Series): + data = data.to_frame() + + # calculate price differences + data["diff"] = data.diff(periods=1) + # calculate gains and losses + data["gain"] = data["diff"].clip(lower=0) + data["loss"] = data["diff"].clip(upper=0).abs() + # calculate rolling window mean gains and losses + data["avg_gain"] = ( + data["gain"].rolling(window=window_length, min_periods=window_length).mean() + ) + data["avg_loss"] = ( + data["loss"].rolling(window=window_length, min_periods=window_length).mean() + ) + # ignore SettingWithCopyWarning for the below operation + with pd.option_context("mode.chained_assignment", None): + for gain_or_loss in ["gain", "loss"]: + for idx, _ in enumerate( + data[f"avg_{gain_or_loss}"].iloc[window_length + 1 :] + ): + data[f"avg_{gain_or_loss}"].iloc[ + idx + window_length + 1 + ] = calculate_wilder_smoothing_averages( + data[f"avg_{gain_or_loss}"].iloc[idx + window_length], + data[gain_or_loss].iloc[idx + window_length + 1], + window_length, + ) + # calculate RS values + data["rs"] = data["avg_gain"] / data["avg_loss"] + # calculate RSI + data["rsi"] = 100 - (100 / (1.0 + data["rs"])) + return data["rsi"] + + +def plot_relative_strength_index( + data: SERIES_OR_DATAFRAME, + window_length: INT = 14, + oversold: INT = 30, + overbought: INT = 70, + standalone: bool = False, +) -> None: + """Computes and visualizes a RSI graph, + plotted along with the prices in another sub-graph + for comparison. + + Ref: https://www.investopedia.com/terms/r/rsi.asp + + :param data: A series/dataframe of daily stock prices + :type data: :py:data:`~.finquant.data_types.SERIES_OR_DATAFRAME` + :param window_length: Window length to compute RSI, default being 14 days + :type window_length: :py:data:`~.finquant.data_types.INT` + :param oversold: Standard level for oversold RSI, default being 30 + :type oversold: :py:data:`~.finquant.data_types.INT` + :param overbought: Standard level for overbought RSI, default being 70 + :type overbought: :py:data:`~.finquant.data_types.INT` + :param standalone: Plot only the RSI graph + """ + + # converting data to pd.DataFrame if it is a pd.Series (for subsequent function calls): + if isinstance(data, pd.Series): + data = data.to_frame() + # Get stock name: + stock_name = data.keys()[0] + + # compute RSI: + data["rsi"] = calculate_relative_strength_index( + data, window_length=window_length, oversold=oversold, overbought=overbought + ) + + # Plot it + if standalone: + # Single plot + fig = plt.figure() + axis = fig.add_subplot(111) + axis.axhline( + y=float(overbought), color="r", linestyle="dashed", label="overbought" + ) + axis.axhline(y=float(oversold), color="g", linestyle="dashed", label="oversold") + axis.set_ylim(0, 100) + data["rsi"].plot(ylabel="RSI", xlabel="Date", ax=axis, grid=True) + plt.title("RSI Plot") + plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + else: + # RSI against price in 2 plots + fig, axis = plt.subplots(2, 1, sharex=True, sharey=False) + axis[0].axhline( + y=float(overbought), color="r", linestyle="dashed", label="overbought" + ) + axis[0].axhline( + y=float(oversold), color="g", linestyle="dashed", label="oversold" + ) + axis[0].set_title("RSI + Price Plot") + axis[0].set_ylim(0, 100) + # plot 2 graphs in 2 colors + colors = plt.rcParams["axes.prop_cycle"]() + data["rsi"].plot( + ylabel="RSI", + ax=axis[0], + grid=True, + color=next(colors)["color"], + legend=True, + ).legend(loc="center left", bbox_to_anchor=(1, 0.5)) + data[stock_name].plot( + xlabel="Date", + ylabel="Price", + ax=axis[1], + grid=True, + color=next(colors)["color"], + legend=True, + ).legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + +# Generating colors for MACD histogram +def gen_macd_color(df: pd.DataFrame) -> List[str]: + """ + Generate a list of color codes based on MACD histogram values in a DataFrame. + + This function takes a DataFrame containing MACD histogram values ('MACDh') and + assigns colors to each data point based on the direction of change in MACD values. + + :param df: A series/dataframe of MACD histogram values + + :return: A list of color codes corresponding to each data point in the DataFrame. + + Note: + - This function assumes that the DataFrame contains a column named 'MACDh'. + - The color assignments are based on the comparison of each data point with its + previous data point in the 'MACDh' column. + + Example: + + .. code-block:: python + + import pandas as pd + from typing import List + + # Create a DataFrame with MACD histogram values + df = pd.DataFrame({'MACDh': [0.5, -0.2, 0.8, -0.6, 0.2]}) + + # Generate MACD color codes + colors = gen_macd_color(df) + print(colors) # Output: ['#26A69A', '#FFCDD2', '#26A69A', '#FFCDD2', '#26A69A'] + + """ + # Type validations: + type_validation(df=df) + macd_color = [] + macd_color.clear() + for idx in range(0, len(df["MACDh"])): + if ( + df["MACDh"].iloc[idx] >= 0 + and df["MACDh"].iloc[idx - 1] < df["MACDh"].iloc[idx] + ): + macd_color.append("#26A69A") # green + elif ( + df["MACDh"].iloc[idx] >= 0 + and df["MACDh"].iloc[idx - 1] > df["MACDh"].iloc[idx] + ): + macd_color.append("#B2DFDB") # faint green + elif ( + df["MACDh"].iloc[idx] < 0 + and df["MACDh"].iloc[idx - 1] > df["MACDh"].iloc[idx] + ): + macd_color.append("#FF5252") # red + elif ( + df["MACDh"].iloc[idx] < 0 + and df["MACDh"].iloc[idx - 1] < df["MACDh"].iloc[idx] + ): + macd_color.append("#FFCDD2") # faint red + else: + macd_color.append("#000000") + return macd_color + + +def calculate_macd( + data: SERIES_OR_DATAFRAME, + longer_ema_window: Optional[INT] = 26, + shorter_ema_window: Optional[INT] = 12, + signal_ema_window: Optional[INT] = 9, + stock_name: Optional[str] = None, + num_days_predate_stock_price: Optional[INT] = 31, +) -> pd.DataFrame: + # Type validations: + type_validation( + data=data, + longer_ema_window=longer_ema_window, + shorter_ema_window=shorter_ema_window, + signal_ema_window=signal_ema_window, + name=stock_name, + num_days_predate_stock_price=num_days_predate_stock_price, + ) + + # validating windows + if longer_ema_window < shorter_ema_window: + raise ValueError("longer ema window should be > shorter ema window") + if longer_ema_window < signal_ema_window: + raise ValueError("longer ema window should be > signal ema window") + + # Taking care of potential column header clash, removing "WIKI/" (which comes from legacy quandl) + if stock_name is None: + stock_name = data.name + if "WIKI/" in stock_name: + stock_name = stock_name.replace("WIKI/", "") + if isinstance(data, pd.Series): + data = data.to_frame() + # Remove prefix substring from column headers + data.columns = data.columns.str.replace("WIKI/", "") + + # Check if required columns are present, if data is a pd.DataFrame, else re-download stock price data: + download_stock_data_again = True + if isinstance(data, pd.DataFrame) and all_list_ele_in_other( + ["Open", "Close", "High", "Low", "Volume"], data.columns + ): + download_stock_data_again = False + if download_stock_data_again: + df = re_download_stock_data( + data, + stock_name=stock_name, + num_days_predate_stock_price=num_days_predate_stock_price, + ) + else: + df = data + + # Get the shorter_ema_window-day EMA of the closing price + ema_short = ( + df["Close"] + .ewm(span=shorter_ema_window, adjust=False, min_periods=shorter_ema_window) + .mean() + ) + # Get the longer_ema_window-day EMA of the closing price + ema_long = ( + df["Close"] + .ewm(span=longer_ema_window, adjust=False, min_periods=longer_ema_window) + .mean() + ) + + # Subtract the longer_ema_window-day EMA from the shorter_ema_window-Day EMA to get the MACD + macd = ema_short - ema_long + # Get the signal_ema_window-Day EMA of the MACD for the Trigger line + macd_s = macd.ewm( + span=signal_ema_window, adjust=False, min_periods=signal_ema_window + ).mean() + # Calculate the difference between the MACD - Trigger for the Convergence/Divergence value + macd_h = macd - macd_s + + # Add all of our new values for the MACD to the dataframe + df["MACD"] = df.index.map(macd) + df["MACDh"] = df.index.map(macd_h) + df["MACDs"] = df.index.map(macd_s) + return df + + +def plot_macd( + data: SERIES_OR_DATAFRAME, + longer_ema_window: Optional[INT] = 26, + shorter_ema_window: Optional[INT] = 12, + signal_ema_window: Optional[INT] = 9, + stock_name: Optional[str] = None, +): + """ + Generate a Matplotlib candlestick chart with MACD (Moving Average Convergence Divergence) indicators. + + Ref: https://github.com/matplotlib/mplfinance/blob/master/examples/indicators/macd_histogram_gradient.ipynb + + This function creates a candlestick chart using the given stock price data and overlays + MACD, MACD Signal Line, and MACD Histogram indicators. The MACD is calculated by taking + the difference between two Exponential Moving Averages (EMAs) of the closing price. + + :param data: Time series data containing stock price information. If a + DataFrame is provided, it should have columns 'Open', 'Close', 'High', 'Low', and 'Volume'. + Else, stock price data for given time frame is downloaded again. + :type data: :py:data:`~.finquant.data_types.SERIES_OR_DATAFRAME` + :param longer_ema_window: Optional, window size for the longer-term EMA (default is 26). + :type longer_ema_window: :py:data:`~.finquant.data_types.INT` + :param shorter_ema_window: Optional, window size for the shorter-term EMA (default is 12). + :type shorter_ema_window: :py:data:`~.finquant.data_types.INT` + :param signal_ema_window: Optional, window size for the signal line EMA (default is 9). + :type signal_ema_window: :py:data:`~.finquant.data_types.INT` + :param stock_name: Optional, name of the stock for labeling purposes (default is None). + + Note: + - If the input data is a DataFrame, it should contain columns 'Open', 'Close', 'High', 'Low', and 'Volume'. + - If the input data is a Series, it should have a valid name. + - The longer EMA window should be greater than or equal to the shorter EMA window and signal EMA window. + + Example: + + .. code-block:: python + + import pandas as pd + from mplfinance.original_flavor import plot as mpf + + # Create a DataFrame or Series with stock price data + data = pd.read_csv('stock_data.csv', index_col='Date', parse_dates=True) + plot_macd(data, longer_ema_window=26, shorter_ema_window=12, signal_ema_window=9, stock_name='DIS') + + """ + # calculate MACD: + df = calculate_macd( + data, + longer_ema_window, + shorter_ema_window, + signal_ema_window, + stock_name=stock_name, + ) + + # plot macd + macd_color = gen_macd_color(df) + apds = [ + mpf.make_addplot(df["MACD"], color="#2962FF", panel=1), + mpf.make_addplot(df["MACDs"], color="#FF6D00", panel=1), + mpf.make_addplot( + df["MACDh"], + type="bar", + width=0.7, + panel=1, + color=macd_color, + alpha=1, + secondary_y=True, + ), + ] + fig, axes = mpf.plot( + df, + volume=True, + type="candle", + style="yahoo", + addplot=apds, + volume_panel=2, + figsize=(20, 10), + returnfig=True, + ) + axes[2].legend(["MACD"], loc="upper left") + axes[3].legend(["Signal"], loc="lower left") + + return fig, axes diff --git a/finquant/type_utilities.py b/finquant/type_utilities.py index 25d86fd..e46f441 100644 --- a/finquant/type_utilities.py +++ b/finquant/type_utilities.py @@ -64,7 +64,7 @@ def _check_type( if element_type is not None: if isinstance(arg_values, pd.DataFrame) and not all( - arg_values.dtypes == element_type + np.issubdtype(value_type, element_type) for value_type in arg_values.dtypes ): validation_failed = True @@ -114,7 +114,7 @@ def _check_empty_data(arg_name: str, arg_values: Any) -> None: ], ] = { # DataFrames, Series, Array: - "data": ((pd.Series, pd.DataFrame), np.floating), + "data": ((pd.Series, pd.DataFrame), np.number), "pf_allocation": (pd.DataFrame, None), "returns_df": (pd.DataFrame, np.floating), "returns_series": (pd.Series, np.floating), @@ -124,6 +124,7 @@ def _check_empty_data(arg_name: str, arg_values: Any) -> None: "initial_weights": (np.ndarray, np.floating), "weights_array": (np.ndarray, np.floating), "cov_matrix": ((np.ndarray, pd.DataFrame), np.floating), + "df": (pd.DataFrame, None), # Lists: "names": ((List, np.ndarray), str), "cols": ((List, np.ndarray), str), @@ -150,15 +151,25 @@ def _check_empty_data(arg_name: str, arg_values: Any) -> None: "freq": ((int, np.integer), None), "span": ((int, np.integer), None), "num_trials": ((int, np.integer), None), + "longer_ema_window": ((int, np.integer), None), + "shorter_ema_window": ((int, np.integer), None), + "signal_ema_window": ((int, np.integer), None), + "window_length": ((int, np.integer), None), + "oversold": ((int, np.integer), None), + "overbought": ((int, np.integer), None), + "num_days_predate_stock_price": ((int, np.integer), None), # NUMERICs: "investment": ((int, np.integer, float, np.floating), None), "dividend": ((int, np.integer, float, np.floating), None), "target": ((int, np.integer, float, np.floating), None), + "avg_gain_loss": ((int, np.integer, float, np.floating), None), + "gain_loss": ((int, np.integer, float, np.floating), None), # Booleans: "plot": (bool, None), "save_weights": (bool, None), "verbose": (bool, None), "defer_update": (bool, None), + "standalone": (bool, None), } type_callable_dict: Dict[ diff --git a/finquant/utils.py b/finquant/utils.py new file mode 100644 index 0000000..e9c34d9 --- /dev/null +++ b/finquant/utils.py @@ -0,0 +1,40 @@ +import datetime +from typing import Optional + +import pandas as pd + +from finquant.data_types import ELEMENT_TYPE, INT, LIST_DICT_KEYS, SERIES_OR_DATAFRAME +from finquant.portfolio import _yfinance_request +from finquant.type_utilities import type_validation + + +def all_list_ele_in_other( + l_1: LIST_DICT_KEYS[ELEMENT_TYPE], + l_2: LIST_DICT_KEYS[ELEMENT_TYPE], +) -> bool: + """Returns True if all elements of list l1 are found in list l2.""" + return all(ele in l_2 for ele in l_1) + + +def re_download_stock_data( + data: SERIES_OR_DATAFRAME, + stock_name: str, + num_days_predate_stock_price: Optional[INT] = 0, +) -> pd.DataFrame: + # Type validations: + type_validation( + data=data, + name=stock_name, + num_days_predate_stock_price=num_days_predate_stock_price, + ) + if num_days_predate_stock_price < 0: + raise ValueError("Error: num_days_predate_stock_price must be >= 0.") + # download additional price data 'Open' for given stock and timeframe: + start_date = data.index.min() - datetime.timedelta( + days=num_days_predate_stock_price + ) + end_date = data.index.max() + datetime.timedelta(days=1) + df = _yfinance_request([stock_name], start_date=start_date, end_date=end_date) + # dropping second level of column header that yfinance returns + df.columns = df.columns.droplevel(1) + return df diff --git a/requirements.txt b/requirements.txt index 3216f4d..2fad55b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ numpy>=1.22.0 scipy>=1.2.0 pandas>=2.0 matplotlib>=3.0 +mplfinance>=0.12.10b0 quandl>=3.4.5 yfinance>=0.1.43 scikit-learn>=1.3.0 \ No newline at end of file diff --git a/tests/test_momentum_indicators.py b/tests/test_momentum_indicators.py new file mode 100644 index 0000000..e907400 --- /dev/null +++ b/tests/test_momentum_indicators.py @@ -0,0 +1,354 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest + +from finquant.momentum_indicators import ( + calculate_macd, + calculate_relative_strength_index, + calculate_wilder_smoothing_averages, + gen_macd_color, + plot_macd, + plot_relative_strength_index, +) +from finquant.utils import re_download_stock_data + +plt.close("all") +plt.switch_backend("Agg") + + +# Define a sample dataframe for testing +price_data = np.array( + [100, 102, 105, 103, 108, 110, 107, 109, 112, 115, 120, 118, 121, 124, 125, 126] +).astype(np.float64) +data = pd.DataFrame({"Close": price_data}) +macd_data = pd.DataFrame( + { + "Date": pd.date_range(start="2022-01-01", periods=16, freq="D"), + "DIS": price_data, + } +).set_index("Date", inplace=False) +macd_data.name = "DIS" + + +def test_calculate_wilder_smoothing_averages(): + # Test with the example values + result = calculate_wilder_smoothing_averages(10.0, 5.0, 14) + assert ( + abs(result - 9.642857142857142) <= 1e-15 + ) # The expected result calculated manually + + # Test with zero average gain/loss + result = calculate_wilder_smoothing_averages(0.0, 5.0, 14) + assert ( + abs(result - 0.35714285714285715) <= 1e-15 + ) # The expected result calculated manually + + # Test with zero current gain/loss + result = calculate_wilder_smoothing_averages(10.0, 0.0, 14) + assert ( + abs(result - 9.285714285714286) <= 1e-15 + ) # The expected result calculated manually + + # Test with window length of 1 + result = calculate_wilder_smoothing_averages(10.0, 5.0, 1) + assert ( + abs(result - 5.0) <= 1e-15 + ) # Since window length is 1, the result should be the current gain/loss + + # Test with negative values + result = calculate_wilder_smoothing_averages(-10.0, -5.0, 14) + assert ( + abs(result - -9.642857142857142) <= 1e-15 + ) # The expected result calculated manually + + # Test with very large numbers + result = calculate_wilder_smoothing_averages(1e20, 1e20, int(1e20)) + assert ( + abs(result - 1e20) <= 1e-15 + ) # The expected result is the same as input due to the large window length + + # Test with non-float input (should raise an exception) + with pytest.raises(TypeError): + calculate_wilder_smoothing_averages("10.0", 5.0, 14) + + # Test with window length of 0 (should raise an exception) + with pytest.raises(ValueError): + calculate_wilder_smoothing_averages(10.0, 5.0, 0) + + # Test with negative window length (should raise an exception) + with pytest.raises(ValueError): + calculate_wilder_smoothing_averages(10.0, 5.0, -14) + + +def test_calculate_relative_strength_index(): + rsi = calculate_relative_strength_index(data["Close"]) + + # Check if the result is a Pandas Series + assert isinstance(rsi, pd.Series) + + # Check the length of the result + assert len(rsi.dropna()) == len(data) - 14 + + # Check the first RSI value + assert np.isclose(rsi.dropna().iloc[0], 82.051282, rtol=1e-4) + + # Check the last RSI value + assert np.isclose(rsi.iloc[-1], 82.53358925143954, rtol=1e-4) + + # Check that the RSI values are within the range [0, 100] + assert (rsi.dropna() >= 0).all() and (rsi.dropna() <= 100).all() + + # Check for window_length > data length, should raise a ValueError + with pytest.raises(ValueError): + calculate_relative_strength_index(data["Close"], window_length=17) + + # Check for oversold >= overbought, should raise a ValueError + with pytest.raises(ValueError): + calculate_relative_strength_index(data["Close"], oversold=70, overbought=70) + + # Check for invalid levels, should raise a ValueError + with pytest.raises(ValueError): + calculate_relative_strength_index(data["Close"], oversold=150, overbought=80) + + with pytest.raises(ValueError): + calculate_relative_strength_index(data["Close"], oversold=20, overbought=120) + + # Check for empty input data, should raise a ValueError + with pytest.raises(ValueError): + calculate_relative_strength_index(pd.Series([])) + + # Check for non-Pandas Series input, should raise a TypeError + with pytest.raises(TypeError): + calculate_relative_strength_index(list(data["Close"])) + + +def test_plot_relative_strength_index_standalone(): + # Test standalone mode + xlabel_orig = "Date" + ylabel_orig = "RSI" + labels_orig = ["overbought", "oversold", "rsi"] + title_orig = "RSI Plot" + plot_relative_strength_index(data["Close"], standalone=True) + # get data from axis object + ax = plt.gca() + # ax.lines[2] is the RSI data + xlabel_plot = ax.get_xlabel() + ylabel_plot = ax.get_ylabel() + # tests + labels_plot = ax.get_legend_handles_labels()[1] + title_plot = ax.get_title() + assert labels_plot == labels_orig + assert xlabel_plot == xlabel_orig + assert ylabel_plot == ylabel_orig + assert title_plot == title_orig + + +def test_plot_relative_strength_index_not_standalone(): + # Test non-standalone mode + xlabel_orig = "Date" + ylabel_orig = "Price" + plot_relative_strength_index(data["Close"], standalone=False) + # get data from axis object + ax = plt.gca() + line1 = ax.lines[0] + stock_plot = line1.get_xydata() + xlabel_plot = ax.get_xlabel() + ylabel_plot = ax.get_ylabel() + # tests + assert (data["Close"].index.values == stock_plot[:, 0]).all() + assert (data["Close"].values == stock_plot[:, 1]).all() + assert xlabel_orig == xlabel_plot + assert ylabel_orig == ylabel_plot + + +def test_gen_macd_color_valid_input(): + # Test with valid input + macd_df = pd.DataFrame({"MACDh": [0.5, -0.2, 0.8, -0.6, 0.2]}) + colors = gen_macd_color(macd_df) + + # Check that the result is a list + assert isinstance(colors, list) + + # Check the length of the result + assert len(colors) == len(macd_df) + + # Check color assignments based on MACD values + assert colors == ["#26A69A", "#FF5252", "#26A69A", "#FF5252", "#26A69A"] + + +def test_gen_macd_color_green(): + # Test with a DataFrame where MACD values are consistently positive, should return + # all green colors + positive_df = pd.DataFrame({"MACDh": [0.5, 0.6, 0.7, 0.8, 0.9]}) + colors = gen_macd_color(positive_df) + + # Check that the result is a list of all green colors + assert colors == ["#B2DFDB", "#26A69A", "#26A69A", "#26A69A", "#26A69A"] + + +def test_gen_macd_color_faint_green(): + # Test with a DataFrame where MACD values are consistently positive but decreasing, + # should return all faint green colors + faint_green_df = pd.DataFrame({"MACDh": [0.5, 0.4, 0.3, 0.2, 0.1]}) + colors = gen_macd_color(faint_green_df) + + # Check that the result is a list of all faint green colors + assert colors == ["#26A69A", "#B2DFDB", "#B2DFDB", "#B2DFDB", "#B2DFDB"] + + +def test_gen_macd_color_red(): + # Test with a DataFrame where MACD values are consistently negative, + # should return all red colors + negative_df = pd.DataFrame({"MACDh": [-0.5, -0.6, -0.7, -0.8, -0.9]}) + colors = gen_macd_color(negative_df) + + # Check that the result is a list of all red colors + assert colors == ["#FFCDD2", "#FF5252", "#FF5252", "#FF5252", "#FF5252"] + + +def test_gen_macd_color_faint_red(): + # Test with a DataFrame where MACD values are consistently negative but decreasing, + # should return all faint red colors + faint_red_df = pd.DataFrame({"MACDh": [-0.5, -0.4, -0.3, -0.2, -0.1]}) + colors = gen_macd_color(faint_red_df) + + # Check that the result is a list of all faint red colors + assert colors == ["#FF5252", "#FFCDD2", "#FFCDD2", "#FFCDD2", "#FFCDD2"] + + +def test_gen_macd_color_single_element(): + # Test with a DataFrame containing a single element, should return a list with one color + single_element_df = pd.DataFrame({"MACDh": [0.5]}) + colors = gen_macd_color(single_element_df) + + # Check that the result is a list with one color + assert colors == ["#000000"] + + +def test_gen_macd_color_empty_input(): + # Test with an empty DataFrame, should return an empty list + empty_df = pd.DataFrame(columns=["MACDh"]) + with pytest.raises(ValueError): + colors = gen_macd_color(empty_df) + + +def test_gen_macd_color_missing_column(): + # Test with a DataFrame missing 'MACDh' column, should raise a KeyError + df_missing_column = pd.DataFrame({"NotMACDh": [0.5, -0.2, 0.8, -0.6, 0.2]}) + + with pytest.raises(KeyError): + gen_macd_color(df_missing_column) + + +def test_gen_macd_color_no_color_change(): + # Test with a DataFrame where MACD values don't change, should return all black colors + no_change_df = pd.DataFrame({"MACDh": [0.5, 0.5, 0.5, 0.5, 0.5]}) + colors = gen_macd_color(no_change_df) + + # Check that the result is a list of all black colors + assert colors == ["#000000", "#000000", "#000000", "#000000", "#000000"] + + +def test_calculate_macd_valid_input(): + # Test with valid input + result = calculate_macd(macd_data, num_days_predate_stock_price=0) + + # Check that the result is a DataFrame + assert isinstance(result, pd.DataFrame) + + # Check the length of the result + assert len(result) == 10 + # not == len(macd_data) here, as we currently re-download data, weekends are not considered + + # Check that the required columns ('MACD', 'MACDh', 'MACDs') are present in the result + assert all(col in result.columns for col in ["MACD", "MACDh", "MACDs"]) + + +def test_calculate_macd_correct_values(): + # Test for correct values in 'MACD', 'MACDh', and 'MACDs' columns + longer_ema_window = 10 + shorter_ema_window = 7 + signal_ema_window = 4 + df = re_download_stock_data( + macd_data, stock_name="DIS", num_days_predate_stock_price=0 + ) + result = calculate_macd( + macd_data, + longer_ema_window=longer_ema_window, + shorter_ema_window=shorter_ema_window, + signal_ema_window=signal_ema_window, + num_days_predate_stock_price=0, + ) + + # Calculate expected values manually (using the provided df) + ema_short = ( + df["Close"] + .ewm(span=shorter_ema_window, adjust=False, min_periods=shorter_ema_window) + .mean() + ) + ema_long = ( + df["Close"] + .ewm(span=longer_ema_window, adjust=False, min_periods=longer_ema_window) + .mean() + ) + macd = ema_short - ema_long + macd.name = "MACD" + signal = macd.ewm( + span=signal_ema_window, adjust=False, min_periods=signal_ema_window + ).mean() + macd_h = macd - signal + + # Check that the calculated values match the values in the DataFrame + assert all(result["MACD"].dropna() == macd.dropna()) + assert all(result["MACDh"].dropna() == macd_h.dropna()) + assert all(result["MACDs"].dropna() == signal.dropna()) + + +def test_calculate_macd_custom_windows(): + # Test with custom EMA window values + result = calculate_macd( + macd_data, longer_ema_window=30, shorter_ema_window=15, signal_ema_window=10 + ) + + # Check that the result is a DataFrame + assert isinstance(result, pd.DataFrame) + + # Check that the required columns ('MACD', 'MACDh', 'MACDs') are present in the result + assert all(col in result.columns for col in ["MACD", "MACDh", "MACDs"]) + + +def test_calculate_macd_invalid_windows(): + # Test with invalid window values, should raise ValueError + with pytest.raises(ValueError): + calculate_macd( + macd_data, longer_ema_window=10, shorter_ema_window=20, signal_ema_window=15 + ) + with pytest.raises(ValueError): + plot_macd( + macd_data, longer_ema_window=10, shorter_ema_window=5, signal_ema_window=30 + ) + + +def test_plot_macd(): + axes0_ylabel_orig = "Price" + axes4_ylabel_orig = "Volume $10^{6}$" + # Create sample data for testing + x = np.sin(np.linspace(1, 10, 100)) + df = pd.DataFrame( + {"Close": x}, index=pd.date_range("2015-01-01", periods=100, freq="D") + ) + df.name = "DIS" + + # Call mpl_macd function + fig, axes = plot_macd(df) + + axes0_ylabel_plot = axes[0].get_ylabel() + axes4_ylabel_plot = axes[4].get_ylabel() + + # Check if the function returned valid figures and axes objects + assert isinstance(fig, plt.Figure) + assert isinstance(axes, list) + assert len(axes) == 6 # Assuming there are six subplots in the returned figure + assert axes0_ylabel_orig == axes0_ylabel_plot + assert axes4_ylabel_orig == axes4_ylabel_plot diff --git a/version b/version index 7ad36c6..37902e9 100644 --- a/version +++ b/version @@ -1,2 +1,2 @@ -version=0.7.0 -release=0.7.0 +version=0.8.0 +release=0.8.0