diff --git a/HawkNet.Owin.Tests/HawkAuthenticationHandlerFixture.cs b/HawkNet.Owin.Tests/HawkAuthenticationHandlerFixture.cs index 773ed73..87c6d48 100644 --- a/HawkNet.Owin.Tests/HawkAuthenticationHandlerFixture.cs +++ b/HawkNet.Owin.Tests/HawkAuthenticationHandlerFixture.cs @@ -636,6 +636,166 @@ public void ShouldParseValidAuthHeaderAndPayloadWithSha256() Assert.IsTrue(logger.Messages.Count == 0); } + [TestMethod] + public void ShouldNotThrowWhenIncludeServerAuthorizationIsTrueAndAuthorizationIsMissing() + { + var credential = new HawkCredential + { + Id = "123", + Algorithm = "sha256", + Key = "werxhqb98rpaxn39848xrunpaw3489ruxnpa98w4rxn", + User = "steve" + }; + + var body = "hello world"; + var bodyBytes = Encoding.UTF8.GetBytes(body); + var ms = new MemoryStream(); + ms.Write(bodyBytes, 0, bodyBytes.Length); + ms.Flush(); + ms.Seek(0, SeekOrigin.Begin); + + var logger = new Logger(); + var builder = new AppBuilderFactory().Create(); + builder.SetLoggerFactory(new LoggerFactory(logger)); + var context = new OwinContext(); + var request = (OwinRequest)context.Request; + + request.Set, object>>("server.OnSendingHeaders", RegisterForOnSendingHeaders); + request.Method = "post"; + request.Body = ms; + request.SetHeader("Host", new string[] { "example.com" }); + request.SetUri(new Uri("http://example.com:8080/resource/4?filter=a")); + request.ContentType = "text/plain"; + + var response = (OwinResponse)context.Response; + + var middleware = new HawkAuthenticationMiddleware( + new AppFuncTransition((env) => + { + response.StatusCode = 200; + return Task.FromResult(null); + }), + builder, + new HawkAuthenticationOptions + { + Credentials = (id) => Task.FromResult(credential), + IncludeServerAuthorization = true + } + ); + + var task = middleware.Invoke(context); + + Assert.AreEqual(200, response.StatusCode); + Assert.AreEqual(null, task.Exception); + } + + [TestMethod] + public void ShouldNotThrowWhenIncludeServerAuthorizationIsTrueAndAuthorizationIsOtherScheme() + { + var credential = new HawkCredential + { + Id = "123", + Algorithm = "sha256", + Key = "werxhqb98rpaxn39848xrunpaw3489ruxnpa98w4rxn", + User = "steve" + }; + + var body = "hello world"; + var bodyBytes = Encoding.UTF8.GetBytes(body); + var ms = new MemoryStream(); + ms.Write(bodyBytes, 0, bodyBytes.Length); + ms.Flush(); + ms.Seek(0, SeekOrigin.Begin); + + var logger = new Logger(); + var builder = new AppBuilderFactory().Create(); + builder.SetLoggerFactory(new LoggerFactory(logger)); + var context = new OwinContext(); + var request = (OwinRequest)context.Request; + request.SetHeader("Authorization", new[] { "OtherScheme" }); + + request.Set, object>>("server.OnSendingHeaders", RegisterForOnSendingHeaders); + request.Method = "post"; + request.Body = ms; + request.SetHeader("Host", new string[] { "example.com" }); + request.SetUri(new Uri("http://example.com:8080/resource/4?filter=a")); + request.ContentType = "text/plain"; + + var response = (OwinResponse)context.Response; + + var middleware = new HawkAuthenticationMiddleware( + new AppFuncTransition((env) => + { + response.StatusCode = 200; + return Task.FromResult(null); + }), + builder, + new HawkAuthenticationOptions + { + Credentials = (id) => Task.FromResult(credential), + IncludeServerAuthorization = true + } + ); + + var task = middleware.Invoke(context); + + Assert.AreEqual(200, response.StatusCode); + Assert.AreEqual(null, task.Exception); + } + [TestMethod] + public void ShouldNotThrowWhenIncludeServerAuthorizationIsTrueAndAuthorizationIsEmpty() + { + var credential = new HawkCredential + { + Id = "123", + Algorithm = "sha256", + Key = "werxhqb98rpaxn39848xrunpaw3489ruxnpa98w4rxn", + User = "steve" + }; + + var body = "hello world"; + var bodyBytes = Encoding.UTF8.GetBytes(body); + var ms = new MemoryStream(); + ms.Write(bodyBytes, 0, bodyBytes.Length); + ms.Flush(); + ms.Seek(0, SeekOrigin.Begin); + + var logger = new Logger(); + var builder = new AppBuilderFactory().Create(); + builder.SetLoggerFactory(new LoggerFactory(logger)); + var context = new OwinContext(); + var request = (OwinRequest)context.Request; + request.SetHeader("Authorization", new[] { "" }); + + request.Set, object>>("server.OnSendingHeaders", RegisterForOnSendingHeaders); + request.Method = "post"; + request.Body = ms; + request.SetHeader("Host", new string[] { "example.com" }); + request.SetUri(new Uri("http://example.com:8080/resource/4?filter=a")); + request.ContentType = "text/plain"; + + var response = (OwinResponse)context.Response; + + var middleware = new HawkAuthenticationMiddleware( + new AppFuncTransition((env) => + { + response.StatusCode = 200; + return Task.FromResult(null); + }), + builder, + new HawkAuthenticationOptions + { + Credentials = (id) => Task.FromResult(credential), + IncludeServerAuthorization = true + } + ); + + var task = middleware.Invoke(context); + + Assert.AreEqual(200, response.StatusCode); + Assert.AreEqual(null, task.Exception); + } + [TestMethod] public void ShouldAuthenticateServer() { diff --git a/HawkNet.Owin/HawkAuthenticationHandler.cs b/HawkNet.Owin/HawkAuthenticationHandler.cs index 4ed01ce..f0484eb 100644 --- a/HawkNet.Owin/HawkAuthenticationHandler.cs +++ b/HawkNet.Owin/HawkAuthenticationHandler.cs @@ -64,7 +64,7 @@ protected override async Task AuthenticateCoreAsync() if (Request.Headers.ContainsKey("authorization")) { - authorization = AuthenticationHeaderValue.Parse(Request.Headers["authorization"]); + AuthenticationHeaderValue.TryParse(Request.Headers["authorization"], out authorization); } if (authorization != null && @@ -153,15 +153,18 @@ protected override async Task ApplyResponseChallengeAsync() { if (this.Options.IncludeServerAuthorization) { - var authorization = AuthenticationHeaderValue.Parse(Request.Headers["authorization"]); - - await AuthenticateResponse(authorization.Parameter, - Request.Host.Value, - Request.Method, - Request.Uri, - Response.ContentType, - this.Options.Credentials, - Response); + AuthenticationHeaderValue authorization; + if (AuthenticationHeaderValue.TryParse(Request.Headers["authorization"], out authorization) + && authorization.Scheme.Equals(HawkAuthenticationOptions.Scheme, StringComparison.OrdinalIgnoreCase)) + { + await AuthenticateResponse(authorization.Parameter, + Request.Host.Value, + Request.Method, + Request.Uri, + Response.ContentType, + this.Options.Credentials, + Response); + } } } else if (Response.StatusCode == 401)