diff --git a/cmd/saml2aws/commands/login.go b/cmd/saml2aws/commands/login.go index 986c42e2b..5d4bd5210 100644 --- a/cmd/saml2aws/commands/login.go +++ b/cmd/saml2aws/commands/login.go @@ -215,11 +215,14 @@ func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFla return nil, errors.Wrap(err, "Error loading saved password.") } } - } else { // if user disabled keychain, dont use Okta sessions & dont remember Okta MFA device + } else { // if user disabled keychain, dont use Okta sessions & dont remember Okta MFA device & dont save browser cookies if strings.ToLower(account.Provider) == "okta" { account.DisableSessions = true account.DisableRememberDevice = true } + if strings.ToLower(account.Provider) == "browser" { + account.DisableCookies = true + } } // log.Printf("%s %s", savedUsername, savedPassword) diff --git a/cmd/saml2aws/commands/login_test.go b/cmd/saml2aws/commands/login_test.go index 85f331be8..5db8dec74 100644 --- a/cmd/saml2aws/commands/login_test.go +++ b/cmd/saml2aws/commands/login_test.go @@ -64,6 +64,38 @@ func TestOktaResolveLoginDetailsWithFlags(t *testing.T) { } +func TestBrowserResolveLoginDetailsWithFlags(t *testing.T) { + + // Default state - user did not supply values for DisableCookies + commonFlags := &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", SkipPrompt: true} + loginFlags := &flags.LoginExecFlags{CommonFlags: commonFlags} + + idpa := &cfg.IDPAccount{ + URL: "https://id.example.com", + MFA: "none", + Provider: "Browser", + Username: "testuser", + } + loginDetails, err := resolveLoginDetails(idpa, loginFlags) + + assert.Nil(t, err) + assert.False(t, idpa.DisableCookies, fmt.Errorf("default state, DisableCookies should be false")) + assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com"}, loginDetails) + + // User disabled keychain, resolveLoginDetails should set the account's DisableCookies field to true + + commonFlags = &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", SkipPrompt: true, DisableKeychain: true} + loginFlags = &flags.LoginExecFlags{CommonFlags: commonFlags} + + loginDetails, err = resolveLoginDetails(idpa, loginFlags) + + assert.Nil(t, err) + assert.True(t, idpa.DisableCookies, fmt.Errorf("user disabled keychain, DisableCookies should be true")) + assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com" }, loginDetails) + +} + + func TestResolveRoleSingleEntry(t *testing.T) { adminRole := &saml2aws.AWSRole{ diff --git a/helper/credentials/saml.go b/helper/credentials/saml.go index 0f3ba65a8..3a6532072 100644 --- a/helper/credentials/saml.go +++ b/helper/credentials/saml.go @@ -33,6 +33,15 @@ func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error loginDetails.ClientID = id loginDetails.ClientSecret = secret } + + if provider == "Browser" { + _, cookiesJson, err := CurrentHelper.Get(path.Join(loginDetails.URL, "/browserCookieJson")) + if err != nil { + return err + } + loginDetails.CookiesJson = cookiesJson + } + return nil } diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go index c2b9277f5..a0ed872ab 100644 --- a/pkg/cfg/cfg.go +++ b/pkg/cfg/cfg.go @@ -63,6 +63,7 @@ type IDPAccount struct { TargetURL string `ini:"target_url"` DisableRememberDevice bool `ini:"disable_remember_device"` // used by Okta DisableSessions bool `ini:"disable_sessions"` // used by Okta + DisableCookies bool `ini:"disable_cookies"` // used by browser DownloadBrowser bool `ini:"download_browser_driver"` // used by browser BrowserDriverDir string `ini:"browser_driver_dir,omitempty"` // used by browser; hide from user if not set Headless bool `ini:"headless"` // used by browser diff --git a/pkg/creds/creds.go b/pkg/creds/creds.go index 9ab7d1d74..0be2ce826 100644 --- a/pkg/creds/creds.go +++ b/pkg/creds/creds.go @@ -5,6 +5,7 @@ type LoginDetails struct { ClientID string // used by OneLogin ClientSecret string // used by OneLogin DownloadBrowser bool // used by Browser + CookiesJson string // used by Browser MFAIPAddress string // used by OneLogin Username string Password string diff --git a/pkg/provider/browser/browser.go b/pkg/provider/browser/browser.go index 0681589b1..c1878933c 100644 --- a/pkg/provider/browser/browser.go +++ b/pkg/provider/browser/browser.go @@ -1,17 +1,19 @@ package browser import ( + "path" "errors" "fmt" "net/url" - "os" "regexp" "strings" + "encoding/json" "github.com/playwright-community/playwright-go" "github.com/sirupsen/logrus" "github.com/versent/saml2aws/v2/pkg/cfg" "github.com/versent/saml2aws/v2/pkg/creds" + "github.com/versent/saml2aws/v2/helper/credentials" ) var logger = logrus.WithField("provider", "browser") @@ -23,6 +25,7 @@ type Client struct { BrowserType string BrowserExecutablePath string Headless bool + DisableCookies bool // Setup alternative directory to download playwright browsers to BrowserDriverDir string Timeout int @@ -36,6 +39,7 @@ func New(idpAccount *cfg.IDPAccount) (*Client, error) { BrowserDriverDir: idpAccount.BrowserDriverDir, BrowserType: strings.ToLower(idpAccount.BrowserType), BrowserExecutablePath: idpAccount.BrowserExecutablePath, + DisableCookies: idpAccount.DisableCookies, Timeout: idpAccount.Timeout, BrowserAutoFill: idpAccount.BrowserAutoFill, }, nil @@ -113,20 +117,28 @@ func (cl *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) // create Context Optionsf contextOptions := playwright.BrowserNewContextOptions{} - // load saved storageState if present and add to contextOptions - userHomeDir, err := os.UserHomeDir() - storageStatePath := fmt.Sprintf("%s/.aws/saml2aws/storageState.json", userHomeDir) + context, err := browser.NewContext(contextOptions) if err != nil { return "", err } - if _, err := os.Stat(storageStatePath); err == nil { - contextOptions.StorageStatePath = playwright.String(storageStatePath) - } - // Create new broswer context - context, err := browser.NewContext(contextOptions) - if err != nil { - return "", err + var cookies []playwright.OptionalCookie + + if !cl.DisableCookies { + + if loginDetails.CookiesJson == "" { + logger.Info("could not retrieve cookies") + } else { + logger.Info("cookie json string length: ", len(loginDetails.CookiesJson)) + } + + if err := json.Unmarshal([]byte(loginDetails.CookiesJson), &cookies); err != nil { + logger.Info("could not unmarshal cookies", err) + } + + if err := context.AddCookies(cookies); err != nil { + logger.Info("could not add cookies", err) + } } page, err := context.NewPage() @@ -135,10 +147,24 @@ func (cl *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) } defer func() { - logger.Info("saving storage state") - _, err := context.StorageState(storageStatePath) - if err != nil { - logger.Info("Error saving storage state", err) + if !cl.DisableCookies { + logger.Info("saving storage state") + cookies, err := context.Cookies(loginDetails.URL) + + if err != nil { + logger.Info("could not get cookies", err) + } + + cookiesByteArr, err := json.Marshal(cookies) + + if err != nil { + logger.Info("Error converting storage state", err) + } + err = credentials.SaveCredentials(path.Join(loginDetails.URL, "/browserCookieJson"), loginDetails.Username, string(cookiesByteArr)) + + if err != nil { + logger.Info("Error saving storage state", err) + } } logger.Info("clean up browser") if err := context.Close(); err != nil { diff --git a/pkg/provider/browser/browser_test.go b/pkg/provider/browser/browser_test.go index af702db2a..6b21d8274 100644 --- a/pkg/provider/browser/browser_test.go +++ b/pkg/provider/browser/browser_test.go @@ -6,6 +6,7 @@ import ( "net/url" "os" "testing" + "fmt" "github.com/playwright-community/playwright-go" "github.com/stretchr/testify/assert" @@ -237,3 +238,27 @@ func TestAutoFill(t *testing.T) { assert.Equal(t, "golang:gopher", result) } } + +func TestOktaCfgFlagsDefaultState(t *testing.T) { + idpAccount := cfg.NewIDPAccount() + idpAccount.URL = "https://idp.example.com/abcd" + idpAccount.Username = "user@example.com" + + oc, err := New(idpAccount) + assert.Nil(t, err) + + assert.False(t, oc.DisableCookies, fmt.Errorf("DisableCookies should be false by default")) +} + +func TestOktaCfgFlagsCustomState(t *testing.T) { + idpAccount := cfg.NewIDPAccount() + idpAccount.URL = "https://idp.example.com/abcd" + idpAccount.Username = "user@example.com" + + idpAccount.DisableCookies = true + + oc, err := New(idpAccount) + assert.Nil(t, err) + + assert.True(t, oc.DisableCookies, fmt.Errorf("DisableCookies was set to true so DisableCookies should be true")) +} \ No newline at end of file