Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ You can either set this in your .ini-file, or pass/override them directly to the
| expiration | jwt.expiration | | Number of seconds (or a datetime.timedelta |
| | | | instance) before a token expires. |
+--------------+-----------------+---------------+--------------------------------------------+
| audience | jwt.audience | | Proposed audience for the token |
| audience | jwt.audience | | Proposed audience for the token. Multiple |
| | | | audiences accepted via comma separated |
| | | | string. e.g. 'example.org,example2.org' |
+--------------+-----------------+---------------+--------------------------------------------+
| leeway | jwt.leeway | 0 | Number of seconds a token is allowed to be |
| | | | expired before it is rejected. |
Expand Down
17 changes: 14 additions & 3 deletions src/pyramid_jwt/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_token(self, principal, expiration=None, audience=None, **claims):
expiration = datetime.timedelta(seconds=expiration)
payload['exp'] = iat + expiration
if audience:
payload['aud'] = audience
payload['aud'] = self._aud_string_to_list(audience)
token = jwt.encode(payload, self.private_key, algorithm=self.algorithm, json_encoder=self.json_encoder)
if not isinstance(token, str): # Python3 unicode madness
token = token.decode('ascii')
Expand All @@ -71,13 +71,12 @@ def get_claims(self, request):
return {}
try:
claims = jwt.decode(token, self.public_key, algorithms=[self.algorithm],
leeway=self.leeway, audience=self.audience)
leeway=self.leeway, audience=self._aud_string_to_list(self.audience))
return claims
except jwt.InvalidTokenError as e:
log.warning('Invalid JWT token from %s: %s', request.remote_addr, e)
return {}


def unauthenticated_userid(self, request):
return request.jwt_claims.get('sub')

Expand All @@ -94,3 +93,15 @@ def forget(self, request):
'has no effect.',
stacklevel=3)
return []

def _aud_string_to_list(self, audience):
"""
Splits the audience variable into a
list to handle multiple audiences.
:param audience: Comma separated list of audiences
:return: List of one or more audiences
"""
if audience is None:
return None
else:
return audience.split(',')
Comment on lines +97 to +107
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move that method out of the class into a separate function? Since it doesn't need to use anything from self there is no need for it to be a method.

I would also suggest splitting on something like re.split(r'\s*,\s*, audience)` to make this a bit friendlier.

40 changes: 39 additions & 1 deletion tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,27 @@ def test_audience_valid():
request = Request.blank('/')
request.authorization = ('JWT', token)
jwt_claims = policy.get_claims(request)
assert jwt_claims['aud'] == 'example.org'
assert jwt_claims['aud'] == ['example.org']


def test_multiple_audience_valid():
policy = JWTAuthenticationPolicy('secret', audience='example.org,example2.org')
token = policy.create_token(15, name=u'Jöhn', admin=True, audience='example.org,example2.org')
request = Request.blank('/')
request.authorization = ('JWT', token)
jwt_claims = policy.get_claims(request)
for key in ('example.org', 'example2.org'):
assert key in jwt_claims['aud']


def test_multiple_to_one_audience_valid():
policy = JWTAuthenticationPolicy('secret', audience='example.org,example2.org')
token = policy.create_token(15, name=u'Jöhn', admin=True, audience='example.org')
request = Request.blank('/')
request.authorization = ('JWT', token)
jwt_claims = policy.get_claims(request)
assert jwt_claims['aud'] == ['example.org']


def test_audience_invalid():
policy = JWTAuthenticationPolicy('secret', audience='example.org')
Expand All @@ -50,6 +70,24 @@ def test_audience_invalid():
assert jwt_claims == {}


def test_multiple_audience_invalid():
policy = JWTAuthenticationPolicy('secret', audience='example.org,example2.org')
token = policy.create_token(15, name=u'Jöhn', admin=True, audience='example.com,example2.com')
request = Request.blank('/')
request.authorization = ('JWT', token)
jwt_claims = policy.get_claims(request)
assert jwt_claims == {}


def test_multiple_to_one_audience_invalid():
policy = JWTAuthenticationPolicy('secret', audience='example.org,example2.org')
token = policy.create_token(15, name=u'Jöhn', admin=True, audience='example.com')
request = Request.blank('/')
request.authorization = ('JWT', token)
jwt_claims = policy.get_claims(request)
assert jwt_claims == {}


def test_algorithm_unsupported():
policy = JWTAuthenticationPolicy('secret', algorithm='SHA1')
with pytest.raises(NotImplementedError):
Expand Down