diff --git a/pkg/provider/aad/aad.go b/pkg/provider/aad/aad.go index 212cb3249..22d842b9c 100644 --- a/pkg/provider/aad/aad.go +++ b/pkg/provider/aad/aad.go @@ -34,26 +34,27 @@ type Client struct { // Autogenrated Converged Response struct // for some cases, some fields may not exist type ConvergedResponse struct { - URLGetCredentialType string `json:"urlGetCredentialType"` - ArrUserProofs []userProof `json:"arrUserProofs"` - URLSkipMfaRegistration string `json:"urlSkipMfaRegistration"` - OPerAuthPollingInterval map[string]float64 `json:"oPerAuthPollingInterval"` - URLBeginAuth string `json:"urlBeginAuth"` - URLEndAuth string `json:"urlEndAuth"` - URLPost string `json:"urlPost"` - SErrorCode string `json:"sErrorCode"` - SErrTxt string `json:"sErrTxt"` - SPOSTUsername string `json:"sPOST_Username"` - SFT string `json:"sFT"` - SFTName string `json:"sFTName"` - SCtx string `json:"sCtx"` - Hpgact int `json:"hpgact"` - Hpgid int `json:"hpgid"` - Pgid string `json:"pgid"` - APICanary string `json:"apiCanary"` - Canary string `json:"canary"` - CorrelationID string `json:"correlationId"` - SessionID string `json:"sessionId"` + URLGetCredentialType string `json:"urlGetCredentialType"` + ArrUserProofs []userProof `json:"arrUserProofs"` + URLSkipMfaRegistration string `json:"urlSkipMfaRegistration"` + OPerAuthPollingInterval map[string]float64 `json:"oPerAuthPollingInterval"` + URLBeginAuth string `json:"urlBeginAuth"` + URLEndAuth string `json:"urlEndAuth"` + URLPost string `json:"urlPost"` + SErrorCode string `json:"sErrorCode"` + SErrTxt string `json:"sErrTxt"` + SPOSTUsername string `json:"sPOST_Username"` + SFT string `json:"sFT"` + SFTName string `json:"sFTName"` + SCtx string `json:"sCtx"` + Hpgact int `json:"hpgact"` + Hpgid int `json:"hpgid"` + Pgid string `json:"pgid"` + APICanary string `json:"apiCanary"` + Canary string `json:"canary"` + CorrelationID string `json:"correlationId"` + SessionID string `json:"sessionId"` + AuthMethodInputFieldName string `json:"sAuthMethodInputFieldName"` } // Autogenerated GetCredentialType Request struct @@ -195,6 +196,9 @@ AuthProcessor: case strings.Contains(resBodyStr, "SAMLRequest"): logger.Debug("processing SAMLRequest") res, err = ac.processSAMLRequest(res, resBodyStr) + case strings.Contains(resBodyStr, "CmsiInterrupt"): + logger.Debug("processing CmsiInterrupt") + res, err = ac.processCmsiInterrupt(res, resBodyStr) case ac.isHiddenForm(resBodyStr): if samlAssertion, _ = ac.getSamlAssertion(resBodyStr); samlAssertion != "" { logger.Debug("processing a SAMLResponse") @@ -389,6 +393,8 @@ func (ac *Client) processKmsiInterrupt(res *http.Response, srcBodyStr string) (* formValues.Set(convergedResponse.SFTName, convergedResponse.SFT) formValues.Set("ctx", convergedResponse.SCtx) formValues.Set("LoginOptions", "1") + formValues.Set("canary", convergedResponse.Canary) + formValues.Set("hpgrequestid", convergedResponse.SessionID) req, err := http.NewRequest("POST", ac.fullUrl(res, convergedResponse.URLPost), strings.NewReader(formValues.Encode())) if err != nil { @@ -407,6 +413,38 @@ func (ac *Client) processKmsiInterrupt(res *http.Response, srcBodyStr string) (* return res, nil } +func (ac *Client) processCmsiInterrupt(res *http.Response, srcBodyStr string) (*http.Response, error) { + var convergedResponse *ConvergedResponse + var err error + + if err := ac.unmarshalEmbeddedJson(srcBodyStr, &convergedResponse); err != nil { + return res, errors.Wrap(err, "CMSI request unmarshal error") + } + + formValues := url.Values{} + formValues.Set("canary", convergedResponse.Canary) + formValues.Set("ContinueAuth", "true") + formValues.Set("ctx", convergedResponse.SCtx) + formValues.Set(convergedResponse.SFTName, convergedResponse.SFT) + formValues.Set("hpgrequestid", convergedResponse.SessionID) + + req, err := http.NewRequest("POST", ac.fullUrl(res, convergedResponse.URLPost), strings.NewReader(formValues.Encode())) + if err != nil { + return res, errors.Wrap(err, "error building CMSI request") + } + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + ac.client.DisableFollowRedirect() + res, err = ac.client.Do(req) + if err != nil { + return res, errors.Wrap(err, "error retrieving CMSI results") + } + ac.client.EnableFollowRedirect() + + return res, nil +} + func (ac *Client) processConvergedTFA(res *http.Response, srcBodyStr string, loginDetails *creds.LoginDetails) (*http.Response, error) { var convergedResponse *ConvergedResponse var err error @@ -603,6 +641,9 @@ func (ac *Client) processMfaAuth(mfaResp mfaResponse, convergedResponse *Converg formValues.Set(convergedResponse.SFTName, mfaResp.FlowToken) formValues.Set("request", mfaResp.Ctx) formValues.Set("login", convergedResponse.SPOSTUsername) + formValues.Set(convergedResponse.AuthMethodInputFieldName, mfaResp.AuthMethodID) + formValues.Set("canary", convergedResponse.APICanary) + formValues.Set("hpgrequestid", convergedResponse.SessionID) req, err = http.NewRequest("POST", convergedResponse.URLPost, strings.NewReader(formValues.Encode())) if err != nil {