forked from hashicorp/cap
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauthn_request.go
More file actions
399 lines (337 loc) · 11.6 KB
/
authn_request.go
File metadata and controls
399 lines (337 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package saml
import (
"bytes"
"compress/flate"
"encoding/base64"
"encoding/xml"
"fmt"
"net/http"
"net/url"
"strings"
"text/template"
"time"
"github.com/jonboulle/clockwork"
"github.com/hashicorp/cap/saml/models/core"
)
const (
// postBindingScriptSha256 is a base64 encoded sha256 hash generated from the javascript within the script tag in ./authn_request.gohtml.
// The hash is set in the Content-Security-Policy header when using the SAML HTTP POST-Binding Authentication Request.
// You can read more about the header and how the hash is generated here: https://content-security-policy.com/hash/
// As the POST-Binding script is static, this value is static as well and shouldn't change.
postBindingScriptSha256 = "sha256-T8Q9GZiIVtYoNIdF6UW5hDNgJudFDijQM/usO+xUkes="
)
type authnRequestOptions struct {
clock clockwork.Clock
allowCreate bool
nameIDFormat core.NameIDFormat
forceAuthn bool
protocolBinding core.ServiceBinding
authnContextClassRefs []string
indent int
assertionConsumerServiceURL string
}
func authnRequestOptionsDefault() authnRequestOptions {
return authnRequestOptions{
allowCreate: false,
clock: clockwork.NewRealClock(),
nameIDFormat: core.NameIDFormat(""),
forceAuthn: false,
protocolBinding: core.ServiceBindingHTTPPost,
}
}
func getAuthnRequestOptions(opt ...Option) authnRequestOptions {
opts := authnRequestOptionsDefault()
ApplyOpts(&opts, opt...)
return opts
}
// AllowCreate is a Boolean value used to indicate whether the identity provider is allowed, in the course
// of fulfilling the request, to create a new identifier to represent the principal.
func AllowCreate() Option {
return func(o interface{}) {
if o, ok := o.(*authnRequestOptions); ok {
o.allowCreate = true
}
}
}
// WithNameIDFormat will set an NameIDPolicy object with the
// given NameIDFormat. It implies allowCreate=true as recommended by
// the SAML 2.0 spec, which says:
// "Requesters that do not make specific use of this (AllowCreate) attribute SHOULD generally set it to “true”
// to maximize interoperability."
// See https://www.oasis-open.org/committees/download.php/56776/sstc-saml-core-errata-2.0-wd-07.pdf
func WithNameIDFormat(f core.NameIDFormat) Option {
return func(o interface{}) {
if o, ok := o.(*authnRequestOptions); ok {
o.nameIDFormat = f
o.allowCreate = true
}
}
}
// ForceAuthentication is a boolean value that tells the identity provider it MUST authenticate the presenter
// directly rather than rely on a previous security context.
func ForceAuthn() Option {
return func(o interface{}) {
if o, ok := o.(*authnRequestOptions); ok {
o.forceAuthn = true
}
}
}
// WithProtocolBinding defines the ProtocolBinding to be used. It defaults to HTTP-Post.
// The ProtocolBinding is a URI reference that identifies a SAML protocol binding to be used
// when returning the <Response> message.
func WithProtocolBinding(binding core.ServiceBinding) Option {
return func(o interface{}) {
if o, ok := o.(*authnRequestOptions); ok {
o.protocolBinding = binding
}
}
}
// WithAuthContextClassRefs defines AuthnContextClassRefs.
// An AuthContextClassRef Specifies the requirements, if any, that the requester places on the
// authentication context that applies to the responding provider's authentication of the presenter.
func WithAuthContextClassRefs(cfs []string) Option {
return func(o interface{}) {
if o, ok := o.(*authnRequestOptions); ok {
o.authnContextClassRefs = cfs
}
}
}
// WithIndent indent the XML document when marshalling it.
func WithIndent(indent int) Option {
return func(o interface{}) {
if o, ok := o.(*authnRequestOptions); ok {
o.indent = indent
}
}
}
// WithClock changes the clock used when generating requests.
func WithClock(clock clockwork.Clock) Option {
return func(o interface{}) {
switch opts := o.(type) {
case *authnRequestOptions:
opts.clock = clock
case *parseResponseOptions:
opts.clock = clock
case *idpMetadataOptions:
opts.clock = clock
}
}
}
// WithAssertionConsumerServiceURL changes the Assertion Consumer Service URL
// to use in the Auth Request or during the response validation
func WithAssertionConsumerServiceURL(url string) Option {
return func(o interface{}) {
switch opts := o.(type) {
case *authnRequestOptions:
opts.assertionConsumerServiceURL = url
case *parseResponseOptions:
opts.assertionConsumerServiceURL = url
}
}
}
// CreateAuthnRequest creates an Authentication Request object.
// The defaults follow the deployment profile for federation interoperability.
// See: 3.1.1 https://kantarainitiative.github.io/SAMLprofiles/saml2int.html#_service_provider_requirements [INT_SAML]
//
// Options:
// - WithClock
// - ForceAuthn
// - AllowCreate
// - WithIDFormat
// - WithProtocolBinding
// - WithAuthContextClassRefs
// - WithAssertionConsumerServiceURL
func (sp *ServiceProvider) CreateAuthnRequest(
id string,
binding core.ServiceBinding,
opt ...Option,
) (*core.AuthnRequest, error) {
const op = "saml.ServiceProvider.CreateAuthnRequest"
if id == "" {
return nil, fmt.Errorf("%s: no ID provided: %w", op, ErrInvalidParameter)
}
if binding == "" {
return nil, fmt.Errorf("%s: no binding provided: %w", op, ErrInvalidParameter)
}
opts := getAuthnRequestOptions(opt...)
destination, err := sp.destination(binding)
if err != nil {
return nil, fmt.Errorf(
"%s: failed to get destination for given service binding (%s): %w",
op,
binding,
err,
)
}
ar := &core.AuthnRequest{}
ar.ID = id
ar.Version = core.SAMLVersion2
ar.ProtocolBinding = opts.protocolBinding
// [INT_SAML][SDP-SP05][SDP-SP06]
// "The message SHOULD contain an AssertionConsumerServiceURL attribute and MUST NOT contain an
// AssertionConsumerServiceIndex attribute (i.e., the desired endpoint MUST be the default,
// or identified via the AssertionConsumerServiceURL attribute)."
ar.AssertionConsumerServiceURL = sp.cfg.AssertionConsumerServiceURL
if opts.assertionConsumerServiceURL != "" {
ar.AssertionConsumerServiceURL = opts.assertionConsumerServiceURL
}
ar.IssueInstant = opts.clock.Now().Truncate(time.Microsecond).UTC()
ar.Destination = destination
ar.Issuer = &core.Issuer{}
ar.Issuer.Value = sp.cfg.EntityID
// [INT_SAML][SDP-SP04]
// "The <samlp:AuthnRequest> message MUST either omit the <samlp:NameIDPolicy> element (RECOMMENDED),
// or the element MUST contain an AllowCreate attribute of "true" and MUST NOT contain a Format attribute."
if opts.allowCreate || opts.nameIDFormat != "" {
ar.NameIDPolicy = &core.NameIDPolicy{
AllowCreate: opts.allowCreate,
}
// This will only be set if the option WithNameIDFormat is set.
if opts.nameIDFormat != "" {
ar.NameIDPolicy.Format = opts.nameIDFormat
}
}
// [INT_SAML][SDP-SP07]
// "An SP that does not require a specific <saml:AuthnContextClassRef> value MUST NOT include a
// <samlp:RequestedAuthnContext> element in its requests.
// An SP that requires specific <saml:AuthnContextClassRef> values MUST specify the allowable values
// in a <samlp:RequestedAuthnContext> element in its requests, with the Comparison attribute set to exact."
if len(opts.authnContextClassRefs) > 0 {
ar.RequestedAuthContext = &core.RequestedAuthnContext{
AuthnContextClassRef: opts.authnContextClassRefs,
Comparison: core.ComparisonExact,
}
}
ar.ForceAuthn = opts.forceAuthn
return ar, nil
}
// AuthnRequestPost creates an AuthRequest with HTTP-Post binding.
func (sp *ServiceProvider) AuthnRequestPost(
relayState string, opt ...Option,
) ([]byte, *core.AuthnRequest, error) {
const op = "saml.ServiceProvider.AuthnRequestPost"
requestID, err := sp.cfg.GenerateAuthRequestID()
if err != nil {
return nil, nil, fmt.Errorf(
"%s: failed to generate authentication request ID: %w",
op,
ErrInternal,
)
}
authN, err := sp.CreateAuthnRequest(requestID, core.ServiceBindingHTTPPost)
if err != nil {
return nil, nil, fmt.Errorf(
"%s: failed to create authentication request: %w",
op,
ErrInternal,
)
}
opts := getAuthnRequestOptions(opt...)
payload, err := authN.CreateXMLDocument(opts.indent)
if err != nil {
return nil, nil, fmt.Errorf(
"%s: failed to create request XML: %w",
op,
ErrInternal,
)
}
b64Payload := base64.StdEncoding.EncodeToString(payload)
tmpl := template.Must(
template.New("post-binding").Parse(postBindingTempl),
)
buf := bytes.Buffer{}
if err := tmpl.Execute(&buf, map[string]string{
"Destination": authN.Destination,
"SAMLRequest": b64Payload,
"RelayState": relayState,
}); err != nil {
return nil, nil, fmt.Errorf(
"%s: failed to execute POST binding template: %w",
op,
ErrInternal,
)
}
return buf.Bytes(), authN, nil
}
// WritePostBindingRequestHeader writes recommended content headers when using the SAML HTTP POST binding.
func WritePostBindingRequestHeader(w http.ResponseWriter) error {
const op = "saml.WritePostBindingHeader"
if w == nil {
return fmt.Errorf("%s: response writer is nil", op)
}
w.Header().
Add("Content-Security-Policy", fmt.Sprintf("script-src '%s'", postBindingScriptSha256))
w.Header().Add("Content-type", "text/html")
return nil
}
// AuthRequestRedirect creates a SAML authentication request with HTTP redirect binding.
func (sp *ServiceProvider) AuthnRequestRedirect(
relayState string, opts ...Option,
) (*url.URL, *core.AuthnRequest, error) {
const op = "saml.ServiceProvider.AuthnRequestRedirect"
requestID, err := sp.cfg.GenerateAuthRequestID()
if err != nil {
return nil, nil, fmt.Errorf(
"%s: failed to generate authentication request ID: %w",
op,
err,
)
}
authN, err := sp.CreateAuthnRequest(requestID, core.ServiceBindingHTTPRedirect, opts...)
if err != nil {
return nil, nil, fmt.Errorf(
"%s: failed to create SAML auth request: %w",
op,
err,
)
}
payload, err := Deflate(authN, opts...)
if err != nil {
return nil, nil, fmt.Errorf("%s: failed to deflate/compress request: %w", op, err)
}
b64Payload := base64.StdEncoding.EncodeToString(payload)
redirect, err := url.Parse(authN.Destination)
if err != nil {
return nil, nil, fmt.Errorf("%s: failed to parse destination URL: %w", op, err)
}
// if sp.SignRequest {
// ctx := sp.SigningContext()
// qs.Add("SigAlg", ctx.GetSignatureMethodIdentifier())
// var rawSignature []byte
// if rawSignature, err = ctx.SignString(signatureInputString(qs.Get("SAMLRequest"), qs.Get("RelayState"), qs.Get("SigAlg"))); err != nil {
// return "", fmt.Errorf("unable to sign query string of redirect URL: %v", err)
// }
// // Now add base64 encoded Signature
// qs.Add("Signature", base64.StdEncoding.EncodeToString(rawSignature))
// }
vals := redirect.Query()
vals.Set("SAMLRequest", b64Payload)
if relayState != "" {
vals.Set("RelayState", relayState)
}
redirect.RawQuery = vals.Encode()
return redirect, authN, nil
}
// Deflate returns an AuthnRequest in the Deflate file format, applying default
// compression.
func Deflate(authn *core.AuthnRequest, opt ...Option) ([]byte, error) {
const op = "saml.Deflate"
buf := bytes.Buffer{}
opts := getAuthnRequestOptions(opt...)
fw, err := flate.NewWriter(&buf, flate.DefaultCompression)
if err != nil {
return nil, fmt.Errorf("%s: failed to create new flate writer: %w", op, err)
}
encoder := xml.NewEncoder(fw)
encoder.Indent("", strings.Repeat(" ", opts.indent))
err = encoder.Encode(authn)
if err != nil {
return nil, fmt.Errorf("%s: failed to XML encode SAML authn request: %w", op, err)
}
if err := fw.Close(); err != nil {
return nil, fmt.Errorf("%s: failed to close flate writer: %w", op, err)
}
return buf.Bytes(), nil
}