diff --git a/server/http_test.go b/server/http_test.go index e17b71f2..195e644f 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -636,9 +636,22 @@ func TestEditSubscription(t *testing.T) { expectedStatusCode: http.StatusBadRequest, }, "No Permissions": { - subscription: `{"id": "aaaaaaaaaaaaaaaaaaaaaaaaab", "channel_id": "aaaaaaaaaaaaaaaaaaaaaaaaac", "filters": {"events": ["jira:issue_created"], "project": ["otherproject"]}}`, + subscription: `{"instance_id": "https://jiraurl1.com", "id": "aaaaaaaaaaaaaaaaaaaaaaaaab", "channel_id": "aaaaaaaaaaaaaaaaaaaaaaaaac", "filters": {"events": ["jira:issue_created"], "project": ["otherproject"]}}`, expectedStatusCode: http.StatusForbidden, apiCalls: func(api *plugintest.API) { + existing := withExistingChannelSubscriptions([]ChannelSubscription{ + { + ID: "aaaaaaaaaaaaaaaaaaaaaaaaab", + ChannelID: "aaaaaaaaaaaaaaaaaaaaaaaaac", + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + }) + existingBytes, _ := json.Marshal(existing) + api.On("KVGet", testSubKey).Return(existingBytes, nil) api.On("HasPermissionTo", mock.AnythingOfType("string"), mock.Anything).Return(false) }, }, @@ -792,6 +805,20 @@ func TestEditSubscription(t *testing.T) { subscription: `{"instance_id": "https://jiraurl1.com", "name": "some name", "id": "subaaaaaaaaaabbbbbbbbbbccc", "channel_id": "channelaaaaaaaaaabbbbbbbbb", "filters": {"events": ["jira:issue_created"], "projects": ["myproject"], "issue_types": ["10001"]}}`, expectedStatusCode: http.StatusBadRequest, apiCalls: func(api *plugintest.API) { + existing := withExistingChannelSubscriptions([]ChannelSubscription{ + { + ID: "subaaaaaaaaaabbbbbbbbbbccc", + ChannelID: "channelaaaaaaaaaabbbbbbbbb", + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + }) + existingBytes, _ := json.Marshal(existing) + api.On("KVGet", testSubKey).Return(existingBytes, nil) + api.On("HasPermissionTo", mock.AnythingOfType("string"), mock.Anything).Return(true) api.On("GetChannel", "channelaaaaaaaaaabbbbbbbbb").Return(&model.Channel{ Id: "channelaaaaaaaaaabbbbbbbbb", Type: model.ChannelTypeDirect, @@ -802,12 +829,33 @@ func TestEditSubscription(t *testing.T) { subscription: `{"instance_id": "https://jiraurl1.com", "name": "some name", "id": "subaaaaaaaaaabbbbbbbbbbccc", "channel_id": "channelaaaaaaaaaabbbbbbbbb", "filters": {"events": ["jira:issue_created"], "projects": ["myproject"], "issue_types": ["10001"]}}`, expectedStatusCode: http.StatusBadRequest, apiCalls: func(api *plugintest.API) { + existing := withExistingChannelSubscriptions([]ChannelSubscription{ + { + ID: "subaaaaaaaaaabbbbbbbbbbccc", + ChannelID: "channelaaaaaaaaaabbbbbbbbb", + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + }) + existingBytes, _ := json.Marshal(existing) + api.On("KVGet", testSubKey).Return(existingBytes, nil) + api.On("HasPermissionTo", mock.AnythingOfType("string"), mock.Anything).Return(true) api.On("GetChannel", "channelaaaaaaaaaabbbbbbbbb").Return(&model.Channel{ Id: "channelaaaaaaaaaabbbbbbbbb", Type: model.ChannelTypeGroup, }, nil) }, }, + "Reject editing non-existent subscription": { + subscription: `{"instance_id": "https://jiraurl1.com", "name": "hijacked", "id": "nonexistentsubidaaaaaaaaaa", "channel_id": "attackerchannelaaabbbbbccc", "filters": {"events": ["jira:issue_created"], "projects": ["myproject"], "issue_types": ["10001"]}}`, + expectedStatusCode: http.StatusBadRequest, + apiCalls: func(api *plugintest.API) { + api.On("KVGet", testSubKey).Return(nil, nil) + }, + }, } { t.Run(name, func(t *testing.T) { api := &plugintest.API{} diff --git a/server/subscribe.go b/server/subscribe.go index ac7a3add..0d485945 100755 --- a/server/subscribe.go +++ b/server/subscribe.go @@ -1233,26 +1233,46 @@ func (p *Plugin) httpChannelEditSubscription(w http.ResponseWriter, r *http.Requ fmt.Errorf("channel subscription invalid")) } - channel, appErr := p.client.Channel.Get(subscription.ChannelID) - if appErr != nil { - return respondErr(w, http.StatusInternalServerError, - errors.Wrap(appErr, "failed to get channel")) - } - if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup { + existingSub, err := p.getChannelSubscription(subscription.InstanceID, subscription.ID) + if err != nil { return respondErr(w, http.StatusBadRequest, - errors.New("subscriptions are not allowed in direct message or group message channels")) + errors.Wrap(err, "failed to find existing subscription")) } - err = p.hasPermissionToManageSubscription(subscription.InstanceID, mattermostUserID, subscription.ChannelID) + err = p.hasPermissionToManageSubscription(subscription.InstanceID, mattermostUserID, existingSub.ChannelID) if err != nil { return respondErr(w, http.StatusForbidden, - errors.Wrap(err, "you don't have permission to manage subscriptions")) + errors.Wrap(err, "you don't have permission to manage subscriptions in the original channel")) } - _, err = p.client.Channel.GetMember(subscription.ChannelID, mattermostUserID) + _, err = p.client.Channel.GetMember(existingSub.ChannelID, mattermostUserID) if err != nil { return respondErr(w, http.StatusForbidden, - errors.New("not a member of the channel specified")) + errors.New("not a member of the channel that owns this subscription")) + } + + if subscription.ChannelID != existingSub.ChannelID { + err = p.hasPermissionToManageSubscription(subscription.InstanceID, mattermostUserID, subscription.ChannelID) + if err != nil { + return respondErr(w, http.StatusForbidden, + errors.Wrap(err, "you don't have permission to manage subscriptions in the target channel")) + } + + _, err = p.client.Channel.GetMember(subscription.ChannelID, mattermostUserID) + if err != nil { + return respondErr(w, http.StatusForbidden, + errors.New("not a member of the target channel")) + } + } + + channel, appErr := p.client.Channel.Get(subscription.ChannelID) + if appErr != nil { + return respondErr(w, http.StatusInternalServerError, + errors.Wrap(appErr, "failed to get channel")) + } + if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup { + return respondErr(w, http.StatusBadRequest, + errors.New("subscriptions are not allowed in direct message or group message channels")) } client, _, connection, err := p.getClient(subscription.InstanceID, types.ID(mattermostUserID))