diff --git a/server/http_test.go b/server/http_test.go index 43d788a0..e17b71f2 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -425,6 +425,26 @@ func TestSubscribe(t *testing.T) { }, }), t), }, + "Reject DM channel subscription": { + subscription: `{"instance_id": "https://jiraurl1.com", "name": "some name", "channel_id": "aaaaaaaaaaaaaaaaaaaaaaaaab", "filters": {"events": ["jira:issue_created"], "projects": ["myproject"], "issue_types": ["10001"]}}`, + expectedStatusCode: http.StatusBadRequest, + apiCalls: func(api *plugintest.API) { + api.On("GetChannel", "aaaaaaaaaaaaaaaaaaaaaaaaab").Return(&model.Channel{ + Id: "aaaaaaaaaaaaaaaaaaaaaaaaab", + Type: model.ChannelTypeDirect, + }, nil) + }, + }, + "Reject GM channel subscription": { + subscription: `{"instance_id": "https://jiraurl1.com", "name": "some name", "channel_id": "aaaaaaaaaaaaaaaaaaaaaaaaab", "filters": {"events": ["jira:issue_created"], "projects": ["myproject"], "issue_types": ["10001"]}}`, + expectedStatusCode: http.StatusBadRequest, + apiCalls: func(api *plugintest.API) { + api.On("GetChannel", "aaaaaaaaaaaaaaaaaaaaaaaaab").Return(&model.Channel{ + Id: "aaaaaaaaaaaaaaaaaaaaaaaaab", + Type: model.ChannelTypeGroup, + }, nil) + }, + }, } { t.Run(name, func(t *testing.T) { api := &plugintest.API{} @@ -442,6 +462,13 @@ func TestSubscribe(t *testing.T) { tc.apiCalls(api) } + if !api.IsMethodCallable(t, "GetChannel", mock.AnythingOfType("string")) { + api.On("GetChannel", mock.AnythingOfType("string")).Return(&model.Channel{ + Id: "aaaaaaaaaaaaaaaaaaaaaaaaab", + Type: model.ChannelTypeOpen, + }, nil) + } + p.updateConfig(func(conf *config) { conf.Secret = someSecret }) @@ -761,6 +788,26 @@ func TestEditSubscription(t *testing.T) { }, }), t), }, + "Reject editing DM channel subscription": { + subscription: `{"instance_id": "https://jiraurl1.com", "name": "some name", "id": "subaaaaaaaaaabbbbbbbbbbccc", "channel_id": "channelaaaaaaaaaabbbbbbbbb", "filters": {"events": ["jira:issue_created"], "projects": ["myproject"], "issue_types": ["10001"]}}`, + expectedStatusCode: http.StatusBadRequest, + apiCalls: func(api *plugintest.API) { + api.On("GetChannel", "channelaaaaaaaaaabbbbbbbbb").Return(&model.Channel{ + Id: "channelaaaaaaaaaabbbbbbbbb", + Type: model.ChannelTypeDirect, + }, nil) + }, + }, + "Reject editing GM channel subscription": { + subscription: `{"instance_id": "https://jiraurl1.com", "name": "some name", "id": "subaaaaaaaaaabbbbbbbbbbccc", "channel_id": "channelaaaaaaaaaabbbbbbbbb", "filters": {"events": ["jira:issue_created"], "projects": ["myproject"], "issue_types": ["10001"]}}`, + expectedStatusCode: http.StatusBadRequest, + apiCalls: func(api *plugintest.API) { + api.On("GetChannel", "channelaaaaaaaaaabbbbbbbbb").Return(&model.Channel{ + Id: "channelaaaaaaaaaabbbbbbbbb", + Type: model.ChannelTypeGroup, + }, nil) + }, + }, } { t.Run(name, func(t *testing.T) { api := &plugintest.API{} @@ -778,6 +825,13 @@ func TestEditSubscription(t *testing.T) { tc.apiCalls(api) } + if !api.IsMethodCallable(t, "GetChannel", mock.AnythingOfType("string")) { + api.On("GetChannel", mock.AnythingOfType("string")).Return(&model.Channel{ + Id: "channelaaaaaaaaaabbbbbbbbb", + Type: model.ChannelTypeOpen, + }, nil) + } + p.updateConfig(func(conf *config) { conf.Secret = someSecret }) diff --git a/server/subscribe.go b/server/subscribe.go index 9390439b..ac7a3add 100755 --- a/server/subscribe.go +++ b/server/subscribe.go @@ -444,6 +444,58 @@ func (p *Plugin) removeChannelSubscription(instanceID types.ID, subscriptionID s }) } +func (p *Plugin) removeSubscriptionsForChannel(instanceID types.ID, channelID string) error { + subs, err := p.getSubscriptions(instanceID) + if err != nil { + return err + } + + if subs.Channel.IDByChannelID[channelID].Len() == 0 { + return nil + } + + subKey := keyWithInstanceID(instanceID, JiraSubscriptionsKey) + return p.client.KV.SetAtomicWithRetries(subKey, func(initialBytes []byte) (interface{}, error) { + subs, err := SubscriptionsFromJSON(initialBytes, instanceID) + if err != nil { + return nil, err + } + + subIDs := subs.Channel.IDByChannelID[channelID] + for _, subID := range subIDs.Elems() { + if sub, ok := subs.Channel.ByID[subID]; ok { + subs.Channel.remove(&sub) + } + } + + modifiedBytes, marshalErr := json.Marshal(&subs) + if marshalErr != nil { + return nil, marshalErr + } + + return modifiedBytes, nil + }) +} + +func (p *Plugin) cleanupDMSubscriptionsOnDisconnect(instanceID types.ID, mattermostUserID string) { + conf := p.getConfig() + dmChannel, err := p.client.Channel.GetDirect(mattermostUserID, conf.botUserID) + if err != nil { + p.client.Log.Warn("Failed to get DM channel for subscription cleanup on disconnect", + "mattermostUserID", mattermostUserID, + "instanceID", string(instanceID), + "error", err.Error()) + return + } + + if err := p.removeSubscriptionsForChannel(instanceID, dmChannel.Id); err != nil { + p.client.Log.Warn("Failed to clean up DM subscriptions on disconnect", + "mattermostUserID", mattermostUserID, + "instanceID", string(instanceID), + "error", err.Error()) + } +} + func (p *Plugin) addChannelSubscription(instanceID types.ID, newSubscription *ChannelSubscription, client Client) error { subKey := keyWithInstanceID(instanceID, JiraSubscriptionsKey) return p.client.KV.SetAtomicWithRetries(subKey, func(initialBytes []byte) (interface{}, error) { @@ -1114,6 +1166,16 @@ func (p *Plugin) httpChannelCreateSubscription(w http.ResponseWriter, r *http.Re errors.New("not a member of the channel specified")) } + channel, appErr := p.client.Channel.Get(subscription.ChannelID) + if appErr != nil { + return respondErr(w, http.StatusInternalServerError, + errors.Wrap(appErr, "failed to get channel")) + } + if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup { + return respondErr(w, http.StatusBadRequest, + errors.New("subscriptions are not allowed in direct message or group message channels")) + } + err = p.hasPermissionToManageSubscription(subscription.InstanceID, mattermostUserID, subscription.ChannelID) if err != nil { return respondErr(w, http.StatusForbidden, @@ -1171,6 +1233,16 @@ func (p *Plugin) httpChannelEditSubscription(w http.ResponseWriter, r *http.Requ fmt.Errorf("channel subscription invalid")) } + channel, appErr := p.client.Channel.Get(subscription.ChannelID) + if appErr != nil { + return respondErr(w, http.StatusInternalServerError, + errors.Wrap(appErr, "failed to get channel")) + } + if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup { + return respondErr(w, http.StatusBadRequest, + errors.New("subscriptions are not allowed in direct message or group message channels")) + } + err = p.hasPermissionToManageSubscription(subscription.InstanceID, mattermostUserID, subscription.ChannelID) if err != nil { return respondErr(w, http.StatusForbidden, diff --git a/server/subscribe_test.go b/server/subscribe_test.go index d7b6b13e..563045d9 100644 --- a/server/subscribe_test.go +++ b/server/subscribe_test.go @@ -1678,3 +1678,215 @@ func TestGetChannelsSubscribed(t *testing.T) { }) } } + +func TestRemoveSubscriptionsForChannel(t *testing.T) { + dmChannelID := "dmchannelaaaaaaaaaaaaaaaa" + otherChannelID := "otherchannelbbbbbbbbbbbbbb" + + for name, tc := range map[string]struct { + existingSubs []ChannelSubscription + expectedRemainingByID map[string]bool + }{ + "no subscriptions to remove": { + existingSubs: []ChannelSubscription{}, + expectedRemainingByID: nil, + }, + "remove single DM subscription": { + existingSubs: []ChannelSubscription{ + { + ID: "sub1______________________", + ChannelID: dmChannelID, + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + }, + expectedRemainingByID: map[string]bool{}, + }, + "remove DM subscription but keep other channel subscriptions": { + existingSubs: []ChannelSubscription{ + { + ID: "sub1______________________", + ChannelID: dmChannelID, + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + { + ID: "sub2______________________", + ChannelID: otherChannelID, + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + }, + expectedRemainingByID: map[string]bool{ + "sub2______________________": true, + }, + }, + "remove multiple DM subscriptions": { + existingSubs: []ChannelSubscription{ + { + ID: "sub1______________________", + ChannelID: dmChannelID, + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + { + ID: "sub2______________________", + ChannelID: dmChannelID, + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_updated"), + Projects: NewStringSet("otherproject"), + IssueTypes: NewStringSet("10002"), + }, + }, + { + ID: "sub3______________________", + ChannelID: otherChannelID, + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + }, + expectedRemainingByID: map[string]bool{ + "sub3______________________": true, + }, + }, + } { + t.Run(name, func(t *testing.T) { + api := &plugintest.API{} + p := Plugin{} + + p.updateConfig(func(conf *config) { + conf.Secret = someSecret + }) + p.SetAPI(api) + p.client = pluginapi.NewClient(api, p.Driver) + p.instanceStore = p.getMockInstanceStoreKV(1) + + existing := withExistingChannelSubscriptions(tc.existingSubs) + existingBytes, err := json.Marshal(existing) + require.NoError(t, err) + + api.On("KVGet", testSubKey).Return(existingBytes, nil) + + if tc.expectedRemainingByID != nil { + api.On("KVSetWithOptions", testSubKey, mock.MatchedBy(func(data []byte) bool { + var savedSubs Subscriptions + unmarshalErr := json.Unmarshal(data, &savedSubs) + if unmarshalErr != nil { + return false + } + + for id := range savedSubs.Channel.ByID { + if !tc.expectedRemainingByID[id] { + return false + } + } + return len(savedSubs.Channel.ByID) == len(tc.expectedRemainingByID) + }), mock.AnythingOfType("model.PluginKVSetOptions")).Return(true, nil) + } + + err = p.removeSubscriptionsForChannel(testInstance1.GetID(), dmChannelID) + assert.NoError(t, err) + }) + } +} + +func TestCleanupDMSubscriptionsOnDisconnect(t *testing.T) { + botUserID := "botuser___________________" + mattermostUserID := "mmuser____________________" + dmChannelID := "dmchannelaaaaaaaaaaaaaaaa" + otherChannelID := "otherchannelbbbbbbbbbbbbbb" + + t.Run("removes DM subscriptions on disconnect", func(t *testing.T) { + api := &plugintest.API{} + p := Plugin{} + + p.updateConfig(func(conf *config) { + conf.Secret = someSecret + conf.botUserID = botUserID + }) + p.SetAPI(api) + p.client = pluginapi.NewClient(api, p.Driver) + p.instanceStore = p.getMockInstanceStoreKV(1) + + api.On("GetDirectChannel", mattermostUserID, botUserID).Return(&model.Channel{ + Id: dmChannelID, + Type: model.ChannelTypeDirect, + }, nil) + + existing := withExistingChannelSubscriptions([]ChannelSubscription{ + { + ID: "sub1______________________", + ChannelID: dmChannelID, + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + { + ID: "sub2______________________", + ChannelID: otherChannelID, + Filters: SubscriptionFilters{ + Events: NewStringSet("jira:issue_created"), + Projects: NewStringSet("myproject"), + IssueTypes: NewStringSet("10001"), + }, + }, + }) + existingBytes, err := json.Marshal(existing) + require.NoError(t, err) + + api.On("KVGet", testSubKey).Return(existingBytes, nil) + api.On("KVSetWithOptions", testSubKey, mock.MatchedBy(func(data []byte) bool { + var savedSubs Subscriptions + unmarshalErr := json.Unmarshal(data, &savedSubs) + if unmarshalErr != nil { + return false + } + + _, hasDMSub := savedSubs.Channel.ByID["sub1______________________"] + _, hasOtherSub := savedSubs.Channel.ByID["sub2______________________"] + return !hasDMSub && hasOtherSub && len(savedSubs.Channel.ByID) == 1 + }), mock.AnythingOfType("model.PluginKVSetOptions")).Return(true, nil) + + api.On("LogDebug", mockAnythingOfTypeBatch("string", 11)...).Return() + api.On("LogWarn", mockAnythingOfTypeBatch("string", 10)...).Return() + + p.cleanupDMSubscriptionsOnDisconnect(testInstance1.GetID(), mattermostUserID) + }) + + t.Run("no-op when DM channel does not exist", func(t *testing.T) { + api := &plugintest.API{} + p := Plugin{} + + p.updateConfig(func(conf *config) { + conf.Secret = someSecret + conf.botUserID = botUserID + }) + p.SetAPI(api) + p.client = pluginapi.NewClient(api, p.Driver) + p.instanceStore = p.getMockInstanceStoreKV(1) + + api.On("GetDirectChannel", mattermostUserID, botUserID).Return(nil, &model.AppError{Message: "channel not found"}) + api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + + p.cleanupDMSubscriptionsOnDisconnect(testInstance1.GetID(), mattermostUserID) + + api.AssertNotCalled(t, "KVGet", mock.Anything) + }) +} diff --git a/server/user.go b/server/user.go index fa56b3a2..4746456e 100644 --- a/server/user.go +++ b/server/user.go @@ -313,6 +313,8 @@ func (p *Plugin) disconnectUser(instance Instance, user *User) (*Connection, err return nil, err } + p.cleanupDMSubscriptionsOnDisconnect(instance.GetID(), user.MattermostUserID.String()) + info, err := p.GetUserInfo(user.MattermostUserID, user) if err != nil { return nil, err diff --git a/server/utils_test.go b/server/utils_test.go index 3b327d7e..cb9c86e4 100644 --- a/server/utils_test.go +++ b/server/utils_test.go @@ -132,9 +132,13 @@ func TestDisconnectUserDueToExpiredToken(t *testing.T) { return b.UserId == testMattermostUserID.String() })).Return().Once() + // GetDirectChannel called twice, once for DM subscription cleanup, once for DM notification api.On("GetDirectChannel", testMattermostUserID.String(), testBotUserID).Return(&model.Channel{ Id: testChannelID, - }, nil).Once() + }, nil) + + // KVGet for subscription cleanup + api.On("KVGet", mock.AnythingOfType("string")).Return(nil, nil) api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { return post.UserId == testBotUserID &&