Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 7 additions & 91 deletions mission_blue.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""This module conatins the BlueSky Web Scrapper."""
"""This module contains the BlueSky Web Scrapper."""

import click
import requests
from alive_progress import alive_bar
from alive_progress.animations.bars import bar_factory
from typing import Optional, List, Dict, Any
import auth
import scraper
import file
from typing import Optional, List, Dict, Any

# pylint: disable=C0301

Expand Down Expand Up @@ -213,89 +212,6 @@ def generate_query_params(
}


def search_posts(params: dict, token: str) -> list[dict]:
# pylint: disable=E1102
# pylint: disable=C0301
"""Search for posts using the BlueSky API.

Args:
params (dict): The query parameters for the API request.
- query (str, required): The search term for the BlueSky posts.
- sort (str, optional): The sorting criteria for results.
Options include "top" for top posts or "latest" for the latest posts.
- since (str, optional): The start date for posts (ISO 8601 format).
- until (str, optional): The end date for posts (ISO 8601 format).
- mentions (str, optional): Mentions to filter posts by.
- Handles will be resolved to DIDs using the provided token.
- author (str, optional): The author of the posts (handle or DID).
- lang (str, optional): The language of the posts.
- domain (str, optional): A domain URL included in the posts.
- url (str, optional): A specific URL included in the posts.
- tags (list, optional): Tags to filter posts by (each tag <= 640 characters).
- limit (int, optional): The maximum number of posts to retrieve in a single response.
Defaults to 25.
- cursor (str, optional): Pagination token for continuing from a previous request.
- posts_limit (int, optional): The maximum number of posts to retrieve across all responses.
Defaults to 500.

Returns:
list: A list of posts matching the search criteria.

Notes:
- Progress is displayed using a progress bar indicating the number of posts fetched.
- Handles pagination automatically until `posts_limit` is reached or no further results are available.
- Logs and returns partial results if an error occurs during fetching.

"""
posts = []
url = "https://bsky.social/xrpc/app.bsky.feed.searchPosts"
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}

total_fetched = 0
posts_limit = params.get("posts_limit")
butterfly_bar = bar_factory("✨", tip="🦋", errors="🔥🧯👩‍🚒")

with alive_bar(posts_limit, bar=butterfly_bar, spinner="waves") as progress:
while True:
try:
response = requests.get(url, headers=headers, params=params, timeout=10)
# print(response)
response.raise_for_status()
data = response.json()

# Check if we have reached our overall posts limit
new_posts = data.get("posts", [])
posts.extend(new_posts)
total_fetched += len(new_posts)

# Update progress bar
progress(len(new_posts))

if posts_limit and total_fetched >= posts_limit:
print(
f"Fetched {total_fetched} posts, total: {total_fetched}/{posts_limit}"
)
return posts[:posts_limit]

# Move to the enxt page if available
next_cursor = data.get("cursor")
if not next_cursor:
print(f"All posts fetched. Total: {total_fetched}")
return posts

params["cursor"] = next_cursor
except requests.exceptions.RequestException as err:
print(f"Error fetching posts: {err}")
print(
"Response:",
response.text if "response" in locals() else "No response",
)
return posts


# Begin Click CLI


Expand Down Expand Up @@ -388,10 +304,10 @@ def search_posts(params: dict, token: str) -> list[dict]:
"--posts_limit",
type=click.IntRange(1, None),
required=False,
default=1000,
default=500,
help=(
"Set the total number of posts to fetch from the API across all paginated responses. This value limits the total data retrieved "
"even if multiple API calls are required. If not specified, 1000 posts will be recieved."
"even if multiple API calls are required. If not specified, 500 posts will be recieved."
),
)
def main(
Expand All @@ -406,7 +322,7 @@ def main(
url: str = "",
tags: tuple = (),
limit: int = 25,
posts_limit: int = 1000,
posts_limit: int = 500,
) -> None:
"""Method that tests if each click param flag is being passed in correctly."""
# pylint: disable=R0913
Expand Down Expand Up @@ -442,7 +358,7 @@ def main(

# Fetch posts
print("Fetching posts...")
raw_posts = search_posts(query_param, access_token)
raw_posts = scraper.search_posts(query_param, access_token)

# Extract post data
print("Extracting post data...")
Expand Down
100 changes: 100 additions & 0 deletions scraper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
This module contains a function to search for posts using the BlueSky API.
"""

import requests
from alive_progress import alive_bar
from alive_progress.animations.bars import bar_factory


def search_posts(params, token):
# pylint: disable=E1102
# pylint: disable=C0301

"""
Search for posts using the BlueSky API.

Args:
params (dict): The query parameters for the API request.
- query (str, required): The search term for the BlueSky posts.
- sort (str, optional): The sorting criteria for results.
Options include "top" for top posts or "latest" for the latest posts.
- since (str, optional): The start date for posts (ISO 8601 format).
- until (str, optional): The end date for posts (ISO 8601 format).
- mentions (str, optional): Mentions to filter posts by.
- Handles will be resolved to DIDs using the provided token.
- author (str, optional): The author of the posts (handle or DID).
- lang (str, optional): The language of the posts.
- domain (str, optional): A domain URL included in the posts.
- url (str, optional): A specific URL included in the posts.
- tags (list, optional): Tags to filter posts by (each tag <= 640 characters).
- limit (int, optional): The maximum number of posts to retrieve in a single response.
Defaults to 25.
- cursor (str, optional): Pagination token for continuing from a previous request.
- posts_limit (int, optional): The maximum number of posts to retrieve across all responses.
Defaults to 500.

Returns:
list: A list of posts matching the search criteria.

Notes:
- Progress is displayed using a progress bar indicating the number of posts fetched.
- Handles pagination automatically until `posts_limit` is reached or no further results are available.
- Logs and returns partial results if an error occurs during fetching.
"""
# Validate input parameters
if "q" not in params:
raise ValueError("Query parameter is required.")
if not token:
raise ValueError("Token is required.")

posts = []
url = "https://bsky.social/xrpc/app.bsky.feed.searchPosts"
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}

total_fetched = 0
posts_limit = params.get("posts_limit", 500)
butterfly_bar = bar_factory("✨", tip="🦋", errors="🔥🧯👩‍🚒")

with alive_bar(posts_limit, bar=butterfly_bar, spinner="waves") as progress:
while True:
try:
response = requests.get(url, headers=headers, params=params, timeout=10)
response.raise_for_status()
data = response.json()

# Check if we have reached our overall posts limit
new_posts = data.get("posts", [])
posts.extend(new_posts)
total_fetched += len(new_posts)

# Update progress bar
remaining = posts_limit - (total_fetched - len(new_posts))
progress(min(len(new_posts), remaining))

if total_fetched >= posts_limit:
print(
f"Fetched {total_fetched} posts, total: {total_fetched}/{posts_limit}"
)
# Truncate only if we exceeded the limit
if len(posts) > posts_limit:
posts = posts[:posts_limit]
return posts

# Move to the next page if available
next_cursor = data.get("cursor")
if not next_cursor:
print(f"All posts fetched. Total: {total_fetched}")
return posts

params["cursor"] = next_cursor
except requests.exceptions.RequestException as err:
print(f"Error fetching posts: {err}")
print(
"Response:",
response.text if "response" in locals() else "No response",
)
return posts
105 changes: 105 additions & 0 deletions tests/scraper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Testing suite for the scraper module."""

import unittest
import requests
import io
import sys
from unittest.mock import patch, MagicMock
from scraper import search_posts


class TestSearchPosts(unittest.TestCase):
"""Testing the search_posts() method."""

@patch("scraper.requests.get")
def test_no_query(self, mock_get: MagicMock) -> None:
"""Test if the function raises ValueError when a query is not provided."""
params = {}
token = "valid_token"

with self.assertRaises(ValueError) as cm:
search_posts(params, token)

mock_get.assert_not_called()
self.assertIn("query", str(cm.exception).lower())

@patch("scraper.requests.get")
def test_no_token(self, mock_get: MagicMock) -> None:
"""Test if the function raises ValueError when a token is not provided."""
params = {"q": "test"}
token = None

with self.assertRaises(ValueError) as cm:
search_posts(params, token)

mock_get.assert_not_called()
self.assertIn("token", str(cm.exception).lower())

@patch("scraper.requests.get")
def test_valid_response(self, mock_get: MagicMock) -> None:
"""Test that the function returns a list of posts when valid parameters are provided."""
params = {"q": "test"}
token = "valid_token"

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"posts": [
{
"uri": "at://did:plc:12345/app.bsky.feed.post/abcdef",
"cid": "bafyre123...",
"author": {
"did": "did:plc:12345",
"handle": "author_handle",
"displayName": "Author Name",
},
"record": {
"text": "Post content",
"createdAt": "2023-10-01T00:00:00Z",
"$type": "app.bsky.feed.post",
},
"indexedAt": "2023-10-01T00:00:01Z",
}
],
"cursor": None,
}

mock_get.return_value = mock_response

result = search_posts(params, token)

self.assertEqual(len(result), 1)
self.assertEqual(result[0]["record"]["text"], "Post content")
self.assertEqual(result[0]["author"]["handle"], "author_handle")
self.assertEqual(result[0]["record"]["createdAt"], "2023-10-01T00:00:00Z")
self.assertEqual(
result[0]["uri"], "at://did:plc:12345/app.bsky.feed.post/abcdef"
)

# Simulate a failed API response (e.g., 400: [InvalidRequest, ExpiredToken, InvalidToken, BadQueryString])
@patch("scraper.requests.get")
def test_invalid_request(self, mock_get: MagicMock) -> None:
"""Test that the function handles invalid requests gracefully."""
params = {"q": "test"}
token = "invalid_token"

mock_response = MagicMock()
mock_response.status_code = 400
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
"400 Client Error: InvalidToken"
)
mock_get.return_value = mock_response

# Redirecting stdout to StringIO
captured_output = io.StringIO()
sys.stdout = captured_output

result = search_posts(params, token)
sys.stdout = sys.__stdout__

self.assertEqual(result, [])
self.assertIn("400 Client Error:", captured_output.getvalue())


if __name__ == "__main__":
unittest.main()