diff --git a/README.rst b/README.rst index 2b04529..a494a28 100644 --- a/README.rst +++ b/README.rst @@ -167,7 +167,9 @@ You can either set this in your .ini-file, or pass/override them directly to the | expiration | jwt.expiration | | Number of seconds (or a datetime.timedelta | | | | | instance) before a token expires. | +--------------+-----------------+---------------+--------------------------------------------+ -| audience | jwt.audience | | Proposed audience for the token | +| audience | jwt.audience | | Proposed audience for the token. Multiple | +| | | | audiences accepted via comma separated | +| | | | string. e.g. 'example.org,example2.org' | +--------------+-----------------+---------------+--------------------------------------------+ | leeway | jwt.leeway | 0 | Number of seconds a token is allowed to be | | | | | expired before it is rejected. | diff --git a/src/pyramid_jwt/policy.py b/src/pyramid_jwt/policy.py index 145e8d4..63b3268 100644 --- a/src/pyramid_jwt/policy.py +++ b/src/pyramid_jwt/policy.py @@ -48,7 +48,7 @@ def create_token(self, principal, expiration=None, audience=None, **claims): expiration = datetime.timedelta(seconds=expiration) payload['exp'] = iat + expiration if audience: - payload['aud'] = audience + payload['aud'] = self._aud_string_to_list(audience) token = jwt.encode(payload, self.private_key, algorithm=self.algorithm, json_encoder=self.json_encoder) if not isinstance(token, str): # Python3 unicode madness token = token.decode('ascii') @@ -71,13 +71,12 @@ def get_claims(self, request): return {} try: claims = jwt.decode(token, self.public_key, algorithms=[self.algorithm], - leeway=self.leeway, audience=self.audience) + leeway=self.leeway, audience=self._aud_string_to_list(self.audience)) return claims except jwt.InvalidTokenError as e: log.warning('Invalid JWT token from %s: %s', request.remote_addr, e) return {} - def unauthenticated_userid(self, request): return request.jwt_claims.get('sub') @@ -94,3 +93,15 @@ def forget(self, request): 'has no effect.', stacklevel=3) return [] + + def _aud_string_to_list(self, audience): + """ + Splits the audience variable into a + list to handle multiple audiences. + :param audience: Comma separated list of audiences + :return: List of one or more audiences + """ + if audience is None: + return None + else: + return audience.split(',') diff --git a/tests/test_policy.py b/tests/test_policy.py index 5b8c8c2..234fa07 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -39,7 +39,27 @@ def test_audience_valid(): request = Request.blank('/') request.authorization = ('JWT', token) jwt_claims = policy.get_claims(request) - assert jwt_claims['aud'] == 'example.org' + assert jwt_claims['aud'] == ['example.org'] + + +def test_multiple_audience_valid(): + policy = JWTAuthenticationPolicy('secret', audience='example.org,example2.org') + token = policy.create_token(15, name=u'Jöhn', admin=True, audience='example.org,example2.org') + request = Request.blank('/') + request.authorization = ('JWT', token) + jwt_claims = policy.get_claims(request) + for key in ('example.org', 'example2.org'): + assert key in jwt_claims['aud'] + + +def test_multiple_to_one_audience_valid(): + policy = JWTAuthenticationPolicy('secret', audience='example.org,example2.org') + token = policy.create_token(15, name=u'Jöhn', admin=True, audience='example.org') + request = Request.blank('/') + request.authorization = ('JWT', token) + jwt_claims = policy.get_claims(request) + assert jwt_claims['aud'] == ['example.org'] + def test_audience_invalid(): policy = JWTAuthenticationPolicy('secret', audience='example.org') @@ -50,6 +70,24 @@ def test_audience_invalid(): assert jwt_claims == {} +def test_multiple_audience_invalid(): + policy = JWTAuthenticationPolicy('secret', audience='example.org,example2.org') + token = policy.create_token(15, name=u'Jöhn', admin=True, audience='example.com,example2.com') + request = Request.blank('/') + request.authorization = ('JWT', token) + jwt_claims = policy.get_claims(request) + assert jwt_claims == {} + + +def test_multiple_to_one_audience_invalid(): + policy = JWTAuthenticationPolicy('secret', audience='example.org,example2.org') + token = policy.create_token(15, name=u'Jöhn', admin=True, audience='example.com') + request = Request.blank('/') + request.authorization = ('JWT', token) + jwt_claims = policy.get_claims(request) + assert jwt_claims == {} + + def test_algorithm_unsupported(): policy = JWTAuthenticationPolicy('secret', algorithm='SHA1') with pytest.raises(NotImplementedError):