Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func main() {
flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS")

flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
flagSet.String("keycloak-group", "", "restrict login to members of this group.")
flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
flagSet.String("github-org", "", "restrict logins to members of this organisation")
flagSet.String("github-team", "", "restrict logins to members of this team")
Expand Down
3 changes: 3 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Options struct {
TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"`

AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
KeycloakGroup string `flag:"keycloak-group" cfg:"keycloak_group"`
AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"`
EmailDomains []string `flag:"email-domain" cfg:"email_domains"`
GitHubOrg string `flag:"github-org" cfg:"github_org"`
Expand Down Expand Up @@ -241,6 +242,8 @@ func parseProviderInfo(o *Options, msgs []string) []string {
p.Configure(o.AzureTenant)
case *providers.GitHubProvider:
p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam)
case *providers.KeycloakProvider:
p.SetGroup(o.KeycloakGroup)
case *providers.GoogleProvider:
if o.GoogleServiceAccountJSON != "" {
file, err := os.Open(o.GoogleServiceAccountJSON)
Expand Down
85 changes: 85 additions & 0 deletions providers/keycloak.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package providers

import (
"log"
"net/http"
"net/url"

"github.com/bitly/oauth2_proxy/api"
)

type KeycloakProvider struct {
*ProviderData
Group string
}

func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
p.ProviderName = "Keycloak"
if p.LoginURL == nil || p.LoginURL.String() == "" {
p.LoginURL = &url.URL{
Scheme: "https",
Host: "keycloak.org",
Path: "/oauth/authorize",
}
}
if p.RedeemURL == nil || p.RedeemURL.String() == "" {
p.RedeemURL = &url.URL{
Scheme: "https",
Host: "keycloak.org",
Path: "/oauth/token",
}
}
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
p.ValidateURL = &url.URL{
Scheme: "https",
Host: "keycloak.org",
Path: "/api/v3/user",
}
}
if p.Scope == "" {
p.Scope = "api"
}
return &KeycloakProvider{ProviderData: p}
}

func (p *KeycloakProvider) SetGroup(group string) {
p.Group = group
}

func (p *KeycloakProvider) GetEmailAddress(s *SessionState) (string, error) {

req, err := http.NewRequest("GET", p.ValidateURL.String(), nil)
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
if err != nil {
log.Printf("failed building request %s", err)
return "", err
}
json, err := api.Request(req)
if err != nil {
log.Printf("failed making request %s", err)
return "", err
}

if p.Group != "" {
var groups, err = json.Get("groups").Array()
if err != nil {
log.Printf("groups not found %s", err)
return "", err
}

var found = false
for i := range groups {
if groups[i].(string) == p.Group {
found = true
break
}
}

if found != true {
log.Printf("group not found, access denied")
return "", nil
}
}

return json.Get("email").String()
}
147 changes: 147 additions & 0 deletions providers/keycloak_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package providers

import (
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/bmizerany/assert"
)

func testKeycloakProvider(hostname, group string) *KeycloakProvider {
p := NewKeycloakProvider(
&ProviderData{
ProviderName: "",
LoginURL: &url.URL{},
RedeemURL: &url.URL{},
ProfileURL: &url.URL{},
ValidateURL: &url.URL{},
Scope: ""})

if group != "" {
p.SetGroup(group)
}

if hostname != "" {
updateURL(p.Data().LoginURL, hostname)
updateURL(p.Data().RedeemURL, hostname)
updateURL(p.Data().ProfileURL, hostname)
updateURL(p.Data().ValidateURL, hostname)
}
return p
}

func testKeycloakBackend(payload string) *httptest.Server {
path := "/api/v3/user"

return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
if url.Path != path {
w.WriteHeader(404)
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
w.WriteHeader(403)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}

func TestKeycloakProviderDefaults(t *testing.T) {
p := testKeycloakProvider("", "")
assert.NotEqual(t, nil, p)
assert.Equal(t, "Keycloak", p.Data().ProviderName)
assert.Equal(t, "https://keycloak.org/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://keycloak.org/oauth/token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://keycloak.org/api/v3/user",
p.Data().ValidateURL.String())
assert.Equal(t, "api", p.Data().Scope)
}

func TestKeycloakProviderOverrides(t *testing.T) {
p := NewKeycloakProvider(
&ProviderData{
LoginURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/oauth/auth"},
RedeemURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/oauth/token"},
ValidateURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/api/v3/user"},
Scope: "profile"})
assert.NotEqual(t, nil, p)
assert.Equal(t, "Keycloak", p.Data().ProviderName)
assert.Equal(t, "https://example.com/oauth/auth",
p.Data().LoginURL.String())
assert.Equal(t, "https://example.com/oauth/token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://example.com/api/v3/user",
p.Data().ValidateURL.String())
assert.Equal(t, "profile", p.Data().Scope)
}

func TestKeycloakProviderGetEmailAddress(t *testing.T) {
b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\"}")
defer b.Close()

b_url, _ := url.Parse(b.URL)
p := testKeycloakProvider(b_url.Host, "")

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
}

func TestKeycloakProviderGetEmailAddressAndGroup(t *testing.T) {
b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\", \"groups\": [\"test-grp1\", \"test-grp2\"]}")
defer b.Close()

b_url, _ := url.Parse(b.URL)
p := testKeycloakProvider(b_url.Host, "test-grp1")

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
}

// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestKeycloakProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testKeycloakBackend("unused payload")
defer b.Close()

b_url, _ := url.Parse(b.URL)
p := testKeycloakProvider(b_url.Host, "")

// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}

func TestKeycloakProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testKeycloakBackend("{\"foo\": \"bar\"}")
defer b.Close()

b_url, _ := url.Parse(b.URL)
p := testKeycloakProvider(b_url.Host, "")

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
2 changes: 2 additions & 0 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ func New(provider string, p *ProviderData) Provider {
return NewAzureProvider(p)
case "gitlab":
return NewGitLabProvider(p)
case "keycloak":
return NewKeycloakProvider(p)
default:
return NewGoogleProvider(p)
}
Expand Down