Source code for stocksimpy.core.stock_data

# src/stocksimpy/core/stock_data.py
from __future__ import annotations

from datetime import date, timedelta
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd


[docs] class StockData: """ Container and validator for stock market time-series data. Loads, validates, and exports OHLCV data with DatetimeIndex. Supports CSV, Excel, SQL, JSON, DataFrames, dictionaries, and yfinance. All loaders validate required columns (Open, High, Low, Close, Volume) and datetime indexing. Parameters ---------- df : pandas.DataFrame, optional Stock data with DatetimeIndex and OHLCV columns. If None, creates an empty instance. Default is None. Attributes ---------- df : pandas.DataFrame Validated stock data with DatetimeIndex and columns 'Open', 'High', 'Low', 'Close', 'Volume'. Multi-level column indexing is used for multi-ticker data. Examples -------- >>> data = StockData.from_yfinance(["AAPL"], days_before=365) # doctest: +SKIP >>> data.df.head() # doctest: +SKIP """ def __init__(self, df: Optional[pd.DataFrame] = None) -> None: # If no DataFrame is provided, create an empty container. if df is None: self.df = pd.DataFrame() return # Otherwise, process and validate provided data. df = self._process_and_validate(df) self.df = df def _clean(self, df: pd.DataFrame) -> pd.DataFrame: """ Clean and standardize input DataFrame format. Parameters ---------- df : pandas.DataFrame Raw input DataFrame. Returns ------- pandas.DataFrame Cleaned DataFrame with DatetimeIndex and MultiIndex columns. Raises ------ ValueError If DataFrame lacks 'Date' column or DatetimeIndex. """ df = df.copy() # We check if the df is already a DatetimeIndex which is seen when loaded from yfinance or formatted this way # The below code handles the case when df isn't already DatetimeIndex if not isinstance(df.index, pd.DatetimeIndex): if "Date" in df.columns: df["Date"] = pd.to_datetime(df["Date"]) df.set_index("Date", inplace=True) else: raise ValueError( "Input DataFrame must have a 'Date' column or a DatetimeIndex." ) # Convert single-level columns into a two-level MultiIndex where the # second level is an empty string. This keeps the API consistent for # code that expects tuples like ('Close', symbol). if not isinstance(df.columns, pd.MultiIndex): df.columns = pd.MultiIndex.from_tuples([(c, "") for c in df.columns]) df.sort_index(inplace=True) return df def _validate(self, df: pd.DataFrame) -> None: """ Validate DataFrame structure and data integrity. Parameters ---------- df : pandas.DataFrame DataFrame to validate. Raises ------ ValueError If DataFrame is empty, index unsorted, has duplicates, or contains negative volume. TypeError If index not DatetimeIndex or columns not numeric. """ if df.empty: raise ValueError("DataFrame cannot be empty.") if not isinstance(df.index, pd.DatetimeIndex): raise TypeError("DataFrame index must be a DatetimeIndex.") # Check for sorted and unique index if not df.index.is_monotonic_increasing: raise ValueError("DataFrame index is not sorted monotonically.") if df.index.has_duplicates: raise ValueError("DataFrame index contains duplicate dates.") # Checks if all the required columns are present in the data frame required_cols = ["Open", "Close", "Volume", "High", "Low"] # Inspect first-level column names for required OHLCV fields first_level_cols = list(df.columns.get_level_values(0).unique()) missing = [col for col in required_cols if col not in first_level_cols] if missing: raise ValueError(f"Missing required columns: {missing}") # Since the data is in MultiIndex format, validate each ticker separately # include empty string ticker which represents single-ticker DataFrames tickers = list(df.columns.get_level_values(1).unique()) for ticker in tickers: for col in required_cols: try: series = df[(col, ticker)] except KeyError: raise ValueError( f"Missing required column '{col}' for ticker '{ticker}'" ) if not pd.api.types.is_numeric_dtype(series.dtype): raise TypeError( f"Column '{col}' must be a numerical type. Found: {series.dtype}" ) # Check for negative volume if (df[("Volume", ticker)] < 0).any(): raise ValueError("Volume data contains negative values.") # Basic logic checking (Low <= Open, Low <= Close, High >= Open, High >= Close) if (df[("Low", ticker)] > df[("High", ticker)]).any(): raise ValueError( "OHLC data inconsistency: Low price is greater than High price." ) if (df[("Open", ticker)] > df[("High", ticker)]).any() or ( df[("Open", ticker)] < df[("Low", ticker)] ).any(): raise ValueError( "OHLC data inconsistency: Open price is outside the High/Low range." ) if (df[("Close", ticker)] > df[("High", ticker)]).any() or ( df[("Close", ticker)] < df[("Low", ticker)] ).any(): raise ValueError( "OHLC data inconsistency: Close price is outside the High/Low range." ) def _process_and_validate(self, df: pd.DataFrame) -> pd.DataFrame: """ Clean and validate the input DataFrame. Parameters ---------- df : pandas.DataFrame Raw input DataFrame. Returns ------- pandas.DataFrame Cleaned and validated DataFrame. """ df_clean = self._clean(df) self._validate(df_clean) return df_clean # ------------------------------------------ # LOAD DATA
[docs] @classmethod def generate_mock_data(cls, days: int = 100, seed: int = 42) -> "StockData": """ Generate synthetic OHLCV data for testing. Parameters ---------- days : int, optional Number of days of data to generate. Default is 100. seed : int, optional Random seed for reproducibility. Default is 42. Returns ------- StockData Instance containing generated mock data. Examples -------- >>> data = StockData.generate_mock_data(50, seed=123) >>> len(data.df) 50 """ np.random.seed(seed) dates = pd.date_range(end=pd.Timestamp.today(), periods=days) # Generate base prices base_prices = np.cumsum(np.random.randn(days)) + 100 # Generate High-Low range ranges = np.random.rand(days) * 2 # Random range size high = base_prices + ranges low = base_prices - ranges # Generate Open and Close within the High-Low range daily_range = high - low open_ = low + daily_range * np.random.rand(days) close = low + daily_range * np.random.rand(days) volume = np.random.randint(1000, 10000, size=days) df = pd.DataFrame( { "Date": dates, "Open": open_, "High": high, "Low": low, "Close": close, "Volume": volume, } ) return cls(df)
[docs] @classmethod def from_csv(cls, file_path: str) -> "StockData": """ Load stock data from CSV file. Expects 'Date' column or uses first column as DatetimeIndex. Removes timezone information if present. Parameters ---------- file_path : str Path to CSV file. Returns ------- StockData Instance with validated CSV data. Raises ------ FileNotFoundError If file does not exist. ValueError If required OHLCV columns are missing. """ # Read CSV with date parsing df = pd.read_csv(file_path, index_col=0) if df.index.name == "Date" or df.index.name == "Index": df.index.name = "Date" df.index = pd.to_datetime(df.index) if df.index.tz is not None: df.index = df.index.tz_localize(None) elif "Date" in df.columns: # Convert to datetime without timezone df["Date"] = pd.to_datetime(df["Date"]) if hasattr(df["Date"].dtype, "tz") and df["Date"].dtype.tz is not None: df["Date"] = df["Date"].dt.tz_localize(None) df.set_index("Date", inplace=True) return cls(df)
[docs] @classmethod def from_excel(cls, file_path: str) -> "StockData": """ Load stock data from Excel file. Parameters ---------- file_path : str Path to Excel file. Returns ------- StockData Instance with loaded Excel data. """ df = pd.read_excel(file_path) return cls(df)
[docs] @classmethod def from_sql(cls, query: str, connection: Any) -> "StockData": """ Load stock data from SQL database. Parameters ---------- query : str SQL SELECT query. connection Open database connection. Returns ------- StockData Instance with query result data. """ df = pd.read_sql(query, connection) return cls(df)
[docs] @classmethod def from_yfinance( cls, tickers: str | list[str], start_date: Optional[date] = None, end_date: Optional[date] = None, days_before: Optional[int] = None, ) -> "StockData": """ Load stock data from Yahoo Finance. Parameters ---------- tickers : list List of ticker symbols, e.g., ['AAPL', 'MSFT']. start_date : date, optional Start date (inclusive). Ignored if days_before is set. Default is None. end_date : date, optional End date (inclusive). Ignored if days_before is set. Default is None. days_before : int, optional Number of days before today to retrieve. Overrides start_date and end_date. Default is None. Returns ------- StockData Instance with multi-ticker OHLCV data with DatetimeIndex. Raises ------ ImportError If yfinance package is not installed. Notes ----- If `days_before` is provided, the `start_date` and `end_date` parameters are ignored. Requires internet connection. Prices are auto-adjusted for splits/dividends. Multi-ticker requests result in a MultiIndex DataFrame. Examples -------- >>> data = StockData.from_yfinance(["AAPL"], days_before=365) # doctest: +SKIP >>> data.df.head() # doctest: +SKIP """ import yfinance as yf try: if days_before: today = date.today() start_date = today - timedelta(days=days_before) end_date = today data = yf.download( tickers, start=start_date, end=end_date, auto_adjust=True ) # If the download returned a DataFrame, ensure 'Date' is a column/index data.reset_index(inplace=True) # Set Date as index before adjusting columns if "Date" in data.columns: data["Date"] = pd.to_datetime(data["Date"]) data.set_index("Date", inplace=True) return cls(data) except ImportError: raise ImportError( """ yfinance is not installed. Install it to use Yahoo Finance loaders. try using: `pip install yfinance` """ )
[docs] @classmethod def from_dataframe(cls, df: pd.DataFrame) -> "StockData": """ Create StockData from existing DataFrame. Parameters ---------- df : pandas.DataFrame DataFrame with OHLCV data. Returns ------- StockData Instance initialized with DataFrame. """ return cls(df)
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> "StockData": """ Create StockData from dictionary. Parameters ---------- data : dict Column names as keys, value lists as data. Returns ------- StockData Instance initialized from dictionary. """ df = pd.DataFrame(data) return cls(df)
[docs] @classmethod def from_json(cls, json_data: dict[str, Any]) -> "StockData": """ Create StockData from JSON object. Parameters ---------- json_data : dict Parsed JSON as dictionary. Returns ------- StockData Instance initialized from JSON. """ df = pd.DataFrame(json_data) return cls(df)
[docs] @classmethod def from_sqlite(cls, query: str, db_path: str) -> "StockData": """ Load stock data from SQLite database. Parameters ---------- query : str SQL SELECT query. db_path : str Path to SQLite database file. Returns ------- StockData Instance with query result data. Examples -------- >>> data = StockData.from_sqlite("SELECT * FROM stocks", "data.db") # doctest: +SKIP """ import sqlite3 conn = sqlite3.connect(db_path) df = pd.read_sql(query, conn) conn.close() return cls(df)
[docs] @staticmethod def auto_loader( source: Union[pd.DataFrame, dict[str, Any], str, tuple[Any, ...]], **kwargs: Any, ) -> "StockData": """ Auto-detect input type and load data. Supports CSV, Excel, SQLite, DataFrame, dict, JSON, and yfinance via tuple specification. Parameters ---------- source : str, dict, tuple, or pandas.DataFrame Input source. File paths (str) trigger format detection by extension. Tuples trigger yfinance loading: (ticker, days_before) or (ticker, start_date, end_date). **kwargs Additional arguments for specific loaders (e.g., 'query' for SQLite). Returns ------- StockData Instance with loaded and validated data. Raises ------ ValueError If file extension is unsupported or tuple format is invalid. TypeError If source type is not recognized. Examples -------- >>> data = StockData.auto_loader("prices.csv") # doctest: +SKIP >>> data = StockData.auto_loader({"Open": [...], "Close": [...]}) # doctest: +SKIP """ import os from datetime import date # DataFrame if isinstance(source, pd.DataFrame): return StockData.from_dataframe(source) # Dict (possibly JSON data) elif isinstance(source, dict): return StockData.from_dict(source) # File path string elif isinstance(source, str): ext = os.path.splitext(source)[-1].lower() if ext == ".csv": return StockData.from_csv(source) elif ext in [".xls", ".xlsx"]: return StockData.from_excel(source) elif ext == ".db": query = kwargs.get("query", "SELECT * FROM stock_data") return StockData.from_sqlite(query, source) else: raise ValueError(f"Unsupported file extension: {ext}") # Tuple -> yfinance loader elif isinstance(source, tuple): if len(source) == 2 and isinstance(source[1], int): ticker, days_before = source return StockData.from_yfinance(tickers=ticker, days_before=days_before) elif ( len(source) == 3 and isinstance(source[1], date) and isinstance(source[2], date) ): ticker, start_date, end_date = source return StockData.from_yfinance( tickers=ticker, start_date=start_date, end_date=end_date ) else: raise ValueError( "Tuple input must be (ticker, days_before) or (ticker, start_date, end_date)" ) else: raise TypeError(f"Unsupported source type: {type(source)}")
# ----------------------- # BASIC INFO AND FUNCTIONALITIES
[docs] def get(self, column: str) -> pd.Series: """ Get a column from the DataFrame. Parameters ---------- column : str Column name. Returns ------- pandas.Series Column data. """ return self.df[column]
[docs] def slice(self, start: Any = None, end: Any = None) -> pd.DataFrame: """ Slice DataFrame by date range. Parameters ---------- start : str or datetime, optional Start date (inclusive). end : str or datetime, optional End date (inclusive). Returns ------- pandas.DataFrame Sliced data. """ return self.df.loc[start:end]
[docs] def head(self, n: int = 5) -> pd.DataFrame: """ Return first n rows. Parameters ---------- n : int, optional Number of rows. Default is 5. Returns ------- pandas.DataFrame First n rows. """ return self.df.head(n)
[docs] def info(self) -> None: """ Display DataFrame information. Returns ------- None """ self.df.info()
[docs] def fill_missing(self, method: Literal["ffill", "bfill"] = "ffill") -> "StockData": """ Fill missing values in DataFrame. Parameters ---------- method : str, optional Fill method: 'ffill' (forward fill) or 'bfill' (backward fill). Default is 'ffill'. Returns ------- StockData Self for method chaining. """ if method == "ffill": self.df.ffill(inplace=True) elif method == "bfill": self.df.bfill(inplace=True) else: raise ValueError( method + ' is not a fill_missing method option, use either "ffill" or "bfill"' ) return self
[docs] def check_missing(self) -> pd.Series: """ Count missing values per column. Returns ------- pandas.Series Missing value count per column. """ return self.df.isnull().sum()
[docs] def add_indicator( self, indicator: str | Callable[..., dict[str, pd.Series]], *args: Any, base_col: str = "Close", symbol: str = "", overwrite: bool = False, **kwargs: Any, ) -> None: """ Add one or more technical indicators to a stock's data. This method applies an indicator function to a specified base column (e.g., "close") for a given stock symbol in a MultiIndex DataFrame. The indicator function must return a dictionary mapping indicator names to pandas Series. Each resulting indicator is inserted as a new column under the corresponding stock symbol in the DataFrame. You can use ``Indicators`` module from ``stocksimpy`` to generate these indicators quickly. Parameters ---------- indicator : str or callable Either the name of a built-in indicator (e.g., "sma", "rsi") or a custom function that accepts a pandas Series and returns a dict of {str: pandas.Series}. base_col : str, default "Close" Base data column to apply the indicator to. symbol : str, default "" Stock symbol identifying the top-level column in the MultiIndex. Leave empty ("") if want to apply to all the symbols overwrite : bool, default False Whether to overwrite existing indicator columns. *args, **kwargs Additional arguments passed directly to the indicator function. Returns ------- None The method mutates the StockData instance in place by adding new indicator columns. Raises ------ KeyError If the specified stock or base column does not exist in the data. Examples -------- >>> def calculate_sma(series, window): ... return {f"sma_{window}": series.rolling(window).mean()} >>> data.add_indicator(calculate_sma, "close", 20, "AAPL") # doctest: +SKIP >>> def calculate_bands(series, window): ... sma = series.rolling(window).mean() ... std = series.rolling(window).std() ... return { ... f"bb_upper_{window}": sma + 2 * std, ... f"bb_lower_{window}": sma - 2 * std, ... } >>> data.add_indicator(calculate_bands, "close", 20, "MSFT") # doctest: +SKIP """ if symbol != "": if (base_col, symbol) not in self.df: raise KeyError( f"The key ({base_col}, {symbol}) does NOT exist in the dataframe" ) if isinstance(indicator, str): indicator = indicator.lower() from stocksimpy.addons.indicators import Indicators indicator_func = Indicators.get_name_func()[indicator] elif callable(indicator): indicator_func = indicator else: raise TypeError( f"``indicator`` type can only be str or callable, but argument is of type {type(indicator)}" ) if symbol == "": symbol_it = self.df.columns.get_level_values(1).unique() else: symbol_it = [symbol] for ticker in symbol_it: series = self.df[(base_col, ticker)] result = indicator_func(series, *args, **kwargs) if not isinstance(result, dict): raise TypeError("Indicator function must return dict[str, pd.Series]") new_cols = {} for name, values in result.items(): col_key = (name, ticker) if overwrite or (col_key not in self.df.columns): new_cols[col_key] = values new_df = pd.DataFrame(new_cols, index=self.df.index) self.df = pd.concat([self.df, new_df], axis=1)
[docs] def add_indicator_all( self, symbol: str = "", base_col: str = "Close", overwrite: bool = False ) -> None: """ Add all built-in technical indicators to the stock data for one or more symbols. This method iterates over all indicator functions in the Indicators module and applies them to the specified base column for each symbol in the DataFrame. It uses the `add_indicator` method for each indicator. If `symbol` is empty, indicators are added for all symbols. Parameters ---------- symbol : str, optional Stock symbol identifying the top-level column in the MultiIndex. If empty (""), applies to all symbols. base_col : str, optional Base data column to apply the indicators to. Default is "Close". overwrite : bool, optional Whether to overwrite existing indicator columns. Default is False. Returns ------- None The method mutates the StockData instance in place by adding new indicator columns. Examples -------- >>> data.add_indicator_all() # Adds all indicators to all symbols >>> data.add_indicator_all(symbol="AAPL", base_col="Close", overwrite=True) # Adds all indicators for AAPL, overwriting existing """ from stocksimpy.addons.indicators import Indicators if symbol == "": tickers = self.df.columns.get_level_values(1).unique() else: tickers = [symbol] for ticker in tickers: for indicator in Indicators.get_name_func(): self.add_indicator( indicator, base_col=base_col, symbol=ticker, overwrite=overwrite )
# -------------------------- # EXPORT DATA
[docs] def to_csv(self, file_path: str, **kwargs: Any) -> str: """ Export DataFrame to CSV file. Parameters ---------- file_path : str Output file path. **kwargs Additional arguments passed to pandas.to_csv. Returns ------- str Path to exported file. """ # Ensure the index is saved as 'Date' column df_to_save = self.df.copy() df_to_save.index.name = "Date" # If we have MultiIndex columns, flatten them for CSV export if isinstance(df_to_save.columns, pd.MultiIndex): df_to_save.columns = [ col[0] if col[1] == "" else f"{col[0]}_{col[1]}" for col in df_to_save.columns ] df_to_save.to_csv(file_path, **kwargs) return file_path
[docs] def to_excel(self, file_path: str, **kwargs: Any) -> str: """ Export DataFrame to Excel file. Parameters ---------- file_path : str Output file path (.xls or .xlsx). **kwargs Additional arguments passed to pandas.to_excel. Returns ------- str Path to exported file. """ self.df.to_excel(file_path, **kwargs) return file_path
[docs] def to_sql( self, table_name: str, connection: Any, if_exists: str = "replace", **kwargs: Any, ) -> str: """ Export DataFrame to SQL table. Parameters ---------- table_name : str Target table name. connection : sqlalchemy.engine.Connection or sqlite3.Connection Active database connection. if_exists : str, optional Behavior if table exists: 'fail', 'replace', 'append'. Default is 'replace'. **kwargs Additional arguments passed to pandas.to_sql. Returns ------- str Table name in database. """ self.df.to_sql( table_name, connection, if_exists=if_exists, index=True, **kwargs ) return table_name
[docs] def to_sqlite( self, table_name: str, db_path: str, if_exists: str = "replace", **kwargs: Any ) -> str: """ Export DataFrame to SQLite database. Parameters ---------- table_name : str Target table name in database. db_path : str Path to SQLite database file. if_exists : str, optional Behavior if table exists: 'fail', 'replace', 'append'. Default is 'replace'. **kwargs Additional arguments passed to pandas.to_sql. Returns ------- str Path to database file. """ import sqlite3 conn = sqlite3.connect(db_path) self.df.to_sql(table_name, conn, if_exists=if_exists, index=True, **kwargs) conn.close() return db_path
[docs] def to_dataframe(self) -> pd.DataFrame: """ Return copy of underlying DataFrame. Returns ------- pandas.DataFrame DataFrame with DatetimeIndex and OHLCV columns. """ return self.df.copy()
[docs] def to_dict(self, orient: str = "records") -> dict[str, Any]: """ Export DataFrame to dictionary. Parameters ---------- orient : str, optional Dictionary orientation. Default is 'records'. Returns ------- dict DataFrame as dictionary. """ return self.df.to_dict(orient=orient)
[docs] def to_json( self, file_path: Optional[str] = None, orient: str = "records", **kwargs: Any, ) -> str: """ Export DataFrame to JSON. Parameters ---------- file_path : str, optional Output file path. If None, returns JSON string. Default is None. orient : str, optional JSON orientation ('records', 'columns', etc.). Default is 'records'. **kwargs Additional arguments passed to pandas.to_json. Returns ------- str JSON string Raises ------ IOError If file cannot be written. """ json_str = self.df.to_json(orient=orient, **kwargs) if file_path: with open(file_path, "w") as f: f.write(json_str) return json_str
[docs] def to_custom( self, export_func: Callable[..., Any], *args: Any, **kwargs: Any ) -> Any: """ Export using custom function. Parameters ---------- export_func : callable Function taking DataFrame as first argument. *args Positional arguments passed to export_func. **kwargs Keyword arguments passed to export_func. Returns ------- result of export_func(self.df, *args, **kwargs) Export result from custom function. Examples -------- >>> def save_parquet(df, path): ... df.to_parquet(path) ... return path >>> data.to_custom(save_parquet, "output.parquet") # doctest: +SKIP Raises ------ IOError If file cannot be written. Raises ------ IOError If file cannot be written. Raises ------ IOError If database cannot be written. """ return export_func(self.df, *args, **kwargs)