diff --git a/mission_blue.py b/mission_blue.py index fcda485..03b1ebe 100644 --- a/mission_blue.py +++ b/mission_blue.py @@ -1,12 +1,11 @@ -"""This module conatins the BlueSky Web Scrapper.""" +"""This module contains the BlueSky Web Scrapper.""" import click import requests -from alive_progress import alive_bar -from alive_progress.animations.bars import bar_factory -from typing import Optional, List, Dict, Any import auth +import scraper import file +from typing import Optional, List, Dict, Any # pylint: disable=C0301 @@ -213,89 +212,6 @@ def generate_query_params( } -def search_posts(params: dict, token: str) -> list[dict]: - # pylint: disable=E1102 - # pylint: disable=C0301 - """Search for posts using the BlueSky API. - - Args: - params (dict): The query parameters for the API request. - - query (str, required): The search term for the BlueSky posts. - - sort (str, optional): The sorting criteria for results. - Options include "top" for top posts or "latest" for the latest posts. - - since (str, optional): The start date for posts (ISO 8601 format). - - until (str, optional): The end date for posts (ISO 8601 format). - - mentions (str, optional): Mentions to filter posts by. - - Handles will be resolved to DIDs using the provided token. - - author (str, optional): The author of the posts (handle or DID). - - lang (str, optional): The language of the posts. - - domain (str, optional): A domain URL included in the posts. - - url (str, optional): A specific URL included in the posts. - - tags (list, optional): Tags to filter posts by (each tag <= 640 characters). - - limit (int, optional): The maximum number of posts to retrieve in a single response. - Defaults to 25. - - cursor (str, optional): Pagination token for continuing from a previous request. - - posts_limit (int, optional): The maximum number of posts to retrieve across all responses. - Defaults to 500. - - Returns: - list: A list of posts matching the search criteria. - - Notes: - - Progress is displayed using a progress bar indicating the number of posts fetched. - - Handles pagination automatically until `posts_limit` is reached or no further results are available. - - Logs and returns partial results if an error occurs during fetching. - - """ - posts = [] - url = "https://bsky.social/xrpc/app.bsky.feed.searchPosts" - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } - - total_fetched = 0 - posts_limit = params.get("posts_limit") - butterfly_bar = bar_factory("✨", tip="🦋", errors="🔥🧯👩‍🚒") - - with alive_bar(posts_limit, bar=butterfly_bar, spinner="waves") as progress: - while True: - try: - response = requests.get(url, headers=headers, params=params, timeout=10) - # print(response) - response.raise_for_status() - data = response.json() - - # Check if we have reached our overall posts limit - new_posts = data.get("posts", []) - posts.extend(new_posts) - total_fetched += len(new_posts) - - # Update progress bar - progress(len(new_posts)) - - if posts_limit and total_fetched >= posts_limit: - print( - f"Fetched {total_fetched} posts, total: {total_fetched}/{posts_limit}" - ) - return posts[:posts_limit] - - # Move to the enxt page if available - next_cursor = data.get("cursor") - if not next_cursor: - print(f"All posts fetched. Total: {total_fetched}") - return posts - - params["cursor"] = next_cursor - except requests.exceptions.RequestException as err: - print(f"Error fetching posts: {err}") - print( - "Response:", - response.text if "response" in locals() else "No response", - ) - return posts - - # Begin Click CLI @@ -388,10 +304,10 @@ def search_posts(params: dict, token: str) -> list[dict]: "--posts_limit", type=click.IntRange(1, None), required=False, - default=1000, + default=500, help=( "Set the total number of posts to fetch from the API across all paginated responses. This value limits the total data retrieved " - "even if multiple API calls are required. If not specified, 1000 posts will be recieved." + "even if multiple API calls are required. If not specified, 500 posts will be recieved." ), ) def main( @@ -406,7 +322,7 @@ def main( url: str = "", tags: tuple = (), limit: int = 25, - posts_limit: int = 1000, + posts_limit: int = 500, ) -> None: """Method that tests if each click param flag is being passed in correctly.""" # pylint: disable=R0913 @@ -442,7 +358,7 @@ def main( # Fetch posts print("Fetching posts...") - raw_posts = search_posts(query_param, access_token) + raw_posts = scraper.search_posts(query_param, access_token) # Extract post data print("Extracting post data...") diff --git a/scraper.py b/scraper.py new file mode 100644 index 0000000..933878c --- /dev/null +++ b/scraper.py @@ -0,0 +1,100 @@ +""" +This module contains a function to search for posts using the BlueSky API. +""" + +import requests +from alive_progress import alive_bar +from alive_progress.animations.bars import bar_factory + + +def search_posts(params, token): + # pylint: disable=E1102 + # pylint: disable=C0301 + + """ + Search for posts using the BlueSky API. + + Args: + params (dict): The query parameters for the API request. + - query (str, required): The search term for the BlueSky posts. + - sort (str, optional): The sorting criteria for results. + Options include "top" for top posts or "latest" for the latest posts. + - since (str, optional): The start date for posts (ISO 8601 format). + - until (str, optional): The end date for posts (ISO 8601 format). + - mentions (str, optional): Mentions to filter posts by. + - Handles will be resolved to DIDs using the provided token. + - author (str, optional): The author of the posts (handle or DID). + - lang (str, optional): The language of the posts. + - domain (str, optional): A domain URL included in the posts. + - url (str, optional): A specific URL included in the posts. + - tags (list, optional): Tags to filter posts by (each tag <= 640 characters). + - limit (int, optional): The maximum number of posts to retrieve in a single response. + Defaults to 25. + - cursor (str, optional): Pagination token for continuing from a previous request. + - posts_limit (int, optional): The maximum number of posts to retrieve across all responses. + Defaults to 500. + + Returns: + list: A list of posts matching the search criteria. + + Notes: + - Progress is displayed using a progress bar indicating the number of posts fetched. + - Handles pagination automatically until `posts_limit` is reached or no further results are available. + - Logs and returns partial results if an error occurs during fetching. + """ + # Validate input parameters + if "q" not in params: + raise ValueError("Query parameter is required.") + if not token: + raise ValueError("Token is required.") + + posts = [] + url = "https://bsky.social/xrpc/app.bsky.feed.searchPosts" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + total_fetched = 0 + posts_limit = params.get("posts_limit", 500) + butterfly_bar = bar_factory("✨", tip="🦋", errors="🔥🧯👩‍🚒") + + with alive_bar(posts_limit, bar=butterfly_bar, spinner="waves") as progress: + while True: + try: + response = requests.get(url, headers=headers, params=params, timeout=10) + response.raise_for_status() + data = response.json() + + # Check if we have reached our overall posts limit + new_posts = data.get("posts", []) + posts.extend(new_posts) + total_fetched += len(new_posts) + + # Update progress bar + remaining = posts_limit - (total_fetched - len(new_posts)) + progress(min(len(new_posts), remaining)) + + if total_fetched >= posts_limit: + print( + f"Fetched {total_fetched} posts, total: {total_fetched}/{posts_limit}" + ) + # Truncate only if we exceeded the limit + if len(posts) > posts_limit: + posts = posts[:posts_limit] + return posts + + # Move to the next page if available + next_cursor = data.get("cursor") + if not next_cursor: + print(f"All posts fetched. Total: {total_fetched}") + return posts + + params["cursor"] = next_cursor + except requests.exceptions.RequestException as err: + print(f"Error fetching posts: {err}") + print( + "Response:", + response.text if "response" in locals() else "No response", + ) + return posts diff --git a/tests/scraper_test.py b/tests/scraper_test.py new file mode 100644 index 0000000..6be76f2 --- /dev/null +++ b/tests/scraper_test.py @@ -0,0 +1,105 @@ +"""Testing suite for the scraper module.""" + +import unittest +import requests +import io +import sys +from unittest.mock import patch, MagicMock +from scraper import search_posts + + +class TestSearchPosts(unittest.TestCase): + """Testing the search_posts() method.""" + + @patch("scraper.requests.get") + def test_no_query(self, mock_get: MagicMock) -> None: + """Test if the function raises ValueError when a query is not provided.""" + params = {} + token = "valid_token" + + with self.assertRaises(ValueError) as cm: + search_posts(params, token) + + mock_get.assert_not_called() + self.assertIn("query", str(cm.exception).lower()) + + @patch("scraper.requests.get") + def test_no_token(self, mock_get: MagicMock) -> None: + """Test if the function raises ValueError when a token is not provided.""" + params = {"q": "test"} + token = None + + with self.assertRaises(ValueError) as cm: + search_posts(params, token) + + mock_get.assert_not_called() + self.assertIn("token", str(cm.exception).lower()) + + @patch("scraper.requests.get") + def test_valid_response(self, mock_get: MagicMock) -> None: + """Test that the function returns a list of posts when valid parameters are provided.""" + params = {"q": "test"} + token = "valid_token" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "posts": [ + { + "uri": "at://did:plc:12345/app.bsky.feed.post/abcdef", + "cid": "bafyre123...", + "author": { + "did": "did:plc:12345", + "handle": "author_handle", + "displayName": "Author Name", + }, + "record": { + "text": "Post content", + "createdAt": "2023-10-01T00:00:00Z", + "$type": "app.bsky.feed.post", + }, + "indexedAt": "2023-10-01T00:00:01Z", + } + ], + "cursor": None, + } + + mock_get.return_value = mock_response + + result = search_posts(params, token) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["record"]["text"], "Post content") + self.assertEqual(result[0]["author"]["handle"], "author_handle") + self.assertEqual(result[0]["record"]["createdAt"], "2023-10-01T00:00:00Z") + self.assertEqual( + result[0]["uri"], "at://did:plc:12345/app.bsky.feed.post/abcdef" + ) + + # Simulate a failed API response (e.g., 400: [InvalidRequest, ExpiredToken, InvalidToken, BadQueryString]) + @patch("scraper.requests.get") + def test_invalid_request(self, mock_get: MagicMock) -> None: + """Test that the function handles invalid requests gracefully.""" + params = {"q": "test"} + token = "invalid_token" + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + "400 Client Error: InvalidToken" + ) + mock_get.return_value = mock_response + + # Redirecting stdout to StringIO + captured_output = io.StringIO() + sys.stdout = captured_output + + result = search_posts(params, token) + sys.stdout = sys.__stdout__ + + self.assertEqual(result, []) + self.assertIn("400 Client Error:", captured_output.getvalue()) + + +if __name__ == "__main__": + unittest.main()