diff --git a/dataloader/loader.py b/dataloader/loader.py index 6bb8367..1d3f929 100644 --- a/dataloader/loader.py +++ b/dataloader/loader.py @@ -185,22 +185,34 @@ def fields(cls, source: str) -> List[str]: @classmethod def all(cls, source: str) -> pd.DataFrame: """Fetch entire table.""" - raise NotImplementedError + query = f"SELECT * FROM {source}" + df = cls.client.query(query) + return cls._format_dataframe(df) + # added columns here @classmethod - def columns(cls, source: str) -> pd.DataFrame: + def columns(cls, source: str, columns: List[str]) -> pd.DataFrame: """Select specific columns from source.""" - raise NotImplementedError + query = f"SELECT {', '.join(columns)} FROM {source}" + df = cls.client.query(query) + return cls._format_dataframe(df) + # QUESTION + # order by date here? @classmethod def head(cls, source: str, n: int = 10) -> pd.DataFrame: """Get first N rows.""" - raise NotImplementedError + query = f"SELECT * FROM {source} LIMIT {n}" + df = cls.client.query(query) + return cls._format_dataframe(df) + # Same question as above @classmethod def paginate(cls, source: str, limit: int, offset: int) -> pd.DataFrame: """Get paginated results.""" - raise NotImplementedError + query = f"SELECT * From {source} LIMIT {limit} OFFSET {offset}" + df = cls.client.query(query) + return cls._format_dataframe(df) @classmethod def filter(cls, source: str, **kwargs) -> pd.DataFrame: @@ -210,50 +222,95 @@ def filter(cls, source: str, **kwargs) -> pd.DataFrame: Example: DataLoader.filter('equities', symbol='AAPL', date_start='2024-01-01') """ - raise NotImplementedError + symbol = kwargs.pop("symbol", None) + date_start = kwargs.pop("date_start", None) + date_end = kwargs.pop("date_end", None) + + filters = {} + if symbol: + filters["symbol"] = symbol + if date_start: + filters["start"] = date_start + if date_end: + filters["end"] = date_end + filters.update(kwargs) + + query = cls._build_query(source, filters=filters) + df = cls.client.query(query, params=filters) + return cls._format_dataframe(df) @classmethod def match_pattern(cls, source: str, pattern: str) -> List[str]: """Get columns matching a pattern.""" - raise NotImplementedError + query = f"FROM {source} SELECT COLUMNS('{pattern}')" + df = cls.client.query(query) + return df["columns"].tolist() @classmethod def select_pattern(cls, source: str, pattern: str, **filters) -> pd.DataFrame: """Select columns matching a pattern with optional filters.""" - raise NotImplementedError + columns = cls.match_pattern(source, pattern) + query = cls._build_query(source, columns_list=columns, filters=filters) + df = cls.client.query(query, params=filters) + return cls._format_dataframe(df) @classmethod def date_range( cls, source: str, start_date: str, end_date: str, **additional_filters ) -> pd.DataFrame: """Get data between two dates (YYYY-MM-DD format).""" - raise NotImplementedError + filters = { + "start": start_date, + "end": end_date, + } + filters.update(additional_filters) + + query = cls._build_query(source, filters=filters) + df = cls.client.query(query, params=filters) + return cls._format_dataframe(df) @classmethod def first_date(cls, source: str) -> pd.Timestamp: """Return the earliest date in the table.""" - raise NotImplementedError + query = f"SELECT MIN(date) AS first_date FROM {source}" + df = cls.client.query(query) + return pd.to_datetime(df["first_date"].iloc[0]) @classmethod def last_date(cls, source: str) -> pd.Timestamp: """Return the latest date in the table.""" - raise NotImplementedError + query = f"SELECT MAX(date) AS last_date FROM {source}" + df = cls.client.query(query) + return pd.to_datetime(df["last_date"].iloc[0]) + # want me to add an optional ticker parameter here? @classmethod def latest(cls, source: str, n: int = 1) -> pd.DataFrame: """Return the last N rows per symbol or table.""" - raise NotImplementedError + query = f""" + SELECT * FROM {source} + ORDER BY date DESC + LIMIT {n} + """ + df = cls.client.query(query) + return cls._format_dataframe(df) @classmethod def describe(cls, source: str) -> pd.DataFrame: """Return column types, non-null counts, basic stats.""" - raise NotImplementedError + query = f"DESCRIBE TABLE {source}" + df = cls.client.query(query) + return df @classmethod def column_types(cls, source: str) -> Dict[str, str]: """Return data types for each column in a table.""" - raise NotImplementedError + query = f"SHOW COLUMNS FROM {source}" + # will return a dataframe with 'field' and 'type' columns + df = cls.client.query(query) + return dict(zip(df["field"], df["type"])) + # order here too? @classmethod def stream(cls, source: str, batch_size: int = 10000): """ @@ -262,7 +319,14 @@ def stream(cls, source: str, batch_size: int = 10000): for df_chunk in DataLoader.stream('equities', 5000): process(df_chunk) """ - raise NotImplementedError + offset = 0 # for pagination + while True: + query = f"SELECT * FROM {source} ORDER BY date LIMIT {batch_size} OFFSET {offset}" + df = cls.client.query(query) + if df.empty: + break + yield cls._format_dataframe(df) + offset += batch_size @classmethod def iter_chunks(cls, source: str, chunk_size: int = 10000): @@ -272,4 +336,10 @@ def iter_chunks(cls, source: str, chunk_size: int = 10000): @classmethod def batch_query(cls, sources: List[str], filters: Optional[Dict[str, Any]] = None): """Query multiple tables or symbols in a single call.""" - raise NotImplementedError + combined_df = pd.DataFrame() + for source in sources: + query = cls._build_query(source, filters=filters) + df = cls.client.query(query, params=filters) + formatted_df = cls._format_dataframe(df) + combined_df = pd.concat([combined_df, formatted_df], axis=1) + return combined_df \ No newline at end of file