Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,27 @@
]


class MockCore(trakt.core.Core):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class MockCore:
def __init__(self):
self.mock_data = {}
for mock_file in MOCK_DATA_FILES:
with open(mock_file, encoding='utf-8') as f:
self.mock_data.update(json.load(f))

def _handle_request(self, method, url, data=None):
uri = url[len(trakt.core.BASE_URL):]
def request(self, method, uri, data=None):
if uri.startswith('/'):
uri = uri[1:]
# use a deepcopy of the mocked data to ensure clean responses on every
# request. this prevents rewrites to JSON responses from persisting
method_responses = deepcopy(self.mock_data).get(uri, {})
result = method_responses.get(method.upper())
if result is None:
print(f"Missing mock for {method.upper()} {trakt.core.BASE_URL}{uri}")

return result


"""Override utility functions from trakt.core to use an underlying MockCore
instance
"""
trakt.core.CORE = MockCore()
trakt.core.get = trakt.core.CORE.get
trakt.core.post = trakt.core.CORE.post
trakt.core.delete = trakt.core.CORE.delete
trakt.core.put = trakt.core.CORE.put
method_responses = self.mock_data.get(uri, {})
response = method_responses.get(method.upper())
if response is None:
print(f"No mock for {uri}")
return deepcopy(response)


trakt.core.CLIENT_ID = 'FOO'
trakt.core.CLIENT_SECRET = 'BAR'

# Override request function with MockCore instance
trakt.core.api().request = MockCore().request
2 changes: 1 addition & 1 deletion trakt/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.4.0.dev0"
__version__ = "4.0.0.dev0"
181 changes: 181 additions & 0 deletions trakt/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import json
import logging
from datetime import datetime, timedelta, timezone
from functools import lru_cache
from json import JSONDecodeError

from requests import Session
from requests.auth import AuthBase

from trakt import errors
from trakt.config import AuthConfig
from trakt.core import TIMEOUT
from trakt.errors import BadResponseException, OAuthException

__author__ = 'Elan Ruusamäe'


class HttpClient:
"""Class for abstracting HTTP requests
"""

#: Default request HEADERS
headers = {'Content-Type': 'application/json', 'trakt-api-version': '2'}

def __init__(self, base_url: str, session: Session, timeout=None):
self.base_url = base_url
self.session = session
self.auth = None
self.timeout = timeout or TIMEOUT
self.logger = logging.getLogger('trakt.http_client')

def get(self, url: str):
return self.request('get', url)

def delete(self, url: str):
self.request('delete', url)

def post(self, url: str, data):
return self.request('post', url, data=data)

def put(self, url: str, data):
return self.request('put', url, data=data)

def set_auth(self, auth):
self.auth = auth

def request(self, method, url, data=None):
"""Handle actually talking out to the trakt API, logging out debug
information, raising any relevant `TraktException` Exception types,
and extracting and returning JSON data

:param method: The HTTP method we're executing on. Will be one of
post, put, delete, get
:param url: The fully qualified url to send our request to
:param data: Optional data payload to send to the API
:return: The decoded JSON response from the Trakt API
:raises TraktException: If any non-200 return code is encountered
"""

url = self.base_url + url
self.logger.debug('REQUEST [%s] (%s)', method, url)
if method == 'get': # GETs need to pass data as params, not body
response = self.session.request(method, url, headers=self.headers, auth=self.auth, timeout=self.timeout, params=data)
else:
response = self.session.request(method, url, headers=self.headers, auth=self.auth, timeout=self.timeout, data=json.dumps(data))
self.logger.debug('RESPONSE [%s] (%s): %s', method, url, str(response))
if response.status_code == 204: # HTTP no content
return None
self.raise_if_needed(response)

return self.decode_response(response)

@staticmethod
def decode_response(response):
try:
return json.loads(response.content.decode('UTF-8', 'ignore'))
except JSONDecodeError as e:
raise BadResponseException(f"Unable to parse JSON: {e}")

def raise_if_needed(self, response):
if response.status_code in self.error_map:
raise self.error_map[response.status_code](response)

@property
@lru_cache(maxsize=None)
def error_map(self):
"""Map HTTP response codes to exception types
"""

# Get all of our exceptions except the base exception
errs = [getattr(errors, att) for att in errors.__all__
if att != 'TraktException']

return {err.http_code: err for err in errs}


class TokenAuth(AuthBase):
"""Attaches Trakt.tv token Authentication to the given Request object."""

#: The OAuth2 Redirect URI for your OAuth Application
REDIRECT_URI: str = 'urn:ietf:wg:oauth:2.0:oob'

def __init__(self, client: HttpClient, config: AuthConfig):
super().__init__()
self.config = config
self.client = client
# OAuth token validity checked
self.OAUTH_TOKEN_VALID = None
self.logger = logging.getLogger('trakt.api.token_auth')

def __call__(self, r):
# Skip oauth requests
if r.path_url.startswith('/oauth/'):
return r

[client_id, client_token] = self.get_token()

r.headers.update({
'trakt-api-key': client_id,
'Authorization': f'Bearer {client_token}',
})
return r

def get_token(self):
"""Return client_id, client_token pair needed for Trakt.tv authentication
"""

self.config.load()
# Check token validity and refresh token if needed
if not self.OAUTH_TOKEN_VALID and self.config.have_refresh_token():
self.validate_token()

return [
self.config.CLIENT_ID,
self.config.OAUTH_TOKEN,
]

def validate_token(self):
"""Check if current OAuth token has not expired"""

current = datetime.now(tz=timezone.utc)
expires_at = datetime.fromtimestamp(self.config.OAUTH_EXPIRES_AT, tz=timezone.utc)
if expires_at - current > timedelta(days=2):
self.OAUTH_TOKEN_VALID = True
else:
self.refresh_token()

def refresh_token(self):
"""Request Trakt API for a new valid OAuth token using refresh_token"""

self.logger.info("OAuth token has expired, refreshing now...")
data = {
'client_id': self.config.CLIENT_ID,
'client_secret': self.config.CLIENT_SECRET,
'refresh_token': self.config.OAUTH_REFRESH,
'redirect_uri': self.REDIRECT_URI,
'grant_type': 'refresh_token'
}

try:
response = self.client.post('/oauth/token', data)
except OAuthException:
self.logger.debug(
"Rejected - Unable to refresh expired OAuth token, "
"refresh_token is invalid"
)
return

self.config.update(
OAUTH_TOKEN=response.get("access_token"),
OAUTH_REFRESH=response.get("refresh_token"),
OAUTH_EXPIRES_AT=response.get("created_at") + response.get("expires_in"),
)
self.OAUTH_TOKEN_VALID = True

self.logger.info(
"OAuth token successfully refreshed, valid until {}".format(
datetime.fromtimestamp(self.config.OAUTH_EXPIRES_AT, tz=timezone.utc)
)
)
self.config.store()
Comment on lines +138 to +181
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve exception handling on OAuth refresh failures

If the refresh token is invalid or expired, the refresh_token method logs a debug message but continues without raising an error. This could cause the calling code to proceed incorrectly. Consider raising a specific exception to alert the caller that the refresh sequence failed.

75 changes: 75 additions & 0 deletions trakt/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
"""Authentication methods"""

__author__ = 'Jon Nappi, Elan Ruusamäe'

from trakt import DEVICE_AUTH, OAUTH_AUTH, PIN_AUTH, api
from trakt import config as config_factory
from trakt.config import AuthConfig


def pin_auth(*args, config, **kwargs):
from trakt.auth.pin import PinAuthAdapter

return PinAuthAdapter(*args, client=api(), config=config, **kwargs).authenticate()


def oauth_auth(*args, config, **kwargs):
from trakt.auth.oauth import OAuthAdapter

return OAuthAdapter(*args, client=api(), config=config, **kwargs).authenticate()


def device_auth(config):
from trakt.auth.device import DeviceAuthAdapter

return DeviceAuthAdapter(client=api(), config=config).authenticate()


def get_client_info(app_id: bool, config: AuthConfig):
"""Helper function to poll the user for Client ID and Client Secret
strings

:return: A 2-tuple of client_id, client_secret
"""
print('If you do not have a client ID and secret. Please visit the '
'following url to create them.')
print('https://trakt.tv/oauth/applications')
client_id = input('Please enter your client id: ')
client_secret = input('Please enter your client secret: ')
if app_id:
msg = f'Please enter your application ID ({config.APPLICATION_ID}): '
user_input = input(msg)
if user_input:
config.APPLICATION_ID = user_input
return client_id, client_secret


def init_auth(method: str, *args, client_id=None, client_secret=None, store=False, **kwargs):
"""Run the auth function specified by *AUTH_METHOD*

:param store: Boolean flag used to determine if your trakt api auth data
should be stored locally on the system. Default is :const:`False` for
the security conscious
"""

methods = {
PIN_AUTH: pin_auth,
OAUTH_AUTH: oauth_auth,
DEVICE_AUTH: device_auth,
}

config = config_factory()
adapter = methods.get(method, PIN_AUTH)

"""
Update client_id, client_secret from input or ask them interactively
"""
if client_id is None and client_secret is None:
client_id, client_secret = get_client_info(adapter.NEEDS_APPLICATION_ID, config)
config.CLIENT_ID, config.CLIENT_SECRET = client_id, client_secret

adapter(*args, config=config, **kwargs)

if store:
config.store()
6 changes: 6 additions & 0 deletions trakt/auth/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class BaseAdapter:
#: The OAuth2 Redirect URI for your OAuth Application
REDIRECT_URI: str = 'urn:ietf:wg:oauth:2.0:oob'

#: True if the Adapter needs APPLICATION_ID
NEEDS_APPLICATION_ID = False
Loading