diff --git a/jumpgate/network/drivers/sl/networks.py b/jumpgate/network/drivers/sl/networks.py index c62a05d..ced6587 100644 --- a/jumpgate/network/drivers/sl/networks.py +++ b/jumpgate/network/drivers/sl/networks.py @@ -36,18 +36,26 @@ def on_get(self, req, resp): @param req: Http Request body @param resp: Http Response body """ - tenant_id = req.env['auth']['tenant_id'] client = req.env['sl_client'] + tenant_id = req.env['auth']['tenant_id'] + + if (req.get_param('tenant_id', store={'tenant_id': tenant_id}) != + tenant_id): + resp.status = 401 - _filter = {'networkVlans': {}} - if req.get_param('name'): - _filter['networkVlans']['id'] = { - 'operation': req.get_param('name')} + if not req.get_param_as_bool('shared', store={'shared': False}): + _filter = {'networkVlans': {}} + if req.get_param('name'): + _filter['networkVlans']['id'] = { + 'operation': req.get_param('name')} - vlans = client['Account'].getNetworkVlans(mask=NETWORK_MASK, - filter=_filter) - network = [format_network(vlan, tenant_id) - for vlan in sorted(vlans, key=operator.itemgetter('id'))] + vlans = client['Account'].getNetworkVlans(mask=NETWORK_MASK, + filter=_filter) + network = [format_network(vlan, tenant_id) + for vlan in sorted(vlans, + key=operator.itemgetter('id'))] + else: + network = [] resp.body = {'networks': network} resp.status = 200 diff --git a/tests/jumpgate-tests/network/test_network.py b/tests/jumpgate-tests/network/test_network.py index c162524..8aba7d4 100644 --- a/tests/jumpgate-tests/network/test_network.py +++ b/tests/jumpgate-tests/network/test_network.py @@ -114,6 +114,18 @@ def test_on_get_response_networksv2_show(self): self.assertEqual(resp.status, 200) self.check_response_body(resp.body['networks'][0]) + def test_on_get_response_networksv2_shared_networks_show(self): + """Test show function in NetworksV2 for shared networks""" + + _, env = get_client_env(query_string='name=123321&shared=True') + + req = falcon.Request(env) + resp = falcon.Response() + + networks.NetworksV2().on_get(req, resp) + self.assertEqual(resp.status, 200) + self.assertEqual(len(resp.body['networks']), 0) + def test_on_get_response_networksv2_show_no_match(self): """Test show function in NetworksV2 with no matching ID""" @@ -144,3 +156,15 @@ def test_on_get_response_networksv2_list(self): networks.NetworksV2().on_get(req, resp) self.assertEqual(resp.status, 200) self.check_response_body(resp.body['networks'][0]) + + def test_on_get_response_networksv2_shared_networks_list(self): + """Test list function for shared networks""" + + _, env = get_client_env(query_string='name=123321&shared=True') + + req = falcon.Request(env) + resp = falcon.Response() + + networks.NetworksV2().on_get(req, resp) + self.assertEqual(resp.status, 200) + self.assertEqual(len(resp.body['networks']), 0)