Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ cache
.pytest_cache
dist
*.egg-info/
scraper_data/
scraper_data/
*.txt
28 changes: 11 additions & 17 deletions src/gentrade/news/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from gentrade.scraper.extractor import ArticleContentExtractor

from gentrade.news.meta import NewsProviderBase, NewsDatabase
from gentrade.news.meta import NewsProviderBase, NewsDatabase, NewsFileDatabase
from gentrade.news.newsapi import NewsApiProvider
from gentrade.news.rss import RssProvider
from gentrade.news.finnhub import FinnhubNewsProvider
Expand Down Expand Up @@ -72,7 +72,6 @@ def create_provider(provider_type: str, **kwargs) -> NewsProviderBase:

return provider_class(**kwargs)


class NewsAggregator:
"""Aggregates news articles from multiple providers and synchronizes them to a database.

Expand All @@ -92,7 +91,7 @@ def __init__(self, providers: List[NewsProviderBase], db: NewsDatabase):
self.db_lock = threading.Lock()

def _fetch_thread(self, provider, aggregator, ticker, category,
max_hour_interval, max_count, is_process=True):
max_hour_interval, max_count, is_process=False):
if ticker:
news = provider.fetch_stock_news(
ticker, category, max_hour_interval, max_count
Expand All @@ -115,7 +114,6 @@ def _fetch_thread(self, provider, aggregator, ticker, category,
item.summary = ace.clean_html(item.summary)
if is_process:
item.content = ace.extract_content(item.url)
logger.info(item.content)

with aggregator.db_lock:
aggregator.db.add_news(news)
Expand Down Expand Up @@ -147,6 +145,9 @@ def sync_news(

threads = []
for provider in self.providers:
if not provider.is_available:
continue

thread = threading.Thread(
target=self._fetch_thread,
args=(provider, self, ticker, category, max_hour_interval, max_count)
Expand All @@ -158,10 +159,11 @@ def sync_news(
thread.join()

self.db.last_sync = current_time
self.db.save()
logger.info("News sync completed.")

if __name__ == "__main__":
db = NewsDatabase()
db = NewsFileDatabase("news_db.txt")

try:
# Initialize providers using the factory
Expand All @@ -170,7 +172,8 @@ def sync_news(
rss_provider = NewsFactory.create_provider("rss")

# Create aggregator with selected providers
aggregator = NewsAggregator(providers=[rss_provider], db=db)
aggregator = NewsAggregator(
providers=[rss_provider, newsapi_provider, finnhub_provider], db=db)

# Sync market news and stock-specific news
aggregator.sync_news(category="business", max_hour_interval=64, max_count=10)
Expand All @@ -185,16 +188,7 @@ def sync_news(
all_news = db.get_all_news()
logger.info(f"Total articles in database: {len(all_news)}")

if all_news:
logger.info("Example article:")
logger.info(all_news[0].to_dict())

for news_item in all_news:
logger.info("--------------------------------")
print(news_item.headline)
print(news_item.url)
print(news_item.content)
logger.info("--------------------------------")

for news_item in all_news:
logger.info("[%s...]: %s..." % (str(news_item.id)[:10], news_item.headline[:15]))
except ValueError as e:
logger.error(f"Error during news aggregation: {e}")
20 changes: 12 additions & 8 deletions src/gentrade/news/finnhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
utilizing the Finnhub.io API to retrieve news articles. It supports both general market news
and news specific to individual stock tickers, with filtering by time interval and article count.
"""

import os
import time
from typing import List
from datetime import datetime, timedelta
Expand All @@ -21,19 +21,23 @@ class FinnhubNewsProvider(NewsProviderBase):
intervals and maximum article count.
"""

def __init__(self, api_key: str):
def __init__(self, api_key: str = None):
"""Initialize the FinnhubNewsProvider with the required API key.

Args:
api_key: API key for authenticating requests to Finnhub.io.
"""
self.api_key = api_key
self.api_key = ( api_key or os.getenv("FINNHUB_API_KEY") )
self.base_url = "https://finnhub.io/api/v1"

@property
def market(self):
def market(self) -> str:
return 'us'

@property
def is_available(self) -> bool:
return self.api_key is not None and len(self.api_key) != 0

def fetch_latest_market_news(
self,
category: str = "business",
Expand Down Expand Up @@ -75,7 +79,7 @@ def fetch_latest_market_news(
headline=article.get("headline", ""),
id=self.url_to_hash_id(article.get("url", "")),
image=article.get("image", ""),
related=article.get("related", ""),
related=article.get("related", []),
source=article.get("source", ""),
summary=article.get("summary", ""),
url=article.get("url", ""),
Expand All @@ -85,7 +89,7 @@ def fetch_latest_market_news(
) for article in articles
]

return self._filter_news(news_list, max_hour_interval, max_count)
return self.filter_news(news_list, max_hour_interval, max_count)

except requests.RequestException as e:
logger.debug(f"Error fetching market news from Finnhub: {e}")
Expand Down Expand Up @@ -134,7 +138,7 @@ def fetch_stock_news(
headline=article.get("headline", ""),
id=article.get("id", hash(article.get("url", ""))),
image=article.get("image", ""),
related=ticker,
related=[ticker,],
source=article.get("source", ""),
summary=article.get("summary", ""),
url=article.get("url", ""),
Expand All @@ -144,7 +148,7 @@ def fetch_stock_news(
) for article in articles
]

return self._filter_news(news_list, max_hour_interval, max_count)
return self.filter_news(news_list, max_hour_interval, max_count)

except requests.RequestException as e:
logger.debug(f"Error fetching stock news from Finnhub: {e}")
Expand Down
94 changes: 80 additions & 14 deletions src/gentrade/news/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

Supports fetching market-wide and stock-specific news, with filtering by time and count.
"""

import os
import json
import abc
import time
import hashlib
Expand All @@ -29,7 +30,7 @@ class NewsInfo:
headline: str
id: int
image: str
related: str # Related stock ticker(s) or empty string
related: list[str] # Related stock ticker(s) or empty list
source: str
summary: str
url: str
Expand Down Expand Up @@ -79,17 +80,36 @@ def fetch_article_html(self) -> Optional[str]:
logger.debug(f"Failed to fetch HTML for {self.url}: {e}")
return None


class NewsProviderBase(metaclass=abc.ABCMeta):
"""Abstract base class defining the interface for news providers.

All concrete news providers (e.g., NewsAPI, Finnhub) must implement these methods.
"""

@property
def market(self):
def market(self) -> str:
"""Get the market identifier this provider is associated with.

Defaults to 'common' for providers that cover general markets.
Concrete providers may override this to specify a specific market (e.g., 'us', 'cn').

Returns:
str: Market identifier string.
"""
return 'common'

@property
def is_available(self) -> bool:
"""Check if the news provider is currently available/operational.

Defaults to True. Concrete providers may override this to implement
availability checks (e.g., API status, rate limits, connectivity).

Returns:
bool: True if provider is available, False otherwise.
"""
return True

@abc.abstractmethod
def fetch_latest_market_news(
self,
Expand All @@ -109,7 +129,6 @@ def fetch_latest_market_news(
"""
raise NotImplementedError

@abc.abstractmethod
def fetch_stock_news(
self,
ticker: str,
Expand All @@ -128,7 +147,27 @@ def fetch_stock_news(
Returns:
List of NewsInfo objects related to the specified ticker.
"""
raise NotImplementedError
# Fetch 2x max_count general news to allow ticker filtering
general_news = self.fetch_latest_market_news(
category=category,
max_hour_interval=max_hour_interval,
max_count=max_count * 2
)

# Filter articles where ticker is in headline or summary (case-insensitive)
ticker_lower = ticker.lower()
ticker_news = [
news for news in general_news
if ticker_lower in news.headline.lower()
or ticker_lower in news.summary.lower()
]

# Update "related" field to link articles to the target ticker
for news in ticker_news:
news.related.append(ticker)

# Limit to max_count results
return ticker_news[:max_count]

def _timestamp_to_epoch(self, timestamp: str) -> int:
"""Convert ISO 8601 timestamp to epoch seconds.
Expand All @@ -146,7 +185,7 @@ def _timestamp_to_epoch(self, timestamp: str) -> int:
except ValueError:
return int(time.time())

def _filter_news(
def filter_news(
self,
news_list: List[NewsInfo],
max_hour_interval: int,
Expand Down Expand Up @@ -184,27 +223,29 @@ class NewsDatabase:

def __init__(self):
"""Initialize an empty database with last sync time set to 0."""
self.news_dict: Dict[str, NewsInfo] = {} # Key: article URL
self.last_sync: float = 0.0 # Epoch time of last successful sync
self.news_list: List[NewsInfo] = []
self.last_sync = 0

def add_news(self, news_list: List[NewsInfo]) -> None:
"""Add news articles to the database, skipping duplicates.

Args:
news_list: List of NewsInfo objects to store.
"""
news_hash_cache_list = [item.id for item in self.news_list]
for news in news_list:
# Use URL as unique identifier to avoid duplicates
if news.url and news.url not in self.news_dict:
self.news_dict[news.url] = news
if news.id in news_hash_cache_list:
logger.error("news %s already in the cache list" % news.id)
continue
self.news_list.append(news)

def get_all_news(self) -> List[NewsInfo]:
"""Retrieve all stored news articles.

Returns:
List of all NewsInfo objects in the database.
"""
return list(self.news_dict.values())
return self.news_list

def get_market_news(self, market='us') -> List[NewsInfo]:
"""Retrieve stored news articles for given market.
Expand All @@ -217,7 +258,32 @@ def get_market_news(self, market='us') -> List[NewsInfo]:
"""
assert market in NEWS_MARKET
market_news = []
for item in self.news_dict.values():
for item in self.news_list:
if item.market == market:
market_news.append(item)
return market_news


class NewsFileDatabase(NewsDatabase):

def __init__(self, filepath):
super().__init__()
self._filepath = filepath
if os.path.exists(self._filepath):
self.load()

def save(self):
news_dicts = [news.to_dict() for news in self.news_list]
content = {
"last_sync": self.last_sync,
"news_list": news_dicts
}
with open(self._filepath, 'w', encoding='utf-8') as f:
json.dump(content, f, indent=4) # indent for readability

def load(self):
with open(self._filepath, 'r', encoding='utf-8') as f:
content = json.load(f) # Directly loads JSON content into a Python list/dict
self.last_sync = content['last_sync']
self.news_list = [NewsInfo(**item_dict) for item_dict in content['news_list']]
logger.info(self.news_list)
Loading