Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 46 additions & 26 deletions src/gentrade/news/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from typing import List, Optional
from loguru import logger

from gentrade.scraper.extractor import ArticleContentExtractor

from gentrade.news.meta import NewsProviderBase, NewsDatabase, NewsFileDatabase
from gentrade.news.newsapi import NewsApiProvider
from gentrade.news.rss import RssProvider
from gentrade.news.finnhub import FinnhubNewsProvider

from gentrade.news.providers.newsapi import NewsApiProvider
from gentrade.news.providers.rss import RssProvider
from gentrade.news.providers.finnhub import FinnhubNewsProvider
from gentrade.news.providers.newsnow import NewsNowProvider
from gentrade.utils.download import ArticleDownloader

class NewsFactory:
"""Factory class for creating news provider instances based on provider type.
Expand Down Expand Up @@ -47,7 +46,8 @@ def create_provider(provider_type: str, **kwargs) -> NewsProviderBase:
providers = {
"newsapi": NewsApiProvider,
"finnhub": FinnhubNewsProvider,
"rss": RssProvider
"rss": RssProvider,
"newsnow": NewsNowProvider
}

provider_class = providers.get(provider_type_lower)
Expand All @@ -70,6 +70,10 @@ def create_provider(provider_type: str, **kwargs) -> NewsProviderBase:
feed_url = kwargs.get("feed_url", os.getenv("RSS_FEED_URL"))
return provider_class(feed_url=feed_url)

if provider_type_lower == "newsnow":
source = kwargs.get("source", "baidu")
return provider_class(source=source)

return provider_class(**kwargs)

class NewsAggregator:
Expand All @@ -79,7 +83,7 @@ class NewsAggregator:
and stores results in a database. Includes logic to avoid frequent syncs.
"""

def __init__(self, providers: List[NewsProviderBase], db: NewsDatabase):
def __init__(self, providers: List[NewsProviderBase], db: NewsDatabase = None):
"""Initialize the NewsAggregator with a list of providers and a database.

Args:
Expand All @@ -91,7 +95,7 @@ def __init__(self, providers: List[NewsProviderBase], db: NewsDatabase):
self.db_lock = threading.Lock()

def _fetch_thread(self, provider, aggregator, ticker, category,
max_hour_interval, max_count, is_process=False):
max_hour_interval, max_count, process_content=True):
if ticker:
news = provider.fetch_stock_news(
ticker, category, max_hour_interval, max_count
Expand All @@ -109,21 +113,26 @@ def _fetch_thread(self, provider, aggregator, ticker, category,
f"{provider.__class__.__name__}"
)

ace = ArticleContentExtractor.inst()
downloader = ArticleDownloader.inst()
for item in news:
item.summary = ace.clean_html(item.summary)
if is_process:
item.content = ace.extract_content(item.url)
item.summary = downloader.clean_html(item.summary)
if process_content:
logger.info(f"Process content ... {item.url}")
item.content = downloader.get_content(item.url)
if item.content:
logger.info(f"Content: {item.content[:20]}")

with aggregator.db_lock:
aggregator.db.add_news(news)
if self.db:
with aggregator.db_lock:
aggregator.db.add_news(news)

def sync_news(
self,
ticker: Optional[str] = None,
category: str = "business",
max_hour_interval: int = 24,
max_count: int = 10
max_count: int = 10,
process_content: bool = True
) -> None:
"""Synchronize news from providers, skipping if last sync was within 1 hour.

Expand All @@ -136,30 +145,35 @@ def sync_news(
max_hour_interval: Maximum age (in hours) of news articles to fetch (default: 24).
max_count: Maximum number of articles to fetch per provider (default: 10).
"""
current_time = time.time()
if current_time < self.db.last_sync + 3600:
logger.info("Skipping sync: Last sync was less than 1 hour ago.")
return
if self.db:
current_time = time.time()
if current_time < self.db.last_sync + 3600:
logger.info("Skipping sync: Last sync was less than 1 hour ago.")
return

logger.info("Starting news sync...")

threads = []
for provider in self.providers:
if not provider.is_available:
logger.error(f"Provider {provider.__class__.__name__} is not available")
continue

thread = threading.Thread(
target=self._fetch_thread,
args=(provider, self, ticker, category, max_hour_interval, max_count)
args=(provider, self, ticker, category, max_hour_interval,
max_count, process_content)
)
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

self.db.last_sync = current_time
self.db.save()
if self.db:
self.db.last_sync = current_time
self.db.save()

logger.info("News sync completed.")

if __name__ == "__main__":
Expand All @@ -170,18 +184,24 @@ def sync_news(
newsapi_provider = NewsFactory.create_provider("newsapi")
finnhub_provider = NewsFactory.create_provider("finnhub")
rss_provider = NewsFactory.create_provider("rss")
newsnow_provider = NewsFactory.create_provider("newsnow", source="jin10")

# Create aggregator with selected providers
aggregator = NewsAggregator(
providers=[rss_provider, newsapi_provider, finnhub_provider], db=db)
providers=[newsnow_provider, finnhub_provider, rss_provider, newsapi_provider], db=db)

# Sync market news and stock-specific news
aggregator.sync_news(category="business", max_hour_interval=64, max_count=10)
aggregator.sync_news(
category="business",
max_hour_interval=64,
max_count=10,
process_content = True)
aggregator.sync_news(
ticker="AAPL",
category="business",
max_hour_interval=240,
max_count=10
max_count=10,
process_content = True
)

# Log results
Expand Down
35 changes: 7 additions & 28 deletions src/gentrade/news/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import abc
import time
import hashlib
from typing import Dict, List, Any, Optional

from typing import Dict, List, Any
from datetime import datetime
from dataclasses import dataclass
from loguru import logger
import requests

NEWS_MARKET = [
'us', 'zh', 'hk', 'cypto', 'common'
'us', 'cn', 'hk', 'cypto', 'common'
]

@dataclass
Expand All @@ -35,7 +35,7 @@ class NewsInfo:
summary: str
url: str
content: str
provider: str # provder like newsapi, finnhub, rss
provider: str # provider like newsapi, finnhub, rss
market: str # market type like us, chn, eur, hk, crypto

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -59,27 +59,6 @@ def to_dict(self) -> Dict[str, Any]:
"market": self.market,
}

def fetch_article_html(self) -> Optional[str]:
"""Fetch raw HTML content from the article's direct URL.

Uses a browser-like user agent to avoid being blocked by servers.

Returns:
Raw HTML string if successful; None if request fails.
"""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
}

try:
response = requests.get(self.url, headers=headers, timeout=15)
response.raise_for_status()
return response.text
except requests.RequestException as e:
logger.debug(f"Failed to fetch HTML for {self.url}: {e}")
return None

class NewsProviderBase(metaclass=abc.ABCMeta):
"""Abstract base class defining the interface for news providers.

Expand Down Expand Up @@ -278,11 +257,11 @@ def save(self):
"last_sync": self.last_sync,
"news_list": news_dicts
}
with open(self._filepath, 'w', encoding='utf-8') as f:
json.dump(content, f, indent=4) # indent for readability
with open(self._filepath, 'w', encoding="utf-8") as f:
json.dump(content, f, ensure_ascii=False, indent=4) # indent for readability

def load(self):
with open(self._filepath, 'r', encoding='utf-8') as f:
with open(self._filepath, 'r', encoding="utf-8") as f:
content = json.load(f) # Directly loads JSON content into a Python list/dict
self.last_sync = content['last_sync']
self.news_list = [NewsInfo(**item_dict) for item_dict in content['news_list']]
Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
Loading