diff --git a/check.go b/check.go index 64df9e24..d0f8200b 100644 --- a/check.go +++ b/check.go @@ -20,7 +20,7 @@ func Check(request CheckRequest, manager Github) (CheckResponse, error) { filterStates = request.Source.States } - pulls, err := manager.ListPullRequests(filterStates) + pulls, err := manager.ListPullRequests(filterStates, request.Source.Branch) if err != nil { return nil, fmt.Errorf("failed to get last commits: %s", err) } diff --git a/check_test.go b/check_test.go index 8c422914..097eaecd 100644 --- a/check_test.go +++ b/check_test.go @@ -262,6 +262,21 @@ func TestCheck(t *testing.T) { resource.NewVersion(testPullRequests[10]), }, }, + + { + description: "check filters out versions from a PR which do not match the branch filter", + source: resource.Source{ + Repository: "itsdalmo/test-repository", + AccessToken: "oauthtoken", + Branch: "pr3", + }, + version: resource.Version{}, + pullRequests: testPullRequests, + files: [][]string{}, + expected: resource.CheckResponse{ + resource.NewVersion(testPullRequests[2]), + }, + }, } for _, tc := range tests { @@ -274,10 +289,14 @@ func TestCheck(t *testing.T) { } for i := range tc.pullRequests { for j := range filterStates { - if filterStates[j] == tc.pullRequests[i].PullRequestObject.State { - pullRequests = append(pullRequests, tc.pullRequests[i]) - break + if filterStates[j] != tc.pullRequests[i].PullRequestObject.State { + continue + } + if tc.source.Branch != "" && tc.source.Branch != tc.pullRequests[i].PullRequestObject.HeadRefName { + continue } + pullRequests = append(pullRequests, tc.pullRequests[i]) + break } } github.ListPullRequestsReturns(pullRequests, nil) diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 523838c2..f0b9c31e 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -90,6 +90,19 @@ func TestCheckE2E(t *testing.T) { }, }, + { + description: "check will only return versions that match the specified branch", + source: resource.Source{ + Repository: "itsdalmo/test-repository", + AccessToken: os.Getenv("GITHUB_ACCESS_TOKEN"), + Branch: "my_second_pull", + }, + version: resource.Version{}, + expected: resource.CheckResponse{ + resource.Version{PR: targetPullRequestID, Commit: targetCommitID, CommittedDate: targetDateTime}, + }, + }, + { description: "check will skip versions which only match the ignore paths", source: resource.Source{ diff --git a/fakes/fake_github.go b/fakes/fake_github.go index 1847478f..22645d25 100644 --- a/fakes/fake_github.go +++ b/fakes/fake_github.go @@ -61,10 +61,11 @@ type FakeGithub struct { result1 []string result2 error } - ListPullRequestsStub func([]githubv4.PullRequestState) ([]*resource.PullRequest, error) + ListPullRequestsStub func([]githubv4.PullRequestState, string) ([]*resource.PullRequest, error) listPullRequestsMutex sync.RWMutex listPullRequestsArgsForCall []struct { arg1 []githubv4.PullRequestState + arg2 string } listPullRequestsReturns struct { result1 []*resource.PullRequest @@ -357,7 +358,7 @@ func (fake *FakeGithub) ListModifiedFilesReturnsOnCall(i int, result1 []string, }{result1, result2} } -func (fake *FakeGithub) ListPullRequests(arg1 []githubv4.PullRequestState) ([]*resource.PullRequest, error) { +func (fake *FakeGithub) ListPullRequests(arg1 []githubv4.PullRequestState, arg2 string) ([]*resource.PullRequest, error) { var arg1Copy []githubv4.PullRequestState if arg1 != nil { arg1Copy = make([]githubv4.PullRequestState, len(arg1)) @@ -367,11 +368,12 @@ func (fake *FakeGithub) ListPullRequests(arg1 []githubv4.PullRequestState) ([]*r ret, specificReturn := fake.listPullRequestsReturnsOnCall[len(fake.listPullRequestsArgsForCall)] fake.listPullRequestsArgsForCall = append(fake.listPullRequestsArgsForCall, struct { arg1 []githubv4.PullRequestState - }{arg1Copy}) - fake.recordInvocation("ListPullRequests", []interface{}{arg1Copy}) + arg2 string + }{arg1Copy, arg2}) + fake.recordInvocation("ListPullRequests", []interface{}{arg1Copy, arg2}) fake.listPullRequestsMutex.Unlock() if fake.ListPullRequestsStub != nil { - return fake.ListPullRequestsStub(arg1) + return fake.ListPullRequestsStub(arg1, arg2) } if specificReturn { return ret.result1, ret.result2 @@ -386,17 +388,17 @@ func (fake *FakeGithub) ListPullRequestsCallCount() int { return len(fake.listPullRequestsArgsForCall) } -func (fake *FakeGithub) ListPullRequestsCalls(stub func([]githubv4.PullRequestState) ([]*resource.PullRequest, error)) { +func (fake *FakeGithub) ListPullRequestsCalls(stub func([]githubv4.PullRequestState, string) ([]*resource.PullRequest, error)) { fake.listPullRequestsMutex.Lock() defer fake.listPullRequestsMutex.Unlock() fake.ListPullRequestsStub = stub } -func (fake *FakeGithub) ListPullRequestsArgsForCall(i int) []githubv4.PullRequestState { +func (fake *FakeGithub) ListPullRequestsArgsForCall(i int) ([]githubv4.PullRequestState, string) { fake.listPullRequestsMutex.RLock() defer fake.listPullRequestsMutex.RUnlock() argsForCall := fake.listPullRequestsArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeGithub) ListPullRequestsReturns(result1 []*resource.PullRequest, result2 error) { diff --git a/github.go b/github.go index ab10cbdc..db87f54a 100644 --- a/github.go +++ b/github.go @@ -20,7 +20,7 @@ import ( // Github for testing purposes. //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -o fakes/fake_github.go . Github type Github interface { - ListPullRequests([]githubv4.PullRequestState) ([]*PullRequest, error) + ListPullRequests([]githubv4.PullRequestState, string) ([]*PullRequest, error) ListModifiedFiles(int) ([]string, error) PostComment(string, string) error GetPullRequest(string, string) (*PullRequest, error) @@ -98,7 +98,7 @@ func NewGithubClient(s *Source) (*GithubClient, error) { } // ListPullRequests gets the last commit on all pull requests with the matching state. -func (m *GithubClient) ListPullRequests(prStates []githubv4.PullRequestState) ([]*PullRequest, error) { +func (m *GithubClient) ListPullRequests(prStates []githubv4.PullRequestState, prHeadRefName string) ([]*PullRequest, error) { var query struct { Repository struct { PullRequests struct { @@ -128,7 +128,7 @@ func (m *GithubClient) ListPullRequests(prStates []githubv4.PullRequestState) ([ EndCursor githubv4.String HasNextPage bool } - } `graphql:"pullRequests(first:$prFirst,states:$prStates,after:$prCursor)"` + } `graphql:"pullRequests(first:$prFirst,states:$prStates,after:$prCursor,headRefName:$prHeadRefName)"` } `graphql:"repository(owner:$repositoryOwner,name:$repositoryName)"` } @@ -137,12 +137,17 @@ func (m *GithubClient) ListPullRequests(prStates []githubv4.PullRequestState) ([ "repositoryName": githubv4.String(m.Repository), "prFirst": githubv4.Int(100), "prStates": prStates, + "prHeadRefName": (*githubv4.String)(nil), "prCursor": (*githubv4.String)(nil), "commitsLast": githubv4.Int(1), "prReviewStates": []githubv4.PullRequestReviewState{githubv4.PullRequestReviewStateApproved}, "labelsFirst": githubv4.Int(100), } + if len(prHeadRefName) > 0 { + vars["prHeadRefName"] = githubv4.String(prHeadRefName) + } + var response []*PullRequest for { if err := m.V4.Query(context.TODO(), &query, vars); err != nil { diff --git a/models.go b/models.go index 9e4e7b1c..0c2c0c84 100644 --- a/models.go +++ b/models.go @@ -27,6 +27,7 @@ type Source struct { RequiredReviewApprovals int `json:"required_review_approvals"` Labels []string `json:"labels"` States []githubv4.PullRequestState `json:"states"` + Branch string `json:"branch"` } // Validate the source configuration.