import datetime as dt
from dataclasses import dataclass
from typing import Union, Optional, Tuple, Dict
import numpy as np
import pandas as pd
from numpy.typing import NDArray
from numba.typed import List as NumbaList
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
import sys
from finmlkit.bar.utils import footprint_to_dataframe
from finmlkit.utils.log import get_logger
from .utils import comp_trade_side_vector, merge_split_trades
import os
logger = get_logger(__name__)
[docs]
class TradesData:
"""
Class to preprocess trades data for bar building.
This class handles standardization of column names, timestamp conversion,
trade merging, and side inference for consistent processing across different
data sources.
"""
def __init__(self,
ts: NDArray, px: NDArray, qty: NDArray, id: NDArray = None, *,
is_buyer_maker: NDArray = None,
side = None,
dt_index: Optional[pd.DatetimeIndex] = None,
timestamp_unit: Optional[str] = None,
preprocess: bool = False,
proc_res: Optional[str] = None, name= None):
"""
Initialize the TradesData with raw trades data.
:param ts: array of timestamps
:param px: array of prices
:param qty: array of quantity or amount of trades
:param id: array of trades id
:param is_buyer_maker: Optional Array of side info: True if buyer maker, False otherwise. If None side information will be inferred from data.
:param side: Optional Array Market order side information (-1: sell, 1: buy) [needed when loading from HDF5 store].
:param dt_index: Optional DatetimeIndex for the trades data. If provided, it will be used as the index. [needed when loading from HDF5 store].
:param timestamp_unit: (Optional) timestamp unit (e.g., 'ms', 'us', 'ns'); inferred if None.
:param proc_res: (Optional) processing resolution for timestamps (e.g., 'ms' cuts us to ms resolution).
:param preprocess: If True, runs the preprocessing pipeline (sorting, merging split trades etc...)
:param name: Optional name for the trades data instance (logging purposes).
:raises ValueError: If required columns are missing or timestamp format is invalid.
"""
if not isinstance(ts, np.ndarray):
raise TypeError("ts must be a np.ndarray")
if not isinstance(px, np.ndarray):
raise TypeError("px must be a np.ndarray")
if not isinstance(qty, np.ndarray):
raise TypeError("qty must be a np.ndarray")
if id is not None and not isinstance(id, np.ndarray):
raise TypeError("id must be a np.ndarray")
if is_buyer_maker is not None and not isinstance(is_buyer_maker, np.ndarray):
raise TypeError("is_buyer_maker must be None or np.ndarray")
if side is not None and not isinstance(side, np.ndarray):
raise TypeError("side must be None or np.ndarray")
self._start_date = self._end_date = None
self._data = pd.DataFrame({'timestamp': ts, 'price': px, 'amount': qty, 'id': id})
self.is_buyer_maker = is_buyer_maker
if side is not None:
self._data['side'] = side
self._orig_timestamp_unit = timestamp_unit if timestamp_unit else self._infer_timestamp_unit()
self.name = name
# Process the trades data
self.missing_pct = 0
self.data_ok = None
self.discontinuities = [] # List to store discontinuity information
if preprocess:
if id is None:
raise ValueError("id is required if preprocess is True")
self._convert_timestamps_to_ns()
self._sort_trades()
self._merge_trades()
self._apply_timestamp_resolution(proc_res)
if "side" not in self._data.columns:
# If side info is not provided, infer it from trades data
self._add_trade_side_info()
# Add datetime_idx
if dt_index is not None:
self._data.set_index(dt_index, inplace=True)
else:
self._data.set_index(pd.to_datetime(self._data['timestamp'], unit='ns'), inplace=True)
self._data.index.name = "datetime"
logger.info("TradesData prepared successfully.")
@property
def start_date(self):
"""
Get the start date of the trades data.
:return: Start date as a pandas Timestamp.
"""
return self._start_date
@property
def end_date(self):
"""
Get the end date of the trades data.
:return: End date as a pandas Timestamp.
"""
return self._end_date
[docs]
def set_view_range(self, start: pd.Timestamp|str, end: pd.Timestamp|str):
"""
Set the view range for the trades data.
:param start: Start timestamp for the view range.
:param end: End timestamp for the view range.
:return: None
"""
if isinstance(start, str):
start = pd.Timestamp(start)
if isinstance(end, str):
end = pd.Timestamp(end)
if start >= end:
raise ValueError("Start timestamp must be before end timestamp.")
self._start_date = start
self._end_date = end
logger.info(f"View range set to {start} - {end}.")
@property
def data(self) -> pd.DataFrame:
"""
Get the processed trades data as a DataFrame corresponding to the active view range.
:return: DataFrame containing trades data.
"""
if self._start_date is None and self._end_date is None:
return self._data
return self._data.loc[self._start_date: self._end_date]
@property
def orig_timestamp_unit(self) -> str:
"""
Get the timestamp unit used for processing.
:return: Timestamp unit string.
"""
return self._orig_timestamp_unit
def _validate_data(self):
"""
Check for gaps in trade IDs
"""
# First convert to numeric to handle potential string IDs
id_diffs = np.diff(self.data['id'].values)
gap_indices = np.where(id_diffs > 1)[0]
cum_gap_size = 0
if len(gap_indices) > 0:
logger.warning(f"{self.name} | Found {len(gap_indices):,} discontinuities in trade IDs. "
f"This indicates missing trades.")
# Record detailed information about each discontinuity
n_large_gaps = 0
for idx in gap_indices:
gap_start_id = int(self.data['id'].iloc[idx])
gap_end_id = int(self.data['id'].iloc[idx + 1])
gap_size = gap_end_id - gap_start_id - 1
cum_gap_size += gap_size
# Get timestamps for the trades before and after the gap
pre_gap_time = pd.to_datetime(self.data['timestamp'].iloc[idx], unit='ns')
post_gap_time = pd.to_datetime(self.data['timestamp'].iloc[idx + 1], unit='ns')
time_diff = post_gap_time - pre_gap_time
# Record the discontinuity if gap is greater than 1 min
if time_diff > pd.Timedelta(minutes=1):
self.data_ok = False
n_large_gaps += 1
self.discontinuities.append({
'start_id': gap_start_id,
'end_id': gap_end_id,
'missing_ids': gap_size,
'pre_gap_time': pre_gap_time,
'post_gap_time': post_gap_time,
'time_interval': time_diff
})
if n_large_gaps > 0:
logger.warning(f"{self.name} | Found {n_large_gaps} large gaps greater than 1 minute.")
logger.info(
f"Recorded {len(self.discontinuities)} trade ID discontinuities with corresponding time intervals.")
self.missing_pct = cum_gap_size / len(self.data) * 100
def _sort_trades(self) -> None:
"""
Sort trades by timestamp to ensure correct order for processing.
Also performs data integrity checks by identifying discontinuities in trade IDs.
"""
self.data_ok = True
self.discontinuities = [] # Reset discontinuities list
# Sort by ID to inspect data integrity
self.data.sort_values(by=['id'], inplace=True)
# Reset index
self.data.reset_index(drop=True, inplace=True)
# Check duplicates in trade IDs
if self.data['id'].duplicated().any():
logger.warning(f"{self.name} | Trade IDs contain duplicates. This may indicate data corruption.")
# Drop duplicates while keeping the first occurrence
self.data.drop_duplicates(subset='id', keep='first', inplace=True)
logger.info("Duplicates in trade IDs have been removed.")
self.data_ok = False
self._validate_data()
# Now sort by timestamp for chronological order if needed
if not self.data.timestamp.is_monotonic_increasing:
logger.warning(f"{self.name} | Trades timestamps are not monotonic increasing after sorting by trade IDs. "
f"Sorting by timestamp for chronological order...")
self.data.sort_values(by=['timestamp', 'id'], inplace=True)
# Reset index
self.data.reset_index(drop=True, inplace=True)
def _merge_trades(self):
"""
Merge trades that occur at the same timestamp and price level.
"""
logger.info('Merging split trades (same timestamps) on same price level...')
ts, px, am, side = merge_split_trades(
self.data['timestamp'].values.astype(np.int64),
self.data['price'].values.astype(np.float64),
self.data['amount'].values.astype(np.float32),
self.is_buyer_maker,
)
self._data = pd.DataFrame({
'timestamp': ts,
'price': px,
'amount': am
})
if self.is_buyer_maker is not None:
self._data['side'] = side
def _convert_timestamps_to_ns(self):
"""
Convert timestamps to nanosecond representation.
:raises ValueError: If timestamp format is invalid.
"""
# Infer or validate timestamp unit
valid_units = ['s', 'ms', 'us', 'ns']
if self.orig_timestamp_unit not in valid_units:
raise ValueError(f"Invalid timestamp format! Must be one of: {', '.join(valid_units)}")
unit_scale_factor = {
's': 1_000_000_000,
'ms': 1_000_000,
'us': 1_000,
'ns': 1
}
# Convert timestamp to nanoseconds
logger.info('Converting timestamp to nanoseconds units for processing...')
# trades.timestamp = pd.to_datetime(trades.timestamp, unit=timestamp_unit).astype(np.int64).values
# Work directly on the underlying NumPy array for better performance
factor = unit_scale_factor[self.orig_timestamp_unit]
self._data['timestamp'].values[:] = np.multiply(self.data['timestamp'].values, factor, dtype=np.int64)
def _apply_timestamp_resolution(self, proc_res: Optional[str]) -> None:
"""
Apply processing resolution to timestamps if specified.
:param proc_res: Target processing resolution for timestamps.
:raises ValueError: If processing resolution is invalid.
"""
if proc_res and proc_res != self.orig_timestamp_unit:
logger.info(f"Processing resolution: {proc_res} -> converting timestamps...")
# Convert proc_res to nanoseconds scale factor
scale_factors = {'s': 1_000_000_000, 'ms': 1_000_000, 'us': 1_000, 'ns': 1}
if proc_res not in scale_factors:
raise ValueError(
f"Invalid processing resolution: {proc_res}. Must be one of: {', '.join(scale_factors.keys())}")
# Round timestamps to the specified resolution
resolution_ns = scale_factors[proc_res]
self.data.timestamp = (self.data.timestamp // resolution_ns) * resolution_ns
def _add_trade_side_info(self) -> None:
"""
Extract trade side information from the trades data.
:returns: None - modifies the trades DataFrame in place to include a 'side' column.
"""
logger.info("No trade side information found. Inferring trade side from price movements.")
self._data['side'] = comp_trade_side_vector(self.data['price'].values)
def _infer_timestamp_unit(self) -> str:
"""
Infer the unit of timestamps in the trades data if not explicitly provided.
:return: Inferred or provided timestamp unit.
"""
max_ts = self.data['timestamp'].max()
if max_ts > 1e18: # Likely in nanoseconds
timestamp_unit = 'ns'
elif max_ts > 1e15: # Likely in microseconds
timestamp_unit = 'us'
elif max_ts > 1e12: # Likely in milliseconds
timestamp_unit = 'ms'
else: # Likely in seconds
timestamp_unit = 's'
logger.warning("Timestamp unit is set to seconds. Please verify the data.")
logger.info(f"Inferred timestamp format: {timestamp_unit}")
return timestamp_unit
[docs]
def save_h5(
self,
filepath: str,
*,
month_key: Optional[str] = None,
complib: str = "blosc:lz4",
complevel: int = 1,
mode: str = "a",
chunksize: int = 1_000_000,
overwrite_month: bool = True,
) -> str:
"""
Persist the raw trades to an on-disk HDF5 store.
The data of **each calendar month** lives under ``/trades/YYYY-MM`` in the file.
- When adding new monthly data, it will be stored in a new group.
- When adding data for an existing month, you can either append to it or overwrite it with confirmation.
:param filepath: Destination `.h5` file. The parent directories are created automatically when missing.
:param month_key: Override the key of the form ``"YYYY-MM"``. When ``None`` the key is derived from the first timestamp of ``self.data``.
:param complib: Compression backend used by PyTables. Default is ``blosc:zstd``.
:param complevel: Compression level. Default is 5.
:param mode: File mode – ``"a"`` to create or append, ``"w"`` to start fresh. Default is ``"a"``.
:param chunksize: Row chunk size used by PyTables when writing large frames. Default is 1000000.
:param overwrite_month: If True and the month data exists, prompts for confirmation to overwrite. Default is True.
:returns: The full key used inside the store, e.g. ``"/trades/2025-02"``.
:raises: ValueError if user declines to overwrite existing data.
"""
# ------------------------------------------------------------------
# Derive the monthly key and ensure output path exists
# ------------------------------------------------------------------
if month_key is None:
first_dt = pd.to_datetime(self.data["timestamp"].iloc[0], unit="ns")
month_key = f"{first_dt.year:04d}-{first_dt.month:02d}"
h5_key = f"/trades/{month_key}"
meta_key = f"/meta/{month_key}"
integrity_key = f"/integrity/{month_key}"
os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True)
# ------------------------------------------------------------------
# Build an *indexed* frame for fast time‑slice queries
# ------------------------------------------------------------------
# frame = self.data.copy()
# frame["datetime"] = pd.to_datetime(frame["timestamp"], unit="ns")
# frame.set_index("datetime", inplace=True)
# It is already indexed!
frame = self.data
# ------------------------------------------------------------------
# Write / append to the store
# ------------------------------------------------------------------
with pd.HDFStore(
filepath,
mode=mode,
complib=complib,
complevel=complevel,
) as store:
# Check if month data already exists
month_exists = h5_key in store
should_overwrite = False
if month_exists and overwrite_month:
# Prompt user for confirmation
#record_count = store.get_storer(h5_key).nrows
#user_input = input(
# f"WARNING: Data for {month_key} already exists with {record_count:,} records.\n"
# f"Do you want to overwrite it? [y/N]: "
#).lower()
#should_overwrite = user_input in ('y', 'yes')
should_overwrite = True
if not should_overwrite:
user_input = input("Do you want to append to existing data instead? [Y/n]: ").lower()
if user_input in ('n', 'no'):
logger.info(f"Operation cancelled by user. No changes made to {month_key} data.")
return h5_key
# Handle data writing
if month_exists and should_overwrite:
logger.info(f"Overwriting existing data for {month_key}...")
store.remove(h5_key)
store.remove(meta_key)
if integrity_key in store:
store.remove(integrity_key)
store.put(
key=h5_key,
value=frame,
format="table",
data_columns=["timestamp"],
index=False,
min_itemsize={"side": 1},
)
elif month_exists:
# Append using PyTables row‑wise interface (fast)
logger.info(f"Appending to existing data for {month_key}...")
store.append(
key=h5_key,
value=frame,
format="table",
data_columns=["timestamp"],
index=False,
min_itemsize={"side": 1},
chunksize=chunksize,
)
else:
# Create new month data
logger.info(f"Creating new data for {month_key}...")
store.put(
key=h5_key,
value=frame,
format="table",
data_columns=["timestamp"],
index=False,
min_itemsize={"side": 1},
)
# ------------------------------------------------------------------
# Update lightweight per‑group metadata (fast group discovery later)
# ------------------------------------------------------------------
meta = pd.Series(
{
"record_count": len(frame) if should_overwrite else (
store.get_storer(h5_key).nrows if month_exists else len(frame)),
"first_timestamp": int(frame["timestamp"].iloc[0]),
"last_timestamp": int(frame["timestamp"].iloc[-1]),
"data_integrity_ok": self.data_ok, # Add integrity flag to main metadata
"missing_pct": self.missing_pct # Add count of discontinuities
}
)
store.put(meta_key, meta, format="fixed")
# ------------------------------------------------------------------
# Store data integrity information if discontinuities were found
# ------------------------------------------------------------------
if self.discontinuities:
# Convert discontinuities list to a DataFrame with string representation of objects
discontinuity_data = []
for disc in self.discontinuities:
# Convert pandas Timestamp and Timedelta objects to strings to ensure serialization works
disc_dict = {
'start_id': disc['start_id'],
'end_id': disc['end_id'],
'missing_ids': disc['missing_ids'],
'pre_gap_time_str': str(disc['pre_gap_time']),
'post_gap_time_str': str(disc['post_gap_time']),
'time_interval_str': str(disc['time_interval'])
}
discontinuity_data.append(disc_dict)
# Save discontinuity data as DataFrame
if discontinuity_data:
disc_df = pd.DataFrame(discontinuity_data)
store.put(integrity_key, disc_df, format="table")
logger.info(f"Saved {len(disc_df)} trade ID discontinuities to metadata.")
logger.info(f"Successfully saved {len(frame):,} records for {month_key}")
return h5_key
# ------------------------------------------------------------------
# Reading helpers
# ------------------------------------------------------------------
@classmethod
def _keys_for_timerange(
cls, store: pd.HDFStore, start: Optional[pd.Timestamp], end: Optional[pd.Timestamp]
) -> list[str]:
"""Internal helper – determine which monthly groups intersect the
*[start, end]* interval by consulting the per‑group metadata.
"""
candidate_keys: list[str] = []
for meta_key in (k for k in store.keys() if k.startswith("/meta/")):
meta = store[meta_key]
first = pd.to_datetime(meta["first_timestamp"], unit="ns")
last = pd.to_datetime(meta["last_timestamp"], unit="ns")
if (end is None or first <= end) and (start is None or last >= start):
# If there is intersection, add the corresponding trades key
candidate_keys.append(meta_key.replace("/meta", "/trades"))
return sorted(candidate_keys)
[docs]
@classmethod
def load_trades_h5(
cls,
filepath: str,
*,
key: Optional[str] = None,
start_time: Optional[Union[str, pd.Timestamp]] = None,
end_time: Optional[Union[str, pd.Timestamp]] = None,
n_workers: Optional[int] = None,
enable_multiprocessing: bool = True,
min_groups_for_mp: int = 2,
) -> "TradesData":
"""
Load trades from *filepath* with optional multiprocessing support.
Three usage modes exist:
1. ``key`` only – load the full monthly partition ``/trades/<key>``.
2. ``start_time`` / ``end_time`` – assemble the minimal set of monthly
groups touching the range, slice **at read time** for maximum speed.
3. Combination – constrain selection *within* the chosen "key".
:param filepath: Path to the HDF5 file.
:param key: Optional specific monthly key to load (e.g., "2025-03").
:param start_time: Optional start time for filtering.
:param end_time: Optional end time for filtering.
:param n_workers: Number of worker processes. If None, uses CPU count - 1.
:param enable_multiprocessing: Whether to use multiprocessing when loading multiple groups.
:param min_groups_for_mp: Minimum number of groups required to enable multiprocessing.
:return: TradesData instance with loaded data.
"""
# ------------------------------------------------------------------
# Normalise temporal boundaries
# ------------------------------------------------------------------
if isinstance(start_time, str):
start_time = pd.Timestamp(start_time)
if isinstance(end_time, str):
end_time = pd.Timestamp(end_time)
logger.info(f"Loading trades from {filepath}...")
# First, determine which keys we need to load
with pd.HDFStore(filepath, mode="r") as store:
available_keys = [k for k in store.keys() if k.startswith('/trades/')]
# Determine which groups to read -----------------------------------------------------
if key is not None:
h5_keys = [f"/trades/{key}"]
# Check if key is available
if h5_keys[0] not in store:
logger.info(f"Available keys in the store: {available_keys}")
raise KeyError(f"HDF5 group '{h5_keys[0]}' not found in the store.")
else:
h5_keys = cls._keys_for_timerange(store, start_time, end_time)
if not h5_keys:
raise KeyError("No HDF5 group matches the requested slice.")
# ------------------------------------------------------------------
# Decide whether to use multiprocessing
# ------------------------------------------------------------------
use_multiprocessing = (
enable_multiprocessing and
len(h5_keys) >= min_groups_for_mp
)
# Prepare where clause for time filtering
where_clause = []
if start_time is not None:
where_clause.append(f"index >= Timestamp('{start_time}')")
if end_time is not None:
where_clause.append(f"index <= Timestamp('{end_time}')")
where = " & ".join(where_clause) if where_clause else None
frames: list[pd.DataFrame] = []
if use_multiprocessing:
logger.info(f"Loading {len(h5_keys)} groups using multiprocessing with {n_workers or mp.cpu_count() - 1} workers...")
# Prepare arguments for worker processes
worker_args = [(filepath, h5_key, where) for h5_key in h5_keys]
# Determine number of workers
if n_workers is None:
n_workers = min(mp.cpu_count() - 1, len(h5_keys))
else:
n_workers = min(n_workers, len(h5_keys))
try:
# Use ProcessPoolExecutor for better control and Jupyter compatibility
with ProcessPoolExecutor(max_workers=n_workers) as executor:
# Submit all tasks
future_to_key = {
executor.submit(_load_single_h5_group, args): args[1]
for args in worker_args
}
# Create a dictionary to store results by key for ordered processing
results_by_key = {}
# Collect results as they complete
for future in as_completed(future_to_key):
h5_key = future_to_key[future]
try:
df = future.result()
if not df.empty:
results_by_key[h5_key] = df
except Exception as e:
logger.error(f"Error loading {h5_key}: {str(e)}")
# Continue with other groups instead of failing completely
# Add results to frames in the same order as h5_keys to maintain chronology
for h5_key in h5_keys:
if h5_key in results_by_key:
logger.info(f"Appending {h5_key} to the frame list for concatanation.")
frames.append(results_by_key[h5_key])
except Exception as e:
logger.warning(f"Multiprocessing failed ({str(e)}), falling back to sequential loading...")
use_multiprocessing = False
# Sequential loading (fallback or when multiprocessing is disabled)
if not use_multiprocessing:
logger.info(f"Loading {len(h5_keys)} groups sequentially...")
with pd.HDFStore(filepath, mode="r") as store:
for h5_key in h5_keys:
try:
if where:
df = store.select(h5_key, where=where)
else:
df = store[h5_key]
if not df.empty:
frames.append(df)
except Exception as e:
logger.error(f"Error loading {h5_key}: {str(e)}")
continue
if not frames:
raise ValueError("No data was successfully loaded from any HDF5 group.")
# ------------------------------------------------------------------
# Concatenate & restore original column order
# ------------------------------------------------------------------
logger.info(f"Concatenating {len(frames)} DataFrames...")
df = pd.concat(frames, copy=False)
# Ensure the DataFrame index is sorted
if not df.index.is_monotonic_increasing:
logger.info("Sorting DataFrame by datetime index after concatenation...")
df.sort_index(inplace=True)
logger.info(f"Successfully loaded {len(df):,} trades from {len(frames)} monthly groups.")
side = df["side"] if "side" in df.columns else None
side_values = side.values if side is not None else None
return cls(df["timestamp"].values, df["price"].values, df["amount"].values,
side=side_values, dt_index=df.index)
# --------
# utils
# --------
def _load_single_h5_group(args: Tuple[str, str, Optional[str]]) -> pd.DataFrame:
"""
Helper function to load a single HDF5 group in a separate process.
:param args: Tuple of (filepath, h5_key, where_clause)
:returns: DataFrame with the loaded data
"""
filepath, h5_key, where_clause = args
try:
with pd.HDFStore(filepath, mode="r") as store:
if where_clause:
df = store.select(h5_key, where=where_clause)
else:
df = store[h5_key]
return df
except Exception as e:
# Return empty DataFrame with error info in case of failure
logger.error(f"Failed to load {h5_key} from {filepath}: {str(e)}")
return pd.DataFrame()
def _is_notebook_environment() -> bool:
"""
Detect if we're running in a Jupyter notebook environment.
:returns: True if in notebook, False otherwise
"""
try:
# Check for IPython
from IPython import get_ipython
if get_ipython() is not None:
return True
except ImportError:
pass
# Check for other notebook indicators
return any('jupyter' in arg.lower() or 'ipython' in arg.lower() for arg in sys.argv)