Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 86 additions & 16 deletions dataloader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,22 +185,34 @@ def fields(cls, source: str) -> List[str]:
@classmethod
def all(cls, source: str) -> pd.DataFrame:
"""Fetch entire table."""
raise NotImplementedError
query = f"SELECT * FROM {source}"
df = cls.client.query(query)
return cls._format_dataframe(df)

# added columns here
@classmethod
def columns(cls, source: str) -> pd.DataFrame:
def columns(cls, source: str, columns: List[str]) -> pd.DataFrame:
"""Select specific columns from source."""
raise NotImplementedError
query = f"SELECT {', '.join(columns)} FROM {source}"
df = cls.client.query(query)
return cls._format_dataframe(df)

# QUESTION
# order by date here?
@classmethod
def head(cls, source: str, n: int = 10) -> pd.DataFrame:
"""Get first N rows."""
raise NotImplementedError
query = f"SELECT * FROM {source} LIMIT {n}"
df = cls.client.query(query)
return cls._format_dataframe(df)

# Same question as above
@classmethod
def paginate(cls, source: str, limit: int, offset: int) -> pd.DataFrame:
"""Get paginated results."""
raise NotImplementedError
query = f"SELECT * From {source} LIMIT {limit} OFFSET {offset}"
df = cls.client.query(query)
return cls._format_dataframe(df)

@classmethod
def filter(cls, source: str, **kwargs) -> pd.DataFrame:
Expand All @@ -210,50 +222,95 @@ def filter(cls, source: str, **kwargs) -> pd.DataFrame:
Example:
DataLoader.filter('equities', symbol='AAPL', date_start='2024-01-01')
"""
raise NotImplementedError
symbol = kwargs.pop("symbol", None)
date_start = kwargs.pop("date_start", None)
date_end = kwargs.pop("date_end", None)

filters = {}
if symbol:
filters["symbol"] = symbol
if date_start:
filters["start"] = date_start
if date_end:
filters["end"] = date_end
filters.update(kwargs)

query = cls._build_query(source, filters=filters)
df = cls.client.query(query, params=filters)
return cls._format_dataframe(df)

@classmethod
def match_pattern(cls, source: str, pattern: str) -> List[str]:
"""Get columns matching a pattern."""
raise NotImplementedError
query = f"FROM {source} SELECT COLUMNS('{pattern}')"
df = cls.client.query(query)
return df["columns"].tolist()

@classmethod
def select_pattern(cls, source: str, pattern: str, **filters) -> pd.DataFrame:
"""Select columns matching a pattern with optional filters."""
raise NotImplementedError
columns = cls.match_pattern(source, pattern)
query = cls._build_query(source, columns_list=columns, filters=filters)
df = cls.client.query(query, params=filters)
return cls._format_dataframe(df)

@classmethod
def date_range(
cls, source: str, start_date: str, end_date: str, **additional_filters
) -> pd.DataFrame:
"""Get data between two dates (YYYY-MM-DD format)."""
raise NotImplementedError
filters = {
"start": start_date,
"end": end_date,
}
filters.update(additional_filters)

query = cls._build_query(source, filters=filters)
df = cls.client.query(query, params=filters)
return cls._format_dataframe(df)

@classmethod
def first_date(cls, source: str) -> pd.Timestamp:
"""Return the earliest date in the table."""
raise NotImplementedError
query = f"SELECT MIN(date) AS first_date FROM {source}"
df = cls.client.query(query)
return pd.to_datetime(df["first_date"].iloc[0])

@classmethod
def last_date(cls, source: str) -> pd.Timestamp:
"""Return the latest date in the table."""
raise NotImplementedError
query = f"SELECT MAX(date) AS last_date FROM {source}"
df = cls.client.query(query)
return pd.to_datetime(df["last_date"].iloc[0])

# want me to add an optional ticker parameter here?
@classmethod
def latest(cls, source: str, n: int = 1) -> pd.DataFrame:
"""Return the last N rows per symbol or table."""
raise NotImplementedError
query = f"""
SELECT * FROM {source}
ORDER BY date DESC
LIMIT {n}
"""
df = cls.client.query(query)
return cls._format_dataframe(df)

@classmethod
def describe(cls, source: str) -> pd.DataFrame:
"""Return column types, non-null counts, basic stats."""
raise NotImplementedError
query = f"DESCRIBE TABLE {source}"
df = cls.client.query(query)
return df

@classmethod
def column_types(cls, source: str) -> Dict[str, str]:
"""Return data types for each column in a table."""
raise NotImplementedError
query = f"SHOW COLUMNS FROM {source}"
# will return a dataframe with 'field' and 'type' columns
df = cls.client.query(query)
return dict(zip(df["field"], df["type"]))

# order here too?
@classmethod
def stream(cls, source: str, batch_size: int = 10000):
"""
Expand All @@ -262,7 +319,14 @@ def stream(cls, source: str, batch_size: int = 10000):
for df_chunk in DataLoader.stream('equities', 5000):
process(df_chunk)
"""
raise NotImplementedError
offset = 0 # for pagination
while True:
query = f"SELECT * FROM {source} ORDER BY date LIMIT {batch_size} OFFSET {offset}"
df = cls.client.query(query)
if df.empty:
break
yield cls._format_dataframe(df)
offset += batch_size

@classmethod
def iter_chunks(cls, source: str, chunk_size: int = 10000):
Expand All @@ -272,4 +336,10 @@ def iter_chunks(cls, source: str, chunk_size: int = 10000):
@classmethod
def batch_query(cls, sources: List[str], filters: Optional[Dict[str, Any]] = None):
"""Query multiple tables or symbols in a single call."""
raise NotImplementedError
combined_df = pd.DataFrame()
for source in sources:
query = cls._build_query(source, filters=filters)
df = cls.client.query(query, params=filters)
formatted_df = cls._format_dataframe(df)
combined_df = pd.concat([combined_df, formatted_df], axis=1)
return combined_df
Loading