Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions mediacloud/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datetime as dt
import importlib.metadata
import logging
from typing import Any, Dict, List, Optional, Union
import warnings
from typing import Any, Dict, List, Optional, Union

import requests

Expand Down Expand Up @@ -62,11 +62,10 @@ def _query(self, endpoint: str, params: Optional[Dict] = None, method: str = 'GE
r = self._session.post(endpoint_url, json=params, timeout=self.TIMEOUT_SECS)
else:
raise RuntimeError(f"Unsupported method of '{method}'")
if r.status_code != 200:
if r.status_code != 200:
raise mediacloud.error.APIResponseError(r, params, r.json())

return r.json()

return r.json()


class DirectoryApi(BaseApi):
Expand Down Expand Up @@ -148,7 +147,6 @@ def _prep_default_params(self, query: str, start_date: dt.date, end_date: dt.dat
end_date = end_date.date()
warnings.warn("end_date was passed as datetime, but expected as date, and has been recast")


params: Dict[Any, Any] = dict(start=start_date.isoformat(), end=end_date.isoformat(), q=query,
platform=(platform or self.PROVIDER))
if len(source_ids):
Expand All @@ -172,6 +170,13 @@ def story_count_over_time(self, query: str, start_date: dt.date, end_date: dt.da
d['date'] = dt.date.fromisoformat(d['date'][:10])
return results['count_over_time']['counts']

def stories_by_source_week(self, query: str, start_date: dt.date, end_date: dt.date,
collection_ids: Optional[List[int]] = [], source_ids: Optional[List[int]] = [],
platform: Optional[str] = None) -> List[Dict]:
params = self._prep_default_params(query, start_date, end_date, collection_ids, source_ids, platform)
results = self._query('search/count-by-source-week', params)
return results['source-week-attention']

def story_list(self, query: str, start_date: dt.date, end_date: dt.date, collection_ids: Optional[List[int]] = [],
source_ids: Optional[List[int]] = [], platform: Optional[str] = None,
expanded: Optional[bool] = None, pagination_token: Optional[str] = None,
Expand Down
22 changes: 19 additions & 3 deletions mediacloud/test/api_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os
import time
from unittest import TestCase

import pytest

import mediacloud.api

COLLECTION_US_NATIONAL = 34412234
Expand Down Expand Up @@ -236,6 +238,19 @@ def test_collection_ids_filter(self):
for s in results:
assert s['media_url'] in domains

def test_stories_by_source_week(self):
results = self._search.stories_by_source_week(query="tariff AND Trump",
start_date=dt.date(2025, 1, 1),
end_date=dt.date(2025, 12, 31),
collection_ids=[262985236])
assert len(results) > 0
for entry in results:
assert 'media_name' in entry
assert 'week' in entry
assert 'matching_stories' in entry
assert 'total_stories' in entry
assert 'ratio' in entry


class SearchSyntaxTest(TestCase):

Expand Down Expand Up @@ -324,8 +339,9 @@ def test_negation_source(self):
assert not_count < all_count
assert minus_count == not_count


class SearchErrorHandlingTest(TestCase):
#New test cases for how the api handles bad input and errors from the server.
# New test cases for how the api handles bad input and errors from the server.

START_DATE = dt.date(2024, 1, 1)
END_DATE = dt.date(2024, 1, 30)
Expand All @@ -339,10 +355,10 @@ def setUp(self):
def test_datetime(self):
query = "biden"
result_via_date = self._search.story_count(query=query, start_date=self.START_DATE, end_date=self.END_DATE,
collection_ids=[COLLECTION_US_NATIONAL])['relevant']
collection_ids=[COLLECTION_US_NATIONAL])['relevant']

with pytest.warns(UserWarning):
result_via_datetime = self._search.story_count(query=query, start_date=self.START_DATETIME, end_date=self.END_DATETIME,
collection_ids=[COLLECTION_US_NATIONAL])['relevant']
collection_ids=[COLLECTION_US_NATIONAL])['relevant']

assert result_via_date == result_via_datetime
Loading