From 3ee0804a7b57e45ca569af9750ae1a8e21152886 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Feb 2026 10:51:21 +0000 Subject: [PATCH 01/67] [py] BiDi Python code generation from CDDL --- common/bidi/spec/BUILD.bazel | 13 + common/bidi/spec/all.cddl | 2489 +++++++++++++++++ common/bidi/spec/local.cddl | 1331 +++++++++ common/bidi/spec/remote.cddl | 1716 ++++++++++++ py/AGENTS.md | 17 +- py/BUILD.bazel | 19 + py/conftest.py | 8 + py/generate_bidi.py | 1824 ++++++++++++ py/private/BUILD.bazel | 5 + py/private/bidi_enhancements_manifest.py | 1557 +++++++++++ py/private/cdp.py | 515 ++++ py/private/generate_bidi.bzl | 112 + py/requirements.txt | 1 + py/requirements_lock.txt | 5 +- py/selenium/common/exceptions.py | 62 +- py/selenium/webdriver/common/bidi/__init__.py | 21 +- py/selenium/webdriver/common/bidi/browser.py | 508 ++-- .../webdriver/common/bidi/browsing_context.py | 1505 +++++----- py/selenium/webdriver/common/bidi/common.py | 27 +- py/selenium/webdriver/common/bidi/console.py | 0 .../webdriver/common/bidi/emulation.py | 835 +++--- py/selenium/webdriver/common/bidi/input.py | 684 +++-- py/selenium/webdriver/common/bidi/log.py | 140 +- py/selenium/webdriver/common/bidi/network.py | 1151 ++++++-- .../webdriver/common/bidi/permissions.py | 85 +- py/selenium/webdriver/common/bidi/py.typed | 0 py/selenium/webdriver/common/bidi/script.py | 1539 +++++++--- py/selenium/webdriver/common/bidi/session.py | 314 ++- py/selenium/webdriver/common/bidi/storage.py | 593 ++-- .../webdriver/common/bidi/webextension.py | 142 +- py/selenium/webdriver/common/by.py | 13 +- py/selenium/webdriver/common/proxy.py | 36 +- py/selenium/webdriver/remote/webdriver.py | 202 +- .../webdriver/remote/websocket_connection.py | 37 +- .../webdriver/common/bidi_browser_tests.py | 11 +- 35 files changed, 14293 insertions(+), 3224 deletions(-) create mode 100644 common/bidi/spec/BUILD.bazel create mode 100644 common/bidi/spec/all.cddl create mode 100644 common/bidi/spec/local.cddl create mode 100644 common/bidi/spec/remote.cddl create mode 100755 py/generate_bidi.py create mode 100644 py/private/bidi_enhancements_manifest.py create mode 100644 py/private/cdp.py create mode 100644 py/private/generate_bidi.bzl mode change 100644 => 100755 py/selenium/webdriver/common/bidi/console.py create mode 100755 py/selenium/webdriver/common/bidi/py.typed diff --git a/common/bidi/spec/BUILD.bazel b/common/bidi/spec/BUILD.bazel new file mode 100644 index 0000000000000..74c3cffd35ed0 --- /dev/null +++ b/common/bidi/spec/BUILD.bazel @@ -0,0 +1,13 @@ +package( + default_visibility = [ + "//py:__pkg__", + ], +) + +exports_files( + srcs = [ + "all.cddl", + "local.cddl", + "remote.cddl", + ], +) diff --git a/common/bidi/spec/all.cddl b/common/bidi/spec/all.cddl new file mode 100644 index 0000000000000..85c4536a2cd10 --- /dev/null +++ b/common/bidi/spec/all.cddl @@ -0,0 +1,2489 @@ +Command = { + id: js-uint, + CommandData, + Extensible, +} + +CommandData = ( + BrowserCommand // + BrowsingContextCommand // + EmulationCommand // + InputCommand // + NetworkCommand // + ScriptCommand // + SessionCommand // + StorageCommand // + WebExtensionCommand +) + +EmptyParams = { + Extensible +} + +Message = ( + CommandResponse / + ErrorResponse / + Event +) + +CommandResponse = { + type: "success", + id: js-uint, + result: ResultData, + Extensible +} + +ErrorResponse = { + type: "error", + id: js-uint / null, + error: ErrorCode, + message: text, + ? stacktrace: text, + Extensible +} + +ResultData = ( + BrowserResult / + BrowsingContextResult / + EmulationResult / + InputResult / + NetworkResult / + ScriptResult / + SessionResult / + StorageResult / + WebExtensionResult +) + +EmptyResult = { + Extensible +} + +Event = { + type: "event", + EventData, + Extensible +} + +EventData = ( + BrowsingContextEvent // + InputEvent // + LogEvent // + NetworkEvent // + ScriptEvent +) + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +ErrorCode = "invalid argument" / + "invalid selector" / + "invalid session id" / + "invalid web extension" / + "move target out of bounds" / + "no such alert" / + "no such network collector" / + "no such element" / + "no such frame" / + "no such handle" / + "no such history entry" / + "no such intercept" / + "no such network data" / + "no such node" / + "no such request" / + "no such script" / + "no such storage partition" / + "no such user context" / + "no such web extension" / + "session not created" / + "unable to capture screen" / + "unable to close browser" / + "unable to set cookie" / + "unable to set file input" / + "unavailable network data" / + "underspecified storage partition" / + "unknown command" / + "unknown error" / + "unsupported operation" + +SessionCommand = ( + session.End // + session.New // + session.Status // + session.Subscribe // + session.Unsubscribe +) + +SessionResult = ( + session.EndResult / + session.NewResult / + session.StatusResult / + session.SubscribeResult / + session.UnsubscribeResult +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.SubscribeParameters = { + events: [+text], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +session.UnsubscribeByIDRequest = { + subscriptions: [+session.Subscription], +} + +session.UnsubscribeByAttributesRequest = { + events: [+text], +} + +session.Status = ( + method: "session.status", + params: EmptyParams, +) + +session.StatusResult = { + ready: bool, + message: text, +} + +session.New = ( + method: "session.new", + params: session.NewParameters +) + +session.NewParameters = { + capabilities: session.CapabilitiesRequest +} + +session.NewResult = { + sessionId: text, + capabilities: { + acceptInsecureCerts: bool, + browserName: text, + browserVersion: text, + platformName: text, + setWindowRect: bool, + userAgent: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + ? webSocketUrl: text, + Extensible + } +} + +session.End = ( + method: "session.end", + params: EmptyParams +) + + +session.EndResult = EmptyResult + +session.Subscribe = ( + method: "session.subscribe", + params: session.SubscribeParameters +) + +session.SubscribeResult = { + subscription: session.Subscription, +} + +session.Unsubscribe = ( + method: "session.unsubscribe", + params: session.UnsubscribeParameters, +) + +session.UnsubscribeParameters = session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + +session.UnsubscribeResult = EmptyResult + +BrowserCommand = ( + browser.Close // + browser.CreateUserContext // + browser.GetClientWindows // + browser.GetUserContexts // + browser.RemoveUserContext // + browser.SetClientWindowState // + browser.SetDownloadBehavior +) + +BrowserResult = ( + browser.CloseResult / + browser.CreateUserContextResult / + browser.GetClientWindowsResult / + browser.GetUserContextsResult / + browser.RemoveUserContextResult / + browser.SetClientWindowStateResult / + browser.SetDownloadBehaviorResult +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.Close = ( + method: "browser.close", + params: EmptyParams, +) + +browser.CloseResult = EmptyResult + +browser.CreateUserContext = ( + method: "browser.createUserContext", + params: browser.CreateUserContextParameters, +) + +browser.CreateUserContextParameters = { + ? acceptInsecureCerts: bool, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler +} + +browser.CreateUserContextResult = browser.UserContextInfo + +browser.GetClientWindows = ( + method: "browser.getClientWindows", + params: EmptyParams, +) + +browser.GetClientWindowsResult = { + clientWindows: [ * browser.ClientWindowInfo] +} + +browser.GetUserContexts = ( + method: "browser.getUserContexts", + params: EmptyParams, +) + +browser.GetUserContextsResult = { + userContexts: [ + browser.UserContextInfo] +} + +browser.RemoveUserContext = ( + method: "browser.removeUserContext", + params: browser.RemoveUserContextParameters +) + +browser.RemoveUserContextParameters = { + userContext: browser.UserContext +} + +browser.RemoveUserContextResult = EmptyResult + +browser.SetClientWindowState = ( + method: "browser.setClientWindowState", + params: browser.SetClientWindowStateParameters +) + +browser.SetClientWindowStateParameters = { + clientWindow: browser.ClientWindow, + (browser.ClientWindowNamedState // browser.ClientWindowRectState) +} + +browser.ClientWindowNamedState = ( + state: "fullscreen" / "maximized" / "minimized" +) + +browser.ClientWindowRectState = ( + state: "normal", + ? width: js-uint, + ? height: js-uint, + ? x: js-int, + ? y: js-int, +) + +browser.SetClientWindowStateResult = browser.ClientWindowInfo + +browser.SetDownloadBehavior = ( + method: "browser.setDownloadBehavior", + params: browser.SetDownloadBehaviorParameters +) + +browser.SetDownloadBehaviorParameters = { + downloadBehavior: browser.DownloadBehavior / null, + ? userContexts: [+browser.UserContext] +} + +browser.DownloadBehavior = { + ( + browser.DownloadBehaviorAllowed // + browser.DownloadBehaviorDenied + ) +} + +browser.DownloadBehaviorAllowed = ( + type: "allowed", + destinationFolder: text +) + +browser.DownloadBehaviorDenied = ( + type: "denied" +) + +browser.SetDownloadBehaviorResult = EmptyResult + +BrowsingContextCommand = ( + browsingContext.Activate // + browsingContext.CaptureScreenshot // + browsingContext.Close // + browsingContext.Create // + browsingContext.GetTree // + browsingContext.HandleUserPrompt // + browsingContext.LocateNodes // + browsingContext.Navigate // + browsingContext.Print // + browsingContext.Reload // + browsingContext.SetViewport // + browsingContext.TraverseHistory +) + +BrowsingContextResult = ( + browsingContext.ActivateResult / + browsingContext.CaptureScreenshotResult / + browsingContext.CloseResult / + browsingContext.CreateResult / + browsingContext.GetTreeResult / + browsingContext.HandleUserPromptResult / + browsingContext.LocateNodesResult / + browsingContext.NavigateResult / + browsingContext.PrintResult / + browsingContext.ReloadResult / + browsingContext.SetViewportResult / + browsingContext.TraverseHistoryResult +) + +BrowsingContextEvent = ( + browsingContext.ContextCreated // + browsingContext.ContextDestroyed // + browsingContext.DomContentLoaded // + browsingContext.DownloadEnd // + browsingContext.DownloadWillBegin // + browsingContext.FragmentNavigated // + browsingContext.HistoryUpdated // + browsingContext.Load // + browsingContext.NavigationAborted // + browsingContext.NavigationCommitted // + browsingContext.NavigationFailed // + browsingContext.NavigationStarted // + browsingContext.UserPromptClosed // + browsingContext.UserPromptOpened +) + +browsingContext.BrowsingContext = text; + +browsingContext.InfoList = [*browsingContext.Info] + +browsingContext.Info = { + children: browsingContext.InfoList / null, + clientWindow: browser.ClientWindow, + context: browsingContext.BrowsingContext, + originalOpener: browsingContext.BrowsingContext / null, + url: text, + userContext: browser.UserContext, + ? parent: browsingContext.BrowsingContext / null, +} + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.BaseNavigationInfo = ( + context: browsingContext.BrowsingContext, + navigation: browsingContext.Navigation / null, + timestamp: js-uint, + url: text, +) + +browsingContext.NavigationInfo = { + browsingContext.BaseNavigationInfo +} + +browsingContext.ReadinessState = "none" / "interactive" / "complete" + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.Activate = ( + method: "browsingContext.activate", + params: browsingContext.ActivateParameters +) + +browsingContext.ActivateParameters = { + context: browsingContext.BrowsingContext +} + +browsingContext.ActivateResult = EmptyResult + +browsingContext.CaptureScreenshot = ( + method: "browsingContext.captureScreenshot", + params: browsingContext.CaptureScreenshotParameters +) + +browsingContext.CaptureScreenshotParameters = { + context: browsingContext.BrowsingContext, + ? origin: ("viewport" / "document") .default "viewport", + ? format: browsingContext.ImageFormat, + ? clip: browsingContext.ClipRectangle, +} + +browsingContext.ImageFormat = { + type: text, + ? quality: 0.0..1.0, +} + +browsingContext.ClipRectangle = ( + browsingContext.BoxClipRectangle / + browsingContext.ElementClipRectangle +) + +browsingContext.ElementClipRectangle = { + type: "element", + element: script.SharedReference +} + +browsingContext.BoxClipRectangle = { + type: "box", + x: float, + y: float, + width: float, + height: float +} + +browsingContext.CaptureScreenshotResult = { + data: text +} + +browsingContext.Close = ( + method: "browsingContext.close", + params: browsingContext.CloseParameters +) + +browsingContext.CloseParameters = { + context: browsingContext.BrowsingContext, + ? promptUnload: bool .default false +} + +browsingContext.CloseResult = EmptyResult + +browsingContext.Create = ( + method: "browsingContext.create", + params: browsingContext.CreateParameters +) + +browsingContext.CreateType = "tab" / "window" + +browsingContext.CreateParameters = { + type: browsingContext.CreateType, + ? referenceContext: browsingContext.BrowsingContext, + ? background: bool .default false, + ? userContext: browser.UserContext +} + +browsingContext.CreateResult = { + context: browsingContext.BrowsingContext +} + +browsingContext.GetTree = ( + method: "browsingContext.getTree", + params: browsingContext.GetTreeParameters +) + +browsingContext.GetTreeParameters = { + ? maxDepth: js-uint, + ? root: browsingContext.BrowsingContext, +} + +browsingContext.GetTreeResult = { + contexts: browsingContext.InfoList +} + +browsingContext.HandleUserPrompt = ( + method: "browsingContext.handleUserPrompt", + params: browsingContext.HandleUserPromptParameters +) + +browsingContext.HandleUserPromptParameters = { + context: browsingContext.BrowsingContext, + ? accept: bool, + ? userText: text, +} + +browsingContext.HandleUserPromptResult = EmptyResult + +browsingContext.LocateNodes = ( + method: "browsingContext.locateNodes", + params: browsingContext.LocateNodesParameters +) + +browsingContext.LocateNodesParameters = { + context: browsingContext.BrowsingContext, + locator: browsingContext.Locator, + ? maxNodeCount: (js-uint .ge 1), + ? serializationOptions: script.SerializationOptions, + ? startNodes: [ + script.SharedReference ] +} + +browsingContext.LocateNodesResult = { + nodes: [ * script.NodeRemoteValue ] +} + +browsingContext.Navigate = ( + method: "browsingContext.navigate", + params: browsingContext.NavigateParameters +) + +browsingContext.NavigateParameters = { + context: browsingContext.BrowsingContext, + url: text, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.NavigateResult = { + navigation: browsingContext.Navigation / null, + url: text, +} + +browsingContext.Print = ( + method: "browsingContext.print", + params: browsingContext.PrintParameters +) + +browsingContext.PrintParameters = { + context: browsingContext.BrowsingContext, + ? background: bool .default false, + ? margin: browsingContext.PrintMarginParameters, + ? orientation: ("portrait" / "landscape") .default "portrait", + ? page: browsingContext.PrintPageParameters, + ? pageRanges: [*(js-uint / text)], + ? scale: (0.1..2.0) .default 1.0, + ? shrinkToFit: bool .default true, +} + +browsingContext.PrintMarginParameters = { + ? bottom: (float .ge 0.0) .default 1.0, + ? left: (float .ge 0.0) .default 1.0, + ? right: (float .ge 0.0) .default 1.0, + ? top: (float .ge 0.0) .default 1.0, +} + +; Minimum size is 1pt x 1pt. Conversion follows from +; https://www.w3.org/TR/css3-values/#absolute-lengths +browsingContext.PrintPageParameters = { + ? height: (float .ge 0.0352) .default 27.94, + ? width: (float .ge 0.0352) .default 21.59, +} + +browsingContext.PrintResult = { + data: text +} + +browsingContext.Reload = ( + method: "browsingContext.reload", + params: browsingContext.ReloadParameters +) + +browsingContext.ReloadParameters = { + context: browsingContext.BrowsingContext, + ? ignoreCache: bool, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.ReloadResult = browsingContext.NavigateResult + +browsingContext.SetViewport = ( + method: "browsingContext.setViewport", + params: browsingContext.SetViewportParameters +) + +browsingContext.SetViewportParameters = { + ? context: browsingContext.BrowsingContext, + ? viewport: browsingContext.Viewport / null, + ? devicePixelRatio: (float .gt 0.0) / null, + ? userContexts: [+browser.UserContext], +} + +browsingContext.Viewport = { + width: js-uint, + height: js-uint, +} + +browsingContext.SetViewportResult = EmptyResult + +browsingContext.TraverseHistory = ( + method: "browsingContext.traverseHistory", + params: browsingContext.TraverseHistoryParameters +) + +browsingContext.TraverseHistoryParameters = { + context: browsingContext.BrowsingContext, + delta: js-int, +} + +browsingContext.TraverseHistoryResult = EmptyResult + +browsingContext.ContextCreated = ( + method: "browsingContext.contextCreated", + params: browsingContext.Info +) + +browsingContext.ContextDestroyed = ( + method: "browsingContext.contextDestroyed", + params: browsingContext.Info +) + +browsingContext.NavigationStarted = ( + method: "browsingContext.navigationStarted", + params: browsingContext.NavigationInfo +) + +browsingContext.FragmentNavigated = ( + method: "browsingContext.fragmentNavigated", + params: browsingContext.NavigationInfo +) + +browsingContext.HistoryUpdated = ( + method: "browsingContext.historyUpdated", + params: browsingContext.HistoryUpdatedParameters +) + +browsingContext.HistoryUpdatedParameters = { + context: browsingContext.BrowsingContext, + timestamp: js-uint, + url: text +} + +browsingContext.DomContentLoaded = ( + method: "browsingContext.domContentLoaded", + params: browsingContext.NavigationInfo +) + +browsingContext.Load = ( + method: "browsingContext.load", + params: browsingContext.NavigationInfo +) + +browsingContext.DownloadWillBegin = ( + method: "browsingContext.downloadWillBegin", + params: browsingContext.DownloadWillBeginParams +) + +browsingContext.DownloadWillBeginParams = { + suggestedFilename: text, + browsingContext.BaseNavigationInfo +} + +browsingContext.DownloadEnd = ( + method: "browsingContext.downloadEnd", + params: browsingContext.DownloadEndParams +) + +browsingContext.DownloadEndParams = { + ( + browsingContext.DownloadCanceledParams // + browsingContext.DownloadCompleteParams + ) +} + +browsingContext.DownloadCanceledParams = ( + status: "canceled", + browsingContext.BaseNavigationInfo +) + +browsingContext.DownloadCompleteParams = ( + status: "complete", + filepath: text / null, + browsingContext.BaseNavigationInfo +) + +browsingContext.NavigationAborted = ( + method: "browsingContext.navigationAborted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationCommitted = ( + method: "browsingContext.navigationCommitted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationFailed = ( + method: "browsingContext.navigationFailed", + params: browsingContext.NavigationInfo +) + +browsingContext.UserPromptClosed = ( + method: "browsingContext.userPromptClosed", + params: browsingContext.UserPromptClosedParameters +) + +browsingContext.UserPromptClosedParameters = { + context: browsingContext.BrowsingContext, + accepted: bool, + type: browsingContext.UserPromptType, + ? userText: text +} + +browsingContext.UserPromptOpened = ( + method: "browsingContext.userPromptOpened", + params: browsingContext.UserPromptOpenedParameters +) + +browsingContext.UserPromptOpenedParameters = { + context: browsingContext.BrowsingContext, + handler: session.UserPromptHandlerType, + message: text, + type: browsingContext.UserPromptType, + ? defaultValue: text +} + +EmulationCommand = ( + emulation.SetForcedColorsModeThemeOverride // + emulation.SetGeolocationOverride // + emulation.SetLocaleOverride // + emulation.SetNetworkConditions // + emulation.SetScreenOrientationOverride // + emulation.SetScreenSettingsOverride // + emulation.SetScriptingEnabled // + emulation.SetScrollbarTypeOverride // + emulation.SetTimezoneOverride // + emulation.SetTouchOverride // + emulation.SetUserAgentOverride // + emulation.SetViewportMetaOverride +) + + +EmulationResult = ( + emulation.SetForcedColorsModeThemeOverrideResult / + emulation.SetGeolocationOverrideResult / + emulation.SetLocaleOverrideResult / + emulation.SetScreenOrientationOverrideResult / + emulation.SetScriptingEnabledResult / + emulation.SetScrollbarTypeOverrideResult / + emulation.SetTimezoneOverrideResult / + emulation.SetTouchOverrideResult / + emulation.SetUserAgentOverrideResult / + emulation.SetViewportMetaOverrideResult +) + +emulation.SetForcedColorsModeThemeOverride = ( + method: "emulation.setForcedColorsModeThemeOverride", + params: emulation.SetForcedColorsModeThemeOverrideParameters +) + +emulation.SetForcedColorsModeThemeOverrideParameters = { + theme: emulation.ForcedColorsModeTheme / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.ForcedColorsModeTheme = "light" / "dark" + +emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult + +emulation.SetGeolocationOverride = ( + method: "emulation.setGeolocationOverride", + params: emulation.SetGeolocationOverrideParameters +) + +emulation.SetGeolocationOverrideParameters = { + ( + (coordinates: emulation.GeolocationCoordinates / null) // + (error: emulation.GeolocationPositionError) + ), + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.GeolocationCoordinates = { + latitude: -90.0..90.0, + longitude: -180.0..180.0, + ? accuracy: (float .ge 0.0) .default 1.0, + ? altitude: float / null .default null, + ? altitudeAccuracy: (float .ge 0.0) / null .default null, + ? heading: (0.0...360.0) / null .default null, + ? speed: (float .ge 0.0) / null .default null, +} + +emulation.GeolocationPositionError = { + type: "positionUnavailable" +} + +emulation.SetGeolocationOverrideResult = EmptyResult + +emulation.SetLocaleOverride = ( + method: "emulation.setLocaleOverride", + params: emulation.SetLocaleOverrideParameters +) + +emulation.SetLocaleOverrideParameters = { + locale: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetLocaleOverrideResult = EmptyResult + +emulation.SetNetworkConditions = ( + method: "emulation.setNetworkConditions", + params: emulation.setNetworkConditionsParameters +) + +emulation.setNetworkConditionsParameters = { + networkConditions: emulation.NetworkConditions / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.NetworkConditions = emulation.NetworkConditionsOffline + +emulation.NetworkConditionsOffline = { + type: "offline" +} + +emulation.SetNetworkConditionsResult = EmptyResult + +emulation.SetScreenSettingsOverride = ( + method: "emulation.setScreenSettingsOverride", + params: emulation.SetScreenSettingsOverrideParameters +) + +emulation.ScreenArea = { + width: js-uint, + height: js-uint +} + +emulation.SetScreenSettingsOverrideParameters = { + screenArea: emulation.ScreenArea / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenSettingsOverrideResult = EmptyResult + +emulation.SetScreenOrientationOverride = ( + method: "emulation.setScreenOrientationOverride", + params: emulation.SetScreenOrientationOverrideParameters +) + +emulation.ScreenOrientationNatural = "portrait" / "landscape" +emulation.ScreenOrientationType = "portrait-primary" / "portrait-secondary" / "landscape-primary" / "landscape-secondary" + +emulation.ScreenOrientation = { + natural: emulation.ScreenOrientationNatural, + type: emulation.ScreenOrientationType +} + +emulation.SetScreenOrientationOverrideParameters = { + screenOrientation: emulation.ScreenOrientation / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenOrientationOverrideResult = EmptyResult + +emulation.SetUserAgentOverride = ( + method: "emulation.setUserAgentOverride", + params: emulation.SetUserAgentOverrideParameters +) + +emulation.SetUserAgentOverrideParameters = { + userAgent: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetUserAgentOverrideResult = EmptyResult + +emulation.SetViewportMetaOverride = ( + method: "emulation.setViewportMetaOverride", + params: emulation.SetViewportMetaOverrideParameters +) + +emulation.SetViewportMetaOverrideParameters = { + viewportMeta: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetViewportMetaOverrideResult = EmptyResult + +emulation.SetScriptingEnabled = ( + method: "emulation.setScriptingEnabled", + params: emulation.SetScriptingEnabledParameters +) + +emulation.SetScriptingEnabledParameters = { + enabled: false / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScriptingEnabledResult = EmptyResult + +emulation.SetScrollbarTypeOverride = ( + method: "emulation.setScrollbarTypeOverride", + params: emulation.SetScrollbarTypeOverrideParameters +) + +emulation.SetScrollbarTypeOverrideParameters = { + scrollbarType: "classic" / "overlay" / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScrollbarTypeOverrideResult = EmptyResult + +emulation.SetTimezoneOverride = ( + method: "emulation.setTimezoneOverride", + params: emulation.SetTimezoneOverrideParameters +) + +emulation.SetTimezoneOverrideParameters = { + timezone: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTimezoneOverrideResult = EmptyResult + +emulation.SetTouchOverride = ( + method: "emulation.setTouchOverride", + params: emulation.SetTouchOverrideParameters +) + +emulation.SetTouchOverrideParameters = { + maxTouchPoints: (js-uint .ge 1) / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTouchOverrideResult = EmptyResult + + +NetworkCommand = ( + network.AddDataCollector // + network.AddIntercept // + network.ContinueRequest // + network.ContinueResponse // + network.ContinueWithAuth // + network.DisownData // + network.FailRequest // + network.GetData // + network.ProvideResponse // + network.RemoveDataCollector // + network.RemoveIntercept // + network.SetCacheBehavior // + network.SetExtraHeaders +) + + + +NetworkResult = ( + network.AddDataCollectorResult / + network.AddInterceptResult / + network.ContinueRequestResult / + network.ContinueResponseResult / + network.ContinueWithAuthResult / + network.DisownDataResult / + network.FailRequestResult / + network.GetDataResult / + network.ProvideResponseResult / + network.RemoveDataCollectorResult / + network.RemoveInterceptResult / + network.SetCacheBehaviorResult / + network.SetExtraHeadersResult +) + +NetworkEvent = ( + network.AuthRequired // + network.BeforeRequestSent // + network.FetchError // + network.ResponseCompleted // + network.ResponseStarted +) + + +network.AuthChallenge = { + scheme: text, + realm: text, +} + +network.AuthCredentials = { + type: "password", + username: text, + password: text, +} + +network.BaseParameters = ( + context: browsingContext.BrowsingContext / null, + isBlocked: bool, + navigation: browsingContext.Navigation / null, + redirectCount: js-uint, + request: network.RequestData, + timestamp: js-uint, + ? intercepts: [+network.Intercept] +) + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.CookieHeader = { + name: text, + value: network.BytesValue, +} + +network.DataType = "request" / "response" + +network.FetchTimingInfo = { + timeOrigin: float, + requestTime: float, + redirectStart: float, + redirectEnd: float, + fetchStart: float, + dnsStart: float, + dnsEnd: float, + connectStart: float, + connectEnd: float, + tlsStart: float, + + requestStart: float, + responseStart: float, + + responseEnd: float, +} + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Initiator = { + ? columnNumber: js-uint, + ? lineNumber: js-uint, + ? request: network.Request, + ? stackTrace: script.StackTrace, + ? type: "parser" / "script" / "preflight" / "other" +} + +network.Intercept = text + +network.Request = text; + +network.RequestData = { + request: network.Request, + url: text, + method: text, + headers: [*network.Header], + cookies: [*network.Cookie], + headersSize: js-uint, + bodySize: js-uint / null, + destination: text, + initiatorType: text / null, + timings: network.FetchTimingInfo, +} + +network.ResponseContent = { + size: js-uint +} + +network.ResponseData = { + url: text, + protocol: text, + status: js-uint, + statusText: text, + fromCache: bool, + headers: [*network.Header], + mimeType: text, + bytesReceived: js-uint, + headersSize: js-uint / null, + bodySize: js-uint / null, + content: network.ResponseContent, + ?authChallenges: [*network.AuthChallenge], +} + + +network.SetCookieHeader = { + name: text, + value: network.BytesValue, + ? domain: text, + ? httpOnly: bool, + ? expiry: text, + ? maxAge: js-int, + ? path: text, + ? sameSite: network.SameSite, + ? secure: bool, +} + +network.UrlPattern = ( + network.UrlPatternPattern / + network.UrlPatternString +) + +network.UrlPatternPattern = { + type: "pattern", + ?protocol: text, + ?hostname: text, + ?port: text, + ?pathname: text, + ?search: text, +} + + +network.UrlPatternString = { + type: "string", + pattern: text, +} + + +network.AddDataCollector = ( + method: "network.addDataCollector", + params: network.AddDataCollectorParameters +) + +network.AddDataCollectorParameters = { + dataTypes: [+network.DataType], + maxEncodedDataSize: js-uint, + ? collectorType: network.CollectorType .default "blob", + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +network.AddDataCollectorResult = { + collector: network.Collector +} + +network.AddIntercept = ( + method: "network.addIntercept", + params: network.AddInterceptParameters +) + +network.AddInterceptParameters = { + phases: [+network.InterceptPhase], + ? contexts: [+browsingContext.BrowsingContext], + ? urlPatterns: [*network.UrlPattern], +} + +network.InterceptPhase = "beforeRequestSent" / "responseStarted" / + "authRequired" + +network.AddInterceptResult = { + intercept: network.Intercept +} + +network.ContinueRequest = ( + method: "network.continueRequest", + params: network.ContinueRequestParameters +) + +network.ContinueRequestParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.CookieHeader], + ?headers: [*network.Header], + ?method: text, + ?url: text, +} + +network.ContinueRequestResult = EmptyResult + +network.ContinueResponse = ( + method: "network.continueResponse", + params: network.ContinueResponseParameters +) + +network.ContinueResponseParameters = { + request: network.Request, + ?cookies: [*network.SetCookieHeader] + ?credentials: network.AuthCredentials, + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ContinueResponseResult = EmptyResult + +network.ContinueWithAuth = ( + method: "network.continueWithAuth", + params: network.ContinueWithAuthParameters +) + +network.ContinueWithAuthParameters = { + request: network.Request, + (network.ContinueWithAuthCredentials // network.ContinueWithAuthNoCredentials) +} + +network.ContinueWithAuthCredentials = ( + action: "provideCredentials", + credentials: network.AuthCredentials +) + +network.ContinueWithAuthNoCredentials = ( + action: "default" / "cancel" +) + +network.ContinueWithAuthResult = EmptyResult + +network.DisownData = ( + method: "network.disownData", + params: network.disownDataParameters +) + +network.disownDataParameters = { + dataType: network.DataType, + collector: network.Collector, + request: network.Request, +} + +network.DisownDataResult = EmptyResult + +network.FailRequest = ( + method: "network.failRequest", + params: network.FailRequestParameters +) + +network.FailRequestParameters = { + request: network.Request, +} + +network.FailRequestResult = EmptyResult + +network.GetData = ( + method: "network.getData", + params: network.GetDataParameters +) + +network.GetDataParameters = { + dataType: network.DataType, + ? collector: network.Collector, + ? disown: bool .default false, + request: network.Request, +} + +network.GetDataResult = { + bytes: network.BytesValue, +} + +network.ProvideResponse = ( + method: "network.provideResponse", + params: network.ProvideResponseParameters +) + +network.ProvideResponseParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.SetCookieHeader], + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ProvideResponseResult = EmptyResult + +network.RemoveDataCollector = ( + method: "network.removeDataCollector", + params: network.RemoveDataCollectorParameters +) + +network.RemoveDataCollectorParameters = { + collector: network.Collector +} + +network.RemoveDataCollectorResult = EmptyResult + +network.RemoveIntercept = ( + method: "network.removeIntercept", + params: network.RemoveInterceptParameters +) + +network.RemoveInterceptParameters = { + intercept: network.Intercept +} + +network.RemoveInterceptResult = EmptyResult + +network.SetCacheBehavior = ( + method: "network.setCacheBehavior", + params: network.SetCacheBehaviorParameters +) + +network.SetCacheBehaviorParameters = { + cacheBehavior: "default" / "bypass", + ? contexts: [+browsingContext.BrowsingContext] +} + +network.SetCacheBehaviorResult = EmptyResult + +network.SetExtraHeaders = ( + method: "network.setExtraHeaders", + params: network.SetExtraHeadersParameters +) + +network.SetExtraHeadersParameters = { + headers: [*network.Header] + ? contexts: [+browsingContext.BrowsingContext] + ? userContexts: [+browser.UserContext] +} + +network.SetExtraHeadersResult = EmptyResult + +network.AuthRequired = ( + method: "network.authRequired", + params: network.AuthRequiredParameters +) + +network.AuthRequiredParameters = { + network.BaseParameters, + response: network.ResponseData +} + + network.BeforeRequestSent = ( + method: "network.beforeRequestSent", + params: network.BeforeRequestSentParameters + ) + +network.BeforeRequestSentParameters = { + network.BaseParameters, + ? initiator: network.Initiator, +} + + network.FetchError = ( + method: "network.fetchError", + params: network.FetchErrorParameters + ) + +network.FetchErrorParameters = { + network.BaseParameters, + errorText: text, +} + + network.ResponseCompleted = ( + method: "network.responseCompleted", + params: network.ResponseCompletedParameters + ) + +network.ResponseCompletedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + + network.ResponseStarted = ( + method: "network.responseStarted", + params: network.ResponseStartedParameters + ) + +network.ResponseStartedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + +ScriptCommand = ( + script.AddPreloadScript // + script.CallFunction // + script.Disown // + script.Evaluate // + script.GetRealms // + script.RemovePreloadScript +) + +ScriptResult = ( + script.AddPreloadScriptResult / + script.CallFunctionResult / + script.DisownResult / + script.EvaluateResult / + script.GetRealmsResult / + script.RemovePreloadScriptResult +) + +ScriptEvent = ( + script.Message // + script.RealmCreated // + script.RealmDestroyed +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmInfo = ( + script.WindowRealmInfo / + script.DedicatedWorkerRealmInfo / + script.SharedWorkerRealmInfo / + script.ServiceWorkerRealmInfo / + script.WorkerRealmInfo / + script.PaintWorkletRealmInfo / + script.AudioWorkletRealmInfo / + script.WorkletRealmInfo +) + +script.BaseRealmInfo = ( + realm: script.Realm, + origin: text +) + +script.WindowRealmInfo = { + script.BaseRealmInfo, + type: "window", + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.DedicatedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "dedicated-worker", + owners: [script.Realm] +} + +script.SharedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "shared-worker" +} + +script.ServiceWorkerRealmInfo = { + script.BaseRealmInfo, + type: "service-worker" +} + +script.WorkerRealmInfo = { + script.BaseRealmInfo, + type: "worker" +} + +script.PaintWorkletRealmInfo = { + script.BaseRealmInfo, + type: "paint-worklet" +} + +script.AudioWorkletRealmInfo = { + script.BaseRealmInfo, + type: "audio-worklet" +} + +script.WorkletRealmInfo = { + script.BaseRealmInfo, + type: "worklet" +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.Source = { + realm: script.Realm, + ? context: browsingContext.BrowsingContext +} + +script.RealmTarget = { + realm: script.Realm +} + +script.ContextTarget = { + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.Target = ( + script.ContextTarget / + script.RealmTarget +) + +script.AddPreloadScript = ( + method: "script.addPreloadScript", + params: script.AddPreloadScriptParameters +) + +script.AddPreloadScriptParameters = { + functionDeclaration: text, + ? arguments: [*script.ChannelValue], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], + ? sandbox: text +} + +script.AddPreloadScriptResult = { + script: script.PreloadScript +} + +script.Disown = ( + method: "script.disown", + params: script.DisownParameters +) + +script.DisownParameters = { + handles: [*script.Handle] + target: script.Target; +} + +script.DisownResult = EmptyResult + +script.CallFunction = ( + method: "script.callFunction", + params: script.CallFunctionParameters +) + +script.CallFunctionParameters = { + functionDeclaration: text, + awaitPromise: bool, + target: script.Target, + ? arguments: [*script.LocalValue], + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? this: script.LocalValue, + ? userActivation: bool .default false, +} + +script.CallFunctionResult = script.EvaluateResult + +script.Evaluate = ( + method: "script.evaluate", + params: script.EvaluateParameters +) + +script.EvaluateParameters = { + expression: text, + target: script.Target, + awaitPromise: bool, + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? userActivation: bool .default false, +} + +script.GetRealms = ( + method: "script.getRealms", + params: script.GetRealmsParameters +) + +script.GetRealmsParameters = { + ? context: browsingContext.BrowsingContext, + ? type: script.RealmType, +} + +script.GetRealmsResult = { + realms: [*script.RealmInfo] +} + +script.RemovePreloadScript = ( + method: "script.removePreloadScript", + params: script.RemovePreloadScriptParameters +) + +script.RemovePreloadScriptParameters = { + script: script.PreloadScript +} + +script.RemovePreloadScriptResult = EmptyResult + + script.Message = ( + method: "script.message", + params: script.MessageParameters + ) + +script.MessageParameters = { + channel: script.Channel, + data: script.RemoteValue, + source: script.Source, +} + +script.RealmCreated = ( + method: "script.realmCreated", + params: script.RealmInfo +) + +script.RealmDestroyed = ( + method: "script.realmDestroyed", + params: script.RealmDestroyedParameters +) + +script.RealmDestroyedParameters = { + realm: script.Realm +} + + +StorageCommand = ( + storage.DeleteCookies // + storage.GetCookies // + storage.SetCookie +) + +StorageResult = ( + storage.DeleteCookiesResult / + storage.GetCookiesResult / + storage.SetCookieResult +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookies = ( + method: "storage.getCookies", + params: storage.GetCookiesParameters +) + + +storage.CookieFilter = { + ? name: text, + ? value: network.BytesValue, + ? domain: text, + ? path: text, + ? size: js-uint, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.BrowsingContextPartitionDescriptor = { + type: "context", + context: browsingContext.BrowsingContext +} + +storage.StorageKeyPartitionDescriptor = { + type: "storageKey", + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.PartitionDescriptor = ( + storage.BrowsingContextPartitionDescriptor / + storage.StorageKeyPartitionDescriptor +) + +storage.GetCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.GetCookiesResult = { + cookies: [*network.Cookie], + partitionKey: storage.PartitionKey, +} + +storage.SetCookie = ( + method: "storage.setCookie", + params: storage.SetCookieParameters, +) + + +storage.PartialCookie = { + name: text, + value: network.BytesValue, + domain: text, + ? path: text, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.SetCookieParameters = { + cookie: storage.PartialCookie, + ? partition: storage.PartitionDescriptor, +} + +storage.SetCookieResult = { + partitionKey: storage.PartitionKey +} + +storage.DeleteCookies = ( + method: "storage.deleteCookies", + params: storage.DeleteCookiesParameters, +) + +storage.DeleteCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.DeleteCookiesResult = { + partitionKey: storage.PartitionKey +} + +LogEvent = ( + log.EntryAdded +) + +log.Level = "debug" / "info" / "warn" / "error" + +log.Entry = ( + log.GenericLogEntry / + log.ConsoleLogEntry / + log.JavascriptLogEntry +) + +log.BaseLogEntry = ( + level: log.Level, + source: script.Source, + text: text / null, + timestamp: js-uint, + ? stackTrace: script.StackTrace, +) + +log.GenericLogEntry = { + log.BaseLogEntry, + type: text, +} + +log.ConsoleLogEntry = { + log.BaseLogEntry, + type: "console", + method: text, + args: [*script.RemoteValue], +} + +log.JavascriptLogEntry = { + log.BaseLogEntry, + type: "javascript", +} + +log.EntryAdded = ( + method: "log.entryAdded", + params: log.Entry, +) + +InputCommand = ( + input.PerformActions // + input.ReleaseActions // + input.SetFiles +) + +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + + +InputEvent = ( + input.FileDialogOpened +) + +input.ElementOrigin = { + type: "element", + element: script.SharedReference +} + +input.PerformActions = ( + method: "input.performActions", + params: input.PerformActionsParameters +) + +input.PerformActionsParameters = { + context: browsingContext.BrowsingContext, + actions: [*input.SourceActions] +} + +input.SourceActions = ( + input.NoneSourceActions / + input.KeySourceActions / + input.PointerSourceActions / + input.WheelSourceActions +) + +input.NoneSourceActions = { + type: "none", + id: text, + actions: [*input.NoneSourceAction] +} + +input.NoneSourceAction = input.PauseAction + +input.KeySourceActions = { + type: "key", + id: text, + actions: [*input.KeySourceAction] +} + +input.KeySourceAction = ( + input.PauseAction / + input.KeyDownAction / + input.KeyUpAction +) + +input.PointerSourceActions = { + type: "pointer", + id: text, + ? parameters: input.PointerParameters, + actions: [*input.PointerSourceAction] +} + +input.PointerType = "mouse" / "pen" / "touch" + +input.PointerParameters = { + ? pointerType: input.PointerType .default "mouse" +} + +input.PointerSourceAction = ( + input.PauseAction / + input.PointerDownAction / + input.PointerUpAction / + input.PointerMoveAction +) + +input.WheelSourceActions = { + type: "wheel", + id: text, + actions: [*input.WheelSourceAction] +} + +input.WheelSourceAction = ( + input.PauseAction / + input.WheelScrollAction +) + +input.PauseAction = { + type: "pause", + ? duration: js-uint +} + +input.KeyDownAction = { + type: "keyDown", + value: text +} + +input.KeyUpAction = { + type: "keyUp", + value: text +} + +input.PointerUpAction = { + type: "pointerUp", + button: js-uint, +} + +input.PointerDownAction = { + type: "pointerDown", + button: js-uint, + input.PointerCommonProperties +} + +input.PointerMoveAction = { + type: "pointerMove", + x: float, + y: float, + ? duration: js-uint, + ? origin: input.Origin, + input.PointerCommonProperties +} + +input.WheelScrollAction = { + type: "scroll", + x: js-int, + y: js-int, + deltaX: js-int, + deltaY: js-int, + ? duration: js-uint, + ? origin: input.Origin .default "viewport", +} + +input.PointerCommonProperties = ( + ? width: js-uint .default 1, + ? height: js-uint .default 1, + ? pressure: float .default 0.0, + ? tangentialPressure: float .default 0.0, + ? twist: (0..359) .default 0, + ; 0 .. Math.PI / 2 + ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ; 0 .. 2 * Math.PI + ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, +) + +input.Origin = "viewport" / "pointer" / input.ElementOrigin + +input.PerformActionsResult = EmptyResult + +input.ReleaseActions = ( + method: "input.releaseActions", + params: input.ReleaseActionsParameters +) + +input.ReleaseActionsParameters = { + context: browsingContext.BrowsingContext, +} + +input.ReleaseActionsResult = EmptyResult + +input.SetFiles = ( + method: "input.setFiles", + params: input.SetFilesParameters +) + +input.SetFilesParameters = { + context: browsingContext.BrowsingContext, + element: script.SharedReference, + files: [*text] +} + +input.SetFilesResult = EmptyResult + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionCommand = ( + webExtension.Install // + webExtension.Uninstall +) + +WebExtensionResult = ( + webExtension.InstallResult / + webExtension.UninstallResult +) + +webExtension.Extension = text + +webExtension.Install = ( + method: "webExtension.install", + params: webExtension.InstallParameters +) + +webExtension.InstallParameters = { + extensionData: webExtension.ExtensionData, +} + +webExtension.ExtensionData = ( + webExtension.ExtensionArchivePath / + webExtension.ExtensionBase64Encoded / + webExtension.ExtensionPath +) + +webExtension.ExtensionPath = { + type: "path", + path: text, +} + +webExtension.ExtensionArchivePath = { + type: "archivePath", + path: text, +} + +webExtension.ExtensionBase64Encoded = { + type: "base64", + value: text, +} + +webExtension.InstallResult = { + extension: webExtension.Extension +} + +webExtension.Uninstall = ( + method: "webExtension.uninstall", + params: webExtension.UninstallParameters +) + +webExtension.UninstallParameters = { + extension: webExtension.Extension, +} + +webExtension.UninstallResult = EmptyResult diff --git a/common/bidi/spec/local.cddl b/common/bidi/spec/local.cddl new file mode 100644 index 0000000000000..d43af0ae11b03 --- /dev/null +++ b/common/bidi/spec/local.cddl @@ -0,0 +1,1331 @@ +Message = ( + CommandResponse / + ErrorResponse / + Event +) + +CommandResponse = { + type: "success", + id: js-uint, + result: ResultData, + Extensible +} + +ErrorResponse = { + type: "error", + id: js-uint / null, + error: ErrorCode, + message: text, + ? stacktrace: text, + Extensible +} + +ResultData = ( + BrowserResult / + BrowsingContextResult / + EmulationResult / + InputResult / + NetworkResult / + ScriptResult / + SessionResult / + StorageResult / + WebExtensionResult +) + +EmptyResult = { + Extensible +} + +Event = { + type: "event", + EventData, + Extensible +} + +EventData = ( + BrowsingContextEvent // + InputEvent // + LogEvent // + NetworkEvent // + ScriptEvent +) + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +ErrorCode = "invalid argument" / + "invalid selector" / + "invalid session id" / + "invalid web extension" / + "move target out of bounds" / + "no such alert" / + "no such network collector" / + "no such element" / + "no such frame" / + "no such handle" / + "no such history entry" / + "no such intercept" / + "no such network data" / + "no such node" / + "no such request" / + "no such script" / + "no such storage partition" / + "no such user context" / + "no such web extension" / + "session not created" / + "unable to capture screen" / + "unable to close browser" / + "unable to set cookie" / + "unable to set file input" / + "unavailable network data" / + "underspecified storage partition" / + "unknown command" / + "unknown error" / + "unsupported operation" + +SessionResult = ( + session.EndResult / + session.NewResult / + session.StatusResult / + session.SubscribeResult / + session.UnsubscribeResult +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.StatusResult = { + ready: bool, + message: text, +} + +session.NewResult = { + sessionId: text, + capabilities: { + acceptInsecureCerts: bool, + browserName: text, + browserVersion: text, + platformName: text, + setWindowRect: bool, + userAgent: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + ? webSocketUrl: text, + Extensible + } +} + +session.EndResult = EmptyResult + +session.SubscribeResult = { + subscription: session.Subscription, +} + +session.UnsubscribeResult = EmptyResult + +BrowserResult = ( + browser.CloseResult / + browser.CreateUserContextResult / + browser.GetClientWindowsResult / + browser.GetUserContextsResult / + browser.RemoveUserContextResult / + browser.SetClientWindowStateResult / + browser.SetDownloadBehaviorResult +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.CloseResult = EmptyResult + +browser.CreateUserContextResult = browser.UserContextInfo + +browser.GetClientWindowsResult = { + clientWindows: [ * browser.ClientWindowInfo] +} + +browser.GetUserContextsResult = { + userContexts: [ + browser.UserContextInfo] +} + +browser.RemoveUserContextResult = EmptyResult + +browser.SetClientWindowStateResult = browser.ClientWindowInfo + +browser.SetDownloadBehaviorResult = EmptyResult + +BrowsingContextResult = ( + browsingContext.ActivateResult / + browsingContext.CaptureScreenshotResult / + browsingContext.CloseResult / + browsingContext.CreateResult / + browsingContext.GetTreeResult / + browsingContext.HandleUserPromptResult / + browsingContext.LocateNodesResult / + browsingContext.NavigateResult / + browsingContext.PrintResult / + browsingContext.ReloadResult / + browsingContext.SetViewportResult / + browsingContext.TraverseHistoryResult +) + +BrowsingContextEvent = ( + browsingContext.ContextCreated // + browsingContext.ContextDestroyed // + browsingContext.DomContentLoaded // + browsingContext.DownloadEnd // + browsingContext.DownloadWillBegin // + browsingContext.FragmentNavigated // + browsingContext.HistoryUpdated // + browsingContext.Load // + browsingContext.NavigationAborted // + browsingContext.NavigationCommitted // + browsingContext.NavigationFailed // + browsingContext.NavigationStarted // + browsingContext.UserPromptClosed // + browsingContext.UserPromptOpened +) + +browsingContext.BrowsingContext = text; + +browsingContext.InfoList = [*browsingContext.Info] + +browsingContext.Info = { + children: browsingContext.InfoList / null, + clientWindow: browser.ClientWindow, + context: browsingContext.BrowsingContext, + originalOpener: browsingContext.BrowsingContext / null, + url: text, + userContext: browser.UserContext, + ? parent: browsingContext.BrowsingContext / null, +} + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.BaseNavigationInfo = ( + context: browsingContext.BrowsingContext, + navigation: browsingContext.Navigation / null, + timestamp: js-uint, + url: text, +) + +browsingContext.NavigationInfo = { + browsingContext.BaseNavigationInfo +} + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.ActivateResult = EmptyResult + +browsingContext.CaptureScreenshotResult = { + data: text +} + +browsingContext.CloseResult = EmptyResult + +browsingContext.CreateResult = { + context: browsingContext.BrowsingContext +} + +browsingContext.GetTreeResult = { + contexts: browsingContext.InfoList +} + +browsingContext.HandleUserPromptResult = EmptyResult + +browsingContext.LocateNodesResult = { + nodes: [ * script.NodeRemoteValue ] +} + +browsingContext.NavigateResult = { + navigation: browsingContext.Navigation / null, + url: text, +} + +browsingContext.PrintResult = { + data: text +} + +browsingContext.ReloadResult = browsingContext.NavigateResult + +browsingContext.SetViewportResult = EmptyResult + +browsingContext.TraverseHistoryResult = EmptyResult + +browsingContext.ContextCreated = ( + method: "browsingContext.contextCreated", + params: browsingContext.Info +) + +browsingContext.ContextDestroyed = ( + method: "browsingContext.contextDestroyed", + params: browsingContext.Info +) + +browsingContext.NavigationStarted = ( + method: "browsingContext.navigationStarted", + params: browsingContext.NavigationInfo +) + +browsingContext.FragmentNavigated = ( + method: "browsingContext.fragmentNavigated", + params: browsingContext.NavigationInfo +) + +browsingContext.HistoryUpdated = ( + method: "browsingContext.historyUpdated", + params: browsingContext.HistoryUpdatedParameters +) + +browsingContext.HistoryUpdatedParameters = { + context: browsingContext.BrowsingContext, + timestamp: js-uint, + url: text +} + +browsingContext.DomContentLoaded = ( + method: "browsingContext.domContentLoaded", + params: browsingContext.NavigationInfo +) + +browsingContext.Load = ( + method: "browsingContext.load", + params: browsingContext.NavigationInfo +) + +browsingContext.DownloadWillBegin = ( + method: "browsingContext.downloadWillBegin", + params: browsingContext.DownloadWillBeginParams +) + +browsingContext.DownloadWillBeginParams = { + suggestedFilename: text, + browsingContext.BaseNavigationInfo +} + +browsingContext.DownloadEnd = ( + method: "browsingContext.downloadEnd", + params: browsingContext.DownloadEndParams +) + +browsingContext.DownloadEndParams = { + ( + browsingContext.DownloadCanceledParams // + browsingContext.DownloadCompleteParams + ) +} + +browsingContext.DownloadCanceledParams = ( + status: "canceled", + browsingContext.BaseNavigationInfo +) + +browsingContext.DownloadCompleteParams = ( + status: "complete", + filepath: text / null, + browsingContext.BaseNavigationInfo +) + +browsingContext.NavigationAborted = ( + method: "browsingContext.navigationAborted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationCommitted = ( + method: "browsingContext.navigationCommitted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationFailed = ( + method: "browsingContext.navigationFailed", + params: browsingContext.NavigationInfo +) + +browsingContext.UserPromptClosed = ( + method: "browsingContext.userPromptClosed", + params: browsingContext.UserPromptClosedParameters +) + +browsingContext.UserPromptClosedParameters = { + context: browsingContext.BrowsingContext, + accepted: bool, + type: browsingContext.UserPromptType, + ? userText: text +} + +browsingContext.UserPromptOpened = ( + method: "browsingContext.userPromptOpened", + params: browsingContext.UserPromptOpenedParameters +) + +browsingContext.UserPromptOpenedParameters = { + context: browsingContext.BrowsingContext, + handler: session.UserPromptHandlerType, + message: text, + type: browsingContext.UserPromptType, + ? defaultValue: text +} + +EmulationResult = ( + emulation.SetForcedColorsModeThemeOverrideResult / + emulation.SetGeolocationOverrideResult / + emulation.SetLocaleOverrideResult / + emulation.SetScreenOrientationOverrideResult / + emulation.SetScriptingEnabledResult / + emulation.SetScrollbarTypeOverrideResult / + emulation.SetTimezoneOverrideResult / + emulation.SetTouchOverrideResult / + emulation.SetUserAgentOverrideResult / + emulation.SetViewportMetaOverrideResult +) + +emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult + +emulation.SetGeolocationOverrideResult = EmptyResult + +emulation.SetLocaleOverrideResult = EmptyResult + +emulation.SetNetworkConditionsResult = EmptyResult + +emulation.SetScreenSettingsOverrideResult = EmptyResult + +emulation.SetScreenOrientationOverrideResult = EmptyResult + +emulation.SetUserAgentOverrideResult = EmptyResult + +emulation.SetViewportMetaOverrideResult = EmptyResult + +emulation.SetScriptingEnabledResult = EmptyResult + +emulation.SetScrollbarTypeOverrideResult = EmptyResult + +emulation.SetTimezoneOverrideResult = EmptyResult + +emulation.SetTouchOverrideResult = EmptyResult + + +NetworkResult = ( + network.AddDataCollectorResult / + network.AddInterceptResult / + network.ContinueRequestResult / + network.ContinueResponseResult / + network.ContinueWithAuthResult / + network.DisownDataResult / + network.FailRequestResult / + network.GetDataResult / + network.ProvideResponseResult / + network.RemoveDataCollectorResult / + network.RemoveInterceptResult / + network.SetCacheBehaviorResult / + network.SetExtraHeadersResult +) + +NetworkEvent = ( + network.AuthRequired // + network.BeforeRequestSent // + network.FetchError // + network.ResponseCompleted // + network.ResponseStarted +) + + +network.AuthChallenge = { + scheme: text, + realm: text, +} + +network.BaseParameters = ( + context: browsingContext.BrowsingContext / null, + isBlocked: bool, + navigation: browsingContext.Navigation / null, + redirectCount: js-uint, + request: network.RequestData, + timestamp: js-uint, + ? intercepts: [+network.Intercept] +) + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.DataType = "request" / "response" + +network.FetchTimingInfo = { + timeOrigin: float, + requestTime: float, + redirectStart: float, + redirectEnd: float, + fetchStart: float, + dnsStart: float, + dnsEnd: float, + connectStart: float, + connectEnd: float, + tlsStart: float, + + requestStart: float, + responseStart: float, + + responseEnd: float, +} + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Initiator = { + ? columnNumber: js-uint, + ? lineNumber: js-uint, + ? request: network.Request, + ? stackTrace: script.StackTrace, + ? type: "parser" / "script" / "preflight" / "other" +} + +network.Intercept = text + +network.Request = text; + +network.RequestData = { + request: network.Request, + url: text, + method: text, + headers: [*network.Header], + cookies: [*network.Cookie], + headersSize: js-uint, + bodySize: js-uint / null, + destination: text, + initiatorType: text / null, + timings: network.FetchTimingInfo, +} + +network.ResponseContent = { + size: js-uint +} + +network.ResponseData = { + url: text, + protocol: text, + status: js-uint, + statusText: text, + fromCache: bool, + headers: [*network.Header], + mimeType: text, + bytesReceived: js-uint, + headersSize: js-uint / null, + bodySize: js-uint / null, + content: network.ResponseContent, + ?authChallenges: [*network.AuthChallenge], +} + +network.AddDataCollectorResult = { + collector: network.Collector +} + +network.AddInterceptResult = { + intercept: network.Intercept +} + +network.ContinueRequestResult = EmptyResult + +network.ContinueResponseResult = EmptyResult + +network.ContinueWithAuthResult = EmptyResult + +network.DisownDataResult = EmptyResult + +network.FailRequestResult = EmptyResult + +network.GetDataResult = { + bytes: network.BytesValue, +} + +network.ProvideResponseResult = EmptyResult + +network.RemoveDataCollectorResult = EmptyResult + +network.RemoveInterceptResult = EmptyResult + +network.SetCacheBehaviorResult = EmptyResult + +network.SetExtraHeadersResult = EmptyResult + +network.AuthRequired = ( + method: "network.authRequired", + params: network.AuthRequiredParameters +) + +network.AuthRequiredParameters = { + network.BaseParameters, + response: network.ResponseData +} + + network.BeforeRequestSent = ( + method: "network.beforeRequestSent", + params: network.BeforeRequestSentParameters + ) + +network.BeforeRequestSentParameters = { + network.BaseParameters, + ? initiator: network.Initiator, +} + + network.FetchError = ( + method: "network.fetchError", + params: network.FetchErrorParameters + ) + +network.FetchErrorParameters = { + network.BaseParameters, + errorText: text, +} + + network.ResponseCompleted = ( + method: "network.responseCompleted", + params: network.ResponseCompletedParameters + ) + +network.ResponseCompletedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + + network.ResponseStarted = ( + method: "network.responseStarted", + params: network.ResponseStartedParameters + ) + +network.ResponseStartedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + +ScriptResult = ( + script.AddPreloadScriptResult / + script.CallFunctionResult / + script.DisownResult / + script.EvaluateResult / + script.GetRealmsResult / + script.RemovePreloadScriptResult +) + +ScriptEvent = ( + script.Message // + script.RealmCreated // + script.RealmDestroyed +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmInfo = ( + script.WindowRealmInfo / + script.DedicatedWorkerRealmInfo / + script.SharedWorkerRealmInfo / + script.ServiceWorkerRealmInfo / + script.WorkerRealmInfo / + script.PaintWorkletRealmInfo / + script.AudioWorkletRealmInfo / + script.WorkletRealmInfo +) + +script.BaseRealmInfo = ( + realm: script.Realm, + origin: text +) + +script.WindowRealmInfo = { + script.BaseRealmInfo, + type: "window", + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.DedicatedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "dedicated-worker", + owners: [script.Realm] +} + +script.SharedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "shared-worker" +} + +script.ServiceWorkerRealmInfo = { + script.BaseRealmInfo, + type: "service-worker" +} + +script.WorkerRealmInfo = { + script.BaseRealmInfo, + type: "worker" +} + +script.PaintWorkletRealmInfo = { + script.BaseRealmInfo, + type: "paint-worklet" +} + +script.AudioWorkletRealmInfo = { + script.BaseRealmInfo, + type: "audio-worklet" +} + +script.WorkletRealmInfo = { + script.BaseRealmInfo, + type: "worklet" +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.Source = { + realm: script.Realm, + ? context: browsingContext.BrowsingContext +} + +script.AddPreloadScriptResult = { + script: script.PreloadScript +} + +script.DisownResult = EmptyResult + +script.CallFunctionResult = script.EvaluateResult + +script.GetRealmsResult = { + realms: [*script.RealmInfo] +} + +script.RemovePreloadScriptResult = EmptyResult + + script.Message = ( + method: "script.message", + params: script.MessageParameters + ) + +script.MessageParameters = { + channel: script.Channel, + data: script.RemoteValue, + source: script.Source, +} + +script.RealmCreated = ( + method: "script.realmCreated", + params: script.RealmInfo +) + +script.RealmDestroyed = ( + method: "script.realmDestroyed", + params: script.RealmDestroyedParameters +) + +script.RealmDestroyedParameters = { + realm: script.Realm +} + + +StorageResult = ( + storage.DeleteCookiesResult / + storage.GetCookiesResult / + storage.SetCookieResult +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookiesResult = { + cookies: [*network.Cookie], + partitionKey: storage.PartitionKey, +} + +storage.SetCookieResult = { + partitionKey: storage.PartitionKey +} + +storage.DeleteCookiesResult = { + partitionKey: storage.PartitionKey +} + +LogEvent = ( + log.EntryAdded +) + +log.Level = "debug" / "info" / "warn" / "error" + +log.Entry = ( + log.GenericLogEntry / + log.ConsoleLogEntry / + log.JavascriptLogEntry +) + +log.BaseLogEntry = ( + level: log.Level, + source: script.Source, + text: text / null, + timestamp: js-uint, + ? stackTrace: script.StackTrace, +) + +log.GenericLogEntry = { + log.BaseLogEntry, + type: text, +} + +log.ConsoleLogEntry = { + log.BaseLogEntry, + type: "console", + method: text, + args: [*script.RemoteValue], +} + +log.JavascriptLogEntry = { + log.BaseLogEntry, + type: "javascript", +} + +log.EntryAdded = ( + method: "log.entryAdded", + params: log.Entry, +) + + +InputEvent = ( + input.FileDialogOpened +) + +input.PerformActionsResult = EmptyResult + +input.ReleaseActionsResult = EmptyResult + +input.SetFilesResult = EmptyResult + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionResult = ( + webExtension.InstallResult / + webExtension.UninstallResult +) + +webExtension.Extension = text + +webExtension.InstallResult = { + extension: webExtension.Extension +} + +webExtension.UninstallResult = EmptyResult diff --git a/common/bidi/spec/remote.cddl b/common/bidi/spec/remote.cddl new file mode 100644 index 0000000000000..a98859a021e12 --- /dev/null +++ b/common/bidi/spec/remote.cddl @@ -0,0 +1,1716 @@ +Command = { + id: js-uint, + CommandData, + Extensible, +} + +CommandData = ( + BrowserCommand // + BrowsingContextCommand // + EmulationCommand // + InputCommand // + NetworkCommand // + ScriptCommand // + SessionCommand // + StorageCommand // + WebExtensionCommand +) + +EmptyParams = { + Extensible +} + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +SessionCommand = ( + session.End // + session.New // + session.Status // + session.Subscribe // + session.Unsubscribe +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.SubscribeParameters = { + events: [+text], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +session.UnsubscribeByIDRequest = { + subscriptions: [+session.Subscription], +} + +session.UnsubscribeByAttributesRequest = { + events: [+text], +} + +session.Status = ( + method: "session.status", + params: EmptyParams, +) + +session.New = ( + method: "session.new", + params: session.NewParameters +) + +session.NewParameters = { + capabilities: session.CapabilitiesRequest +} + +session.End = ( + method: "session.end", + params: EmptyParams +) + + +session.Subscribe = ( + method: "session.subscribe", + params: session.SubscribeParameters +) + +session.Unsubscribe = ( + method: "session.unsubscribe", + params: session.UnsubscribeParameters, +) + +session.UnsubscribeParameters = session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + +BrowserCommand = ( + browser.Close // + browser.CreateUserContext // + browser.GetClientWindows // + browser.GetUserContexts // + browser.RemoveUserContext // + browser.SetClientWindowState // + browser.SetDownloadBehavior +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.Close = ( + method: "browser.close", + params: EmptyParams, +) + +browser.CreateUserContext = ( + method: "browser.createUserContext", + params: browser.CreateUserContextParameters, +) + +browser.CreateUserContextParameters = { + ? acceptInsecureCerts: bool, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler +} + +browser.GetClientWindows = ( + method: "browser.getClientWindows", + params: EmptyParams, +) + +browser.GetUserContexts = ( + method: "browser.getUserContexts", + params: EmptyParams, +) + +browser.RemoveUserContext = ( + method: "browser.removeUserContext", + params: browser.RemoveUserContextParameters +) + +browser.RemoveUserContextParameters = { + userContext: browser.UserContext +} + +browser.SetClientWindowState = ( + method: "browser.setClientWindowState", + params: browser.SetClientWindowStateParameters +) + +browser.SetClientWindowStateParameters = { + clientWindow: browser.ClientWindow, + (browser.ClientWindowNamedState // browser.ClientWindowRectState) +} + +browser.ClientWindowNamedState = ( + state: "fullscreen" / "maximized" / "minimized" +) + +browser.ClientWindowRectState = ( + state: "normal", + ? width: js-uint, + ? height: js-uint, + ? x: js-int, + ? y: js-int, +) + +browser.SetDownloadBehavior = ( + method: "browser.setDownloadBehavior", + params: browser.SetDownloadBehaviorParameters +) + +browser.SetDownloadBehaviorParameters = { + downloadBehavior: browser.DownloadBehavior / null, + ? userContexts: [+browser.UserContext] +} + +browser.DownloadBehavior = { + ( + browser.DownloadBehaviorAllowed // + browser.DownloadBehaviorDenied + ) +} + +browser.DownloadBehaviorAllowed = ( + type: "allowed", + destinationFolder: text +) + +browser.DownloadBehaviorDenied = ( + type: "denied" +) + +BrowsingContextCommand = ( + browsingContext.Activate // + browsingContext.CaptureScreenshot // + browsingContext.Close // + browsingContext.Create // + browsingContext.GetTree // + browsingContext.HandleUserPrompt // + browsingContext.LocateNodes // + browsingContext.Navigate // + browsingContext.Print // + browsingContext.Reload // + browsingContext.SetViewport // + browsingContext.TraverseHistory +) + +browsingContext.BrowsingContext = text; + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.ReadinessState = "none" / "interactive" / "complete" + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.Activate = ( + method: "browsingContext.activate", + params: browsingContext.ActivateParameters +) + +browsingContext.ActivateParameters = { + context: browsingContext.BrowsingContext +} + +browsingContext.CaptureScreenshot = ( + method: "browsingContext.captureScreenshot", + params: browsingContext.CaptureScreenshotParameters +) + +browsingContext.CaptureScreenshotParameters = { + context: browsingContext.BrowsingContext, + ? origin: ("viewport" / "document") .default "viewport", + ? format: browsingContext.ImageFormat, + ? clip: browsingContext.ClipRectangle, +} + +browsingContext.ImageFormat = { + type: text, + ? quality: 0.0..1.0, +} + +browsingContext.ClipRectangle = ( + browsingContext.BoxClipRectangle / + browsingContext.ElementClipRectangle +) + +browsingContext.ElementClipRectangle = { + type: "element", + element: script.SharedReference +} + +browsingContext.BoxClipRectangle = { + type: "box", + x: float, + y: float, + width: float, + height: float +} + +browsingContext.Close = ( + method: "browsingContext.close", + params: browsingContext.CloseParameters +) + +browsingContext.CloseParameters = { + context: browsingContext.BrowsingContext, + ? promptUnload: bool .default false +} + +browsingContext.Create = ( + method: "browsingContext.create", + params: browsingContext.CreateParameters +) + +browsingContext.CreateType = "tab" / "window" + +browsingContext.CreateParameters = { + type: browsingContext.CreateType, + ? referenceContext: browsingContext.BrowsingContext, + ? background: bool .default false, + ? userContext: browser.UserContext +} + +browsingContext.GetTree = ( + method: "browsingContext.getTree", + params: browsingContext.GetTreeParameters +) + +browsingContext.GetTreeParameters = { + ? maxDepth: js-uint, + ? root: browsingContext.BrowsingContext, +} + +browsingContext.HandleUserPrompt = ( + method: "browsingContext.handleUserPrompt", + params: browsingContext.HandleUserPromptParameters +) + +browsingContext.HandleUserPromptParameters = { + context: browsingContext.BrowsingContext, + ? accept: bool, + ? userText: text, +} + +browsingContext.LocateNodes = ( + method: "browsingContext.locateNodes", + params: browsingContext.LocateNodesParameters +) + +browsingContext.LocateNodesParameters = { + context: browsingContext.BrowsingContext, + locator: browsingContext.Locator, + ? maxNodeCount: (js-uint .ge 1), + ? serializationOptions: script.SerializationOptions, + ? startNodes: [ + script.SharedReference ] +} + +browsingContext.Navigate = ( + method: "browsingContext.navigate", + params: browsingContext.NavigateParameters +) + +browsingContext.NavigateParameters = { + context: browsingContext.BrowsingContext, + url: text, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.Print = ( + method: "browsingContext.print", + params: browsingContext.PrintParameters +) + +browsingContext.PrintParameters = { + context: browsingContext.BrowsingContext, + ? background: bool .default false, + ? margin: browsingContext.PrintMarginParameters, + ? orientation: ("portrait" / "landscape") .default "portrait", + ? page: browsingContext.PrintPageParameters, + ? pageRanges: [*(js-uint / text)], + ? scale: (0.1..2.0) .default 1.0, + ? shrinkToFit: bool .default true, +} + +browsingContext.PrintMarginParameters = { + ? bottom: (float .ge 0.0) .default 1.0, + ? left: (float .ge 0.0) .default 1.0, + ? right: (float .ge 0.0) .default 1.0, + ? top: (float .ge 0.0) .default 1.0, +} + +; Minimum size is 1pt x 1pt. Conversion follows from +; https://www.w3.org/TR/css3-values/#absolute-lengths +browsingContext.PrintPageParameters = { + ? height: (float .ge 0.0352) .default 27.94, + ? width: (float .ge 0.0352) .default 21.59, +} + +browsingContext.Reload = ( + method: "browsingContext.reload", + params: browsingContext.ReloadParameters +) + +browsingContext.ReloadParameters = { + context: browsingContext.BrowsingContext, + ? ignoreCache: bool, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.SetViewport = ( + method: "browsingContext.setViewport", + params: browsingContext.SetViewportParameters +) + +browsingContext.SetViewportParameters = { + ? context: browsingContext.BrowsingContext, + ? viewport: browsingContext.Viewport / null, + ? devicePixelRatio: (float .gt 0.0) / null, + ? userContexts: [+browser.UserContext], +} + +browsingContext.Viewport = { + width: js-uint, + height: js-uint, +} + +browsingContext.TraverseHistory = ( + method: "browsingContext.traverseHistory", + params: browsingContext.TraverseHistoryParameters +) + +browsingContext.TraverseHistoryParameters = { + context: browsingContext.BrowsingContext, + delta: js-int, +} + +EmulationCommand = ( + emulation.SetForcedColorsModeThemeOverride // + emulation.SetGeolocationOverride // + emulation.SetLocaleOverride // + emulation.SetNetworkConditions // + emulation.SetScreenOrientationOverride // + emulation.SetScreenSettingsOverride // + emulation.SetScriptingEnabled // + emulation.SetScrollbarTypeOverride // + emulation.SetTimezoneOverride // + emulation.SetTouchOverride // + emulation.SetUserAgentOverride // + emulation.SetViewportMetaOverride +) + + +emulation.SetForcedColorsModeThemeOverride = ( + method: "emulation.setForcedColorsModeThemeOverride", + params: emulation.SetForcedColorsModeThemeOverrideParameters +) + +emulation.SetForcedColorsModeThemeOverrideParameters = { + theme: emulation.ForcedColorsModeTheme / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.ForcedColorsModeTheme = "light" / "dark" + +emulation.SetGeolocationOverride = ( + method: "emulation.setGeolocationOverride", + params: emulation.SetGeolocationOverrideParameters +) + +emulation.SetGeolocationOverrideParameters = { + ( + (coordinates: emulation.GeolocationCoordinates / null) // + (error: emulation.GeolocationPositionError) + ), + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.GeolocationCoordinates = { + latitude: -90.0..90.0, + longitude: -180.0..180.0, + ? accuracy: (float .ge 0.0) .default 1.0, + ? altitude: float / null .default null, + ? altitudeAccuracy: (float .ge 0.0) / null .default null, + ? heading: (0.0...360.0) / null .default null, + ? speed: (float .ge 0.0) / null .default null, +} + +emulation.GeolocationPositionError = { + type: "positionUnavailable" +} + +emulation.SetLocaleOverride = ( + method: "emulation.setLocaleOverride", + params: emulation.SetLocaleOverrideParameters +) + +emulation.SetLocaleOverrideParameters = { + locale: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetNetworkConditions = ( + method: "emulation.setNetworkConditions", + params: emulation.setNetworkConditionsParameters +) + +emulation.setNetworkConditionsParameters = { + networkConditions: emulation.NetworkConditions / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.NetworkConditions = emulation.NetworkConditionsOffline + +emulation.NetworkConditionsOffline = { + type: "offline" +} + +emulation.SetScreenSettingsOverride = ( + method: "emulation.setScreenSettingsOverride", + params: emulation.SetScreenSettingsOverrideParameters +) + +emulation.ScreenArea = { + width: js-uint, + height: js-uint +} + +emulation.SetScreenSettingsOverrideParameters = { + screenArea: emulation.ScreenArea / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenOrientationOverride = ( + method: "emulation.setScreenOrientationOverride", + params: emulation.SetScreenOrientationOverrideParameters +) + +emulation.ScreenOrientationNatural = "portrait" / "landscape" +emulation.ScreenOrientationType = "portrait-primary" / "portrait-secondary" / "landscape-primary" / "landscape-secondary" + +emulation.ScreenOrientation = { + natural: emulation.ScreenOrientationNatural, + type: emulation.ScreenOrientationType +} + +emulation.SetScreenOrientationOverrideParameters = { + screenOrientation: emulation.ScreenOrientation / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetUserAgentOverride = ( + method: "emulation.setUserAgentOverride", + params: emulation.SetUserAgentOverrideParameters +) + +emulation.SetUserAgentOverrideParameters = { + userAgent: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetViewportMetaOverride = ( + method: "emulation.setViewportMetaOverride", + params: emulation.SetViewportMetaOverrideParameters +) + +emulation.SetViewportMetaOverrideParameters = { + viewportMeta: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScriptingEnabled = ( + method: "emulation.setScriptingEnabled", + params: emulation.SetScriptingEnabledParameters +) + +emulation.SetScriptingEnabledParameters = { + enabled: false / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScrollbarTypeOverride = ( + method: "emulation.setScrollbarTypeOverride", + params: emulation.SetScrollbarTypeOverrideParameters +) + +emulation.SetScrollbarTypeOverrideParameters = { + scrollbarType: "classic" / "overlay" / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTimezoneOverride = ( + method: "emulation.setTimezoneOverride", + params: emulation.SetTimezoneOverrideParameters +) + +emulation.SetTimezoneOverrideParameters = { + timezone: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTouchOverride = ( + method: "emulation.setTouchOverride", + params: emulation.SetTouchOverrideParameters +) + +emulation.SetTouchOverrideParameters = { + maxTouchPoints: (js-uint .ge 1) / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + + +NetworkCommand = ( + network.AddDataCollector // + network.AddIntercept // + network.ContinueRequest // + network.ContinueResponse // + network.ContinueWithAuth // + network.DisownData // + network.FailRequest // + network.GetData // + network.ProvideResponse // + network.RemoveDataCollector // + network.RemoveIntercept // + network.SetCacheBehavior // + network.SetExtraHeaders +) + + +network.AuthCredentials = { + type: "password", + username: text, + password: text, +} + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.CookieHeader = { + name: text, + value: network.BytesValue, +} + +network.DataType = "request" / "response" + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Intercept = text + +network.Request = text; + + +network.SetCookieHeader = { + name: text, + value: network.BytesValue, + ? domain: text, + ? httpOnly: bool, + ? expiry: text, + ? maxAge: js-int, + ? path: text, + ? sameSite: network.SameSite, + ? secure: bool, +} + +network.UrlPattern = ( + network.UrlPatternPattern / + network.UrlPatternString +) + +network.UrlPatternPattern = { + type: "pattern", + ?protocol: text, + ?hostname: text, + ?port: text, + ?pathname: text, + ?search: text, +} + + +network.UrlPatternString = { + type: "string", + pattern: text, +} + + +network.AddDataCollector = ( + method: "network.addDataCollector", + params: network.AddDataCollectorParameters +) + +network.AddDataCollectorParameters = { + dataTypes: [+network.DataType], + maxEncodedDataSize: js-uint, + ? collectorType: network.CollectorType .default "blob", + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +network.AddIntercept = ( + method: "network.addIntercept", + params: network.AddInterceptParameters +) + +network.AddInterceptParameters = { + phases: [+network.InterceptPhase], + ? contexts: [+browsingContext.BrowsingContext], + ? urlPatterns: [*network.UrlPattern], +} + +network.InterceptPhase = "beforeRequestSent" / "responseStarted" / + "authRequired" + +network.ContinueRequest = ( + method: "network.continueRequest", + params: network.ContinueRequestParameters +) + +network.ContinueRequestParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.CookieHeader], + ?headers: [*network.Header], + ?method: text, + ?url: text, +} + +network.ContinueResponse = ( + method: "network.continueResponse", + params: network.ContinueResponseParameters +) + +network.ContinueResponseParameters = { + request: network.Request, + ?cookies: [*network.SetCookieHeader] + ?credentials: network.AuthCredentials, + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ContinueWithAuth = ( + method: "network.continueWithAuth", + params: network.ContinueWithAuthParameters +) + +network.ContinueWithAuthParameters = { + request: network.Request, + (network.ContinueWithAuthCredentials // network.ContinueWithAuthNoCredentials) +} + +network.ContinueWithAuthCredentials = ( + action: "provideCredentials", + credentials: network.AuthCredentials +) + +network.ContinueWithAuthNoCredentials = ( + action: "default" / "cancel" +) + +network.DisownData = ( + method: "network.disownData", + params: network.disownDataParameters +) + +network.disownDataParameters = { + dataType: network.DataType, + collector: network.Collector, + request: network.Request, +} + +network.FailRequest = ( + method: "network.failRequest", + params: network.FailRequestParameters +) + +network.FailRequestParameters = { + request: network.Request, +} + +network.GetData = ( + method: "network.getData", + params: network.GetDataParameters +) + +network.GetDataParameters = { + dataType: network.DataType, + ? collector: network.Collector, + ? disown: bool .default false, + request: network.Request, +} + +network.ProvideResponse = ( + method: "network.provideResponse", + params: network.ProvideResponseParameters +) + +network.ProvideResponseParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.SetCookieHeader], + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.RemoveDataCollector = ( + method: "network.removeDataCollector", + params: network.RemoveDataCollectorParameters +) + +network.RemoveDataCollectorParameters = { + collector: network.Collector +} + +network.RemoveIntercept = ( + method: "network.removeIntercept", + params: network.RemoveInterceptParameters +) + +network.RemoveInterceptParameters = { + intercept: network.Intercept +} + +network.SetCacheBehavior = ( + method: "network.setCacheBehavior", + params: network.SetCacheBehaviorParameters +) + +network.SetCacheBehaviorParameters = { + cacheBehavior: "default" / "bypass", + ? contexts: [+browsingContext.BrowsingContext] +} + +network.SetExtraHeaders = ( + method: "network.setExtraHeaders", + params: network.SetExtraHeadersParameters +) + +network.SetExtraHeadersParameters = { + headers: [*network.Header] + ? contexts: [+browsingContext.BrowsingContext] + ? userContexts: [+browser.UserContext] +} + +ScriptCommand = ( + script.AddPreloadScript // + script.CallFunction // + script.Disown // + script.Evaluate // + script.GetRealms // + script.RemovePreloadScript +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.RealmTarget = { + realm: script.Realm +} + +script.ContextTarget = { + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.Target = ( + script.ContextTarget / + script.RealmTarget +) + +script.AddPreloadScript = ( + method: "script.addPreloadScript", + params: script.AddPreloadScriptParameters +) + +script.AddPreloadScriptParameters = { + functionDeclaration: text, + ? arguments: [*script.ChannelValue], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], + ? sandbox: text +} + +script.Disown = ( + method: "script.disown", + params: script.DisownParameters +) + +script.DisownParameters = { + handles: [*script.Handle] + target: script.Target; +} + +script.CallFunction = ( + method: "script.callFunction", + params: script.CallFunctionParameters +) + +script.CallFunctionParameters = { + functionDeclaration: text, + awaitPromise: bool, + target: script.Target, + ? arguments: [*script.LocalValue], + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? this: script.LocalValue, + ? userActivation: bool .default false, +} + +script.Evaluate = ( + method: "script.evaluate", + params: script.EvaluateParameters +) + +script.EvaluateParameters = { + expression: text, + target: script.Target, + awaitPromise: bool, + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? userActivation: bool .default false, +} + +script.GetRealms = ( + method: "script.getRealms", + params: script.GetRealmsParameters +) + +script.GetRealmsParameters = { + ? context: browsingContext.BrowsingContext, + ? type: script.RealmType, +} + +script.RemovePreloadScript = ( + method: "script.removePreloadScript", + params: script.RemovePreloadScriptParameters +) + +script.RemovePreloadScriptParameters = { + script: script.PreloadScript +} + +StorageCommand = ( + storage.DeleteCookies // + storage.GetCookies // + storage.SetCookie +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookies = ( + method: "storage.getCookies", + params: storage.GetCookiesParameters +) + + +storage.CookieFilter = { + ? name: text, + ? value: network.BytesValue, + ? domain: text, + ? path: text, + ? size: js-uint, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.BrowsingContextPartitionDescriptor = { + type: "context", + context: browsingContext.BrowsingContext +} + +storage.StorageKeyPartitionDescriptor = { + type: "storageKey", + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.PartitionDescriptor = ( + storage.BrowsingContextPartitionDescriptor / + storage.StorageKeyPartitionDescriptor +) + +storage.GetCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.SetCookie = ( + method: "storage.setCookie", + params: storage.SetCookieParameters, +) + + +storage.PartialCookie = { + name: text, + value: network.BytesValue, + domain: text, + ? path: text, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.SetCookieParameters = { + cookie: storage.PartialCookie, + ? partition: storage.PartitionDescriptor, +} + +storage.DeleteCookies = ( + method: "storage.deleteCookies", + params: storage.DeleteCookiesParameters, +) + +storage.DeleteCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +InputCommand = ( + input.PerformActions // + input.ReleaseActions // + input.SetFiles +) + +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + +input.ElementOrigin = { + type: "element", + element: script.SharedReference +} + +input.PerformActions = ( + method: "input.performActions", + params: input.PerformActionsParameters +) + +input.PerformActionsParameters = { + context: browsingContext.BrowsingContext, + actions: [*input.SourceActions] +} + +input.SourceActions = ( + input.NoneSourceActions / + input.KeySourceActions / + input.PointerSourceActions / + input.WheelSourceActions +) + +input.NoneSourceActions = { + type: "none", + id: text, + actions: [*input.NoneSourceAction] +} + +input.NoneSourceAction = input.PauseAction + +input.KeySourceActions = { + type: "key", + id: text, + actions: [*input.KeySourceAction] +} + +input.KeySourceAction = ( + input.PauseAction / + input.KeyDownAction / + input.KeyUpAction +) + +input.PointerSourceActions = { + type: "pointer", + id: text, + ? parameters: input.PointerParameters, + actions: [*input.PointerSourceAction] +} + +input.PointerType = "mouse" / "pen" / "touch" + +input.PointerParameters = { + ? pointerType: input.PointerType .default "mouse" +} + +input.PointerSourceAction = ( + input.PauseAction / + input.PointerDownAction / + input.PointerUpAction / + input.PointerMoveAction +) + +input.WheelSourceActions = { + type: "wheel", + id: text, + actions: [*input.WheelSourceAction] +} + +input.WheelSourceAction = ( + input.PauseAction / + input.WheelScrollAction +) + +input.PauseAction = { + type: "pause", + ? duration: js-uint +} + +input.KeyDownAction = { + type: "keyDown", + value: text +} + +input.KeyUpAction = { + type: "keyUp", + value: text +} + +input.PointerUpAction = { + type: "pointerUp", + button: js-uint, +} + +input.PointerDownAction = { + type: "pointerDown", + button: js-uint, + input.PointerCommonProperties +} + +input.PointerMoveAction = { + type: "pointerMove", + x: float, + y: float, + ? duration: js-uint, + ? origin: input.Origin, + input.PointerCommonProperties +} + +input.WheelScrollAction = { + type: "scroll", + x: js-int, + y: js-int, + deltaX: js-int, + deltaY: js-int, + ? duration: js-uint, + ? origin: input.Origin .default "viewport", +} + +input.PointerCommonProperties = ( + ? width: js-uint .default 1, + ? height: js-uint .default 1, + ? pressure: float .default 0.0, + ? tangentialPressure: float .default 0.0, + ? twist: (0..359) .default 0, + ; 0 .. Math.PI / 2 + ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ; 0 .. 2 * Math.PI + ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, +) + +input.Origin = "viewport" / "pointer" / input.ElementOrigin + +input.ReleaseActions = ( + method: "input.releaseActions", + params: input.ReleaseActionsParameters +) + +input.ReleaseActionsParameters = { + context: browsingContext.BrowsingContext, +} + +input.SetFiles = ( + method: "input.setFiles", + params: input.SetFilesParameters +) + +input.SetFilesParameters = { + context: browsingContext.BrowsingContext, + element: script.SharedReference, + files: [*text] +} + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionCommand = ( + webExtension.Install // + webExtension.Uninstall +) + +webExtension.Extension = text + +webExtension.Install = ( + method: "webExtension.install", + params: webExtension.InstallParameters +) + +webExtension.InstallParameters = { + extensionData: webExtension.ExtensionData, +} + +webExtension.ExtensionData = ( + webExtension.ExtensionArchivePath / + webExtension.ExtensionBase64Encoded / + webExtension.ExtensionPath +) + +webExtension.ExtensionPath = { + type: "path", + path: text, +} + +webExtension.ExtensionArchivePath = { + type: "archivePath", + path: text, +} + +webExtension.ExtensionBase64Encoded = { + type: "base64", + value: text, +} + +webExtension.Uninstall = ( + method: "webExtension.uninstall", + params: webExtension.UninstallParameters +) + +webExtension.UninstallParameters = { + extension: webExtension.Extension, +} diff --git a/py/AGENTS.md b/py/AGENTS.md index 57a5e819e1a8e..27c1aaac41e9a 100644 --- a/py/AGENTS.md +++ b/py/AGENTS.md @@ -51,24 +51,25 @@ def method(param: str | None) -> int | None: pass # Avoid -from typing import Optional def method(param: Optional[str]) -> Optional[int]: pass ``` ### Python version -Code must work with Python 3.10 or later. Use modern syntax features available in 3.10+. +Code must work with Python 3.10 or later. Use modern syntax features available in 3.10+: -See the **Type hints** section for guidance on preferred type annotation syntax (including unions). +- Use `|` for union types instead of `Union[]` +- Use `X | None` instead of `Optional[X]` -For testing: use `bazel test //py/...` which employs a hermetic Python 3.10+ toolchain (see `/AGENTS.md`). - -For ad-hoc scripts, check your Python version locally before running: +When running tests or code in the terminal, explicitly use `python3.10` or later: ```bash -python --version -# Ensure you have 3.10+; on macOS/Linux use python3.10+ or on Windows py -3.10 +# Use explicitly +python3.10 -c "..." +python3.11 -c "..." + +# Avoid relying on `python3` as it may be 3.9 or earlier ``` ### Documentation diff --git a/py/BUILD.bazel b/py/BUILD.bazel index 0a4d5228d518d..4d11a3063e66a 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -10,6 +10,7 @@ load("@rules_python//sphinxdocs:sphinx.bzl", "sphinx_build_binary", "sphinx_docs load("//common:defs.bzl", "copy_file") load("//py:defs.bzl", "generate_devtools", "py_test_suite") load("//py/private:browsers.bzl", "BROWSERS") +load("//py/private:generate_bidi.bzl", "generate_bidi") load("//py/private:import.bzl", "py_import") exports_files( @@ -574,6 +575,12 @@ py_binary( deps = [requirement("inflection")], ) +py_binary( + name = "generate_bidi", + srcs = ["generate_bidi.py"], + srcs_version = "PY3", +) + [generate_devtools( name = "create-cdp-srcs-{}".format(devtools_version), browser_protocol = "//common/devtools/chromium/{}:browser_protocol".format(devtools_version), @@ -583,6 +590,17 @@ py_binary( protocol_version = devtools_version, ) for devtools_version in BROWSER_VERSIONS] +# Pilot BiDi code generation from CDDL specification +generate_bidi( + name = "create-bidi-src", + cddl_file = "//common/bidi/spec:all.cddl", + enhancements_manifest = "//py/private:bidi_enhancements_manifest.py", + extra_srcs = ["//py/private:cdp.py"], + generator = ":generate_bidi", + module_name = "selenium/webdriver/common/bidi", + spec_version = "1.0", +) + py_test_suite( name = "unit", size = "small", @@ -762,6 +780,7 @@ BROWSER_TESTS = { ] ] + test_suite( name = "test-remote", tags = ["remote"], diff --git a/py/conftest.py b/py/conftest.py index 7cdb2446e4107..ff60cf06f08ec 100644 --- a/py/conftest.py +++ b/py/conftest.py @@ -118,6 +118,14 @@ def pytest_addoption(parser): metavar="DRIVER", help="Driver to run tests against ({})".format(", ".join(drivers)), ) + parser.addoption( + "--browser", + action="append", + choices=drivers, + dest="drivers", + metavar="BROWSER", + help="Browser to run tests against (alias for --driver)", + ) parser.addoption( "--browser-binary", action="store", diff --git a/py/generate_bidi.py b/py/generate_bidi.py new file mode 100755 index 0000000000000..1770cf436bef1 --- /dev/null +++ b/py/generate_bidi.py @@ -0,0 +1,1824 @@ +#!/usr/bin/env python3 +""" +Generate Python WebDriver BiDi command modules from CDDL specification. + +This generator reads CDDL (Concise Data Definition Language) specification files +and produces Python type definitions and command classes that conform to the +WebDriver BiDi protocol. + +Usage: + python generate_bidi.py + +Example: + python generate_bidi.py local.cddl ./selenium/webdriver/common/bidi 1.0 +""" + +import argparse +import importlib.util +import logging +import re +import sys +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from textwrap import dedent, indent as tw_indent +from typing import Any, Dict, List, Optional, Set, Tuple + +__version__ = "1.0.0" + +# Logging setup +log_level = logging.INFO +logging.basicConfig(level=log_level) +logger = logging.getLogger("generate_bidi") + +# File headers +SHARED_HEADER = """# DO NOT EDIT THIS FILE! +# +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules.""" + +MODULE_HEADER = f"""{SHARED_HEADER} +# +# WebDriver BiDi module: {{}} +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +""" + + +def indent(s: str, n: int) -> str: + """Indent a string by n spaces.""" + return tw_indent(s, n * " ") + + +def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: + """Load enhancement manifest from a Python file. + + Args: + manifest_path: Path to Python file containing ENHANCEMENTS dict + + Returns: + Dictionary with enhancement rules, or empty dict if no manifest provided + """ + if not manifest_path: + return {} + + manifest_file = Path(manifest_path) + if not manifest_file.exists(): + logger.warning(f"Enhancement manifest not found: {manifest_path}") + return {} + + try: + spec = importlib.util.spec_from_file_location( + "bidi_enhancements", manifest_file + ) + if spec is None or spec.loader is None: + logger.warning(f"Could not load manifest: {manifest_path}") + return {} + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + enhancements = getattr(module, "ENHANCEMENTS", {}) + dataclass_methods = getattr(module, "DATACLASS_METHOD_TEMPLATES", {}) + method_docstrings = getattr(module, "DATACLASS_METHOD_DOCSTRINGS", {}) + + logger.info(f"Loaded enhancement manifest from: {manifest_path}") + logger.debug(f"Enhancements for modules: {list(enhancements.keys())}") + + return { + "enhancements": enhancements, + "dataclass_methods": dataclass_methods, + "method_docstrings": method_docstrings, + } + except Exception as e: + logger.error(f"Failed to load enhancement manifest: {e}", exc_info=True) + return {} + + +class CddlType(Enum): + """CDDL type mappings to Python types.""" + + TSTR = "str" # text string + TEXT = "str" # text (alias) + UINT = "int" # unsigned integer + INT = "int" # signed integer + NINT = "int" # negative integer + BOOL = "bool" # boolean + NULL = "None" # null + ANY = "Any" # any type + + @classmethod + def get_annotation(cls, cddl_type: str) -> str: + """Get Python type annotation for a CDDL type.""" + cddl_type = cddl_type.strip().lower() + + # Handle basic types + for member in cls: + if cddl_type == member.name.lower(): + return member.value + + # Handle composite types + if cddl_type.startswith("["): # Array + inner = cddl_type.strip("[]+ ") + inner_type = cls.get_annotation(inner) + return f"List[{inner_type}]" + + if cddl_type.startswith("{"): # Map/Dict + return "Dict[str, Any]" + + # Default to Any for unknown types + return "Any" + + +@dataclass +class CddlCommand: + """Represents a CDDL command definition.""" + + module: str + name: str + params: Dict[str, str] = field(default_factory=dict) + result: Optional[str] = None + description: str = "" + + def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + """Generate Python method code for this command. + + Args: + enhancements: Dictionary with enhancement rules for this method + """ + enhancements = enhancements or {} + method_name = self._camel_to_snake(self.name) + + # Build parameter list with type hints + # Check if there's a params_override for user-friendly named arguments + params_to_use = self.params + if "params_override" in enhancements: + params_to_use = enhancements["params_override"] + + param_strs = [] + param_names = [] # Keep track of parameter names for later use + for param_name, param_type in params_to_use.items(): + if param_type in ["bool", "str", "int"]: + python_type = param_type + else: + python_type = CddlType.get_annotation(param_type) + snake_param = self._camel_to_snake(param_name) + param_names.append((param_name, snake_param)) + param_strs.append(f"{snake_param}: {python_type} | None = None") + + if param_strs: + param_list = "self, " + ", ".join(param_strs) + else: + param_list = "self" + + # Build method body + body = f" def {method_name}({param_list}):\n" + body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' + + # Add validation if specified + if "validate" in enhancements: + validate_func = enhancements["validate"] + # Build parameter list for validation function + param_args = ", ".join(f"{snake}={snake}" for _, snake in param_names) + body += f" {validate_func}({param_args})\n" + body += "\n" + + # Add transformation and preprocessing + # First, check if any transform is needed + if "transform" in enhancements: + transform_spec = enhancements["transform"] + if isinstance(transform_spec, dict): + # New format with explicit function and result parameter + transform_func = transform_spec.get("func") + result_param = transform_spec.get("result_param", "params") + input_params = [ + transform_spec.get(k) + for k in ["allowed", "destination_folder"] + if transform_spec.get(k) + ] + + if transform_func and result_param: + body += f" {result_param} = None\n" + param_args = ", ".join(input_params) + body += f" {result_param} = {transform_func}({param_args})\n" + body += "\n" + else: + # Legacy format for backward compatibility + transform_func = transform_spec + if self.name == "setDownloadBehavior": + body += " download_behavior = None\n" + body += f" download_behavior = {transform_func}(allowed, destination_folder)\n" + body += "\n" + + # Add preprocessing for serialization (check for to_bidi_dict method) + if "preprocess" in enhancements: + preprocess_rules = enhancements["preprocess"] + for param_name, preprocess_type in preprocess_rules.items(): + snake_param = self._camel_to_snake(param_name) + if preprocess_type == "check_serialize_method": + body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n" + body += ( + f" {snake_param} = {snake_param}.to_bidi_dict()\n" + ) + body += "\n" + + # Build params dict + body += " params = {\n" + + # If there's a transform with a result parameter, map it to the BiDi protocol name + if "transform" in enhancements and isinstance(enhancements["transform"], dict): + transform_spec = enhancements["transform"] + result_param = transform_spec.get("result_param") + + # Map the result parameter to the original CDDL parameter name + if result_param == "download_behavior": + body += ' "downloadBehavior": download_behavior,\n' + # Add remaining parameters that weren't part of the transform + override_params = enhancements.get("params_override", {}) + for cddl_param_name in self.params: + if cddl_param_name not in ["downloadBehavior"]: + snake_name = self._camel_to_snake(cddl_param_name) + body += f' "{cddl_param_name}": {snake_name},\n' + else: + # Standard parameter mapping from CDDL + for param_name, snake_param in param_names: + body += f' "{param_name}": {snake_param},\n' + + body += " }\n" + body += " params = {k: v for k, v in params.items() if v is not None}\n" + body += f' cmd = command_builder("{self.module}.{self.name}", params)\n' + body += " result = self._conn.execute(cmd)\n" + + # Add response handling for extraction/deserialization + if "extract_field" in enhancements: + extract_field = enhancements["extract_field"] + extract_property = enhancements.get("extract_property") + + # Check if we also need to deserialize the extracted field + deserialize_rules = enhancements.get("deserialize", {}) + + if extract_property: + # Extract property from list items + body += f' if result and "{extract_field}" in result:\n' + body += f' items = result.get("{extract_field}", [])\n' + body += f" return [\n" + body += f' item.get("{extract_property}")\n' + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" + elif extract_field in deserialize_rules: + # Extract field and deserialize to typed objects + type_name = deserialize_rules[extract_field] + body += f' if result and "{extract_field}" in result:\n' + body += f' items = result.get("{extract_field}", [])\n' + body += f" return [\n" + body += f" {type_name}(\n" + body += self._generate_field_args(extract_field, type_name) + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" + else: + # Simple field extraction (return the value directly, not wrapped in result dict) + body += f' if result and "{extract_field}" in result:\n' + body += f' extracted = result.get("{extract_field}")\n' + body += f" return extracted\n" + body += f" return result\n" + elif "deserialize" in enhancements: + # Deserialize response to typed objects (legacy, without extract_field) + deserialize_rules = enhancements["deserialize"] + for response_field, type_name in deserialize_rules.items(): + body += f' if result and "{response_field}" in result:\n' + body += f' items = result.get("{response_field}", [])\n' + body += f" return [\n" + body += f" {type_name}(\n" + body += self._generate_field_args(response_field, type_name) + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" + else: + # No special response handling, just return the result + body += " return result\n" + + return body + + def _generate_field_args(self, response_field: str, type_name: str) -> str: + """Generate constructor arguments for deserializing response objects. + + For now, this handles ClientWindowInfo and Info specifically. + Could be extended to be more generic. + """ + if type_name == "ClientWindowInfo": + return ( + ' active=item.get("active"),\n' + ' client_window=item.get("clientWindow"),\n' + ' height=item.get("height"),\n' + ' state=item.get("state"),\n' + ' width=item.get("width"),\n' + ' x=item.get("x"),\n' + ' y=item.get("y")\n' + ) + elif type_name == "Info": + return ( + ' children=_deserialize_info_list(item.get("children", [])),\n' + ' client_window=item.get("clientWindow"),\n' + ' context=item.get("context"),\n' + ' original_opener=item.get("originalOpener"),\n' + ' url=item.get("url"),\n' + ' user_context=item.get("userContext"),\n' + ' parent=item.get("parent")\n' + ) + # For other types, return empty + return "" + + @staticmethod + def _camel_to_snake(name: str) -> str: + """Convert camelCase to snake_case.""" + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +@dataclass +class CddlTypeDefinition: + """Represents a CDDL type definition.""" + + module: str + name: str + fields: Dict[str, str] = field(default_factory=dict) + description: str = "" + + def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + """Generate Python dataclass code for this type. + + Args: + enhancements: Dictionary containing dataclass_methods and method_docstrings + """ + enhancements = enhancements or {} + dataclass_methods = enhancements.get("dataclass_methods", {}) + method_docstrings = enhancements.get("method_docstrings", {}) + + # Generate class name from type name (keep it as-is, don't split on underscores) + class_name = self.name + code = f"@dataclass\n" + code += f"class {class_name}:\n" + code += f' """{self.description or self.name}."""\n\n' + + if not self.fields: + code += " pass\n" + else: + for field_name, field_type in self.fields.items(): + # Convert CDDL type to Python type + python_type = self._get_python_type(field_type) + snake_name = CddlCommand._camel_to_snake(field_name) + + # Check if the CDDL field type is a quoted string literal (e.g., type: "key") + # These are discriminant fields: auto-populate and exclude from __init__ + # so callers don't need to pass them as positional or keyword arguments. + literal_match = re.match(r'^"([^"]+)"$', field_type.strip()) + if literal_match: + literal_value = literal_match.group(1) + code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' + # Check if this field is a list type + elif "List[" in python_type: + code += f" {snake_name}: {python_type} = field(default_factory=list)\n" + else: + code += f" {snake_name}: {python_type} = None\n" + + # Add custom methods if defined for this class + if class_name in dataclass_methods: + code += "\n" + methods_dict = dataclass_methods[class_name] + docstrings_dict = method_docstrings.get(class_name, {}) + + for method_name in methods_dict: + method_impl = methods_dict[method_name] + docstring = docstrings_dict.get(method_name, "") + code += f" def {method_name}(self):\n" + if docstring: + code += f' """{docstring}"""\n' + code += f" {method_impl}\n" + code += "\n" + + return code + + @staticmethod + def _get_python_type(cddl_type: str) -> str: + """Convert CDDL type to Python type annotation using Python 3.10+ syntax.""" + cddl_type = cddl_type.strip().lower() + + # Handle basic types + type_mapping = { + "tstr": "str", + "text": "str", + "uint": "int", + "int": "int", + "nint": "int", + "bool": "bool", + "null": "None", + } + + for cddl, python in type_mapping.items(): + if cddl_type == cddl: + # Use Python 3.10+ union syntax: type | None + return f"{python} | None" + + # Handle arrays + if cddl_type.startswith("["): + inner = cddl_type.strip("[]+ ") + inner_type = CddlTypeDefinition._get_python_type(inner) + # Remove " | None" from inner type since it might be wrapped + if " | None" in inner_type: + inner_base = inner_type.replace(" | None", "") + return f"list[{inner_base} | None] | None" + return f"list[{inner_type}] | None" + + # Handle maps/dicts + if cddl_type.startswith("{"): + return "dict[str, Any] | None" + + # Default to Any for unknown/complex types + return "Any | None" + + +@dataclass +class CddlEnum: + """Represents a CDDL enum definition (string union).""" + + module: str + name: str + values: List[str] = field(default_factory=list) + description: str = "" + + def to_python_class(self) -> str: + """Generate Python enum class code. + + Generates a simple class with string constants to match the existing + pattern in the codebase (e.g., ClientWindowState). + """ + class_name = self.name + code = f"class {class_name}:\n" + code += f' """{self.description or self.name}."""\n\n' + + for value in self.values: + # Convert value to UPPER_SNAKE_CASE constant name + const_name = self._value_to_const_name(value) + code += f' {const_name} = "{value}"\n' + + return code + + @staticmethod + def _value_to_const_name(value: str) -> str: + """Convert enum string value to constant name. + + Examples: + "none" -> "NONE" + "portrait-primary" -> "PORTRAIT_PRIMARY" + "interactive" -> "INTERACTIVE" + """ + # Replace hyphens with underscores + const_name = value.replace("-", "_") + # Convert to uppercase + return const_name.upper() + + +@dataclass +class CddlEvent: + """Represents a CDDL event definition (incoming message from browser).""" + + module: str + name: str + method: str + params_type: str + description: str = "" + + def to_python_dataclass(self) -> str: + """Generate Python dataclass code for the event info type. + + Returns a dataclass code that attempts to use globals().get() for safety. + """ + class_name = self.name + + # Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info") + # The params_type comes from the CDDL and includes module prefix + type_name = ( + self.params_type.split(".")[-1] + if "." in self.params_type + else self.params_type + ) + + # Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly + # (NavigationInfo will be created as an alias to it) + if type_name == "NavigationInfo": + type_name = "BaseNavigationInfo" + + # Generate type alias using globals().get() for safety + code = f"# Event: {self.method}\n" + code += f"{class_name} = globals().get('{type_name}', dict) # Fallback to dict if type not defined\n" + + return code + + +@dataclass +class CddlModule: + """Represents a CDDL module (e.g., script, network, browsing_context).""" + + name: str + commands: List[CddlCommand] = field(default_factory=list) + types: List[CddlTypeDefinition] = field(default_factory=list) + enums: List[CddlEnum] = field(default_factory=list) + events: List[CddlEvent] = field(default_factory=list) + + @staticmethod + def _convert_method_to_event_name(method_suffix: str) -> str: + """Convert BiDi method suffix to friendly event name. + + Examples: + "contextCreated" -> "context_created" + "navigationStarted" -> "navigation_started" + "userPromptOpened" -> "user_prompt_opened" + """ + # Convert camelCase to snake_case + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + """Generate Python code for this module. + + Args: + enhancements: Dictionary with module-level enhancements + """ + enhancements = enhancements or {} + code = MODULE_HEADER.format(self.name) + + # Add imports if needed + if self.types: + code += "from dataclasses import field\n" + if self.commands or self.types: + code += "from typing import Generator\n" + code += "from dataclasses import dataclass\n" + + # Add imports for event handling if needed + if self.events: + code += "import threading\n" + code += "from collections.abc import Callable\n" + code += "from dataclasses import dataclass\n" + code += "from selenium.webdriver.common.bidi.session import Session\n" + + code += "\n\n" + + # Add helper function definitions from enhancements + # Collect all referenced helper functions (validate, transform) + helper_funcs_to_add = set() + for cmd in self.commands: + method_name_snake = cmd._camel_to_snake(cmd.name) + method_enhancements = enhancements.get(method_name_snake, {}) + if "validate" in method_enhancements: + helper_funcs_to_add.add(("validate", method_enhancements["validate"])) + if "transform" in method_enhancements and isinstance( + method_enhancements["transform"], dict + ): + transform_spec = method_enhancements["transform"] + if "func" in transform_spec: + helper_funcs_to_add.add(("transform", transform_spec["func"])) + + # Generate helper functions if needed + if helper_funcs_to_add: + for func_type, func_name in sorted(helper_funcs_to_add): + if ( + func_type == "validate" + and func_name == "validate_download_behavior" + ): + code += """def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + \"\"\"Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts + + Raises: + ValueError: If parameters are invalid + \"\"\" + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +""" + elif ( + func_type == "transform" + and func_name == "transform_download_params" + ): + code += """def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any] | None: + \"\"\"Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads (accepts str or + pathlib.Path; will be coerced to str) + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + \"\"\" + if allowed is True: + return { + "type": "allowed", + # Coerce pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": str(destination_folder) if destination_folder is not None else None, + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +""" + + # Generate enums first + for enum_def in self.enums: + code += enum_def.to_python_class() + code += "\n\n" + + # Emit module-level aliases from enhancement manifest (e.g. LogLevel = Level) + for alias, target in enhancements.get("aliases", {}).items(): + code += f"{alias} = {target}\n\n" + + # Generate type dataclasses, skipping any overridden by extra_dataclasses + exclude_types = set(enhancements.get("exclude_types", [])) + for type_def in self.types: + if type_def.name in exclude_types: + continue + code += type_def.to_python_dataclass(enhancements) + code += "\n\n" + + # Emit extra dataclasses from enhancement manifest (non-CDDL additions) + for extra_cls in enhancements.get("extra_dataclasses", []): + code += extra_cls + code += "\n\n" + + # NOTE: Don't generate event type aliases here - they reference types that may not be defined yet + # They will be generated after the class definition instead + + # Generate EVENT_NAME_MAPPING for the module (before the module class) + if self.events: + # Generate EVENT_NAME_MAPPING for the module + code += "# BiDi Event Name to Parameter Type Mapping\n" + code += "EVENT_NAME_MAPPING = {\n" + for event_def in self.events: + # Convert method name to user-friendly event name + # e.g., "browsingContext.contextCreated" -> "context_created" + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + code += f' "{event_name}": "{event_def.method}",\n' + # Extra events not in the CDDL spec (e.g. Chromium-specific events) + for extra_evt in enhancements.get("extra_events", []): + code += ( + f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' + ) + code += "}\n\n" + + # Add custom method function definitions before the class (for browsingContext) + if self.name == "browsingContext": + # Add helper function for recursive Info deserialization + code += """def _deserialize_info_list(items: list) -> list | None: + \"\"\"Recursively deserialize a list of dicts to Info objects. + + Args: + items: List of dicts from the API response + + Returns: + List of Info objects with properly nested children, or None if empty + \"\"\" + if not items or not isinstance(items, list): + return None + + result = [] + for item in items: + if isinstance(item, dict): + # Recursively deserialize children only if the key exists in response + children_list = None + if "children" in item: + children_list = _deserialize_info_list(item.get("children", [])) + info = Info( + children=children_list, + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent"), + ) + result.append(info) + return result if result else None + + +""" + code += "\n\n" + + # Generate EventConfig and _EventManager for modules with events + if self.events: + # Generate EventConfig dataclass + code += """@dataclass +class EventConfig: + \"\"\"Configuration for a BiDi event.\"\"\" + event_key: str + bidi_event: str + event_class: type + + +""" + + # Generate _EventManager class + code += """class _EventWrapper: + \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + \"\"\"Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + \"\"\" + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, \"from_json\") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend([\"_\", char.lower()]) + else: + result.append(char) + return \"\".join(result) + + +class _EventManager: + \"\"\"Manages event subscriptions and callbacks.\"\"\" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get(\"subscription\") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + \"callbacks\": [], + \"subscription_id\": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry[\"callbacks\"]: + session = Session(self.conn) + sub_id = entry.get(\"subscription_id\") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry[\"callbacks\"]: + entry[\"callbacks\"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get(\"subscription_id\") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + +""" + code += "\n\n" + + # Generate class + # Convert module name (camelCase or snake_case) to proper class name (PascalCase) + class_name = module_name_to_class_name(self.name) + code += f"class {class_name}:\n" + code += f' """WebDriver BiDi {self.name} module."""\n\n' + + # Add EVENT_CONFIGS dict if there are events + if self.events: + code += ( + " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + ) + + if self.name == "script": + code += " def __init__(self, conn, driver=None) -> None:\n" + code += " self._conn = conn\n" + code += " self._driver = driver\n" + else: + code += " def __init__(self, conn) -> None:\n" + code += " self._conn = conn\n" + + # Initialize _event_manager if there are events + if self.events: + code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" + + # Append extra init code from enhancements (e.g. self.intercepts = []) + for init_line in enhancements.get("extra_init_code", []): + code += f" {init_line}\n" + + code += "\n" + + # Generate command methods + exclude_methods = enhancements.get("exclude_methods", []) + if self.commands: + for command in self.commands: + # Get method-specific enhancements + # Convert command name to snake_case to match enhancement manifest keys + method_name_snake = command._camel_to_snake(command.name) + if method_name_snake in exclude_methods: + continue + method_enhancements = enhancements.get(method_name_snake, {}) + code += command.to_python_method(method_enhancements) + code += "\n" + else: + code += " pass\n" + + # Emit extra methods from enhancement manifest + for extra_method in enhancements.get("extra_methods", []): + code += extra_method + code += "\n" + + # Add delegating event handler methods if events are present + if self.events: + code += """ + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + \"\"\"Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + \"\"\" + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + \"\"\"Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + \"\"\" + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + return self._event_manager.clear_event_handlers() +""" + + # Generate event info type aliases AFTER the class definition + # This ensures all types are available when we create the aliases + if self.events: + code += "\n# Event Info Type Aliases\n" + for event_def in self.events: + code += event_def.to_python_dataclass() + code += "\n" + + # Now populate EVENT_CONFIGS after the aliases are defined + code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + # Use globals() to look up types dynamically to handle missing types gracefully + code += f"_globals = globals()\n" + code += f"{class_name}.EVENT_CONFIGS = {{\n" + for event_def in self.events: + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # The event class is the event name (e.g., ContextCreated) + # Try to get it from globals, default to dict if not found + code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + # Extra events not in the CDDL spec + for extra_evt in enhancements.get("extra_events", []): + ek = extra_evt["event_key"] + be = extra_evt["bidi_event"] + ec = extra_evt["event_class"] + code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + code += "}\n" + + return code + + +class CddlParser: + """Parse CDDL specification files.""" + + def __init__(self, cddl_path: str): + """Initialize parser with CDDL file path.""" + self.cddl_path = Path(cddl_path) + self.content = "" + self.modules: Dict[str, CddlModule] = {} + self.definitions: Dict[str, str] = {} + self.event_names: Set[str] = set() # Names of definitions that are events + self._read_file() + + def _read_file(self) -> None: + """Read and preprocess CDDL file.""" + if not self.cddl_path.exists(): + raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") + + with open(self.cddl_path, "r", encoding="utf-8") as f: + self.content = f.read() + + logger.info(f"Loaded CDDL file: {self.cddl_path}") + + def parse(self) -> Dict[str, CddlModule]: + """Parse CDDL content and return modules.""" + # Remove comments + content = self._remove_comments(self.content) + + # Extract all definitions + self._extract_definitions(content) + + # Extract event names from event union definitions + self._extract_event_names() + + # Extract type definitions by module + self._extract_types() + + # Extract event definitions by module + self._extract_events() + + # Extract command definitions by module + self._extract_commands() + + # If no modules found, create a default one from the filename + if not self.modules: + module_name = self.cddl_path.stem + default_module = CddlModule(name=module_name) + self.modules[module_name] = default_module + logger.warning(f"No modules found in CDDL, creating default: {module_name}") + + return self.modules + + def _remove_comments(self, content: str) -> str: + """Remove comments from CDDL content.""" + # CDDL uses ; for comments to end of line + lines = content.split("\n") + cleaned = [] + for line in lines: + if ";" in line and not line.strip().startswith(";"): + line = line[: line.index(";")] + elif line.strip().startswith(";"): + continue + cleaned.append(line) + return "\n".join(cleaned) + + def _extract_definitions(self, content: str) -> None: + """Extract CDDL definitions (type definitions, commands, etc.).""" + # Match pattern: Name = Definition + # Handles multiline definitions properly + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + + for match in re.finditer(pattern, content, re.DOTALL): + name = match.group(1).strip() + definition = match.group(2).strip() + self.definitions[name] = definition + logger.debug(f"Extracted definition: {name}") + + def _extract_event_names(self) -> None: + """Extract event names from event union definitions. + + Event union definitions follow pattern: + module.ModuleEvent = ( + module.EventName1 // + module.EventName2 // + ... + ) + """ + # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. + event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") + + for def_name, def_content in self.definitions.items(): + # Check if this looks like an event union (name ends with "Event") and + # contains a module-qualified reference like "module.EventName". + # Handles both single-item (no //) and multi-item (// separated) unions. + if "Event" in def_name and re.search(r"\w+\.\w+", def_content): + # Extract event names from the union (works for single and multi-item) + event_refs = re.findall(r"(\w+\.\w+)", def_content) + for event_ref in event_refs: + self.event_names.add(event_ref) + logger.debug(f"Identified event: {event_ref} (from {def_name})") + + def _extract_types(self) -> None: + """Extract type definitions from parsed definitions.""" + # Type definitions follow pattern: module.TypeName = { field: type, ... } + # They have dots in the name and curly braces in the content + # But they DON'T have method: "..." pattern (which means it's not a command) + # Enums follow pattern: module.EnumName = "value1" / "value2" / ... + + for def_name, def_content in self.definitions.items(): + # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") + if "." not in def_name: + continue + + # Skip if it's a command (contains method: pattern) + if "method:" in def_content: + continue + + # Extract module.TypeName + if "." in def_name: + module_name, type_name = def_name.rsplit(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Check if this is an enum (string union with /) + if self._is_enum_definition(def_content): + # Extract enum values + values = self._extract_enum_values(def_content) + if values: + enum_def = CddlEnum( + module=module_name, + name=type_name, + values=values, + description=f"{type_name}", + ) + self.modules[module_name].enums.append(enum_def) + logger.debug( + f"Found enum: {def_name} with {len(values)} values" + ) + else: + # Extract fields from type definition + fields = self._extract_type_fields(def_content) + + if fields: # Only create type if it has fields + type_def = CddlTypeDefinition( + module=module_name, + name=type_name, + fields=fields, + description=f"{type_name}", + ) + self.modules[module_name].types.append(type_def) + logger.debug( + f"Found type: {def_name} with {len(fields)} fields" + ) + + def _is_enum_definition(self, definition: str) -> bool: + """Check if a definition is an enum (string union with /). + + Enums are defined as: "value1" / "value2" / "value3" + """ + # Clean whitespace + clean_def = definition.strip() + + # Must not have curly braces (that would be a type definition) + if "{" in clean_def or "}" in clean_def: + return False + + # Must contain the union operator / surrounded by quotes + # Pattern: "something" / "something_else" + return " / " in clean_def and '"' in clean_def + + def _extract_enum_values(self, enum_definition: str) -> List[str]: + """Extract individual values from an enum definition. + + Enums are defined as: "value1" / "value2" / "value3" + Can span multiple lines. + """ + values = [] + + # Clean the definition and extract quoted strings + # Split by / and extract quoted values + parts = enum_definition.split("/") + + for part in parts: + part = part.strip() + + # Extract quoted string - use search instead of match to find quotes anywhere + match = re.search(r'"([^"]*)"', part) + if match: + value = match.group(1) + values.append(value) + logger.debug(f"Extracted enum value: {value}") + + return values + + @staticmethod + def _normalize_cddl_type(field_type: str) -> str: + """Normalize a CDDL type expression to a simple Python-compatible form. + + Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and + replaces interval/constraint expressions with their base types so that + the caller can safely check for nested struct syntax. + + Examples: + '(float .ge 0.0) .default 1.0' -> 'float' + '(float .ge 0.0) / null' -> 'float / null' + '(0.0...360.0) / null' -> 'float / null' + '-90.0..90.0' -> 'float' + 'float / null .default null' -> 'float / null' + """ + result = field_type + # Remove trailing .default annotations + result = re.sub(r"\s*\.default\s+\S+", "", result) + # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType + result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) + # Replace parenthesised numeric interval types: (0.0...360.0) -> float + result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) + # Replace bare numeric interval types: -90.0..90.0 -> float + result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) + return result.strip() + + def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + """Extract fields from a type definition block.""" + fields = {} + + # Remove outer braces + clean_def = type_definition.strip() + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Parse each line for field: type patterns + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line or line.startswith("//"): + continue + + # Match pattern: [?] fieldName: type + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + field_name = match.group(1).strip() + field_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(field_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + fields[field_name] = normalized_type + logger.debug(f"Extracted field {field_name}: {normalized_type}") + + return fields + + def _extract_events(self) -> None: + """Extract event definitions from parsed definitions. + + Events are definitions that: + 1. Are listed in an event union (e.g., BrowsingContextEvent) + 2. Have method: "..." and params: ... fields + + Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) + """ + # Find definitions that are in the event_names set + event_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip if not identified as an event + if def_name not in self.event_names: + continue + + # Extract method and params + match = event_pattern.search(def_content) + if match: + method = match.group(1) # e.g., "browsingContext.contextCreated" + params_type = match.group(2) # e.g., "browsingContext.Info" + + # Extract module name from method + if "." in method: + module_name, _ = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract event name from definition name (e.g., browsingContext.ContextCreated) + _, event_name = def_name.rsplit(".", 1) + + # Create event + event = CddlEvent( + module=module_name, + name=event_name, + method=method, + params_type=params_type, + description=f"Event: {method}", + ) + + self.modules[module_name].events.append(event) + logger.debug( + f"Found event: {def_name} (method={method}, params={params_type})" + ) + + def _extract_commands(self) -> None: + """Extract command definitions from parsed definitions.""" + # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) + command_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip definitions that are events (they share the same pattern) + if def_name in self.event_names: + continue + matches = list(command_pattern.finditer(def_content)) + if matches: + for match in matches: + method = match.group(1) # e.g., "session.new" + params_type = match.group(2) # e.g., "session.NewParameters" + + # Extract module name from method + if "." in method: + module_name, command_name = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract parameters + params = self._extract_parameters(params_type) + + # Create command + cmd = CddlCommand( + module=module_name, + name=command_name, + params=params, + description=f"Execute {method}", + ) + + self.modules[module_name].commands.append(cmd) + logger.debug( + f"Found command: {method} with params {params_type}" + ) + + def _extract_parameters( + self, params_type: str, _seen: Optional[Set[str]] = None + ) -> Dict[str, str]: + """Extract parameters from a parameter type definition. + + Handles both struct types ({...}) and top-level union types (TypeA / TypeB), + merging all fields from each alternative as optional parameters. + """ + params = {} + + if _seen is None: + _seen = set() + if params_type in _seen: + return params + _seen.add(params_type) + + if params_type not in self.definitions: + logger.debug(f"Parameter type not found: {params_type}") + return params + + definition = self.definitions[params_type] + + # Handle top-level type alias that is a union of other named types: + # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + # These definitions contain a single line with "/" separating type names + # (not the double-slash "//" used for command unions). + stripped = definition.strip() + if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: + # Each token separated by "/" should be a named type reference + alternatives = [a.strip() for a in stripped.split("/") if a.strip()] + all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) + if all_named: + for alt_type in alternatives: + alt_params = self._extract_parameters(alt_type, _seen) + params.update(alt_params) + return params + + # Remove the outer curly braces and split by comma + # Then parse each line for key: type patterns + clean_def = stripped + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Split by newlines and process each line + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line: + continue + + # Match pattern: [?] name: type + # Using a simple pattern that handles optional prefix + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + param_name = match.group(1).strip() + param_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(param_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + params[param_name] = normalized_type + logger.debug( + f"Extracted param {param_name}: {normalized_type} from {params_type}" + ) + + return params + + +def module_name_to_class_name(module_name: str) -> str: + """Convert module name to class name (PascalCase). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + """ + if "_" in module_name: + # Snake_case: browsing_context -> BrowsingContext + return "".join(word.capitalize() for word in module_name.split("_")) + else: + # CamelCase: browsingContext -> BrowsingContext + return module_name[0].upper() + module_name[1:] if module_name else "" + + +def module_name_to_filename(module_name: str) -> str: + """Convert module name to Python filename (snake_case). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + Special cases: + - browsingContext -> browsing_context + - webExtension -> webextension + """ + # Handle explicit mappings for known camelCase names + camel_to_snake_map = { + "browsingContext": "browsing_context", + "webExtension": "webextension", + } + + if module_name in camel_to_snake_map: + return camel_to_snake_map[module_name] + + if "_" in module_name: + # Already snake_case + return module_name + else: + # Convert camelCase to snake_case for other cases + # This handles cases like "myModuleName" -> "my_module_name" + import re + + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: + """Generate __init__.py file for the module.""" + init_path = output_path / "__init__.py" + + code = f"""{SHARED_HEADER} + +from __future__ import annotations + +""" + + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + filename = module_name_to_filename(module_name) + code += f"from .{filename} import {class_name}\n" + + code += f"\n__all__ = [\n" + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + code += f' "{class_name}",\n' + code += "]\n" + + with open(init_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {init_path}") + + +def generate_common_file(output_path: Path) -> None: + """Generate common.py file with shared utilities.""" + common_path = output_path / "common.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""Common utilities for BiDi command construction."""\n' + "\n" + "from typing import Any, Dict, Generator\n" + "\n" + "\n" + "def command_builder(\n" + " method: str, params: Dict[str, Any]\n" + ") -> Generator[Dict[str, Any], Any, Any]:\n" + ' """Build a BiDi command generator.\n' + "\n" + " Args:\n" + ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' + " params: The parameters for the command\n" + "\n" + " Yields:\n" + " A dictionary representing the BiDi command\n" + "\n" + " Returns:\n" + " The result from the BiDi command execution\n" + ' """\n' + ' result = yield {"method": method, "params": params}\n' + " return result\n" + ) + + with open(common_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {common_path}") + + +def generate_console_file(output_path: Path) -> None: + """Generate console.py file with Console enum helper.""" + console_path = output_path / "console.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + "from enum import Enum\n" + "\n" + "\n" + "class Console(Enum):\n" + ' ALL = "all"\n' + ' LOG = "log"\n' + ' ERROR = "error"\n' + ) + + with open(console_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {console_path}") + + +def generate_permissions_file(output_path: Path) -> None: + """Generate permissions.py file with permission-related classes.""" + permissions_path = output_path / "permissions.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""WebDriver BiDi Permissions module."""\n' + "\n" + "from __future__ import annotations\n" + "\n" + "from enum import Enum\n" + "from typing import Any, Optional, Union\n" + "\n" + "from .common import command_builder\n" + "\n" + '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' + "\n" + "\n" + "class PermissionState(str, Enum):\n" + ' """Permission state enumeration."""\n' + "\n" + ' GRANTED = "granted"\n' + ' DENIED = "denied"\n' + ' PROMPT = "prompt"\n' + "\n" + "\n" + "class PermissionDescriptor:\n" + ' """Descriptor for a permission."""\n' + "\n" + " def __init__(self, name: str) -> None:\n" + ' """Initialize a PermissionDescriptor.\n' + "\n" + " Args:\n" + " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" + ' """\n' + " self.name = name\n" + "\n" + " def __repr__(self) -> str:\n" + " return f\"PermissionDescriptor('{self.name}')\"\n" + "\n" + "\n" + "class Permissions:\n" + ' """WebDriver BiDi Permissions module."""\n' + "\n" + " def __init__(self, websocket_connection: Any) -> None:\n" + ' """Initialize the Permissions module.\n' + "\n" + " Args:\n" + " websocket_connection: The WebSocket connection for sending BiDi commands\n" + ' """\n' + " self._conn = websocket_connection\n" + "\n" + " def set_permission(\n" + " self,\n" + " descriptor: Union[PermissionDescriptor, str],\n" + " state: Union[PermissionState, str],\n" + " origin: Optional[str] = None,\n" + " user_context: Optional[str] = None,\n" + " ) -> None:\n" + ' """Set a permission for a given origin.\n' + "\n" + " Args:\n" + " descriptor: The permission descriptor or permission name as a string\n" + " state: The desired permission state\n" + " origin: The origin for which to set the permission\n" + " user_context: Optional user context ID to scope the permission\n" + "\n" + " Raises:\n" + " ValueError: If the state is not a valid permission state\n" + ' """\n' + " state_value = state.value if isinstance(state, PermissionState) else state\n" + " if state_value not in _VALID_PERMISSION_STATES:\n" + " raise ValueError(\n" + ' f"Invalid permission state: {state_value!r}. "\n' + ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' + " )\n" + "\n" + " if isinstance(descriptor, str):\n" + ' descriptor_dict = {"name": descriptor}\n' + " else:\n" + ' descriptor_dict = {"name": descriptor.name}\n' + "\n" + " params: dict[str, Any] = {\n" + ' "descriptor": descriptor_dict,\n' + ' "state": state_value,\n' + " }\n" + " if origin is not None:\n" + ' params["origin"] = origin\n' + " if user_context is not None:\n" + ' params["userContext"] = user_context\n' + "\n" + ' cmd = command_builder("permissions.setPermission", params)\n' + " self._conn.execute(cmd)\n" + ) + + with open(permissions_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {permissions_path}") + + +def main( + cddl_file: str, + output_dir: str, + spec_version: str = "1.0", + enhancements_manifest: Optional[str] = None, +) -> None: + """Main entry point. + + Args: + cddl_file: Path to CDDL specification file + output_dir: Output directory for generated modules + spec_version: BiDi spec version + enhancements_manifest: Path to enhancement manifest Python file + """ + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + logger.info(f"WebDriver BiDi Code Generator v{__version__}") + logger.info(f"Input CDDL: {cddl_file}") + logger.info(f"Output directory: {output_path}") + logger.info(f"Spec version: {spec_version}") + + # Load enhancement manifest + manifest = load_enhancements_manifest(enhancements_manifest) + if manifest: + logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") + + # Parse CDDL + parser = CddlParser(cddl_file) + modules = parser.parse() + + logger.info(f"Parsed {len(modules)} modules") + + # Clean up existing generated files + for file_path in output_path.glob("*.py"): + if file_path.name != "py.typed" and not file_path.name.startswith("_"): + file_path.unlink() + logger.debug(f"Removed: {file_path}") + + # Generate module files using snake_case filenames + for module_name, module in sorted(modules.items()): + filename = module_name_to_filename(module_name) + module_path = output_path / f"{filename}.py" + + # Get module-specific enhancements (merge with dataclass templates) + module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) + + # Add dataclass methods and docstrings to the enhancement data for this module + full_module_enhancements = { + **module_enhancements, + "dataclass_methods": manifest.get("dataclass_methods", {}), + "method_docstrings": manifest.get("method_docstrings", {}), + } + + with open(module_path, "w", encoding="utf-8") as f: + f.write(module.generate_code(full_module_enhancements)) + logger.info(f"Generated: {module_path}") + + # Generate __init__.py + generate_init_file(output_path, modules) + + # Generate common.py + generate_common_file(output_path) + + # Generate permissions.py + generate_permissions_file(output_path) + + # Generate console.py + generate_console_file(output_path) + + # Create py.typed marker + py_typed_path = output_path / "py.typed" + py_typed_path.touch() + logger.info(f"Generated type marker: {py_typed_path}") + + logger.info("Code generation complete!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate Python WebDriver BiDi modules from CDDL specification" + ) + parser.add_argument( + "cddl_file", + help="Path to CDDL specification file", + ) + parser.add_argument( + "output_dir", + help="Output directory for generated Python modules", + ) + parser.add_argument( + "--version", + default="1.0", + help="BiDi spec version (default: 1.0)", + ) + parser.add_argument( + "--enhancements-manifest", + default=None, + help="Path to enhancement manifest Python file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger("generate_bidi").setLevel(logging.DEBUG) + + try: + main( + args.cddl_file, + args.output_dir, + args.version, + args.enhancements_manifest, + ) + sys.exit(0) + except Exception as e: + logger.error(f"Generation failed: {e}", exc_info=True) + sys.exit(1) diff --git a/py/private/BUILD.bazel b/py/private/BUILD.bazel index 8b02ac341a0dc..88acc9d2aba11 100644 --- a/py/private/BUILD.bazel +++ b/py/private/BUILD.bazel @@ -1,5 +1,10 @@ load("@rules_python//python:defs.bzl", "py_binary") +exports_files([ + "bidi_enhancements_manifest.py", + "cdp.py", +]) + py_binary( name = "untar", srcs = [ diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py new file mode 100644 index 0000000000000..ae7229f6ddebd --- /dev/null +++ b/py/private/bidi_enhancements_manifest.py @@ -0,0 +1,1557 @@ +""" +Enhancement manifest for BiDi code generation. + +This file defines custom enhancements applied to generated BiDi modules, +including custom dataclass methods, parameter validation/transformation, +response deserialization, and field extraction. + +All code must be compatible with Python 3.10+. +""" + +from __future__ import annotations + +from typing import Any + +# ============================================================================ +# Format Guide +# ============================================================================ +# Each module in ENHANCEMENTS specifies enhancement rules for methods: +# +# 'module_name': { +# 'method_name': { +# 'dataclass_methods': { # For dataclass enhancements +# 'ClassName': ['method1', 'method2', ...] +# }, +# 'preprocess': { # Pre-processing on parameters +# 'param_name': 'check_serialize_method' +# }, +# 'deserialize': { # Deserialize response to typed objects +# 'response_field': 'TypeName', +# }, +# 'extract_field': str, # Extract nested field from response +# 'extract_property': str, # Extract property from extracted items +# 'validate': str, # Validation function name +# 'transform': str, # Transformation function name +# } +# } +# ============================================================================ + +ENHANCEMENTS: dict[str, dict[str, Any]] = { + "browser": { + # Dataclass custom methods + "__dataclass_methods__": { + "ClientWindowInfo": [ + "get_client_window", + "get_state", + "get_width", + "get_height", + "is_active", + "get_x", + "get_y", + ], + }, + # Method enhancements + "create_user_context": { + "preprocess": { + "proxy": "check_serialize_method", + "unhandled_prompt_behavior": "check_serialize_method", + }, + "extract_field": "userContext", + }, + "get_client_windows": { + "deserialize": { + "clientWindows": "ClientWindowInfo", + }, + }, + "get_user_contexts": { + "extract_field": "userContexts", + "extract_property": "userContext", + }, + "set_download_behavior": { + "params_override": { + "allowed": "bool", + "destination_folder": "str", + "userContexts": "[*browser.UserContext]", + }, + "validate": "validate_download_behavior", + "transform": { + "allowed": "allowed", + "destination_folder": "destination_folder", + "func": "transform_download_params", + "result_param": "download_behavior", + }, + }, + # Override the generator-produced set_download_behavior so that + # downloadBehavior is never stripped by the generic None filter. + # The BiDi spec marks it as required (can be null, but must be present). + "extra_methods": [ + ''' def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Set the download behavior for the browser. + + Args: + allowed: ``True`` to allow downloads, ``False`` to deny, or ``None`` + to reset to browser default (sends ``null`` to the protocol). + destination_folder: Destination folder for downloads. Required when + ``allowed=True``. Accepts a string or :class:`pathlib.Path`. + user_contexts: Optional list of user context IDs. + + Raises: + ValueError: If *allowed* is ``True`` and *destination_folder* is + omitted, or ``False`` and *destination_folder* is provided. + """ + validate_download_behavior( + allowed=allowed, + destination_folder=destination_folder, + user_contexts=user_contexts, + ) + download_behavior = transform_download_params(allowed, destination_folder) + # downloadBehavior is a REQUIRED field in the BiDi spec (can be null but + # must be present). Do NOT use a generic None-filter on it. + params: dict = {"downloadBehavior": download_behavior} + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd)''', + ], + }, + "browsingContext": { + # Method enhancements + "create": { + "extract_field": "context", + }, + "get_tree": { + "extract_field": "contexts", + "deserialize": { + "contexts": "Info", + }, + }, + "capture_screenshot": { + "extract_field": "data", + "params_override": { + "context": "str", + "format": "ImageFormat", + "clip": "BoxClipRectangle", + "origin": "str", + }, + }, + "print": { + "extract_field": "data", + }, + "locate_nodes": { + "extract_field": "nodes", + "params_override": { + "context": "str", + "locator": "dict", + "serializationOptions": "dict", + "startNodes": "list", + "maxNodeCount": "int", + }, + }, + "set_viewport": { + "params_override": { + "context": "str", + "viewport": "dict", + "userContexts": "list", + "devicePixelRatio": "float", + }, + }, + # Non-CDDL download event dataclasses (Chromium-specific) + "extra_dataclasses": [ + '''@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None''', + '''@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: Any | None = None''', + '''@dataclass +class DownloadParams: + """DownloadParams - fields shared by all download end event variants.""" + + status: str | None = None + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None + filepath: str | None = None''', + '''@dataclass +class DownloadEndParams: + """DownloadEndParams - params for browsingContext.downloadEnd event.""" + + download_params: "DownloadParams | None" = None + + @classmethod + def from_json(cls, params: dict) -> "DownloadEndParams": + """Deserialize from BiDi wire-level params dict.""" + dp = DownloadParams( + status=params.get("status"), + context=params.get("context"), + navigation=params.get("navigation"), + timestamp=params.get("timestamp"), + url=params.get("url"), + filepath=params.get("filepath"), + ) + return cls(download_params=dp)''', + ], + # Non-CDDL download events (Chromium-specific, not in the BiDi spec) + "extra_events": [ + { + "event_key": "download_will_begin", + "bidi_event": "browsingContext.downloadWillBegin", + "event_class": "DownloadWillBeginParams", + }, + { + "event_key": "download_end", + "bidi_event": "browsingContext.downloadEnd", + "event_class": "DownloadEndParams", + }, + ], + }, + "log": { + # Make LogLevel an alias for Level so existing code using LogLevel works + "aliases": {"LogLevel": "Level"}, + # Replace the minimal CDDL-generated versions with richer ones that have from_json + "exclude_types": ["JavascriptLogEntry"], + "extra_dataclasses": [ + '''@dataclass +class ConsoleLogEntry: + """ConsoleLogEntry - a console log entry from the browser.""" + + type_: str | None = None + method: str | None = None + args: list | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None + + @classmethod + def from_json(cls, params: dict) -> "ConsoleLogEntry": + """Deserialize from BiDi params dict.""" + return cls( + type_=params.get("type"), + method=params.get("method"), + args=params.get("args"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stack_trace=params.get("stackTrace"), + )''', + '''@dataclass +class JavascriptLogEntry: + """JavascriptLogEntry - a JavaScript error log entry from the browser.""" + + type_: str | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stacktrace: Any | None = None + + @classmethod + def from_json(cls, params: dict) -> "JavascriptLogEntry": + """Deserialize from BiDi params dict.""" + return cls( + type_=params.get("type"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stacktrace=params.get("stackTrace"), + )''', + ], + }, + "emulation": { + "extra_methods": [ + ''' def set_geolocation_override( + self, + coordinates=None, + error=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setGeolocationOverride. + + Sets or clears the geolocation override for specified browsing or user contexts. + + Args: + coordinates: A GeolocationCoordinates instance (or dict) to override the + position, or ``None`` to clear a previously-set override. + error: A GeolocationPositionError instance (or dict) to simulate a + position-unavailable error. Mutually exclusive with *coordinates*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {} + if coordinates is not None: + if isinstance(coordinates, dict): + coords_dict = coordinates + else: + coords_dict = {} + if coordinates.latitude is not None: + coords_dict["latitude"] = coordinates.latitude + if coordinates.longitude is not None: + coords_dict["longitude"] = coordinates.longitude + if coordinates.accuracy is not None: + coords_dict["accuracy"] = coordinates.accuracy + if coordinates.altitude is not None: + coords_dict["altitude"] = coordinates.altitude + if coordinates.altitude_accuracy is not None: + coords_dict["altitudeAccuracy"] = coordinates.altitude_accuracy + if coordinates.heading is not None: + coords_dict["heading"] = coordinates.heading + if coordinates.speed is not None: + coords_dict["speed"] = coordinates.speed + params["coordinates"] = coords_dict + if error is not None: + if isinstance(error, dict): + params["error"] = error + else: + params["error"] = { + "type": error.type if error.type is not None else "positionUnavailable" + } + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result''', + ''' def set_timezone_override( + self, + timezone=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setTimezoneOverride. + + Sets or clears the timezone override for specified browsing or user contexts. + Pass ``timezone=None`` (or omit it) to clear a previously-set override. + + Args: + timezone: IANA timezone string (e.g. ``"America/New_York"``) or ``None`` + to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {"timezone": timezone} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setTimezoneOverride", params) + return self._conn.execute(cmd)''', + ''' def set_scripting_enabled( + self, + enabled=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setScriptingEnabled. + + Enables or disables scripting for specified browsing or user contexts. + Pass ``enabled=None`` to restore the default behaviour. + + Args: + enabled: ``True`` to enable scripting, ``False`` to disable it, or + ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {"enabled": enabled} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScriptingEnabled", params) + return self._conn.execute(cmd)''', + ''' def set_user_agent_override( + self, + user_agent=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setUserAgentOverride. + + Overrides the User-Agent string for specified browsing or user contexts. + Pass ``user_agent=None`` to clear a previously-set override. + + Args: + user_agent: Custom User-Agent string, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {"userAgent": user_agent} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setUserAgentOverride", params) + return self._conn.execute(cmd)''', + ''' def set_screen_orientation_override( + self, + screen_orientation=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setScreenOrientationOverride. + + Sets or clears the screen orientation override for specified browsing or + user contexts. + + Args: + screen_orientation: A :class:`ScreenOrientation` instance (or dict with + ``natural`` and ``type`` keys) to lock the orientation, or ``None`` + to clear a previously-set override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if screen_orientation is None: + so_value = None + elif isinstance(screen_orientation, dict): + so_value = screen_orientation + else: + natural = screen_orientation.natural + orientation_type = screen_orientation.type + so_value = { + "natural": natural.lower() if isinstance(natural, str) else natural, + "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, + } + params = {"screenOrientation": so_value} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenOrientationOverride", params) + return self._conn.execute(cmd)''', + ''' def set_network_conditions( + self, + network_conditions=None, + offline: bool | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setNetworkConditions. + + Sets or clears network condition emulation for specified browsing or user + contexts. + + Args: + network_conditions: A dict with the raw ``networkConditions`` value + (e.g. ``{"type": "offline"}``), or ``None`` to clear the override. + Mutually exclusive with *offline*. + offline: Convenience bool — ``True`` sets offline conditions, + ``False`` clears them (sends ``null``). When provided, this takes + precedence over *network_conditions*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if offline is not None: + nc_value = {"type": "offline"} if offline else None + else: + nc_value = network_conditions + params = {"networkConditions": nc_value} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setNetworkConditions", params) + return self._conn.execute(cmd)''', + ], + }, + "script": { + "extra_methods": [ + ''' def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: + """Execute a function declaration in the browser context. + + Args: + function_declaration: The function as a string, e.g. ``"() => document.title"``. + *args: Optional Python values to pass as arguments to the function. + Each value is serialised to a BiDi ``LocalValue`` automatically. + Supported types: ``None``, ``bool``, ``int``, ``float`` + (including ``NaN`` and ``Infinity``), ``str``, ``list``, + ``dict``, and ``datetime.datetime``. + context_id: The browsing context ID to run in. Defaults to the + driver\'s current window handle when a driver was provided. + + Returns: + The inner RemoteValue result dict, or raises WebDriverException on exception. + """ + import math as _math + import datetime as _datetime + from selenium.common.exceptions import WebDriverException as _WebDriverException + + def _serialize_arg(value): + """Serialise a Python value to a BiDi LocalValue dict.""" + if value is None: + return {"type": "null"} + if isinstance(value, bool): + return {"type": "boolean", "value": value} + if isinstance(value, _datetime.datetime): + return {"type": "date", "value": value.isoformat()} + if isinstance(value, float): + if _math.isnan(value): + return {"type": "number", "value": "NaN"} + if _math.isinf(value): + return {"type": "number", "value": "Infinity" if value > 0 else "-Infinity"} + return {"type": "number", "value": value} + if isinstance(value, int): + _MAX_SAFE_INT = 9007199254740991 + if abs(value) > _MAX_SAFE_INT: + return {"type": "bigint", "value": str(value)} + return {"type": "number", "value": value} + if isinstance(value, str): + return {"type": "string", "value": value} + if isinstance(value, list): + return {"type": "array", "value": [_serialize_arg(v) for v in value]} + if isinstance(value, dict): + return {"type": "object", "value": [[str(k), _serialize_arg(v)] for k, v in value.items()]} + return value + + if context_id is None and self._driver is not None: + try: + context_id = self._driver.current_window_handle + except Exception: + pass + target = {"context": context_id} if context_id else {} + serialized_args = [_serialize_arg(a) for a in args] if args else None + raw = self.call_function( + function_declaration=function_declaration, + await_promise=True, + target=target, + arguments=serialized_args, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails", {}) + msg = exc.get("text", str(exc)) if isinstance(exc, dict) else str(exc) + raise _WebDriverException(msg) + if raw.get("type") == "success": + return raw.get("result") + return raw''', + ''' def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + """Add a preload script with validation. + + Args: + function_declaration: The JS function to run on page load. + arguments: Optional list of BiDi arguments. + contexts: Optional list of browsing context IDs. + user_contexts: Optional list of user context IDs. + sandbox: Optional sandbox name. + + Returns: + script_id: The ID of the added preload script (str). + + Raises: + ValueError: If both contexts and user_contexts are specified. + """ + if contexts is not None and user_contexts is not None: + raise ValueError("Cannot specify both contexts and user_contexts") + result = self.add_preload_script( + function_declaration=function_declaration, + arguments=arguments, + contexts=contexts, + user_contexts=user_contexts, + sandbox=sandbox, + ) + if isinstance(result, dict): + return result.get("script") + return result''', + ''' def _remove_preload_script(self, script_id): + """Remove a preload script by ID. + + Args: + script_id: The ID of the preload script to remove. + """ + return self.remove_preload_script(script=script_id)''', + ''' def pin(self, function_declaration): + """Pin (add) a preload script that runs on every page load. + + Args: + function_declaration: The JS function to execute on page load. + + Returns: + script_id: The ID of the pinned script (str). + """ + return self._add_preload_script(function_declaration)''', + ''' def unpin(self, script_id): + """Unpin (remove) a previously pinned preload script. + + Args: + script_id: The ID returned by pin(). + """ + return self._remove_preload_script(script_id=script_id)''', + ''' def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + """Evaluate a script expression and return a structured result. + + Args: + expression: The JavaScript expression to evaluate. + target: A dict like {"context": } or {"realm": }. + await_promise: Whether to await a returned promise. + result_ownership: Optional result ownership setting. + serialization_options: Optional serialization options dict. + user_activation: Optional user activation flag. + + Returns: + An object with .realm, .result (dict or None), and .exception_details (or None). + """ + class _EvalResult: + def __init__(self2, realm, result, exception_details): + self2.realm = realm + self2.result = result + self2.exception_details = exception_details + + raw = self.evaluate( + expression=expression, + target=target, + await_promise=await_promise, + result_ownership=result_ownership, + serialization_options=serialization_options, + user_activation=user_activation, + ) + if isinstance(raw, dict): + realm = raw.get("realm") + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _EvalResult(realm=realm, result=None, exception_details=exc) + return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) + return _EvalResult(realm=None, result=raw, exception_details=None)''', + ''' def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + """Call a function and return a structured result. + + Args: + function_declaration: The JS function string. + await_promise: Whether to await the return value. + target: A dict like {"context": }. + arguments: Optional list of BiDi arguments. + result_ownership: Optional result ownership. + this: Optional \'this\' binding. + user_activation: Optional user activation flag. + serialization_options: Optional serialization options dict. + + Returns: + An object with .result (dict or None) and .exception_details (or None). + """ + class _CallResult: + def __init__(self2, result, exception_details): + self2.result = result + self2.exception_details = exception_details + + raw = self.call_function( + function_declaration=function_declaration, + await_promise=await_promise, + target=target, + arguments=arguments, + result_ownership=result_ownership, + this=this, + user_activation=user_activation, + serialization_options=serialization_options, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _CallResult(result=None, exception_details=exc) + if raw.get("type") == "success": + return _CallResult(result=raw.get("result"), exception_details=None) + return _CallResult(result=raw, exception_details=None)''', + ''' def _get_realms(self, context=None, type=None): + """Get all realms, optionally filtered by context and type. + + Args: + context: Optional browsing context ID to filter by. + type: Optional realm type string to filter by (e.g. RealmType.WINDOW). + + Returns: + List of realm info objects with .realm, .origin, .type, .context attributes. + """ + class _RealmInfo: + def __init__(self2, realm, origin, type_, context): + self2.realm = realm + self2.origin = origin + self2.type = type_ + self2.context = context + + raw = self.get_realms(context=context, type=type) + realms_list = raw.get("realms", []) if isinstance(raw, dict) else [] + result = [] + for r in realms_list: + if isinstance(r, dict): + result.append(_RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + )) + return result''', + ''' def _disown(self, handles, target): + """Disown handles in a browsing context. + + Args: + handles: List of handle strings to disown. + target: A dict like {"context": }. + """ + return self.disown(handles=handles, target=target)''', + ''' def _subscribe_log_entry(self, callback, entry_type_filter=None): + """Subscribe to log.entryAdded BiDi events with optional type filtering.""" + import threading as _threading + from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + + bidi_event = "log.entryAdded" + + if not hasattr(self, "_log_subscriptions"): + self._log_subscriptions = {} + self._log_lock = _threading.Lock() + + def _deserialize(params): + t = params.get("type") if isinstance(params, dict) else None + if t == "console": + cls = getattr(_log_mod, "ConsoleLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + elif t == "javascript": + cls = getattr(_log_mod, "JavascriptLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + return params + + def _wrapped(raw): + entry = _deserialize(raw) + if entry_type_filter is None: + callback(entry) + else: + t = getattr(entry, "type_", None) or ( + entry.get("type") if isinstance(entry, dict) else None + ) + if t == entry_type_filter: + callback(entry) + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + callback_id = self._conn.add_callback(_wrapper, _wrapped) + with self._log_lock: + if bidi_event not in self._log_subscriptions: + session = _Session(self._conn) + result = session.subscribe([bidi_event]) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self._log_subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) + return callback_id''', + ''' def _unsubscribe_log_entry(self, callback_id): + """Unsubscribe a log entry callback by ID.""" + from selenium.webdriver.common.bidi.session import Session as _Session + + bidi_event = "log.entryAdded" + if not hasattr(self, "_log_subscriptions"): + return + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + self._conn.remove_callback(_wrapper, callback_id) + with self._log_lock: + entry = self._log_subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + if entry is not None and not entry["callbacks"]: + session = _Session(self._conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self._log_subscriptions[bidi_event]''', + ''' def add_console_message_handler(self, callback: Callable) -> int: + """Add a handler for console log messages (log.entryAdded type=console). + + Args: + callback: Function called with a ConsoleLogEntry on each console message. + + Returns: + callback_id for use with remove_console_message_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="console")''', + ''' def remove_console_message_handler(self, callback_id: int) -> None: + """Remove a console message handler by callback ID.""" + self._unsubscribe_log_entry(callback_id)''', + ''' def add_javascript_error_handler(self, callback: Callable) -> int: + """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). + + Args: + callback: Function called with a JavascriptLogEntry on each JS error. + + Returns: + callback_id for use with remove_javascript_error_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="javascript")''', + ''' def remove_javascript_error_handler(self, callback_id: int) -> None: + """Remove a JavaScript error handler by callback ID.""" + self._unsubscribe_log_entry(callback_id)''', + ], + }, + "network": { + # Initialize intercepts tracking list in __init__ + "extra_init_code": ["self.intercepts = []"], + # Request class wraps a beforeRequestSent event params and provides actions + "extra_dataclasses": [ + '''class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: str, value: str) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value}''', + '''class Request: + """Wraps a BiDi network request event params and provides request action methods.""" + + def __init__(self, conn, params): + self._conn = conn + self._params = params if isinstance(params, dict) else {} + req = self._params.get("request", {}) or {} + self.url = req.get("url", "") + self._request_id = req.get("request") + + def continue_request(self, **kwargs): + """Continue the intercepted request.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + params = {"request": self._request_id} + params.update(kwargs) + self._conn.execute(_cb("network.continueRequest", params))''', + ], + # Add before_request event (maps to network.beforeRequestSent) + "extra_events": [ + { + "event_key": "before_request", + "bidi_event": "network.beforeRequestSent", + "event_class": "dict", + }, + ], + "extra_methods": [ + ''' def _add_intercept(self, phases=None, url_patterns=None): + """Add a low-level network intercept. + + Args: + phases: list of intercept phases (default: ["beforeRequestSent"]) + url_patterns: optional URL patterns to filter + + Returns: + dict with "intercept" key containing the intercept ID + """ + from selenium.webdriver.common.bidi.common import command_builder as _cb + + if phases is None: + phases = ["beforeRequestSent"] + params = {"phases": phases} + if url_patterns: + params["urlPatterns"] = url_patterns + result = self._conn.execute(_cb("network.addIntercept", params)) + if result: + intercept_id = result.get("intercept") + if intercept_id and intercept_id not in self.intercepts: + self.intercepts.append(intercept_id) + return result''', + ''' def _remove_intercept(self, intercept_id): + """Remove a low-level network intercept.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) + if intercept_id in self.intercepts: + self.intercepts.remove(intercept_id)''', + ''' def add_request_handler(self, event, callback, url_patterns=None): + """Add a handler for network requests at the specified phase. + + Args: + event: Event name, e.g. ``"before_request"``. + callback: Callable receiving a :class:`Request` instance. + url_patterns: optional list of URL pattern dicts to filter. + + Returns: + callback_id int for later removal via remove_request_handler. + """ + phase_map = { + "before_request": "beforeRequestSent", + "before_request_sent": "beforeRequestSent", + "response_started": "responseStarted", + "auth_required": "authRequired", + } + phase = phase_map.get(event, "beforeRequestSent") + self._add_intercept(phases=[phase], url_patterns=url_patterns) + + def _request_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request = Request(self._conn, raw) + callback(request) + + return self.add_event_handler(event, _request_callback)''', + ''' def remove_request_handler(self, event, callback_id): + """Remove a network request handler. + + Args: + event: The event name used when adding the handler. + callback_id: The int returned by add_request_handler. + """ + self.remove_event_handler(event, callback_id)''', + ''' def clear_request_handlers(self): + """Clear all request handlers and remove all tracked intercepts.""" + self.clear_event_handlers() + for intercept_id in list(self.intercepts): + self._remove_intercept(intercept_id)''', + ''' def add_auth_handler(self, username, password): + """Add an auth handler that automatically provides credentials. + + Args: + username: The username for basic authentication. + password: The password for basic authentication. + + Returns: + callback_id int for later removal via remove_auth_handler. + """ + from selenium.webdriver.common.bidi.common import command_builder as _cb + + def _auth_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request_id = ( + raw.get("request", {}).get("request") + if isinstance(raw, dict) + else None + ) + if request_id: + self._conn.execute( + _cb( + "network.continueWithAuth", + { + "request": request_id, + "action": "provideCredentials", + "credentials": { + "type": "password", + "username": username, + "password": password, + }, + }, + ) + ) + + return self.add_event_handler("auth_required", _auth_callback)''', + ''' def remove_auth_handler(self, callback_id): + """Remove an auth handler by callback ID.""" + self.remove_event_handler("auth_required", callback_id)''', + ], + }, + "storage": { + # Exclude auto-generated dataclasses that need custom to_bidi_dict() + # for JSON-over-WebSocket serialization, or custom constructors. + "exclude_types": [ + "CookieFilter", + "PartialCookie", + "BrowsingContextPartitionDescriptor", + "StorageKeyPartitionDescriptor", + ], + "extra_dataclasses": [ + # Re-export network types used in cookie operations so they can be + # imported from selenium.webdriver.common.bidi.storage alongside + # the storage-specific classes. + '''class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: str, value: str) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value}''', + '''class SameSite: + """SameSite cookie attribute values.""" + + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default"''', + # Helper: cookie object returned inside a GetCookiesResult response + '''@dataclass +class StorageCookie: + """A cookie object returned by storage.getCookies.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + @classmethod + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + """Deserialize a wire-level cookie dict to a StorageCookie.""" + value_raw = raw.get("value") + if isinstance(value_raw, dict): + value = BytesValue(value_raw.get("type"), value_raw.get("value")) + else: + value = value_raw + return cls( + name=raw.get("name"), + value=value, + domain=raw.get("domain"), + path=raw.get("path"), + size=raw.get("size"), + http_only=raw.get("httpOnly"), + secure=raw.get("secure"), + same_site=raw.get("sameSite"), + expiry=raw.get("expiry"), + )''', + # Custom CookieFilter with camelCase serialization + '''@dataclass +class CookieFilter: + """CookieFilter.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain + if self.path is not None: + result["path"] = self.path + if self.size is not None: + result["size"] = self.size + if self.http_only is not None: + result["httpOnly"] = self.http_only + if self.secure is not None: + result["secure"] = self.secure + if self.same_site is not None: + result["sameSite"] = self.same_site + if self.expiry is not None: + result["expiry"] = self.expiry + return result''', + # Custom PartialCookie with camelCase serialization + '''@dataclass +class PartialCookie: + """PartialCookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain + if self.path is not None: + result["path"] = self.path + if self.http_only is not None: + result["httpOnly"] = self.http_only + if self.secure is not None: + result["secure"] = self.secure + if self.same_site is not None: + result["sameSite"] = self.same_site + if self.expiry is not None: + result["expiry"] = self.expiry + return result''', + # BrowsingContextPartitionDescriptor: first positional arg is *context* + # (the auto-generated dataclass had `type` first, breaking positional + # usage like BrowsingContextPartitionDescriptor(driver.current_window_handle)) + '''class BrowsingContextPartitionDescriptor: + """BrowsingContextPartitionDescriptor. + + The first positional argument is *context* (a browsing-context ID / window + handle), mirroring how the class is used throughout the test suite: + ``BrowsingContextPartitionDescriptor(driver.current_window_handle)``. + """ + + def __init__(self, context: Any = None, type: str = "context") -> None: + self.context = context + self.type = type + + def to_bidi_dict(self) -> dict: + return {"type": "context", "context": self.context}''', + # StorageKeyPartitionDescriptor with camelCase serialization + '''@dataclass +class StorageKeyPartitionDescriptor: + """StorageKeyPartitionDescriptor.""" + + type: Any | None = "storageKey" + user_context: str | None = None + source_origin: str | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {"type": "storageKey"} + if self.user_context is not None: + result["userContext"] = self.user_context + if self.source_origin is not None: + result["sourceOrigin"] = self.source_origin + return result''', + ], + # Override the generated Storage class methods (Python's last-definition- + # wins semantics means these extra_methods shadow the generated ones). + "extra_methods": [ + ''' def get_cookies(self, filter=None, partition=None): + """Execute storage.getCookies and return a GetCookiesResult.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + if result and "cookies" in result: + cookies = [ + StorageCookie.from_bidi_dict(c) + for c in result.get("cookies", []) + if isinstance(c, dict) + ] + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return GetCookiesResult(cookies=cookies, partition_key=pk) + return GetCookiesResult(cookies=[], partition_key=None)''', + ''' def set_cookie(self, cookie=None, partition=None): + """Execute storage.setCookie.""" + if cookie and hasattr(cookie, "to_bidi_dict"): + cookie = cookie.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result''', + ''' def delete_cookies(self, filter=None, partition=None): + """Execute storage.deleteCookies.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result''', + ], + }, + "session": { + # Override UserPromptHandler to add to_bidi_dict() for JSON serialization + "exclude_types": ["UserPromptHandler"], + "extra_dataclasses": [ + '''@dataclass +class UserPromptHandler: + """UserPromptHandler.""" + + alert: Any | None = None + before_unload: Any | None = None + confirm: Any | None = None + default: Any | None = None + file: Any | None = None + prompt: Any | None = None + + def to_bidi_dict(self) -> dict: + """Convert to BiDi protocol dict with camelCase keys.""" + result = {} + if self.alert is not None: + result["alert"] = self.alert + if self.before_unload is not None: + result["beforeUnload"] = self.before_unload + if self.confirm is not None: + result["confirm"] = self.confirm + if self.default is not None: + result["default"] = self.default + if self.file is not None: + result["file"] = self.file + if self.prompt is not None: + result["prompt"] = self.prompt + return result''', + ], + }, + "webExtension": { + # Suppress the raw generated stubs; hand-written versions follow below + "exclude_methods": ["install", "uninstall"], + "extra_methods": [ + ''' def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + """Install a web extension. + + Exactly one of the three keyword arguments must be provided. + + Args: + path: Directory path to an unpacked extension (also accepted for + signed ``.xpi`` / ``.crx`` archive files on Firefox). + archive_path: File-system path to a packed extension archive. + base64_value: Base64-encoded extension archive string. + + Returns: + The raw result dict from the BiDi ``webExtension.install`` command + (contains at least an ``"extension"`` key with the extension ID). + + Raises: + ValueError: If more than one, or none, of the arguments is provided. + """ + provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + if len(provided) != 1: + raise ValueError( + f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" + ) + if path is not None: + extension_data = {"type": "path", "path": path} + elif archive_path is not None: + extension_data = {"type": "archivePath", "path": archive_path} + else: + extension_data = {"type": "base64", "value": base64_value} + params = {"extensionData": extension_data} + cmd = command_builder("webExtension.install", params) + return self._conn.execute(cmd)''', + ''' def uninstall(self, extension: Any | None = None): + """Uninstall a web extension. + + Args: + extension: Either the extension ID string returned by ``install``, + or the full result dict returned by ``install`` (the + ``"extension"`` value is extracted automatically). + """ + if isinstance(extension, dict): + extension = extension.get("extension") + params = {"extension": extension} + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("webExtension.uninstall", params) + return self._conn.execute(cmd)''', + ], + }, + "input": { + # FileDialogInfo needs from_json for event deserialization + "exclude_types": ["FileDialogInfo", "PointerMoveAction", "PointerDownAction"], + "extra_dataclasses": [ + '''@dataclass +class FileDialogInfo: + """FileDialogInfo - parameters for the input.fileDialogOpened event.""" + + context: Any | None = None + element: Any | None = None + multiple: bool | None = None + + @classmethod + def from_json(cls, params: dict) -> "FileDialogInfo": + """Deserialize event params into FileDialogInfo.""" + return cls( + context=params.get("context"), + element=params.get("element"), + multiple=params.get("multiple"), + )''', + '''@dataclass +class PointerMoveAction: + """PointerMoveAction.""" + + type: str = field(default="pointerMove", init=False) + x: Any | None = None + y: Any | None = None + duration: Any | None = None + origin: Any | None = None + properties: Any | None = None''', + '''@dataclass +class PointerDownAction: + """PointerDownAction.""" + + type: str = field(default="pointerDown", init=False) + button: Any | None = None + properties: Any | None = None''', + ], + "extra_methods": [ + ''' def add_file_dialog_handler(self, callback) -> int: + """Subscribe to the input.fileDialogOpened event. + + Args: + callback: Callable invoked with a FileDialogInfo when a file dialog opens. + + Returns: + A handler ID that can be passed to remove_file_dialog_handler. + """ + return self._event_manager.add_event_handler("file_dialog_opened", callback) + + def remove_file_dialog_handler(self, handler_id: int) -> None: + """Unsubscribe a previously registered file dialog event handler. + + Args: + handler_id: The handler ID returned by add_file_dialog_handler. + """ + return self._event_manager.remove_event_handler("file_dialog_opened", handler_id)''', + ], + }, +} + + +# ============================================================================ +# Pre-processing Functions +# ============================================================================ + + +def check_serialize_method(obj: Any) -> Any: + """Check if object has to_bidi_dict() method and use it for serialization.""" + if obj and hasattr(obj, "to_bidi_dict"): + return obj.to_bidi_dict() + return obj + + +# ============================================================================ +# Validation Functions +# ============================================================================ + + +def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + """Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts (ignored for validation) + + Raises: + ValueError: If parameters are invalid + """ + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +# ============================================================================ +# Transformation Functions +# ============================================================================ + + +def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any]: + """Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + """ + if allowed is True: + return { + "type": "allowed", + # Convert pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": ( + str(destination_folder) if destination_folder is not None else None + ), + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +# ============================================================================ +# Dataclass Method Templates +# ============================================================================ + +DATACLASS_METHOD_TEMPLATES: dict[str, dict[str, str]] = { + "ClientWindowInfo": { + "get_client_window": "return self.client_window", + "get_state": "return self.state", + "get_width": "return self.width", + "get_height": "return self.height", + "is_active": "return self.active", + "get_x": "return self.x", + "get_y": "return self.y", + }, + "BrowsingContext": { + "add_event_handler": "_add_event_handler_impl", + "remove_event_handler": "_remove_event_handler_impl", + }, +} + +DATACLASS_METHOD_DOCSTRINGS: dict[str, dict[str, str]] = { + "ClientWindowInfo": { + "get_client_window": "Get the client window ID.", + "get_state": "Get the client window state.", + "get_width": "Get the client window width.", + "get_height": "Get the client window height.", + "is_active": "Check if the client window is active.", + "get_x": "Get the client window X position.", + "get_y": "Get the client window Y position.", + }, + "BrowsingContext": { + "add_event_handler": "Add an event handler for browsing context events.", + "remove_event_handler": "Remove an event handler by callback ID.", + }, +} + +# ============================================================================ +# Event Handler Support for BrowsingContext +# ============================================================================ + + +def _add_event_handler( + self, + event_name: str, + callback: callable, + contexts: list[str] | None = None, +) -> str: + """Add an event handler for a browsing context event. + + Supported events: + - 'context_created' + - 'context_destroyed' + - 'navigation_started' + - 'navigation_committed' + - 'navigation_failed' + - 'dom_content_loaded' + - 'load' + - 'fragment_navigated' + - 'user_prompt_opened' + - 'user_prompt_closed' + - 'download_will_begin' + - 'download_end' + - 'history_updated' + + Args: + event_name: The name of the event to subscribe to + callback: Callback function to invoke when event occurs + contexts: Optional list of context IDs to limit event subscription + + Returns: + A callback ID that can be used to unsubscribe the handler + """ + if not hasattr(self, "_event_handlers"): + self._event_handlers = {} + self._event_callback_id_counter = 0 + + # Generate unique callback ID + self._event_callback_id_counter += 1 + callback_id = f"callback_{self._event_callback_id_counter}" + + # Store the handler + self._event_handlers[callback_id] = { + "event": event_name, + "callback": callback, + "contexts": contexts, + } + + # Subscribe via the driver's event listening mechanism + if hasattr(self._driver, "_subscribe_event"): + self._driver._subscribe_event(event_name, callback, contexts) + + return callback_id + + +def _remove_event_handler( + self, + callback_id: str, +) -> None: + """Remove an event handler by its callback ID. + + Args: + callback_id: The callback ID returned from add_event_handler + """ + if not hasattr(self, "_event_handlers"): + return + + if callback_id in self._event_handlers: + handler_info = self._event_handlers[callback_id] + + # Unsubscribe from the driver + if hasattr(self._driver, "_unsubscribe_event"): + self._driver._unsubscribe_event( + handler_info["event"], + handler_info["callback"], + handler_info["contexts"], + ) + + del self._event_handlers[callback_id] diff --git a/py/private/cdp.py b/py/private/cdp.py new file mode 100644 index 0000000000000..b097762fe50cd --- /dev/null +++ b/py/private/cdp.py @@ -0,0 +1,515 @@ +# The MIT License(MIT) +# +# Copyright(c) 2018 Hyperion Gray +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp + +import contextvars +import importlib +import itertools +import json +import logging +import pathlib +from collections import defaultdict +from collections.abc import AsyncGenerator, AsyncIterator, Generator +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass +from typing import Any, TypeVar + +import trio +from trio_websocket import ConnectionClosed as WsConnectionClosed +from trio_websocket import connect_websocket_url + +logger = logging.getLogger("trio_cdp") +T = TypeVar("T") +MAX_WS_MESSAGE_SIZE = 2**24 + +devtools = None +version = None + + +def import_devtools(ver): + """Attempt to load the current latest available devtools into the module cache for use later.""" + global devtools + global version + version = ver + base = "selenium.webdriver.common.devtools.v" + try: + devtools = importlib.import_module(f"{base}{ver}") + return devtools + except ModuleNotFoundError: + # Attempt to parse and load the 'most recent' devtools module. This is likely + # because cdp has been updated but selenium python has not been released yet. + devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") + versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) + latest = max(int(x[1:]) for x in versions) + selenium_logger = logging.getLogger(__name__) + selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) + devtools = importlib.import_module(f"{base}{latest}") + return devtools + + +_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") +_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") + + +def get_connection_context(fn_name): + """Look up the current connection. + + If there is no current connection, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _connection_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a connection context.") + + +def get_session_context(fn_name): + """Look up the current session. + + If there is no current session, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _session_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a session context.") + + +@contextmanager +def connection_context(connection): + """Context manager installs ``connection`` as the session context for the current Trio task.""" + token = _connection_context.set(connection) + try: + yield + finally: + _connection_context.reset(token) + + +@contextmanager +def session_context(session): + """Context manager installs ``session`` as the session context for the current Trio task.""" + token = _session_context.set(session) + try: + yield + finally: + _session_context.reset(token) + + +def set_global_connection(connection): + """Install ``connection`` in the root context so that it will become the default connection for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _connection_context + _connection_context = contextvars.ContextVar("_connection_context", default=connection) + + +def set_global_session(session): + """Install ``session`` in the root context so that it will become the default session for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _session_context + _session_context = contextvars.ContextVar("_session_context", default=session) + + +class BrowserError(Exception): + """This exception is raised when the browser's response to a command indicates that an error occurred.""" + + def __init__(self, obj): + self.code = obj.get("code") + self.message = obj.get("message") + self.detail = obj.get("data") + + def __str__(self): + return f"BrowserError {self.detail}" + + +class CdpConnectionClosed(WsConnectionClosed): + """Raised when a public method is called on a closed CDP connection.""" + + def __init__(self, reason): + """Constructor. + + Args: + reason: wsproto.frame_protocol.CloseReason + """ + self.reason = reason + + def __repr__(self): + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" + + +class InternalError(Exception): + """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" + + pass + + +@dataclass +class CmEventProxy: + """A proxy object returned by :meth:`CdpBase.wait_for()``. + + After the context manager executes, this proxy object will have a + value set that contains the returned event. + """ + + value: Any = None + + +class CdpBase: + def __init__(self, ws, session_id, target_id): + self.ws = ws + self.session_id = session_id + self.target_id = target_id + self.channels = defaultdict(set) + self.id_iter = itertools.count() + self.inflight_cmd = {} + self.inflight_result = {} + + async def execute(self, cmd: Generator[dict, T, Any]) -> T: + """Execute a command on the server and wait for the result. + + Args: + cmd: any CDP command + + Returns: + a CDP result + """ + cmd_id = next(self.id_iter) + cmd_event = trio.Event() + self.inflight_cmd[cmd_id] = cmd, cmd_event + request = next(cmd) + request["id"] = cmd_id + if self.session_id: + request["sessionId"] = self.session_id + request_str = json.dumps(request) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") + try: + await self.ws.send_message(request_str) + except WsConnectionClosed as wcc: + raise CdpConnectionClosed(wcc.reason) from None + await cmd_event.wait() + response = self.inflight_result.pop(cmd_id) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Received CDP message: {response}") + if isinstance(response, Exception): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") + raise response + return response + + def listen(self, *event_types, buffer_size=10): + """Listen for events. + + Returns: + An async iterator that iterates over events matching the indicated types. + """ + sender, receiver = trio.open_memory_channel(buffer_size) + for event_type in event_types: + self.channels[event_type].add(sender) + return receiver + + @asynccontextmanager + async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: + """Wait for an event of the given type and return it. + + This is an async context manager, so you should open it inside + an async with block. The block will not exit until the indicated + event is received. + """ + sender: trio.MemorySendChannel + receiver: trio.MemoryReceiveChannel + sender, receiver = trio.open_memory_channel(buffer_size) + self.channels[event_type].add(sender) + proxy = CmEventProxy() + yield proxy + async with receiver: + event = await receiver.receive() + proxy.value = event + + def _handle_data(self, data): + """Handle incoming WebSocket data. + + Args: + data: a JSON dictionary + """ + if "id" in data: + self._handle_cmd_response(data) + else: + self._handle_event(data) + + def _handle_cmd_response(self, data: dict): + """Handle a response to a command. + + This will set an event flag that will return control to the + task that called the command. + + Args: + data: response as a JSON dictionary + """ + cmd_id = data["id"] + try: + cmd, event = self.inflight_cmd.pop(cmd_id) + except KeyError: + logger.warning("Got a message with a command ID that does not exist: %s", data) + return + if "error" in data: + # If the server reported an error, convert it to an exception and do + # not process the response any further. + self.inflight_result[cmd_id] = BrowserError(data["error"]) + else: + # Otherwise, continue the generator to parse the JSON result + # into a CDP object. + try: + _ = cmd.send(data["result"]) + raise InternalError("The command's generator function did not exit when expected!") + except StopIteration as exit: + return_ = exit.value + self.inflight_result[cmd_id] = return_ + event.set() + + def _handle_event(self, data: dict): + """Handle an event. + + Args: + data: event as a JSON dictionary + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + event = devtools.util.parse_json_event(data) + logger.debug("Received event: %s", event) + to_remove = set() + for sender in self.channels[type(event)]: + try: + sender.send_nowait(event) + except trio.WouldBlock: + logger.error('Unable to send event "%r" due to full channel %s', event, sender) + except trio.BrokenResourceError: + to_remove.add(sender) + if to_remove: + self.channels[type(event)] -= to_remove + + +class CdpSession(CdpBase): + """Contains the state for a CDP session. + + Generally you should not instantiate this object yourself; you should call + :meth:`CdpConnection.open_session`. + """ + + def __init__(self, ws, session_id, target_id): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + session_id: devtools.target.SessionID + target_id: devtools.target.TargetID + """ + super().__init__(ws, session_id, target_id) + + self._dom_enable_count = 0 + self._dom_enable_lock = trio.Lock() + self._page_enable_count = 0 + self._page_enable_lock = trio.Lock() + + @asynccontextmanager + async def dom_enable(self): + """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. + + This keeps track of concurrent callers and only disables DOM + events when all callers have exited. + """ + global devtools + async with self._dom_enable_lock: + self._dom_enable_count += 1 + if self._dom_enable_count == 1: + await self.execute(devtools.dom.enable()) + + yield + + async with self._dom_enable_lock: + self._dom_enable_count -= 1 + if self._dom_enable_count == 0: + await self.execute(devtools.dom.disable()) + + @asynccontextmanager + async def page_enable(self): + """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. + + This keeps track of concurrent callers and only disables page + events when all callers have exited. + """ + global devtools + async with self._page_enable_lock: + self._page_enable_count += 1 + if self._page_enable_count == 1: + await self.execute(devtools.page.enable()) + + yield + + async with self._page_enable_lock: + self._page_enable_count -= 1 + if self._page_enable_count == 0: + await self.execute(devtools.page.disable()) + + +class CdpConnection(CdpBase, trio.abc.AsyncResource): + """Contains the connection state for a Chrome DevTools Protocol server. + + CDP can multiplex multiple "sessions" over a single connection. This + class corresponds to the "root" session, i.e. the implicitly created + session that has no session ID. This class is responsible for + reading incoming WebSocket messages and forwarding them to the + corresponding session, as well as handling messages targeted at the + root session itself. You should generally call the + :func:`open_cdp()` instead of instantiating this class directly. + """ + + def __init__(self, ws): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + """ + super().__init__(ws, session_id=None, target_id=None) + self.sessions = {} + + async def aclose(self): + """Close the underlying WebSocket connection. + + This will cause the reader task to gracefully exit when it tries + to read the next message from the WebSocket. All of the public + APIs (``execute()``, ``listen()``, etc.) will raise + ``CdpConnectionClosed`` after the CDP connection is closed. It + is safe to call this multiple times. + """ + await self.ws.aclose() + + @asynccontextmanager + async def open_session(self, target_id) -> AsyncIterator[CdpSession]: + """Context manager opens a session and enables the "simple" style of calling CDP APIs. + + For example, inside a session context, you can call ``await + dom.get_document()`` and it will execute on the current session + automatically. + """ + session = await self.connect_session(target_id) + with session_context(session): + yield session + + async def connect_session(self, target_id) -> "CdpSession": + """Returns a new :class:`CdpSession` connected to the specified target.""" + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + session = CdpSession(self.ws, session_id, target_id) + self.sessions[session_id] = session + return session + + async def _reader_task(self): + """Runs in the background and handles incoming messages. + + Dispatches responses to commands and events to listeners. + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + while True: + try: + message = await self.ws.get_message() + except WsConnectionClosed: + # If the WebSocket is closed, we don't want to throw an + # exception from the reader task. Instead we will throw + # exceptions from the public API methods, and we can quietly + # exit the reader task here. + break + try: + data = json.loads(message) + except json.JSONDecodeError: + raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) + logger.debug("Received message %r", data) + if "sessionId" in data: + session_id = devtools.target.SessionID(data["sessionId"]) + try: + session = self.sessions[session_id] + except KeyError: + raise BrowserError( + { + "code": -32700, + "message": "Browser sent a message for an invalid session", + "data": f"{session_id!r}", + } + ) + session._handle_data(data) + else: + self._handle_data(data) + + for _, session in self.sessions.items(): + for _, senders in session.channels.items(): + for sender in senders: + sender.close() + + +@asynccontextmanager +async def open_cdp(url) -> AsyncIterator[CdpConnection]: + """Async context manager opens a connection to the browser then closes the connection when the block exits. + + The context manager also sets the connection as the default + connection for the current task, so that commands like ``await + target.get_targets()`` will run on this connection automatically. If + you want to use multiple connections concurrently, it is recommended + to open each on in a separate task. + """ + async with trio.open_nursery() as nursery: + conn = await connect_cdp(nursery, url) + try: + with connection_context(conn): + yield conn + finally: + await conn.aclose() + + +async def connect_cdp(nursery, url) -> CdpConnection: + """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. + + The ``open_cdp()`` context manager is preferred in most situations. + You should only use this function if you need to specify a custom + nursery. This connection is not automatically closed! You can either + use the connection object as a context manager (``async with + conn:``) or else call ``await conn.aclose()`` on it when you are + done with it. If ``set_context`` is True, then the returned + connection will be installed as the default connection for the + current task. This argument is for unusual use cases, such as + running inside of a notebook. + """ + ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) + cdp_conn = CdpConnection(ws) + nursery.start_soon(cdp_conn._reader_task) + return cdp_conn diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl new file mode 100644 index 0000000000000..c11b6efe4735f --- /dev/null +++ b/py/private/generate_bidi.bzl @@ -0,0 +1,112 @@ +"""Bazel rule for generating WebDriver BiDi Python modules from CDDL specification.""" + +def _generate_bidi_impl(ctx): + """Implementation of the generate_bidi rule.""" + + cddl_file = ctx.file.cddl_file + manifest_file = ctx.file.enhancements_manifest + generator = ctx.executable.generator + output_dir = ctx.attr.module_name + spec_version = ctx.attr.spec_version + + # The generator creates BiDi modules from the CDDL spec + # Using snake_case naming convention for Python files + module_names = [ + "browser", + "browsing_context", + "common", + "console", + "emulation", + "input", + "log", + "network", + "permissions", + "script", + "session", + "storage", + "webextension", + ] + + # Declare all output files + module_files = [ + ctx.actions.declare_file(output_dir + "/" + name + ".py") + for name in module_names + ] + init_file = ctx.actions.declare_file(output_dir + "/__init__.py") + py_typed = ctx.actions.declare_file(output_dir + "/py.typed") + + gen_outputs = module_files + [init_file, py_typed] + + # Copy static extra_srcs into the output directory + extra_outputs = [] + for src in ctx.files.extra_srcs: + out = ctx.actions.declare_file(output_dir + "/" + src.basename) + ctx.actions.symlink(output = out, target_file = src) + extra_outputs.append(out) + + outputs = gen_outputs + extra_outputs + + # Output directory for the generator + output_base = init_file.dirname + + # Build the command to run the generator + args = [ + cddl_file.path, + output_base, + "--version", + spec_version, + ] + + # Add enhancement manifest if provided + inputs = [cddl_file] + if manifest_file: + args.extend(["--enhancements-manifest", manifest_file.path]) + inputs.append(manifest_file) + + ctx.actions.run( + inputs = inputs, + outputs = gen_outputs, + executable = generator, + arguments = args, + use_default_shell_env = True, + ) + + return [DefaultInfo(files = depset(outputs))] + + +generate_bidi = rule( + implementation = _generate_bidi_impl, + attrs = { + "cddl_file": attr.label( + allow_single_file = [".cddl"], + mandatory = True, + doc = "CDDL specification file", + ), + "enhancements_manifest": attr.label( + allow_single_file = [".py"], + mandatory = False, + doc = "Enhancement manifest Python file (optional)", + ), + "extra_srcs": attr.label_list( + allow_files = [".py"], + mandatory = False, + default = [], + doc = "Additional static Python files to copy verbatim into the output directory", + ), + "generator": attr.label( + executable = True, + cfg = "exec", + mandatory = True, + doc = "Generator script (e.g., generate_bidi.py)", + ), + "module_name": attr.string( + mandatory = True, + doc = "Name of the module being generated (e.g., 'selenium/webdriver/common/bidi')", + ), + "spec_version": attr.string( + default = "1.0", + doc = "WebDriver BiDi specification version", + ), + }, + doc = "Generates Python WebDriver BiDi modules from CDDL specification", +) diff --git a/py/requirements.txt b/py/requirements.txt index fe7abe214f2e5..5f943fdd24f91 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -2,6 +2,7 @@ async-generator==1.10 attrs==25.4.0 backports.tarfile==1.2.0 cachetools==7.0.1 + certifi==2026.1.4 cffi==2.0.0 chardet==5.2.0 diff --git a/py/requirements_lock.txt b/py/requirements_lock.txt index 222690e162f3c..66fa883ec4a53 100644 --- a/py/requirements_lock.txt +++ b/py/requirements_lock.txt @@ -416,7 +416,6 @@ jeepney==0.9.0 \ --hash=sha256:cf0e9e845622b81e4a28df94c40345400256ec608d0e55bb8a3feaa9163f5732 # via # -r py/requirements.txt - # keyring # secretstorage jinja2==3.1.6 \ --hash=sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d \ @@ -965,9 +964,7 @@ rich==14.3.2 \ secretstorage==3.5.0 \ --hash=sha256:0ce65888c0725fcb2c5bc0fdb8e5438eece02c523557ea40ce0703c266248137 \ --hash=sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be - # via - # -r py/requirements.txt - # keyring + # via -r py/requirements.txt sniffio==1.3.1 \ --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \ --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc diff --git a/py/selenium/common/exceptions.py b/py/selenium/common/exceptions.py index c45f530002a8f..7ec809eb20b18 100644 --- a/py/selenium/common/exceptions.py +++ b/py/selenium/common/exceptions.py @@ -27,7 +27,10 @@ class WebDriverException(Exception): """Base webdriver exception.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: super().__init__() self.msg = msg @@ -73,7 +76,10 @@ class NoSuchElementException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#nosuchelementexception" @@ -111,9 +117,14 @@ class StaleElementReferenceException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" + ) super().__init__(with_support, screen, stacktrace) @@ -161,7 +172,10 @@ class ElementNotVisibleException(InvalidElementStateException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotvisibleexception" @@ -172,9 +186,14 @@ class ElementNotInteractableException(InvalidElementStateException): """Thrown when element interactions will hit another element due to paint order.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" + ) super().__init__(with_support, screen, stacktrace) @@ -213,7 +232,10 @@ class InvalidSelectorException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#invalidselectorexception" @@ -252,9 +274,14 @@ class ElementClickInterceptedException(WebDriverException): """Thrown when element click fails because another element obscures it.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" + ) super().__init__(with_support, screen, stacktrace) @@ -271,7 +298,10 @@ class InvalidSessionIdException(WebDriverException): """Thrown when the given session id is not in the list of active sessions.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#invalidsessionidexception" @@ -282,7 +312,10 @@ class SessionNotCreatedException(WebDriverException): """A new session could not be created.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#sessionnotcreatedexception" @@ -297,7 +330,10 @@ class NoSuchDriverException(WebDriverException): """Raised when driver is not specified and cannot be located.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}/driver_location" diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index a5b1e6f85a09e..ab96f2d81e292 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,16 +1,7 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. + +from __future__ import annotations + diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 5b449ae69276a..ed6a4d8f33bc5 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,280 +1,330 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi.session import UserPromptHandler -from selenium.webdriver.common.proxy import Proxy - - -class ClientWindowState: - """Represents a window state.""" +# WebDriver BiDi module: browser +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass + + +def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any] | None: + """Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads (accepts str or + pathlib.Path; will be coerced to str) + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + """ + if allowed is True: + return { + "type": "allowed", + # Coerce pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": str(destination_folder) if destination_folder is not None else None, + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + """Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts + + Raises: + ValueError: If parameters are invalid + """ + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +class ClientWindowNamedState: + """ClientWindowNamedState.""" FULLSCREEN = "fullscreen" MAXIMIZED = "maximized" MINIMIZED = "minimized" - NORMAL = "normal" - - VALID_STATES = {FULLSCREEN, MAXIMIZED, MINIMIZED, NORMAL} +@dataclass class ClientWindowInfo: - """Represents a client window information.""" - - def __init__( - self, - client_window: str, - state: str, - width: int, - height: int, - x: int, - y: int, - active: bool, - ): - self.client_window = client_window - self.state = state - self.width = width - self.height = height - self.x = x - self.y = y - self.active = active - - def get_state(self) -> str: - """Gets the state of the client window. - - Returns: - str: The state of the client window (one of the ClientWindowState constants). - """ - return self.state - - def get_client_window(self) -> str: - """Gets the client window identifier. - - Returns: - str: The client window identifier. - """ + """ClientWindowInfo.""" + + active: bool | None = None + client_window: Any | None = None + height: Any | None = None + state: Any | None = None + width: Any | None = None + x: Any | None = None + y: Any | None = None + + def get_client_window(self): + """Get the client window ID.""" return self.client_window - def get_width(self) -> int: - """Gets the width of the client window. + def get_state(self): + """Get the client window state.""" + return self.state - Returns: - int: The width of the client window. - """ + def get_width(self): + """Get the client window width.""" return self.width - def get_height(self) -> int: - """Gets the height of the client window. - - Returns: - int: The height of the client window. - """ + def get_height(self): + """Get the client window height.""" return self.height - def get_x(self) -> int: - """Gets the x coordinate of the client window. + def is_active(self): + """Check if the client window is active.""" + return self.active - Returns: - int: The x coordinate of the client window. - """ + def get_x(self): + """Get the client window X position.""" return self.x - def get_y(self) -> int: - """Gets the y coordinate of the client window. - - Returns: - int: The y coordinate of the client window. - """ + def get_y(self): + """Get the client window Y position.""" return self.y - def is_active(self) -> bool: - """Checks if the client window is active. - Returns: - bool: True if the client window is active, False otherwise. - """ - return self.active - @classmethod - def from_dict(cls, data: dict) -> "ClientWindowInfo": - """Creates a ClientWindowInfo instance from a dictionary. +@dataclass +class UserContextInfo: + """UserContextInfo.""" - Args: - data: A dictionary containing the client window information. + user_context: Any | None = None - Returns: - ClientWindowInfo: A new instance of ClientWindowInfo. - Raises: - ValueError: If required fields are missing or have invalid types. - """ - try: - client_window = data["clientWindow"] - if not isinstance(client_window, str): - raise ValueError("clientWindow must be a string") - - state = data["state"] - if not isinstance(state, str): - raise ValueError("state must be a string") - if state not in ClientWindowState.VALID_STATES: - raise ValueError(f"Invalid state: {state}. Must be one of {ClientWindowState.VALID_STATES}") - - width = data["width"] - if not isinstance(width, int) or width < 0: - raise ValueError(f"width must be a non-negative integer, got {width}") - - height = data["height"] - if not isinstance(height, int) or height < 0: - raise ValueError(f"height must be a non-negative integer, got {height}") - - x = data["x"] - if not isinstance(x, int): - raise ValueError(f"x must be an integer, got {type(x).__name__}") - - y = data["y"] - if not isinstance(y, int): - raise ValueError(f"y must be an integer, got {type(y).__name__}") - - active = data["active"] - if not isinstance(active, bool): - raise ValueError("active must be a boolean") - - return cls( - client_window=client_window, - state=state, - width=width, - height=height, - x=x, - y=y, - active=active, - ) - except (KeyError, TypeError) as e: - raise ValueError(f"Invalid data format for ClientWindowInfo: {e}") from e +@dataclass +class CreateUserContextParameters: + """CreateUserContextParameters.""" + accept_insecure_certs: bool | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None -class Browser: - """BiDi implementation of the browser module.""" - def __init__(self, conn): - self.conn = conn +@dataclass +class GetClientWindowsResult: + """GetClientWindowsResult.""" - def create_user_context( - self, - accept_insecure_certs: bool | None = None, - proxy: Proxy | None = None, - unhandled_prompt_behavior: UserPromptHandler | None = None, - ) -> str: - """Creates a new user context. + client_windows: list[Any | None] | None = None - Args: - accept_insecure_certs: Optional flag to accept insecure TLS certificates. - proxy: Optional proxy configuration for the user context. - unhandled_prompt_behavior: Optional configuration for handling user prompts. - Returns: - str: The ID of the created user context. - """ - params: dict[str, Any] = {} +@dataclass +class GetUserContextsResult: + """GetUserContextsResult.""" - if accept_insecure_certs is not None: - params["acceptInsecureCerts"] = accept_insecure_certs + user_contexts: list[Any | None] | None = None - if proxy is not None: - params["proxy"] = proxy.to_bidi_dict() - if unhandled_prompt_behavior is not None: - params["unhandledPromptBehavior"] = unhandled_prompt_behavior.to_dict() +@dataclass +class RemoveUserContextParameters: + """RemoveUserContextParameters.""" - result = self.conn.execute(command_builder("browser.createUserContext", params)) - return result["userContext"] + user_context: Any | None = None - def get_user_contexts(self) -> list[str]: - """Gets all user contexts. - Returns: - List[str]: A list of user context IDs. - """ - result = self.conn.execute(command_builder("browser.getUserContexts", {})) - return [context_info["userContext"] for context_info in result["userContexts"]] +@dataclass +class SetClientWindowStateParameters: + """SetClientWindowStateParameters.""" - def remove_user_context(self, user_context_id: str) -> None: - """Removes a user context. + client_window: Any | None = None - Args: - user_context_id: The ID of the user context to remove. - Raises: - ValueError: If the user context ID is "default" or does not exist. - """ - if user_context_id == "default": - raise ValueError("Cannot remove the default user context") +@dataclass +class ClientWindowRectState: + """ClientWindowRectState.""" - params = {"userContext": user_context_id} - self.conn.execute(command_builder("browser.removeUserContext", params)) + state: str = field(default="normal", init=False) + width: Any | None = None + height: Any | None = None + x: Any | None = None + y: Any | None = None - def get_client_windows(self) -> list[ClientWindowInfo]: - """Gets all client windows. - Returns: - List[ClientWindowInfo]: A list of client window information. - """ - result = self.conn.execute(command_builder("browser.getClientWindows", {})) - return [ClientWindowInfo.from_dict(window) for window in result["clientWindows"]] - - def set_download_behavior( - self, - *, - allowed: bool | None = None, - destination_folder: str | os.PathLike | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set the download behavior for the browser or specific user contexts. +@dataclass +class SetDownloadBehaviorParameters: + """SetDownloadBehaviorParameters.""" + + download_behavior: Any | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class DownloadBehaviorAllowed: + """DownloadBehaviorAllowed.""" + + type: str = field(default="allowed", init=False) + destination_folder: str | None = None + + +@dataclass +class DownloadBehaviorDenied: + """DownloadBehaviorDenied.""" + + type: str = field(default="denied", init=False) + + +class Browser: + """WebDriver BiDi browser module.""" + + def __init__(self, conn) -> None: + self._conn = conn + + def close(self): + """Execute browser.close.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.close", params) + result = self._conn.execute(cmd) + return result + + def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): + """Execute browser.createUserContext.""" + if proxy and hasattr(proxy, 'to_bidi_dict'): + proxy = proxy.to_bidi_dict() + + if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, 'to_bidi_dict'): + unhandled_prompt_behavior = unhandled_prompt_behavior.to_bidi_dict() + + params = { + "acceptInsecureCerts": accept_insecure_certs, + "proxy": proxy, + "unhandledPromptBehavior": unhandled_prompt_behavior, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.createUserContext", params) + result = self._conn.execute(cmd) + if result and "userContext" in result: + extracted = result.get("userContext") + return extracted + return result + + def get_client_windows(self): + """Execute browser.getClientWindows.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.getClientWindows", params) + result = self._conn.execute(cmd) + if result and "clientWindows" in result: + items = result.get("clientWindows", []) + return [ + ClientWindowInfo( + active=item.get("active"), + client_window=item.get("clientWindow"), + height=item.get("height"), + state=item.get("state"), + width=item.get("width"), + x=item.get("x"), + y=item.get("y") + ) + for item in items + if isinstance(item, dict) + ] + return [] + + def get_user_contexts(self): + """Execute browser.getUserContexts.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.getUserContexts", params) + result = self._conn.execute(cmd) + if result and "userContexts" in result: + items = result.get("userContexts", []) + return [ + item.get("userContext") + for item in items + if isinstance(item, dict) + ] + return [] + + def remove_user_context(self, user_context: Any | None = None): + """Execute browser.removeUserContext.""" + params = { + "userContext": user_context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.removeUserContext", params) + result = self._conn.execute(cmd) + return result + + def set_client_window_state(self, client_window: Any | None = None): + """Execute browser.setClientWindowState.""" + params = { + "clientWindow": client_window, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setClientWindowState", params) + result = self._conn.execute(cmd) + return result + + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) + + download_behavior = None + download_behavior = transform_download_params(allowed, destination_folder) + + params = { + "downloadBehavior": download_behavior, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setDownloadBehavior", params) + result = self._conn.execute(cmd) + return result + + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Set the download behavior for the browser. Args: - allowed: True to allow downloads, False to deny downloads, or None to - clear download behavior (revert to default). - destination_folder: Required when allowed is True. Specifies the folder - to store downloads in. - user_contexts: Optional list of user context IDs to apply this - behavior to. If omitted, updates the default behavior. + allowed: ``True`` to allow downloads, ``False`` to deny, or ``None`` + to reset to browser default (sends ``null`` to the protocol). + destination_folder: Destination folder for downloads. Required when + ``allowed=True``. Accepts a string or :class:`pathlib.Path`. + user_contexts: Optional list of user context IDs. Raises: - ValueError: If allowed=True and destination_folder is missing, or if - allowed=False and destination_folder is provided. + ValueError: If *allowed* is ``True`` and *destination_folder* is + omitted, or ``False`` and *destination_folder* is provided. """ - params: dict[str, Any] = {} - - if allowed is None: - params["downloadBehavior"] = None - else: - if allowed: - if not destination_folder: - raise ValueError("destination_folder is required when allowed=True.") - params["downloadBehavior"] = { - "type": "allowed", - "destinationFolder": os.fspath(destination_folder), - } - else: - if destination_folder: - raise ValueError("destination_folder should not be provided when allowed=False.") - params["downloadBehavior"] = {"type": "denied"} - + validate_download_behavior( + allowed=allowed, + destination_folder=destination_folder, + user_contexts=user_contexts, + ) + download_behavior = transform_download_params(allowed, destination_folder) + # downloadBehavior is a REQUIRED field in the BiDi spec (can be null but + # must be present). Do NOT use a generic None-filter on it. + params: dict = {"downloadBehavior": download_behavior} if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("browser.setDownloadBehavior", params)) + cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index e8ae150342bda..35aea615d1780 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,35 +1,24 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: browsingContext +from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable from dataclasses import dataclass -from typing import Any - -from typing_extensions import Sentinel - -from selenium.webdriver.common.bidi.common import command_builder from selenium.webdriver.common.bidi.session import Session -UNDEFINED = Sentinel("UNDEFINED") - class ReadinessState: - """Represents the stage of document loading at which a navigation command will return.""" + """ReadinessState.""" NONE = "none" INTERACTIVE = "interactive" @@ -37,576 +26,509 @@ class ReadinessState: class UserPromptType: - """Represents the possible user prompt types.""" + """UserPromptType.""" ALERT = "alert" - BEFORE_UNLOAD = "beforeunload" + BEFOREUNLOAD = "beforeunload" CONFIRM = "confirm" PROMPT = "prompt" -class NavigationInfo: - """Provides details of an ongoing navigation.""" +class CreateType: + """CreateType.""" - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - ): - self.context = context - self.navigation = navigation - self.timestamp = timestamp - self.url = url + TAB = "tab" + WINDOW = "window" - @classmethod - def from_json(cls, json: dict) -> "NavigationInfo": - """Creates a NavigationInfo instance from a dictionary. - Args: - json: A dictionary containing the navigation information. +class DownloadCompleteParams: + """DownloadCompleteParams.""" - Returns: - A new instance of NavigationInfo. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - navigation = json.get("navigation") - if navigation is not None and not isinstance(navigation, str): - raise ValueError("navigation must be a string") - - timestamp = json.get("timestamp") - if timestamp is None or not isinstance(timestamp, int) or timestamp < 0: - raise ValueError("timestamp is required and must be a non-negative integer") - - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") - - return cls(context, navigation, timestamp, url) - - -class BrowsingContextInfo: - """Represents the properties of a navigable.""" - - def __init__( - self, - context: str, - url: str, - children: list["BrowsingContextInfo"] | None, - client_window: str, - user_context: str, - parent: str | None = None, - original_opener: str | None = None, - ): - self.context = context - self.url = url - self.children = children - self.parent = parent - self.user_context = user_context - self.original_opener = original_opener - self.client_window = client_window + COMPLETE = "complete" - @classmethod - def from_json(cls, json: dict) -> "BrowsingContextInfo": - """Creates a BrowsingContextInfo instance from a dictionary. - Args: - json: A dictionary containing the browsing context information. +@dataclass +class Info: + """Info.""" - Returns: - A new instance of BrowsingContextInfo. - """ - children = None - raw_children = json.get("children") - if raw_children is not None: - if not isinstance(raw_children, list): - raise ValueError("children must be a list if provided") - - children = [] - for child in raw_children: - if not isinstance(child, dict): - raise ValueError(f"Each child must be a dictionary, got {type(child)}") - children.append(BrowsingContextInfo.from_json(child)) - - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") - - parent = json.get("parent") - if parent is not None and not isinstance(parent, str): - raise ValueError("parent must be a string if provided") - - user_context = json.get("userContext") - if user_context is None or not isinstance(user_context, str): - raise ValueError("userContext is required and must be a string") - - original_opener = json.get("originalOpener") - if original_opener is not None and not isinstance(original_opener, str): - raise ValueError("originalOpener must be a string if provided") - - client_window = json.get("clientWindow") - if client_window is None or not isinstance(client_window, str): - raise ValueError("clientWindow is required and must be a string") - - return cls( - context=context, - url=url, - children=children, - client_window=client_window, - user_context=user_context, - parent=parent, - original_opener=original_opener, - ) + children: Any | None = None + client_window: Any | None = None + context: Any | None = None + original_opener: Any | None = None + url: str | None = None + user_context: Any | None = None + parent: Any | None = None -class DownloadWillBeginParams(NavigationInfo): - """Parameters for the downloadWillBegin event.""" +@dataclass +class AccessibilityLocator: + """AccessibilityLocator.""" - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - suggested_filename: str, - ): - super().__init__(context, navigation, timestamp, url) - self.suggested_filename = suggested_filename + type: str = field(default="accessibility", init=False) + name: str | None = None + role: str | None = None - @classmethod - def from_json(cls, json: dict) -> "DownloadWillBeginParams": - nav_info = NavigationInfo.from_json(json) - - suggested_filename = json.get("suggestedFilename") - if suggested_filename is None or not isinstance(suggested_filename, str): - raise ValueError("suggestedFilename is required and must be a string") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - suggested_filename=suggested_filename, - ) +@dataclass +class CssLocator: + """CssLocator.""" -class UserPromptOpenedParams: - """Parameters for the userPromptOpened event.""" + type: str = field(default="css", init=False) + value: str | None = None - def __init__( - self, - context: str, - handler: str, - message: str, - type: str, - default_value: str | None = None, - ): - self.context = context - self.handler = handler - self.message = message - self.type = type - self.default_value = default_value - @classmethod - def from_json(cls, json: dict) -> "UserPromptOpenedParams": - """Creates a UserPromptOpenedParams instance from a dictionary. +@dataclass +class ContextLocator: + """ContextLocator.""" - Args: - json: A dictionary containing the user prompt parameters. + type: str = field(default="context", init=False) + context: Any | None = None - Returns: - A new instance of UserPromptOpenedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - handler = json.get("handler") - if handler is None or not isinstance(handler, str): - raise ValueError("handler is required and must be a string") - - message = json.get("message") - if message is None or not isinstance(message, str): - raise ValueError("message is required and must be a string") - - type_value = json.get("type") - if type_value is None or not isinstance(type_value, str): - raise ValueError("type is required and must be a string") - - default_value = json.get("defaultValue") - if default_value is not None and not isinstance(default_value, str): - raise ValueError("defaultValue must be a string if provided") - - return cls( - context=context, - handler=handler, - message=message, - type=type_value, - default_value=default_value, - ) +@dataclass +class InnerTextLocator: + """InnerTextLocator.""" -class UserPromptClosedParams: - """Parameters for the userPromptClosed event.""" + type: str = field(default="innerText", init=False) + value: str | None = None + ignore_case: bool | None = None + match_type: Any | None = None + max_depth: Any | None = None - def __init__( - self, - context: str, - accepted: bool, - type: str, - user_text: str | None = None, - ): - self.context = context - self.accepted = accepted - self.type = type - self.user_text = user_text - @classmethod - def from_json(cls, json: dict) -> "UserPromptClosedParams": - """Creates a UserPromptClosedParams instance from a dictionary. +@dataclass +class XPathLocator: + """XPathLocator.""" - Args: - json: A dictionary containing the user prompt closed parameters. + type: str = field(default="xpath", init=False) + value: str | None = None - Returns: - A new instance of UserPromptClosedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - accepted = json.get("accepted") - if accepted is None or not isinstance(accepted, bool): - raise ValueError("accepted is required and must be a boolean") - - type_value = json.get("type") - if type_value is None or not isinstance(type_value, str): - raise ValueError("type is required and must be a string") - - user_text = json.get("userText") - if user_text is not None and not isinstance(user_text, str): - raise ValueError("userText must be a string if provided") - - return cls( - context=context, - accepted=accepted, - type=type_value, - user_text=user_text, - ) +@dataclass +class BaseNavigationInfo: + """BaseNavigationInfo.""" + + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None -class HistoryUpdatedParams: - """Parameters for the historyUpdated event.""" - def __init__( - self, - context: str, - timestamp: int, - url: str, - ): - self.context = context - self.timestamp = timestamp - self.url = url +@dataclass +class ActivateParameters: + """ActivateParameters.""" - @classmethod - def from_json(cls, json: dict) -> "HistoryUpdatedParams": - """Creates a HistoryUpdatedParams instance from a dictionary. + context: Any | None = None - Args: - json: A dictionary containing the history updated parameters. - Returns: - A new instance of HistoryUpdatedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - timestamp = json.get("timestamp") - if timestamp is None or not isinstance(timestamp, int) or timestamp < 0: - raise ValueError("timestamp is required and must be a non-negative integer") - - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") - - return cls( - context=context, - timestamp=timestamp, - url=url, - ) +@dataclass +class CaptureScreenshotParameters: + """CaptureScreenshotParameters.""" + context: Any | None = None + format: Any | None = None + clip: Any | None = None -class DownloadCanceledParams(NavigationInfo): - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - status: str = "canceled", - ): - super().__init__(context, navigation, timestamp, url) - self.status = status - @classmethod - def from_json(cls, json: dict) -> "DownloadCanceledParams": - nav_info = NavigationInfo.from_json(json) - - status = json.get("status") - if status is None or status != "canceled": - raise ValueError("status is required and must be 'canceled'") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - status=status, - ) +@dataclass +class ImageFormat: + """ImageFormat.""" + type: str | None = None + quality: Any | None = None -class DownloadCompleteParams(NavigationInfo): - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - status: str = "complete", - filepath: str | None = None, - ): - super().__init__(context, navigation, timestamp, url) - self.status = status - self.filepath = filepath - @classmethod - def from_json(cls, json: dict) -> "DownloadCompleteParams": - nav_info = NavigationInfo.from_json(json) - - status = json.get("status") - if status is None or status != "complete": - raise ValueError("status is required and must be 'complete'") - - filepath = json.get("filepath") - if filepath is not None and not isinstance(filepath, str): - raise ValueError("filepath must be a string if provided") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - status=status, - filepath=filepath, - ) +@dataclass +class ElementClipRectangle: + """ElementClipRectangle.""" + type: str = field(default="element", init=False) + element: Any | None = None -class DownloadEndParams: - """Parameters for the downloadEnd event.""" - def __init__( - self, - download_params: DownloadCanceledParams | DownloadCompleteParams, - ): - self.download_params = download_params +@dataclass +class BoxClipRectangle: + """BoxClipRectangle.""" - @classmethod - def from_json(cls, json: dict) -> "DownloadEndParams": - status = json.get("status") - if status == "canceled": - return cls(DownloadCanceledParams.from_json(json)) - elif status == "complete": - return cls(DownloadCompleteParams.from_json(json)) - else: - raise ValueError("status must be either 'canceled' or 'complete'") + type: str = field(default="box", init=False) + x: Any | None = None + y: Any | None = None + width: Any | None = None + height: Any | None = None -class ContextCreated: - """Event class for browsingContext.contextCreated event.""" +@dataclass +class CaptureScreenshotResult: + """CaptureScreenshotResult.""" - event_class = "browsingContext.contextCreated" + data: str | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, BrowsingContextInfo): - return json - return BrowsingContextInfo.from_json(json) +@dataclass +class CloseParameters: + """CloseParameters.""" -class ContextDestroyed: - """Event class for browsingContext.contextDestroyed event.""" + context: Any | None = None + prompt_unload: bool | None = None - event_class = "browsingContext.contextDestroyed" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, BrowsingContextInfo): - return json - return BrowsingContextInfo.from_json(json) +@dataclass +class CreateParameters: + """CreateParameters.""" + type: Any | None = None + reference_context: Any | None = None + background: bool | None = None + user_context: Any | None = None -class NavigationStarted: - """Event class for browsingContext.navigationStarted event.""" - event_class = "browsingContext.navigationStarted" +@dataclass +class CreateResult: + """CreateResult.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + context: Any | None = None -class NavigationCommitted: - """Event class for browsingContext.navigationCommitted event.""" +@dataclass +class GetTreeParameters: + """GetTreeParameters.""" - event_class = "browsingContext.navigationCommitted" + max_depth: Any | None = None + root: Any | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class GetTreeResult: + """GetTreeResult.""" -class NavigationFailed: - """Event class for browsingContext.navigationFailed event.""" + contexts: Any | None = None - event_class = "browsingContext.navigationFailed" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class HandleUserPromptParameters: + """HandleUserPromptParameters.""" + context: Any | None = None + accept: bool | None = None + user_text: str | None = None -class NavigationAborted: - """Event class for browsingContext.navigationAborted event.""" - event_class = "browsingContext.navigationAborted" +@dataclass +class LocateNodesParameters: + """LocateNodesParameters.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + context: Any | None = None + locator: Any | None = None + serialization_options: Any | None = None + start_nodes: list[Any | None] | None = None -class DomContentLoaded: - """Event class for browsingContext.domContentLoaded event.""" +@dataclass +class LocateNodesResult: + """LocateNodesResult.""" - event_class = "browsingContext.domContentLoaded" + nodes: list[Any | None] | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class NavigateParameters: + """NavigateParameters.""" -class Load: - """Event class for browsingContext.load event.""" + context: Any | None = None + url: str | None = None + wait: Any | None = None - event_class = "browsingContext.load" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class NavigateResult: + """NavigateResult.""" + navigation: Any | None = None + url: str | None = None -class FragmentNavigated: - """Event class for browsingContext.fragmentNavigated event.""" - event_class = "browsingContext.fragmentNavigated" +@dataclass +class PrintParameters: + """PrintParameters.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + context: Any | None = None + background: bool | None = None + margin: Any | None = None + page: Any | None = None + scale: Any | None = None + shrink_to_fit: bool | None = None -class DownloadWillBegin: - """Event class for browsingContext.downloadWillBegin event.""" +@dataclass +class PrintMarginParameters: + """PrintMarginParameters.""" - event_class = "browsingContext.downloadWillBegin" + bottom: Any | None = None + left: Any | None = None + right: Any | None = None + top: Any | None = None - @classmethod - def from_json(cls, json: dict): - return DownloadWillBeginParams.from_json(json) +@dataclass +class PrintPageParameters: + """PrintPageParameters.""" -class UserPromptOpened: - """Event class for browsingContext.userPromptOpened event.""" + height: Any | None = None + width: Any | None = None - event_class = "browsingContext.userPromptOpened" - @classmethod - def from_json(cls, json: dict): - return UserPromptOpenedParams.from_json(json) +@dataclass +class PrintResult: + """PrintResult.""" + + data: str | None = None -class UserPromptClosed: - """Event class for browsingContext.userPromptClosed event.""" +@dataclass +class ReloadParameters: + """ReloadParameters.""" - event_class = "browsingContext.userPromptClosed" + context: Any | None = None + ignore_cache: bool | None = None + wait: Any | None = None - @classmethod - def from_json(cls, json: dict): - return UserPromptClosedParams.from_json(json) +@dataclass +class SetViewportParameters: + """SetViewportParameters.""" -class HistoryUpdated: - """Event class for browsingContext.historyUpdated event.""" + context: Any | None = None + viewport: Any | None = None + device_pixel_ratio: Any | None = None + user_contexts: list[Any | None] | None = None - event_class = "browsingContext.historyUpdated" - @classmethod - def from_json(cls, json: dict): - return HistoryUpdatedParams.from_json(json) +@dataclass +class Viewport: + """Viewport.""" + + width: Any | None = None + height: Any | None = None + + +@dataclass +class TraverseHistoryParameters: + """TraverseHistoryParameters.""" + + context: Any | None = None + delta: Any | None = None + + +@dataclass +class HistoryUpdatedParameters: + """HistoryUpdatedParameters.""" + + context: Any | None = None + timestamp: Any | None = None + url: str | None = None + + +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None + + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: str = field(default="canceled", init=False) -class DownloadEnd: - """Event class for browsingContext.downloadEnd event.""" +@dataclass +class UserPromptClosedParameters: + """UserPromptClosedParameters.""" + + context: Any | None = None + accepted: bool | None = None + type: Any | None = None + user_text: str | None = None + + +@dataclass +class UserPromptOpenedParameters: + """UserPromptOpenedParameters.""" + + context: Any | None = None + handler: Any | None = None + message: str | None = None + type: Any | None = None + default_value: str | None = None + + +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" - event_class = "browsingContext.downloadEnd" + suggested_filename: str | None = None + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: Any | None = None + +@dataclass +class DownloadParams: + """DownloadParams - fields shared by all download end event variants.""" + + status: str | None = None + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None + filepath: str | None = None + +@dataclass +class DownloadEndParams: + """DownloadEndParams - params for browsingContext.downloadEnd event.""" + + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, json: dict): - return DownloadEndParams.from_json(json) + def from_json(cls, params: dict) -> "DownloadEndParams": + """Deserialize from BiDi wire-level params dict.""" + dp = DownloadParams( + status=params.get("status"), + context=params.get("context"), + navigation=params.get("navigation"), + timestamp=params.get("timestamp"), + url=params.get("url"), + filepath=params.get("filepath"), + ) + return cls(download_params=dp) + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "context_created": "browsingContext.contextCreated", + "context_destroyed": "browsingContext.contextDestroyed", + "navigation_started": "browsingContext.navigationStarted", + "fragment_navigated": "browsingContext.fragmentNavigated", + "history_updated": "browsingContext.historyUpdated", + "dom_content_loaded": "browsingContext.domContentLoaded", + "load": "browsingContext.load", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", + "navigation_aborted": "browsingContext.navigationAborted", + "navigation_committed": "browsingContext.navigationCommitted", + "navigation_failed": "browsingContext.navigationFailed", + "user_prompt_closed": "browsingContext.userPromptClosed", + "user_prompt_opened": "browsingContext.userPromptOpened", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", +} + +def _deserialize_info_list(items: list) -> list | None: + """Recursively deserialize a list of dicts to Info objects. + + Args: + items: List of dicts from the API response + + Returns: + List of Info objects with properly nested children, or None if empty + """ + if not items or not isinstance(items, list): + return None + + result = [] + for item in items: + if isinstance(item, dict): + # Recursively deserialize children only if the key exists in response + children_list = None + if "children" in item: + children_list = _deserialize_info_list(item.get("children", [])) + info = Info( + children=children_list, + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent"), + ) + result.append(info) + return result if result else None + + @dataclass class EventConfig: + """Configuration for a BiDi event.""" event_key: str bidi_event: str event_class: type +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + class _EventManager: - """Class to manage event subscriptions and callbacks for BrowsingContext.""" + """Manages event subscriptions and callbacks.""" def __init__(self, conn, event_configs: dict[str, EventConfig]): self.conn = conn self.event_configs = event_configs self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} self._available_events = ", ".join(sorted(event_configs.keys())) - # Thread safety lock for subscription operations self._subscription_lock = threading.Lock() + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + def validate_event(self, event: str) -> EventConfig: event_config = self.event_configs.get(event) if not event_config: @@ -614,447 +536,352 @@ def validate_event(self, event: str) -> EventConfig: return event_config def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed. - - Args: - bidi_event: The BiDi event name. - contexts: Optional browsing context IDs to subscribe to. - """ + """Subscribe to a BiDi event if not already subscribed.""" with self._subscription_lock: if bidi_event not in self.subscriptions: session = Session(self.conn) - self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts)) - self.subscriptions[bidi_event] = [] + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist. - - Args: - bidi_event: The BiDi event name. - """ + """Unsubscribe from a BiDi event if no more callbacks exist.""" with self._subscription_lock: - callback_list = self.subscriptions.get(bidi_event) - if callback_list is not None and not callback_list: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: session = Session(self.conn) - self.conn.execute(session.unsubscribe(bidi_event)) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) del self.subscriptions[bidi_event] def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: with self._subscription_lock: - self.subscriptions[bidi_event].append(callback_id) + self.subscriptions[bidi_event]["callbacks"].append(callback_id) def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: with self._subscription_lock: - callback_list = self.subscriptions.get(bidi_event) - if callback_list and callback_id in callback_list: - callback_list.remove(callback_id) + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: event_config = self.validate_event(event) - - callback_id = self.conn.add_callback(event_config.event_class, callback) - - # Subscribe to the event if needed + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) self.subscribe_to_event(event_config.bidi_event, contexts) - - # Track the callback self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id def remove_event_handler(self, event: str, callback_id: int) -> None: event_config = self.validate_event(event) - - # Remove the callback from the connection - self.conn.remove_callback(event_config.event_class, callback_id) - - # Remove from tracking collections + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - - # Unsubscribe if no more callbacks exist self.unsubscribe_from_event(event_config.bidi_event) def clear_event_handlers(self) -> None: - """Clear all event handlers from the browsing context.""" + """Clear all event handlers.""" with self._subscription_lock: if not self.subscriptions: return - session = Session(self.conn) - - for bidi_event, callback_ids in list(self.subscriptions.items()): - event_class = self._bidi_to_class.get(bidi_event) - if event_class: - # Remove all callbacks for this event - for callback_id in callback_ids: - self.conn.remove_callback(event_class, callback_id) - - self.conn.execute(session.unsubscribe(bidi_event)) - + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) self.subscriptions.clear() -class BrowsingContext: - """BiDi implementation of the browsingContext module.""" - - EVENT_CONFIGS = { - "context_created": EventConfig("context_created", "browsingContext.contextCreated", ContextCreated), - "context_destroyed": EventConfig("context_destroyed", "browsingContext.contextDestroyed", ContextDestroyed), - "dom_content_loaded": EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", DomContentLoaded), - "download_end": EventConfig("download_end", "browsingContext.downloadEnd", DownloadEnd), - "download_will_begin": EventConfig( - "download_will_begin", "browsingContext.downloadWillBegin", DownloadWillBegin - ), - "fragment_navigated": EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", FragmentNavigated), - "history_updated": EventConfig("history_updated", "browsingContext.historyUpdated", HistoryUpdated), - "load": EventConfig("load", "browsingContext.load", Load), - "navigation_aborted": EventConfig("navigation_aborted", "browsingContext.navigationAborted", NavigationAborted), - "navigation_committed": EventConfig( - "navigation_committed", "browsingContext.navigationCommitted", NavigationCommitted - ), - "navigation_failed": EventConfig("navigation_failed", "browsingContext.navigationFailed", NavigationFailed), - "navigation_started": EventConfig("navigation_started", "browsingContext.navigationStarted", NavigationStarted), - "user_prompt_closed": EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", UserPromptClosed), - "user_prompt_opened": EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", UserPromptOpened), - } - - def __init__(self, conn): - self.conn = conn - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - - @classmethod - def get_event_names(cls) -> list[str]: - """Get a list of all available event names. - - Returns: - A list of event names that can be used with event handlers. - """ - return list(cls.EVENT_CONFIGS.keys()) - def activate(self, context: str) -> None: - """Activates and focuses the given top-level traversable. - Args: - context: The browsing context ID to activate. +class BrowsingContext: + """WebDriver BiDi browsingContext module.""" - Raises: - Exception: If the browsing context is not a top-level traversable. - """ - params = {"context": context} - self.conn.execute(command_builder("browsingContext.activate", params)) - - def capture_screenshot( - self, - context: str, - origin: str = "viewport", - format: dict | None = None, - clip: dict | None = None, - ) -> str: - """Captures an image of the given navigable, and returns it as a Base64-encoded string. + EVENT_CONFIGS = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - Args: - context: The browsing context ID to capture. - origin: The origin of the screenshot, either "viewport" or "document". - format: The format of the screenshot. - clip: The clip rectangle of the screenshot. + def activate(self, context: Any | None = None): + """Execute browsingContext.activate.""" + params = { + "context": context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.activate", params) + result = self._conn.execute(cmd) + return result - Returns: - The Base64-encoded screenshot. - """ - params: dict[str, Any] = {"context": context, "origin": origin} - if format is not None: - params["format"] = format - if clip is not None: - params["clip"] = clip + def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): + """Execute browsingContext.captureScreenshot.""" + params = { + "context": context, + "format": format, + "clip": clip, + "origin": origin, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.captureScreenshot", params) + result = self._conn.execute(cmd) + if result and "data" in result: + extracted = result.get("data") + return extracted + return result - result = self.conn.execute(command_builder("browsingContext.captureScreenshot", params)) - return result["data"] + def close(self, context: Any | None = None, prompt_unload: bool | None = None): + """Execute browsingContext.close.""" + params = { + "context": context, + "promptUnload": prompt_unload, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.close", params) + result = self._conn.execute(cmd) + return result - def close(self, context: str, prompt_unload: bool = False) -> None: - """Closes a top-level traversable. + def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): + """Execute browsingContext.create.""" + params = { + "type": type, + "referenceContext": reference_context, + "background": background, + "userContext": user_context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.create", params) + result = self._conn.execute(cmd) + if result and "context" in result: + extracted = result.get("context") + return extracted + return result - Args: - context: The browsing context ID to close. - prompt_unload: Whether to prompt to unload. + def get_tree(self, max_depth: Any | None = None, root: Any | None = None): + """Execute browsingContext.getTree.""" + params = { + "maxDepth": max_depth, + "root": root, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.getTree", params) + result = self._conn.execute(cmd) + if result and "contexts" in result: + items = result.get("contexts", []) + return [ + Info( + children=_deserialize_info_list(item.get("children", [])), + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent") + ) + for item in items + if isinstance(item, dict) + ] + return [] + + def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): + """Execute browsingContext.handleUserPrompt.""" + params = { + "context": context, + "accept": accept, + "userText": user_text, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.handleUserPrompt", params) + result = self._conn.execute(cmd) + return result - Raises: - Exception: If the browsing context is not a top-level traversable. - """ - params = {"context": context, "promptUnload": prompt_unload} - self.conn.execute(command_builder("browsingContext.close", params)) - - def create( - self, - type: str, - reference_context: str | None = None, - background: bool = False, - user_context: str | None = None, - ) -> str: - """Creates a new navigable, either in a new tab or in a new window, and returns its navigable id. + def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): + """Execute browsingContext.locateNodes.""" + params = { + "context": context, + "locator": locator, + "serializationOptions": serialization_options, + "startNodes": start_nodes, + "maxNodeCount": max_node_count, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.locateNodes", params) + result = self._conn.execute(cmd) + if result and "nodes" in result: + extracted = result.get("nodes") + return extracted + return result - Args: - type: The type of the new navigable, either "tab" or "window". - reference_context: The reference browsing context ID. - background: Whether to create the new navigable in the background. - user_context: The user context ID. + def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): + """Execute browsingContext.navigate.""" + params = { + "context": context, + "url": url, + "wait": wait, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.navigate", params) + result = self._conn.execute(cmd) + return result - Returns: - The browsing context ID of the created navigable. - """ - params: dict[str, Any] = {"type": type} - if reference_context is not None: - params["referenceContext"] = reference_context - if background is not None: - params["background"] = background - if user_context is not None: - params["userContext"] = user_context - - result = self.conn.execute(command_builder("browsingContext.create", params)) - return result["context"] - - def get_tree( - self, - max_depth: int | None = None, - root: str | None = None, - ) -> list[BrowsingContextInfo]: - """Get a tree of all descendent navigables including the given parent itself. - - Returns a tree of all descendent navigables including the given parent itself, or all top-level contexts - when no parent is provided. + def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): + """Execute browsingContext.print.""" + params = { + "context": context, + "background": background, + "margin": margin, + "page": page, + "scale": scale, + "shrinkToFit": shrink_to_fit, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.print", params) + result = self._conn.execute(cmd) + if result and "data" in result: + extracted = result.get("data") + return extracted + return result - Args: - max_depth: The maximum depth of the tree. - root: The root browsing context ID. + def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): + """Execute browsingContext.reload.""" + params = { + "context": context, + "ignoreCache": ignore_cache, + "wait": wait, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.reload", params) + result = self._conn.execute(cmd) + return result - Returns: - A list of browsing context information. - """ - params: dict[str, Any] = {} - if max_depth is not None: - params["maxDepth"] = max_depth - if root is not None: - params["root"] = root - - result = self.conn.execute(command_builder("browsingContext.getTree", params)) - return [BrowsingContextInfo.from_json(context) for context in result["contexts"]] - - def handle_user_prompt( - self, - context: str, - accept: bool | None = None, - user_text: str | None = None, - ) -> None: - """Allows closing an open prompt. + def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + """Execute browsingContext.setViewport.""" + params = { + "context": context, + "viewport": viewport, + "userContexts": user_contexts, + "devicePixelRatio": device_pixel_ratio, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.setViewport", params) + result = self._conn.execute(cmd) + return result - Args: - context: The browsing context ID. - accept: Whether to accept the prompt. - user_text: The text to enter in the prompt. - """ - params: dict[str, Any] = {"context": context} - if accept is not None: - params["accept"] = accept - if user_text is not None: - params["userText"] = user_text - - self.conn.execute(command_builder("browsingContext.handleUserPrompt", params)) - - def locate_nodes( - self, - context: str, - locator: dict, - max_node_count: int | None = None, - serialization_options: dict | None = None, - start_nodes: list[dict] | None = None, - ) -> list[dict]: - """Returns a list of all nodes matching the specified locator. + def traverse_history(self, context: Any | None = None, delta: Any | None = None): + """Execute browsingContext.traverseHistory.""" + params = { + "context": context, + "delta": delta, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.traverseHistory", params) + result = self._conn.execute(cmd) + return result - Args: - context: The browsing context ID. - locator: The locator to use. - max_node_count: The maximum number of nodes to return. - serialization_options: The serialization options. - start_nodes: The start nodes. - Returns: - A list of nodes. - """ - params: dict[str, Any] = {"context": context, "locator": locator} - if max_node_count is not None: - params["maxNodeCount"] = max_node_count - if serialization_options is not None: - params["serializationOptions"] = serialization_options - if start_nodes is not None: - params["startNodes"] = start_nodes - - result = self.conn.execute(command_builder("browsingContext.locateNodes", params)) - return result["nodes"] - - def navigate( - self, - context: str, - url: str, - wait: str | None = None, - ) -> dict: - """Navigates a navigable to the given URL. + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - context: The browsing context ID. - url: The URL to navigate to. - wait: The readiness state to wait for. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). Returns: - A dictionary containing the navigation result. + The callback ID. """ - params = {"context": context, "url": url} - if wait is not None: - params["wait"] = wait - - result = self.conn.execute(command_builder("browsingContext.navigate", params)) - return result + return self._event_manager.add_event_handler(event, callback, contexts) - def print( - self, - context: str, - background: bool = False, - margin: dict | None = None, - orientation: str = "portrait", - page: dict | None = None, - page_ranges: list[int | str] | None = None, - scale: float = 1.0, - shrink_to_fit: bool = True, - ) -> str: - """Create a paginated PDF representation of the document as a Base64-encoded string. + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. Args: - context: The browsing context ID. - background: Whether to include the background. - margin: The margin parameters. - orientation: The orientation, either "portrait" or "landscape". - page: The page parameters. - page_ranges: The page ranges. - scale: The scale. - shrink_to_fit: Whether to shrink to fit. - - Returns: - The Base64-encoded PDF document. + event: The event to unsubscribe from. + callback_id: The callback ID. """ - params = { - "context": context, - "background": background, - "orientation": orientation, - "scale": scale, - "shrinkToFit": shrink_to_fit, - } - if margin is not None: - params["margin"] = margin - if page is not None: - params["page"] = page - if page_ranges is not None: - params["pageRanges"] = page_ranges - - result = self.conn.execute(command_builder("browsingContext.print", params)) - return result["data"] - - def reload( - self, - context: str, - ignore_cache: bool | None = None, - wait: str | None = None, - ) -> dict: - """Reloads a navigable. + return self._event_manager.remove_event_handler(event, callback_id) - Args: - context: The browsing context ID. - ignore_cache: Whether to ignore the cache. - wait: The readiness state to wait for. + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() - Returns: - A dictionary containing the navigation result. - """ - params: dict[str, Any] = {"context": context} - if ignore_cache is not None: - params["ignoreCache"] = ignore_cache - if wait is not None: - params["wait"] = wait +# Event Info Type Aliases +# Event: browsingContext.contextCreated +ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined - result = self.conn.execute(command_builder("browsingContext.reload", params)) - return result +# Event: browsingContext.contextDestroyed +ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined - def set_viewport( - self, - context: str | None = None, - viewport: dict | None | Sentinel = UNDEFINED, - device_pixel_ratio: float | None | Sentinel = UNDEFINED, - user_contexts: list[str] | None = None, - ) -> None: - """Modifies specific viewport characteristics on the given top-level traversable. +# Event: browsingContext.navigationStarted +NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - Args: - context: The browsing context ID. - viewport: The viewport parameters - {"width": , "height": } (`None` resets to default). - device_pixel_ratio: The device pixel ratio (`None` resets to default). - user_contexts: The user context IDs. - - Raises: - Exception: If the browsing context is not a top-level traversable - ValueError: If neither `context` nor `user_contexts` is provided - ValueError: If both `context` and `user_contexts` are provided - """ - if context is not None and user_contexts is not None: - raise ValueError("Cannot specify both context and user_contexts") +# Event: browsingContext.fragmentNavigated +FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - if context is None and user_contexts is None: - raise ValueError("Must specify either context or user_contexts") +# Event: browsingContext.historyUpdated +HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined - params: dict[str, Any] = {} - if context is not None: - params["context"] = context - elif user_contexts is not None: - params["userContexts"] = user_contexts - if viewport is not UNDEFINED: - params["viewport"] = viewport - if device_pixel_ratio is not UNDEFINED: - params["devicePixelRatio"] = device_pixel_ratio +# Event: browsingContext.domContentLoaded +DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - self.conn.execute(command_builder("browsingContext.setViewport", params)) +# Event: browsingContext.load +Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - def traverse_history(self, context: str, delta: int) -> dict: - """Traverses the history of a given navigable by a delta. +# Event: browsingContext.downloadWillBegin +DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined - Args: - context: The browsing context ID. - delta: The delta to traverse by. +# Event: browsingContext.downloadEnd +DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined - Returns: - A dictionary containing the traverse history result. - """ - params = {"context": context, "delta": delta} - result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) - return result +# Event: browsingContext.navigationAborted +NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler to the browsing context. +# Event: browsingContext.navigationCommitted +NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The browsing context IDs to subscribe to. +# Event: browsingContext.navigationFailed +NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - Returns: - Callback id. - """ - return self._event_manager.add_event_handler(event, callback, contexts) +# Event: browsingContext.userPromptClosed +UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined - def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler from the browsing context. +# Event: browsingContext.userPromptOpened +UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined - Args: - event: The event to unsubscribe from. - callback_id: The callback id to remove. - """ - self._event_manager.remove_event_handler(event, callback_id) - def clear_event_handlers(self) -> None: - """Clear all event handlers from the browsing context.""" - self._event_manager.clear_event_handlers() +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +BrowsingContext.EVENT_CONFIGS = { + "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), + "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), + "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), + "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), + "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), + "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), + "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), + "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), + "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), + "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), + "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), + "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), + "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), + "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), + "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), + "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index 0f57d07e5f0d4..d90d8c770263a 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -15,22 +15,25 @@ # specific language governing permissions and limitations # under the License. -from collections.abc import Generator +"""Common utilities for BiDi command construction.""" +from typing import Any, Dict, Generator -def command_builder(method: str, params: dict | None = None) -> Generator[dict, dict, dict]: - """Build a command iterator to send to the BiDi protocol. + +def command_builder( + method: str, params: Dict[str, Any] +) -> Generator[Dict[str, Any], Any, Any]: + """Build a BiDi command generator. Args: - method: The method to execute. - params: The parameters to pass to the method. Default is None. + method: The BiDi method name (e.g., "session.status", "browser.close") + params: The parameters for the command + + Yields: + A dictionary representing the BiDi command Returns: - The response from the command execution. + The result from the BiDi command execution """ - if params is None: - params = {} - - command = {"method": method, "params": params} - cmd = yield command - return cmd + result = yield {"method": method, "params": params} + return result diff --git a/py/selenium/webdriver/common/bidi/console.py b/py/selenium/webdriver/common/bidi/console.py old mode 100644 new mode 100755 diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index a6acaefe89b83..4cd6ae2e3c712 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,39 +1,34 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: emulation from __future__ import annotations -from enum import Enum -from typing import TYPE_CHECKING, Any, TypeVar +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass -from selenium.webdriver.common.bidi.common import command_builder -if TYPE_CHECKING: - from selenium.webdriver.remote.websocket_connection import WebSocketConnection +class ForcedColorsModeTheme: + """ForcedColorsModeTheme.""" + LIGHT = "light" + DARK = "dark" -class ScreenOrientationNatural(Enum): - """Natural screen orientation.""" + +class ScreenOrientationNatural: + """ScreenOrientationNatural.""" PORTRAIT = "portrait" LANDSCAPE = "landscape" -class ScreenOrientationType(Enum): - """Screen orientation type.""" +class ScreenOrientationType: + """ScreenOrientationType.""" PORTRAIT_PRIMARY = "portrait-primary" PORTRAIT_SECONDARY = "portrait-secondary" @@ -41,484 +36,494 @@ class ScreenOrientationType(Enum): LANDSCAPE_SECONDARY = "landscape-secondary" -E = TypeVar("E", ScreenOrientationNatural, ScreenOrientationType) +@dataclass +class SetForcedColorsModeThemeOverrideParameters: + """SetForcedColorsModeThemeOverrideParameters.""" + theme: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None -def _convert_to_enum(value: E | str, enum_class: type[E]) -> E: - if isinstance(value, enum_class): - return value - assert isinstance(value, str) - try: - return enum_class(value.lower()) - except ValueError: - raise ValueError(f"Invalid orientation: {value}") +@dataclass +class SetGeolocationOverrideParameters: + """SetGeolocationOverrideParameters.""" -class ScreenOrientation: - """Represents screen orientation configuration.""" + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - def __init__( - self, - natural: ScreenOrientationNatural | str, - type: ScreenOrientationType | str, - ): - """Initialize ScreenOrientation. - Args: - natural: Natural screen orientation ("portrait" or "landscape"). - type: Screen orientation type ("portrait-primary", "portrait-secondary", - "landscape-primary", or "landscape-secondary"). +@dataclass +class GeolocationCoordinates: + """GeolocationCoordinates.""" - Raises: - ValueError: If natural or type values are invalid. - """ - # handle string values - self.natural = _convert_to_enum(natural, ScreenOrientationNatural) - self.type = _convert_to_enum(type, ScreenOrientationType) - - def to_dict(self) -> dict[str, str]: - return { - "natural": self.natural.value, - "type": self.type.value, - } + latitude: Any | None = None + longitude: Any | None = None + accuracy: Any | None = None + altitude: Any | None = None + altitude_accuracy: Any | None = None + heading: Any | None = None + speed: Any | None = None -class GeolocationCoordinates: - """Represents geolocation coordinates.""" +@dataclass +class GeolocationPositionError: + """GeolocationPositionError.""" - def __init__( - self, - latitude: float, - longitude: float, - accuracy: float = 1.0, - altitude: float | None = None, - altitude_accuracy: float | None = None, - heading: float | None = None, - speed: float | None = None, - ): - """Initialize GeolocationCoordinates. + type: str = field(default="positionUnavailable", init=False) - Args: - latitude: Latitude coordinate (-90.0 to 90.0). - longitude: Longitude coordinate (-180.0 to 180.0). - accuracy: Accuracy in meters (>= 0.0), defaults to 1.0. - altitude: Altitude in meters or None, defaults to None. - altitude_accuracy: Altitude accuracy in meters (>= 0.0) or None, defaults to None. - heading: Heading in degrees (0.0 to 360.0) or None, defaults to None. - speed: Speed in meters per second (>= 0.0) or None, defaults to None. - - Raises: - ValueError: If coordinates are out of valid range or if altitude_accuracy is provided without altitude. - """ - self.latitude = latitude - self.longitude = longitude - self.accuracy = accuracy - self.altitude = altitude - self.altitude_accuracy = altitude_accuracy - self.heading = heading - self.speed = speed - - @property - def latitude(self) -> float: - return self._latitude - - @latitude.setter - def latitude(self, value: float) -> None: - if not (-90.0 <= value <= 90.0): - raise ValueError("latitude must be between -90.0 and 90.0") - self._latitude = value - - @property - def longitude(self) -> float: - return self._longitude - - @longitude.setter - def longitude(self, value: float) -> None: - if not (-180.0 <= value <= 180.0): - raise ValueError("longitude must be between -180.0 and 180.0") - self._longitude = value - - @property - def accuracy(self) -> float: - return self._accuracy - - @accuracy.setter - def accuracy(self, value: float) -> None: - if value < 0.0: - raise ValueError("accuracy must be >= 0.0") - self._accuracy = value - - @property - def altitude(self) -> float | None: - return self._altitude - - @altitude.setter - def altitude(self, value: float | None) -> None: - self._altitude = value - - @property - def altitude_accuracy(self) -> float | None: - return self._altitude_accuracy - - @altitude_accuracy.setter - def altitude_accuracy(self, value: float | None) -> None: - if value is not None and self.altitude is None: - raise ValueError("altitude_accuracy cannot be set without altitude") - if value is not None and value < 0.0: - raise ValueError("altitude_accuracy must be >= 0.0") - self._altitude_accuracy = value - - @property - def heading(self) -> float | None: - return self._heading - - @heading.setter - def heading(self, value: float | None) -> None: - if value is not None and not (0.0 <= value < 360.0): - raise ValueError("heading must be between 0.0 and 360.0") - self._heading = value - - @property - def speed(self) -> float | None: - return self._speed - - @speed.setter - def speed(self, value: float | None) -> None: - if value is not None and value < 0.0: - raise ValueError("speed must be >= 0.0") - self._speed = value - - def to_dict(self) -> dict[str, float | None]: - result: dict[str, float | None] = { - "latitude": self.latitude, - "longitude": self.longitude, - "accuracy": self.accuracy, - } - if self.altitude is not None: - result["altitude"] = self.altitude +@dataclass +class SetLocaleOverrideParameters: + """SetLocaleOverrideParameters.""" - if self.altitude_accuracy is not None: - result["altitudeAccuracy"] = self.altitude_accuracy + locale: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - if self.heading is not None: - result["heading"] = self.heading - if self.speed is not None: - result["speed"] = self.speed +@dataclass +class setNetworkConditionsParameters: + """setNetworkConditionsParameters.""" - return result + network_conditions: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None -class GeolocationPositionError: - """Represents a geolocation position error.""" +@dataclass +class NetworkConditionsOffline: + """NetworkConditionsOffline.""" - TYPE_POSITION_UNAVAILABLE = "positionUnavailable" + type: str = field(default="offline", init=False) - def __init__(self, type: str = TYPE_POSITION_UNAVAILABLE): - if type != self.TYPE_POSITION_UNAVAILABLE: - raise ValueError(f'type must be "{self.TYPE_POSITION_UNAVAILABLE}"') - self.type = type - def to_dict(self) -> dict[str, str]: - return {"type": self.type} +@dataclass +class ScreenArea: + """ScreenArea.""" + width: Any | None = None + height: Any | None = None -class Emulation: - """BiDi implementation of the emulation module.""" - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn +@dataclass +class SetScreenSettingsOverrideParameters: + """SetScreenSettingsOverrideParameters.""" - def set_geolocation_override( - self, - coordinates: GeolocationCoordinates | None = None, - error: GeolocationPositionError | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set geolocation override for the given contexts or user contexts. + screen_area: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - Args: - coordinates: Geolocation coordinates to emulate, or None. - error: Geolocation error to emulate, or None. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both coordinates and error are provided, or if both contexts - and user_contexts are provided, or if neither contexts nor - user_contexts are provided. - """ - if coordinates is not None and error is not None: - raise ValueError("Cannot specify both coordinates and error") - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") +@dataclass +class ScreenOrientation: + """ScreenOrientation.""" - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + natural: Any | None = None + type: Any | None = None - params: dict[str, Any] = {} - if coordinates is not None: - params["coordinates"] = coordinates.to_dict() - elif error is not None: - params["error"] = error.to_dict() +@dataclass +class SetScreenOrientationOverrideParameters: + """SetScreenOrientationOverrideParameters.""" - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts + screen_orientation: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - self.conn.execute(command_builder("emulation.setGeolocationOverride", params)) - def set_timezone_override( - self, - timezone: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set timezone override for the given contexts or user contexts. +@dataclass +class SetUserAgentOverrideParameters: + """SetUserAgentOverrideParameters.""" - Args: - timezone: Timezone identifier (IANA timezone name or offset string like '+01:00'), - or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") + user_agent: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") - params: dict[str, Any] = {"timezone": timezone} +@dataclass +class SetViewportMetaOverrideParameters: + """SetViewportMetaOverrideParameters.""" - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts + viewport_meta: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - self.conn.execute(command_builder("emulation.setTimezoneOverride", params)) - def set_locale_override( - self, - locale: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set locale override for the given contexts or user contexts. +@dataclass +class SetScriptingEnabledParameters: + """SetScriptingEnabledParameters.""" - Args: - locale: Locale string as per BCP 47, or None to clear override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + enabled: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided, or if locale is invalid. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") +@dataclass +class SetScrollbarTypeOverrideParameters: + """SetScrollbarTypeOverrideParameters.""" - params: dict[str, Any] = {"locale": locale} + scrollbar_type: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts - self.conn.execute(command_builder("emulation.setLocaleOverride", params)) +@dataclass +class SetTimezoneOverrideParameters: + """SetTimezoneOverrideParameters.""" - def set_scripting_enabled( - self, - enabled: bool | None = False, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set scripting enabled override for the given contexts or user contexts. + timezone: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - Args: - enabled: False to disable scripting, None to clear the override. - Note: Only emulation of disabled JavaScript is supported. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided, or if enabled is True. - """ - if enabled: - raise ValueError("Only emulation of disabled JavaScript is supported (enabled must be False or None)") - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") +@dataclass +class SetTouchOverrideParameters: + """SetTouchOverrideParameters.""" - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - params: dict[str, Any] = {"enabled": enabled} - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts +class Emulation: + """WebDriver BiDi emulation module.""" - self.conn.execute(command_builder("emulation.setScriptingEnabled", params)) + def __init__(self, conn) -> None: + self._conn = conn - def set_screen_orientation_override( - self, - screen_orientation: ScreenOrientation | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set screen orientation override for the given contexts or user contexts. + def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setForcedColorsModeThemeOverride.""" + params = { + "theme": theme, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setForcedColorsModeThemeOverride", params) + result = self._conn.execute(cmd) + return result - Args: - screen_orientation: ScreenOrientation object to emulate, or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setGeolocationOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") + def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setLocaleOverride.""" + params = { + "locale": locale, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setLocaleOverride", params) + result = self._conn.execute(cmd) + return result + + def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setNetworkConditions.""" + params = { + "networkConditions": network_conditions, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setNetworkConditions", params) + result = self._conn.execute(cmd) + return result - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenSettingsOverride.""" + params = { + "screenArea": screen_area, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenSettingsOverride", params) + result = self._conn.execute(cmd) + return result - params: dict[str, Any] = { - "screenOrientation": screen_orientation.to_dict() if screen_orientation is not None else None + def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenOrientationOverride.""" + params = { + "screenOrientation": screen_orientation, + "contexts": contexts, + "userContexts": user_contexts, } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenOrientationOverride", params) + result = self._conn.execute(cmd) + return result - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts + def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setUserAgentOverride.""" + params = { + "userAgent": user_agent, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setUserAgentOverride", params) + result = self._conn.execute(cmd) + return result - self.conn.execute(command_builder("emulation.setScreenOrientationOverride", params)) + def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setViewportMetaOverride.""" + params = { + "viewportMeta": viewport_meta, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setViewportMetaOverride", params) + result = self._conn.execute(cmd) + return result - def set_user_agent_override( - self, - user_agent: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set user agent override for the given contexts or user contexts. + def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScriptingEnabled.""" + params = { + "enabled": enabled, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScriptingEnabled", params) + result = self._conn.execute(cmd) + return result - Args: - user_agent: User agent string to emulate, or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScrollbarTypeOverride.""" + params = { + "scrollbarType": scrollbar_type, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScrollbarTypeOverride", params) + result = self._conn.execute(cmd) + return result - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") + def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setTimezoneOverride.""" + params = { + "timezone": timezone, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTimezoneOverride", params) + result = self._conn.execute(cmd) + return result + + def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setTouchOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTouchOverride", params) + result = self._conn.execute(cmd) + return result - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") + def set_geolocation_override( + self, + coordinates=None, + error=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setGeolocationOverride. - params: dict[str, Any] = {"userAgent": user_agent} + Sets or clears the geolocation override for specified browsing or user contexts. + Args: + coordinates: A GeolocationCoordinates instance (or dict) to override the + position, or ``None`` to clear a previously-set override. + error: A GeolocationPositionError instance (or dict) to simulate a + position-unavailable error. Mutually exclusive with *coordinates*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {} + if coordinates is not None: + if isinstance(coordinates, dict): + coords_dict = coordinates + else: + coords_dict = {} + if coordinates.latitude is not None: + coords_dict["latitude"] = coordinates.latitude + if coordinates.longitude is not None: + coords_dict["longitude"] = coordinates.longitude + if coordinates.accuracy is not None: + coords_dict["accuracy"] = coordinates.accuracy + if coordinates.altitude is not None: + coords_dict["altitude"] = coordinates.altitude + if coordinates.altitude_accuracy is not None: + coords_dict["altitudeAccuracy"] = coordinates.altitude_accuracy + if coordinates.heading is not None: + coords_dict["heading"] = coordinates.heading + if coordinates.speed is not None: + coords_dict["speed"] = coordinates.speed + params["coordinates"] = coords_dict + if error is not None: + if isinstance(error, dict): + params["error"] = error + else: + params["error"] = { + "type": error.type if error.type is not None else "positionUnavailable" + } if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("emulation.setUserAgentOverride", params)) - - def set_network_conditions( + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result + def set_timezone_override( self, - offline: bool = False, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set network conditions for the given contexts or user contexts. + timezone=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setTimezoneOverride. - Args: - offline: True to emulate offline network conditions, False to clear the override. - contexts: List of browsing context IDs to apply the conditions to. - user_contexts: List of user context IDs to apply the conditions to. + Sets or clears the timezone override for specified browsing or user contexts. + Pass ``timezone=None`` (or omit it) to clear a previously-set override. - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. + Args: + timezone: IANA timezone string (e.g. ``"America/New_York"``) or ``None`` + to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") - - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") - - params: dict[str, Any] = {} - - if offline: - params["networkConditions"] = {"type": "offline"} - else: - # if offline is False or None, then clear the override - params["networkConditions"] = None - + params = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts + cmd = command_builder("emulation.setTimezoneOverride", params) + return self._conn.execute(cmd) + def set_scripting_enabled( + self, + enabled=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setScriptingEnabled. - self.conn.execute(command_builder("emulation.setNetworkConditions", params)) + Enables or disables scripting for specified browsing or user contexts. + Pass ``enabled=None`` to restore the default behaviour. - def set_screen_settings_override( + Args: + enabled: ``True`` to enable scripting, ``False`` to disable it, or + ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {"enabled": enabled} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScriptingEnabled", params) + return self._conn.execute(cmd) + def set_user_agent_override( self, - width: int | None = None, - height: int | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set screen settings override for the given contexts or user contexts. + user_agent=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setUserAgentOverride. + + Overrides the User-Agent string for specified browsing or user contexts. + Pass ``user_agent=None`` to clear a previously-set override. Args: - width: Screen width in pixels (>= 0). None to clear the override. - height: Screen height in pixels (>= 0). None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If only one of width/height is provided, or if both contexts - and user_contexts are provided, or if neither is provided. + user_agent: Custom User-Agent string, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. """ - if (width is None) != (height is None): - raise ValueError("Must provide both width and height, or neither to clear the override") - - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") + params = {"userAgent": user_agent} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setUserAgentOverride", params) + return self._conn.execute(cmd) + def set_screen_orientation_override( + self, + screen_orientation=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setScreenOrientationOverride. - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") + Sets or clears the screen orientation override for specified browsing or + user contexts. - screen_area = None - if width is not None and height is not None: - if not isinstance(width, int) or not isinstance(height, int): - raise ValueError("width and height must be integers") - if width < 0 or height < 0: - raise ValueError("width and height must be >= 0") - screen_area = {"width": width, "height": height} + Args: + screen_orientation: A :class:`ScreenOrientation` instance (or dict with + ``natural`` and ``type`` keys) to lock the orientation, or ``None`` + to clear a previously-set override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if screen_orientation is None: + so_value = None + elif isinstance(screen_orientation, dict): + so_value = screen_orientation + else: + natural = screen_orientation.natural + orientation_type = screen_orientation.type + so_value = { + "natural": natural.lower() if isinstance(natural, str) else natural, + "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, + } + params = {"screenOrientation": so_value} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenOrientationOverride", params) + return self._conn.execute(cmd) + def set_network_conditions( + self, + network_conditions=None, + offline: bool | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setNetworkConditions. - params: dict[str, Any] = {"screenArea": screen_area} + Sets or clears network condition emulation for specified browsing or user + contexts. + Args: + network_conditions: A dict with the raw ``networkConditions`` value + (e.g. ``{"type": "offline"}``), or ``None`` to clear the override. + Mutually exclusive with *offline*. + offline: Convenience bool — ``True`` sets offline conditions, + ``False`` clears them (sends ``null``). When provided, this takes + precedence over *network_conditions*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if offline is not None: + nc_value = {"type": "offline"} if offline else None + else: + nc_value = network_conditions + params = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("emulation.setScreenSettingsOverride", params)) + cmd = command_builder("emulation.setNetworkConditions", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 270ececaf41a1..5dbe71dbd3886 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,40 +1,32 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import math -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder +# WebDriver BiDi module: input +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading +from collections.abc import Callable +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session class PointerType: - """Represents the possible pointer types.""" + """PointerType.""" MOUSE = "mouse" PEN = "pen" TOUCH = "touch" - VALID_TYPES = {MOUSE, PEN, TOUCH} - class Origin: - """Represents the possible origin types.""" + """Origin.""" VIEWPORT = "viewport" POINTER = "pointer" @@ -42,421 +34,425 @@ class Origin: @dataclass class ElementOrigin: - """Represents an element origin for input actions.""" - - type: str - element: dict - - def __init__(self, element_reference: dict): - self.type = "element" - self.element = element_reference + """ElementOrigin.""" - def to_dict(self) -> dict: - """Convert the ElementOrigin to a dictionary.""" - return {"type": self.type, "element": self.element} + type: str = field(default="element", init=False) + element: Any | None = None @dataclass -class PointerParameters: - """Represents pointer parameters for pointer actions.""" - - pointer_type: str = PointerType.MOUSE - - def __post_init__(self): - if self.pointer_type not in PointerType.VALID_TYPES: - raise ValueError(f"Invalid pointer type: {self.pointer_type}. Must be one of {PointerType.VALID_TYPES}") +class PerformActionsParameters: + """PerformActionsParameters.""" - def to_dict(self) -> dict: - """Convert the PointerParameters to a dictionary.""" - return {"pointerType": self.pointer_type} + context: Any | None = None + actions: list[Any | None] | None = None @dataclass -class PointerCommonProperties: - """Common properties for pointer actions.""" - - width: int = 1 - height: int = 1 - pressure: float = 0.0 - tangential_pressure: float = 0.0 - twist: int = 0 - altitude_angle: float = 0.0 - azimuth_angle: float = 0.0 - - def __post_init__(self): - if self.width < 1: - raise ValueError("width must be at least 1") - if self.height < 1: - raise ValueError("height must be at least 1") - if not (0.0 <= self.pressure <= 1.0): - raise ValueError("pressure must be between 0.0 and 1.0") - if not (0.0 <= self.tangential_pressure <= 1.0): - raise ValueError("tangential_pressure must be between 0.0 and 1.0") - if not (0 <= self.twist <= 359): - raise ValueError("twist must be between 0 and 359") - if not (0.0 <= self.altitude_angle <= math.pi / 2): - raise ValueError("altitude_angle must be between 0.0 and π/2") - if not (0.0 <= self.azimuth_angle <= 2 * math.pi): - raise ValueError("azimuth_angle must be between 0.0 and 2π") - - def to_dict(self) -> dict: - """Convert the PointerCommonProperties to a dictionary.""" - result: dict[str, Any] = {} - if self.width != 1: - result["width"] = self.width - if self.height != 1: - result["height"] = self.height - if self.pressure != 0.0: - result["pressure"] = self.pressure - if self.tangential_pressure != 0.0: - result["tangentialPressure"] = self.tangential_pressure - if self.twist != 0: - result["twist"] = self.twist - if self.altitude_angle != 0.0: - result["altitudeAngle"] = self.altitude_angle - if self.azimuth_angle != 0.0: - result["azimuthAngle"] = self.azimuth_angle - return result - +class NoneSourceActions: + """NoneSourceActions.""" -# Action classes -@dataclass -class PauseAction: - """Represents a pause action.""" + type: str = field(default="none", init=False) + id: str | None = None + actions: list[Any | None] | None = None - duration: int | None = None - @property - def type(self) -> str: - return "pause" +@dataclass +class KeySourceActions: + """KeySourceActions.""" - def to_dict(self) -> dict: - """Convert the PauseAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type} - if self.duration is not None: - result["duration"] = self.duration - return result + type: str = field(default="key", init=False) + id: str | None = None + actions: list[Any | None] | None = None @dataclass -class KeyDownAction: - """Represents a key down action.""" - - value: str = "" - - @property - def type(self) -> str: - return "keyDown" +class PointerSourceActions: + """PointerSourceActions.""" - def to_dict(self) -> dict: - """Convert the KeyDownAction to a dictionary.""" - return {"type": self.type, "value": self.value} + type: str = field(default="pointer", init=False) + id: str | None = None + parameters: Any | None = None + actions: list[Any | None] | None = None @dataclass -class KeyUpAction: - """Represents a key up action.""" - - value: str = "" - - @property - def type(self) -> str: - return "keyUp" +class PointerParameters: + """PointerParameters.""" - def to_dict(self) -> dict: - """Convert the KeyUpAction to a dictionary.""" - return {"type": self.type, "value": self.value} + pointer_type: Any | None = None @dataclass -class PointerDownAction: - """Represents a pointer down action.""" +class WheelSourceActions: + """WheelSourceActions.""" - button: int = 0 - properties: PointerCommonProperties | None = None + type: str = field(default="wheel", init=False) + id: str | None = None + actions: list[Any | None] | None = None - @property - def type(self) -> str: - return "pointerDown" - def to_dict(self) -> dict: - """Convert the PointerDownAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type, "button": self.button} - if self.properties: - result.update(self.properties.to_dict()) - return result +@dataclass +class PauseAction: + """PauseAction.""" + + type: str = field(default="pause", init=False) + duration: Any | None = None @dataclass -class PointerUpAction: - """Represents a pointer up action.""" +class KeyDownAction: + """KeyDownAction.""" + + type: str = field(default="keyDown", init=False) + value: str | None = None - button: int = 0 - @property - def type(self) -> str: - return "pointerUp" +@dataclass +class KeyUpAction: + """KeyUpAction.""" - def to_dict(self) -> dict: - """Convert the PointerUpAction to a dictionary.""" - return {"type": self.type, "button": self.button} + type: str = field(default="keyUp", init=False) + value: str | None = None @dataclass -class PointerMoveAction: - """Represents a pointer move action.""" - - x: float = 0 - y: float = 0 - duration: int | None = None - origin: str | ElementOrigin | None = None - properties: PointerCommonProperties | None = None - - @property - def type(self) -> str: - return "pointerMove" - - def to_dict(self) -> dict: - """Convert the PointerMoveAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type, "x": self.x, "y": self.y} - if self.duration is not None: - result["duration"] = self.duration - if self.origin is not None: - if isinstance(self.origin, ElementOrigin): - result["origin"] = self.origin.to_dict() - else: - result["origin"] = self.origin - if self.properties: - result.update(self.properties.to_dict()) - return result +class PointerUpAction: + """PointerUpAction.""" + + type: str = field(default="pointerUp", init=False) + button: Any | None = None @dataclass class WheelScrollAction: - """Represents a wheel scroll action.""" - - x: int = 0 - y: int = 0 - delta_x: int = 0 - delta_y: int = 0 - duration: int | None = None - origin: str | ElementOrigin | None = Origin.VIEWPORT - - @property - def type(self) -> str: - return "scroll" - - def to_dict(self) -> dict: - """Convert the WheelScrollAction to a dictionary.""" - result: dict[str, Any] = { - "type": self.type, - "x": self.x, - "y": self.y, - "deltaX": self.delta_x, - "deltaY": self.delta_y, - } - if self.duration is not None: - result["duration"] = self.duration - if self.origin is not None: - if isinstance(self.origin, ElementOrigin): - result["origin"] = self.origin.to_dict() - else: - result["origin"] = self.origin - return result + """WheelScrollAction.""" + type: str = field(default="scroll", init=False) + x: Any | None = None + y: Any | None = None + delta_x: Any | None = None + delta_y: Any | None = None + duration: Any | None = None + origin: Any | None = None -# Source Actions -@dataclass -class NoneSourceActions: - """Represents a sequence of none actions.""" - id: str = "" - actions: list[PauseAction] = field(default_factory=list) - - @property - def type(self) -> str: - return "none" +@dataclass +class PointerCommonProperties: + """PointerCommonProperties.""" - def to_dict(self) -> dict: - """Convert the NoneSourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + width: Any | None = None + height: Any | None = None + pressure: Any | None = None + tangential_pressure: Any | None = None + twist: Any | None = None + altitude_angle: Any | None = None + azimuth_angle: Any | None = None @dataclass -class KeySourceActions: - """Represents a sequence of key actions.""" +class ReleaseActionsParameters: + """ReleaseActionsParameters.""" + + context: Any | None = None - id: str = "" - actions: list[PauseAction | KeyDownAction | KeyUpAction] = field(default_factory=list) - @property - def type(self) -> str: - return "key" +@dataclass +class SetFilesParameters: + """SetFilesParameters.""" - def to_dict(self) -> dict: - """Convert the KeySourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + context: Any | None = None + element: Any | None = None + files: list[Any | None] | None = None @dataclass -class PointerSourceActions: - """Represents a sequence of pointer actions.""" - - id: str = "" - parameters: PointerParameters | None = None - actions: list[PauseAction | PointerDownAction | PointerUpAction | PointerMoveAction] = field(default_factory=list) - - def __post_init__(self): - if self.parameters is None: - self.parameters = PointerParameters() - - @property - def type(self) -> str: - return "pointer" - - def to_dict(self) -> dict: - """Convert the PointerSourceActions to a dictionary.""" - result: dict[str, Any] = { - "type": self.type, - "id": self.id, - "actions": [action.to_dict() for action in self.actions], - } - if self.parameters: - result["parameters"] = self.parameters.to_dict() - return result +class FileDialogInfo: + """FileDialogInfo - parameters for the input.fileDialogOpened event.""" + + context: Any | None = None + element: Any | None = None + multiple: bool | None = None + @classmethod + def from_json(cls, params: dict) -> "FileDialogInfo": + """Deserialize event params into FileDialogInfo.""" + return cls( + context=params.get("context"), + element=params.get("element"), + multiple=params.get("multiple"), + ) @dataclass -class WheelSourceActions: - """Represents a sequence of wheel actions.""" +class PointerMoveAction: + """PointerMoveAction.""" - id: str = "" - actions: list[PauseAction | WheelScrollAction] = field(default_factory=list) + type: str = field(default="pointerMove", init=False) + x: Any | None = None + y: Any | None = None + duration: Any | None = None + origin: Any | None = None + properties: Any | None = None - @property - def type(self) -> str: - return "wheel" +@dataclass +class PointerDownAction: + """PointerDownAction.""" - def to_dict(self) -> dict: - """Convert the WheelSourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + type: str = field(default="pointerDown", init=False) + button: Any | None = None + properties: Any | None = None +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "file_dialog_opened": "input.fileDialogOpened", +} @dataclass -class FileDialogInfo: - """Represents file dialog information from input.fileDialogOpened event.""" +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type - context: str - multiple: bool - element: dict | None = None - @classmethod - def from_dict(cls, data: dict) -> "FileDialogInfo": - """Creates a FileDialogInfo instance from a dictionary. +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - data: A dictionary containing the file dialog information. + params: Raw BiDi event params with camelCase keys. Returns: - FileDialogInfo: A new instance of FileDialogInfo. + An instance of the dataclass, or the raw dict on failure. """ - return cls(context=data["context"], multiple=data["multiple"], element=data.get("element")) + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) -# Event Class -class FileDialogOpened: - """Event class for input.fileDialogOpened event.""" +class _EventManager: + """Manages event subscriptions and callbacks.""" - event_class = "input.fileDialogOpened" + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id - @classmethod - def from_json(cls, json): - """Create FileDialogInfo from JSON data.""" - return FileDialogInfo.from_dict(json) + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() -class Input: - """BiDi implementation of the input module.""" - def __init__(self, conn): - self.conn = conn - self.subscriptions = {} - self.callbacks = {} - def perform_actions( - self, - context: str, - actions: list[NoneSourceActions | KeySourceActions | PointerSourceActions | WheelSourceActions], - ) -> None: - """Performs a sequence of user input actions. +class Input: + """WebDriver BiDi input module.""" + + EVENT_CONFIGS = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): + """Execute input.performActions.""" + params = { + "context": context, + "actions": actions, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.performActions", params) + result = self._conn.execute(cmd) + return result - Args: - context: The browsing context ID where actions should be performed. - actions: A list of source actions to perform. - """ - params = {"context": context, "actions": [action.to_dict() for action in actions]} - self.conn.execute(command_builder("input.performActions", params)) + def release_actions(self, context: Any | None = None): + """Execute input.releaseActions.""" + params = { + "context": context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.releaseActions", params) + result = self._conn.execute(cmd) + return result + + def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): + """Execute input.setFiles.""" + params = { + "context": context, + "element": element, + "files": files, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.setFiles", params) + result = self._conn.execute(cmd) + return result - def release_actions(self, context: str) -> None: - """Releases all input state for the given context. + def add_file_dialog_handler(self, callback) -> int: + """Subscribe to the input.fileDialogOpened event. Args: - context: The browsing context ID to release actions for. + callback: Callable invoked with a FileDialogInfo when a file dialog opens. + + Returns: + A handler ID that can be passed to remove_file_dialog_handler. """ - params = {"context": context} - self.conn.execute(command_builder("input.releaseActions", params)) + return self._event_manager.add_event_handler("file_dialog_opened", callback) - def set_files(self, context: str, element: dict, files: list[str]) -> None: - """Sets files for a file input element. + def remove_file_dialog_handler(self, handler_id: int) -> None: + """Unsubscribe a previously registered file dialog event handler. Args: - context: The browsing context ID. - element: The element reference (script.SharedReference). - files: A list of file paths to set. + handler_id: The handler ID returned by add_file_dialog_handler. """ - params = {"context": context, "element": element, "files": files} - self.conn.execute(command_builder("input.setFiles", params)) + return self._event_manager.remove_event_handler("file_dialog_opened", handler_id) - def add_file_dialog_handler(self, handler) -> int: - """Add a handler for file dialog opened events. + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - handler: Callback function that takes a FileDialogInfo object. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). Returns: - int: Callback ID for removing the handler later. + The callback ID. """ - # Subscribe to the event if not already subscribed - if FileDialogOpened.event_class not in self.subscriptions: - session = Session(self.conn) - self.conn.execute(session.subscribe(FileDialogOpened.event_class)) - self.subscriptions[FileDialogOpened.event_class] = [] - - # Add callback - the callback receives the parsed FileDialogInfo directly - callback_id = self.conn.add_callback(FileDialogOpened, handler) + return self._event_manager.add_event_handler(event, callback, contexts) - self.subscriptions[FileDialogOpened.event_class].append(callback_id) - self.callbacks[callback_id] = handler - - return callback_id - - def remove_file_dialog_handler(self, callback_id: int) -> None: - """Remove a file dialog handler. + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. Args: - callback_id: The callback ID returned by add_file_dialog_handler. + event: The event to unsubscribe from. + callback_id: The callback ID. """ - if callback_id in self.callbacks: - del self.callbacks[callback_id] + return self._event_manager.remove_event_handler(event, callback_id) - if FileDialogOpened.event_class in self.subscriptions: - if callback_id in self.subscriptions[FileDialogOpened.event_class]: - self.subscriptions[FileDialogOpened.event_class].remove(callback_id) + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: input.fileDialogOpened +FileDialogOpened = globals().get('FileDialogInfo', dict) # Fallback to dict if type not defined - # If no more callbacks for this event, unsubscribe - if not self.subscriptions[FileDialogOpened.event_class]: - session = Session(self.conn) - self.conn.execute(session.unsubscribe(FileDialogOpened.event_class)) - del self.subscriptions[FileDialogOpened.event_class] - self.conn.remove_callback(FileDialogOpened, callback_id) +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Input.EVENT_CONFIGS = { + "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 575545776bda8..faf6c85ae2b6c 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,81 +1,109 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: log from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator from dataclasses import dataclass -from typing import Any -class LogEntryAdded: - event_class = "log.entryAdded" +class Level: + """Level.""" + + DEBUG = "debug" + INFO = "info" + WARN = "warn" + ERROR = "error" + + +LogLevel = Level + +@dataclass +class BaseLogEntry: + """BaseLogEntry.""" + + level: Any | None = None + source: Any | None = None + text: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None - @classmethod - def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry | JavaScriptLogEntry | None: - if json["type"] == "console": - return ConsoleLogEntry.from_json(json) - elif json["type"] == "javascript": - return JavaScriptLogEntry.from_json(json) - return None + +@dataclass +class GenericLogEntry: + """GenericLogEntry.""" + + type: str | None = None @dataclass class ConsoleLogEntry: - level: str - text: str - timestamp: str - method: str - args: list[dict[str, Any]] - type_: str + """ConsoleLogEntry - a console log entry from the browser.""" + + type_: str | None = None + method: str | None = None + args: list | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None @classmethod - def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": + """Deserialize from BiDi params dict.""" return cls( - level=json["level"], - text=json["text"], - timestamp=json["timestamp"], - method=json["method"], - args=json["args"], - type_=json["type"], + type_=params.get("type"), + method=params.get("method"), + args=params.get("args"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stack_trace=params.get("stackTrace"), ) - @dataclass -class JavaScriptLogEntry: - level: str - text: str - timestamp: str - stacktrace: dict[str, Any] - type_: str +class JavascriptLogEntry: + """JavascriptLogEntry - a JavaScript error log entry from the browser.""" + + type_: str | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stacktrace: Any | None = None @classmethod - def from_json(cls, json: dict[str, Any]) -> JavaScriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": + """Deserialize from BiDi params dict.""" return cls( - level=json["level"], - text=json["text"], - timestamp=json["timestamp"], - stacktrace=json["stackTrace"], - type_=json["type"], + type_=params.get("type"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stacktrace=params.get("stackTrace"), ) +class Log: + """WebDriver BiDi log module.""" -class LogLevel: - """Represents log level.""" + def __init__(self, conn) -> None: + self._conn = conn + + def entry_added(self): + """Execute log.entryAdded.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("log.entryAdded", params) + result = self._conn.execute(cmd) + return result - DEBUG = "debug" - INFO = "info" - WARN = "warn" - ERROR = "error" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 82472838dccde..4f44e309bffbb 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,338 +1,923 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - +# WebDriver BiDi module: network from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading from collections.abc import Callable -from typing import Any +from dataclasses import dataclass +from selenium.webdriver.common.bidi.session import Session -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.remote.websocket_connection import WebSocketConnection +class SameSite: + """SameSite.""" -class NetworkEvent: - """Represents a network event.""" + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default" - def __init__(self, event_class: str, **kwargs: Any) -> None: - self.event_class = event_class - self.params = kwargs - @classmethod - def from_json(cls, json: dict[str, Any]) -> NetworkEvent: - return cls(event_class=json.get("event_class", ""), **json) +class DataType: + """DataType.""" + REQUEST = "request" + RESPONSE = "response" -class Network: - EVENTS = { - "before_request": "network.beforeRequestSent", - "response_started": "network.responseStarted", - "response_completed": "network.responseCompleted", - "auth_required": "network.authRequired", - "fetch_error": "network.fetchError", - "continue_request": "network.continueRequest", - "continue_auth": "network.continueWithAuth", - } - - PHASES = { - "before_request": "beforeRequestSent", - "response_started": "responseStarted", - "auth_required": "authRequired", - } - - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn - self.intercepts: list[str] = [] - self.callbacks: dict[str | int, Any] = {} - self.subscriptions: dict[str, list[int]] = {} - - def _add_intercept( - self, - phases: list[str] | None = None, - contexts: list[str] | None = None, - url_patterns: list[Any] | None = None, - ) -> dict[str, Any]: - """Add an intercept to the network. + +class InterceptPhase: + """InterceptPhase.""" + + BEFOREREQUESTSENT = "beforeRequestSent" + RESPONSESTARTED = "responseStarted" + AUTHREQUIRED = "authRequired" + + +class ContinueWithAuthNoCredentials: + """ContinueWithAuthNoCredentials.""" + + DEFAULT = "default" + CANCEL = "cancel" + + +@dataclass +class AuthChallenge: + """AuthChallenge.""" + + scheme: str | None = None + realm: str | None = None + + +@dataclass +class AuthCredentials: + """AuthCredentials.""" + + type: str = field(default="password", init=False) + username: str | None = None + password: str | None = None + + +@dataclass +class BaseParameters: + """BaseParameters.""" + + context: Any | None = None + is_blocked: bool | None = None + navigation: Any | None = None + redirect_count: Any | None = None + request: Any | None = None + timestamp: Any | None = None + intercepts: list[Any | None] | None = None + + +@dataclass +class StringValue: + """StringValue.""" + + type: str = field(default="string", init=False) + value: str | None = None + + +@dataclass +class Base64Value: + """Base64Value.""" + + type: str = field(default="base64", init=False) + value: str | None = None + + +@dataclass +class Cookie: + """Cookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + +@dataclass +class CookieHeader: + """CookieHeader.""" + + name: str | None = None + value: Any | None = None + + +@dataclass +class FetchTimingInfo: + """FetchTimingInfo.""" + + time_origin: Any | None = None + request_time: Any | None = None + redirect_start: Any | None = None + redirect_end: Any | None = None + fetch_start: Any | None = None + dns_start: Any | None = None + dns_end: Any | None = None + connect_start: Any | None = None + connect_end: Any | None = None + tls_start: Any | None = None + request_start: Any | None = None + response_start: Any | None = None + response_end: Any | None = None + + +@dataclass +class Header: + """Header.""" + + name: str | None = None + value: Any | None = None + + +@dataclass +class Initiator: + """Initiator.""" + + column_number: Any | None = None + line_number: Any | None = None + request: Any | None = None + stack_trace: Any | None = None + type: Any | None = None + + +@dataclass +class ResponseContent: + """ResponseContent.""" + + size: Any | None = None + + +@dataclass +class ResponseData: + """ResponseData.""" + + url: str | None = None + protocol: str | None = None + status: Any | None = None + status_text: str | None = None + from_cache: bool | None = None + headers: list[Any | None] | None = None + mime_type: str | None = None + bytes_received: Any | None = None + headers_size: Any | None = None + body_size: Any | None = None + content: Any | None = None + auth_challenges: list[Any | None] | None = None + + +@dataclass +class SetCookieHeader: + """SetCookieHeader.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + http_only: bool | None = None + expiry: str | None = None + max_age: Any | None = None + path: str | None = None + same_site: Any | None = None + secure: bool | None = None + + +@dataclass +class UrlPatternPattern: + """UrlPatternPattern.""" + + type: str = field(default="pattern", init=False) + protocol: str | None = None + hostname: str | None = None + port: str | None = None + pathname: str | None = None + search: str | None = None + + +@dataclass +class UrlPatternString: + """UrlPatternString.""" + + type: str = field(default="string", init=False) + pattern: str | None = None + + +@dataclass +class AddDataCollectorParameters: + """AddDataCollectorParameters.""" + + data_types: list[Any | None] | None = None + max_encoded_data_size: Any | None = None + collector_type: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class AddDataCollectorResult: + """AddDataCollectorResult.""" + + collector: Any | None = None + + +@dataclass +class AddInterceptParameters: + """AddInterceptParameters.""" + + phases: list[Any | None] | None = None + contexts: list[Any | None] | None = None + url_patterns: list[Any | None] | None = None + + +@dataclass +class AddInterceptResult: + """AddInterceptResult.""" + + intercept: Any | None = None + + +@dataclass +class ContinueResponseParameters: + """ContinueResponseParameters.""" + + request: Any | None = None + cookies: list[Any | None] | None = None + credentials: Any | None = None + headers: list[Any | None] | None = None + reason_phrase: str | None = None + status_code: Any | None = None + + +@dataclass +class ContinueWithAuthParameters: + """ContinueWithAuthParameters.""" + + request: Any | None = None + + +@dataclass +class ContinueWithAuthCredentials: + """ContinueWithAuthCredentials.""" + + action: str = field(default="provideCredentials", init=False) + credentials: Any | None = None + + +@dataclass +class disownDataParameters: + """disownDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + request: Any | None = None + + +@dataclass +class FailRequestParameters: + """FailRequestParameters.""" + + request: Any | None = None + + +@dataclass +class GetDataParameters: + """GetDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + disown: bool | None = None + request: Any | None = None + + +@dataclass +class GetDataResult: + """GetDataResult.""" + + bytes: Any | None = None + + +@dataclass +class ProvideResponseParameters: + """ProvideResponseParameters.""" + + request: Any | None = None + body: Any | None = None + cookies: list[Any | None] | None = None + headers: list[Any | None] | None = None + reason_phrase: str | None = None + status_code: Any | None = None + + +@dataclass +class RemoveDataCollectorParameters: + """RemoveDataCollectorParameters.""" + + collector: Any | None = None + + +@dataclass +class RemoveInterceptParameters: + """RemoveInterceptParameters.""" + + intercept: Any | None = None + + +@dataclass +class SetCacheBehaviorParameters: + """SetCacheBehaviorParameters.""" + + cache_behavior: Any | None = None + contexts: list[Any | None] | None = None + + +@dataclass +class SetExtraHeadersParameters: + """SetExtraHeadersParameters.""" + + headers: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class ResponseStartedParameters: + """ResponseStartedParameters.""" + + response: Any | None = None + + +class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: str, value: str) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value} + +class Request: + """Wraps a BiDi network request event params and provides request action methods.""" + + def __init__(self, conn, params): + self._conn = conn + self._params = params if isinstance(params, dict) else {} + req = self._params.get("request", {}) or {} + self.url = req.get("url", "") + self._request_id = req.get("request") + + def continue_request(self, **kwargs): + """Continue the intercepted request.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + params = {"request": self._request_id} + params.update(kwargs) + self._conn.execute(_cb("network.continueRequest", params)) + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "auth_required": "network.authRequired", + "before_request": "network.beforeRequestSent", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - phases: A list of phases to intercept. Default is None (empty list). - contexts: A list of contexts to intercept. Default is None. - url_patterns: A list of URL patterns to intercept. Default is None. + params: Raw BiDi event params with camelCase keys. Returns: - str: intercept id + An instance of the dataclass, or the raw dict on failure. """ - if phases is None: - phases = [] - params = {} - if contexts is not None: - params["contexts"] = contexts - if url_patterns is not None: - params["urlPatterns"] = url_patterns - if len(phases) > 0: - params["phases"] = phases - else: - params["phases"] = ["beforeRequestSent"] + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + + +class Network: + """WebDriver BiDi network module.""" + + EVENT_CONFIGS = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + self.intercepts = [] + + def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute network.addDataCollector.""" + params = { + "dataTypes": data_types, + "maxEncodedDataSize": max_encoded_data_size, + "collectorType": collector_type, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.addDataCollector", params) + result = self._conn.execute(cmd) + return result + + def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): + """Execute network.addIntercept.""" + params = { + "phases": phases, + "contexts": contexts, + "urlPatterns": url_patterns, + } + params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("network.addIntercept", params) + result = self._conn.execute(cmd) + return result - result: dict[str, Any] = self.conn.execute(cmd) - self.intercepts.append(result["intercept"]) + def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): + """Execute network.continueRequest.""" + params = { + "request": request, + "body": body, + "cookies": cookies, + "headers": headers, + "method": method, + "url": url, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueRequest", params) + result = self._conn.execute(cmd) return result - def _remove_intercept(self, intercept: str | None = None) -> None: - """Remove a specific intercept, or all intercepts. + def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + """Execute network.continueResponse.""" + params = { + "request": request, + "cookies": cookies, + "credentials": credentials, + "headers": headers, + "reasonPhrase": reason_phrase, + "statusCode": status_code, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueResponse", params) + result = self._conn.execute(cmd) + return result - Args: - intercept: The intercept to remove. Default is None. + def continue_with_auth(self, request: Any | None = None): + """Execute network.continueWithAuth.""" + params = { + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueWithAuth", params) + result = self._conn.execute(cmd) + return result + + def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): + """Execute network.disownData.""" + params = { + "dataType": data_type, + "collector": collector, + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.disownData", params) + result = self._conn.execute(cmd) + return result - Raises: - ValueError: If intercept is not found. + def fail_request(self, request: Any | None = None): + """Execute network.failRequest.""" + params = { + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.failRequest", params) + result = self._conn.execute(cmd) + return result - Note: - If intercept is None, all intercepts will be removed. - """ - if intercept is None: - intercepts_to_remove = self.intercepts.copy() # create a copy before iterating - for intercept_id in intercepts_to_remove: # remove all intercepts - self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept_id})) - self.intercepts.remove(intercept_id) - else: - try: - self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept})) - self.intercepts.remove(intercept) - except Exception as e: - raise Exception(f"Exception: {e}") - - def _on_request(self, event_name: str, callback: Callable[[Request], Any]) -> int: - """Set a callback function to subscribe to a network event. + def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): + """Execute network.getData.""" + params = { + "dataType": data_type, + "collector": collector, + "disown": disown, + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.getData", params) + result = self._conn.execute(cmd) + return result + + def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + """Execute network.provideResponse.""" + params = { + "request": request, + "body": body, + "cookies": cookies, + "headers": headers, + "reasonPhrase": reason_phrase, + "statusCode": status_code, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.provideResponse", params) + result = self._conn.execute(cmd) + return result + + def remove_data_collector(self, collector: Any | None = None): + """Execute network.removeDataCollector.""" + params = { + "collector": collector, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.removeDataCollector", params) + result = self._conn.execute(cmd) + return result + + def remove_intercept(self, intercept: Any | None = None): + """Execute network.removeIntercept.""" + params = { + "intercept": intercept, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.removeIntercept", params) + result = self._conn.execute(cmd) + return result + + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): + """Execute network.setCacheBehavior.""" + params = { + "cacheBehavior": cache_behavior, + "contexts": contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.setCacheBehavior", params) + result = self._conn.execute(cmd) + return result + + def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute network.setExtraHeaders.""" + params = { + "headers": headers, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.setExtraHeaders", params) + result = self._conn.execute(cmd) + return result + + def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.beforeRequestSent.""" + params = { + "initiator": initiator, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.beforeRequestSent", params) + result = self._conn.execute(cmd) + return result + + def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.fetchError.""" + params = { + "errorText": error_text, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.fetchError", params) + result = self._conn.execute(cmd) + return result + + def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.responseCompleted.""" + params = { + "response": response, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseCompleted", params) + result = self._conn.execute(cmd) + return result + + def response_started(self, response: Any | None = None): + """Execute network.responseStarted.""" + params = { + "response": response, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseStarted", params) + result = self._conn.execute(cmd) + return result + + def _add_intercept(self, phases=None, url_patterns=None): + """Add a low-level network intercept. Args: - event_name: The event to subscribe to. - callback: The callback function to execute on event. - Takes Request object as argument. + phases: list of intercept phases (default: ["beforeRequestSent"]) + url_patterns: optional URL patterns to filter Returns: - int: callback id + dict with "intercept" key containing the intercept ID """ - event = NetworkEvent(event_name) - - def _callback(event_data: NetworkEvent) -> None: - request = Request( - network=self, - request_id=event_data.params["request"].get("request", None), - body_size=event_data.params["request"].get("bodySize", None), - cookies=event_data.params["request"].get("cookies", None), - resource_type=event_data.params["request"].get("goog:resourceType", None), - headers=event_data.params["request"].get("headers", None), - headers_size=event_data.params["request"].get("headersSize", None), - timings=event_data.params["request"].get("timings", None), - url=event_data.params["request"].get("url", None), - ) - callback(request) + from selenium.webdriver.common.bidi.common import command_builder as _cb - callback_id: int = self.conn.add_callback(event, _callback) - - if event_name in self.callbacks: - self.callbacks[event_name].append(callback_id) - else: - self.callbacks[event_name] = [callback_id] - - return callback_id + if phases is None: + phases = ["beforeRequestSent"] + params = {"phases": phases} + if url_patterns: + params["urlPatterns"] = url_patterns + result = self._conn.execute(_cb("network.addIntercept", params)) + if result: + intercept_id = result.get("intercept") + if intercept_id and intercept_id not in self.intercepts: + self.intercepts.append(intercept_id) + return result + def _remove_intercept(self, intercept_id): + """Remove a low-level network intercept.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb - def add_request_handler( - self, - event: str, - callback: Callable[[Request], Any], - url_patterns: list[Any] | None = None, - contexts: list[str] | None = None, - ) -> int: - """Add a request handler to the network. + self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) + if intercept_id in self.intercepts: + self.intercepts.remove(intercept_id) + def add_request_handler(self, event, callback, url_patterns=None): + """Add a handler for network requests at the specified phase. Args: - event: The event to subscribe to. - callback: The callback function to execute on request interception. - Takes Request object as argument. - url_patterns: A list of URL patterns to intercept. Default is None. - contexts: A list of contexts to intercept. Default is None. + event: Event name, e.g. ``"before_request"``. + callback: Callable receiving a :class:`Request` instance. + url_patterns: optional list of URL pattern dicts to filter. Returns: - int: callback id + callback_id int for later removal via remove_request_handler. """ - try: - event_name = self.EVENTS[event] - phase_name = self.PHASES[event] - except KeyError: - raise Exception(f"Event {event} not found") - - result = self._add_intercept(phases=[phase_name], url_patterns=url_patterns, contexts=contexts) - callback_id = self._on_request(event_name, callback) - - if event_name in self.subscriptions: - self.subscriptions[event_name].append(callback_id) - else: - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.subscribe", params)) - self.subscriptions[event_name] = [callback_id] - - self.callbacks[callback_id] = result["intercept"] - return callback_id + phase_map = { + "before_request": "beforeRequestSent", + "before_request_sent": "beforeRequestSent", + "response_started": "responseStarted", + "auth_required": "authRequired", + } + phase = phase_map.get(event, "beforeRequestSent") + self._add_intercept(phases=[phase], url_patterns=url_patterns) + + def _request_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request = Request(self._conn, raw) + callback(request) - def remove_request_handler(self, event: str, callback_id: int) -> None: - """Remove a request handler from the network. + return self.add_event_handler(event, _request_callback) + def remove_request_handler(self, event, callback_id): + """Remove a network request handler. Args: - event: The event to unsubscribe from. - callback_id: The callback id to remove. + event: The event name used when adding the handler. + callback_id: The int returned by add_request_handler. """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") - - net_event = NetworkEvent(event_name) - - self.conn.remove_callback(net_event, callback_id) - self._remove_intercept(self.callbacks[callback_id]) - del self.callbacks[callback_id] - self.subscriptions[event_name].remove(callback_id) - if len(self.subscriptions[event_name]) == 0: - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.unsubscribe", params)) - del self.subscriptions[event_name] - - def clear_request_handlers(self) -> None: - """Clear all request handlers from the network.""" - for event_name in self.subscriptions: - net_event = NetworkEvent(event_name) - for callback_id in self.subscriptions[event_name]: - self.conn.remove_callback(net_event, callback_id) - self._remove_intercept(self.callbacks[callback_id]) - del self.callbacks[callback_id] - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.unsubscribe", params)) - self.subscriptions = {} - - def add_auth_handler(self, username: str, password: str) -> int: - """Add an authentication handler to the network. + self.remove_event_handler(event, callback_id) + def clear_request_handlers(self): + """Clear all request handlers and remove all tracked intercepts.""" + self.clear_event_handlers() + for intercept_id in list(self.intercepts): + self._remove_intercept(intercept_id) + def add_auth_handler(self, username, password): + """Add an auth handler that automatically provides credentials. Args: - username: The username to authenticate with. - password: The password to authenticate with. + username: The username for basic authentication. + password: The password for basic authentication. Returns: - int: callback id + callback_id int for later removal via remove_auth_handler. """ - event = "auth_required" - - def _callback(request: Request) -> None: - request._continue_with_auth(username, password) + from selenium.webdriver.common.bidi.common import command_builder as _cb - return self.add_request_handler(event, _callback) - - def remove_auth_handler(self, callback_id: int) -> None: - """Remove an authentication handler from the network. + def _auth_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request_id = ( + raw.get("request", {}).get("request") + if isinstance(raw, dict) + else None + ) + if request_id: + self._conn.execute( + _cb( + "network.continueWithAuth", + { + "request": request_id, + "action": "provideCredentials", + "credentials": { + "type": "password", + "username": username, + "password": password, + }, + }, + ) + ) + + return self.add_event_handler("auth_required", _auth_callback) + def remove_auth_handler(self, callback_id): + """Remove an auth handler by callback ID.""" + self.remove_event_handler("auth_required", callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - callback_id: The callback id to remove. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. """ - event = "auth_required" - self.remove_request_handler(event, callback_id) + return self._event_manager.add_event_handler(event, callback, contexts) - -class Request: - """Represents an intercepted network request.""" - - def __init__( - self, - network: Network, - request_id: Any, - body_size: int | None = None, - cookies: Any = None, - resource_type: str | None = None, - headers: Any = None, - headers_size: int | None = None, - method: str | None = None, - timings: Any = None, - url: str | None = None, - ) -> None: - self.network = network - self.request_id = request_id - self.body_size = body_size - self.cookies = cookies - self.resource_type = resource_type - self.headers = headers - self.headers_size = headers_size - self.method = method - self.timings = timings - self.url = url - - def fail_request(self) -> None: - """Fail this request.""" - if not self.request_id: - raise ValueError("Request not found.") - - params: dict[str, Any] = {"request": self.request_id} - self.network.conn.execute(command_builder("network.failRequest", params)) - - def continue_request( - self, - body: Any = None, - method: str | None = None, - headers: Any = None, - cookies: Any = None, - url: str | None = None, - ) -> None: - """Continue after intercepting this request.""" - if not self.request_id: - raise ValueError("Request not found.") - - params: dict[str, Any] = {"request": self.request_id} - if body is not None: - params["body"] = body - if method is not None: - params["method"] = method - if headers is not None: - params["headers"] = headers - if cookies is not None: - params["cookies"] = cookies - if url is not None: - params["url"] = url - - self.network.conn.execute(command_builder("network.continueRequest", params)) - - def _continue_with_auth(self, username: str | None = None, password: str | None = None) -> None: - """Continue with authentication. + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. Args: - username: The username to authenticate with. - password: The password to authenticate with. - - Note: - If username or password is None, it attempts auth with no credentials. + event: The event to unsubscribe from. + callback_id: The callback ID. """ - params: dict[str, Any] = {} - params["request"] = self.request_id + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: network.authRequired +AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined - if not username or not password: # no credentials is valid option - params["action"] = "default" - else: - params["action"] = "provideCredentials" - params["credentials"] = {"type": "password", "username": username, "password": password} - self.network.conn.execute(command_builder("network.continueWithAuth", params)) +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Network.EVENT_CONFIGS = { + "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), + "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 17faa1ff5454f..f00e765c62e3b 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -15,12 +15,20 @@ # specific language governing permissions and limitations # under the License. +"""WebDriver BiDi Permissions module.""" -from selenium.webdriver.common.bidi.common import command_builder +from __future__ import annotations +from enum import Enum +from typing import Any, Optional, Union -class PermissionState: - """Represents the possible permission states.""" +from .common import command_builder + +_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"} + + +class PermissionState(str, Enum): + """Permission state enumeration.""" GRANTED = "granted" DENIED = "denied" @@ -28,56 +36,69 @@ class PermissionState: class PermissionDescriptor: - """Represents a permission descriptor.""" + """Descriptor for a permission.""" - def __init__(self, name: str): + def __init__(self, name: str) -> None: + """Initialize a PermissionDescriptor. + + Args: + name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera') + """ self.name = name - def to_dict(self) -> dict: - return {"name": self.name} + def __repr__(self) -> str: + return f"PermissionDescriptor('{self.name}')" class Permissions: - """BiDi implementation of the permissions module.""" + """WebDriver BiDi Permissions module.""" - def __init__(self, conn): - self.conn = conn + def __init__(self, websocket_connection: Any) -> None: + """Initialize the Permissions module. + + Args: + websocket_connection: The WebSocket connection for sending BiDi commands + """ + self._conn = websocket_connection def set_permission( self, - descriptor: str | PermissionDescriptor, - state: str, - origin: str, - user_context: str | None = None, + descriptor: Union[PermissionDescriptor, str], + state: Union[PermissionState, str], + origin: Optional[str] = None, + user_context: Optional[str] = None, ) -> None: - """Sets a permission state for a given permission descriptor. + """Set a permission for a given origin. Args: - descriptor: The permission name (str) or PermissionDescriptor object. - Examples: "geolocation", "camera", "microphone". - state: The permission state (granted, denied, prompt). - origin: The origin for which the permission is set. - user_context: The user context id (optional). + descriptor: The permission descriptor or permission name as a string + state: The desired permission state + origin: The origin for which to set the permission + user_context: Optional user context ID to scope the permission Raises: - ValueError: If the permission state is invalid. + ValueError: If the state is not a valid permission state """ - if state not in [PermissionState.GRANTED, PermissionState.DENIED, PermissionState.PROMPT]: - valid_states = f"{PermissionState.GRANTED}, {PermissionState.DENIED}, {PermissionState.PROMPT}" - raise ValueError(f"Invalid permission state. Must be one of: {valid_states}") + state_value = state.value if isinstance(state, PermissionState) else state + if state_value not in _VALID_PERMISSION_STATES: + raise ValueError( + f"Invalid permission state: {state_value!r}. " + f"Must be one of {sorted(_VALID_PERMISSION_STATES)}" + ) if isinstance(descriptor, str): - permission_descriptor = PermissionDescriptor(descriptor) + descriptor_dict = {"name": descriptor} else: - permission_descriptor = descriptor + descriptor_dict = {"name": descriptor.name} - params = { - "descriptor": permission_descriptor.to_dict(), - "state": state, - "origin": origin, + params: dict[str, Any] = { + "descriptor": descriptor_dict, + "state": state_value, } - + if origin is not None: + params["origin"] = origin if user_context is not None: params["userContext"] = user_context - self.conn.execute(command_builder("permissions.setPermission", params)) + cmd = command_builder("permissions.setPermission", params) + self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/py.typed b/py/selenium/webdriver/common/bidi/py.typed new file mode 100755 index 0000000000000..e69de29bb2d1d diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index e37b3269a4ade..e13c11f71a5cb 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,40 +1,33 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import datetime -import math -from dataclasses import dataclass -from typing import Any +# WebDriver BiDi module: script +from __future__ import annotations -from selenium.common.exceptions import WebDriverException -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi.log import LogEntryAdded +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading +from collections.abc import Callable +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -class ResultOwnership: - """Represents the possible result ownership types.""" +class SpecialNumber: + """SpecialNumber.""" - NONE = "none" - ROOT = "root" + NAN = "NaN" + _0 = "-0" + INFINITY = "Infinity" + _INFINITY = "-Infinity" class RealmType: - """Represents the possible realm types.""" + """RealmType.""" WINDOW = "window" DEDICATED_WORKER = "dedicated-worker" @@ -46,502 +39,1224 @@ class RealmType: WORKLET = "worklet" +class ResultOwnership: + """ResultOwnership.""" + + ROOT = "root" + NONE = "none" + + +@dataclass +class ChannelValue: + """ChannelValue.""" + + type: str = field(default="channel", init=False) + value: Any | None = None + + +@dataclass +class ChannelProperties: + """ChannelProperties.""" + + channel: Any | None = None + serialization_options: Any | None = None + ownership: Any | None = None + + +@dataclass +class EvaluateResultSuccess: + """EvaluateResultSuccess.""" + + type: str = field(default="success", init=False) + result: Any | None = None + realm: Any | None = None + + +@dataclass +class EvaluateResultException: + """EvaluateResultException.""" + + type: str = field(default="exception", init=False) + exception_details: Any | None = None + realm: Any | None = None + + +@dataclass +class ExceptionDetails: + """ExceptionDetails.""" + + column_number: Any | None = None + exception: Any | None = None + line_number: Any | None = None + stack_trace: Any | None = None + text: str | None = None + + +@dataclass +class ArrayLocalValue: + """ArrayLocalValue.""" + + type: str = field(default="array", init=False) + value: Any | None = None + + +@dataclass +class DateLocalValue: + """DateLocalValue.""" + + type: str = field(default="date", init=False) + value: str | None = None + + +@dataclass +class MapLocalValue: + """MapLocalValue.""" + + type: str = field(default="map", init=False) + value: Any | None = None + + +@dataclass +class ObjectLocalValue: + """ObjectLocalValue.""" + + type: str = field(default="object", init=False) + value: Any | None = None + + +@dataclass +class RegExpValue: + """RegExpValue.""" + + pattern: str | None = None + flags: str | None = None + + +@dataclass +class RegExpLocalValue: + """RegExpLocalValue.""" + + type: str = field(default="regexp", init=False) + value: Any | None = None + + +@dataclass +class SetLocalValue: + """SetLocalValue.""" + + type: str = field(default="set", init=False) + value: Any | None = None + + @dataclass -class RealmInfo: - """Represents information about a realm.""" +class UndefinedValue: + """UndefinedValue.""" + + type: str = field(default="undefined", init=False) + + +@dataclass +class NullValue: + """NullValue.""" + + type: str = field(default="null", init=False) + + +@dataclass +class StringValue: + """StringValue.""" + + type: str = field(default="string", init=False) + value: str | None = None + + +@dataclass +class NumberValue: + """NumberValue.""" + + type: str = field(default="number", init=False) + value: Any | None = None + + +@dataclass +class BooleanValue: + """BooleanValue.""" + + type: str = field(default="boolean", init=False) + value: bool | None = None + + +@dataclass +class BigIntValue: + """BigIntValue.""" + + type: str = field(default="bigint", init=False) + value: str | None = None + + +@dataclass +class BaseRealmInfo: + """BaseRealmInfo.""" + + realm: Any | None = None + origin: str | None = None + - realm: str - origin: str - type: str - context: str | None = None +@dataclass +class WindowRealmInfo: + """WindowRealmInfo.""" + + type: str = field(default="window", init=False) + context: Any | None = None sandbox: str | None = None - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmInfo": - """Creates a RealmInfo instance from a dictionary. - Args: - json: A dictionary containing the realm information. +@dataclass +class DedicatedWorkerRealmInfo: + """DedicatedWorkerRealmInfo.""" - Returns: - RealmInfo: A new instance of RealmInfo. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in RealmInfo") - if "origin" not in json: - raise ValueError("Missing required field 'origin' in RealmInfo") - if "type" not in json: - raise ValueError("Missing required field 'type' in RealmInfo") - - return cls( - realm=json["realm"], - origin=json["origin"], - type=json["type"], - context=json.get("context"), - sandbox=json.get("sandbox"), - ) + type: str = field(default="dedicated-worker", init=False) + owners: list[Any | None] | None = None @dataclass -class Source: - """Represents the source of a script message.""" +class SharedWorkerRealmInfo: + """SharedWorkerRealmInfo.""" - realm: str - context: str | None = None + type: str = field(default="shared-worker", init=False) - @classmethod - def from_json(cls, json: dict[str, Any]) -> "Source": - """Creates a Source instance from a dictionary. - Args: - json: A dictionary containing the source information. +@dataclass +class ServiceWorkerRealmInfo: + """ServiceWorkerRealmInfo.""" - Returns: - Source: A new instance of Source. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in Source") + type: str = field(default="service-worker", init=False) - return cls( - realm=json["realm"], - context=json.get("context"), - ) + +@dataclass +class WorkerRealmInfo: + """WorkerRealmInfo.""" + + type: str = field(default="worker", init=False) @dataclass -class EvaluateResult: - """Represents the result of script evaluation.""" +class PaintWorkletRealmInfo: + """PaintWorkletRealmInfo.""" - type: str - realm: str - result: dict | None = None - exception_details: dict | None = None + type: str = field(default="paint-worklet", init=False) - @classmethod - def from_json(cls, json: dict[str, Any]) -> "EvaluateResult": - """Creates an EvaluateResult instance from a dictionary. - Args: - json: A dictionary containing the evaluation result. +@dataclass +class AudioWorkletRealmInfo: + """AudioWorkletRealmInfo.""" - Returns: - EvaluateResult: A new instance of EvaluateResult. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in EvaluateResult") - if "type" not in json: - raise ValueError("Missing required field 'type' in EvaluateResult") - - return cls( - type=json["type"], - realm=json["realm"], - result=json.get("result"), - exception_details=json.get("exceptionDetails"), - ) + type: str = field(default="audio-worklet", init=False) -class ScriptMessage: - """Represents a script message event.""" +@dataclass +class WorkletRealmInfo: + """WorkletRealmInfo.""" - event_class = "script.message" + type: str = field(default="worklet", init=False) - def __init__(self, channel: str, data: dict, source: Source): - self.channel = channel - self.data = data - self.source = source - @classmethod - def from_json(cls, json: dict[str, Any]) -> "ScriptMessage": - """Creates a ScriptMessage instance from a dictionary. +@dataclass +class SharedReference: + """SharedReference.""" - Args: - json: A dictionary containing the script message. + shared_id: Any | None = None + handle: Any | None = None - Returns: - ScriptMessage: A new instance of ScriptMessage. - """ - if "channel" not in json: - raise ValueError("Missing required field 'channel' in ScriptMessage") - if "data" not in json: - raise ValueError("Missing required field 'data' in ScriptMessage") - if "source" not in json: - raise ValueError("Missing required field 'source' in ScriptMessage") - - return cls( - channel=json["channel"], - data=json["data"], - source=Source.from_json(json["source"]), - ) +@dataclass +class RemoteObjectReference: + """RemoteObjectReference.""" -class RealmCreated: - """Represents a realm created event.""" + handle: Any | None = None + shared_id: Any | None = None - event_class = "script.realmCreated" - def __init__(self, realm_info: RealmInfo): - self.realm_info = realm_info +@dataclass +class SymbolRemoteValue: + """SymbolRemoteValue.""" - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmCreated": - """Creates a RealmCreated instance from a dictionary. + type: str = field(default="symbol", init=False) + handle: Any | None = None + internal_id: Any | None = None - Args: - json: A dictionary containing the realm created event. - Returns: - RealmCreated: A new instance of RealmCreated. - """ - return cls(realm_info=RealmInfo.from_json(json)) +@dataclass +class ArrayRemoteValue: + """ArrayRemoteValue.""" + type: str = field(default="array", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None -class RealmDestroyed: - """Represents a realm destroyed event.""" - event_class = "script.realmDestroyed" +@dataclass +class ObjectRemoteValue: + """ObjectRemoteValue.""" - def __init__(self, realm: str): - self.realm = realm + type: str = field(default="object", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmDestroyed": - """Creates a RealmDestroyed instance from a dictionary. - Args: - json: A dictionary containing the realm destroyed event. +@dataclass +class FunctionRemoteValue: + """FunctionRemoteValue.""" - Returns: - RealmDestroyed: A new instance of RealmDestroyed. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in RealmDestroyed") + type: str = field(default="function", init=False) + handle: Any | None = None + internal_id: Any | None = None - return cls(realm=json["realm"]) +@dataclass +class RegExpRemoteValue: + """RegExpRemoteValue.""" -class Script: - """BiDi implementation of the script module.""" + handle: Any | None = None + internal_id: Any | None = None - EVENTS = { - "message": "script.message", - "realm_created": "script.realmCreated", - "realm_destroyed": "script.realmDestroyed", - } - def __init__(self, conn, driver=None): - self.conn = conn - self.driver = driver - self.log_entry_subscribed = False - self.subscriptions = {} - self.callbacks = {} +@dataclass +class DateRemoteValue: + """DateRemoteValue.""" + + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class MapRemoteValue: + """MapRemoteValue.""" + + type: str = field(default="map", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class SetRemoteValue: + """SetRemoteValue.""" + + type: str = field(default="set", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class WeakMapRemoteValue: + """WeakMapRemoteValue.""" + + type: str = field(default="weakmap", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class WeakSetRemoteValue: + """WeakSetRemoteValue.""" + + type: str = field(default="weakset", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class GeneratorRemoteValue: + """GeneratorRemoteValue.""" + + type: str = field(default="generator", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ErrorRemoteValue: + """ErrorRemoteValue.""" + + type: str = field(default="error", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ProxyRemoteValue: + """ProxyRemoteValue.""" + + type: str = field(default="proxy", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class PromiseRemoteValue: + """PromiseRemoteValue.""" + + type: str = field(default="promise", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class TypedArrayRemoteValue: + """TypedArrayRemoteValue.""" + + type: str = field(default="typedarray", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ArrayBufferRemoteValue: + """ArrayBufferRemoteValue.""" + + type: str = field(default="arraybuffer", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class NodeListRemoteValue: + """NodeListRemoteValue.""" + + type: str = field(default="nodelist", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class HTMLCollectionRemoteValue: + """HTMLCollectionRemoteValue.""" + + type: str = field(default="htmlcollection", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class NodeRemoteValue: + """NodeRemoteValue.""" + + type: str = field(default="node", init=False) + shared_id: Any | None = None + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class NodeProperties: + """NodeProperties.""" + + node_type: Any | None = None + child_node_count: Any | None = None + children: list[Any | None] | None = None + local_name: str | None = None + mode: Any | None = None + namespace_uri: str | None = None + node_value: str | None = None + shadow_root: Any | None = None + + +@dataclass +class WindowProxyRemoteValue: + """WindowProxyRemoteValue.""" + + type: str = field(default="window", init=False) + value: Any | None = None + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class WindowProxyProperties: + """WindowProxyProperties.""" + + context: Any | None = None + + +@dataclass +class StackFrame: + """StackFrame.""" + + column_number: Any | None = None + function_name: str | None = None + line_number: Any | None = None + url: str | None = None + + +@dataclass +class StackTrace: + """StackTrace.""" + + call_frames: list[Any | None] | None = None + + +@dataclass +class Source: + """Source.""" + + realm: Any | None = None + context: Any | None = None + + +@dataclass +class RealmTarget: + """RealmTarget.""" + + realm: Any | None = None + + +@dataclass +class ContextTarget: + """ContextTarget.""" + + context: Any | None = None + sandbox: str | None = None + + +@dataclass +class AddPreloadScriptParameters: + """AddPreloadScriptParameters.""" + + function_declaration: str | None = None + arguments: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + sandbox: str | None = None + + +@dataclass +class AddPreloadScriptResult: + """AddPreloadScriptResult.""" + + script: Any | None = None + - # High-level APIs for SCRIPT module +@dataclass +class DisownParameters: + """DisownParameters.""" + + handles: list[Any | None] | None = None + target: Any | None = None + + +@dataclass +class CallFunctionParameters: + """CallFunctionParameters.""" + + function_declaration: str | None = None + await_promise: bool | None = None + target: Any | None = None + arguments: list[Any | None] | None = None + result_ownership: Any | None = None + serialization_options: Any | None = None + this: Any | None = None + user_activation: bool | None = None - def add_console_message_handler(self, handler): - self._subscribe_to_log_entries() - return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler)) - def add_javascript_error_handler(self, handler): - self._subscribe_to_log_entries() - return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("javascript", handler)) +@dataclass +class EvaluateParameters: + """EvaluateParameters.""" + + expression: str | None = None + target: Any | None = None + await_promise: bool | None = None + result_ownership: Any | None = None + serialization_options: Any | None = None + user_activation: bool | None = None + + +@dataclass +class GetRealmsParameters: + """GetRealmsParameters.""" + + context: Any | None = None + type: Any | None = None + + +@dataclass +class GetRealmsResult: + """GetRealmsResult.""" + + realms: list[Any | None] | None = None + + +@dataclass +class RemovePreloadScriptParameters: + """RemovePreloadScriptParameters.""" + + script: Any | None = None + + +@dataclass +class MessageParameters: + """MessageParameters.""" + + channel: Any | None = None + data: Any | None = None + source: Any | None = None + + +@dataclass +class RealmDestroyedParameters: + """RealmDestroyedParameters.""" + + realm: Any | None = None + + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "realm_created": "script.realmCreated", + "realm_destroyed": "script.realmDestroyed", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type - def remove_console_message_handler(self, id): - self.conn.remove_callback(LogEntryAdded, id) - self._unsubscribe_from_log_entries() - remove_javascript_error_handler = remove_console_message_handler +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization - def pin(self, script: str) -> str: - """Pins a script to the current browsing context. + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - script: The script to pin. + params: Raw BiDi event params with camelCase keys. Returns: - str: The ID of the pinned script. + An instance of the dataclass, or the raw dict on failure. """ - return self._add_preload_script(script) + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() - def unpin(self, script_id: str) -> None: - """Unpins a script from the current browsing context. - Args: - script_id: The ID of the pinned script to unpin. - """ - self._remove_preload_script(script_id) - def execute(self, script: str, *args) -> dict: - """Executes a script in the current browsing context. - Args: - script: The script function to execute. - *args: Arguments to pass to the script function. +class Script: + """WebDriver BiDi script module.""" - Returns: - dict: The result value from the script execution. + EVENT_CONFIGS = {} + def __init__(self, conn, driver=None) -> None: + self._conn = conn + self._driver = driver + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - Raises: - WebDriverException: If the script execution fails. - """ - if self.driver is None: - raise WebDriverException("Driver reference is required for script execution") - browsing_context_id = self.driver.current_window_handle + def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): + """Execute script.addPreloadScript.""" + params = { + "functionDeclaration": function_declaration, + "arguments": arguments, + "contexts": contexts, + "userContexts": user_contexts, + "sandbox": sandbox, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.addPreloadScript", params) + result = self._conn.execute(cmd) + return result - # Convert arguments to the format expected by BiDi call_function (LocalValue Type) - arguments = [] - for arg in args: - arguments.append(self.__convert_to_local_value(arg)) + def disown(self, handles: List[Any] | None = None, target: Any | None = None): + """Execute script.disown.""" + params = { + "handles": handles, + "target": target, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.disown", params) + result = self._conn.execute(cmd) + return result - target = {"context": browsing_context_id} + def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): + """Execute script.callFunction.""" + params = { + "functionDeclaration": function_declaration, + "awaitPromise": await_promise, + "target": target, + "arguments": arguments, + "resultOwnership": result_ownership, + "serializationOptions": serialization_options, + "this": this, + "userActivation": user_activation, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.callFunction", params) + result = self._conn.execute(cmd) + return result - result = self._call_function( - function_declaration=script, await_promise=True, target=target, arguments=arguments if arguments else None - ) + def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): + """Execute script.evaluate.""" + params = { + "expression": expression, + "target": target, + "awaitPromise": await_promise, + "resultOwnership": result_ownership, + "serializationOptions": serialization_options, + "userActivation": user_activation, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.evaluate", params) + result = self._conn.execute(cmd) + return result - if result.type == "success": - return result.result if result.result is not None else {} - else: - error_message = "Error while executing script" - if result.exception_details: - if "text" in result.exception_details: - error_message += f": {result.exception_details['text']}" - elif "message" in result.exception_details: - error_message += f": {result.exception_details['message']}" - - raise WebDriverException(error_message) - - def __convert_to_local_value(self, value) -> dict: - """Converts a Python value to BiDi LocalValue format.""" - if value is None: - return {"type": "null"} - elif isinstance(value, bool): - return {"type": "boolean", "value": value} - elif isinstance(value, (int, float)): + def get_realms(self, context: Any | None = None, type: Any | None = None): + """Execute script.getRealms.""" + params = { + "context": context, + "type": type, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.getRealms", params) + result = self._conn.execute(cmd) + return result + + def remove_preload_script(self, script: Any | None = None): + """Execute script.removePreloadScript.""" + params = { + "script": script, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.removePreloadScript", params) + result = self._conn.execute(cmd) + return result + + def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): + """Execute script.message.""" + params = { + "channel": channel, + "data": data, + "source": source, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.message", params) + result = self._conn.execute(cmd) + return result + + def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: + """Execute a function declaration in the browser context. + + Args: + function_declaration: The function as a string, e.g. ``"() => document.title"``. + *args: Optional Python values to pass as arguments to the function. + Each value is serialised to a BiDi ``LocalValue`` automatically. + Supported types: ``None``, ``bool``, ``int``, ``float`` + (including ``NaN`` and ``Infinity``), ``str``, ``list``, + ``dict``, and ``datetime.datetime``. + context_id: The browsing context ID to run in. Defaults to the + driver's current window handle when a driver was provided. + + Returns: + The inner RemoteValue result dict, or raises WebDriverException on exception. + """ + import math as _math + import datetime as _datetime + from selenium.common.exceptions import WebDriverException as _WebDriverException + + def _serialize_arg(value): + """Serialise a Python value to a BiDi LocalValue dict.""" + if value is None: + return {"type": "null"} + if isinstance(value, bool): + return {"type": "boolean", "value": value} + if isinstance(value, _datetime.datetime): + return {"type": "date", "value": value.isoformat()} if isinstance(value, float): - if math.isnan(value): + if _math.isnan(value): return {"type": "number", "value": "NaN"} - elif math.isinf(value): - if value > 0: - return {"type": "number", "value": "Infinity"} - else: - return {"type": "number", "value": "-Infinity"} - elif value == 0.0 and math.copysign(1.0, value) < 0: - return {"type": "number", "value": "-0"} - - JS_MAX_SAFE_INTEGER = 9007199254740991 - if isinstance(value, int) and (value > JS_MAX_SAFE_INTEGER or value < -JS_MAX_SAFE_INTEGER): - return {"type": "bigint", "value": str(value)} - - return {"type": "number", "value": value} - - elif isinstance(value, str): - return {"type": "string", "value": value} - elif isinstance(value, datetime.datetime): - # Convert Python datetime to JavaScript Date (ISO 8601 format) - return {"type": "date", "value": value.isoformat() + "Z" if value.tzinfo is None else value.isoformat()} - elif isinstance(value, datetime.date): - # Convert Python date to JavaScript Date - dt = datetime.datetime.combine(value, datetime.time.min).replace(tzinfo=datetime.timezone.utc) - return {"type": "date", "value": dt.isoformat()} - elif isinstance(value, set): - return {"type": "set", "value": [self.__convert_to_local_value(item) for item in value]} - elif isinstance(value, (list, tuple)): - return {"type": "array", "value": [self.__convert_to_local_value(item) for item in value]} - elif isinstance(value, dict): - return { - "type": "object", - "value": [ - [self.__convert_to_local_value(k), self.__convert_to_local_value(v)] for k, v in value.items() - ], - } - else: - # For other types, convert to string - return {"type": "string", "value": str(value)} - - # low-level APIs for script module - def _add_preload_script( - self, - function_declaration: str, - arguments: list[dict[str, Any]] | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - sandbox: str | None = None, - ) -> str: - """Adds a preload script. + if _math.isinf(value): + return {"type": "number", "value": "Infinity" if value > 0 else "-Infinity"} + return {"type": "number", "value": value} + if isinstance(value, int): + _MAX_SAFE_INT = 9007199254740991 + if abs(value) > _MAX_SAFE_INT: + return {"type": "bigint", "value": str(value)} + return {"type": "number", "value": value} + if isinstance(value, str): + return {"type": "string", "value": value} + if isinstance(value, list): + return {"type": "array", "value": [_serialize_arg(v) for v in value]} + if isinstance(value, dict): + return {"type": "object", "value": [[str(k), _serialize_arg(v)] for k, v in value.items()]} + return value + + if context_id is None and self._driver is not None: + try: + context_id = self._driver.current_window_handle + except Exception: + pass + target = {"context": context_id} if context_id else {} + serialized_args = [_serialize_arg(a) for a in args] if args else None + raw = self.call_function( + function_declaration=function_declaration, + await_promise=True, + target=target, + arguments=serialized_args, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails", {}) + msg = exc.get("text", str(exc)) if isinstance(exc, dict) else str(exc) + raise _WebDriverException(msg) + if raw.get("type") == "success": + return raw.get("result") + return raw + def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + """Add a preload script with validation. Args: - function_declaration: The function declaration to preload. - arguments: The arguments to pass to the function. - contexts: The browsing context IDs to apply the script to. - user_contexts: The user context IDs to apply the script to. - sandbox: The sandbox name to apply the script to. + function_declaration: The JS function to run on page load. + arguments: Optional list of BiDi arguments. + contexts: Optional list of browsing context IDs. + user_contexts: Optional list of user context IDs. + sandbox: Optional sandbox name. Returns: - str: The preload script ID. + script_id: The ID of the added preload script (str). Raises: - ValueError: If both contexts and user_contexts are provided. + ValueError: If both contexts and user_contexts are specified. """ if contexts is not None and user_contexts is not None: raise ValueError("Cannot specify both contexts and user_contexts") + result = self.add_preload_script( + function_declaration=function_declaration, + arguments=arguments, + contexts=contexts, + user_contexts=user_contexts, + sandbox=sandbox, + ) + if isinstance(result, dict): + return result.get("script") + return result + def _remove_preload_script(self, script_id): + """Remove a preload script by ID. - params: dict[str, Any] = {"functionDeclaration": function_declaration} - - if arguments is not None: - params["arguments"] = arguments - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - if sandbox is not None: - params["sandbox"] = sandbox + Args: + script_id: The ID of the preload script to remove. + """ + return self.remove_preload_script(script=script_id) + def pin(self, function_declaration): + """Pin (add) a preload script that runs on every page load. - result = self.conn.execute(command_builder("script.addPreloadScript", params)) - return result["script"] + Args: + function_declaration: The JS function to execute on page load. - def _remove_preload_script(self, script_id: str) -> None: - """Removes a preload script. + Returns: + script_id: The ID of the pinned script (str). + """ + return self._add_preload_script(function_declaration) + def unpin(self, script_id): + """Unpin (remove) a previously pinned preload script. Args: - script_id: The preload script ID to remove. + script_id: The ID returned by pin(). """ - params = {"script": script_id} - self.conn.execute(command_builder("script.removePreloadScript", params)) + return self._remove_preload_script(script_id=script_id) + def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + """Evaluate a script expression and return a structured result. - def _disown(self, handles: list[str], target: dict) -> None: - """Disowns the given handles. + Args: + expression: The JavaScript expression to evaluate. + target: A dict like {"context": } or {"realm": }. + await_promise: Whether to await a returned promise. + result_ownership: Optional result ownership setting. + serialization_options: Optional serialization options dict. + user_activation: Optional user activation flag. + + Returns: + An object with .realm, .result (dict or None), and .exception_details (or None). + """ + class _EvalResult: + def __init__(self2, realm, result, exception_details): + self2.realm = realm + self2.result = result + self2.exception_details = exception_details + + raw = self.evaluate( + expression=expression, + target=target, + await_promise=await_promise, + result_ownership=result_ownership, + serialization_options=serialization_options, + user_activation=user_activation, + ) + if isinstance(raw, dict): + realm = raw.get("realm") + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _EvalResult(realm=realm, result=None, exception_details=exc) + return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) + return _EvalResult(realm=None, result=raw, exception_details=None) + def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + """Call a function and return a structured result. Args: - handles: The handles to disown. - target: The target realm or context. + function_declaration: The JS function string. + await_promise: Whether to await the return value. + target: A dict like {"context": }. + arguments: Optional list of BiDi arguments. + result_ownership: Optional result ownership. + this: Optional 'this' binding. + user_activation: Optional user activation flag. + serialization_options: Optional serialization options dict. + + Returns: + An object with .result (dict or None) and .exception_details (or None). """ - params = { - "handles": handles, - "target": target, - } - self.conn.execute(command_builder("script.disown", params)) - - def _call_function( - self, - function_declaration: str, - await_promise: bool, - target: dict, - arguments: list[dict] | None = None, - result_ownership: str | None = None, - serialization_options: dict | None = None, - this: dict | None = None, - user_activation: bool = False, - ) -> EvaluateResult: - """Calls a provided function with given arguments in a given realm. + class _CallResult: + def __init__(self2, result, exception_details): + self2.result = result + self2.exception_details = exception_details + + raw = self.call_function( + function_declaration=function_declaration, + await_promise=await_promise, + target=target, + arguments=arguments, + result_ownership=result_ownership, + this=this, + user_activation=user_activation, + serialization_options=serialization_options, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _CallResult(result=None, exception_details=exc) + if raw.get("type") == "success": + return _CallResult(result=raw.get("result"), exception_details=None) + return _CallResult(result=raw, exception_details=None) + def _get_realms(self, context=None, type=None): + """Get all realms, optionally filtered by context and type. Args: - function_declaration: The function declaration to call. - await_promise: Whether to await promise resolution. - target: The target realm or context. - arguments: The arguments to pass to the function. - result_ownership: The result ownership type. - serialization_options: The serialization options. - this: The 'this' value for the function call. - user_activation: Whether to trigger user activation. + context: Optional browsing context ID to filter by. + type: Optional realm type string to filter by (e.g. RealmType.WINDOW). Returns: - EvaluateResult: The result of the function call. + List of realm info objects with .realm, .origin, .type, .context attributes. """ - params = { - "functionDeclaration": function_declaration, - "awaitPromise": await_promise, - "target": target, - "userActivation": user_activation, - } + class _RealmInfo: + def __init__(self2, realm, origin, type_, context): + self2.realm = realm + self2.origin = origin + self2.type = type_ + self2.context = context + + raw = self.get_realms(context=context, type=type) + realms_list = raw.get("realms", []) if isinstance(raw, dict) else [] + result = [] + for r in realms_list: + if isinstance(r, dict): + result.append(_RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + )) + return result + def _disown(self, handles, target): + """Disown handles in a browsing context. - if arguments is not None: - params["arguments"] = arguments - if result_ownership is not None: - params["resultOwnership"] = result_ownership - if serialization_options is not None: - params["serializationOptions"] = serialization_options - if this is not None: - params["this"] = this - - result = self.conn.execute(command_builder("script.callFunction", params)) - return EvaluateResult.from_json(result) - - def _evaluate( - self, - expression: str, - target: dict, - await_promise: bool, - result_ownership: str | None = None, - serialization_options: dict | None = None, - user_activation: bool = False, - ) -> EvaluateResult: - """Evaluates a provided script in a given realm. + Args: + handles: List of handle strings to disown. + target: A dict like {"context": }. + """ + return self.disown(handles=handles, target=target) + def _subscribe_log_entry(self, callback, entry_type_filter=None): + """Subscribe to log.entryAdded BiDi events with optional type filtering.""" + import threading as _threading + from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + + bidi_event = "log.entryAdded" + + if not hasattr(self, "_log_subscriptions"): + self._log_subscriptions = {} + self._log_lock = _threading.Lock() + + def _deserialize(params): + t = params.get("type") if isinstance(params, dict) else None + if t == "console": + cls = getattr(_log_mod, "ConsoleLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + elif t == "javascript": + cls = getattr(_log_mod, "JavascriptLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + return params + + def _wrapped(raw): + entry = _deserialize(raw) + if entry_type_filter is None: + callback(entry) + else: + t = getattr(entry, "type_", None) or ( + entry.get("type") if isinstance(entry, dict) else None + ) + if t == entry_type_filter: + callback(entry) + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + callback_id = self._conn.add_callback(_wrapper, _wrapped) + with self._log_lock: + if bidi_event not in self._log_subscriptions: + session = _Session(self._conn) + result = session.subscribe([bidi_event]) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self._log_subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) + return callback_id + def _unsubscribe_log_entry(self, callback_id): + """Unsubscribe a log entry callback by ID.""" + from selenium.webdriver.common.bidi.session import Session as _Session + + bidi_event = "log.entryAdded" + if not hasattr(self, "_log_subscriptions"): + return + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + self._conn.remove_callback(_wrapper, callback_id) + with self._log_lock: + entry = self._log_subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + if entry is not None and not entry["callbacks"]: + session = _Session(self._conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self._log_subscriptions[bidi_event] + def add_console_message_handler(self, callback: Callable) -> int: + """Add a handler for console log messages (log.entryAdded type=console). Args: - expression: The script expression to evaluate. - target: The target realm or context. - await_promise: Whether to await promise resolution. - result_ownership: The result ownership type. - serialization_options: The serialization options. - user_activation: Whether to trigger user activation. + callback: Function called with a ConsoleLogEntry on each console message. Returns: - EvaluateResult: The result of the script evaluation. + callback_id for use with remove_console_message_handler. """ - params = { - "expression": expression, - "target": target, - "awaitPromise": await_promise, - "userActivation": user_activation, - } + return self._subscribe_log_entry(callback, entry_type_filter="console") + def remove_console_message_handler(self, callback_id: int) -> None: + """Remove a console message handler by callback ID.""" + self._unsubscribe_log_entry(callback_id) + def add_javascript_error_handler(self, callback: Callable) -> int: + """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). - if result_ownership is not None: - params["resultOwnership"] = result_ownership - if serialization_options is not None: - params["serializationOptions"] = serialization_options + Args: + callback: Function called with a JavascriptLogEntry on each JS error. - result = self.conn.execute(command_builder("script.evaluate", params)) - return EvaluateResult.from_json(result) + Returns: + callback_id for use with remove_javascript_error_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="javascript") + def remove_javascript_error_handler(self, callback_id: int) -> None: + """Remove a JavaScript error handler by callback ID.""" + self._unsubscribe_log_entry(callback_id) - def _get_realms( - self, - context: str | None = None, - type: str | None = None, - ) -> list[RealmInfo]: - """Returns a list of all realms, optionally filtered. + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - context: The browsing context ID to filter by. - type: The realm type to filter by. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). Returns: - List[RealmInfo]: A list of realm information. + The callback ID. """ - params = {} + return self._event_manager.add_event_handler(event, callback, contexts) - if context is not None: - params["context"] = context - if type is not None: - params["type"] = type + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. - result = self.conn.execute(command_builder("script.getRealms", params)) - return [RealmInfo.from_json(realm) for realm in result["realms"]] + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + """ + return self._event_manager.remove_event_handler(event, callback_id) - def _subscribe_to_log_entries(self): - if not self.log_entry_subscribed: - session = Session(self.conn) - self.conn.execute(session.subscribe(LogEntryAdded.event_class)) - self.log_entry_subscribed = True + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() - def _unsubscribe_from_log_entries(self): - if self.log_entry_subscribed and LogEntryAdded.event_class not in self.conn.callbacks: - session = Session(self.conn) - self.conn.execute(session.unsubscribe(LogEntryAdded.event_class)) - self.log_entry_subscribed = False +# Event Info Type Aliases +# Event: script.realmCreated +RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined + +# Event: script.realmDestroyed +RealmDestroyed = globals().get('RealmDestroyedParameters', dict) # Fallback to dict if type not defined - def _handle_log_entry(self, type, handler): - def _handle_log_entry(log_entry): - if log_entry.type_ == type: - handler(log_entry) - return _handle_log_entry +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Script.EVENT_CONFIGS = { + "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), + "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 3481c2d77842d..9b1daaae557fa 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,134 +1,236 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: session +from __future__ import annotations - -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class UserPromptHandlerType: - """Represents the behavior of the user prompt handler.""" + """UserPromptHandlerType.""" ACCEPT = "accept" DISMISS = "dismiss" IGNORE = "ignore" - VALID_TYPES = {ACCEPT, DISMISS, IGNORE} +@dataclass +class CapabilitiesRequest: + """CapabilitiesRequest.""" + + always_match: Any | None = None + first_match: list[Any | None] | None = None + + +@dataclass +class CapabilityRequest: + """CapabilityRequest.""" + + accept_insecure_certs: bool | None = None + browser_name: str | None = None + browser_version: str | None = None + platform_name: str | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None + + +@dataclass +class AutodetectProxyConfiguration: + """AutodetectProxyConfiguration.""" + + proxy_type: str = field(default="autodetect", init=False) + + +@dataclass +class DirectProxyConfiguration: + """DirectProxyConfiguration.""" + + proxy_type: str = field(default="direct", init=False) + + +@dataclass +class ManualProxyConfiguration: + """ManualProxyConfiguration.""" + + proxy_type: str = field(default="manual", init=False) + http_proxy: str | None = None + ssl_proxy: str | None = None + no_proxy: list[Any | None] | None = None + + +@dataclass +class SocksProxyConfiguration: + """SocksProxyConfiguration.""" + + socks_proxy: str | None = None + socks_version: Any | None = None + + +@dataclass +class PacProxyConfiguration: + """PacProxyConfiguration.""" + + proxy_type: str = field(default="pac", init=False) + proxy_autoconfig_url: str | None = None + + +@dataclass +class SystemProxyConfiguration: + """SystemProxyConfiguration.""" + + proxy_type: str = field(default="system", init=False) + + +@dataclass +class SubscribeParameters: + """SubscribeParameters.""" + + events: list[str | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class UnsubscribeByIDRequest: + """UnsubscribeByIDRequest.""" + + subscriptions: list[Any | None] | None = None + + +@dataclass +class UnsubscribeByAttributesRequest: + """UnsubscribeByAttributesRequest.""" + + events: list[str | None] | None = None + + +@dataclass +class StatusResult: + """StatusResult.""" + ready: bool | None = None + message: str | None = None + + +@dataclass +class NewParameters: + """NewParameters.""" + + capabilities: Any | None = None + + +@dataclass +class NewResult: + """NewResult.""" + + session_id: str | None = None + accept_insecure_certs: bool | None = None + browser_name: str | None = None + browser_version: str | None = None + platform_name: str | None = None + set_window_rect: bool | None = None + user_agent: str | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None + web_socket_url: str | None = None + + +@dataclass +class SubscribeResult: + """SubscribeResult.""" + + subscription: Any | None = None + + +@dataclass class UserPromptHandler: - """Represents the configuration of the user prompt handler.""" - - def __init__( - self, - alert: str | None = None, - before_unload: str | None = None, - confirm: str | None = None, - default: str | None = None, - file: str | None = None, - prompt: str | None = None, - ): - """Initialize UserPromptHandler. - - Args: - alert: Handler type for alert prompts. - before_unload: Handler type for beforeUnload prompts. - confirm: Handler type for confirm prompts. - default: Default handler type for all prompts. - file: Handler type for file picker prompts. - prompt: Handler type for prompt dialogs. - - Raises: - ValueError: If any handler type is not valid. - """ - for field_name, value in [ - ("alert", alert), - ("before_unload", before_unload), - ("confirm", confirm), - ("default", default), - ("file", file), - ("prompt", prompt), - ]: - if value is not None and value not in UserPromptHandlerType.VALID_TYPES: - raise ValueError( - f"Invalid {field_name} handler type: {value}. Must be one of {UserPromptHandlerType.VALID_TYPES}" - ) - - self.alert = alert - self.before_unload = before_unload - self.confirm = confirm - self.default = default - self.file = file - self.prompt = prompt - - def to_dict(self) -> dict[str, str]: - """Convert the UserPromptHandler to a dictionary for BiDi protocol. - - Returns: - Dictionary representation suitable for BiDi protocol. - """ - field_mapping = { - "alert": "alert", - "before_unload": "beforeUnload", - "confirm": "confirm", - "default": "default", - "file": "file", - "prompt": "prompt", - } + """UserPromptHandler.""" + + alert: Any | None = None + before_unload: Any | None = None + confirm: Any | None = None + default: Any | None = None + file: Any | None = None + prompt: Any | None = None + def to_bidi_dict(self) -> dict: + """Convert to BiDi protocol dict with camelCase keys.""" result = {} - for attr_name, dict_key in field_mapping.items(): - value = getattr(self, attr_name) - if value is not None: - result[dict_key] = value + if self.alert is not None: + result["alert"] = self.alert + if self.before_unload is not None: + result["beforeUnload"] = self.before_unload + if self.confirm is not None: + result["confirm"] = self.confirm + if self.default is not None: + result["default"] = self.default + if self.file is not None: + result["file"] = self.file + if self.prompt is not None: + result["prompt"] = self.prompt return result - class Session: - def __init__(self, conn): - self.conn = conn + """WebDriver BiDi session module.""" + + def __init__(self, conn) -> None: + self._conn = conn - def subscribe(self, *events, browsing_contexts=None): + def status(self): + """Execute session.status.""" params = { - "events": events, } - if browsing_contexts is None: - browsing_contexts = [] - if browsing_contexts: - params["browsingContexts"] = browsing_contexts - return command_builder("session.subscribe", params) + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.status", params) + result = self._conn.execute(cmd) + return result - def unsubscribe(self, *events, browsing_contexts=None): + def new(self, capabilities: Any | None = None): + """Execute session.new.""" params = { - "events": events, + "capabilities": capabilities, } - if browsing_contexts is None: - browsing_contexts = [] - if browsing_contexts: - params["browsingContexts"] = browsing_contexts - return command_builder("session.unsubscribe", params) + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.new", params) + result = self._conn.execute(cmd) + return result - def status(self): - """The session.status command returns information about the remote end's readiness. + def end(self): + """Execute session.end.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.end", params) + result = self._conn.execute(cmd) + return result + + def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute session.subscribe.""" + params = { + "events": events, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.subscribe", params) + result = self._conn.execute(cmd) + return result - Returns information about the remote end's readiness to create new sessions - and may include implementation-specific metadata. + def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): + """Execute session.unsubscribe.""" + params = { + "events": events, + "subscriptions": subscriptions, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.unsubscribe", params) + result = self._conn.execute(cmd) + return result - Returns: - Dictionary containing the ready state (bool), message (str) and metadata. - """ - cmd = command_builder("session.status", {}) - return self.conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 992ede07f4100..7e4c9c6dee459 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,150 +1,152 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: storage from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass -from selenium.webdriver.common.bidi.common import command_builder -if TYPE_CHECKING: - from selenium.webdriver.remote.websocket_connection import WebSocketConnection +@dataclass +class PartitionKey: + """PartitionKey.""" + user_context: str | None = None + source_origin: str | None = None -class SameSite: - """Represents the possible same site values for cookies.""" - STRICT = "strict" - LAX = "lax" - NONE = "none" - DEFAULT = "default" +@dataclass +class GetCookiesParameters: + """GetCookiesParameters.""" + + filter: Any | None = None + partition: Any | None = None + + +@dataclass +class GetCookiesResult: + """GetCookiesResult.""" + + cookies: list[Any | None] | None = None + partition_key: Any | None = None + + +@dataclass +class SetCookieParameters: + """SetCookieParameters.""" + + cookie: Any | None = None + partition: Any | None = None + + +@dataclass +class SetCookieResult: + """SetCookieResult.""" + + partition_key: Any | None = None + + +@dataclass +class DeleteCookiesParameters: + """DeleteCookiesParameters.""" + + filter: Any | None = None + partition: Any | None = None + + +@dataclass +class DeleteCookiesResult: + """DeleteCookiesResult.""" + + partition_key: Any | None = None class BytesValue: - """Represents a bytes value.""" + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ - TYPE_BASE64 = "base64" TYPE_STRING = "string" + TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str): + def __init__(self, type: str, value: str) -> None: self.type = type self.value = value - def to_dict(self) -> dict[str, str]: - """Converts the BytesValue to a dictionary. - - Returns: - A dictionary representation of the BytesValue. - """ + def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} +class SameSite: + """SameSite cookie attribute values.""" -class Cookie: - """Represents a cookie.""" - - def __init__( - self, - name: str, - value: BytesValue, - domain: str, - path: str | None = None, - size: int | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.size = size - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default" + +@dataclass +class StorageCookie: + """A cookie object returned by storage.getCookies.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None @classmethod - def from_dict(cls, data: dict[str, Any]) -> Cookie: - """Creates a Cookie instance from a dictionary. - - Args: - data: A dictionary containing the cookie information. - - Returns: - A new instance of Cookie. - """ - # Validation for empty strings - name = data.get("name") - if not name: - raise ValueError("name is required and cannot be empty") - domain = data.get("domain") - if not domain: - raise ValueError("domain is required and cannot be empty") - - value = BytesValue(data.get("value", {}).get("type"), data.get("value", {}).get("value")) + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + """Deserialize a wire-level cookie dict to a StorageCookie.""" + value_raw = raw.get("value") + if isinstance(value_raw, dict): + value = BytesValue(value_raw.get("type"), value_raw.get("value")) + else: + value = value_raw return cls( - name=str(name), + name=raw.get("name"), value=value, - domain=str(domain), - path=data.get("path"), - size=data.get("size"), - http_only=data.get("httpOnly"), - secure=data.get("secure"), - same_site=data.get("sameSite"), - expiry=data.get("expiry"), + domain=raw.get("domain"), + path=raw.get("path"), + size=raw.get("size"), + http_only=raw.get("httpOnly"), + secure=raw.get("secure"), + same_site=raw.get("sameSite"), + expiry=raw.get("expiry"), ) - +@dataclass class CookieFilter: - """Represents a filter for cookies.""" - - def __init__( - self, - name: str | None = None, - value: BytesValue | None = None, - domain: str | None = None, - path: str | None = None, - size: int | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.size = size - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry - - def to_dict(self) -> dict[str, Any]: - """Converts the CookieFilter to a dictionary. - - Returns: - A dictionary representation of the CookieFilter. - """ - result: dict[str, Any] = {} + """CookieFilter.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} if self.name is not None: result["name"] = self.name if self.value is not None: - result["value"] = self.value.to_dict() + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value if self.domain is not None: result["domain"] = self.domain if self.path is not None: @@ -161,103 +163,28 @@ def to_dict(self) -> dict[str, Any]: result["expiry"] = self.expiry return result - -class PartitionKey: - """Represents a storage partition key.""" - - def __init__(self, user_context: str | None = None, source_origin: str | None = None): - self.user_context = user_context - self.source_origin = source_origin - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> PartitionKey: - """Creates a PartitionKey instance from a dictionary. - - Args: - data: A dictionary containing the partition key information. - - Returns: - A new instance of PartitionKey. - """ - return cls( - user_context=data.get("userContext"), - source_origin=data.get("sourceOrigin"), - ) - - -class BrowsingContextPartitionDescriptor: - """Represents a browsing context partition descriptor.""" - - def __init__(self, context: str): - self.type = "context" - self.context = context - - def to_dict(self) -> dict[str, str]: - """Converts the BrowsingContextPartitionDescriptor to a dictionary. - - Returns: - Dict: A dictionary representation of the BrowsingContextPartitionDescriptor. - """ - return {"type": self.type, "context": self.context} - - -class StorageKeyPartitionDescriptor: - """Represents a storage key partition descriptor.""" - - def __init__(self, user_context: str | None = None, source_origin: str | None = None): - self.type = "storageKey" - self.user_context = user_context - self.source_origin = source_origin - - def to_dict(self) -> dict[str, str]: - """Converts the StorageKeyPartitionDescriptor to a dictionary. - - Returns: - Dict: A dictionary representation of the StorageKeyPartitionDescriptor. - """ - result = {"type": self.type} - if self.user_context is not None: - result["userContext"] = self.user_context - if self.source_origin is not None: - result["sourceOrigin"] = self.source_origin - return result - - +@dataclass class PartialCookie: - """Represents a partial cookie for setting.""" - - def __init__( - self, - name: str, - value: BytesValue, - domain: str, - path: str | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry - - def to_dict(self) -> dict[str, Any]: - """Converts the PartialCookie to a dictionary. - - Returns: - ------- - Dict: A dictionary representation of the PartialCookie. - """ - result: dict[str, Any] = { - "name": self.name, - "value": self.value.to_dict(), - "domain": self.domain, - } + """PartialCookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain if self.path is not None: result["path"] = self.path if self.http_only is not None: @@ -270,144 +197,132 @@ def to_dict(self) -> dict[str, Any]: result["expiry"] = self.expiry return result +class BrowsingContextPartitionDescriptor: + """BrowsingContextPartitionDescriptor. -class GetCookiesResult: - """Represents the result of a getCookies command.""" - - def __init__(self, cookies: list[Cookie], partition_key: PartitionKey): - self.cookies = cookies - self.partition_key = partition_key - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> GetCookiesResult: - """Creates a GetCookiesResult instance from a dictionary. - - Args: - data: A dictionary containing the get cookies result information. - - Returns: - A new instance of GetCookiesResult. - """ - cookies = [Cookie.from_dict(cookie) for cookie in data.get("cookies", [])] - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(cookies=cookies, partition_key=partition_key) - - -class SetCookieResult: - """Represents the result of a setCookie command.""" - - def __init__(self, partition_key: PartitionKey): - self.partition_key = partition_key - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> SetCookieResult: - """Creates a SetCookieResult instance from a dictionary. - - Args: - data: A dictionary containing the set cookie result information. - - Returns: - A new instance of SetCookieResult. - """ - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(partition_key=partition_key) - - -class DeleteCookiesResult: - """Represents the result of a deleteCookies command.""" + The first positional argument is *context* (a browsing-context ID / window + handle), mirroring how the class is used throughout the test suite: + ``BrowsingContextPartitionDescriptor(driver.current_window_handle)``. + """ - def __init__(self, partition_key: PartitionKey): - self.partition_key = partition_key + def __init__(self, context: Any = None, type: str = "context") -> None: + self.context = context + self.type = type - @classmethod - def from_dict(cls, data: dict[str, Any]) -> DeleteCookiesResult: - """Creates a DeleteCookiesResult instance from a dictionary. + def to_bidi_dict(self) -> dict: + return {"type": "context", "context": self.context} - Args: - data: A dictionary containing the delete cookies result information. +@dataclass +class StorageKeyPartitionDescriptor: + """StorageKeyPartitionDescriptor.""" - Returns: - A new instance of DeleteCookiesResult. - """ - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(partition_key=partition_key) + type: Any | None = "storageKey" + user_context: str | None = None + source_origin: str | None = None + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {"type": "storageKey"} + if self.user_context is not None: + result["userContext"] = self.user_context + if self.source_origin is not None: + result["sourceOrigin"] = self.source_origin + return result class Storage: - """BiDi implementation of the storage module.""" + """WebDriver BiDi storage module.""" - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn + def __init__(self, conn) -> None: + self._conn = conn - def get_cookies( - self, - filter: CookieFilter | None = None, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> GetCookiesResult: - """Gets cookies matching the specified filter. + def get_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.getCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + return result - Args: - filter: Optional filter to specify which cookies to retrieve. - partition: Optional partition key to limit the scope of the operation. + def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): + """Execute storage.setCookie.""" + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result - Returns: - A GetCookiesResult containing the cookies and partition key. + def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.deleteCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result - Example: - result = await storage.get_cookies( - filter=CookieFilter(name="sessionId"), - partition=PartitionKey(...) + def get_cookies(self, filter=None, partition=None): + """Execute storage.getCookies and return a GetCookiesResult.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + if result and "cookies" in result: + cookies = [ + StorageCookie.from_bidi_dict(c) + for c in result.get("cookies", []) + if isinstance(c, dict) + ] + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None ) - """ - params = {} - if filter is not None: - params["filter"] = filter.to_dict() - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.getCookies", params)) - return GetCookiesResult.from_dict(result) - - def set_cookie( - self, - cookie: PartialCookie, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> SetCookieResult: - """Sets a cookie in the browser. - - Args: - cookie: The cookie to set. - partition: Optional partition descriptor. - - Returns: - The result of the set cookie command. - """ - params = {"cookie": cookie.to_dict()} - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.setCookie", params)) - return SetCookieResult.from_dict(result) - - def delete_cookies( - self, - filter: CookieFilter | None = None, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> DeleteCookiesResult: - """Deletes cookies that match the given parameters. - - Args: - filter: Optional filter to match cookies to delete. - partition: Optional partition descriptor. - - Returns: - The result of the delete cookies command. - """ - params = {} - if filter is not None: - params["filter"] = filter.to_dict() - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.deleteCookies", params)) - return DeleteCookiesResult.from_dict(result) + return GetCookiesResult(cookies=cookies, partition_key=pk) + return GetCookiesResult(cookies=[], partition_key=None) + def set_cookie(self, cookie=None, partition=None): + """Execute storage.setCookie.""" + if cookie and hasattr(cookie, "to_bidi_dict"): + cookie = cookie.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result + def delete_cookies(self, filter=None, partition=None): + """Execute storage.deleteCookies.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 7609a04f3b3a4..8a737efeeafde 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,78 +1,112 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: webExtension +from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass -from selenium.common.exceptions import WebDriverException -from selenium.webdriver.common.bidi.common import command_builder + +@dataclass +class InstallParameters: + """InstallParameters.""" + + extension_data: Any | None = None + + +@dataclass +class ExtensionPath: + """ExtensionPath.""" + + type: str = field(default="path", init=False) + path: str | None = None + + +@dataclass +class ExtensionArchivePath: + """ExtensionArchivePath.""" + + type: str = field(default="archivePath", init=False) + path: str | None = None + + +@dataclass +class ExtensionBase64Encoded: + """ExtensionBase64Encoded.""" + + type: str = field(default="base64", init=False) + value: str | None = None + + +@dataclass +class InstallResult: + """InstallResult.""" + + extension: Any | None = None + + +@dataclass +class UninstallParameters: + """UninstallParameters.""" + + extension: Any | None = None class WebExtension: - """BiDi implementation of the webExtension module.""" + """WebDriver BiDi webExtension module.""" - def __init__(self, conn): - self.conn = conn + def __init__(self, conn) -> None: + self._conn = conn - def install(self, path=None, archive_path=None, base64_value=None) -> dict: - """Installs a web extension in the remote end. + def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + """Install a web extension. - You must provide exactly one of the parameters. + Exactly one of the three keyword arguments must be provided. Args: - path: Path to an extension directory. - archive_path: Path to an extension archive file. - base64_value: Base64 encoded string of the extension archive. + path: Directory path to an unpacked extension (also accepted for + signed ``.xpi`` / ``.crx`` archive files on Firefox). + archive_path: File-system path to a packed extension archive. + base64_value: Base64-encoded extension archive string. Returns: - A dictionary containing the extension ID. - """ - if sum(x is not None for x in (path, archive_path, base64_value)) != 1: - raise ValueError("Exactly one of path, archive_path, or base64_value must be provided") + The raw result dict from the BiDi ``webExtension.install`` command + (contains at least an ``"extension"`` key with the extension ID). + Raises: + ValueError: If more than one, or none, of the arguments is provided. + """ + provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + if len(provided) != 1: + raise ValueError( + f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" + ) if path is not None: extension_data = {"type": "path", "path": path} elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} - elif base64_value is not None: + else: extension_data = {"type": "base64", "value": base64_value} - params = {"extensionData": extension_data} - - try: - result = self.conn.execute(command_builder("webExtension.install", params)) - return result - except WebDriverException as e: - if "Method not available" in str(e): - raise WebDriverException( - f"{e!s}. If you are using Chrome or Edge, add '--enable-unsafe-extension-debugging' " - "and '--remote-debugging-pipe' arguments or set options.enable_webextensions = True" - ) from e - raise - - def uninstall(self, extension_id_or_result: str | dict) -> None: - """Uninstalls a web extension from the remote end. + cmd = command_builder("webExtension.install", params) + return self._conn.execute(cmd) + def uninstall(self, extension: Any | None = None): + """Uninstall a web extension. Args: - extension_id_or_result: Either the extension ID as a string or the result dictionary - from a previous install() call containing the extension ID. + extension: Either the extension ID string returned by ``install``, + or the full result dict returned by ``install`` (the + ``"extension"`` value is extracted automatically). """ - if isinstance(extension_id_or_result, dict): - extension_id = extension_id_or_result.get("extension") - else: - extension_id = extension_id_or_result - - params = {"extension": extension_id} - self.conn.execute(command_builder("webExtension.uninstall", params)) + if isinstance(extension, dict): + extension = extension.get("extension") + params = {"extension": extension} + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("webExtension.uninstall", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/by.py b/py/selenium/webdriver/common/by.py index d2a10ac70a7c6..9cc0ac1b1864a 100644 --- a/py/selenium/webdriver/common/by.py +++ b/py/selenium/webdriver/common/by.py @@ -16,9 +16,20 @@ # under the License. """The By implementation.""" +from __future__ import annotations + from typing import Literal -ByType = Literal["id", "xpath", "link text", "partial link text", "name", "tag name", "class name", "css selector"] +ByType = Literal[ + "id", + "xpath", + "link text", + "partial link text", + "name", + "tag name", + "class name", + "css selector", +] class By: diff --git a/py/selenium/webdriver/common/proxy.py b/py/selenium/webdriver/common/proxy.py index 89172d1122c36..28de19afa5742 100644 --- a/py/selenium/webdriver/common/proxy.py +++ b/py/selenium/webdriver/common/proxy.py @@ -17,6 +17,8 @@ """The Proxy implementation.""" +from __future__ import annotations + class ProxyTypeFactory: """Factory for proxy types.""" @@ -33,13 +35,23 @@ class ProxyType: profile preference, 'string' is id of proxy type. """ - DIRECT = ProxyTypeFactory.make(0, "DIRECT") # Direct connection, no proxy (default on Windows). - MANUAL = ProxyTypeFactory.make(1, "MANUAL") # Manual proxy settings (e.g., for httpProxy). + DIRECT = ProxyTypeFactory.make( + 0, "DIRECT" + ) # Direct connection, no proxy (default on Windows). + MANUAL = ProxyTypeFactory.make( + 1, "MANUAL" + ) # Manual proxy settings (e.g., for httpProxy). PAC = ProxyTypeFactory.make(2, "PAC") # Proxy autoconfiguration from URL. RESERVED_1 = ProxyTypeFactory.make(3, "RESERVED1") # Never used. - AUTODETECT = ProxyTypeFactory.make(4, "AUTODETECT") # Proxy autodetection (presumably with WPAD). - SYSTEM = ProxyTypeFactory.make(5, "SYSTEM") # Use system settings (default on Linux). - UNSPECIFIED = ProxyTypeFactory.make(6, "UNSPECIFIED") # Not initialized (for internal use). + AUTODETECT = ProxyTypeFactory.make( + 4, "AUTODETECT" + ) # Proxy autodetection (presumably with WPAD). + SYSTEM = ProxyTypeFactory.make( + 5, "SYSTEM" + ) # Use system settings (default on Linux). + UNSPECIFIED = ProxyTypeFactory.make( + 6, "UNSPECIFIED" + ) # Not initialized (for internal use). @classmethod def load(cls, value): @@ -48,7 +60,11 @@ def load(cls, value): value = str(value).upper() for attr in dir(cls): attr_value = getattr(cls, attr) - if isinstance(attr_value, dict) and "string" in attr_value and attr_value["string"] == value: + if ( + isinstance(attr_value, dict) + and "string" in attr_value + and attr_value["string"] == value + ): return attr_value raise Exception(f"No proxy type is found for {value}") @@ -203,13 +219,17 @@ def to_bidi_dict(self) -> dict: if self.noProxy: # Convert comma-separated string to list if isinstance(self.noProxy, str): - result["noProxy"] = [host.strip() for host in self.noProxy.split(",") if host.strip()] + result["noProxy"] = [ + host.strip() for host in self.noProxy.split(",") if host.strip() + ] elif isinstance(self.noProxy, list): if not all(isinstance(h, str) for h in self.noProxy): raise TypeError("no_proxy list must contain only strings") result["noProxy"] = self.noProxy else: - raise TypeError("no_proxy must be a comma-separated string or a list of strings") + raise TypeError( + "no_proxy must be a comma-separated string or a list of strings" + ) elif proxy_type == "pac": if self.proxyAutoconfigUrl: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 898eb8d547aa6..3b9402f7b547e 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -20,6 +20,7 @@ import base64 import contextlib import copy +import inspect import os import pkgutil import tempfile @@ -111,7 +112,9 @@ def get_remote_connection( client_config: ClientConfig | None = None, ) -> RemoteConnection: if isinstance(command_executor, str): - client_config = client_config or ClientConfig(remote_server_addr=command_executor) + client_config = client_config or ClientConfig( + remote_server_addr=command_executor + ) client_config.remote_server_addr = command_executor command_executor = RemoteConnection(client_config=client_config) @@ -392,9 +395,13 @@ def create_web_element(self, element_id: str) -> WebElement: def _unwrap_value(self, value): if isinstance(value, dict): if "element-6066-11e4-a52e-4f735466cecf" in value: - return self.create_web_element(value["element-6066-11e4-a52e-4f735466cecf"]) + return self.create_web_element( + value["element-6066-11e4-a52e-4f735466cecf"] + ) if "shadow-6066-11e4-a52e-4f735466cecf" in value: - return self._shadowroot_cls(self, value["shadow-6066-11e4-a52e-4f735466cecf"]) + return self._shadowroot_cls( + self, value["shadow-6066-11e4-a52e-4f735466cecf"] + ) for key, val in value.items(): value[key] = self._unwrap_value(val) return value @@ -420,18 +427,29 @@ def execute_cdp_cmd(self, cmd: str, cmd_args: dict): Example: `driver.execute_cdp_cmd("Network.getResponseBody", {"requestId": requestId})` """ - return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})["value"] + return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})[ + "value" + ] - def execute(self, driver_command: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + def execute( + self, driver_command: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: """Sends a command to be executed by a command.CommandExecutor. Args: - driver_command: The name of the command to execute as a string. + driver_command: The name of the command to execute as a string. Can also be a generator + for BiDi protocol commands. params: A dictionary of named parameters to send with the command. Returns: The command's JSON response loaded into a dictionary object. """ + # Handle BiDi generator commands + if inspect.isgenerator(driver_command): + # BiDi command: use WebSocketConnection directly + return self.command_executor.execute(driver_command) + + # Legacy WebDriver command: handle normally params = self._wrap_value(params) if self.session_id: @@ -440,7 +458,9 @@ def execute(self, driver_command: str, params: dict[str, Any] | None = None) -> elif "sessionId" not in params: params["sessionId"] = self.session_id - response = cast(RemoteConnection, self.command_executor).execute(driver_command, params) + response = cast(RemoteConnection, self.command_executor).execute( + driver_command, params + ) if response: self.error_handler.check_response(response) @@ -496,7 +516,9 @@ def unpin(self, script_key: ScriptKey) -> None: try: self.pinned_scripts.pop(script_key.id) except KeyError: - raise KeyError(f"No script with key: {script_key} existed in {self.pinned_scripts}") from None + raise KeyError( + f"No script with key: {script_key} existed in {self.pinned_scripts}" + ) from None def get_pinned_scripts(self) -> list[str]: """Return a list of all pinned scripts. @@ -529,7 +551,9 @@ def execute_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute(command, {"script": script, "args": converted_args})[ + "value" + ] def execute_async_script(self, script: str, *args) -> Any: """Asynchronously Executes JavaScript in the current window/frame. @@ -548,7 +572,9 @@ def execute_async_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT_ASYNC - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute(command, {"script": script, "args": converted_args})[ + "value" + ] @property def current_url(self) -> str: @@ -722,7 +748,9 @@ def implicitly_wait(self, time_to_wait: float) -> None: Example: `driver.implicitly_wait(30)` """ - self.execute(Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)} + ) def set_script_timeout(self, time_to_wait: float) -> None: """Set the timeout for asynchronous script execution. @@ -751,9 +779,14 @@ def set_page_load_timeout(self, time_to_wait: float) -> None: `driver.set_page_load_timeout(30)` """ try: - self.execute(Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)} + ) except WebDriverException: - self.execute(Command.SET_TIMEOUTS, {"ms": float(time_to_wait) * 1000, "type": "page load"}) + self.execute( + Command.SET_TIMEOUTS, + {"ms": float(time_to_wait) * 1000, "type": "page load"}, + ) @property def timeouts(self) -> Timeouts: @@ -789,7 +822,9 @@ def timeouts(self, timeouts) -> None: """ _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] - def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) -> WebElement: + def find_element( + self, by: str | RelativeBy = By.ID, value: str | None = None + ) -> WebElement: """Find an element given a By strategy and locator. Args: @@ -810,12 +845,18 @@ def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) - if isinstance(by, RelativeBy): elements = self.find_elements(by=by, value=value) if not elements: - raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") + raise NoSuchElementException( + f"Cannot locate relative element with: {by.root}" + ) return elements[0] - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] + return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})[ + "value" + ] - def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) -> list[WebElement]: + def find_elements( + self, by: str | RelativeBy = By.ID, value: str | None = None + ) -> list[WebElement]: """Find elements given a By strategy and locator. Args: @@ -837,14 +878,21 @@ def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) _pkg = ".".join(__name__.split(".")[:-1]) raw_data = pkgutil.get_data(_pkg, "findElements.js") if raw_data is None: - raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") + raise FileNotFoundError( + f"Could not find findElements.js in package {_pkg}" + ) raw_function = raw_data.decode("utf8") - find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" + find_element_js = ( + f"/* findElements */return ({raw_function}).apply(null, arguments);" + ) return self.execute_script(find_element_js, by.to_dict()) # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 - return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] + return ( + self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] + or [] + ) @property def capabilities(self) -> dict: @@ -941,7 +989,9 @@ def get_window_size(self, windowHandle: str = "current") -> dict: return {k: size[k] for k in ("width", "height")} - def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict: + def set_window_position( + self, x: float, y: float, windowHandle: str = "current" + ) -> dict: """Sets the x,y position of the current window. Args: @@ -969,7 +1019,10 @@ def get_window_position(self, windowHandle="current") -> dict: def _check_if_window_handle_is_current(self, windowHandle: str) -> None: """Warns if the window handle is not equal to `current`.""" if windowHandle != "current": - warnings.warn("Only 'current' window is supported for W3C compatible browsers.", stacklevel=2) + warnings.warn( + "Only 'current' window is supported for W3C compatible browsers.", + stacklevel=2, + ) def get_window_rect(self) -> dict: """Get the window's position and size. @@ -997,7 +1050,9 @@ def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: if (x is None and y is None) and (not height and not width): raise InvalidArgumentException("x and y or height and width need values") - return self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"] + return self.execute( + Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height} + )["value"] @property def file_detector(self) -> FileDetector: @@ -1042,7 +1097,9 @@ def orientation(self, value) -> None: if value.upper() in allowed_values: self.execute(Command.SET_SCREEN_ORIENTATION, {"orientation": value}) else: - raise WebDriverException("You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'") + raise WebDriverException( + "You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'" + ) def start_devtools(self) -> tuple[Any, WebSocketConnection]: global cdp @@ -1057,7 +1114,9 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) if cdp is None: raise WebDriverException("CDP module not loaded") @@ -1066,20 +1125,28 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: if self._websocket_connection: return self._devtools, self._websocket_connection if self.caps["browserName"].lower() == "firefox": - raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.") + raise RuntimeError( + "CDP support for Firefox has been removed. Please switch to WebDriver BiDi." + ) if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException("command_executor must be a RemoteConnection instance for CDP support") + raise WebDriverException( + "command_executor must be a RemoteConnection instance for CDP support" + ) self._websocket_connection = WebSocketConnection( ws_url, self.command_executor.client_config.websocket_timeout, self.command_executor.client_config.websocket_interval, ) - targets = self._websocket_connection.execute(self._devtools.target.get_targets()) + targets = self._websocket_connection.execute( + self._devtools.target.get_targets() + ) for target in targets: if target.target_id == self.current_window_handle: target_id = target.target_id break - session = self._websocket_connection.execute(self._devtools.target.attach_to_target(target_id, True)) + session = self._websocket_connection.execute( + self._devtools.target.attach_to_target(target_id, True) + ) self._websocket_connection.session_id = session return self._devtools, self._websocket_connection @@ -1094,7 +1161,9 @@ async def bidi_connection(self): version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) devtools = cdp.import_devtools(version) async with cdp.open_cdp(ws_url) as conn: @@ -1120,10 +1189,14 @@ def _start_bidi(self) -> None: if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") else: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException("command_executor must be a RemoteConnection instance for BiDi support") + raise WebDriverException( + "command_executor must be a RemoteConnection instance for BiDi support" + ) self._websocket_connection = WebSocketConnection( ws_url, @@ -1321,9 +1394,13 @@ def _get_cdp_details(self): http = urllib3.PoolManager() try: if self.caps.get("browserName") == "chrome": - debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") + debugger_address = self.caps.get("goog:chromeOptions").get( + "debuggerAddress" + ) elif self.caps.get("browserName") in ("MicrosoftEdge", "webview2"): - debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") + debugger_address = self.caps.get("ms:edgeOptions").get( + "debuggerAddress" + ) except AttributeError: raise WebDriverException("Can't get debugger address.") @@ -1351,7 +1428,9 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non driver.add_virtual_authenticator(options) ``` """ - self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] + self._authenticator_id = self.execute( + Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict() + )["value"] @property def virtual_authenticator_id(self) -> str | None: @@ -1365,7 +1444,10 @@ def remove_virtual_authenticator(self) -> None: The authenticator is no longer valid after removal, so no methods may be called. """ - self.execute(Command.REMOVE_VIRTUAL_AUTHENTICATOR, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_VIRTUAL_AUTHENTICATOR, + {"authenticatorId": self._authenticator_id}, + ) self._authenticator_id = None @required_virtual_authenticator @@ -1380,13 +1462,20 @@ def add_credential(self, credential: Credential) -> None: driver.add_credential(credential) ``` """ - self.execute(Command.ADD_CREDENTIAL, {**credential.to_dict(), "authenticatorId": self._authenticator_id}) + self.execute( + Command.ADD_CREDENTIAL, + {**credential.to_dict(), "authenticatorId": self._authenticator_id}, + ) @required_virtual_authenticator def get_credentials(self) -> list[Credential]: """Returns the list of credentials owned by the authenticator.""" - credential_data = self.execute(Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id}) - return [Credential.from_dict(credential) for credential in credential_data["value"]] + credential_data = self.execute( + Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id} + ) + return [ + Credential.from_dict(credential) for credential in credential_data["value"] + ] @required_virtual_authenticator def remove_credential(self, credential_id: str | bytearray) -> None: @@ -1401,13 +1490,16 @@ def remove_credential(self, credential_id: str | bytearray) -> None: credential_id = urlsafe_b64encode(credential_id).decode() self.execute( - Command.REMOVE_CREDENTIAL, {"credentialId": credential_id, "authenticatorId": self._authenticator_id} + Command.REMOVE_CREDENTIAL, + {"credentialId": credential_id, "authenticatorId": self._authenticator_id}, ) @required_virtual_authenticator def remove_all_credentials(self) -> None: """Removes all credentials from the authenticator.""" - self.execute(Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id} + ) @required_virtual_authenticator def set_user_verified(self, verified: bool) -> None: @@ -1420,12 +1512,17 @@ def set_user_verified(self, verified: bool) -> None: Example: `driver.set_user_verified(True)` """ - self.execute(Command.SET_USER_VERIFIED, {"authenticatorId": self._authenticator_id, "isUserVerified": verified}) + self.execute( + Command.SET_USER_VERIFIED, + {"authenticatorId": self._authenticator_id, "isUserVerified": verified}, + ) def get_downloadable_files(self) -> list: """Retrieves the downloadable files as a list of file names.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"] @@ -1440,12 +1537,16 @@ def download_file(self, file_name: str, target_directory: str) -> None: `driver.download_file("example.zip", "/path/to/directory")` """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) if not os.path.exists(target_directory): os.makedirs(target_directory) - contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"]["contents"] + contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"][ + "contents" + ] with tempfile.TemporaryDirectory() as tmp_dir: zip_file = os.path.join(tmp_dir, file_name + ".zip") @@ -1458,7 +1559,9 @@ def download_file(self, file_name: str, target_directory: str) -> None: def delete_downloadable_files(self) -> None: """Deletes all downloadable files.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) self.execute(Command.DELETE_DOWNLOADABLE_FILES) @@ -1564,5 +1667,10 @@ def _check_fedcm() -> Dialog | None: except NoAlertPresentException: return None - wait = WebDriverWait(self, timeout, poll_frequency=poll_frequency, ignored_exceptions=ignored_exceptions) + wait = WebDriverWait( + self, + timeout, + poll_frequency=poll_frequency, + ignored_exceptions=ignored_exceptions, + ) return wait.until(lambda _: _check_fedcm()) diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 98bf4f4b9057a..68358e4a09974 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import dataclasses import json import logging from ssl import CERT_NONE @@ -25,6 +26,40 @@ from selenium.common import WebDriverException + +def _snake_to_camel(name: str) -> str: + """Convert snake_case field name to camelCase for BiDi protocol.""" + parts = name.split("_") + return parts[0] + "".join(p.title() for p in parts[1:]) + + +class _BiDiEncoder(json.JSONEncoder): + """JSON encoder for BiDi dataclass instances. + + Converts snake_case field names to camelCase, strips ``None`` values, + and flattens a ``properties`` field (e.g. ``PointerCommonProperties``) + directly into its parent action dict as required by the BiDi spec. + """ + + def default(self, o): + if dataclasses.is_dataclass(o) and not isinstance(o, type): + result = {} + for f in dataclasses.fields(o): + value = getattr(o, f.name) + if value is None: + continue + camel_key = _snake_to_camel(f.name) + # Flatten PointerCommonProperties fields inline into the parent + if camel_key == "properties" and dataclasses.is_dataclass(value): + for pf in dataclasses.fields(value): + pv = getattr(value, pf.name) + if pv is not None: + result[_snake_to_camel(pf.name)] = pv + else: + result[camel_key] = value + return result + return super().default(o) + logger = logging.getLogger(__name__) @@ -63,7 +98,7 @@ def execute(self, command): if self.session_id: payload["sessionId"] = self.session_id - data = json.dumps(payload) + data = json.dumps(payload, cls=_BiDiEncoder) logger.debug(f"-> {data}"[: self._max_log_message_size]) self._ws.send(data) diff --git a/py/test/selenium/webdriver/common/bidi_browser_tests.py b/py/test/selenium/webdriver/common/bidi_browser_tests.py index 7fd054b73627d..b9e042403dcba 100644 --- a/py/test/selenium/webdriver/common/bidi_browser_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browser_tests.py @@ -20,7 +20,7 @@ import pytest from selenium.common.exceptions import TimeoutException -from selenium.webdriver.common.bidi.browser import ClientWindowInfo, ClientWindowState +from selenium.webdriver.common.bidi.browser import ClientWindowInfo, ClientWindowNamedState from selenium.webdriver.common.bidi.browsing_context import ReadinessState from selenium.webdriver.common.bidi.session import UserPromptHandler, UserPromptHandlerType from selenium.webdriver.common.by import By @@ -100,10 +100,9 @@ def test_raises_exception_when_removing_default_user_context(driver): def test_client_window_state_constants(driver): - assert ClientWindowState.FULLSCREEN == "fullscreen" - assert ClientWindowState.MAXIMIZED == "maximized" - assert ClientWindowState.MINIMIZED == "minimized" - assert ClientWindowState.NORMAL == "normal" + """Test ClientWindowNamedState constants.""" + assert ClientWindowNamedState.MAXIMIZED == "maximized" + assert ClientWindowNamedState.MINIMIZED == "minimized" def test_create_user_context_with_accept_insecure_certs(driver): @@ -177,7 +176,7 @@ def test_create_user_context_with_manual_proxy_all_params(driver, proxy_server): # Visit a site that should be proxied driver.get("http://example.com/") - body_text = driver.find_element("tag name", "body").text + body_text = driver.find_element(By.TAG_NAME, "body").text assert "proxied response" in body_text.lower() finally: From 565fb054c4f39ac2c7d5345a2ca7099504fa004b Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 27 Feb 2026 14:07:04 +0000 Subject: [PATCH 02/67] fixup --- py/generate_bidi.py | 204 ++++--- py/private/bidi_enhancements_manifest.py | 77 ++- py/selenium/webdriver/common/bidi/__init__.py | 23 + py/selenium/webdriver/common/bidi/browser.py | 43 +- .../webdriver/common/bidi/browsing_context.py | 171 ++++-- py/selenium/webdriver/common/bidi/cdp.py | 515 ------------------ py/selenium/webdriver/common/bidi/common.py | 7 +- py/selenium/webdriver/common/bidi/console.py | 0 .../webdriver/common/bidi/emulation.py | 187 +++---- py/selenium/webdriver/common/bidi/input.py | 36 +- py/selenium/webdriver/common/bidi/log.py | 223 +++++++- py/selenium/webdriver/common/bidi/network.py | 115 ++-- .../webdriver/common/bidi/permissions.py | 10 +- py/selenium/webdriver/common/bidi/script.py | 113 +++- py/selenium/webdriver/common/bidi/session.py | 30 +- py/selenium/webdriver/common/bidi/storage.py | 44 +- .../webdriver/common/bidi/webextension.py | 20 +- 17 files changed, 873 insertions(+), 945 deletions(-) delete mode 100644 py/selenium/webdriver/common/bidi/cdp.py mode change 100755 => 100644 py/selenium/webdriver/common/bidi/console.py diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 1770cf436bef1..2db595ff37cd0 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,12 +18,11 @@ import logging import re import sys -from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import dedent, indent as tw_indent -from typing import Any, Dict, List, Optional, Set, Tuple +from textwrap import indent as tw_indent +from typing import Any __version__ = "1.0.0" @@ -43,8 +42,7 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder +from typing import Any """ @@ -53,7 +51,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: +def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -124,10 +122,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"List[{inner_type}]" + return f"list[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "Dict[str, Any]" + return "dict[str, Any]" # Default to Any for unknown types return "Any" @@ -139,11 +137,11 @@ class CddlCommand: module: str name: str - params: Dict[str, str] = field(default_factory=dict) - result: Optional[str] = None + params: dict[str, str] = field(default_factory=dict) + result: str | None = None description: str = "" - def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python method code for this command. Args: @@ -174,8 +172,15 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str else: param_list = "self" - # Build method body - body = f" def {method_name}({param_list}):\n" + # Build method body - wrap long signatures over multiple lines if needed + sig_line = f" def {method_name}({param_list}):" + if len(sig_line) > 120 and param_strs: + body = f" def {method_name}(\n self,\n" + for p in param_strs: + body += f" {p},\n" + body += " ):\n" + else: + body = sig_line + "\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add validation if specified @@ -237,7 +242,6 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform - override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -264,45 +268,45 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f' item.get("{extract_property}")\n' - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += f" return extracted\n" - body += f" return result\n" + body += " return extracted\n" + body += " return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -351,10 +355,10 @@ class CddlTypeDefinition: module: str name: str - fields: Dict[str, str] = field(default_factory=dict) + fields: dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python dataclass code for this type. Args: @@ -366,7 +370,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> # Generate class name from type name (keep it as-is, don't split on underscores) class_name = self.name - code = f"@dataclass\n" + code = "@dataclass\n" code += f"class {class_name}:\n" code += f' """{self.description or self.name}."""\n\n' @@ -386,7 +390,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' # Check if this field is a list type - elif "List[" in python_type: + elif "list[" in python_type: code += f" {snake_name}: {python_type} = field(default_factory=list)\n" else: code += f" {snake_name}: {python_type} = None\n" @@ -453,7 +457,7 @@ class CddlEnum: module: str name: str - values: List[str] = field(default_factory=list) + values: list[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -530,10 +534,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: List[CddlCommand] = field(default_factory=list) - types: List[CddlTypeDefinition] = field(default_factory=list) - enums: List[CddlEnum] = field(default_factory=list) - events: List[CddlEvent] = field(default_factory=list) + commands: list[CddlCommand] = field(default_factory=list) + types: list[CddlTypeDefinition] = field(default_factory=list) + enums: list[CddlEnum] = field(default_factory=list) + events: list[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -548,7 +552,33 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def _needs_field_import(self, enhancements: dict[str, Any] | None = None) -> bool: + """Check if any type definition in this module requires the 'field' import. + + Respects the same type exclusions applied during code generation. + """ + enhancements = enhancements or {} + extra_cls_names: set[str] = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names + + for type_def in self.types: + if type_def.name in exclude_types: + continue + for field_type in type_def.fields.values(): + # Literal string discriminants use field(default=..., init=False) + if re.match(r'^"', field_type.strip()): + return True + # List-typed fields use field(default_factory=list) + python_type = CddlTypeDefinition._get_python_type(field_type) + if python_type.startswith("list["): + return True + return False + + def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python code for this module. Args: @@ -558,17 +588,21 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code = MODULE_HEADER.format(self.name) # Add imports if needed - if self.types: - code += "from dataclasses import field\n" + if self.commands: + code += "from .common import command_builder\n" + dataclass_imported = False if self.commands or self.types: - code += "from typing import Generator\n" code += "from dataclasses import dataclass\n" + dataclass_imported = True + if self.types and self._needs_field_import(enhancements): + code += "from dataclasses import field\n" # Add imports for event handling if needed if self.events: code += "import threading\n" code += "from collections.abc import Callable\n" - code += "from dataclasses import dataclass\n" + if not dataclass_imported: + code += "from dataclasses import dataclass\n" code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -660,7 +694,13 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - exclude_types = set(enhancements.get("exclude_types", [])) + # Also auto-exclude types whose names appear in extra_dataclasses + extra_cls_names = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names for type_def in self.types: if type_def.name in exclude_types: continue @@ -680,13 +720,16 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: # Generate EVENT_NAME_MAPPING for the module code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" + # Collect event keys from extra_events so we skip CDDL duplicates + extra_event_keys = {evt["event_key"] for evt in enhancements.get("extra_events", [])} for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - code += f' "{event_name}": "{event_def.method}",\n' + if event_name not in extra_event_keys: + code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): code += ( @@ -923,7 +966,13 @@ def clear_event_handlers(self) -> None: code += "\n" # Generate command methods - exclude_methods = enhancements.get("exclude_methods", []) + # Auto-exclude methods whose names appear in extra_methods to prevent duplicates + extra_method_names = set() + for extra_meth in enhancements.get("extra_methods", []): + m = re.search(r"def\s+(\w+)\s*\(", extra_meth) + if m: + extra_method_names.add(m.group(1)) + exclude_methods = set(enhancements.get("exclude_methods", [])) | extra_method_names if self.commands: for command in self.commands: # Get method-specific enhancements @@ -981,24 +1030,44 @@ def clear_event_handlers(self) -> None: code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined - code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" # Use globals() to look up types dynamically to handle missing types gracefully - code += f"_globals = globals()\n" + code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" + # Collect extra event keys to skip CDDL duplicates + extra_event_keys_cfg = {evt["event_key"] for evt in enhancements.get("extra_events", [])} for event_def in self.events: # Convert method name to user-friendly event name method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) + if event_name in extra_event_keys_cfg: + continue # The event class is the event name (e.g., ContextCreated) # Try to get it from globals, default to dict if not found - code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + code += ( + f' "{event_name}": (\n' + f' EventConfig("{event_name}", "{event_def.method}",\n' + f' _globals.get("{event_def.name}", dict))\n' + f' if _globals.get("{event_def.name}")\n' + f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' + f' ),\n' + ) # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] be = extra_evt["bidi_event"] ec = extra_evt["event_class"] - code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + single = f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),' + if len(single) > 120: + code += ( + f' "{ek}": EventConfig(\n' + f' "{ek}", "{be}",\n' + f' _globals.get("{ec}", dict),\n' + f' ),\n' + ) + else: + code += single + "\n" code += "}\n" return code @@ -1011,9 +1080,9 @@ def __init__(self, cddl_path: str): """Initialize parser with CDDL file path.""" self.cddl_path = Path(cddl_path) self.content = "" - self.modules: Dict[str, CddlModule] = {} - self.definitions: Dict[str, str] = {} - self.event_names: Set[str] = set() # Names of definitions that are events + self.modules: dict[str, CddlModule] = {} + self.definitions: dict[str, str] = {} + self.event_names: set[str] = set() # Names of definitions that are events self._read_file() def _read_file(self) -> None: @@ -1021,12 +1090,12 @@ def _read_file(self) -> None: if not self.cddl_path.exists(): raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - with open(self.cddl_path, "r", encoding="utf-8") as f: + with open(self.cddl_path, encoding="utf-8") as f: self.content = f.read() logger.info(f"Loaded CDDL file: {self.cddl_path}") - def parse(self) -> Dict[str, CddlModule]: + def parse(self) -> dict[str, CddlModule]: """Parse CDDL content and return modules.""" # Remove comments content = self._remove_comments(self.content) @@ -1090,9 +1159,6 @@ def _extract_event_names(self) -> None: ... ) """ - # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. - event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") - for def_name, def_content in self.definitions.items(): # Check if this looks like an event union (name ends with "Event") and # contains a module-qualified reference like "module.EventName". @@ -1175,7 +1241,7 @@ def _is_enum_definition(self, definition: str) -> bool: # Pattern: "something" / "something_else" return " / " in clean_def and '"' in clean_def - def _extract_enum_values(self, enum_definition: str) -> List[str]: + def _extract_enum_values(self, enum_definition: str) -> list[str]: """Extract individual values from an enum definition. Enums are defined as: "value1" / "value2" / "value3" @@ -1225,7 +1291,7 @@ def _normalize_cddl_type(field_type: str) -> str: result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) return result.strip() - def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + def _extract_type_fields(self, type_definition: str) -> dict[str, str]: """Extract fields from a type definition block.""" fields = {} @@ -1352,8 +1418,8 @@ def _extract_commands(self) -> None: ) def _extract_parameters( - self, params_type: str, _seen: Optional[Set[str]] = None - ) -> Dict[str, str]: + self, params_type: str, _seen: set[str] | None = None + ) -> dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), @@ -1466,7 +1532,7 @@ def module_name_to_filename(module_name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: +def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: """Generate __init__.py file for the module.""" init_path = output_path / "__init__.py" @@ -1481,7 +1547,7 @@ def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> Non filename = module_name_to_filename(module_name) code += f"from .{filename} import {class_name}\n" - code += f"\n__all__ = [\n" + code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) code += f' "{class_name}",\n' @@ -1703,7 +1769,7 @@ def main( cddl_file: str, output_dir: str, spec_version: str = "1.0", - enhancements_manifest: Optional[str] = None, + enhancements_manifest: str | None = None, ) -> None: """Main entry point. diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index ae7229f6ddebd..39af67d4c635b 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -85,7 +85,12 @@ # downloadBehavior is never stripped by the generic None filter. # The BiDi spec marks it as required (can be null, but must be present). "extra_methods": [ - ''' def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + ''' def set_download_behavior( + self, + allowed: bool | None = None, + destination_folder: str | None = None, + user_contexts: list[Any] | None = None, + ): """Set the download behavior for the browser. Args: @@ -272,8 +277,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": self, coordinates=None, error=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setGeolocationOverride. @@ -325,8 +330,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_timezone_override( self, timezone=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setTimezoneOverride. @@ -349,8 +354,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_scripting_enabled( self, enabled=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScriptingEnabled. @@ -373,8 +378,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_user_agent_override( self, user_agent=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setUserAgentOverride. @@ -396,8 +401,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_screen_orientation_override( self, screen_orientation=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenOrientationOverride. @@ -433,8 +438,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": self, network_conditions=None, offline: bool | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setNetworkConditions. @@ -534,7 +539,14 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw''', - ''' def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + ''' def _add_preload_script( + self, + function_declaration, + arguments=None, + contexts=None, + user_contexts=None, + sandbox=None, + ): """Add a preload script with validation. Args: @@ -586,7 +598,15 @@ def _serialize_arg(value): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id)''', - ''' def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + ''' def _evaluate( + self, + expression, + target, + await_promise, + result_ownership=None, + serialization_options=None, + user_activation=None, + ): """Evaluate a script expression and return a structured result. Args: @@ -621,7 +641,17 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None)''', - ''' def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + ''' def _call_function( + self, + function_declaration, + await_promise, + target, + arguments=None, + result_ownership=None, + this=None, + user_activation=None, + serialization_options=None, + ): """Call a function and return a structured result. Args: @@ -1256,7 +1286,12 @@ def to_bidi_dict(self) -> dict: # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], "extra_methods": [ - ''' def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + ''' def install( + self, + path: str | None = None, + archive_path: str | None = None, + base64_value: str | None = None, + ): """Install a web extension. Exactly one of the three keyword arguments must be provided. @@ -1274,7 +1309,11 @@ def to_bidi_dict(self) -> dict: Raises: ValueError: If more than one, or none, of the arguments is provided. """ - provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + provided = [ + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None + ] if len(provided) != 1: raise ValueError( f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" @@ -1502,6 +1541,7 @@ def _add_event_handler( - 'history_updated' Args: + self: The module instance this handler is bound to. event_name: The name of the event to subscribe to callback: Callback function to invoke when event occurs contexts: Optional list of context IDs to limit event subscription @@ -1538,6 +1578,7 @@ def _remove_event_handler( """Remove an event handler by its callback ID. Args: + self: The module instance this handler is bound to. callback_id: The callback ID returned from add_event_handler """ if not hasattr(self, "_event_handlers"): diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index ab96f2d81e292..7be7bd4f73856 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -5,3 +5,26 @@ from __future__ import annotations +from .browser import Browser +from .browsing_context import BrowsingContext +from .emulation import Emulation +from .input import Input +from .log import Log +from .network import Network +from .script import Script +from .session import Session +from .storage import Storage +from .webextension import WebExtension + +__all__ = [ + "Browser", + "BrowsingContext", + "Emulation", + "Input", + "Log", + "Network", + "Script", + "Session", + "Storage", + "WebExtension", +] diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index ed6a4d8f33bc5..acda63f71953e 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass def transform_download_params( @@ -131,14 +130,14 @@ class CreateUserContextParameters: class GetClientWindowsResult: """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = None + client_windows: list[Any | None] | None = field(default_factory=list) @dataclass class GetUserContextsResult: """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -171,7 +170,7 @@ class SetDownloadBehaviorParameters: """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -204,7 +203,12 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): + def create_user_context( + self, + accept_insecure_certs: bool | None = None, + proxy: Any | None = None, + unhandled_prompt_behavior: Any | None = None, + ): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -285,23 +289,12 @@ def set_client_window_state(self, client_window: Any | None = None): result = self._conn.execute(cmd) return result - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): - """Execute browser.setDownloadBehavior.""" - validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) - - download_behavior = None - download_behavior = transform_download_params(allowed, destination_folder) - - params = { - "downloadBehavior": download_behavior, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setDownloadBehavior", params) - result = self._conn.execute(cmd) - return result - - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + def set_download_behavior( + self, + allowed: bool | None = None, + destination_folder: str | None = None, + user_contexts: list[Any] | None = None, + ): """Set the download behavior for the browser. Args: diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 35aea615d1780..5f128635df29d 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class ReadinessState: """ReadinessState.""" @@ -220,14 +219,14 @@ class LocateNodesParameters: context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = None + start_nodes: list[Any | None] | None = field(default_factory=list) @dataclass class LocateNodesResult: """LocateNodesResult.""" - nodes: list[Any | None] | None = None + nodes: list[Any | None] | None = field(default_factory=list) @dataclass @@ -300,7 +299,7 @@ class SetViewportParameters: context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -328,20 +327,6 @@ class HistoryUpdatedParameters: url: str | None = None -@dataclass -class DownloadWillBeginParams: - """DownloadWillBeginParams.""" - - suggested_filename: str | None = None - - -@dataclass -class DownloadCanceledParams: - """DownloadCanceledParams.""" - - status: str = field(default="canceled", init=False) - - @dataclass class UserPromptClosedParameters: """UserPromptClosedParameters.""" @@ -390,10 +375,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -414,8 +399,6 @@ def from_json(cls, params: dict) -> "DownloadEndParams": "history_updated": "browsingContext.historyUpdated", "dom_content_loaded": "browsingContext.domContentLoaded", "load": "browsingContext.load", - "download_will_begin": "browsingContext.downloadWillBegin", - "download_end": "browsingContext.downloadEnd", "navigation_aborted": "browsingContext.navigationAborted", "navigation_committed": "browsingContext.navigationCommitted", "navigation_failed": "browsingContext.navigationFailed", @@ -630,7 +613,13 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): + def capture_screenshot( + self, + context: str | None = None, + format: Any | None = None, + clip: Any | None = None, + origin: str | None = None, + ): """Execute browsingContext.captureScreenshot.""" params = { "context": context, @@ -657,7 +646,13 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): + def create( + self, + type: Any | None = None, + reference_context: Any | None = None, + background: bool | None = None, + user_context: Any | None = None, + ): """Execute browsingContext.create.""" params = { "type": type, @@ -711,7 +706,14 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): + def locate_nodes( + self, + context: str | None = None, + locator: Any | None = None, + serialization_options: Any | None = None, + start_nodes: Any | None = None, + max_node_count: int | None = None, + ): """Execute browsingContext.locateNodes.""" params = { "context": context, @@ -740,7 +742,15 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): + def print( + self, + context: Any | None = None, + background: bool | None = None, + margin: Any | None = None, + page: Any | None = None, + scale: Any | None = None, + shrink_to_fit: bool | None = None, + ): """Execute browsingContext.print.""" params = { "context": context, @@ -770,7 +780,13 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + def set_viewport( + self, + context: str | None = None, + viewport: Any | None = None, + user_contexts: Any | None = None, + device_pixel_ratio: Any | None = None, + ): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -868,20 +884,81 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), - "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), - "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), - "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), - "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), - "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), - "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), - "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), - "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), - "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), - "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), - "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), - "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), - "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), - "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), + "context_created": ( + EventConfig("context_created", "browsingContext.contextCreated", + _globals.get("ContextCreated", dict)) + if _globals.get("ContextCreated") + else EventConfig("context_created", "browsingContext.contextCreated", dict) + ), + "context_destroyed": ( + EventConfig("context_destroyed", "browsingContext.contextDestroyed", + _globals.get("ContextDestroyed", dict)) + if _globals.get("ContextDestroyed") + else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict) + ), + "navigation_started": ( + EventConfig("navigation_started", "browsingContext.navigationStarted", + _globals.get("NavigationStarted", dict)) + if _globals.get("NavigationStarted") + else EventConfig("navigation_started", "browsingContext.navigationStarted", dict) + ), + "fragment_navigated": ( + EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", + _globals.get("FragmentNavigated", dict)) + if _globals.get("FragmentNavigated") + else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict) + ), + "history_updated": ( + EventConfig("history_updated", "browsingContext.historyUpdated", + _globals.get("HistoryUpdated", dict)) + if _globals.get("HistoryUpdated") + else EventConfig("history_updated", "browsingContext.historyUpdated", dict) + ), + "dom_content_loaded": ( + EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", + _globals.get("DomContentLoaded", dict)) + if _globals.get("DomContentLoaded") + else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict) + ), + "load": ( + EventConfig("load", "browsingContext.load", + _globals.get("Load", dict)) + if _globals.get("Load") + else EventConfig("load", "browsingContext.load", dict) + ), + "navigation_aborted": ( + EventConfig("navigation_aborted", "browsingContext.navigationAborted", + _globals.get("NavigationAborted", dict)) + if _globals.get("NavigationAborted") + else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict) + ), + "navigation_committed": ( + EventConfig("navigation_committed", "browsingContext.navigationCommitted", + _globals.get("NavigationCommitted", dict)) + if _globals.get("NavigationCommitted") + else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict) + ), + "navigation_failed": ( + EventConfig("navigation_failed", "browsingContext.navigationFailed", + _globals.get("NavigationFailed", dict)) + if _globals.get("NavigationFailed") + else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict) + ), + "user_prompt_closed": ( + EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", + _globals.get("UserPromptClosed", dict)) + if _globals.get("UserPromptClosed") + else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict) + ), + "user_prompt_opened": ( + EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", + _globals.get("UserPromptOpened", dict)) + if _globals.get("UserPromptOpened") + else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict) + ), + "download_will_begin": EventConfig( + "download_will_begin", "browsingContext.downloadWillBegin", + _globals.get("DownloadWillBeginParams", dict), + ), "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), } diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py deleted file mode 100644 index b097762fe50cd..0000000000000 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ /dev/null @@ -1,515 +0,0 @@ -# The MIT License(MIT) -# -# Copyright(c) 2018 Hyperion Gray -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files(the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp - -import contextvars -import importlib -import itertools -import json -import logging -import pathlib -from collections import defaultdict -from collections.abc import AsyncGenerator, AsyncIterator, Generator -from contextlib import asynccontextmanager, contextmanager -from dataclasses import dataclass -from typing import Any, TypeVar - -import trio -from trio_websocket import ConnectionClosed as WsConnectionClosed -from trio_websocket import connect_websocket_url - -logger = logging.getLogger("trio_cdp") -T = TypeVar("T") -MAX_WS_MESSAGE_SIZE = 2**24 - -devtools = None -version = None - - -def import_devtools(ver): - """Attempt to load the current latest available devtools into the module cache for use later.""" - global devtools - global version - version = ver - base = "selenium.webdriver.common.devtools.v" - try: - devtools = importlib.import_module(f"{base}{ver}") - return devtools - except ModuleNotFoundError: - # Attempt to parse and load the 'most recent' devtools module. This is likely - # because cdp has been updated but selenium python has not been released yet. - devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") - versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - latest = max(int(x[1:]) for x in versions) - selenium_logger = logging.getLogger(__name__) - selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) - devtools = importlib.import_module(f"{base}{latest}") - return devtools - - -_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") -_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") - - -def get_connection_context(fn_name): - """Look up the current connection. - - If there is no current connection, raise a ``RuntimeError`` with a - helpful message. - """ - try: - return _connection_context.get() - except LookupError: - raise RuntimeError(f"{fn_name}() must be called in a connection context.") - - -def get_session_context(fn_name): - """Look up the current session. - - If there is no current session, raise a ``RuntimeError`` with a - helpful message. - """ - try: - return _session_context.get() - except LookupError: - raise RuntimeError(f"{fn_name}() must be called in a session context.") - - -@contextmanager -def connection_context(connection): - """Context manager installs ``connection`` as the session context for the current Trio task.""" - token = _connection_context.set(connection) - try: - yield - finally: - _connection_context.reset(token) - - -@contextmanager -def session_context(session): - """Context manager installs ``session`` as the session context for the current Trio task.""" - token = _session_context.set(session) - try: - yield - finally: - _session_context.reset(token) - - -def set_global_connection(connection): - """Install ``connection`` in the root context so that it will become the default connection for all tasks. - - This is generally not recommended, except it may be necessary in - certain use cases such as running inside Jupyter notebook. - """ - global _connection_context - _connection_context = contextvars.ContextVar("_connection_context", default=connection) - - -def set_global_session(session): - """Install ``session`` in the root context so that it will become the default session for all tasks. - - This is generally not recommended, except it may be necessary in - certain use cases such as running inside Jupyter notebook. - """ - global _session_context - _session_context = contextvars.ContextVar("_session_context", default=session) - - -class BrowserError(Exception): - """This exception is raised when the browser's response to a command indicates that an error occurred.""" - - def __init__(self, obj): - self.code = obj.get("code") - self.message = obj.get("message") - self.detail = obj.get("data") - - def __str__(self): - return f"BrowserError {self.detail}" - - -class CdpConnectionClosed(WsConnectionClosed): - """Raised when a public method is called on a closed CDP connection.""" - - def __init__(self, reason): - """Constructor. - - Args: - reason: wsproto.frame_protocol.CloseReason - """ - self.reason = reason - - def __repr__(self): - """Return representation.""" - return f"{self.__class__.__name__}<{self.reason}>" - - -class InternalError(Exception): - """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" - - pass - - -@dataclass -class CmEventProxy: - """A proxy object returned by :meth:`CdpBase.wait_for()``. - - After the context manager executes, this proxy object will have a - value set that contains the returned event. - """ - - value: Any = None - - -class CdpBase: - def __init__(self, ws, session_id, target_id): - self.ws = ws - self.session_id = session_id - self.target_id = target_id - self.channels = defaultdict(set) - self.id_iter = itertools.count() - self.inflight_cmd = {} - self.inflight_result = {} - - async def execute(self, cmd: Generator[dict, T, Any]) -> T: - """Execute a command on the server and wait for the result. - - Args: - cmd: any CDP command - - Returns: - a CDP result - """ - cmd_id = next(self.id_iter) - cmd_event = trio.Event() - self.inflight_cmd[cmd_id] = cmd, cmd_event - request = next(cmd) - request["id"] = cmd_id - if self.session_id: - request["sessionId"] = self.session_id - request_str = json.dumps(request) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") - try: - await self.ws.send_message(request_str) - except WsConnectionClosed as wcc: - raise CdpConnectionClosed(wcc.reason) from None - await cmd_event.wait() - response = self.inflight_result.pop(cmd_id) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Received CDP message: {response}") - if isinstance(response, Exception): - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") - raise response - return response - - def listen(self, *event_types, buffer_size=10): - """Listen for events. - - Returns: - An async iterator that iterates over events matching the indicated types. - """ - sender, receiver = trio.open_memory_channel(buffer_size) - for event_type in event_types: - self.channels[event_type].add(sender) - return receiver - - @asynccontextmanager - async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: - """Wait for an event of the given type and return it. - - This is an async context manager, so you should open it inside - an async with block. The block will not exit until the indicated - event is received. - """ - sender: trio.MemorySendChannel - receiver: trio.MemoryReceiveChannel - sender, receiver = trio.open_memory_channel(buffer_size) - self.channels[event_type].add(sender) - proxy = CmEventProxy() - yield proxy - async with receiver: - event = await receiver.receive() - proxy.value = event - - def _handle_data(self, data): - """Handle incoming WebSocket data. - - Args: - data: a JSON dictionary - """ - if "id" in data: - self._handle_cmd_response(data) - else: - self._handle_event(data) - - def _handle_cmd_response(self, data: dict): - """Handle a response to a command. - - This will set an event flag that will return control to the - task that called the command. - - Args: - data: response as a JSON dictionary - """ - cmd_id = data["id"] - try: - cmd, event = self.inflight_cmd.pop(cmd_id) - except KeyError: - logger.warning("Got a message with a command ID that does not exist: %s", data) - return - if "error" in data: - # If the server reported an error, convert it to an exception and do - # not process the response any further. - self.inflight_result[cmd_id] = BrowserError(data["error"]) - else: - # Otherwise, continue the generator to parse the JSON result - # into a CDP object. - try: - _ = cmd.send(data["result"]) - raise InternalError("The command's generator function did not exit when expected!") - except StopIteration as exit: - return_ = exit.value - self.inflight_result[cmd_id] = return_ - event.set() - - def _handle_event(self, data: dict): - """Handle an event. - - Args: - data: event as a JSON dictionary - """ - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - event = devtools.util.parse_json_event(data) - logger.debug("Received event: %s", event) - to_remove = set() - for sender in self.channels[type(event)]: - try: - sender.send_nowait(event) - except trio.WouldBlock: - logger.error('Unable to send event "%r" due to full channel %s', event, sender) - except trio.BrokenResourceError: - to_remove.add(sender) - if to_remove: - self.channels[type(event)] -= to_remove - - -class CdpSession(CdpBase): - """Contains the state for a CDP session. - - Generally you should not instantiate this object yourself; you should call - :meth:`CdpConnection.open_session`. - """ - - def __init__(self, ws, session_id, target_id): - """Constructor. - - Args: - ws: trio_websocket.WebSocketConnection - session_id: devtools.target.SessionID - target_id: devtools.target.TargetID - """ - super().__init__(ws, session_id, target_id) - - self._dom_enable_count = 0 - self._dom_enable_lock = trio.Lock() - self._page_enable_count = 0 - self._page_enable_lock = trio.Lock() - - @asynccontextmanager - async def dom_enable(self): - """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. - - This keeps track of concurrent callers and only disables DOM - events when all callers have exited. - """ - global devtools - async with self._dom_enable_lock: - self._dom_enable_count += 1 - if self._dom_enable_count == 1: - await self.execute(devtools.dom.enable()) - - yield - - async with self._dom_enable_lock: - self._dom_enable_count -= 1 - if self._dom_enable_count == 0: - await self.execute(devtools.dom.disable()) - - @asynccontextmanager - async def page_enable(self): - """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. - - This keeps track of concurrent callers and only disables page - events when all callers have exited. - """ - global devtools - async with self._page_enable_lock: - self._page_enable_count += 1 - if self._page_enable_count == 1: - await self.execute(devtools.page.enable()) - - yield - - async with self._page_enable_lock: - self._page_enable_count -= 1 - if self._page_enable_count == 0: - await self.execute(devtools.page.disable()) - - -class CdpConnection(CdpBase, trio.abc.AsyncResource): - """Contains the connection state for a Chrome DevTools Protocol server. - - CDP can multiplex multiple "sessions" over a single connection. This - class corresponds to the "root" session, i.e. the implicitly created - session that has no session ID. This class is responsible for - reading incoming WebSocket messages and forwarding them to the - corresponding session, as well as handling messages targeted at the - root session itself. You should generally call the - :func:`open_cdp()` instead of instantiating this class directly. - """ - - def __init__(self, ws): - """Constructor. - - Args: - ws: trio_websocket.WebSocketConnection - """ - super().__init__(ws, session_id=None, target_id=None) - self.sessions = {} - - async def aclose(self): - """Close the underlying WebSocket connection. - - This will cause the reader task to gracefully exit when it tries - to read the next message from the WebSocket. All of the public - APIs (``execute()``, ``listen()``, etc.) will raise - ``CdpConnectionClosed`` after the CDP connection is closed. It - is safe to call this multiple times. - """ - await self.ws.aclose() - - @asynccontextmanager - async def open_session(self, target_id) -> AsyncIterator[CdpSession]: - """Context manager opens a session and enables the "simple" style of calling CDP APIs. - - For example, inside a session context, you can call ``await - dom.get_document()`` and it will execute on the current session - automatically. - """ - session = await self.connect_session(target_id) - with session_context(session): - yield session - - async def connect_session(self, target_id) -> "CdpSession": - """Returns a new :class:`CdpSession` connected to the specified target.""" - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) - session = CdpSession(self.ws, session_id, target_id) - self.sessions[session_id] = session - return session - - async def _reader_task(self): - """Runs in the background and handles incoming messages. - - Dispatches responses to commands and events to listeners. - """ - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - while True: - try: - message = await self.ws.get_message() - except WsConnectionClosed: - # If the WebSocket is closed, we don't want to throw an - # exception from the reader task. Instead we will throw - # exceptions from the public API methods, and we can quietly - # exit the reader task here. - break - try: - data = json.loads(message) - except json.JSONDecodeError: - raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) - logger.debug("Received message %r", data) - if "sessionId" in data: - session_id = devtools.target.SessionID(data["sessionId"]) - try: - session = self.sessions[session_id] - except KeyError: - raise BrowserError( - { - "code": -32700, - "message": "Browser sent a message for an invalid session", - "data": f"{session_id!r}", - } - ) - session._handle_data(data) - else: - self._handle_data(data) - - for _, session in self.sessions.items(): - for _, senders in session.channels.items(): - for sender in senders: - sender.close() - - -@asynccontextmanager -async def open_cdp(url) -> AsyncIterator[CdpConnection]: - """Async context manager opens a connection to the browser then closes the connection when the block exits. - - The context manager also sets the connection as the default - connection for the current task, so that commands like ``await - target.get_targets()`` will run on this connection automatically. If - you want to use multiple connections concurrently, it is recommended - to open each on in a separate task. - """ - async with trio.open_nursery() as nursery: - conn = await connect_cdp(nursery, url) - try: - with connection_context(conn): - yield conn - finally: - await conn.aclose() - - -async def connect_cdp(nursery, url) -> CdpConnection: - """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. - - The ``open_cdp()`` context manager is preferred in most situations. - You should only use this function if you need to specify a custom - nursery. This connection is not automatically closed! You can either - use the connection object as a context manager (``async with - conn:``) or else call ``await conn.aclose()`` on it when you are - done with it. If ``set_context`` is True, then the returned - connection will be installed as the default connection for the - current task. This argument is for unusual use cases, such as - running inside of a notebook. - """ - ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) - cdp_conn = CdpConnection(ws) - nursery.start_soon(cdp_conn._reader_task) - return cdp_conn diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d90d8c770263a..d7cb436a08471 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,13 @@ """Common utilities for BiDi command construction.""" -from typing import Any, Dict, Generator +from collections.abc import Generator +from typing import Any def command_builder( - method: str, params: Dict[str, Any] -) -> Generator[Dict[str, Any], Any, Any]: + method: str, params: dict[str, Any] +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/console.py b/py/selenium/webdriver/common/bidi/console.py old mode 100755 new mode 100644 diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 4cd6ae2e3c712..cb575bbdc54dd 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class ForcedColorsModeTheme: @@ -41,16 +40,16 @@ class SetForcedColorsModeThemeOverrideParameters: """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetGeolocationOverrideParameters: """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -78,8 +77,8 @@ class SetLocaleOverrideParameters: """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -87,8 +86,8 @@ class setNetworkConditionsParameters: """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -111,8 +110,8 @@ class SetScreenSettingsOverrideParameters: """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -128,8 +127,8 @@ class SetScreenOrientationOverrideParameters: """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -137,8 +136,8 @@ class SetUserAgentOverrideParameters: """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -146,8 +145,8 @@ class SetViewportMetaOverrideParameters: """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -155,8 +154,8 @@ class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -164,8 +163,8 @@ class SetScrollbarTypeOverrideParameters: """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -173,16 +172,16 @@ class SetTimezoneOverrideParameters: """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetTouchOverrideParameters: """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) class Emulation: @@ -191,7 +190,12 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_forced_colors_mode_theme_override( + self, + theme: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setForcedColorsModeThemeOverride.""" params = { "theme": theme, @@ -203,18 +207,12 @@ def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contex result = self._conn.execute(cmd) return result - def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setGeolocationOverride.""" - params = { - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setGeolocationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_locale_override( + self, + locale: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setLocaleOverride.""" params = { "locale": locale, @@ -226,19 +224,12 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setNetworkConditions.""" - params = { - "networkConditions": network_conditions, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setNetworkConditions", params) - result = self._conn.execute(cmd) - return result - - def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_screen_settings_override( + self, + screen_area: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setScreenSettingsOverride.""" params = { "screenArea": screen_area, @@ -250,31 +241,12 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScreenOrientationOverride.""" - params = { - "screenOrientation": screen_orientation, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScreenOrientationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setUserAgentOverride.""" - params = { - "userAgent": user_agent, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setUserAgentOverride", params) - result = self._conn.execute(cmd) - return result - - def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_viewport_meta_override( + self, + viewport_meta: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setViewportMetaOverride.""" params = { "viewportMeta": viewport_meta, @@ -286,19 +258,12 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScriptingEnabled.""" - params = { - "enabled": enabled, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScriptingEnabled", params) - result = self._conn.execute(cmd) - return result - - def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_scrollbar_type_override( + self, + scrollbar_type: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setScrollbarTypeOverride.""" params = { "scrollbarType": scrollbar_type, @@ -310,19 +275,7 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context result = self._conn.execute(cmd) return result - def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setTimezoneOverride.""" - params = { - "timezone": timezone, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setTimezoneOverride", params) - result = self._conn.execute(cmd) - return result - - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, @@ -337,8 +290,8 @@ def set_geolocation_override( self, coordinates=None, error=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setGeolocationOverride. @@ -390,8 +343,8 @@ def set_geolocation_override( def set_timezone_override( self, timezone=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setTimezoneOverride. @@ -414,8 +367,8 @@ def set_timezone_override( def set_scripting_enabled( self, enabled=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScriptingEnabled. @@ -438,8 +391,8 @@ def set_scripting_enabled( def set_user_agent_override( self, user_agent=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setUserAgentOverride. @@ -461,8 +414,8 @@ def set_user_agent_override( def set_screen_orientation_override( self, screen_orientation=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenOrientationOverride. @@ -498,8 +451,8 @@ def set_network_conditions( self, network_conditions=None, offline: bool | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setNetworkConditions. diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 5dbe71dbd3886..13f43361293f2 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: input from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class PointerType: """PointerType.""" @@ -45,7 +44,7 @@ class PerformActionsParameters: """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -54,7 +53,7 @@ class NoneSourceActions: type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -63,7 +62,7 @@ class KeySourceActions: type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -73,7 +72,7 @@ class PointerSourceActions: type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -89,7 +88,7 @@ class WheelSourceActions: type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -163,7 +162,7 @@ class SetFilesParameters: context: Any | None = None element: Any | None = None - files: list[Any | None] | None = None + files: list[Any | None] | None = field(default_factory=list) @dataclass @@ -175,7 +174,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -368,7 +367,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): """Execute input.performActions.""" params = { "context": context, @@ -389,7 +388,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): """Execute input.setFiles.""" params = { "context": context, @@ -454,5 +453,10 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), + "file_dialog_opened": ( + EventConfig("file_dialog_opened", "input.fileDialogOpened", + _globals.get("FileDialogOpened", dict)) + if _globals.get("FileDialogOpened") + else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict) + ), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index faf6c85ae2b6c..7971b807e94a1 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,11 +6,12 @@ # WebDriver BiDi module: log from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator +import threading +from collections.abc import Callable from dataclasses import dataclass +from typing import Any + +from selenium.webdriver.common.bidi.session import Session class Level: @@ -56,7 +57,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -81,7 +82,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -92,18 +93,212 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), ) +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "entry_added": "log.entryAdded", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + + class Log: """WebDriver BiDi log module.""" + EVENT_CONFIGS = {} def __init__(self, conn) -> None: self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + pass + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + """ + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + """ + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: log.entryAdded +EntryAdded = globals().get('Entry', dict) # Fallback to dict if type not defined - def entry_added(self): - """Execute log.entryAdded.""" - params = { - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("log.entryAdded", params) - result = self._conn.execute(cmd) - return result +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Log.EVENT_CONFIGS = { + "entry_added": ( + EventConfig("entry_added", "log.entryAdded", + _globals.get("EntryAdded", dict)) + if _globals.get("EntryAdded") + else EventConfig("entry_added", "log.entryAdded", dict) + ), +} diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 4f44e309bffbb..6e02eeabc4ed7 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: network from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SameSite: """SameSite.""" @@ -75,7 +74,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = None + intercepts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -171,13 +170,13 @@ class ResponseData: status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = None + auth_challenges: list[Any | None] | None = field(default_factory=list) @dataclass @@ -219,11 +218,11 @@ class UrlPatternString: class AddDataCollectorParameters: """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = None + data_types: list[Any | None] | None = field(default_factory=list) max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -237,9 +236,9 @@ class AddDataCollectorResult: class AddInterceptParameters: """AddInterceptParameters.""" - phases: list[Any | None] | None = None - contexts: list[Any | None] | None = None - url_patterns: list[Any | None] | None = None + phases: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + url_patterns: list[Any | None] | None = field(default_factory=list) @dataclass @@ -254,9 +253,9 @@ class ContinueResponseParameters: """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) credentials: Any | None = None - headers: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -315,8 +314,8 @@ class ProvideResponseParameters: request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = None - headers: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -340,16 +339,16 @@ class SetCacheBehaviorParameters: """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetExtraHeadersParameters: """SetExtraHeadersParameters.""" - headers: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -562,7 +561,14 @@ def __init__(self, conn) -> None: self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) self.intercepts = [] - def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def add_data_collector( + self, + data_types: list[Any] | None = None, + max_encoded_data_size: Any | None = None, + collector_type: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute network.addDataCollector.""" params = { "dataTypes": data_types, @@ -576,7 +582,12 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da result = self._conn.execute(cmd) return result - def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): + def add_intercept( + self, + phases: list[Any] | None = None, + contexts: list[Any] | None = None, + url_patterns: list[Any] | None = None, + ): """Execute network.addIntercept.""" params = { "phases": phases, @@ -588,7 +599,15 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): + def continue_request( + self, + request: Any | None = None, + body: Any | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, + method: Any | None = None, + url: Any | None = None, + ): """Execute network.continueRequest.""" params = { "request": request, @@ -603,7 +622,15 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, result = self._conn.execute(cmd) return result - def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def continue_response( + self, + request: Any | None = None, + cookies: list[Any] | None = None, + credentials: Any | None = None, + headers: list[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.continueResponse.""" params = { "request": request, @@ -650,7 +677,13 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): + def get_data( + self, + data_type: Any | None = None, + collector: Any | None = None, + disown: bool | None = None, + request: Any | None = None, + ): """Execute network.getData.""" params = { "dataType": data_type, @@ -663,7 +696,15 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d result = self._conn.execute(cmd) return result - def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def provide_response( + self, + request: Any | None = None, + body: Any | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.provideResponse.""" params = { "request": request, @@ -698,7 +739,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): """Execute network.setCacheBehavior.""" params = { "cacheBehavior": cache_behavior, @@ -709,7 +750,12 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A result = self._conn.execute(cmd) return result - def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_extra_headers( + self, + headers: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute network.setExtraHeaders.""" params = { "headers": headers, @@ -918,6 +964,11 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), + "auth_required": ( + EventConfig("auth_required", "network.authRequired", + _globals.get("AuthRequired", dict)) + if _globals.get("AuthRequired") + else EventConfig("auth_required", "network.authRequired", dict) + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index f00e765c62e3b..6dd138da17309 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Optional, Union +from typing import Any from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: Union[PermissionDescriptor, str], - state: Union[PermissionState, str], - origin: Optional[str] = None, - user_context: Optional[str] = None, + descriptor: PermissionDescriptor | str, + state: PermissionState | str, + origin: str | None = None, + user_context: str | None = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index e13c11f71a5cb..b29721db88503 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: script from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SpecialNumber: """SpecialNumber.""" @@ -216,7 +215,7 @@ class DedicatedWorkerRealmInfo: """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = None + owners: list[Any | None] | None = field(default_factory=list) @dataclass @@ -460,7 +459,7 @@ class NodeProperties: node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = None + children: list[Any | None] | None = field(default_factory=list) local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -499,7 +498,7 @@ class StackFrame: class StackTrace: """StackTrace.""" - call_frames: list[Any | None] | None = None + call_frames: list[Any | None] | None = field(default_factory=list) @dataclass @@ -530,9 +529,9 @@ class AddPreloadScriptParameters: """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + arguments: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) sandbox: str | None = None @@ -547,7 +546,7 @@ class AddPreloadScriptResult: class DisownParameters: """DisownParameters.""" - handles: list[Any | None] | None = None + handles: list[Any | None] | None = field(default_factory=list) target: Any | None = None @@ -558,7 +557,7 @@ class CallFunctionParameters: function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = None + arguments: list[Any | None] | None = field(default_factory=list) result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -589,7 +588,7 @@ class GetRealmsParameters: class GetRealmsResult: """GetRealmsResult.""" - realms: list[Any | None] | None = None + realms: list[Any | None] | None = field(default_factory=list) @dataclass @@ -783,7 +782,14 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): + def add_preload_script( + self, + function_declaration: Any | None = None, + arguments: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + sandbox: Any | None = None, + ): """Execute script.addPreloadScript.""" params = { "functionDeclaration": function_declaration, @@ -797,7 +803,7 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: result = self._conn.execute(cmd) return result - def disown(self, handles: List[Any] | None = None, target: Any | None = None): + def disown(self, handles: list[Any] | None = None, target: Any | None = None): """Execute script.disown.""" params = { "handles": handles, @@ -808,7 +814,17 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): + def call_function( + self, + function_declaration: Any | None = None, + await_promise: bool | None = None, + target: Any | None = None, + arguments: list[Any] | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + this: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.callFunction.""" params = { "functionDeclaration": function_declaration, @@ -825,7 +841,15 @@ def call_function(self, function_declaration: Any | None = None, await_promise: result = self._conn.execute(cmd) return result - def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): + def evaluate( + self, + expression: Any | None = None, + target: Any | None = None, + await_promise: bool | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.evaluate.""" params = { "expression": expression, @@ -889,8 +913,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -941,7 +966,14 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw - def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + def _add_preload_script( + self, + function_declaration, + arguments=None, + contexts=None, + user_contexts=None, + sandbox=None, + ): """Add a preload script with validation. Args: @@ -993,7 +1025,15 @@ def unpin(self, script_id): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id) - def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + def _evaluate( + self, + expression, + target, + await_promise, + result_ownership=None, + serialization_options=None, + user_activation=None, + ): """Evaluate a script expression and return a structured result. Args: @@ -1028,7 +1068,17 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None) - def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + def _call_function( + self, + function_declaration, + await_promise, + target, + arguments=None, + result_ownership=None, + this=None, + user_activation=None, + serialization_options=None, + ): """Call a function and return a structured result. Args: @@ -1106,8 +1156,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" @@ -1257,6 +1308,16 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), - "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), + "realm_created": ( + EventConfig("realm_created", "script.realmCreated", + _globals.get("RealmCreated", dict)) + if _globals.get("RealmCreated") + else EventConfig("realm_created", "script.realmCreated", dict) + ), + "realm_destroyed": ( + EventConfig("realm_destroyed", "script.realmDestroyed", + _globals.get("RealmDestroyed", dict)) + if _globals.get("RealmDestroyed") + else EventConfig("realm_destroyed", "script.realmDestroyed", dict) + ), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 9b1daaae557fa..c1b5be09ca024 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class UserPromptHandlerType: @@ -26,7 +25,7 @@ class CapabilitiesRequest: """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = None + first_match: list[Any | None] | None = field(default_factory=list) @dataclass @@ -62,7 +61,7 @@ class ManualProxyConfiguration: proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = None + no_proxy: list[Any | None] | None = field(default_factory=list) @dataclass @@ -92,23 +91,23 @@ class SystemProxyConfiguration: class SubscribeParameters: """SubscribeParameters.""" - events: list[str | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + events: list[str | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByIDRequest: """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = None + subscriptions: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = None + events: list[str | None] | None = field(default_factory=list) @dataclass @@ -211,7 +210,12 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def subscribe( + self, + events: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute session.subscribe.""" params = { "events": events, @@ -223,7 +227,7 @@ def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None result = self._conn.execute(cmd) return result - def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): + def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7e4c9c6dee459..3f29b85d13a23 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -33,7 +32,7 @@ class GetCookiesParameters: class GetCookiesResult: """GetCookiesResult.""" - cookies: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) partition_key: Any | None = None @@ -107,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): @@ -235,39 +234,6 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn - def get_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.getCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.getCookies", params) - result = self._conn.execute(cmd) - return result - - def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): - """Execute storage.setCookie.""" - params = { - "cookie": cookie, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.setCookie", params) - result = self._conn.execute(cmd) - return result - - def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.deleteCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.deleteCookies", params) - result = self._conn.execute(cmd) - return result - def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 8a737efeeafde..ebbe6729499b2 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -64,7 +63,12 @@ class WebExtension: def __init__(self, conn) -> None: self._conn = conn - def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + def install( + self, + path: str | None = None, + archive_path: str | None = None, + base64_value: str | None = None, + ): """Install a web extension. Exactly one of the three keyword arguments must be provided. @@ -82,7 +86,11 @@ def install(self, path: str | None = None, archive_path: str | None = None, base Raises: ValueError: If more than one, or none, of the arguments is provided. """ - provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + provided = [ + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None + ] if len(provided) != 1: raise ValueError( f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" From 9d40d06d0d7fae0254735ef6723fb323e01059c0 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 28 Feb 2026 08:54:27 +0000 Subject: [PATCH 03/67] fixup --- py/generate_bidi.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 2db595ff37cd0..4bf0d8b64514e 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -721,7 +721,9 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" # Collect event keys from extra_events so we skip CDDL duplicates - extra_event_keys = {evt["event_key"] for evt in enhancements.get("extra_events", [])} + extra_event_keys = { + evt["event_key"] for evt in enhancements.get("extra_events", []) + } for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" @@ -972,7 +974,9 @@ def clear_event_handlers(self) -> None: m = re.search(r"def\s+(\w+)\s*\(", extra_meth) if m: extra_method_names.add(m.group(1)) - exclude_methods = set(enhancements.get("exclude_methods", [])) | extra_method_names + exclude_methods = ( + set(enhancements.get("exclude_methods", [])) | extra_method_names + ) if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1035,7 +1039,9 @@ def clear_event_handlers(self) -> None: code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" # Collect extra event keys to skip CDDL duplicates - extra_event_keys_cfg = {evt["event_key"] for evt in enhancements.get("extra_events", [])} + extra_event_keys_cfg = { + evt["event_key"] for evt in enhancements.get("extra_events", []) + } for event_def in self.events: # Convert method name to user-friendly event name method_parts = event_def.method.split(".") @@ -1051,7 +1057,7 @@ def clear_event_handlers(self) -> None: f' _globals.get("{event_def.name}", dict))\n' f' if _globals.get("{event_def.name}")\n' f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' - f' ),\n' + f" ),\n" ) # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): @@ -1064,7 +1070,7 @@ def clear_event_handlers(self) -> None: f' "{ek}": EventConfig(\n' f' "{ek}", "{be}",\n' f' _globals.get("{ec}", dict),\n' - f' ),\n' + f" ),\n" ) else: code += single + "\n" From 030b3e603d67fd4178fdb45e5fca1770700e1195 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 28 Feb 2026 08:55:57 +0000 Subject: [PATCH 04/67] fixup --- py/selenium/webdriver/common/bidi/cdp.py | 515 +++++++++++++++++++++++ 1 file changed, 515 insertions(+) create mode 100644 py/selenium/webdriver/common/bidi/cdp.py diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py new file mode 100644 index 0000000000000..b097762fe50cd --- /dev/null +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -0,0 +1,515 @@ +# The MIT License(MIT) +# +# Copyright(c) 2018 Hyperion Gray +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp + +import contextvars +import importlib +import itertools +import json +import logging +import pathlib +from collections import defaultdict +from collections.abc import AsyncGenerator, AsyncIterator, Generator +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass +from typing import Any, TypeVar + +import trio +from trio_websocket import ConnectionClosed as WsConnectionClosed +from trio_websocket import connect_websocket_url + +logger = logging.getLogger("trio_cdp") +T = TypeVar("T") +MAX_WS_MESSAGE_SIZE = 2**24 + +devtools = None +version = None + + +def import_devtools(ver): + """Attempt to load the current latest available devtools into the module cache for use later.""" + global devtools + global version + version = ver + base = "selenium.webdriver.common.devtools.v" + try: + devtools = importlib.import_module(f"{base}{ver}") + return devtools + except ModuleNotFoundError: + # Attempt to parse and load the 'most recent' devtools module. This is likely + # because cdp has been updated but selenium python has not been released yet. + devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") + versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) + latest = max(int(x[1:]) for x in versions) + selenium_logger = logging.getLogger(__name__) + selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) + devtools = importlib.import_module(f"{base}{latest}") + return devtools + + +_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") +_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") + + +def get_connection_context(fn_name): + """Look up the current connection. + + If there is no current connection, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _connection_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a connection context.") + + +def get_session_context(fn_name): + """Look up the current session. + + If there is no current session, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _session_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a session context.") + + +@contextmanager +def connection_context(connection): + """Context manager installs ``connection`` as the session context for the current Trio task.""" + token = _connection_context.set(connection) + try: + yield + finally: + _connection_context.reset(token) + + +@contextmanager +def session_context(session): + """Context manager installs ``session`` as the session context for the current Trio task.""" + token = _session_context.set(session) + try: + yield + finally: + _session_context.reset(token) + + +def set_global_connection(connection): + """Install ``connection`` in the root context so that it will become the default connection for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _connection_context + _connection_context = contextvars.ContextVar("_connection_context", default=connection) + + +def set_global_session(session): + """Install ``session`` in the root context so that it will become the default session for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _session_context + _session_context = contextvars.ContextVar("_session_context", default=session) + + +class BrowserError(Exception): + """This exception is raised when the browser's response to a command indicates that an error occurred.""" + + def __init__(self, obj): + self.code = obj.get("code") + self.message = obj.get("message") + self.detail = obj.get("data") + + def __str__(self): + return f"BrowserError {self.detail}" + + +class CdpConnectionClosed(WsConnectionClosed): + """Raised when a public method is called on a closed CDP connection.""" + + def __init__(self, reason): + """Constructor. + + Args: + reason: wsproto.frame_protocol.CloseReason + """ + self.reason = reason + + def __repr__(self): + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" + + +class InternalError(Exception): + """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" + + pass + + +@dataclass +class CmEventProxy: + """A proxy object returned by :meth:`CdpBase.wait_for()``. + + After the context manager executes, this proxy object will have a + value set that contains the returned event. + """ + + value: Any = None + + +class CdpBase: + def __init__(self, ws, session_id, target_id): + self.ws = ws + self.session_id = session_id + self.target_id = target_id + self.channels = defaultdict(set) + self.id_iter = itertools.count() + self.inflight_cmd = {} + self.inflight_result = {} + + async def execute(self, cmd: Generator[dict, T, Any]) -> T: + """Execute a command on the server and wait for the result. + + Args: + cmd: any CDP command + + Returns: + a CDP result + """ + cmd_id = next(self.id_iter) + cmd_event = trio.Event() + self.inflight_cmd[cmd_id] = cmd, cmd_event + request = next(cmd) + request["id"] = cmd_id + if self.session_id: + request["sessionId"] = self.session_id + request_str = json.dumps(request) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") + try: + await self.ws.send_message(request_str) + except WsConnectionClosed as wcc: + raise CdpConnectionClosed(wcc.reason) from None + await cmd_event.wait() + response = self.inflight_result.pop(cmd_id) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Received CDP message: {response}") + if isinstance(response, Exception): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") + raise response + return response + + def listen(self, *event_types, buffer_size=10): + """Listen for events. + + Returns: + An async iterator that iterates over events matching the indicated types. + """ + sender, receiver = trio.open_memory_channel(buffer_size) + for event_type in event_types: + self.channels[event_type].add(sender) + return receiver + + @asynccontextmanager + async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: + """Wait for an event of the given type and return it. + + This is an async context manager, so you should open it inside + an async with block. The block will not exit until the indicated + event is received. + """ + sender: trio.MemorySendChannel + receiver: trio.MemoryReceiveChannel + sender, receiver = trio.open_memory_channel(buffer_size) + self.channels[event_type].add(sender) + proxy = CmEventProxy() + yield proxy + async with receiver: + event = await receiver.receive() + proxy.value = event + + def _handle_data(self, data): + """Handle incoming WebSocket data. + + Args: + data: a JSON dictionary + """ + if "id" in data: + self._handle_cmd_response(data) + else: + self._handle_event(data) + + def _handle_cmd_response(self, data: dict): + """Handle a response to a command. + + This will set an event flag that will return control to the + task that called the command. + + Args: + data: response as a JSON dictionary + """ + cmd_id = data["id"] + try: + cmd, event = self.inflight_cmd.pop(cmd_id) + except KeyError: + logger.warning("Got a message with a command ID that does not exist: %s", data) + return + if "error" in data: + # If the server reported an error, convert it to an exception and do + # not process the response any further. + self.inflight_result[cmd_id] = BrowserError(data["error"]) + else: + # Otherwise, continue the generator to parse the JSON result + # into a CDP object. + try: + _ = cmd.send(data["result"]) + raise InternalError("The command's generator function did not exit when expected!") + except StopIteration as exit: + return_ = exit.value + self.inflight_result[cmd_id] = return_ + event.set() + + def _handle_event(self, data: dict): + """Handle an event. + + Args: + data: event as a JSON dictionary + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + event = devtools.util.parse_json_event(data) + logger.debug("Received event: %s", event) + to_remove = set() + for sender in self.channels[type(event)]: + try: + sender.send_nowait(event) + except trio.WouldBlock: + logger.error('Unable to send event "%r" due to full channel %s', event, sender) + except trio.BrokenResourceError: + to_remove.add(sender) + if to_remove: + self.channels[type(event)] -= to_remove + + +class CdpSession(CdpBase): + """Contains the state for a CDP session. + + Generally you should not instantiate this object yourself; you should call + :meth:`CdpConnection.open_session`. + """ + + def __init__(self, ws, session_id, target_id): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + session_id: devtools.target.SessionID + target_id: devtools.target.TargetID + """ + super().__init__(ws, session_id, target_id) + + self._dom_enable_count = 0 + self._dom_enable_lock = trio.Lock() + self._page_enable_count = 0 + self._page_enable_lock = trio.Lock() + + @asynccontextmanager + async def dom_enable(self): + """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. + + This keeps track of concurrent callers and only disables DOM + events when all callers have exited. + """ + global devtools + async with self._dom_enable_lock: + self._dom_enable_count += 1 + if self._dom_enable_count == 1: + await self.execute(devtools.dom.enable()) + + yield + + async with self._dom_enable_lock: + self._dom_enable_count -= 1 + if self._dom_enable_count == 0: + await self.execute(devtools.dom.disable()) + + @asynccontextmanager + async def page_enable(self): + """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. + + This keeps track of concurrent callers and only disables page + events when all callers have exited. + """ + global devtools + async with self._page_enable_lock: + self._page_enable_count += 1 + if self._page_enable_count == 1: + await self.execute(devtools.page.enable()) + + yield + + async with self._page_enable_lock: + self._page_enable_count -= 1 + if self._page_enable_count == 0: + await self.execute(devtools.page.disable()) + + +class CdpConnection(CdpBase, trio.abc.AsyncResource): + """Contains the connection state for a Chrome DevTools Protocol server. + + CDP can multiplex multiple "sessions" over a single connection. This + class corresponds to the "root" session, i.e. the implicitly created + session that has no session ID. This class is responsible for + reading incoming WebSocket messages and forwarding them to the + corresponding session, as well as handling messages targeted at the + root session itself. You should generally call the + :func:`open_cdp()` instead of instantiating this class directly. + """ + + def __init__(self, ws): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + """ + super().__init__(ws, session_id=None, target_id=None) + self.sessions = {} + + async def aclose(self): + """Close the underlying WebSocket connection. + + This will cause the reader task to gracefully exit when it tries + to read the next message from the WebSocket. All of the public + APIs (``execute()``, ``listen()``, etc.) will raise + ``CdpConnectionClosed`` after the CDP connection is closed. It + is safe to call this multiple times. + """ + await self.ws.aclose() + + @asynccontextmanager + async def open_session(self, target_id) -> AsyncIterator[CdpSession]: + """Context manager opens a session and enables the "simple" style of calling CDP APIs. + + For example, inside a session context, you can call ``await + dom.get_document()`` and it will execute on the current session + automatically. + """ + session = await self.connect_session(target_id) + with session_context(session): + yield session + + async def connect_session(self, target_id) -> "CdpSession": + """Returns a new :class:`CdpSession` connected to the specified target.""" + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + session = CdpSession(self.ws, session_id, target_id) + self.sessions[session_id] = session + return session + + async def _reader_task(self): + """Runs in the background and handles incoming messages. + + Dispatches responses to commands and events to listeners. + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + while True: + try: + message = await self.ws.get_message() + except WsConnectionClosed: + # If the WebSocket is closed, we don't want to throw an + # exception from the reader task. Instead we will throw + # exceptions from the public API methods, and we can quietly + # exit the reader task here. + break + try: + data = json.loads(message) + except json.JSONDecodeError: + raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) + logger.debug("Received message %r", data) + if "sessionId" in data: + session_id = devtools.target.SessionID(data["sessionId"]) + try: + session = self.sessions[session_id] + except KeyError: + raise BrowserError( + { + "code": -32700, + "message": "Browser sent a message for an invalid session", + "data": f"{session_id!r}", + } + ) + session._handle_data(data) + else: + self._handle_data(data) + + for _, session in self.sessions.items(): + for _, senders in session.channels.items(): + for sender in senders: + sender.close() + + +@asynccontextmanager +async def open_cdp(url) -> AsyncIterator[CdpConnection]: + """Async context manager opens a connection to the browser then closes the connection when the block exits. + + The context manager also sets the connection as the default + connection for the current task, so that commands like ``await + target.get_targets()`` will run on this connection automatically. If + you want to use multiple connections concurrently, it is recommended + to open each on in a separate task. + """ + async with trio.open_nursery() as nursery: + conn = await connect_cdp(nursery, url) + try: + with connection_context(conn): + yield conn + finally: + await conn.aclose() + + +async def connect_cdp(nursery, url) -> CdpConnection: + """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. + + The ``open_cdp()`` context manager is preferred in most situations. + You should only use this function if you need to specify a custom + nursery. This connection is not automatically closed! You can either + use the connection object as a context manager (``async with + conn:``) or else call ``await conn.aclose()`` on it when you are + done with it. If ``set_context`` is True, then the returned + connection will be installed as the default connection for the + current task. This argument is for unusual use cases, such as + running inside of a notebook. + """ + ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) + cdp_conn = CdpConnection(ws) + nursery.start_soon(cdp_conn._reader_task) + return cdp_conn From 66bf6e4e4123777689e0323b3a09e5777ff14fdf Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Mon, 2 Mar 2026 11:26:58 +0000 Subject: [PATCH 05/67] handle comments --- py/generate_bidi.py | 64 +++--- py/private/bidi_enhancements_manifest.py | 36 +++- py/selenium/webdriver/common/bidi/browser.py | 38 ++-- .../webdriver/common/bidi/browsing_context.py | 68 +++--- py/selenium/webdriver/common/bidi/common.py | 6 +- .../webdriver/common/bidi/emulation.py | 36 ++-- py/selenium/webdriver/common/bidi/input.py | 30 +-- py/selenium/webdriver/common/bidi/log.py | 4 +- py/selenium/webdriver/common/bidi/network.py | 194 ++++++++++-------- py/selenium/webdriver/common/bidi/script.py | 154 +++++++------- py/selenium/webdriver/common/bidi/session.py | 30 +-- py/selenium/webdriver/common/bidi/storage.py | 14 +- .../webdriver/common/bidi/webextension.py | 12 +- py/selenium/webdriver/remote/webdriver.py | 8 +- 14 files changed, 386 insertions(+), 308 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 4bf0d8b64514e..5d7f39e53abfc 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -368,11 +368,14 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str dataclass_methods = enhancements.get("dataclass_methods", {}) method_docstrings = enhancements.get("method_docstrings", {}) - # Generate class name from type name (keep it as-is, don't split on underscores) - class_name = self.name + # Generate class name from type name. + # CDDL type names that start with a lowercase letter (e.g. camelCase + # command-parameter types like "setNetworkConditionsParameters") are + # capitalised so that the resulting Python class follows PascalCase. + class_name = self.name[0].upper() + self.name[1:] if self.name else self.name code = "@dataclass\n" code += f"class {class_name}:\n" - code += f' """{self.description or self.name}."""\n\n' + code += f' """{class_name} type definition."""\n\n' if not self.fields: code += " pass\n" @@ -466,9 +469,9 @@ def to_python_class(self) -> str: Generates a simple class with string constants to match the existing pattern in the codebase (e.g., ClientWindowState). """ - class_name = self.name + class_name = self.name[0].upper() + self.name[1:] if self.name else self.name code = f"class {class_name}:\n" - code += f' """{self.description or self.name}."""\n\n' + code += f' """{class_name}."""\n\n' for value in self.values: # Convert value to UPPER_SNAKE_CASE constant name @@ -684,8 +687,19 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Generate enums first + # Collect names of extra_dataclasses so we can skip CDDL-generated + # enums and types that are overridden by manual definitions. + extra_cls_names = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names + + # Generate enums first, skipping any that are overridden via extra_dataclasses for enum_def in self.enums: + if enum_def.name in exclude_types: + continue code += enum_def.to_python_class() code += "\n\n" @@ -694,13 +708,6 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - # Also auto-exclude types whose names appear in extra_dataclasses - extra_cls_names = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names for type_def in self.types: if type_def.name in exclude_types: continue @@ -1146,8 +1153,12 @@ def _remove_comments(self, content: str) -> str: def _extract_definitions(self, content: str) -> None: """Extract CDDL definitions (type definitions, commands, etc.).""" # Match pattern: Name = Definition - # Handles multiline definitions properly - pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + # Handles multiline definitions properly. + # The \s* after \n in the lookahead allows definitions that start with + # leading whitespace (e.g. " network.BeforeRequestSent = (") to be + # recognised as separate definitions instead of being swallowed into + # the body of the preceding definition. + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\s*\w+(?:\.\w+)?\s*=|\Z)" for match in re.finditer(pattern, content, re.DOTALL): name = match.group(1).strip() @@ -1589,12 +1600,15 @@ def generate_common_file(output_path: Path) -> None: "\n" '"""Common utilities for BiDi command construction."""\n' "\n" - "from typing import Any, Dict, Generator\n" + "from __future__ import annotations\n" + "\n" + "from collections.abc import Generator\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" - " method: str, params: Dict[str, Any]\n" - ") -> Generator[Dict[str, Any], Any, Any]:\n" + " method: str, params: dict[str, Any] | None = None\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" @@ -1607,6 +1621,8 @@ def generate_common_file(output_path: Path) -> None: " Returns:\n" " The result from the BiDi command execution\n" ' """\n' + " if params is None:\n" + " params = {}\n" ' result = yield {"method": method, "params": params}\n' " return result\n" ) @@ -1680,8 +1696,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" + "from __future__ import annotations\n" + "\n" "from enum import Enum\n" - "from typing import Any, Optional, Union\n" + "from typing import Any\n" "\n" "from .common import command_builder\n" "\n" @@ -1724,10 +1742,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" " def set_permission(\n" " self,\n" - " descriptor: Union[PermissionDescriptor, str],\n" - " state: Union[PermissionState, str],\n" - " origin: Optional[str] = None,\n" - " user_context: Optional[str] = None,\n" + " descriptor: PermissionDescriptor | str,\n" + " state: PermissionState | str,\n" + " origin: str | None = None,\n" + " user_context: str | None = None,\n" " ) -> None:\n" ' """Set a permission for a given origin.\n' "\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 39af67d4c635b..adf0a17128af3 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -81,6 +81,20 @@ "result_param": "download_behavior", }, }, + # Replace the auto-generated ClientWindowNamedState so we can add the + # convenience NORMAL constant. In the BiDi spec "normal" is the state + # represented by ClientWindowRectState, but exposing it here keeps the + # Python API consistent with the old ClientWindowState enum. + "exclude_types": ["ClientWindowNamedState"], + "extra_dataclasses": [ + '''class ClientWindowNamedState: + """Named states for a browser client window.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + NORMAL = "normal"''', + ], # Override the generator-produced set_download_behavior so that # downloadBehavior is never stripped by the generic None filter. # The BiDi spec marks it as required (can be null, but must be present). @@ -845,8 +859,11 @@ def from_json(self2, p): ], }, "network": { - # Initialize intercepts tracking list in __init__ - "extra_init_code": ["self.intercepts = []"], + # Initialize intercepts tracking list and per-handler intercept map + "extra_init_code": [ + "self.intercepts = []", + "self._handler_intercepts: dict = {}", + ], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ '''class BytesValue: @@ -940,7 +957,8 @@ def continue_request(self, **kwargs): "auth_required": "authRequired", } phase = phase_map.get(event, "beforeRequestSent") - self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_id = intercept_result.get("intercept") if intercept_result else None def _request_callback(params): raw = ( @@ -951,15 +969,21 @@ def _request_callback(params): request = Request(self._conn, raw) callback(request) - return self.add_event_handler(event, _request_callback)''', + callback_id = self.add_event_handler(event, _request_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id''', ''' def remove_request_handler(self, event, callback_id): - """Remove a network request handler. + """Remove a network request handler and its associated network intercept. Args: event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ - self.remove_event_handler(event, callback_id)''', + self.remove_event_handler(event, callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id)''', ''' def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index acda63f71953e..71f917634304d 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -60,17 +60,9 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") -class ClientWindowNamedState: - """ClientWindowNamedState.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - - @dataclass class ClientWindowInfo: - """ClientWindowInfo.""" + """ClientWindowInfo type definition.""" active: bool | None = None client_window: Any | None = None @@ -112,14 +104,14 @@ def get_y(self): @dataclass class UserContextInfo: - """UserContextInfo.""" + """UserContextInfo type definition.""" user_context: Any | None = None @dataclass class CreateUserContextParameters: - """CreateUserContextParameters.""" + """CreateUserContextParameters type definition.""" accept_insecure_certs: bool | None = None proxy: Any | None = None @@ -128,35 +120,35 @@ class CreateUserContextParameters: @dataclass class GetClientWindowsResult: - """GetClientWindowsResult.""" + """GetClientWindowsResult type definition.""" client_windows: list[Any | None] | None = field(default_factory=list) @dataclass class GetUserContextsResult: - """GetUserContextsResult.""" + """GetUserContextsResult type definition.""" user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class RemoveUserContextParameters: - """RemoveUserContextParameters.""" + """RemoveUserContextParameters type definition.""" user_context: Any | None = None @dataclass class SetClientWindowStateParameters: - """SetClientWindowStateParameters.""" + """SetClientWindowStateParameters type definition.""" client_window: Any | None = None @dataclass class ClientWindowRectState: - """ClientWindowRectState.""" + """ClientWindowRectState type definition.""" state: str = field(default="normal", init=False) width: Any | None = None @@ -167,7 +159,7 @@ class ClientWindowRectState: @dataclass class SetDownloadBehaviorParameters: - """SetDownloadBehaviorParameters.""" + """SetDownloadBehaviorParameters type definition.""" download_behavior: Any | None = None user_contexts: list[Any | None] | None = field(default_factory=list) @@ -175,7 +167,7 @@ class SetDownloadBehaviorParameters: @dataclass class DownloadBehaviorAllowed: - """DownloadBehaviorAllowed.""" + """DownloadBehaviorAllowed type definition.""" type: str = field(default="allowed", init=False) destination_folder: str | None = None @@ -183,11 +175,19 @@ class DownloadBehaviorAllowed: @dataclass class DownloadBehaviorDenied: - """DownloadBehaviorDenied.""" + """DownloadBehaviorDenied type definition.""" type: str = field(default="denied", init=False) +class ClientWindowNamedState: + """Named states for a browser client window.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + NORMAL = "normal" + class Browser: """WebDriver BiDi browser module.""" diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5f128635df29d..ede96071778c3 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -48,7 +48,7 @@ class DownloadCompleteParams: @dataclass class Info: - """Info.""" + """Info type definition.""" children: Any | None = None client_window: Any | None = None @@ -61,7 +61,7 @@ class Info: @dataclass class AccessibilityLocator: - """AccessibilityLocator.""" + """AccessibilityLocator type definition.""" type: str = field(default="accessibility", init=False) name: str | None = None @@ -70,7 +70,7 @@ class AccessibilityLocator: @dataclass class CssLocator: - """CssLocator.""" + """CssLocator type definition.""" type: str = field(default="css", init=False) value: str | None = None @@ -78,7 +78,7 @@ class CssLocator: @dataclass class ContextLocator: - """ContextLocator.""" + """ContextLocator type definition.""" type: str = field(default="context", init=False) context: Any | None = None @@ -86,7 +86,7 @@ class ContextLocator: @dataclass class InnerTextLocator: - """InnerTextLocator.""" + """InnerTextLocator type definition.""" type: str = field(default="innerText", init=False) value: str | None = None @@ -97,7 +97,7 @@ class InnerTextLocator: @dataclass class XPathLocator: - """XPathLocator.""" + """XPathLocator type definition.""" type: str = field(default="xpath", init=False) value: str | None = None @@ -105,7 +105,7 @@ class XPathLocator: @dataclass class BaseNavigationInfo: - """BaseNavigationInfo.""" + """BaseNavigationInfo type definition.""" context: Any | None = None navigation: Any | None = None @@ -115,14 +115,14 @@ class BaseNavigationInfo: @dataclass class ActivateParameters: - """ActivateParameters.""" + """ActivateParameters type definition.""" context: Any | None = None @dataclass class CaptureScreenshotParameters: - """CaptureScreenshotParameters.""" + """CaptureScreenshotParameters type definition.""" context: Any | None = None format: Any | None = None @@ -131,7 +131,7 @@ class CaptureScreenshotParameters: @dataclass class ImageFormat: - """ImageFormat.""" + """ImageFormat type definition.""" type: str | None = None quality: Any | None = None @@ -139,7 +139,7 @@ class ImageFormat: @dataclass class ElementClipRectangle: - """ElementClipRectangle.""" + """ElementClipRectangle type definition.""" type: str = field(default="element", init=False) element: Any | None = None @@ -147,7 +147,7 @@ class ElementClipRectangle: @dataclass class BoxClipRectangle: - """BoxClipRectangle.""" + """BoxClipRectangle type definition.""" type: str = field(default="box", init=False) x: Any | None = None @@ -158,14 +158,14 @@ class BoxClipRectangle: @dataclass class CaptureScreenshotResult: - """CaptureScreenshotResult.""" + """CaptureScreenshotResult type definition.""" data: str | None = None @dataclass class CloseParameters: - """CloseParameters.""" + """CloseParameters type definition.""" context: Any | None = None prompt_unload: bool | None = None @@ -173,7 +173,7 @@ class CloseParameters: @dataclass class CreateParameters: - """CreateParameters.""" + """CreateParameters type definition.""" type: Any | None = None reference_context: Any | None = None @@ -183,14 +183,14 @@ class CreateParameters: @dataclass class CreateResult: - """CreateResult.""" + """CreateResult type definition.""" context: Any | None = None @dataclass class GetTreeParameters: - """GetTreeParameters.""" + """GetTreeParameters type definition.""" max_depth: Any | None = None root: Any | None = None @@ -198,14 +198,14 @@ class GetTreeParameters: @dataclass class GetTreeResult: - """GetTreeResult.""" + """GetTreeResult type definition.""" contexts: Any | None = None @dataclass class HandleUserPromptParameters: - """HandleUserPromptParameters.""" + """HandleUserPromptParameters type definition.""" context: Any | None = None accept: bool | None = None @@ -214,7 +214,7 @@ class HandleUserPromptParameters: @dataclass class LocateNodesParameters: - """LocateNodesParameters.""" + """LocateNodesParameters type definition.""" context: Any | None = None locator: Any | None = None @@ -224,14 +224,14 @@ class LocateNodesParameters: @dataclass class LocateNodesResult: - """LocateNodesResult.""" + """LocateNodesResult type definition.""" nodes: list[Any | None] | None = field(default_factory=list) @dataclass class NavigateParameters: - """NavigateParameters.""" + """NavigateParameters type definition.""" context: Any | None = None url: str | None = None @@ -240,7 +240,7 @@ class NavigateParameters: @dataclass class NavigateResult: - """NavigateResult.""" + """NavigateResult type definition.""" navigation: Any | None = None url: str | None = None @@ -248,7 +248,7 @@ class NavigateResult: @dataclass class PrintParameters: - """PrintParameters.""" + """PrintParameters type definition.""" context: Any | None = None background: bool | None = None @@ -260,7 +260,7 @@ class PrintParameters: @dataclass class PrintMarginParameters: - """PrintMarginParameters.""" + """PrintMarginParameters type definition.""" bottom: Any | None = None left: Any | None = None @@ -270,7 +270,7 @@ class PrintMarginParameters: @dataclass class PrintPageParameters: - """PrintPageParameters.""" + """PrintPageParameters type definition.""" height: Any | None = None width: Any | None = None @@ -278,14 +278,14 @@ class PrintPageParameters: @dataclass class PrintResult: - """PrintResult.""" + """PrintResult type definition.""" data: str | None = None @dataclass class ReloadParameters: - """ReloadParameters.""" + """ReloadParameters type definition.""" context: Any | None = None ignore_cache: bool | None = None @@ -294,7 +294,7 @@ class ReloadParameters: @dataclass class SetViewportParameters: - """SetViewportParameters.""" + """SetViewportParameters type definition.""" context: Any | None = None viewport: Any | None = None @@ -304,7 +304,7 @@ class SetViewportParameters: @dataclass class Viewport: - """Viewport.""" + """Viewport type definition.""" width: Any | None = None height: Any | None = None @@ -312,7 +312,7 @@ class Viewport: @dataclass class TraverseHistoryParameters: - """TraverseHistoryParameters.""" + """TraverseHistoryParameters type definition.""" context: Any | None = None delta: Any | None = None @@ -320,7 +320,7 @@ class TraverseHistoryParameters: @dataclass class HistoryUpdatedParameters: - """HistoryUpdatedParameters.""" + """HistoryUpdatedParameters type definition.""" context: Any | None = None timestamp: Any | None = None @@ -329,7 +329,7 @@ class HistoryUpdatedParameters: @dataclass class UserPromptClosedParameters: - """UserPromptClosedParameters.""" + """UserPromptClosedParameters type definition.""" context: Any | None = None accepted: bool | None = None @@ -339,7 +339,7 @@ class UserPromptClosedParameters: @dataclass class UserPromptOpenedParameters: - """UserPromptOpenedParameters.""" + """UserPromptOpenedParameters type definition.""" context: Any | None = None handler: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d7cb436a08471..dae051876833e 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,14 @@ """Common utilities for BiDi command construction.""" +from __future__ import annotations + from collections.abc import Generator from typing import Any def command_builder( - method: str, params: dict[str, Any] + method: str, params: dict[str, Any] | None = None ) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. @@ -36,5 +38,7 @@ def command_builder( Returns: The result from the BiDi command execution """ + if params is None: + params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index cb575bbdc54dd..fbbe0966d8b3a 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -37,7 +37,7 @@ class ScreenOrientationType: @dataclass class SetForcedColorsModeThemeOverrideParameters: - """SetForcedColorsModeThemeOverrideParameters.""" + """SetForcedColorsModeThemeOverrideParameters type definition.""" theme: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -46,7 +46,7 @@ class SetForcedColorsModeThemeOverrideParameters: @dataclass class SetGeolocationOverrideParameters: - """SetGeolocationOverrideParameters.""" + """SetGeolocationOverrideParameters type definition.""" contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) @@ -54,7 +54,7 @@ class SetGeolocationOverrideParameters: @dataclass class GeolocationCoordinates: - """GeolocationCoordinates.""" + """GeolocationCoordinates type definition.""" latitude: Any | None = None longitude: Any | None = None @@ -67,14 +67,14 @@ class GeolocationCoordinates: @dataclass class GeolocationPositionError: - """GeolocationPositionError.""" + """GeolocationPositionError type definition.""" type: str = field(default="positionUnavailable", init=False) @dataclass class SetLocaleOverrideParameters: - """SetLocaleOverrideParameters.""" + """SetLocaleOverrideParameters type definition.""" locale: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -82,8 +82,8 @@ class SetLocaleOverrideParameters: @dataclass -class setNetworkConditionsParameters: - """setNetworkConditionsParameters.""" +class SetNetworkConditionsParameters: + """SetNetworkConditionsParameters type definition.""" network_conditions: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -92,14 +92,14 @@ class setNetworkConditionsParameters: @dataclass class NetworkConditionsOffline: - """NetworkConditionsOffline.""" + """NetworkConditionsOffline type definition.""" type: str = field(default="offline", init=False) @dataclass class ScreenArea: - """ScreenArea.""" + """ScreenArea type definition.""" width: Any | None = None height: Any | None = None @@ -107,7 +107,7 @@ class ScreenArea: @dataclass class SetScreenSettingsOverrideParameters: - """SetScreenSettingsOverrideParameters.""" + """SetScreenSettingsOverrideParameters type definition.""" screen_area: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -116,7 +116,7 @@ class SetScreenSettingsOverrideParameters: @dataclass class ScreenOrientation: - """ScreenOrientation.""" + """ScreenOrientation type definition.""" natural: Any | None = None type: Any | None = None @@ -124,7 +124,7 @@ class ScreenOrientation: @dataclass class SetScreenOrientationOverrideParameters: - """SetScreenOrientationOverrideParameters.""" + """SetScreenOrientationOverrideParameters type definition.""" screen_orientation: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -133,7 +133,7 @@ class SetScreenOrientationOverrideParameters: @dataclass class SetUserAgentOverrideParameters: - """SetUserAgentOverrideParameters.""" + """SetUserAgentOverrideParameters type definition.""" user_agent: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -142,7 +142,7 @@ class SetUserAgentOverrideParameters: @dataclass class SetViewportMetaOverrideParameters: - """SetViewportMetaOverrideParameters.""" + """SetViewportMetaOverrideParameters type definition.""" viewport_meta: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -151,7 +151,7 @@ class SetViewportMetaOverrideParameters: @dataclass class SetScriptingEnabledParameters: - """SetScriptingEnabledParameters.""" + """SetScriptingEnabledParameters type definition.""" enabled: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -160,7 +160,7 @@ class SetScriptingEnabledParameters: @dataclass class SetScrollbarTypeOverrideParameters: - """SetScrollbarTypeOverrideParameters.""" + """SetScrollbarTypeOverrideParameters type definition.""" scrollbar_type: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -169,7 +169,7 @@ class SetScrollbarTypeOverrideParameters: @dataclass class SetTimezoneOverrideParameters: - """SetTimezoneOverrideParameters.""" + """SetTimezoneOverrideParameters type definition.""" timezone: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -178,7 +178,7 @@ class SetTimezoneOverrideParameters: @dataclass class SetTouchOverrideParameters: - """SetTouchOverrideParameters.""" + """SetTouchOverrideParameters type definition.""" contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 13f43361293f2..c8e58181b343e 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -33,7 +33,7 @@ class Origin: @dataclass class ElementOrigin: - """ElementOrigin.""" + """ElementOrigin type definition.""" type: str = field(default="element", init=False) element: Any | None = None @@ -41,7 +41,7 @@ class ElementOrigin: @dataclass class PerformActionsParameters: - """PerformActionsParameters.""" + """PerformActionsParameters type definition.""" context: Any | None = None actions: list[Any | None] | None = field(default_factory=list) @@ -49,7 +49,7 @@ class PerformActionsParameters: @dataclass class NoneSourceActions: - """NoneSourceActions.""" + """NoneSourceActions type definition.""" type: str = field(default="none", init=False) id: str | None = None @@ -58,7 +58,7 @@ class NoneSourceActions: @dataclass class KeySourceActions: - """KeySourceActions.""" + """KeySourceActions type definition.""" type: str = field(default="key", init=False) id: str | None = None @@ -67,7 +67,7 @@ class KeySourceActions: @dataclass class PointerSourceActions: - """PointerSourceActions.""" + """PointerSourceActions type definition.""" type: str = field(default="pointer", init=False) id: str | None = None @@ -77,14 +77,14 @@ class PointerSourceActions: @dataclass class PointerParameters: - """PointerParameters.""" + """PointerParameters type definition.""" pointer_type: Any | None = None @dataclass class WheelSourceActions: - """WheelSourceActions.""" + """WheelSourceActions type definition.""" type: str = field(default="wheel", init=False) id: str | None = None @@ -93,7 +93,7 @@ class WheelSourceActions: @dataclass class PauseAction: - """PauseAction.""" + """PauseAction type definition.""" type: str = field(default="pause", init=False) duration: Any | None = None @@ -101,7 +101,7 @@ class PauseAction: @dataclass class KeyDownAction: - """KeyDownAction.""" + """KeyDownAction type definition.""" type: str = field(default="keyDown", init=False) value: str | None = None @@ -109,7 +109,7 @@ class KeyDownAction: @dataclass class KeyUpAction: - """KeyUpAction.""" + """KeyUpAction type definition.""" type: str = field(default="keyUp", init=False) value: str | None = None @@ -117,7 +117,7 @@ class KeyUpAction: @dataclass class PointerUpAction: - """PointerUpAction.""" + """PointerUpAction type definition.""" type: str = field(default="pointerUp", init=False) button: Any | None = None @@ -125,7 +125,7 @@ class PointerUpAction: @dataclass class WheelScrollAction: - """WheelScrollAction.""" + """WheelScrollAction type definition.""" type: str = field(default="scroll", init=False) x: Any | None = None @@ -138,7 +138,7 @@ class WheelScrollAction: @dataclass class PointerCommonProperties: - """PointerCommonProperties.""" + """PointerCommonProperties type definition.""" width: Any | None = None height: Any | None = None @@ -151,14 +151,14 @@ class PointerCommonProperties: @dataclass class ReleaseActionsParameters: - """ReleaseActionsParameters.""" + """ReleaseActionsParameters type definition.""" context: Any | None = None @dataclass class SetFilesParameters: - """SetFilesParameters.""" + """SetFilesParameters type definition.""" context: Any | None = None element: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 7971b807e94a1..eaf52a2ec08c2 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -27,7 +27,7 @@ class Level: @dataclass class BaseLogEntry: - """BaseLogEntry.""" + """BaseLogEntry type definition.""" level: Any | None = None source: Any | None = None @@ -38,7 +38,7 @@ class BaseLogEntry: @dataclass class GenericLogEntry: - """GenericLogEntry.""" + """GenericLogEntry type definition.""" type: str | None = None diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 6e02eeabc4ed7..c9737ac9131d0 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -49,7 +49,7 @@ class ContinueWithAuthNoCredentials: @dataclass class AuthChallenge: - """AuthChallenge.""" + """AuthChallenge type definition.""" scheme: str | None = None realm: str | None = None @@ -57,7 +57,7 @@ class AuthChallenge: @dataclass class AuthCredentials: - """AuthCredentials.""" + """AuthCredentials type definition.""" type: str = field(default="password", init=False) username: str | None = None @@ -66,7 +66,7 @@ class AuthCredentials: @dataclass class BaseParameters: - """BaseParameters.""" + """BaseParameters type definition.""" context: Any | None = None is_blocked: bool | None = None @@ -79,7 +79,7 @@ class BaseParameters: @dataclass class StringValue: - """StringValue.""" + """StringValue type definition.""" type: str = field(default="string", init=False) value: str | None = None @@ -87,7 +87,7 @@ class StringValue: @dataclass class Base64Value: - """Base64Value.""" + """Base64Value type definition.""" type: str = field(default="base64", init=False) value: str | None = None @@ -95,7 +95,7 @@ class Base64Value: @dataclass class Cookie: - """Cookie.""" + """Cookie type definition.""" name: str | None = None value: Any | None = None @@ -110,7 +110,7 @@ class Cookie: @dataclass class CookieHeader: - """CookieHeader.""" + """CookieHeader type definition.""" name: str | None = None value: Any | None = None @@ -118,7 +118,7 @@ class CookieHeader: @dataclass class FetchTimingInfo: - """FetchTimingInfo.""" + """FetchTimingInfo type definition.""" time_origin: Any | None = None request_time: Any | None = None @@ -137,7 +137,7 @@ class FetchTimingInfo: @dataclass class Header: - """Header.""" + """Header type definition.""" name: str | None = None value: Any | None = None @@ -145,7 +145,7 @@ class Header: @dataclass class Initiator: - """Initiator.""" + """Initiator type definition.""" column_number: Any | None = None line_number: Any | None = None @@ -156,14 +156,14 @@ class Initiator: @dataclass class ResponseContent: - """ResponseContent.""" + """ResponseContent type definition.""" size: Any | None = None @dataclass class ResponseData: - """ResponseData.""" + """ResponseData type definition.""" url: str | None = None protocol: str | None = None @@ -181,7 +181,7 @@ class ResponseData: @dataclass class SetCookieHeader: - """SetCookieHeader.""" + """SetCookieHeader type definition.""" name: str | None = None value: Any | None = None @@ -196,7 +196,7 @@ class SetCookieHeader: @dataclass class UrlPatternPattern: - """UrlPatternPattern.""" + """UrlPatternPattern type definition.""" type: str = field(default="pattern", init=False) protocol: str | None = None @@ -208,7 +208,7 @@ class UrlPatternPattern: @dataclass class UrlPatternString: - """UrlPatternString.""" + """UrlPatternString type definition.""" type: str = field(default="string", init=False) pattern: str | None = None @@ -216,7 +216,7 @@ class UrlPatternString: @dataclass class AddDataCollectorParameters: - """AddDataCollectorParameters.""" + """AddDataCollectorParameters type definition.""" data_types: list[Any | None] | None = field(default_factory=list) max_encoded_data_size: Any | None = None @@ -227,14 +227,14 @@ class AddDataCollectorParameters: @dataclass class AddDataCollectorResult: - """AddDataCollectorResult.""" + """AddDataCollectorResult type definition.""" collector: Any | None = None @dataclass class AddInterceptParameters: - """AddInterceptParameters.""" + """AddInterceptParameters type definition.""" phases: list[Any | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) @@ -243,14 +243,14 @@ class AddInterceptParameters: @dataclass class AddInterceptResult: - """AddInterceptResult.""" + """AddInterceptResult type definition.""" intercept: Any | None = None @dataclass class ContinueResponseParameters: - """ContinueResponseParameters.""" + """ContinueResponseParameters type definition.""" request: Any | None = None cookies: list[Any | None] | None = field(default_factory=list) @@ -262,22 +262,22 @@ class ContinueResponseParameters: @dataclass class ContinueWithAuthParameters: - """ContinueWithAuthParameters.""" + """ContinueWithAuthParameters type definition.""" request: Any | None = None @dataclass class ContinueWithAuthCredentials: - """ContinueWithAuthCredentials.""" + """ContinueWithAuthCredentials type definition.""" action: str = field(default="provideCredentials", init=False) credentials: Any | None = None @dataclass -class disownDataParameters: - """disownDataParameters.""" +class DisownDataParameters: + """DisownDataParameters type definition.""" data_type: Any | None = None collector: Any | None = None @@ -286,14 +286,14 @@ class disownDataParameters: @dataclass class FailRequestParameters: - """FailRequestParameters.""" + """FailRequestParameters type definition.""" request: Any | None = None @dataclass class GetDataParameters: - """GetDataParameters.""" + """GetDataParameters type definition.""" data_type: Any | None = None collector: Any | None = None @@ -303,14 +303,14 @@ class GetDataParameters: @dataclass class GetDataResult: - """GetDataResult.""" + """GetDataResult type definition.""" bytes: Any | None = None @dataclass class ProvideResponseParameters: - """ProvideResponseParameters.""" + """ProvideResponseParameters type definition.""" request: Any | None = None body: Any | None = None @@ -322,21 +322,21 @@ class ProvideResponseParameters: @dataclass class RemoveDataCollectorParameters: - """RemoveDataCollectorParameters.""" + """RemoveDataCollectorParameters type definition.""" collector: Any | None = None @dataclass class RemoveInterceptParameters: - """RemoveInterceptParameters.""" + """RemoveInterceptParameters type definition.""" intercept: Any | None = None @dataclass class SetCacheBehaviorParameters: - """SetCacheBehaviorParameters.""" + """SetCacheBehaviorParameters type definition.""" cache_behavior: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -344,16 +344,44 @@ class SetCacheBehaviorParameters: @dataclass class SetExtraHeadersParameters: - """SetExtraHeadersParameters.""" + """SetExtraHeadersParameters type definition.""" headers: list[Any | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) +@dataclass +class AuthRequiredParameters: + """AuthRequiredParameters type definition.""" + + response: Any | None = None + + +@dataclass +class BeforeRequestSentParameters: + """BeforeRequestSentParameters type definition.""" + + initiator: Any | None = None + + +@dataclass +class FetchErrorParameters: + """FetchErrorParameters type definition.""" + + error_text: str | None = None + + +@dataclass +class ResponseCompletedParameters: + """ResponseCompletedParameters type definition.""" + + response: Any | None = None + + @dataclass class ResponseStartedParameters: - """ResponseStartedParameters.""" + """ResponseStartedParameters type definition.""" response: Any | None = None @@ -396,6 +424,10 @@ def continue_request(self, **kwargs): # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", + "before_request_sent": "network.beforeRequestSent", + "fetch_error": "network.fetchError", + "response_completed": "network.responseCompleted", + "response_started": "network.responseStarted", "before_request": "network.beforeRequestSent", } @@ -560,6 +592,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) self.intercepts = [] + self._handler_intercepts: dict = {} def add_data_collector( self, @@ -767,52 +800,6 @@ def set_extra_headers( result = self._conn.execute(cmd) return result - def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.beforeRequestSent.""" - params = { - "initiator": initiator, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.beforeRequestSent", params) - result = self._conn.execute(cmd) - return result - - def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.fetchError.""" - params = { - "errorText": error_text, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.fetchError", params) - result = self._conn.execute(cmd) - return result - - def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.responseCompleted.""" - params = { - "response": response, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.responseCompleted", params) - result = self._conn.execute(cmd) - return result - - def response_started(self, response: Any | None = None): - """Execute network.responseStarted.""" - params = { - "response": response, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.responseStarted", params) - result = self._conn.execute(cmd) - return result - def _add_intercept(self, phases=None, url_patterns=None): """Add a low-level network intercept. @@ -861,7 +848,8 @@ def add_request_handler(self, event, callback, url_patterns=None): "auth_required": "authRequired", } phase = phase_map.get(event, "beforeRequestSent") - self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_id = intercept_result.get("intercept") if intercept_result else None def _request_callback(params): raw = ( @@ -872,15 +860,21 @@ def _request_callback(params): request = Request(self._conn, raw) callback(request) - return self.add_event_handler(event, _request_callback) + callback_id = self.add_event_handler(event, _request_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id def remove_request_handler(self, event, callback_id): - """Remove a network request handler. + """Remove a network request handler and its associated network intercept. Args: event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ self.remove_event_handler(event, callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id) def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() @@ -960,6 +954,18 @@ def clear_event_handlers(self) -> None: # Event: network.authRequired AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined +# Event: network.beforeRequestSent +BeforeRequestSent = globals().get('BeforeRequestSentParameters', dict) # Fallback to dict if type not defined + +# Event: network.fetchError +FetchError = globals().get('FetchErrorParameters', dict) # Fallback to dict if type not defined + +# Event: network.responseCompleted +ResponseCompleted = globals().get('ResponseCompletedParameters', dict) # Fallback to dict if type not defined + +# Event: network.responseStarted +ResponseStarted = globals().get('ResponseStartedParameters', dict) # Fallback to dict if type not defined + # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() @@ -970,5 +976,29 @@ def clear_event_handlers(self) -> None: if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict) ), + "before_request_sent": ( + EventConfig("before_request_sent", "network.beforeRequestSent", + _globals.get("BeforeRequestSent", dict)) + if _globals.get("BeforeRequestSent") + else EventConfig("before_request_sent", "network.beforeRequestSent", dict) + ), + "fetch_error": ( + EventConfig("fetch_error", "network.fetchError", + _globals.get("FetchError", dict)) + if _globals.get("FetchError") + else EventConfig("fetch_error", "network.fetchError", dict) + ), + "response_completed": ( + EventConfig("response_completed", "network.responseCompleted", + _globals.get("ResponseCompleted", dict)) + if _globals.get("ResponseCompleted") + else EventConfig("response_completed", "network.responseCompleted", dict) + ), + "response_started": ( + EventConfig("response_started", "network.responseStarted", + _globals.get("ResponseStarted", dict)) + if _globals.get("ResponseStarted") + else EventConfig("response_started", "network.responseStarted", dict) + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index b29721db88503..061bb17b0deec 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -47,7 +47,7 @@ class ResultOwnership: @dataclass class ChannelValue: - """ChannelValue.""" + """ChannelValue type definition.""" type: str = field(default="channel", init=False) value: Any | None = None @@ -55,7 +55,7 @@ class ChannelValue: @dataclass class ChannelProperties: - """ChannelProperties.""" + """ChannelProperties type definition.""" channel: Any | None = None serialization_options: Any | None = None @@ -64,7 +64,7 @@ class ChannelProperties: @dataclass class EvaluateResultSuccess: - """EvaluateResultSuccess.""" + """EvaluateResultSuccess type definition.""" type: str = field(default="success", init=False) result: Any | None = None @@ -73,7 +73,7 @@ class EvaluateResultSuccess: @dataclass class EvaluateResultException: - """EvaluateResultException.""" + """EvaluateResultException type definition.""" type: str = field(default="exception", init=False) exception_details: Any | None = None @@ -82,7 +82,7 @@ class EvaluateResultException: @dataclass class ExceptionDetails: - """ExceptionDetails.""" + """ExceptionDetails type definition.""" column_number: Any | None = None exception: Any | None = None @@ -93,7 +93,7 @@ class ExceptionDetails: @dataclass class ArrayLocalValue: - """ArrayLocalValue.""" + """ArrayLocalValue type definition.""" type: str = field(default="array", init=False) value: Any | None = None @@ -101,7 +101,7 @@ class ArrayLocalValue: @dataclass class DateLocalValue: - """DateLocalValue.""" + """DateLocalValue type definition.""" type: str = field(default="date", init=False) value: str | None = None @@ -109,7 +109,7 @@ class DateLocalValue: @dataclass class MapLocalValue: - """MapLocalValue.""" + """MapLocalValue type definition.""" type: str = field(default="map", init=False) value: Any | None = None @@ -117,7 +117,7 @@ class MapLocalValue: @dataclass class ObjectLocalValue: - """ObjectLocalValue.""" + """ObjectLocalValue type definition.""" type: str = field(default="object", init=False) value: Any | None = None @@ -125,7 +125,7 @@ class ObjectLocalValue: @dataclass class RegExpValue: - """RegExpValue.""" + """RegExpValue type definition.""" pattern: str | None = None flags: str | None = None @@ -133,7 +133,7 @@ class RegExpValue: @dataclass class RegExpLocalValue: - """RegExpLocalValue.""" + """RegExpLocalValue type definition.""" type: str = field(default="regexp", init=False) value: Any | None = None @@ -141,7 +141,7 @@ class RegExpLocalValue: @dataclass class SetLocalValue: - """SetLocalValue.""" + """SetLocalValue type definition.""" type: str = field(default="set", init=False) value: Any | None = None @@ -149,21 +149,21 @@ class SetLocalValue: @dataclass class UndefinedValue: - """UndefinedValue.""" + """UndefinedValue type definition.""" type: str = field(default="undefined", init=False) @dataclass class NullValue: - """NullValue.""" + """NullValue type definition.""" type: str = field(default="null", init=False) @dataclass class StringValue: - """StringValue.""" + """StringValue type definition.""" type: str = field(default="string", init=False) value: str | None = None @@ -171,7 +171,7 @@ class StringValue: @dataclass class NumberValue: - """NumberValue.""" + """NumberValue type definition.""" type: str = field(default="number", init=False) value: Any | None = None @@ -179,7 +179,7 @@ class NumberValue: @dataclass class BooleanValue: - """BooleanValue.""" + """BooleanValue type definition.""" type: str = field(default="boolean", init=False) value: bool | None = None @@ -187,7 +187,7 @@ class BooleanValue: @dataclass class BigIntValue: - """BigIntValue.""" + """BigIntValue type definition.""" type: str = field(default="bigint", init=False) value: str | None = None @@ -195,7 +195,7 @@ class BigIntValue: @dataclass class BaseRealmInfo: - """BaseRealmInfo.""" + """BaseRealmInfo type definition.""" realm: Any | None = None origin: str | None = None @@ -203,7 +203,7 @@ class BaseRealmInfo: @dataclass class WindowRealmInfo: - """WindowRealmInfo.""" + """WindowRealmInfo type definition.""" type: str = field(default="window", init=False) context: Any | None = None @@ -212,7 +212,7 @@ class WindowRealmInfo: @dataclass class DedicatedWorkerRealmInfo: - """DedicatedWorkerRealmInfo.""" + """DedicatedWorkerRealmInfo type definition.""" type: str = field(default="dedicated-worker", init=False) owners: list[Any | None] | None = field(default_factory=list) @@ -220,49 +220,49 @@ class DedicatedWorkerRealmInfo: @dataclass class SharedWorkerRealmInfo: - """SharedWorkerRealmInfo.""" + """SharedWorkerRealmInfo type definition.""" type: str = field(default="shared-worker", init=False) @dataclass class ServiceWorkerRealmInfo: - """ServiceWorkerRealmInfo.""" + """ServiceWorkerRealmInfo type definition.""" type: str = field(default="service-worker", init=False) @dataclass class WorkerRealmInfo: - """WorkerRealmInfo.""" + """WorkerRealmInfo type definition.""" type: str = field(default="worker", init=False) @dataclass class PaintWorkletRealmInfo: - """PaintWorkletRealmInfo.""" + """PaintWorkletRealmInfo type definition.""" type: str = field(default="paint-worklet", init=False) @dataclass class AudioWorkletRealmInfo: - """AudioWorkletRealmInfo.""" + """AudioWorkletRealmInfo type definition.""" type: str = field(default="audio-worklet", init=False) @dataclass class WorkletRealmInfo: - """WorkletRealmInfo.""" + """WorkletRealmInfo type definition.""" type: str = field(default="worklet", init=False) @dataclass class SharedReference: - """SharedReference.""" + """SharedReference type definition.""" shared_id: Any | None = None handle: Any | None = None @@ -270,7 +270,7 @@ class SharedReference: @dataclass class RemoteObjectReference: - """RemoteObjectReference.""" + """RemoteObjectReference type definition.""" handle: Any | None = None shared_id: Any | None = None @@ -278,7 +278,7 @@ class RemoteObjectReference: @dataclass class SymbolRemoteValue: - """SymbolRemoteValue.""" + """SymbolRemoteValue type definition.""" type: str = field(default="symbol", init=False) handle: Any | None = None @@ -287,7 +287,7 @@ class SymbolRemoteValue: @dataclass class ArrayRemoteValue: - """ArrayRemoteValue.""" + """ArrayRemoteValue type definition.""" type: str = field(default="array", init=False) handle: Any | None = None @@ -297,7 +297,7 @@ class ArrayRemoteValue: @dataclass class ObjectRemoteValue: - """ObjectRemoteValue.""" + """ObjectRemoteValue type definition.""" type: str = field(default="object", init=False) handle: Any | None = None @@ -307,7 +307,7 @@ class ObjectRemoteValue: @dataclass class FunctionRemoteValue: - """FunctionRemoteValue.""" + """FunctionRemoteValue type definition.""" type: str = field(default="function", init=False) handle: Any | None = None @@ -316,7 +316,7 @@ class FunctionRemoteValue: @dataclass class RegExpRemoteValue: - """RegExpRemoteValue.""" + """RegExpRemoteValue type definition.""" handle: Any | None = None internal_id: Any | None = None @@ -324,7 +324,7 @@ class RegExpRemoteValue: @dataclass class DateRemoteValue: - """DateRemoteValue.""" + """DateRemoteValue type definition.""" handle: Any | None = None internal_id: Any | None = None @@ -332,7 +332,7 @@ class DateRemoteValue: @dataclass class MapRemoteValue: - """MapRemoteValue.""" + """MapRemoteValue type definition.""" type: str = field(default="map", init=False) handle: Any | None = None @@ -342,7 +342,7 @@ class MapRemoteValue: @dataclass class SetRemoteValue: - """SetRemoteValue.""" + """SetRemoteValue type definition.""" type: str = field(default="set", init=False) handle: Any | None = None @@ -352,7 +352,7 @@ class SetRemoteValue: @dataclass class WeakMapRemoteValue: - """WeakMapRemoteValue.""" + """WeakMapRemoteValue type definition.""" type: str = field(default="weakmap", init=False) handle: Any | None = None @@ -361,7 +361,7 @@ class WeakMapRemoteValue: @dataclass class WeakSetRemoteValue: - """WeakSetRemoteValue.""" + """WeakSetRemoteValue type definition.""" type: str = field(default="weakset", init=False) handle: Any | None = None @@ -370,7 +370,7 @@ class WeakSetRemoteValue: @dataclass class GeneratorRemoteValue: - """GeneratorRemoteValue.""" + """GeneratorRemoteValue type definition.""" type: str = field(default="generator", init=False) handle: Any | None = None @@ -379,7 +379,7 @@ class GeneratorRemoteValue: @dataclass class ErrorRemoteValue: - """ErrorRemoteValue.""" + """ErrorRemoteValue type definition.""" type: str = field(default="error", init=False) handle: Any | None = None @@ -388,7 +388,7 @@ class ErrorRemoteValue: @dataclass class ProxyRemoteValue: - """ProxyRemoteValue.""" + """ProxyRemoteValue type definition.""" type: str = field(default="proxy", init=False) handle: Any | None = None @@ -397,7 +397,7 @@ class ProxyRemoteValue: @dataclass class PromiseRemoteValue: - """PromiseRemoteValue.""" + """PromiseRemoteValue type definition.""" type: str = field(default="promise", init=False) handle: Any | None = None @@ -406,7 +406,7 @@ class PromiseRemoteValue: @dataclass class TypedArrayRemoteValue: - """TypedArrayRemoteValue.""" + """TypedArrayRemoteValue type definition.""" type: str = field(default="typedarray", init=False) handle: Any | None = None @@ -415,7 +415,7 @@ class TypedArrayRemoteValue: @dataclass class ArrayBufferRemoteValue: - """ArrayBufferRemoteValue.""" + """ArrayBufferRemoteValue type definition.""" type: str = field(default="arraybuffer", init=False) handle: Any | None = None @@ -424,7 +424,7 @@ class ArrayBufferRemoteValue: @dataclass class NodeListRemoteValue: - """NodeListRemoteValue.""" + """NodeListRemoteValue type definition.""" type: str = field(default="nodelist", init=False) handle: Any | None = None @@ -434,7 +434,7 @@ class NodeListRemoteValue: @dataclass class HTMLCollectionRemoteValue: - """HTMLCollectionRemoteValue.""" + """HTMLCollectionRemoteValue type definition.""" type: str = field(default="htmlcollection", init=False) handle: Any | None = None @@ -444,7 +444,7 @@ class HTMLCollectionRemoteValue: @dataclass class NodeRemoteValue: - """NodeRemoteValue.""" + """NodeRemoteValue type definition.""" type: str = field(default="node", init=False) shared_id: Any | None = None @@ -455,7 +455,7 @@ class NodeRemoteValue: @dataclass class NodeProperties: - """NodeProperties.""" + """NodeProperties type definition.""" node_type: Any | None = None child_node_count: Any | None = None @@ -469,7 +469,7 @@ class NodeProperties: @dataclass class WindowProxyRemoteValue: - """WindowProxyRemoteValue.""" + """WindowProxyRemoteValue type definition.""" type: str = field(default="window", init=False) value: Any | None = None @@ -479,14 +479,14 @@ class WindowProxyRemoteValue: @dataclass class WindowProxyProperties: - """WindowProxyProperties.""" + """WindowProxyProperties type definition.""" context: Any | None = None @dataclass class StackFrame: - """StackFrame.""" + """StackFrame type definition.""" column_number: Any | None = None function_name: str | None = None @@ -496,14 +496,14 @@ class StackFrame: @dataclass class StackTrace: - """StackTrace.""" + """StackTrace type definition.""" call_frames: list[Any | None] | None = field(default_factory=list) @dataclass class Source: - """Source.""" + """Source type definition.""" realm: Any | None = None context: Any | None = None @@ -511,14 +511,14 @@ class Source: @dataclass class RealmTarget: - """RealmTarget.""" + """RealmTarget type definition.""" realm: Any | None = None @dataclass class ContextTarget: - """ContextTarget.""" + """ContextTarget type definition.""" context: Any | None = None sandbox: str | None = None @@ -526,7 +526,7 @@ class ContextTarget: @dataclass class AddPreloadScriptParameters: - """AddPreloadScriptParameters.""" + """AddPreloadScriptParameters type definition.""" function_declaration: str | None = None arguments: list[Any | None] | None = field(default_factory=list) @@ -537,14 +537,14 @@ class AddPreloadScriptParameters: @dataclass class AddPreloadScriptResult: - """AddPreloadScriptResult.""" + """AddPreloadScriptResult type definition.""" script: Any | None = None @dataclass class DisownParameters: - """DisownParameters.""" + """DisownParameters type definition.""" handles: list[Any | None] | None = field(default_factory=list) target: Any | None = None @@ -552,7 +552,7 @@ class DisownParameters: @dataclass class CallFunctionParameters: - """CallFunctionParameters.""" + """CallFunctionParameters type definition.""" function_declaration: str | None = None await_promise: bool | None = None @@ -566,7 +566,7 @@ class CallFunctionParameters: @dataclass class EvaluateParameters: - """EvaluateParameters.""" + """EvaluateParameters type definition.""" expression: str | None = None target: Any | None = None @@ -578,7 +578,7 @@ class EvaluateParameters: @dataclass class GetRealmsParameters: - """GetRealmsParameters.""" + """GetRealmsParameters type definition.""" context: Any | None = None type: Any | None = None @@ -586,21 +586,21 @@ class GetRealmsParameters: @dataclass class GetRealmsResult: - """GetRealmsResult.""" + """GetRealmsResult type definition.""" realms: list[Any | None] | None = field(default_factory=list) @dataclass class RemovePreloadScriptParameters: - """RemovePreloadScriptParameters.""" + """RemovePreloadScriptParameters type definition.""" script: Any | None = None @dataclass class MessageParameters: - """MessageParameters.""" + """MessageParameters type definition.""" channel: Any | None = None data: Any | None = None @@ -609,13 +609,14 @@ class MessageParameters: @dataclass class RealmDestroyedParameters: - """RealmDestroyedParameters.""" + """RealmDestroyedParameters type definition.""" realm: Any | None = None # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { + "message": "script.message", "realm_created": "script.realmCreated", "realm_destroyed": "script.realmDestroyed", } @@ -885,18 +886,6 @@ def remove_preload_script(self, script: Any | None = None): result = self._conn.execute(cmd) return result - def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): - """Execute script.message.""" - params = { - "channel": channel, - "data": data, - "source": source, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.message", params) - result = self._conn.execute(cmd) - return result - def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: """Execute a function declaration in the browser context. @@ -1298,6 +1287,9 @@ def clear_event_handlers(self) -> None: return self._event_manager.clear_event_handlers() # Event Info Type Aliases +# Event: script.message +Message = globals().get('MessageParameters', dict) # Fallback to dict if type not defined + # Event: script.realmCreated RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined @@ -1308,6 +1300,12 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { + "message": ( + EventConfig("message", "script.message", + _globals.get("Message", dict)) + if _globals.get("Message") + else EventConfig("message", "script.message", dict) + ), "realm_created": ( EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index c1b5be09ca024..da12c1cd49792 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -22,7 +22,7 @@ class UserPromptHandlerType: @dataclass class CapabilitiesRequest: - """CapabilitiesRequest.""" + """CapabilitiesRequest type definition.""" always_match: Any | None = None first_match: list[Any | None] | None = field(default_factory=list) @@ -30,7 +30,7 @@ class CapabilitiesRequest: @dataclass class CapabilityRequest: - """CapabilityRequest.""" + """CapabilityRequest type definition.""" accept_insecure_certs: bool | None = None browser_name: str | None = None @@ -42,21 +42,21 @@ class CapabilityRequest: @dataclass class AutodetectProxyConfiguration: - """AutodetectProxyConfiguration.""" + """AutodetectProxyConfiguration type definition.""" proxy_type: str = field(default="autodetect", init=False) @dataclass class DirectProxyConfiguration: - """DirectProxyConfiguration.""" + """DirectProxyConfiguration type definition.""" proxy_type: str = field(default="direct", init=False) @dataclass class ManualProxyConfiguration: - """ManualProxyConfiguration.""" + """ManualProxyConfiguration type definition.""" proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None @@ -66,7 +66,7 @@ class ManualProxyConfiguration: @dataclass class SocksProxyConfiguration: - """SocksProxyConfiguration.""" + """SocksProxyConfiguration type definition.""" socks_proxy: str | None = None socks_version: Any | None = None @@ -74,7 +74,7 @@ class SocksProxyConfiguration: @dataclass class PacProxyConfiguration: - """PacProxyConfiguration.""" + """PacProxyConfiguration type definition.""" proxy_type: str = field(default="pac", init=False) proxy_autoconfig_url: str | None = None @@ -82,14 +82,14 @@ class PacProxyConfiguration: @dataclass class SystemProxyConfiguration: - """SystemProxyConfiguration.""" + """SystemProxyConfiguration type definition.""" proxy_type: str = field(default="system", init=False) @dataclass class SubscribeParameters: - """SubscribeParameters.""" + """SubscribeParameters type definition.""" events: list[str | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) @@ -98,21 +98,21 @@ class SubscribeParameters: @dataclass class UnsubscribeByIDRequest: - """UnsubscribeByIDRequest.""" + """UnsubscribeByIDRequest type definition.""" subscriptions: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: - """UnsubscribeByAttributesRequest.""" + """UnsubscribeByAttributesRequest type definition.""" events: list[str | None] | None = field(default_factory=list) @dataclass class StatusResult: - """StatusResult.""" + """StatusResult type definition.""" ready: bool | None = None message: str | None = None @@ -120,14 +120,14 @@ class StatusResult: @dataclass class NewParameters: - """NewParameters.""" + """NewParameters type definition.""" capabilities: Any | None = None @dataclass class NewResult: - """NewResult.""" + """NewResult type definition.""" session_id: str | None = None accept_insecure_certs: bool | None = None @@ -143,7 +143,7 @@ class NewResult: @dataclass class SubscribeResult: - """SubscribeResult.""" + """SubscribeResult type definition.""" subscription: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 3f29b85d13a23..c5a4666ebaf07 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -14,7 +14,7 @@ @dataclass class PartitionKey: - """PartitionKey.""" + """PartitionKey type definition.""" user_context: str | None = None source_origin: str | None = None @@ -22,7 +22,7 @@ class PartitionKey: @dataclass class GetCookiesParameters: - """GetCookiesParameters.""" + """GetCookiesParameters type definition.""" filter: Any | None = None partition: Any | None = None @@ -30,7 +30,7 @@ class GetCookiesParameters: @dataclass class GetCookiesResult: - """GetCookiesResult.""" + """GetCookiesResult type definition.""" cookies: list[Any | None] | None = field(default_factory=list) partition_key: Any | None = None @@ -38,7 +38,7 @@ class GetCookiesResult: @dataclass class SetCookieParameters: - """SetCookieParameters.""" + """SetCookieParameters type definition.""" cookie: Any | None = None partition: Any | None = None @@ -46,14 +46,14 @@ class SetCookieParameters: @dataclass class SetCookieResult: - """SetCookieResult.""" + """SetCookieResult type definition.""" partition_key: Any | None = None @dataclass class DeleteCookiesParameters: - """DeleteCookiesParameters.""" + """DeleteCookiesParameters type definition.""" filter: Any | None = None partition: Any | None = None @@ -61,7 +61,7 @@ class DeleteCookiesParameters: @dataclass class DeleteCookiesResult: - """DeleteCookiesResult.""" + """DeleteCookiesResult type definition.""" partition_key: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index ebbe6729499b2..0a3998a611125 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -14,14 +14,14 @@ @dataclass class InstallParameters: - """InstallParameters.""" + """InstallParameters type definition.""" extension_data: Any | None = None @dataclass class ExtensionPath: - """ExtensionPath.""" + """ExtensionPath type definition.""" type: str = field(default="path", init=False) path: str | None = None @@ -29,7 +29,7 @@ class ExtensionPath: @dataclass class ExtensionArchivePath: - """ExtensionArchivePath.""" + """ExtensionArchivePath type definition.""" type: str = field(default="archivePath", init=False) path: str | None = None @@ -37,7 +37,7 @@ class ExtensionArchivePath: @dataclass class ExtensionBase64Encoded: - """ExtensionBase64Encoded.""" + """ExtensionBase64Encoded type definition.""" type: str = field(default="base64", init=False) value: str | None = None @@ -45,14 +45,14 @@ class ExtensionBase64Encoded: @dataclass class InstallResult: - """InstallResult.""" + """InstallResult type definition.""" extension: Any | None = None @dataclass class UninstallParameters: - """UninstallParameters.""" + """UninstallParameters type definition.""" extension: Any | None = None diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 3b9402f7b547e..979415bd7a6e1 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -446,8 +446,12 @@ def execute( """ # Handle BiDi generator commands if inspect.isgenerator(driver_command): - # BiDi command: use WebSocketConnection directly - return self.command_executor.execute(driver_command) + # BiDi command: route through the WebSocket connection, not the + # HTTP RemoteConnection which only accepts (command, params) pairs. + if not self._websocket_connection: + self._start_bidi() + assert self._websocket_connection is not None + return self._websocket_connection.execute(driver_command) # Legacy WebDriver command: handle normally params = self._wrap_value(params) From 0b60f4543a27ba41dad5c2c6d594b40c3c3493ec Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Mon, 2 Mar 2026 13:55:49 +0000 Subject: [PATCH 06/67] [py] Fix Copilot review: license headers, _BiDiEncoder nested types, revert unrelated requirements changes --- py/generate_bidi.py | 19 ++++++++++++++++++- py/requirements.txt | 1 - py/requirements_lock.txt | 5 ++++- py/selenium/webdriver/common/bidi/__init__.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/browser.py | 17 +++++++++++++++++ .../webdriver/common/bidi/browsing_context.py | 17 +++++++++++++++++ .../webdriver/common/bidi/emulation.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/input.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/log.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/network.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/script.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/session.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/storage.py | 17 +++++++++++++++++ .../webdriver/common/bidi/webextension.py | 17 +++++++++++++++++ .../webdriver/remote/websocket_connection.py | 14 ++++++++++++-- 15 files changed, 221 insertions(+), 5 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 5d7f39e53abfc..412494517772a 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -32,7 +32,24 @@ logger = logging.getLogger("generate_bidi") # File headers -SHARED_HEADER = """# DO NOT EDIT THIS FILE! +SHARED_HEADER = """# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make # changes, edit the generator and regenerate all of the modules.""" diff --git a/py/requirements.txt b/py/requirements.txt index 5f943fdd24f91..fe7abe214f2e5 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -2,7 +2,6 @@ async-generator==1.10 attrs==25.4.0 backports.tarfile==1.2.0 cachetools==7.0.1 - certifi==2026.1.4 cffi==2.0.0 chardet==5.2.0 diff --git a/py/requirements_lock.txt b/py/requirements_lock.txt index 66fa883ec4a53..222690e162f3c 100644 --- a/py/requirements_lock.txt +++ b/py/requirements_lock.txt @@ -416,6 +416,7 @@ jeepney==0.9.0 \ --hash=sha256:cf0e9e845622b81e4a28df94c40345400256ec608d0e55bb8a3feaa9163f5732 # via # -r py/requirements.txt + # keyring # secretstorage jinja2==3.1.6 \ --hash=sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d \ @@ -964,7 +965,9 @@ rich==14.3.2 \ secretstorage==3.5.0 \ --hash=sha256:0ce65888c0725fcb2c5bc0fdb8e5438eece02c523557ea40ce0703c266248137 \ --hash=sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be - # via -r py/requirements.txt + # via + # -r py/requirements.txt + # keyring sniffio==1.3.1 \ --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \ --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index 7be7bd4f73856..bb129d5f6a195 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 71f917634304d..ff0c2d59b8cf2 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index ede96071778c3..7a0f8faf8687e 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index fbbe0966d8b3a..c58f6d5f78d6c 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index c8e58181b343e..e9c3f8345f05d 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index eaf52a2ec08c2..94f511d7185f8 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index c9737ac9131d0..9dc5fb94d8488 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 061bb17b0deec..0b2ec04101933 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index da12c1cd49792..771a5327151bf 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index c5a4666ebaf07..7623381706040 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a3998a611125..99250afca4c68 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 68358e4a09974..8d6f745d4ac5b 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -41,6 +41,16 @@ class _BiDiEncoder(json.JSONEncoder): directly into its parent action dict as required by the BiDi spec. """ + def _convert(self, value): + """Recursively convert a value, handling nested dataclasses, lists, and dicts.""" + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return self.default(value) + if isinstance(value, list): + return [self._convert(item) for item in value] + if isinstance(value, dict): + return {k: self._convert(v) for k, v in value.items()} + return value + def default(self, o): if dataclasses.is_dataclass(o) and not isinstance(o, type): result = {} @@ -54,9 +64,9 @@ def default(self, o): for pf in dataclasses.fields(value): pv = getattr(value, pf.name) if pv is not None: - result[_snake_to_camel(pf.name)] = pv + result[_snake_to_camel(pf.name)] = self._convert(pv) else: - result[camel_key] = value + result[camel_key] = self._convert(value) return result return super().default(o) From ec82adf18f50aabf1ab2c675090ac57671aefbd3 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 11:48:05 +0000 Subject: [PATCH 07/67] fixup --- py/generate_bidi.py | 1134 +---------------- .../webdriver/common/bidi/_event_manager.py | 186 +++ .../webdriver/remote/websocket_connection.py | 13 +- 3 files changed, 199 insertions(+), 1134 deletions(-) create mode 100644 py/selenium/webdriver/common/bidi/_event_manager.py diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 412494517772a..8103cafe40684 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -619,11 +619,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Add imports for event handling if needed if self.events: - code += "import threading\n" - code += "from collections.abc import Callable\n" - if not dataclass_imported: - code += "from dataclasses import dataclass\n" - code += "from selenium.webdriver.common.bidi.session import Session\n" + code += "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager\n" code += "\n\n" @@ -801,1131 +797,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # Generate EventConfig and _EventManager for modules with events - if self.events: - # Generate EventConfig dataclass - code += """@dataclass -class EventConfig: - \"\"\"Configuration for a BiDi event.\"\"\" - event_key: str - bidi_event: str - event_class: type - - -""" - - # Generate _EventManager class - code += """class _EventWrapper: - \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - \"\"\"Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - \"\"\" - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, \"from_json\") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend([\"_\", char.lower()]) - else: - result.append(char) - return \"\".join(result) - - -class _EventManager: - \"\"\"Manages event subscriptions and callbacks.\"\"\" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get(\"subscription\") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - \"callbacks\": [], - \"subscription_id\": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry[\"callbacks\"]: - session = Session(self.conn) - sub_id = entry.get(\"subscription_id\") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry[\"callbacks\"]: - entry[\"callbacks\"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - \"\"\"Clear all event handlers.\"\"\" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get(\"subscription_id\") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - -""" - code += "\n\n" + # EventConfig, _EventWrapper, and _EventManager are imported from + # ._event_manager (see the import block above); nothing to emit here. # Generate class - # Convert module name (camelCase or snake_case) to proper class name (PascalCase) - class_name = module_name_to_class_name(self.name) - code += f"class {class_name}:\n" - code += f' """WebDriver BiDi {self.name} module."""\n\n' - - # Add EVENT_CONFIGS dict if there are events - if self.events: - code += ( - " EVENT_CONFIGS = {}\n" # Will be populated after types are defined - ) - - if self.name == "script": - code += " def __init__(self, conn, driver=None) -> None:\n" - code += " self._conn = conn\n" - code += " self._driver = driver\n" - else: - code += " def __init__(self, conn) -> None:\n" - code += " self._conn = conn\n" - - # Initialize _event_manager if there are events - if self.events: - code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" - - # Append extra init code from enhancements (e.g. self.intercepts = []) - for init_line in enhancements.get("extra_init_code", []): - code += f" {init_line}\n" - - code += "\n" - - # Generate command methods - # Auto-exclude methods whose names appear in extra_methods to prevent duplicates - extra_method_names = set() - for extra_meth in enhancements.get("extra_methods", []): - m = re.search(r"def\s+(\w+)\s*\(", extra_meth) - if m: - extra_method_names.add(m.group(1)) - exclude_methods = ( - set(enhancements.get("exclude_methods", [])) | extra_method_names - ) - if self.commands: - for command in self.commands: - # Get method-specific enhancements - # Convert command name to snake_case to match enhancement manifest keys - method_name_snake = command._camel_to_snake(command.name) - if method_name_snake in exclude_methods: - continue - method_enhancements = enhancements.get(method_name_snake, {}) - code += command.to_python_method(method_enhancements) - code += "\n" - else: - code += " pass\n" - - # Emit extra methods from enhancement manifest - for extra_method in enhancements.get("extra_methods", []): - code += extra_method - code += "\n" - - # Add delegating event handler methods if events are present - if self.events: - code += """ - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - \"\"\"Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - \"\"\" - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - \"\"\"Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - \"\"\" - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - \"\"\"Clear all event handlers.\"\"\" - return self._event_manager.clear_event_handlers() -""" - - # Generate event info type aliases AFTER the class definition - # This ensures all types are available when we create the aliases - if self.events: - code += "\n# Event Info Type Aliases\n" - for event_def in self.events: - code += event_def.to_python_dataclass() - code += "\n" - - # Now populate EVENT_CONFIGS after the aliases are defined - code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" - # Use globals() to look up types dynamically to handle missing types gracefully - code += "_globals = globals()\n" - code += f"{class_name}.EVENT_CONFIGS = {{\n" - # Collect extra event keys to skip CDDL duplicates - extra_event_keys_cfg = { - evt["event_key"] for evt in enhancements.get("extra_events", []) - } - for event_def in self.events: - # Convert method name to user-friendly event name - method_parts = event_def.method.split(".") - if len(method_parts) == 2: - event_name = self._convert_method_to_event_name(method_parts[1]) - if event_name in extra_event_keys_cfg: - continue - # The event class is the event name (e.g., ContextCreated) - # Try to get it from globals, default to dict if not found - code += ( - f' "{event_name}": (\n' - f' EventConfig("{event_name}", "{event_def.method}",\n' - f' _globals.get("{event_def.name}", dict))\n' - f' if _globals.get("{event_def.name}")\n' - f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' - f" ),\n" - ) - # Extra events not in the CDDL spec - for extra_evt in enhancements.get("extra_events", []): - ek = extra_evt["event_key"] - be = extra_evt["bidi_event"] - ec = extra_evt["event_class"] - single = f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),' - if len(single) > 120: - code += ( - f' "{ek}": EventConfig(\n' - f' "{ek}", "{be}",\n' - f' _globals.get("{ec}", dict),\n' - f" ),\n" - ) - else: - code += single + "\n" - code += "}\n" - - return code - - -class CddlParser: - """Parse CDDL specification files.""" - - def __init__(self, cddl_path: str): - """Initialize parser with CDDL file path.""" - self.cddl_path = Path(cddl_path) - self.content = "" - self.modules: dict[str, CddlModule] = {} - self.definitions: dict[str, str] = {} - self.event_names: set[str] = set() # Names of definitions that are events - self._read_file() - - def _read_file(self) -> None: - """Read and preprocess CDDL file.""" - if not self.cddl_path.exists(): - raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - - with open(self.cddl_path, encoding="utf-8") as f: - self.content = f.read() - - logger.info(f"Loaded CDDL file: {self.cddl_path}") - - def parse(self) -> dict[str, CddlModule]: - """Parse CDDL content and return modules.""" - # Remove comments - content = self._remove_comments(self.content) - - # Extract all definitions - self._extract_definitions(content) - - # Extract event names from event union definitions - self._extract_event_names() - - # Extract type definitions by module - self._extract_types() - - # Extract event definitions by module - self._extract_events() - - # Extract command definitions by module - self._extract_commands() - - # If no modules found, create a default one from the filename - if not self.modules: - module_name = self.cddl_path.stem - default_module = CddlModule(name=module_name) - self.modules[module_name] = default_module - logger.warning(f"No modules found in CDDL, creating default: {module_name}") - - return self.modules - - def _remove_comments(self, content: str) -> str: - """Remove comments from CDDL content.""" - # CDDL uses ; for comments to end of line - lines = content.split("\n") - cleaned = [] - for line in lines: - if ";" in line and not line.strip().startswith(";"): - line = line[: line.index(";")] - elif line.strip().startswith(";"): - continue - cleaned.append(line) - return "\n".join(cleaned) - - def _extract_definitions(self, content: str) -> None: - """Extract CDDL definitions (type definitions, commands, etc.).""" - # Match pattern: Name = Definition - # Handles multiline definitions properly. - # The \s* after \n in the lookahead allows definitions that start with - # leading whitespace (e.g. " network.BeforeRequestSent = (") to be - # recognised as separate definitions instead of being swallowed into - # the body of the preceding definition. - pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\s*\w+(?:\.\w+)?\s*=|\Z)" - - for match in re.finditer(pattern, content, re.DOTALL): - name = match.group(1).strip() - definition = match.group(2).strip() - self.definitions[name] = definition - logger.debug(f"Extracted definition: {name}") - - def _extract_event_names(self) -> None: - """Extract event names from event union definitions. - - Event union definitions follow pattern: - module.ModuleEvent = ( - module.EventName1 // - module.EventName2 // - ... - ) - """ - for def_name, def_content in self.definitions.items(): - # Check if this looks like an event union (name ends with "Event") and - # contains a module-qualified reference like "module.EventName". - # Handles both single-item (no //) and multi-item (// separated) unions. - if "Event" in def_name and re.search(r"\w+\.\w+", def_content): - # Extract event names from the union (works for single and multi-item) - event_refs = re.findall(r"(\w+\.\w+)", def_content) - for event_ref in event_refs: - self.event_names.add(event_ref) - logger.debug(f"Identified event: {event_ref} (from {def_name})") - - def _extract_types(self) -> None: - """Extract type definitions from parsed definitions.""" - # Type definitions follow pattern: module.TypeName = { field: type, ... } - # They have dots in the name and curly braces in the content - # But they DON'T have method: "..." pattern (which means it's not a command) - # Enums follow pattern: module.EnumName = "value1" / "value2" / ... - - for def_name, def_content in self.definitions.items(): - # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") - if "." not in def_name: - continue - - # Skip if it's a command (contains method: pattern) - if "method:" in def_content: - continue - - # Extract module.TypeName - if "." in def_name: - module_name, type_name = def_name.rsplit(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Check if this is an enum (string union with /) - if self._is_enum_definition(def_content): - # Extract enum values - values = self._extract_enum_values(def_content) - if values: - enum_def = CddlEnum( - module=module_name, - name=type_name, - values=values, - description=f"{type_name}", - ) - self.modules[module_name].enums.append(enum_def) - logger.debug( - f"Found enum: {def_name} with {len(values)} values" - ) - else: - # Extract fields from type definition - fields = self._extract_type_fields(def_content) - - if fields: # Only create type if it has fields - type_def = CddlTypeDefinition( - module=module_name, - name=type_name, - fields=fields, - description=f"{type_name}", - ) - self.modules[module_name].types.append(type_def) - logger.debug( - f"Found type: {def_name} with {len(fields)} fields" - ) - - def _is_enum_definition(self, definition: str) -> bool: - """Check if a definition is an enum (string union with /). - - Enums are defined as: "value1" / "value2" / "value3" - """ - # Clean whitespace - clean_def = definition.strip() - - # Must not have curly braces (that would be a type definition) - if "{" in clean_def or "}" in clean_def: - return False - - # Must contain the union operator / surrounded by quotes - # Pattern: "something" / "something_else" - return " / " in clean_def and '"' in clean_def - - def _extract_enum_values(self, enum_definition: str) -> list[str]: - """Extract individual values from an enum definition. - - Enums are defined as: "value1" / "value2" / "value3" - Can span multiple lines. - """ - values = [] - - # Clean the definition and extract quoted strings - # Split by / and extract quoted values - parts = enum_definition.split("/") - - for part in parts: - part = part.strip() - - # Extract quoted string - use search instead of match to find quotes anywhere - match = re.search(r'"([^"]*)"', part) - if match: - value = match.group(1) - values.append(value) - logger.debug(f"Extracted enum value: {value}") - - return values - - @staticmethod - def _normalize_cddl_type(field_type: str) -> str: - """Normalize a CDDL type expression to a simple Python-compatible form. - - Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and - replaces interval/constraint expressions with their base types so that - the caller can safely check for nested struct syntax. - - Examples: - '(float .ge 0.0) .default 1.0' -> 'float' - '(float .ge 0.0) / null' -> 'float / null' - '(0.0...360.0) / null' -> 'float / null' - '-90.0..90.0' -> 'float' - 'float / null .default null' -> 'float / null' - """ - result = field_type - # Remove trailing .default annotations - result = re.sub(r"\s*\.default\s+\S+", "", result) - # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType - result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) - # Replace parenthesised numeric interval types: (0.0...360.0) -> float - result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) - # Replace bare numeric interval types: -90.0..90.0 -> float - result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) - return result.strip() - - def _extract_type_fields(self, type_definition: str) -> dict[str, str]: - """Extract fields from a type definition block.""" - fields = {} - - # Remove outer braces - clean_def = type_definition.strip() - if clean_def.startswith("{"): - clean_def = clean_def[1:] - if clean_def.endswith("}"): - clean_def = clean_def[:-1] - - # Parse each line for field: type patterns - for line in clean_def.split("\n"): - line = line.strip() - if not line or "Extensible" in line or line.startswith("//"): - continue - - # Match pattern: [?] fieldName: type - match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - if not match: - # Try without optional marker - match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - - if match: - field_name = match.group(1).strip() - field_type = match.group(2).strip() - normalized_type = self._normalize_cddl_type(field_type) - - # Skip lines that are part of nested definitions - if "{" not in normalized_type and "(" not in normalized_type: - fields[field_name] = normalized_type - logger.debug(f"Extracted field {field_name}: {normalized_type}") - - return fields - - def _extract_events(self) -> None: - """Extract event definitions from parsed definitions. - - Events are definitions that: - 1. Are listed in an event union (e.g., BrowsingContextEvent) - 2. Have method: "..." and params: ... fields - - Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) - """ - # Find definitions that are in the event_names set - event_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) - - for def_name, def_content in self.definitions.items(): - # Skip if not identified as an event - if def_name not in self.event_names: - continue - - # Extract method and params - match = event_pattern.search(def_content) - if match: - method = match.group(1) # e.g., "browsingContext.contextCreated" - params_type = match.group(2) # e.g., "browsingContext.Info" - - # Extract module name from method - if "." in method: - module_name, _ = method.split(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Extract event name from definition name (e.g., browsingContext.ContextCreated) - _, event_name = def_name.rsplit(".", 1) - - # Create event - event = CddlEvent( - module=module_name, - name=event_name, - method=method, - params_type=params_type, - description=f"Event: {method}", - ) - - self.modules[module_name].events.append(event) - logger.debug( - f"Found event: {def_name} (method={method}, params={params_type})" - ) - - def _extract_commands(self) -> None: - """Extract command definitions from parsed definitions.""" - # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) - command_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) - - for def_name, def_content in self.definitions.items(): - # Skip definitions that are events (they share the same pattern) - if def_name in self.event_names: - continue - matches = list(command_pattern.finditer(def_content)) - if matches: - for match in matches: - method = match.group(1) # e.g., "session.new" - params_type = match.group(2) # e.g., "session.NewParameters" - - # Extract module name from method - if "." in method: - module_name, command_name = method.split(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Extract parameters - params = self._extract_parameters(params_type) - - # Create command - cmd = CddlCommand( - module=module_name, - name=command_name, - params=params, - description=f"Execute {method}", - ) - - self.modules[module_name].commands.append(cmd) - logger.debug( - f"Found command: {method} with params {params_type}" - ) - - def _extract_parameters( - self, params_type: str, _seen: set[str] | None = None - ) -> dict[str, str]: - """Extract parameters from a parameter type definition. - - Handles both struct types ({...}) and top-level union types (TypeA / TypeB), - merging all fields from each alternative as optional parameters. - """ - params = {} - - if _seen is None: - _seen = set() - if params_type in _seen: - return params - _seen.add(params_type) - - if params_type not in self.definitions: - logger.debug(f"Parameter type not found: {params_type}") - return params - - definition = self.definitions[params_type] - - # Handle top-level type alias that is a union of other named types: - # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest - # These definitions contain a single line with "/" separating type names - # (not the double-slash "//" used for command unions). - stripped = definition.strip() - if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: - # Each token separated by "/" should be a named type reference - alternatives = [a.strip() for a in stripped.split("/") if a.strip()] - all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) - if all_named: - for alt_type in alternatives: - alt_params = self._extract_parameters(alt_type, _seen) - params.update(alt_params) - return params - - # Remove the outer curly braces and split by comma - # Then parse each line for key: type patterns - clean_def = stripped - if clean_def.startswith("{"): - clean_def = clean_def[1:] - if clean_def.endswith("}"): - clean_def = clean_def[:-1] - - # Split by newlines and process each line - for line in clean_def.split("\n"): - line = line.strip() - if not line or "Extensible" in line: - continue - - # Match pattern: [?] name: type - # Using a simple pattern that handles optional prefix - match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - if not match: - # Try without optional marker - match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - - if match: - param_name = match.group(1).strip() - param_type = match.group(2).strip() - normalized_type = self._normalize_cddl_type(param_type) - - # Skip lines that are part of nested definitions - if "{" not in normalized_type and "(" not in normalized_type: - params[param_name] = normalized_type - logger.debug( - f"Extracted param {param_name}: {normalized_type} from {params_type}" - ) - - return params - - -def module_name_to_class_name(module_name: str) -> str: - """Convert module name to class name (PascalCase). - - Handles both camelCase (browsingContext) and snake_case (browsing_context). - """ - if "_" in module_name: - # Snake_case: browsing_context -> BrowsingContext - return "".join(word.capitalize() for word in module_name.split("_")) - else: - # CamelCase: browsingContext -> BrowsingContext - return module_name[0].upper() + module_name[1:] if module_name else "" - - -def module_name_to_filename(module_name: str) -> str: - """Convert module name to Python filename (snake_case). - - Handles both camelCase (browsingContext) and snake_case (browsing_context). - Special cases: - - browsingContext -> browsing_context - - webExtension -> webextension - """ - # Handle explicit mappings for known camelCase names - camel_to_snake_map = { - "browsingContext": "browsing_context", - "webExtension": "webextension", - } - - if module_name in camel_to_snake_map: - return camel_to_snake_map[module_name] - - if "_" in module_name: - # Already snake_case - return module_name - else: - # Convert camelCase to snake_case for other cases - # This handles cases like "myModuleName" -> "my_module_name" - import re - - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - - -def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: - """Generate __init__.py file for the module.""" - init_path = output_path / "__init__.py" - - code = f"""{SHARED_HEADER} - -from __future__ import annotations - -""" - - for module_name in sorted(modules.keys()): - class_name = module_name_to_class_name(module_name) - filename = module_name_to_filename(module_name) - code += f"from .{filename} import {class_name}\n" - - code += "\n__all__ = [\n" - for module_name in sorted(modules.keys()): - class_name = module_name_to_class_name(module_name) - code += f' "{class_name}",\n' - code += "]\n" - - with open(init_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {init_path}") - - -def generate_common_file(output_path: Path) -> None: - """Generate common.py file with shared utilities.""" - common_path = output_path / "common.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - '"""Common utilities for BiDi command construction."""\n' - "\n" - "from __future__ import annotations\n" - "\n" - "from collections.abc import Generator\n" - "from typing import Any\n" - "\n" - "\n" - "def command_builder(\n" - " method: str, params: dict[str, Any] | None = None\n" - ") -> Generator[dict[str, Any], Any, Any]:\n" - ' """Build a BiDi command generator.\n' - "\n" - " Args:\n" - ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' - " params: The parameters for the command\n" - "\n" - " Yields:\n" - " A dictionary representing the BiDi command\n" - "\n" - " Returns:\n" - " The result from the BiDi command execution\n" - ' """\n' - " if params is None:\n" - " params = {}\n" - ' result = yield {"method": method, "params": params}\n' - " return result\n" - ) - - with open(common_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {common_path}") - - -def generate_console_file(output_path: Path) -> None: - """Generate console.py file with Console enum helper.""" - console_path = output_path / "console.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - "from enum import Enum\n" - "\n" - "\n" - "class Console(Enum):\n" - ' ALL = "all"\n' - ' LOG = "log"\n' - ' ERROR = "error"\n' - ) - - with open(console_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {console_path}") - - -def generate_permissions_file(output_path: Path) -> None: - """Generate permissions.py file with permission-related classes.""" - permissions_path = output_path / "permissions.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - '"""WebDriver BiDi Permissions module."""\n' - "\n" - "from __future__ import annotations\n" - "\n" - "from __future__ import annotations\n" - "\n" - "from enum import Enum\n" - "from typing import Any\n" - "\n" - "from .common import command_builder\n" - "\n" - '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' - "\n" - "\n" - "class PermissionState(str, Enum):\n" - ' """Permission state enumeration."""\n' - "\n" - ' GRANTED = "granted"\n' - ' DENIED = "denied"\n' - ' PROMPT = "prompt"\n' - "\n" - "\n" - "class PermissionDescriptor:\n" - ' """Descriptor for a permission."""\n' - "\n" - " def __init__(self, name: str) -> None:\n" - ' """Initialize a PermissionDescriptor.\n' - "\n" - " Args:\n" - " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" - ' """\n' - " self.name = name\n" - "\n" - " def __repr__(self) -> str:\n" - " return f\"PermissionDescriptor('{self.name}')\"\n" - "\n" - "\n" - "class Permissions:\n" - ' """WebDriver BiDi Permissions module."""\n' - "\n" - " def __init__(self, websocket_connection: Any) -> None:\n" - ' """Initialize the Permissions module.\n' - "\n" - " Args:\n" - " websocket_connection: The WebSocket connection for sending BiDi commands\n" - ' """\n' - " self._conn = websocket_connection\n" - "\n" - " def set_permission(\n" - " self,\n" - " descriptor: PermissionDescriptor | str,\n" - " state: PermissionState | str,\n" - " origin: str | None = None,\n" - " user_context: str | None = None,\n" - " ) -> None:\n" - ' """Set a permission for a given origin.\n' - "\n" - " Args:\n" - " descriptor: The permission descriptor or permission name as a string\n" - " state: The desired permission state\n" - " origin: The origin for which to set the permission\n" - " user_context: Optional user context ID to scope the permission\n" - "\n" - " Raises:\n" - " ValueError: If the state is not a valid permission state\n" - ' """\n' - " state_value = state.value if isinstance(state, PermissionState) else state\n" - " if state_value not in _VALID_PERMISSION_STATES:\n" - " raise ValueError(\n" - ' f"Invalid permission state: {state_value!r}. "\n' - ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' - " )\n" - "\n" - " if isinstance(descriptor, str):\n" - ' descriptor_dict = {"name": descriptor}\n' - " else:\n" - ' descriptor_dict = {"name": descriptor.name}\n' - "\n" - " params: dict[str, Any] = {\n" - ' "descriptor": descriptor_dict,\n' - ' "state": state_value,\n' - " }\n" - " if origin is not None:\n" - ' params["origin"] = origin\n' - " if user_context is not None:\n" - ' params["userContext"] = user_context\n' - "\n" - ' cmd = command_builder("permissions.setPermission", params)\n' - " self._conn.execute(cmd)\n" - ) - - with open(permissions_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {permissions_path}") - - -def main( - cddl_file: str, - output_dir: str, - spec_version: str = "1.0", - enhancements_manifest: str | None = None, -) -> None: - """Main entry point. - - Args: - cddl_file: Path to CDDL specification file - output_dir: Output directory for generated modules - spec_version: BiDi spec version - enhancements_manifest: Path to enhancement manifest Python file - """ - output_path = Path(output_dir).resolve() - output_path.mkdir(parents=True, exist_ok=True) - - logger.info(f"WebDriver BiDi Code Generator v{__version__}") - logger.info(f"Input CDDL: {cddl_file}") - logger.info(f"Output directory: {output_path}") - logger.info(f"Spec version: {spec_version}") - - # Load enhancement manifest - manifest = load_enhancements_manifest(enhancements_manifest) - if manifest: - logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") - - # Parse CDDL - parser = CddlParser(cddl_file) - modules = parser.parse() - - logger.info(f"Parsed {len(modules)} modules") - - # Clean up existing generated files - for file_path in output_path.glob("*.py"): - if file_path.name != "py.typed" and not file_path.name.startswith("_"): - file_path.unlink() - logger.debug(f"Removed: {file_path}") - - # Generate module files using snake_case filenames - for module_name, module in sorted(modules.items()): - filename = module_name_to_filename(module_name) - module_path = output_path / f"{filename}.py" - - # Get module-specific enhancements (merge with dataclass templates) - module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) - - # Add dataclass methods and docstrings to the enhancement data for this module - full_module_enhancements = { - **module_enhancements, - "dataclass_methods": manifest.get("dataclass_methods", {}), - "method_docstrings": manifest.get("method_docstrings", {}), - } - - with open(module_path, "w", encoding="utf-8") as f: - f.write(module.generate_code(full_module_enhancements)) - logger.info(f"Generated: {module_path}") - - # Generate __init__.py - generate_init_file(output_path, modules) - - # Generate common.py - generate_common_file(output_path) - - # Generate permissions.py - generate_permissions_file(output_path) - - # Generate console.py - generate_console_file(output_path) - - # Create py.typed marker - py_typed_path = output_path / "py.typed" - py_typed_path.touch() - logger.info(f"Generated type marker: {py_typed_path}") - - logger.info("Code generation complete!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate Python WebDriver BiDi modules from CDDL specification" - ) - parser.add_argument( - "cddl_file", - help="Path to CDDL specification file", - ) - parser.add_argument( - "output_dir", - help="Output directory for generated Python modules", - ) - parser.add_argument( - "--version", - default="1.0", - help="BiDi spec version (default: 1.0)", - ) - parser.add_argument( - "--enhancements-manifest", - default=None, - help="Path to enhancement manifest Python file (optional)", - ) - parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Enable verbose logging", - ) - - args = parser.parse_args() - - if args.verbose: - logging.getLogger("generate_bidi").setLevel(logging.DEBUG) - - try: - main( - args.cddl_file, - args.output_dir, - args.version, - args.enhancements_manifest, - ) - sys.exit(0) - except Exception as e: - logger.error(f"Generation failed: {e}", exc_info=True) - sys.exit(1) diff --git a/py/selenium/webdriver/common/bidi/_event_manager.py b/py/selenium/webdriver/common/bidi/_event_manager.py new file mode 100644 index 0000000000000..216a5b8eccb70 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/_event_manager.py @@ -0,0 +1,186 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Shared event management helpers for generated WebDriver BiDi modules. + +``EventConfig``, ``_EventWrapper``, and ``_EventManager`` are emitted +identically into every generated module that exposes events. Rather than +duplicating ~160 lines of code across all of those modules, they are defined +once here and imported by the generated files. +""" + +from __future__ import annotations + +import threading +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from selenium.webdriver.common.bidi.session import Session + + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 8d6f745d4ac5b..b4cac118df033 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -70,6 +70,7 @@ def default(self, o): return result return super().default(o) + logger = logging.getLogger(__name__) @@ -154,7 +155,9 @@ def _serialize_command(self, command): def _deserialize_result(self, result, command): try: _ = command.send(result) - raise WebDriverException("The command's generator function did not exit when expected!") + raise WebDriverException( + "The command's generator function did not exit when expected!" + ) except StopIteration as exit: return exit.value @@ -171,11 +174,15 @@ def on_error(ws, error): def run_socket(): if self.url.startswith("wss://"): - self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) + self._ws.run_forever( + sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True + ) else: self._ws.run_forever(suppress_origin=True) - self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) + self._ws = WebSocketApp( + self.url, on_open=on_open, on_message=on_message, on_error=on_error + ) self._ws_thread = Thread(target=run_socket, daemon=True) self._ws_thread.start() From e8dc180070b7ee1503a2d85060937362d3798ede Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:12:02 +0000 Subject: [PATCH 08/67] remove --version call --- py/private/generate_bidi.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl index c11b6efe4735f..e072279f85e94 100644 --- a/py/private/generate_bidi.bzl +++ b/py/private/generate_bidi.bzl @@ -53,7 +53,6 @@ def _generate_bidi_impl(ctx): args = [ cddl_file.path, output_base, - "--version", spec_version, ] From 35eff41316653c1cb03292e5b7eb3f1208508a9f Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:37:46 +0000 Subject: [PATCH 09/67] correct web extensions --- py/generate_bidi.py | 1274 +++++++++++++++-- py/private/bidi_enhancements_manifest.py | 10 +- py/selenium/webdriver/common/bidi/__init__.py | 17 - py/selenium/webdriver/common/bidi/browser.py | 83 +- .../webdriver/common/bidi/browsing_context.py | 256 ++-- py/selenium/webdriver/common/bidi/common.py | 11 +- .../webdriver/common/bidi/emulation.py | 216 +-- py/selenium/webdriver/common/bidi/input.py | 83 +- py/selenium/webdriver/common/bidi/log.py | 39 +- py/selenium/webdriver/common/bidi/network.py | 312 ++-- .../webdriver/common/bidi/permissions.py | 10 +- py/selenium/webdriver/common/bidi/script.py | 253 ++-- py/selenium/webdriver/common/bidi/session.py | 77 +- py/selenium/webdriver/common/bidi/storage.py | 75 +- .../webdriver/common/bidi/webextension.py | 57 +- 15 files changed, 1764 insertions(+), 1009 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 8103cafe40684..d14e2575c8bfd 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,11 +18,12 @@ import logging import re import sys +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import indent as tw_indent -from typing import Any +from textwrap import dedent, indent as tw_indent +from typing import Any, Dict, List, Optional, Set, Tuple __version__ = "1.0.0" @@ -32,24 +33,7 @@ logger = logging.getLogger("generate_bidi") # File headers -SHARED_HEADER = """# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# DO NOT EDIT THIS FILE! +SHARED_HEADER = """# DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make # changes, edit the generator and regenerate all of the modules.""" @@ -59,7 +43,8 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any +from typing import Any, Dict, List, Optional, Union +from .common import command_builder """ @@ -68,7 +53,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: +def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -139,10 +124,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"list[{inner_type}]" + return f"List[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "dict[str, Any]" + return "Dict[str, Any]" # Default to Any for unknown types return "Any" @@ -154,11 +139,11 @@ class CddlCommand: module: str name: str - params: dict[str, str] = field(default_factory=dict) - result: str | None = None + params: Dict[str, str] = field(default_factory=dict) + result: Optional[str] = None description: str = "" - def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: + def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python method code for this command. Args: @@ -189,15 +174,8 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: else: param_list = "self" - # Build method body - wrap long signatures over multiple lines if needed - sig_line = f" def {method_name}({param_list}):" - if len(sig_line) > 120 and param_strs: - body = f" def {method_name}(\n self,\n" - for p in param_strs: - body += f" {p},\n" - body += " ):\n" - else: - body = sig_line + "\n" + # Build method body + body = f" def {method_name}({param_list}):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add validation if specified @@ -259,6 +237,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform + override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -285,45 +264,45 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f' item.get("{extract_property}")\n' - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += " )\n" - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += " return extracted\n" - body += " return result\n" + body += f" return extracted\n" + body += f" return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += " )\n" - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -372,10 +351,10 @@ class CddlTypeDefinition: module: str name: str - fields: dict[str, str] = field(default_factory=dict) + fields: Dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: + def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python dataclass code for this type. Args: @@ -385,14 +364,11 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str dataclass_methods = enhancements.get("dataclass_methods", {}) method_docstrings = enhancements.get("method_docstrings", {}) - # Generate class name from type name. - # CDDL type names that start with a lowercase letter (e.g. camelCase - # command-parameter types like "setNetworkConditionsParameters") are - # capitalised so that the resulting Python class follows PascalCase. - class_name = self.name[0].upper() + self.name[1:] if self.name else self.name - code = "@dataclass\n" + # Generate class name from type name (keep it as-is, don't split on underscores) + class_name = self.name + code = f"@dataclass\n" code += f"class {class_name}:\n" - code += f' """{class_name} type definition."""\n\n' + code += f' """{self.description or self.name}."""\n\n' if not self.fields: code += " pass\n" @@ -410,7 +386,7 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' # Check if this field is a list type - elif "list[" in python_type: + elif "List[" in python_type: code += f" {snake_name}: {python_type} = field(default_factory=list)\n" else: code += f" {snake_name}: {python_type} = None\n" @@ -477,7 +453,7 @@ class CddlEnum: module: str name: str - values: list[str] = field(default_factory=list) + values: List[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -486,9 +462,9 @@ def to_python_class(self) -> str: Generates a simple class with string constants to match the existing pattern in the codebase (e.g., ClientWindowState). """ - class_name = self.name[0].upper() + self.name[1:] if self.name else self.name + class_name = self.name code = f"class {class_name}:\n" - code += f' """{class_name}."""\n\n' + code += f' """{self.description or self.name}."""\n\n' for value in self.values: # Convert value to UPPER_SNAKE_CASE constant name @@ -554,10 +530,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: list[CddlCommand] = field(default_factory=list) - types: list[CddlTypeDefinition] = field(default_factory=list) - enums: list[CddlEnum] = field(default_factory=list) - events: list[CddlEvent] = field(default_factory=list) + commands: List[CddlCommand] = field(default_factory=list) + types: List[CddlTypeDefinition] = field(default_factory=list) + enums: List[CddlEnum] = field(default_factory=list) + events: List[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -572,33 +548,7 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def _needs_field_import(self, enhancements: dict[str, Any] | None = None) -> bool: - """Check if any type definition in this module requires the 'field' import. - - Respects the same type exclusions applied during code generation. - """ - enhancements = enhancements or {} - extra_cls_names: set[str] = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names - - for type_def in self.types: - if type_def.name in exclude_types: - continue - for field_type in type_def.fields.values(): - # Literal string discriminants use field(default=..., init=False) - if re.match(r'^"', field_type.strip()): - return True - # List-typed fields use field(default_factory=list) - python_type = CddlTypeDefinition._get_python_type(field_type) - if python_type.startswith("list["): - return True - return False - - def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: + def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python code for this module. Args: @@ -608,18 +558,18 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code = MODULE_HEADER.format(self.name) # Add imports if needed - if self.commands: - code += "from .common import command_builder\n" - dataclass_imported = False + if self.types: + code += "from dataclasses import field\n" if self.commands or self.types: + code += "from typing import Generator\n" code += "from dataclasses import dataclass\n" - dataclass_imported = True - if self.types and self._needs_field_import(enhancements): - code += "from dataclasses import field\n" # Add imports for event handling if needed if self.events: - code += "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager\n" + code += "import threading\n" + code += "from collections.abc import Callable\n" + code += "from dataclasses import dataclass\n" + code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -700,19 +650,8 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Collect names of extra_dataclasses so we can skip CDDL-generated - # enums and types that are overridden by manual definitions. - extra_cls_names = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names - - # Generate enums first, skipping any that are overridden via extra_dataclasses + # Generate enums first for enum_def in self.enums: - if enum_def.name in exclude_types: - continue code += enum_def.to_python_class() code += "\n\n" @@ -721,6 +660,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses + exclude_types = set(enhancements.get("exclude_types", [])) for type_def in self.types: if type_def.name in exclude_types: continue @@ -740,18 +680,13 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate EVENT_NAME_MAPPING for the module code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" - # Collect event keys from extra_events so we skip CDDL duplicates - extra_event_keys = { - evt["event_key"] for evt in enhancements.get("extra_events", []) - } for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - if event_name not in extra_event_keys: - code += f' "{event_name}": "{event_def.method}",\n' + code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): code += ( @@ -797,7 +732,1094 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # EventConfig, _EventWrapper, and _EventManager are imported from - # ._event_manager (see the import block above); nothing to emit here. + # Generate EventConfig and _EventManager for modules with events + if self.events: + # Generate EventConfig dataclass + code += """@dataclass +class EventConfig: + \"\"\"Configuration for a BiDi event.\"\"\" + event_key: str + bidi_event: str + event_class: type + + +""" + + # Generate _EventManager class + code += """class _EventWrapper: + \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + \"\"\"Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + \"\"\" + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, \"from_json\") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend([\"_\", char.lower()]) + else: + result.append(char) + return \"\".join(result) + + +class _EventManager: + \"\"\"Manages event subscriptions and callbacks.\"\"\" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get(\"subscription\") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + \"callbacks\": [], + \"subscription_id\": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry[\"callbacks\"]: + session = Session(self.conn) + sub_id = entry.get(\"subscription_id\") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry[\"callbacks\"]: + entry[\"callbacks\"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get(\"subscription_id\") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + +""" + code += "\n\n" # Generate class + # Convert module name (camelCase or snake_case) to proper class name (PascalCase) + class_name = module_name_to_class_name(self.name) + code += f"class {class_name}:\n" + code += f' """WebDriver BiDi {self.name} module."""\n\n' + + # Add EVENT_CONFIGS dict if there are events + if self.events: + code += ( + " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + ) + + if self.name == "script": + code += " def __init__(self, conn, driver=None) -> None:\n" + code += " self._conn = conn\n" + code += " self._driver = driver\n" + else: + code += " def __init__(self, conn) -> None:\n" + code += " self._conn = conn\n" + + # Initialize _event_manager if there are events + if self.events: + code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" + + # Append extra init code from enhancements (e.g. self.intercepts = []) + for init_line in enhancements.get("extra_init_code", []): + code += f" {init_line}\n" + + code += "\n" + + # Generate command methods + exclude_methods = enhancements.get("exclude_methods", []) + if self.commands: + for command in self.commands: + # Get method-specific enhancements + # Convert command name to snake_case to match enhancement manifest keys + method_name_snake = command._camel_to_snake(command.name) + if method_name_snake in exclude_methods: + continue + method_enhancements = enhancements.get(method_name_snake, {}) + code += command.to_python_method(method_enhancements) + code += "\n" + else: + code += " pass\n" + + # Emit extra methods from enhancement manifest + for extra_method in enhancements.get("extra_methods", []): + code += extra_method + code += "\n" + + # Add delegating event handler methods if events are present + if self.events: + code += """ + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + \"\"\"Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + \"\"\" + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + \"\"\"Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + \"\"\" + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + return self._event_manager.clear_event_handlers() +""" + + # Generate event info type aliases AFTER the class definition + # This ensures all types are available when we create the aliases + if self.events: + code += "\n# Event Info Type Aliases\n" + for event_def in self.events: + code += event_def.to_python_dataclass() + code += "\n" + + # Now populate EVENT_CONFIGS after the aliases are defined + code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + # Use globals() to look up types dynamically to handle missing types gracefully + code += f"_globals = globals()\n" + code += f"{class_name}.EVENT_CONFIGS = {{\n" + for event_def in self.events: + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # The event class is the event name (e.g., ContextCreated) + # Try to get it from globals, default to dict if not found + code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + # Extra events not in the CDDL spec + for extra_evt in enhancements.get("extra_events", []): + ek = extra_evt["event_key"] + be = extra_evt["bidi_event"] + ec = extra_evt["event_class"] + code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + code += "}\n" + + return code + + +class CddlParser: + """Parse CDDL specification files.""" + + def __init__(self, cddl_path: str): + """Initialize parser with CDDL file path.""" + self.cddl_path = Path(cddl_path) + self.content = "" + self.modules: Dict[str, CddlModule] = {} + self.definitions: Dict[str, str] = {} + self.event_names: Set[str] = set() # Names of definitions that are events + self._read_file() + + def _read_file(self) -> None: + """Read and preprocess CDDL file.""" + if not self.cddl_path.exists(): + raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") + + with open(self.cddl_path, "r", encoding="utf-8") as f: + self.content = f.read() + + logger.info(f"Loaded CDDL file: {self.cddl_path}") + + def parse(self) -> Dict[str, CddlModule]: + """Parse CDDL content and return modules.""" + # Remove comments + content = self._remove_comments(self.content) + + # Extract all definitions + self._extract_definitions(content) + + # Extract event names from event union definitions + self._extract_event_names() + + # Extract type definitions by module + self._extract_types() + + # Extract event definitions by module + self._extract_events() + + # Extract command definitions by module + self._extract_commands() + + # If no modules found, create a default one from the filename + if not self.modules: + module_name = self.cddl_path.stem + default_module = CddlModule(name=module_name) + self.modules[module_name] = default_module + logger.warning(f"No modules found in CDDL, creating default: {module_name}") + + return self.modules + + def _remove_comments(self, content: str) -> str: + """Remove comments from CDDL content.""" + # CDDL uses ; for comments to end of line + lines = content.split("\n") + cleaned = [] + for line in lines: + if ";" in line and not line.strip().startswith(";"): + line = line[: line.index(";")] + elif line.strip().startswith(";"): + continue + cleaned.append(line) + return "\n".join(cleaned) + + def _extract_definitions(self, content: str) -> None: + """Extract CDDL definitions (type definitions, commands, etc.).""" + # Match pattern: Name = Definition + # Handles multiline definitions properly + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + + for match in re.finditer(pattern, content, re.DOTALL): + name = match.group(1).strip() + definition = match.group(2).strip() + self.definitions[name] = definition + logger.debug(f"Extracted definition: {name}") + + def _extract_event_names(self) -> None: + """Extract event names from event union definitions. + + Event union definitions follow pattern: + module.ModuleEvent = ( + module.EventName1 // + module.EventName2 // + ... + ) + """ + # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. + event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") + + for def_name, def_content in self.definitions.items(): + # Check if this looks like an event union (name ends with "Event") and + # contains a module-qualified reference like "module.EventName". + # Handles both single-item (no //) and multi-item (// separated) unions. + if "Event" in def_name and re.search(r"\w+\.\w+", def_content): + # Extract event names from the union (works for single and multi-item) + event_refs = re.findall(r"(\w+\.\w+)", def_content) + for event_ref in event_refs: + self.event_names.add(event_ref) + logger.debug(f"Identified event: {event_ref} (from {def_name})") + + def _extract_types(self) -> None: + """Extract type definitions from parsed definitions.""" + # Type definitions follow pattern: module.TypeName = { field: type, ... } + # They have dots in the name and curly braces in the content + # But they DON'T have method: "..." pattern (which means it's not a command) + # Enums follow pattern: module.EnumName = "value1" / "value2" / ... + + for def_name, def_content in self.definitions.items(): + # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") + if "." not in def_name: + continue + + # Skip if it's a command (contains method: pattern) + if "method:" in def_content: + continue + + # Extract module.TypeName + if "." in def_name: + module_name, type_name = def_name.rsplit(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Check if this is an enum (string union with /) + if self._is_enum_definition(def_content): + # Extract enum values + values = self._extract_enum_values(def_content) + if values: + enum_def = CddlEnum( + module=module_name, + name=type_name, + values=values, + description=f"{type_name}", + ) + self.modules[module_name].enums.append(enum_def) + logger.debug( + f"Found enum: {def_name} with {len(values)} values" + ) + else: + # Extract fields from type definition + fields = self._extract_type_fields(def_content) + + if fields: # Only create type if it has fields + type_def = CddlTypeDefinition( + module=module_name, + name=type_name, + fields=fields, + description=f"{type_name}", + ) + self.modules[module_name].types.append(type_def) + logger.debug( + f"Found type: {def_name} with {len(fields)} fields" + ) + + def _is_enum_definition(self, definition: str) -> bool: + """Check if a definition is an enum (string union with /). + + Enums are defined as: "value1" / "value2" / "value3" + """ + # Clean whitespace + clean_def = definition.strip() + + # Must not have curly braces (that would be a type definition) + if "{" in clean_def or "}" in clean_def: + return False + + # Must contain the union operator / surrounded by quotes + # Pattern: "something" / "something_else" + return " / " in clean_def and '"' in clean_def + + def _extract_enum_values(self, enum_definition: str) -> List[str]: + """Extract individual values from an enum definition. + + Enums are defined as: "value1" / "value2" / "value3" + Can span multiple lines. + """ + values = [] + + # Clean the definition and extract quoted strings + # Split by / and extract quoted values + parts = enum_definition.split("/") + + for part in parts: + part = part.strip() + + # Extract quoted string - use search instead of match to find quotes anywhere + match = re.search(r'"([^"]*)"', part) + if match: + value = match.group(1) + values.append(value) + logger.debug(f"Extracted enum value: {value}") + + return values + + @staticmethod + def _normalize_cddl_type(field_type: str) -> str: + """Normalize a CDDL type expression to a simple Python-compatible form. + + Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and + replaces interval/constraint expressions with their base types so that + the caller can safely check for nested struct syntax. + + Examples: + '(float .ge 0.0) .default 1.0' -> 'float' + '(float .ge 0.0) / null' -> 'float / null' + '(0.0...360.0) / null' -> 'float / null' + '-90.0..90.0' -> 'float' + 'float / null .default null' -> 'float / null' + """ + result = field_type + # Remove trailing .default annotations + result = re.sub(r"\s*\.default\s+\S+", "", result) + # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType + result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) + # Replace parenthesised numeric interval types: (0.0...360.0) -> float + result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) + # Replace bare numeric interval types: -90.0..90.0 -> float + result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) + return result.strip() + + def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + """Extract fields from a type definition block.""" + fields = {} + + # Remove outer braces + clean_def = type_definition.strip() + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Parse each line for field: type patterns + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line or line.startswith("//"): + continue + + # Match pattern: [?] fieldName: type + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + field_name = match.group(1).strip() + field_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(field_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + fields[field_name] = normalized_type + logger.debug(f"Extracted field {field_name}: {normalized_type}") + + return fields + + def _extract_events(self) -> None: + """Extract event definitions from parsed definitions. + + Events are definitions that: + 1. Are listed in an event union (e.g., BrowsingContextEvent) + 2. Have method: "..." and params: ... fields + + Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) + """ + # Find definitions that are in the event_names set + event_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip if not identified as an event + if def_name not in self.event_names: + continue + + # Extract method and params + match = event_pattern.search(def_content) + if match: + method = match.group(1) # e.g., "browsingContext.contextCreated" + params_type = match.group(2) # e.g., "browsingContext.Info" + + # Extract module name from method + if "." in method: + module_name, _ = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract event name from definition name (e.g., browsingContext.ContextCreated) + _, event_name = def_name.rsplit(".", 1) + + # Create event + event = CddlEvent( + module=module_name, + name=event_name, + method=method, + params_type=params_type, + description=f"Event: {method}", + ) + + self.modules[module_name].events.append(event) + logger.debug( + f"Found event: {def_name} (method={method}, params={params_type})" + ) + + def _extract_commands(self) -> None: + """Extract command definitions from parsed definitions.""" + # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) + command_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip definitions that are events (they share the same pattern) + if def_name in self.event_names: + continue + matches = list(command_pattern.finditer(def_content)) + if matches: + for match in matches: + method = match.group(1) # e.g., "session.new" + params_type = match.group(2) # e.g., "session.NewParameters" + + # Extract module name from method + if "." in method: + module_name, command_name = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract parameters + params = self._extract_parameters(params_type) + + # Create command + cmd = CddlCommand( + module=module_name, + name=command_name, + params=params, + description=f"Execute {method}", + ) + + self.modules[module_name].commands.append(cmd) + logger.debug( + f"Found command: {method} with params {params_type}" + ) + + def _extract_parameters( + self, params_type: str, _seen: Optional[Set[str]] = None + ) -> Dict[str, str]: + """Extract parameters from a parameter type definition. + + Handles both struct types ({...}) and top-level union types (TypeA / TypeB), + merging all fields from each alternative as optional parameters. + """ + params = {} + + if _seen is None: + _seen = set() + if params_type in _seen: + return params + _seen.add(params_type) + + if params_type not in self.definitions: + logger.debug(f"Parameter type not found: {params_type}") + return params + + definition = self.definitions[params_type] + + # Handle top-level type alias that is a union of other named types: + # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + # These definitions contain a single line with "/" separating type names + # (not the double-slash "//" used for command unions). + stripped = definition.strip() + if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: + # Each token separated by "/" should be a named type reference + alternatives = [a.strip() for a in stripped.split("/") if a.strip()] + all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) + if all_named: + for alt_type in alternatives: + alt_params = self._extract_parameters(alt_type, _seen) + params.update(alt_params) + return params + + # Remove the outer curly braces and split by comma + # Then parse each line for key: type patterns + clean_def = stripped + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Split by newlines and process each line + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line: + continue + + # Match pattern: [?] name: type + # Using a simple pattern that handles optional prefix + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + param_name = match.group(1).strip() + param_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(param_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + params[param_name] = normalized_type + logger.debug( + f"Extracted param {param_name}: {normalized_type} from {params_type}" + ) + + return params + + +def module_name_to_class_name(module_name: str) -> str: + """Convert module name to class name (PascalCase). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + """ + if "_" in module_name: + # Snake_case: browsing_context -> BrowsingContext + return "".join(word.capitalize() for word in module_name.split("_")) + else: + # CamelCase: browsingContext -> BrowsingContext + return module_name[0].upper() + module_name[1:] if module_name else "" + + +def module_name_to_filename(module_name: str) -> str: + """Convert module name to Python filename (snake_case). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + Special cases: + - browsingContext -> browsing_context + - webExtension -> webextension + """ + # Handle explicit mappings for known camelCase names + camel_to_snake_map = { + "browsingContext": "browsing_context", + "webExtension": "webextension", + } + + if module_name in camel_to_snake_map: + return camel_to_snake_map[module_name] + + if "_" in module_name: + # Already snake_case + return module_name + else: + # Convert camelCase to snake_case for other cases + # This handles cases like "myModuleName" -> "my_module_name" + import re + + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: + """Generate __init__.py file for the module.""" + init_path = output_path / "__init__.py" + + code = f"""{SHARED_HEADER} + +from __future__ import annotations + +""" + + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + filename = module_name_to_filename(module_name) + code += f"from .{filename} import {class_name}\n" + + code += f"\n__all__ = [\n" + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + code += f' "{class_name}",\n' + code += "]\n" + + with open(init_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {init_path}") + + +def generate_common_file(output_path: Path) -> None: + """Generate common.py file with shared utilities.""" + common_path = output_path / "common.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""Common utilities for BiDi command construction."""\n' + "\n" + "from typing import Any, Dict, Generator\n" + "\n" + "\n" + "def command_builder(\n" + " method: str, params: Dict[str, Any]\n" + ") -> Generator[Dict[str, Any], Any, Any]:\n" + ' """Build a BiDi command generator.\n' + "\n" + " Args:\n" + ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' + " params: The parameters for the command\n" + "\n" + " Yields:\n" + " A dictionary representing the BiDi command\n" + "\n" + " Returns:\n" + " The result from the BiDi command execution\n" + ' """\n' + ' result = yield {"method": method, "params": params}\n' + " return result\n" + ) + + with open(common_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {common_path}") + + +def generate_console_file(output_path: Path) -> None: + """Generate console.py file with Console enum helper.""" + console_path = output_path / "console.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + "from enum import Enum\n" + "\n" + "\n" + "class Console(Enum):\n" + ' ALL = "all"\n' + ' LOG = "log"\n' + ' ERROR = "error"\n' + ) + + with open(console_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {console_path}") + + +def generate_permissions_file(output_path: Path) -> None: + """Generate permissions.py file with permission-related classes.""" + permissions_path = output_path / "permissions.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""WebDriver BiDi Permissions module."""\n' + "\n" + "from __future__ import annotations\n" + "\n" + "from enum import Enum\n" + "from typing import Any, Optional, Union\n" + "\n" + "from .common import command_builder\n" + "\n" + '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' + "\n" + "\n" + "class PermissionState(str, Enum):\n" + ' """Permission state enumeration."""\n' + "\n" + ' GRANTED = "granted"\n' + ' DENIED = "denied"\n' + ' PROMPT = "prompt"\n' + "\n" + "\n" + "class PermissionDescriptor:\n" + ' """Descriptor for a permission."""\n' + "\n" + " def __init__(self, name: str) -> None:\n" + ' """Initialize a PermissionDescriptor.\n' + "\n" + " Args:\n" + " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" + ' """\n' + " self.name = name\n" + "\n" + " def __repr__(self) -> str:\n" + " return f\"PermissionDescriptor('{self.name}')\"\n" + "\n" + "\n" + "class Permissions:\n" + ' """WebDriver BiDi Permissions module."""\n' + "\n" + " def __init__(self, websocket_connection: Any) -> None:\n" + ' """Initialize the Permissions module.\n' + "\n" + " Args:\n" + " websocket_connection: The WebSocket connection for sending BiDi commands\n" + ' """\n' + " self._conn = websocket_connection\n" + "\n" + " def set_permission(\n" + " self,\n" + " descriptor: Union[PermissionDescriptor, str],\n" + " state: Union[PermissionState, str],\n" + " origin: Optional[str] = None,\n" + " user_context: Optional[str] = None,\n" + " ) -> None:\n" + ' """Set a permission for a given origin.\n' + "\n" + " Args:\n" + " descriptor: The permission descriptor or permission name as a string\n" + " state: The desired permission state\n" + " origin: The origin for which to set the permission\n" + " user_context: Optional user context ID to scope the permission\n" + "\n" + " Raises:\n" + " ValueError: If the state is not a valid permission state\n" + ' """\n' + " state_value = state.value if isinstance(state, PermissionState) else state\n" + " if state_value not in _VALID_PERMISSION_STATES:\n" + " raise ValueError(\n" + ' f"Invalid permission state: {state_value!r}. "\n' + ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' + " )\n" + "\n" + " if isinstance(descriptor, str):\n" + ' descriptor_dict = {"name": descriptor}\n' + " else:\n" + ' descriptor_dict = {"name": descriptor.name}\n' + "\n" + " params: dict[str, Any] = {\n" + ' "descriptor": descriptor_dict,\n' + ' "state": state_value,\n' + " }\n" + " if origin is not None:\n" + ' params["origin"] = origin\n' + " if user_context is not None:\n" + ' params["userContext"] = user_context\n' + "\n" + ' cmd = command_builder("permissions.setPermission", params)\n' + " self._conn.execute(cmd)\n" + ) + + with open(permissions_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {permissions_path}") + + +def main( + cddl_file: str, + output_dir: str, + spec_version: str = "1.0", + enhancements_manifest: Optional[str] = None, +) -> None: + """Main entry point. + + Args: + cddl_file: Path to CDDL specification file + output_dir: Output directory for generated modules + spec_version: BiDi spec version + enhancements_manifest: Path to enhancement manifest Python file + """ + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + logger.info(f"WebDriver BiDi Code Generator v{__version__}") + logger.info(f"Input CDDL: {cddl_file}") + logger.info(f"Output directory: {output_path}") + logger.info(f"Spec version: {spec_version}") + + # Load enhancement manifest + manifest = load_enhancements_manifest(enhancements_manifest) + if manifest: + logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") + + # Parse CDDL + parser = CddlParser(cddl_file) + modules = parser.parse() + + logger.info(f"Parsed {len(modules)} modules") + + # Clean up existing generated files + for file_path in output_path.glob("*.py"): + if file_path.name != "py.typed" and not file_path.name.startswith("_"): + file_path.unlink() + logger.debug(f"Removed: {file_path}") + + # Generate module files using snake_case filenames + for module_name, module in sorted(modules.items()): + filename = module_name_to_filename(module_name) + module_path = output_path / f"{filename}.py" + + # Get module-specific enhancements (merge with dataclass templates) + module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) + + # Add dataclass methods and docstrings to the enhancement data for this module + full_module_enhancements = { + **module_enhancements, + "dataclass_methods": manifest.get("dataclass_methods", {}), + "method_docstrings": manifest.get("method_docstrings", {}), + } + + with open(module_path, "w", encoding="utf-8") as f: + f.write(module.generate_code(full_module_enhancements)) + logger.info(f"Generated: {module_path}") + + # Generate __init__.py + generate_init_file(output_path, modules) + + # Generate common.py + generate_common_file(output_path) + + # Generate permissions.py + generate_permissions_file(output_path) + + # Generate console.py + generate_console_file(output_path) + + # Create py.typed marker + py_typed_path = output_path / "py.typed" + py_typed_path.touch() + logger.info(f"Generated type marker: {py_typed_path}") + + logger.info("Code generation complete!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate Python WebDriver BiDi modules from CDDL specification" + ) + parser.add_argument( + "cddl_file", + help="Path to CDDL specification file", + ) + parser.add_argument( + "output_dir", + help="Output directory for generated Python modules", + ) + parser.add_argument( + "spec_version", + nargs="?", + default="1.0", + help="BiDi spec version (default: 1.0)", + ) + parser.add_argument( + "--enhancements-manifest", + default=None, + help="Path to enhancement manifest Python file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger("generate_bidi").setLevel(logging.DEBUG) + + try: + main( + args.cddl_file, + args.output_dir, + args.spec_version, + args.enhancements_manifest, + ) + sys.exit(0) + except Exception as e: + logger.error(f"Generation failed: {e}", exc_info=True) + sys.exit(1) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index adf0a17128af3..5dcce3c25ffeb 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1351,18 +1351,24 @@ def to_bidi_dict(self) -> dict: params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd)''', - ''' def uninstall(self, extension: Any | None = None): + ''' def uninstall(self, extension: str | dict): """Uninstall a web extension. Args: extension: Either the extension ID string returned by ``install``, or the full result dict returned by ``install`` (the ``"extension"`` value is extracted automatically). + + Raises: + ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): extension = extension.get("extension") + + if extension is None: + raise ValueError("extension parameter is required") + params = {"extension": extension} - params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', ], diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index bb129d5f6a195..7be7bd4f73856 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index ff0c2d59b8cf2..7cf9678c9b007 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: browser from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass def transform_download_params( @@ -77,9 +61,17 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") +class ClientWindowNamedState: + """ClientWindowNamedState.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + + @dataclass class ClientWindowInfo: - """ClientWindowInfo type definition.""" + """ClientWindowInfo.""" active: bool | None = None client_window: Any | None = None @@ -121,14 +113,14 @@ def get_y(self): @dataclass class UserContextInfo: - """UserContextInfo type definition.""" + """UserContextInfo.""" user_context: Any | None = None @dataclass class CreateUserContextParameters: - """CreateUserContextParameters type definition.""" + """CreateUserContextParameters.""" accept_insecure_certs: bool | None = None proxy: Any | None = None @@ -137,35 +129,35 @@ class CreateUserContextParameters: @dataclass class GetClientWindowsResult: - """GetClientWindowsResult type definition.""" + """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = field(default_factory=list) + client_windows: list[Any | None] | None = None @dataclass class GetUserContextsResult: - """GetUserContextsResult type definition.""" + """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class RemoveUserContextParameters: - """RemoveUserContextParameters type definition.""" + """RemoveUserContextParameters.""" user_context: Any | None = None @dataclass class SetClientWindowStateParameters: - """SetClientWindowStateParameters type definition.""" + """SetClientWindowStateParameters.""" client_window: Any | None = None @dataclass class ClientWindowRectState: - """ClientWindowRectState type definition.""" + """ClientWindowRectState.""" state: str = field(default="normal", init=False) width: Any | None = None @@ -176,15 +168,15 @@ class ClientWindowRectState: @dataclass class SetDownloadBehaviorParameters: - """SetDownloadBehaviorParameters type definition.""" + """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class DownloadBehaviorAllowed: - """DownloadBehaviorAllowed type definition.""" + """DownloadBehaviorAllowed.""" type: str = field(default="allowed", init=False) destination_folder: str | None = None @@ -192,7 +184,7 @@ class DownloadBehaviorAllowed: @dataclass class DownloadBehaviorDenied: - """DownloadBehaviorDenied type definition.""" + """DownloadBehaviorDenied.""" type: str = field(default="denied", init=False) @@ -220,12 +212,7 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context( - self, - accept_insecure_certs: bool | None = None, - proxy: Any | None = None, - unhandled_prompt_behavior: Any | None = None, - ): + def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -306,6 +293,22 @@ def set_client_window_state(self, client_window: Any | None = None): result = self._conn.execute(cmd) return result + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) + + download_behavior = None + download_behavior = transform_download_params(allowed, destination_folder) + + params = { + "downloadBehavior": download_behavior, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setDownloadBehavior", params) + result = self._conn.execute(cmd) + return result + def set_download_behavior( self, allowed: bool | None = None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 7a0f8faf8687e..35aea615d1780 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class ReadinessState: """ReadinessState.""" @@ -65,7 +49,7 @@ class DownloadCompleteParams: @dataclass class Info: - """Info type definition.""" + """Info.""" children: Any | None = None client_window: Any | None = None @@ -78,7 +62,7 @@ class Info: @dataclass class AccessibilityLocator: - """AccessibilityLocator type definition.""" + """AccessibilityLocator.""" type: str = field(default="accessibility", init=False) name: str | None = None @@ -87,7 +71,7 @@ class AccessibilityLocator: @dataclass class CssLocator: - """CssLocator type definition.""" + """CssLocator.""" type: str = field(default="css", init=False) value: str | None = None @@ -95,7 +79,7 @@ class CssLocator: @dataclass class ContextLocator: - """ContextLocator type definition.""" + """ContextLocator.""" type: str = field(default="context", init=False) context: Any | None = None @@ -103,7 +87,7 @@ class ContextLocator: @dataclass class InnerTextLocator: - """InnerTextLocator type definition.""" + """InnerTextLocator.""" type: str = field(default="innerText", init=False) value: str | None = None @@ -114,7 +98,7 @@ class InnerTextLocator: @dataclass class XPathLocator: - """XPathLocator type definition.""" + """XPathLocator.""" type: str = field(default="xpath", init=False) value: str | None = None @@ -122,7 +106,7 @@ class XPathLocator: @dataclass class BaseNavigationInfo: - """BaseNavigationInfo type definition.""" + """BaseNavigationInfo.""" context: Any | None = None navigation: Any | None = None @@ -132,14 +116,14 @@ class BaseNavigationInfo: @dataclass class ActivateParameters: - """ActivateParameters type definition.""" + """ActivateParameters.""" context: Any | None = None @dataclass class CaptureScreenshotParameters: - """CaptureScreenshotParameters type definition.""" + """CaptureScreenshotParameters.""" context: Any | None = None format: Any | None = None @@ -148,7 +132,7 @@ class CaptureScreenshotParameters: @dataclass class ImageFormat: - """ImageFormat type definition.""" + """ImageFormat.""" type: str | None = None quality: Any | None = None @@ -156,7 +140,7 @@ class ImageFormat: @dataclass class ElementClipRectangle: - """ElementClipRectangle type definition.""" + """ElementClipRectangle.""" type: str = field(default="element", init=False) element: Any | None = None @@ -164,7 +148,7 @@ class ElementClipRectangle: @dataclass class BoxClipRectangle: - """BoxClipRectangle type definition.""" + """BoxClipRectangle.""" type: str = field(default="box", init=False) x: Any | None = None @@ -175,14 +159,14 @@ class BoxClipRectangle: @dataclass class CaptureScreenshotResult: - """CaptureScreenshotResult type definition.""" + """CaptureScreenshotResult.""" data: str | None = None @dataclass class CloseParameters: - """CloseParameters type definition.""" + """CloseParameters.""" context: Any | None = None prompt_unload: bool | None = None @@ -190,7 +174,7 @@ class CloseParameters: @dataclass class CreateParameters: - """CreateParameters type definition.""" + """CreateParameters.""" type: Any | None = None reference_context: Any | None = None @@ -200,14 +184,14 @@ class CreateParameters: @dataclass class CreateResult: - """CreateResult type definition.""" + """CreateResult.""" context: Any | None = None @dataclass class GetTreeParameters: - """GetTreeParameters type definition.""" + """GetTreeParameters.""" max_depth: Any | None = None root: Any | None = None @@ -215,14 +199,14 @@ class GetTreeParameters: @dataclass class GetTreeResult: - """GetTreeResult type definition.""" + """GetTreeResult.""" contexts: Any | None = None @dataclass class HandleUserPromptParameters: - """HandleUserPromptParameters type definition.""" + """HandleUserPromptParameters.""" context: Any | None = None accept: bool | None = None @@ -231,24 +215,24 @@ class HandleUserPromptParameters: @dataclass class LocateNodesParameters: - """LocateNodesParameters type definition.""" + """LocateNodesParameters.""" context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = field(default_factory=list) + start_nodes: list[Any | None] | None = None @dataclass class LocateNodesResult: - """LocateNodesResult type definition.""" + """LocateNodesResult.""" - nodes: list[Any | None] | None = field(default_factory=list) + nodes: list[Any | None] | None = None @dataclass class NavigateParameters: - """NavigateParameters type definition.""" + """NavigateParameters.""" context: Any | None = None url: str | None = None @@ -257,7 +241,7 @@ class NavigateParameters: @dataclass class NavigateResult: - """NavigateResult type definition.""" + """NavigateResult.""" navigation: Any | None = None url: str | None = None @@ -265,7 +249,7 @@ class NavigateResult: @dataclass class PrintParameters: - """PrintParameters type definition.""" + """PrintParameters.""" context: Any | None = None background: bool | None = None @@ -277,7 +261,7 @@ class PrintParameters: @dataclass class PrintMarginParameters: - """PrintMarginParameters type definition.""" + """PrintMarginParameters.""" bottom: Any | None = None left: Any | None = None @@ -287,7 +271,7 @@ class PrintMarginParameters: @dataclass class PrintPageParameters: - """PrintPageParameters type definition.""" + """PrintPageParameters.""" height: Any | None = None width: Any | None = None @@ -295,14 +279,14 @@ class PrintPageParameters: @dataclass class PrintResult: - """PrintResult type definition.""" + """PrintResult.""" data: str | None = None @dataclass class ReloadParameters: - """ReloadParameters type definition.""" + """ReloadParameters.""" context: Any | None = None ignore_cache: bool | None = None @@ -311,17 +295,17 @@ class ReloadParameters: @dataclass class SetViewportParameters: - """SetViewportParameters type definition.""" + """SetViewportParameters.""" context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class Viewport: - """Viewport type definition.""" + """Viewport.""" width: Any | None = None height: Any | None = None @@ -329,7 +313,7 @@ class Viewport: @dataclass class TraverseHistoryParameters: - """TraverseHistoryParameters type definition.""" + """TraverseHistoryParameters.""" context: Any | None = None delta: Any | None = None @@ -337,16 +321,30 @@ class TraverseHistoryParameters: @dataclass class HistoryUpdatedParameters: - """HistoryUpdatedParameters type definition.""" + """HistoryUpdatedParameters.""" context: Any | None = None timestamp: Any | None = None url: str | None = None +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None + + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: str = field(default="canceled", init=False) + + @dataclass class UserPromptClosedParameters: - """UserPromptClosedParameters type definition.""" + """UserPromptClosedParameters.""" context: Any | None = None accepted: bool | None = None @@ -356,7 +354,7 @@ class UserPromptClosedParameters: @dataclass class UserPromptOpenedParameters: - """UserPromptOpenedParameters type definition.""" + """UserPromptOpenedParameters.""" context: Any | None = None handler: Any | None = None @@ -392,10 +390,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: DownloadParams | None = None + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: + def from_json(cls, params: dict) -> "DownloadEndParams": """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -416,6 +414,8 @@ def from_json(cls, params: dict) -> DownloadEndParams: "history_updated": "browsingContext.historyUpdated", "dom_content_loaded": "browsingContext.domContentLoaded", "load": "browsingContext.load", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", "navigation_aborted": "browsingContext.navigationAborted", "navigation_committed": "browsingContext.navigationCommitted", "navigation_failed": "browsingContext.navigationFailed", @@ -630,13 +630,7 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot( - self, - context: str | None = None, - format: Any | None = None, - clip: Any | None = None, - origin: str | None = None, - ): + def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): """Execute browsingContext.captureScreenshot.""" params = { "context": context, @@ -663,13 +657,7 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create( - self, - type: Any | None = None, - reference_context: Any | None = None, - background: bool | None = None, - user_context: Any | None = None, - ): + def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): """Execute browsingContext.create.""" params = { "type": type, @@ -723,14 +711,7 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes( - self, - context: str | None = None, - locator: Any | None = None, - serialization_options: Any | None = None, - start_nodes: Any | None = None, - max_node_count: int | None = None, - ): + def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): """Execute browsingContext.locateNodes.""" params = { "context": context, @@ -759,15 +740,7 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print( - self, - context: Any | None = None, - background: bool | None = None, - margin: Any | None = None, - page: Any | None = None, - scale: Any | None = None, - shrink_to_fit: bool | None = None, - ): + def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): """Execute browsingContext.print.""" params = { "context": context, @@ -797,13 +770,7 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport( - self, - context: str | None = None, - viewport: Any | None = None, - user_contexts: Any | None = None, - device_pixel_ratio: Any | None = None, - ): + def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -901,81 +868,20 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": ( - EventConfig("context_created", "browsingContext.contextCreated", - _globals.get("ContextCreated", dict)) - if _globals.get("ContextCreated") - else EventConfig("context_created", "browsingContext.contextCreated", dict) - ), - "context_destroyed": ( - EventConfig("context_destroyed", "browsingContext.contextDestroyed", - _globals.get("ContextDestroyed", dict)) - if _globals.get("ContextDestroyed") - else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict) - ), - "navigation_started": ( - EventConfig("navigation_started", "browsingContext.navigationStarted", - _globals.get("NavigationStarted", dict)) - if _globals.get("NavigationStarted") - else EventConfig("navigation_started", "browsingContext.navigationStarted", dict) - ), - "fragment_navigated": ( - EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", - _globals.get("FragmentNavigated", dict)) - if _globals.get("FragmentNavigated") - else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict) - ), - "history_updated": ( - EventConfig("history_updated", "browsingContext.historyUpdated", - _globals.get("HistoryUpdated", dict)) - if _globals.get("HistoryUpdated") - else EventConfig("history_updated", "browsingContext.historyUpdated", dict) - ), - "dom_content_loaded": ( - EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", - _globals.get("DomContentLoaded", dict)) - if _globals.get("DomContentLoaded") - else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict) - ), - "load": ( - EventConfig("load", "browsingContext.load", - _globals.get("Load", dict)) - if _globals.get("Load") - else EventConfig("load", "browsingContext.load", dict) - ), - "navigation_aborted": ( - EventConfig("navigation_aborted", "browsingContext.navigationAborted", - _globals.get("NavigationAborted", dict)) - if _globals.get("NavigationAborted") - else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict) - ), - "navigation_committed": ( - EventConfig("navigation_committed", "browsingContext.navigationCommitted", - _globals.get("NavigationCommitted", dict)) - if _globals.get("NavigationCommitted") - else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict) - ), - "navigation_failed": ( - EventConfig("navigation_failed", "browsingContext.navigationFailed", - _globals.get("NavigationFailed", dict)) - if _globals.get("NavigationFailed") - else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict) - ), - "user_prompt_closed": ( - EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", - _globals.get("UserPromptClosed", dict)) - if _globals.get("UserPromptClosed") - else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict) - ), - "user_prompt_opened": ( - EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", - _globals.get("UserPromptOpened", dict)) - if _globals.get("UserPromptOpened") - else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict) - ), - "download_will_begin": EventConfig( - "download_will_begin", "browsingContext.downloadWillBegin", - _globals.get("DownloadWillBeginParams", dict), - ), + "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), + "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), + "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), + "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), + "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), + "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), + "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), + "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), + "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), + "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), + "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), + "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), + "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), + "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), + "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), } diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index dae051876833e..d90d8c770263a 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,15 +17,12 @@ """Common utilities for BiDi command construction.""" -from __future__ import annotations - -from collections.abc import Generator -from typing import Any +from typing import Any, Dict, Generator def command_builder( - method: str, params: dict[str, Any] | None = None -) -> Generator[dict[str, Any], Any, Any]: + method: str, params: Dict[str, Any] +) -> Generator[Dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: @@ -38,7 +35,5 @@ def command_builder( Returns: The result from the BiDi command execution """ - if params is None: - params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index c58f6d5f78d6c..a85eaad3e223a 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class ForcedColorsModeTheme: @@ -54,24 +38,24 @@ class ScreenOrientationType: @dataclass class SetForcedColorsModeThemeOverrideParameters: - """SetForcedColorsModeThemeOverrideParameters type definition.""" + """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetGeolocationOverrideParameters: - """SetGeolocationOverrideParameters type definition.""" + """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class GeolocationCoordinates: - """GeolocationCoordinates type definition.""" + """GeolocationCoordinates.""" latitude: Any | None = None longitude: Any | None = None @@ -84,39 +68,39 @@ class GeolocationCoordinates: @dataclass class GeolocationPositionError: - """GeolocationPositionError type definition.""" + """GeolocationPositionError.""" type: str = field(default="positionUnavailable", init=False) @dataclass class SetLocaleOverrideParameters: - """SetLocaleOverrideParameters type definition.""" + """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass -class SetNetworkConditionsParameters: - """SetNetworkConditionsParameters type definition.""" +class setNetworkConditionsParameters: + """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class NetworkConditionsOffline: - """NetworkConditionsOffline type definition.""" + """NetworkConditionsOffline.""" type: str = field(default="offline", init=False) @dataclass class ScreenArea: - """ScreenArea type definition.""" + """ScreenArea.""" width: Any | None = None height: Any | None = None @@ -124,16 +108,16 @@ class ScreenArea: @dataclass class SetScreenSettingsOverrideParameters: - """SetScreenSettingsOverrideParameters type definition.""" + """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class ScreenOrientation: - """ScreenOrientation type definition.""" + """ScreenOrientation.""" natural: Any | None = None type: Any | None = None @@ -141,64 +125,64 @@ class ScreenOrientation: @dataclass class SetScreenOrientationOverrideParameters: - """SetScreenOrientationOverrideParameters type definition.""" + """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetUserAgentOverrideParameters: - """SetUserAgentOverrideParameters type definition.""" + """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetViewportMetaOverrideParameters: - """SetViewportMetaOverrideParameters type definition.""" + """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetScriptingEnabledParameters: - """SetScriptingEnabledParameters type definition.""" + """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetScrollbarTypeOverrideParameters: - """SetScrollbarTypeOverrideParameters type definition.""" + """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetTimezoneOverrideParameters: - """SetTimezoneOverrideParameters type definition.""" + """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetTouchOverrideParameters: - """SetTouchOverrideParameters type definition.""" + """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None class Emulation: @@ -207,12 +191,7 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override( - self, - theme: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setForcedColorsModeThemeOverride.""" params = { "theme": theme, @@ -224,12 +203,18 @@ def set_forced_colors_mode_theme_override( result = self._conn.execute(cmd) return result - def set_locale_override( - self, - locale: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setGeolocationOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result + + def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setLocaleOverride.""" params = { "locale": locale, @@ -241,12 +226,19 @@ def set_locale_override( result = self._conn.execute(cmd) return result - def set_screen_settings_override( - self, - screen_area: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setNetworkConditions.""" + params = { + "networkConditions": network_conditions, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setNetworkConditions", params) + result = self._conn.execute(cmd) + return result + + def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenSettingsOverride.""" params = { "screenArea": screen_area, @@ -258,12 +250,31 @@ def set_screen_settings_override( result = self._conn.execute(cmd) return result - def set_viewport_meta_override( - self, - viewport_meta: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenOrientationOverride.""" + params = { + "screenOrientation": screen_orientation, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenOrientationOverride", params) + result = self._conn.execute(cmd) + return result + + def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setUserAgentOverride.""" + params = { + "userAgent": user_agent, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setUserAgentOverride", params) + result = self._conn.execute(cmd) + return result + + def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setViewportMetaOverride.""" params = { "viewportMeta": viewport_meta, @@ -275,12 +286,19 @@ def set_viewport_meta_override( result = self._conn.execute(cmd) return result - def set_scrollbar_type_override( - self, - scrollbar_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScriptingEnabled.""" + params = { + "enabled": enabled, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScriptingEnabled", params) + result = self._conn.execute(cmd) + return result + + def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScrollbarTypeOverride.""" params = { "scrollbarType": scrollbar_type, @@ -292,7 +310,19 @@ def set_scrollbar_type_override( result = self._conn.execute(cmd) return result - def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): + def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setTimezoneOverride.""" + params = { + "timezone": timezone, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTimezoneOverride", params) + result = self._conn.execute(cmd) + return result + + def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index e9c3f8345f05d..5dbe71dbd3886 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: input from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class PointerType: """PointerType.""" @@ -50,7 +34,7 @@ class Origin: @dataclass class ElementOrigin: - """ElementOrigin type definition.""" + """ElementOrigin.""" type: str = field(default="element", init=False) element: Any | None = None @@ -58,59 +42,59 @@ class ElementOrigin: @dataclass class PerformActionsParameters: - """PerformActionsParameters type definition.""" + """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class NoneSourceActions: - """NoneSourceActions type definition.""" + """NoneSourceActions.""" type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class KeySourceActions: - """KeySourceActions type definition.""" + """KeySourceActions.""" type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PointerSourceActions: - """PointerSourceActions type definition.""" + """PointerSourceActions.""" type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PointerParameters: - """PointerParameters type definition.""" + """PointerParameters.""" pointer_type: Any | None = None @dataclass class WheelSourceActions: - """WheelSourceActions type definition.""" + """WheelSourceActions.""" type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PauseAction: - """PauseAction type definition.""" + """PauseAction.""" type: str = field(default="pause", init=False) duration: Any | None = None @@ -118,7 +102,7 @@ class PauseAction: @dataclass class KeyDownAction: - """KeyDownAction type definition.""" + """KeyDownAction.""" type: str = field(default="keyDown", init=False) value: str | None = None @@ -126,7 +110,7 @@ class KeyDownAction: @dataclass class KeyUpAction: - """KeyUpAction type definition.""" + """KeyUpAction.""" type: str = field(default="keyUp", init=False) value: str | None = None @@ -134,7 +118,7 @@ class KeyUpAction: @dataclass class PointerUpAction: - """PointerUpAction type definition.""" + """PointerUpAction.""" type: str = field(default="pointerUp", init=False) button: Any | None = None @@ -142,7 +126,7 @@ class PointerUpAction: @dataclass class WheelScrollAction: - """WheelScrollAction type definition.""" + """WheelScrollAction.""" type: str = field(default="scroll", init=False) x: Any | None = None @@ -155,7 +139,7 @@ class WheelScrollAction: @dataclass class PointerCommonProperties: - """PointerCommonProperties type definition.""" + """PointerCommonProperties.""" width: Any | None = None height: Any | None = None @@ -168,18 +152,18 @@ class PointerCommonProperties: @dataclass class ReleaseActionsParameters: - """ReleaseActionsParameters type definition.""" + """ReleaseActionsParameters.""" context: Any | None = None @dataclass class SetFilesParameters: - """SetFilesParameters type definition.""" + """SetFilesParameters.""" context: Any | None = None element: Any | None = None - files: list[Any | None] | None = field(default_factory=list) + files: list[Any | None] | None = None @dataclass @@ -191,7 +175,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: + def from_json(cls, params: dict) -> "FileDialogInfo": """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -384,7 +368,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" params = { "context": context, @@ -405,7 +389,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" params = { "context": context, @@ -470,10 +454,5 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": ( - EventConfig("file_dialog_opened", "input.fileDialogOpened", - _globals.get("FileDialogOpened", dict)) - if _globals.get("FileDialogOpened") - else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict) - ), + "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 94f511d7185f8..7aa7fbf7a3171 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,11 +6,14 @@ # WebDriver BiDi module: log from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable from dataclasses import dataclass -from typing import Any - from selenium.webdriver.common.bidi.session import Session @@ -44,7 +30,7 @@ class Level: @dataclass class BaseLogEntry: - """BaseLogEntry type definition.""" + """BaseLogEntry.""" level: Any | None = None source: Any | None = None @@ -55,7 +41,7 @@ class BaseLogEntry: @dataclass class GenericLogEntry: - """GenericLogEntry type definition.""" + """GenericLogEntry.""" type: str | None = None @@ -74,7 +60,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -99,7 +85,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -312,10 +298,5 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Log.EVENT_CONFIGS = { - "entry_added": ( - EventConfig("entry_added", "log.entryAdded", - _globals.get("EntryAdded", dict)) - if _globals.get("EntryAdded") - else EventConfig("entry_added", "log.entryAdded", dict) - ), + "entry_added": (EventConfig("entry_added", "log.entryAdded", _globals.get("EntryAdded", dict)) if _globals.get("EntryAdded") else EventConfig("entry_added", "log.entryAdded", dict)), } diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 9dc5fb94d8488..2290c9fec12d3 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: network from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SameSite: """SameSite.""" @@ -66,7 +50,7 @@ class ContinueWithAuthNoCredentials: @dataclass class AuthChallenge: - """AuthChallenge type definition.""" + """AuthChallenge.""" scheme: str | None = None realm: str | None = None @@ -74,7 +58,7 @@ class AuthChallenge: @dataclass class AuthCredentials: - """AuthCredentials type definition.""" + """AuthCredentials.""" type: str = field(default="password", init=False) username: str | None = None @@ -83,7 +67,7 @@ class AuthCredentials: @dataclass class BaseParameters: - """BaseParameters type definition.""" + """BaseParameters.""" context: Any | None = None is_blocked: bool | None = None @@ -91,12 +75,12 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = field(default_factory=list) + intercepts: list[Any | None] | None = None @dataclass class StringValue: - """StringValue type definition.""" + """StringValue.""" type: str = field(default="string", init=False) value: str | None = None @@ -104,7 +88,7 @@ class StringValue: @dataclass class Base64Value: - """Base64Value type definition.""" + """Base64Value.""" type: str = field(default="base64", init=False) value: str | None = None @@ -112,7 +96,7 @@ class Base64Value: @dataclass class Cookie: - """Cookie type definition.""" + """Cookie.""" name: str | None = None value: Any | None = None @@ -127,7 +111,7 @@ class Cookie: @dataclass class CookieHeader: - """CookieHeader type definition.""" + """CookieHeader.""" name: str | None = None value: Any | None = None @@ -135,7 +119,7 @@ class CookieHeader: @dataclass class FetchTimingInfo: - """FetchTimingInfo type definition.""" + """FetchTimingInfo.""" time_origin: Any | None = None request_time: Any | None = None @@ -154,7 +138,7 @@ class FetchTimingInfo: @dataclass class Header: - """Header type definition.""" + """Header.""" name: str | None = None value: Any | None = None @@ -162,7 +146,7 @@ class Header: @dataclass class Initiator: - """Initiator type definition.""" + """Initiator.""" column_number: Any | None = None line_number: Any | None = None @@ -173,32 +157,32 @@ class Initiator: @dataclass class ResponseContent: - """ResponseContent type definition.""" + """ResponseContent.""" size: Any | None = None @dataclass class ResponseData: - """ResponseData type definition.""" + """ResponseData.""" url: str | None = None protocol: str | None = None status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = None mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = field(default_factory=list) + auth_challenges: list[Any | None] | None = None @dataclass class SetCookieHeader: - """SetCookieHeader type definition.""" + """SetCookieHeader.""" name: str | None = None value: Any | None = None @@ -213,7 +197,7 @@ class SetCookieHeader: @dataclass class UrlPatternPattern: - """UrlPatternPattern type definition.""" + """UrlPatternPattern.""" type: str = field(default="pattern", init=False) protocol: str | None = None @@ -225,7 +209,7 @@ class UrlPatternPattern: @dataclass class UrlPatternString: - """UrlPatternString type definition.""" + """UrlPatternString.""" type: str = field(default="string", init=False) pattern: str | None = None @@ -233,68 +217,68 @@ class UrlPatternString: @dataclass class AddDataCollectorParameters: - """AddDataCollectorParameters type definition.""" + """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = field(default_factory=list) + data_types: list[Any | None] | None = None max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class AddDataCollectorResult: - """AddDataCollectorResult type definition.""" + """AddDataCollectorResult.""" collector: Any | None = None @dataclass class AddInterceptParameters: - """AddInterceptParameters type definition.""" + """AddInterceptParameters.""" - phases: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - url_patterns: list[Any | None] | None = field(default_factory=list) + phases: list[Any | None] | None = None + contexts: list[Any | None] | None = None + url_patterns: list[Any | None] | None = None @dataclass class AddInterceptResult: - """AddInterceptResult type definition.""" + """AddInterceptResult.""" intercept: Any | None = None @dataclass class ContinueResponseParameters: - """ContinueResponseParameters type definition.""" + """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None credentials: Any | None = None - headers: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = None reason_phrase: str | None = None status_code: Any | None = None @dataclass class ContinueWithAuthParameters: - """ContinueWithAuthParameters type definition.""" + """ContinueWithAuthParameters.""" request: Any | None = None @dataclass class ContinueWithAuthCredentials: - """ContinueWithAuthCredentials type definition.""" + """ContinueWithAuthCredentials.""" action: str = field(default="provideCredentials", init=False) credentials: Any | None = None @dataclass -class DisownDataParameters: - """DisownDataParameters type definition.""" +class disownDataParameters: + """disownDataParameters.""" data_type: Any | None = None collector: Any | None = None @@ -303,14 +287,14 @@ class DisownDataParameters: @dataclass class FailRequestParameters: - """FailRequestParameters type definition.""" + """FailRequestParameters.""" request: Any | None = None @dataclass class GetDataParameters: - """GetDataParameters type definition.""" + """GetDataParameters.""" data_type: Any | None = None collector: Any | None = None @@ -320,85 +304,57 @@ class GetDataParameters: @dataclass class GetDataResult: - """GetDataResult type definition.""" + """GetDataResult.""" bytes: Any | None = None @dataclass class ProvideResponseParameters: - """ProvideResponseParameters type definition.""" + """ProvideResponseParameters.""" request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = field(default_factory=list) - headers: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None + headers: list[Any | None] | None = None reason_phrase: str | None = None status_code: Any | None = None @dataclass class RemoveDataCollectorParameters: - """RemoveDataCollectorParameters type definition.""" + """RemoveDataCollectorParameters.""" collector: Any | None = None @dataclass class RemoveInterceptParameters: - """RemoveInterceptParameters type definition.""" + """RemoveInterceptParameters.""" intercept: Any | None = None @dataclass class SetCacheBehaviorParameters: - """SetCacheBehaviorParameters type definition.""" + """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None @dataclass class SetExtraHeadersParameters: - """SetExtraHeadersParameters type definition.""" - - headers: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) - - -@dataclass -class AuthRequiredParameters: - """AuthRequiredParameters type definition.""" - - response: Any | None = None - - -@dataclass -class BeforeRequestSentParameters: - """BeforeRequestSentParameters type definition.""" - - initiator: Any | None = None - - -@dataclass -class FetchErrorParameters: - """FetchErrorParameters type definition.""" - - error_text: str | None = None + """SetExtraHeadersParameters.""" - -@dataclass -class ResponseCompletedParameters: - """ResponseCompletedParameters type definition.""" - - response: Any | None = None + headers: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class ResponseStartedParameters: - """ResponseStartedParameters type definition.""" + """ResponseStartedParameters.""" response: Any | None = None @@ -441,10 +397,6 @@ def continue_request(self, **kwargs): # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", - "before_request_sent": "network.beforeRequestSent", - "fetch_error": "network.fetchError", - "response_completed": "network.responseCompleted", - "response_started": "network.responseStarted", "before_request": "network.beforeRequestSent", } @@ -611,14 +563,7 @@ def __init__(self, conn) -> None: self.intercepts = [] self._handler_intercepts: dict = {} - def add_data_collector( - self, - data_types: list[Any] | None = None, - max_encoded_data_size: Any | None = None, - collector_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.addDataCollector.""" params = { "dataTypes": data_types, @@ -632,12 +577,7 @@ def add_data_collector( result = self._conn.execute(cmd) return result - def add_intercept( - self, - phases: list[Any] | None = None, - contexts: list[Any] | None = None, - url_patterns: list[Any] | None = None, - ): + def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): """Execute network.addIntercept.""" params = { "phases": phases, @@ -649,15 +589,7 @@ def add_intercept( result = self._conn.execute(cmd) return result - def continue_request( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - method: Any | None = None, - url: Any | None = None, - ): + def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): """Execute network.continueRequest.""" params = { "request": request, @@ -672,15 +604,7 @@ def continue_request( result = self._conn.execute(cmd) return result - def continue_response( - self, - request: Any | None = None, - cookies: list[Any] | None = None, - credentials: Any | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): + def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.continueResponse.""" params = { "request": request, @@ -727,13 +651,7 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data( - self, - data_type: Any | None = None, - collector: Any | None = None, - disown: bool | None = None, - request: Any | None = None, - ): + def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): """Execute network.getData.""" params = { "dataType": data_type, @@ -746,15 +664,7 @@ def get_data( result = self._conn.execute(cmd) return result - def provide_response( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): + def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.provideResponse.""" params = { "request": request, @@ -789,7 +699,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" params = { "cacheBehavior": cache_behavior, @@ -800,12 +710,7 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[A result = self._conn.execute(cmd) return result - def set_extra_headers( - self, - headers: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.setExtraHeaders.""" params = { "headers": headers, @@ -817,6 +722,52 @@ def set_extra_headers( result = self._conn.execute(cmd) return result + def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.beforeRequestSent.""" + params = { + "initiator": initiator, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.beforeRequestSent", params) + result = self._conn.execute(cmd) + return result + + def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.fetchError.""" + params = { + "errorText": error_text, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.fetchError", params) + result = self._conn.execute(cmd) + return result + + def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.responseCompleted.""" + params = { + "response": response, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseCompleted", params) + result = self._conn.execute(cmd) + return result + + def response_started(self, response: Any | None = None): + """Execute network.responseStarted.""" + params = { + "response": response, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseStarted", params) + result = self._conn.execute(cmd) + return result + def _add_intercept(self, phases=None, url_patterns=None): """Add a low-level network intercept. @@ -971,51 +922,10 @@ def clear_event_handlers(self) -> None: # Event: network.authRequired AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined -# Event: network.beforeRequestSent -BeforeRequestSent = globals().get('BeforeRequestSentParameters', dict) # Fallback to dict if type not defined - -# Event: network.fetchError -FetchError = globals().get('FetchErrorParameters', dict) # Fallback to dict if type not defined - -# Event: network.responseCompleted -ResponseCompleted = globals().get('ResponseCompletedParameters', dict) # Fallback to dict if type not defined - -# Event: network.responseStarted -ResponseStarted = globals().get('ResponseStartedParameters', dict) # Fallback to dict if type not defined - # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": ( - EventConfig("auth_required", "network.authRequired", - _globals.get("AuthRequired", dict)) - if _globals.get("AuthRequired") - else EventConfig("auth_required", "network.authRequired", dict) - ), - "before_request_sent": ( - EventConfig("before_request_sent", "network.beforeRequestSent", - _globals.get("BeforeRequestSent", dict)) - if _globals.get("BeforeRequestSent") - else EventConfig("before_request_sent", "network.beforeRequestSent", dict) - ), - "fetch_error": ( - EventConfig("fetch_error", "network.fetchError", - _globals.get("FetchError", dict)) - if _globals.get("FetchError") - else EventConfig("fetch_error", "network.fetchError", dict) - ), - "response_completed": ( - EventConfig("response_completed", "network.responseCompleted", - _globals.get("ResponseCompleted", dict)) - if _globals.get("ResponseCompleted") - else EventConfig("response_completed", "network.responseCompleted", dict) - ), - "response_started": ( - EventConfig("response_started", "network.responseStarted", - _globals.get("ResponseStarted", dict)) - if _globals.get("ResponseStarted") - else EventConfig("response_started", "network.responseStarted", dict) - ), + "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 6dd138da17309..f00e765c62e3b 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any +from typing import Any, Optional, Union from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: PermissionDescriptor | str, - state: PermissionState | str, - origin: str | None = None, - user_context: str | None = None, + descriptor: Union[PermissionDescriptor, str], + state: Union[PermissionState, str], + origin: Optional[str] = None, + user_context: Optional[str] = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 0b2ec04101933..c7bfcb3774dff 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: script from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SpecialNumber: """SpecialNumber.""" @@ -64,7 +48,7 @@ class ResultOwnership: @dataclass class ChannelValue: - """ChannelValue type definition.""" + """ChannelValue.""" type: str = field(default="channel", init=False) value: Any | None = None @@ -72,7 +56,7 @@ class ChannelValue: @dataclass class ChannelProperties: - """ChannelProperties type definition.""" + """ChannelProperties.""" channel: Any | None = None serialization_options: Any | None = None @@ -81,7 +65,7 @@ class ChannelProperties: @dataclass class EvaluateResultSuccess: - """EvaluateResultSuccess type definition.""" + """EvaluateResultSuccess.""" type: str = field(default="success", init=False) result: Any | None = None @@ -90,7 +74,7 @@ class EvaluateResultSuccess: @dataclass class EvaluateResultException: - """EvaluateResultException type definition.""" + """EvaluateResultException.""" type: str = field(default="exception", init=False) exception_details: Any | None = None @@ -99,7 +83,7 @@ class EvaluateResultException: @dataclass class ExceptionDetails: - """ExceptionDetails type definition.""" + """ExceptionDetails.""" column_number: Any | None = None exception: Any | None = None @@ -110,7 +94,7 @@ class ExceptionDetails: @dataclass class ArrayLocalValue: - """ArrayLocalValue type definition.""" + """ArrayLocalValue.""" type: str = field(default="array", init=False) value: Any | None = None @@ -118,7 +102,7 @@ class ArrayLocalValue: @dataclass class DateLocalValue: - """DateLocalValue type definition.""" + """DateLocalValue.""" type: str = field(default="date", init=False) value: str | None = None @@ -126,7 +110,7 @@ class DateLocalValue: @dataclass class MapLocalValue: - """MapLocalValue type definition.""" + """MapLocalValue.""" type: str = field(default="map", init=False) value: Any | None = None @@ -134,7 +118,7 @@ class MapLocalValue: @dataclass class ObjectLocalValue: - """ObjectLocalValue type definition.""" + """ObjectLocalValue.""" type: str = field(default="object", init=False) value: Any | None = None @@ -142,7 +126,7 @@ class ObjectLocalValue: @dataclass class RegExpValue: - """RegExpValue type definition.""" + """RegExpValue.""" pattern: str | None = None flags: str | None = None @@ -150,7 +134,7 @@ class RegExpValue: @dataclass class RegExpLocalValue: - """RegExpLocalValue type definition.""" + """RegExpLocalValue.""" type: str = field(default="regexp", init=False) value: Any | None = None @@ -158,7 +142,7 @@ class RegExpLocalValue: @dataclass class SetLocalValue: - """SetLocalValue type definition.""" + """SetLocalValue.""" type: str = field(default="set", init=False) value: Any | None = None @@ -166,21 +150,21 @@ class SetLocalValue: @dataclass class UndefinedValue: - """UndefinedValue type definition.""" + """UndefinedValue.""" type: str = field(default="undefined", init=False) @dataclass class NullValue: - """NullValue type definition.""" + """NullValue.""" type: str = field(default="null", init=False) @dataclass class StringValue: - """StringValue type definition.""" + """StringValue.""" type: str = field(default="string", init=False) value: str | None = None @@ -188,7 +172,7 @@ class StringValue: @dataclass class NumberValue: - """NumberValue type definition.""" + """NumberValue.""" type: str = field(default="number", init=False) value: Any | None = None @@ -196,7 +180,7 @@ class NumberValue: @dataclass class BooleanValue: - """BooleanValue type definition.""" + """BooleanValue.""" type: str = field(default="boolean", init=False) value: bool | None = None @@ -204,7 +188,7 @@ class BooleanValue: @dataclass class BigIntValue: - """BigIntValue type definition.""" + """BigIntValue.""" type: str = field(default="bigint", init=False) value: str | None = None @@ -212,7 +196,7 @@ class BigIntValue: @dataclass class BaseRealmInfo: - """BaseRealmInfo type definition.""" + """BaseRealmInfo.""" realm: Any | None = None origin: str | None = None @@ -220,7 +204,7 @@ class BaseRealmInfo: @dataclass class WindowRealmInfo: - """WindowRealmInfo type definition.""" + """WindowRealmInfo.""" type: str = field(default="window", init=False) context: Any | None = None @@ -229,57 +213,57 @@ class WindowRealmInfo: @dataclass class DedicatedWorkerRealmInfo: - """DedicatedWorkerRealmInfo type definition.""" + """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = field(default_factory=list) + owners: list[Any | None] | None = None @dataclass class SharedWorkerRealmInfo: - """SharedWorkerRealmInfo type definition.""" + """SharedWorkerRealmInfo.""" type: str = field(default="shared-worker", init=False) @dataclass class ServiceWorkerRealmInfo: - """ServiceWorkerRealmInfo type definition.""" + """ServiceWorkerRealmInfo.""" type: str = field(default="service-worker", init=False) @dataclass class WorkerRealmInfo: - """WorkerRealmInfo type definition.""" + """WorkerRealmInfo.""" type: str = field(default="worker", init=False) @dataclass class PaintWorkletRealmInfo: - """PaintWorkletRealmInfo type definition.""" + """PaintWorkletRealmInfo.""" type: str = field(default="paint-worklet", init=False) @dataclass class AudioWorkletRealmInfo: - """AudioWorkletRealmInfo type definition.""" + """AudioWorkletRealmInfo.""" type: str = field(default="audio-worklet", init=False) @dataclass class WorkletRealmInfo: - """WorkletRealmInfo type definition.""" + """WorkletRealmInfo.""" type: str = field(default="worklet", init=False) @dataclass class SharedReference: - """SharedReference type definition.""" + """SharedReference.""" shared_id: Any | None = None handle: Any | None = None @@ -287,7 +271,7 @@ class SharedReference: @dataclass class RemoteObjectReference: - """RemoteObjectReference type definition.""" + """RemoteObjectReference.""" handle: Any | None = None shared_id: Any | None = None @@ -295,7 +279,7 @@ class RemoteObjectReference: @dataclass class SymbolRemoteValue: - """SymbolRemoteValue type definition.""" + """SymbolRemoteValue.""" type: str = field(default="symbol", init=False) handle: Any | None = None @@ -304,7 +288,7 @@ class SymbolRemoteValue: @dataclass class ArrayRemoteValue: - """ArrayRemoteValue type definition.""" + """ArrayRemoteValue.""" type: str = field(default="array", init=False) handle: Any | None = None @@ -314,7 +298,7 @@ class ArrayRemoteValue: @dataclass class ObjectRemoteValue: - """ObjectRemoteValue type definition.""" + """ObjectRemoteValue.""" type: str = field(default="object", init=False) handle: Any | None = None @@ -324,7 +308,7 @@ class ObjectRemoteValue: @dataclass class FunctionRemoteValue: - """FunctionRemoteValue type definition.""" + """FunctionRemoteValue.""" type: str = field(default="function", init=False) handle: Any | None = None @@ -333,7 +317,7 @@ class FunctionRemoteValue: @dataclass class RegExpRemoteValue: - """RegExpRemoteValue type definition.""" + """RegExpRemoteValue.""" handle: Any | None = None internal_id: Any | None = None @@ -341,7 +325,7 @@ class RegExpRemoteValue: @dataclass class DateRemoteValue: - """DateRemoteValue type definition.""" + """DateRemoteValue.""" handle: Any | None = None internal_id: Any | None = None @@ -349,7 +333,7 @@ class DateRemoteValue: @dataclass class MapRemoteValue: - """MapRemoteValue type definition.""" + """MapRemoteValue.""" type: str = field(default="map", init=False) handle: Any | None = None @@ -359,7 +343,7 @@ class MapRemoteValue: @dataclass class SetRemoteValue: - """SetRemoteValue type definition.""" + """SetRemoteValue.""" type: str = field(default="set", init=False) handle: Any | None = None @@ -369,7 +353,7 @@ class SetRemoteValue: @dataclass class WeakMapRemoteValue: - """WeakMapRemoteValue type definition.""" + """WeakMapRemoteValue.""" type: str = field(default="weakmap", init=False) handle: Any | None = None @@ -378,7 +362,7 @@ class WeakMapRemoteValue: @dataclass class WeakSetRemoteValue: - """WeakSetRemoteValue type definition.""" + """WeakSetRemoteValue.""" type: str = field(default="weakset", init=False) handle: Any | None = None @@ -387,7 +371,7 @@ class WeakSetRemoteValue: @dataclass class GeneratorRemoteValue: - """GeneratorRemoteValue type definition.""" + """GeneratorRemoteValue.""" type: str = field(default="generator", init=False) handle: Any | None = None @@ -396,7 +380,7 @@ class GeneratorRemoteValue: @dataclass class ErrorRemoteValue: - """ErrorRemoteValue type definition.""" + """ErrorRemoteValue.""" type: str = field(default="error", init=False) handle: Any | None = None @@ -405,7 +389,7 @@ class ErrorRemoteValue: @dataclass class ProxyRemoteValue: - """ProxyRemoteValue type definition.""" + """ProxyRemoteValue.""" type: str = field(default="proxy", init=False) handle: Any | None = None @@ -414,7 +398,7 @@ class ProxyRemoteValue: @dataclass class PromiseRemoteValue: - """PromiseRemoteValue type definition.""" + """PromiseRemoteValue.""" type: str = field(default="promise", init=False) handle: Any | None = None @@ -423,7 +407,7 @@ class PromiseRemoteValue: @dataclass class TypedArrayRemoteValue: - """TypedArrayRemoteValue type definition.""" + """TypedArrayRemoteValue.""" type: str = field(default="typedarray", init=False) handle: Any | None = None @@ -432,7 +416,7 @@ class TypedArrayRemoteValue: @dataclass class ArrayBufferRemoteValue: - """ArrayBufferRemoteValue type definition.""" + """ArrayBufferRemoteValue.""" type: str = field(default="arraybuffer", init=False) handle: Any | None = None @@ -441,7 +425,7 @@ class ArrayBufferRemoteValue: @dataclass class NodeListRemoteValue: - """NodeListRemoteValue type definition.""" + """NodeListRemoteValue.""" type: str = field(default="nodelist", init=False) handle: Any | None = None @@ -451,7 +435,7 @@ class NodeListRemoteValue: @dataclass class HTMLCollectionRemoteValue: - """HTMLCollectionRemoteValue type definition.""" + """HTMLCollectionRemoteValue.""" type: str = field(default="htmlcollection", init=False) handle: Any | None = None @@ -461,7 +445,7 @@ class HTMLCollectionRemoteValue: @dataclass class NodeRemoteValue: - """NodeRemoteValue type definition.""" + """NodeRemoteValue.""" type: str = field(default="node", init=False) shared_id: Any | None = None @@ -472,11 +456,11 @@ class NodeRemoteValue: @dataclass class NodeProperties: - """NodeProperties type definition.""" + """NodeProperties.""" node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = field(default_factory=list) + children: list[Any | None] | None = None local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -486,7 +470,7 @@ class NodeProperties: @dataclass class WindowProxyRemoteValue: - """WindowProxyRemoteValue type definition.""" + """WindowProxyRemoteValue.""" type: str = field(default="window", init=False) value: Any | None = None @@ -496,14 +480,14 @@ class WindowProxyRemoteValue: @dataclass class WindowProxyProperties: - """WindowProxyProperties type definition.""" + """WindowProxyProperties.""" context: Any | None = None @dataclass class StackFrame: - """StackFrame type definition.""" + """StackFrame.""" column_number: Any | None = None function_name: str | None = None @@ -513,14 +497,14 @@ class StackFrame: @dataclass class StackTrace: - """StackTrace type definition.""" + """StackTrace.""" - call_frames: list[Any | None] | None = field(default_factory=list) + call_frames: list[Any | None] | None = None @dataclass class Source: - """Source type definition.""" + """Source.""" realm: Any | None = None context: Any | None = None @@ -528,14 +512,14 @@ class Source: @dataclass class RealmTarget: - """RealmTarget type definition.""" + """RealmTarget.""" realm: Any | None = None @dataclass class ContextTarget: - """ContextTarget type definition.""" + """ContextTarget.""" context: Any | None = None sandbox: str | None = None @@ -543,38 +527,38 @@ class ContextTarget: @dataclass class AddPreloadScriptParameters: - """AddPreloadScriptParameters type definition.""" + """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + arguments: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None sandbox: str | None = None @dataclass class AddPreloadScriptResult: - """AddPreloadScriptResult type definition.""" + """AddPreloadScriptResult.""" script: Any | None = None @dataclass class DisownParameters: - """DisownParameters type definition.""" + """DisownParameters.""" - handles: list[Any | None] | None = field(default_factory=list) + handles: list[Any | None] | None = None target: Any | None = None @dataclass class CallFunctionParameters: - """CallFunctionParameters type definition.""" + """CallFunctionParameters.""" function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = field(default_factory=list) + arguments: list[Any | None] | None = None result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -583,7 +567,7 @@ class CallFunctionParameters: @dataclass class EvaluateParameters: - """EvaluateParameters type definition.""" + """EvaluateParameters.""" expression: str | None = None target: Any | None = None @@ -595,7 +579,7 @@ class EvaluateParameters: @dataclass class GetRealmsParameters: - """GetRealmsParameters type definition.""" + """GetRealmsParameters.""" context: Any | None = None type: Any | None = None @@ -603,21 +587,21 @@ class GetRealmsParameters: @dataclass class GetRealmsResult: - """GetRealmsResult type definition.""" + """GetRealmsResult.""" - realms: list[Any | None] | None = field(default_factory=list) + realms: list[Any | None] | None = None @dataclass class RemovePreloadScriptParameters: - """RemovePreloadScriptParameters type definition.""" + """RemovePreloadScriptParameters.""" script: Any | None = None @dataclass class MessageParameters: - """MessageParameters type definition.""" + """MessageParameters.""" channel: Any | None = None data: Any | None = None @@ -626,14 +610,13 @@ class MessageParameters: @dataclass class RealmDestroyedParameters: - """RealmDestroyedParameters type definition.""" + """RealmDestroyedParameters.""" realm: Any | None = None # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { - "message": "script.message", "realm_created": "script.realmCreated", "realm_destroyed": "script.realmDestroyed", } @@ -800,14 +783,7 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script( - self, - function_declaration: Any | None = None, - arguments: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - sandbox: Any | None = None, - ): + def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): """Execute script.addPreloadScript.""" params = { "functionDeclaration": function_declaration, @@ -821,7 +797,7 @@ def add_preload_script( result = self._conn.execute(cmd) return result - def disown(self, handles: list[Any] | None = None, target: Any | None = None): + def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" params = { "handles": handles, @@ -832,17 +808,7 @@ def disown(self, handles: list[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function( - self, - function_declaration: Any | None = None, - await_promise: bool | None = None, - target: Any | None = None, - arguments: list[Any] | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - this: Any | None = None, - user_activation: bool | None = None, - ): + def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): """Execute script.callFunction.""" params = { "functionDeclaration": function_declaration, @@ -859,15 +825,7 @@ def call_function( result = self._conn.execute(cmd) return result - def evaluate( - self, - expression: Any | None = None, - target: Any | None = None, - await_promise: bool | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - user_activation: bool | None = None, - ): + def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): """Execute script.evaluate.""" params = { "expression": expression, @@ -903,6 +861,18 @@ def remove_preload_script(self, script: Any | None = None): result = self._conn.execute(cmd) return result + def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): + """Execute script.message.""" + params = { + "channel": channel, + "data": data, + "source": source, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.message", params) + result = self._conn.execute(cmd) + return result + def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: """Execute a function declaration in the browser context. @@ -919,9 +889,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1162,9 +1131,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" @@ -1304,9 +1272,6 @@ def clear_event_handlers(self) -> None: return self._event_manager.clear_event_handlers() # Event Info Type Aliases -# Event: script.message -Message = globals().get('MessageParameters', dict) # Fallback to dict if type not defined - # Event: script.realmCreated RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined @@ -1317,22 +1282,6 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "message": ( - EventConfig("message", "script.message", - _globals.get("Message", dict)) - if _globals.get("Message") - else EventConfig("message", "script.message", dict) - ), - "realm_created": ( - EventConfig("realm_created", "script.realmCreated", - _globals.get("RealmCreated", dict)) - if _globals.get("RealmCreated") - else EventConfig("realm_created", "script.realmCreated", dict) - ), - "realm_destroyed": ( - EventConfig("realm_destroyed", "script.realmDestroyed", - _globals.get("RealmDestroyed", dict)) - if _globals.get("RealmDestroyed") - else EventConfig("realm_destroyed", "script.realmDestroyed", dict) - ), + "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), + "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 771a5327151bf..9b1daaae557fa 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: session from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class UserPromptHandlerType: @@ -39,15 +23,15 @@ class UserPromptHandlerType: @dataclass class CapabilitiesRequest: - """CapabilitiesRequest type definition.""" + """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = field(default_factory=list) + first_match: list[Any | None] | None = None @dataclass class CapabilityRequest: - """CapabilityRequest type definition.""" + """CapabilityRequest.""" accept_insecure_certs: bool | None = None browser_name: str | None = None @@ -59,31 +43,31 @@ class CapabilityRequest: @dataclass class AutodetectProxyConfiguration: - """AutodetectProxyConfiguration type definition.""" + """AutodetectProxyConfiguration.""" proxy_type: str = field(default="autodetect", init=False) @dataclass class DirectProxyConfiguration: - """DirectProxyConfiguration type definition.""" + """DirectProxyConfiguration.""" proxy_type: str = field(default="direct", init=False) @dataclass class ManualProxyConfiguration: - """ManualProxyConfiguration type definition.""" + """ManualProxyConfiguration.""" proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = field(default_factory=list) + no_proxy: list[Any | None] | None = None @dataclass class SocksProxyConfiguration: - """SocksProxyConfiguration type definition.""" + """SocksProxyConfiguration.""" socks_proxy: str | None = None socks_version: Any | None = None @@ -91,7 +75,7 @@ class SocksProxyConfiguration: @dataclass class PacProxyConfiguration: - """PacProxyConfiguration type definition.""" + """PacProxyConfiguration.""" proxy_type: str = field(default="pac", init=False) proxy_autoconfig_url: str | None = None @@ -99,37 +83,37 @@ class PacProxyConfiguration: @dataclass class SystemProxyConfiguration: - """SystemProxyConfiguration type definition.""" + """SystemProxyConfiguration.""" proxy_type: str = field(default="system", init=False) @dataclass class SubscribeParameters: - """SubscribeParameters type definition.""" + """SubscribeParameters.""" - events: list[str | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + events: list[str | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class UnsubscribeByIDRequest: - """UnsubscribeByIDRequest type definition.""" + """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = field(default_factory=list) + subscriptions: list[Any | None] | None = None @dataclass class UnsubscribeByAttributesRequest: - """UnsubscribeByAttributesRequest type definition.""" + """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = field(default_factory=list) + events: list[str | None] | None = None @dataclass class StatusResult: - """StatusResult type definition.""" + """StatusResult.""" ready: bool | None = None message: str | None = None @@ -137,14 +121,14 @@ class StatusResult: @dataclass class NewParameters: - """NewParameters type definition.""" + """NewParameters.""" capabilities: Any | None = None @dataclass class NewResult: - """NewResult type definition.""" + """NewResult.""" session_id: str | None = None accept_insecure_certs: bool | None = None @@ -160,7 +144,7 @@ class NewResult: @dataclass class SubscribeResult: - """SubscribeResult type definition.""" + """SubscribeResult.""" subscription: Any | None = None @@ -227,12 +211,7 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe( - self, - events: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute session.subscribe.""" params = { "events": events, @@ -244,7 +223,7 @@ def subscribe( result = self._conn.execute(cmd) return result - def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): + def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7623381706040..7e4c9c6dee459 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: storage from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass @dataclass class PartitionKey: - """PartitionKey type definition.""" + """PartitionKey.""" user_context: str | None = None source_origin: str | None = None @@ -39,7 +23,7 @@ class PartitionKey: @dataclass class GetCookiesParameters: - """GetCookiesParameters type definition.""" + """GetCookiesParameters.""" filter: Any | None = None partition: Any | None = None @@ -47,15 +31,15 @@ class GetCookiesParameters: @dataclass class GetCookiesResult: - """GetCookiesResult type definition.""" + """GetCookiesResult.""" - cookies: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None partition_key: Any | None = None @dataclass class SetCookieParameters: - """SetCookieParameters type definition.""" + """SetCookieParameters.""" cookie: Any | None = None partition: Any | None = None @@ -63,14 +47,14 @@ class SetCookieParameters: @dataclass class SetCookieResult: - """SetCookieResult type definition.""" + """SetCookieResult.""" partition_key: Any | None = None @dataclass class DeleteCookiesParameters: - """DeleteCookiesParameters type definition.""" + """DeleteCookiesParameters.""" filter: Any | None = None partition: Any | None = None @@ -78,7 +62,7 @@ class DeleteCookiesParameters: @dataclass class DeleteCookiesResult: - """DeleteCookiesResult type definition.""" + """DeleteCookiesResult.""" partition_key: Any | None = None @@ -123,7 +107,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): @@ -251,6 +235,39 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn + def get_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.getCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + return result + + def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): + """Execute storage.setCookie.""" + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result + + def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.deleteCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result + def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 99250afca4c68..98d852512f591 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,22 +6,23 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass @dataclass class InstallParameters: - """InstallParameters type definition.""" + """InstallParameters.""" extension_data: Any | None = None @dataclass class ExtensionPath: - """ExtensionPath type definition.""" + """ExtensionPath.""" type: str = field(default="path", init=False) path: str | None = None @@ -46,7 +30,7 @@ class ExtensionPath: @dataclass class ExtensionArchivePath: - """ExtensionArchivePath type definition.""" + """ExtensionArchivePath.""" type: str = field(default="archivePath", init=False) path: str | None = None @@ -54,7 +38,7 @@ class ExtensionArchivePath: @dataclass class ExtensionBase64Encoded: - """ExtensionBase64Encoded type definition.""" + """ExtensionBase64Encoded.""" type: str = field(default="base64", init=False) value: str | None = None @@ -62,14 +46,14 @@ class ExtensionBase64Encoded: @dataclass class InstallResult: - """InstallResult type definition.""" + """InstallResult.""" extension: Any | None = None @dataclass class UninstallParameters: - """UninstallParameters type definition.""" + """UninstallParameters.""" extension: Any | None = None @@ -104,9 +88,13 @@ def install( ValueError: If more than one, or none, of the arguments is provided. """ provided = [ - k for k, v in { - "path": path, "archive_path": archive_path, "base64_value": base64_value, - }.items() if v is not None + k + for k, v in { + "path": path, + "archive_path": archive_path, + "base64_value": base64_value, + }.items() + if v is not None ] if len(provided) != 1: raise ValueError( @@ -121,17 +109,24 @@ def install( params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd) - def uninstall(self, extension: Any | None = None): + + def uninstall(self, extension: str | dict): """Uninstall a web extension. Args: extension: Either the extension ID string returned by ``install``, or the full result dict returned by ``install`` (the ``"extension"`` value is extracted automatically). + + Raises: + ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): extension = extension.get("extension") + + if extension is None: + raise ValueError("extension parameter is required") + params = {"extension": extension} - params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From 750b091522b15a05ea4cbf182eb38ea8d11def43 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:57:06 +0000 Subject: [PATCH 10/67] Fix webextension and log from comments --- py/generate_bidi.py | 21 ++++++++++++++++++- py/private/bidi_enhancements_manifest.py | 7 +++++++ py/selenium/webdriver/common/bidi/log.py | 4 +++- .../webdriver/common/bidi/webextension.py | 11 +++------- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index d14e2575c8bfd..53eb3a9e52fcc 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -672,6 +672,11 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code += extra_cls code += "\n\n" + # Emit extra type aliases from enhancement manifest (e.g., union types for events) + for extra_alias in enhancements.get("extra_type_aliases", []): + code += extra_alias + code += "\n\n" + # NOTE: Don't generate event type aliases here - they reference types that may not be defined yet # They will be generated after the class definition instead @@ -976,8 +981,22 @@ def clear_event_handlers(self) -> None: # This ensures all types are available when we create the aliases if self.events: code += "\n# Event Info Type Aliases\n" + # Check for explicit event_type_aliases in the enhancement manifest + event_type_aliases = enhancements.get("event_type_aliases", {}) for event_def in self.events: - code += event_def.to_python_dataclass() + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # Check if there's an explicit alias defined in the enhancement manifest + if event_name in event_type_aliases: + # Use the alias directly + type_name = event_type_aliases[event_name] + code += f"# Event: {event_def.method}\n" + code += f"{event_def.name} = {type_name}\n" + else: + # Fall back to the original behavior + code += event_def.to_python_dataclass() code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 5dcce3c25ffeb..f06a4119625e6 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -284,6 +284,13 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), )''', ], + # Define Entry union type for log.entryAdded event deserialization + "extra_type_aliases": [ + "Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry", + ], + "event_type_aliases": { + "entry_added": "Entry", + }, }, "emulation": { "extra_methods": [ diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 7aa7fbf7a3171..c58018e8a947a 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -96,6 +96,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), ) +Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "entry_added": "log.entryAdded", @@ -292,7 +294,7 @@ def clear_event_handlers(self) -> None: # Event Info Type Aliases # Event: log.entryAdded -EntryAdded = globals().get('Entry', dict) # Fallback to dict if type not defined +EntryAdded = Entry # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 98d852512f591..e007f8e4792a6 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -88,13 +88,9 @@ def install( ValueError: If more than one, or none, of the arguments is provided. """ provided = [ - k - for k, v in { - "path": path, - "archive_path": archive_path, - "base64_value": base64_value, - }.items() - if v is not None + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None ] if len(provided) != 1: raise ValueError( @@ -109,7 +105,6 @@ def install( params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd) - def uninstall(self, extension: str | dict): """Uninstall a web extension. From 02867e0e52bcb5e14522bbfbc91cfcbc282afa85 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:06:48 +0000 Subject: [PATCH 11/67] Correct usage of dafault_factory --- py/generate_bidi.py | 13 +++-- py/selenium/webdriver/common/bidi/browser.py | 6 +-- .../webdriver/common/bidi/browsing_context.py | 6 +-- .../webdriver/common/bidi/emulation.py | 48 +++++++++---------- py/selenium/webdriver/common/bidi/input.py | 12 ++--- py/selenium/webdriver/common/bidi/network.py | 34 ++++++------- py/selenium/webdriver/common/bidi/script.py | 18 +++---- py/selenium/webdriver/common/bidi/session.py | 14 +++--- py/selenium/webdriver/common/bidi/storage.py | 2 +- 9 files changed, 80 insertions(+), 73 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 53eb3a9e52fcc..f4915aa1ad123 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -385,9 +385,16 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> if literal_match: literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' - # Check if this field is a list type - elif "List[" in python_type: - code += f" {snake_name}: {python_type} = field(default_factory=list)\n" + # Check if this field is a list type (using lowercase 'list[' from Python 3.10+ syntax) + elif python_type.startswith("list["): + # Remove the trailing ' | None' from list types since default_factory=list ensures non-None + type_annotation = python_type.replace(" | None", "") + code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n" + # Check if this field is a dict type (using lowercase 'dict[' from Python 3.10+ syntax) + elif python_type.startswith("dict["): + # Remove the trailing ' | None' from dict types since default_factory=dict ensures non-None + type_annotation = python_type.replace(" | None", "") + code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n" else: code += f" {snake_name}: {python_type} = None\n" diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 7cf9678c9b007..0618beb14ddef 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -131,14 +131,14 @@ class CreateUserContextParameters: class GetClientWindowsResult: """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = None + client_windows: list[Any] = field(default_factory=list) @dataclass class GetUserContextsResult: """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -171,7 +171,7 @@ class SetDownloadBehaviorParameters: """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 35aea615d1780..d17829709c0c3 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -220,14 +220,14 @@ class LocateNodesParameters: context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = None + start_nodes: list[Any] = field(default_factory=list) @dataclass class LocateNodesResult: """LocateNodesResult.""" - nodes: list[Any | None] | None = None + nodes: list[Any] = field(default_factory=list) @dataclass @@ -300,7 +300,7 @@ class SetViewportParameters: context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index a85eaad3e223a..7edb7a9dacd06 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -41,16 +41,16 @@ class SetForcedColorsModeThemeOverrideParameters: """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class SetGeolocationOverrideParameters: """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -78,8 +78,8 @@ class SetLocaleOverrideParameters: """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -87,8 +87,8 @@ class setNetworkConditionsParameters: """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -111,8 +111,8 @@ class SetScreenSettingsOverrideParameters: """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -128,8 +128,8 @@ class SetScreenOrientationOverrideParameters: """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -137,8 +137,8 @@ class SetUserAgentOverrideParameters: """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -146,8 +146,8 @@ class SetViewportMetaOverrideParameters: """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -155,8 +155,8 @@ class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -164,8 +164,8 @@ class SetScrollbarTypeOverrideParameters: """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -173,16 +173,16 @@ class SetTimezoneOverrideParameters: """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class SetTouchOverrideParameters: """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) class Emulation: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 5dbe71dbd3886..a294bde307b89 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -45,7 +45,7 @@ class PerformActionsParameters: """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -54,7 +54,7 @@ class NoneSourceActions: type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -63,7 +63,7 @@ class KeySourceActions: type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -73,7 +73,7 @@ class PointerSourceActions: type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -89,7 +89,7 @@ class WheelSourceActions: type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -163,7 +163,7 @@ class SetFilesParameters: context: Any | None = None element: Any | None = None - files: list[Any | None] | None = None + files: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 2290c9fec12d3..af079f421546c 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -75,7 +75,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = None + intercepts: list[Any] = field(default_factory=list) @dataclass @@ -171,13 +171,13 @@ class ResponseData: status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = None + auth_challenges: list[Any] = field(default_factory=list) @dataclass @@ -219,11 +219,11 @@ class UrlPatternString: class AddDataCollectorParameters: """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = None + data_types: list[Any] = field(default_factory=list) max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -237,9 +237,9 @@ class AddDataCollectorResult: class AddInterceptParameters: """AddInterceptParameters.""" - phases: list[Any | None] | None = None - contexts: list[Any | None] | None = None - url_patterns: list[Any | None] | None = None + phases: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + url_patterns: list[Any] = field(default_factory=list) @dataclass @@ -254,9 +254,9 @@ class ContinueResponseParameters: """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) credentials: Any | None = None - headers: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -315,8 +315,8 @@ class ProvideResponseParameters: request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = None - headers: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) + headers: list[Any] = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -340,16 +340,16 @@ class SetCacheBehaviorParameters: """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) @dataclass class SetExtraHeadersParameters: """SetExtraHeadersParameters.""" - headers: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index c7bfcb3774dff..492d1fe431680 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -216,7 +216,7 @@ class DedicatedWorkerRealmInfo: """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = None + owners: list[Any] = field(default_factory=list) @dataclass @@ -460,7 +460,7 @@ class NodeProperties: node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = None + children: list[Any] = field(default_factory=list) local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -499,7 +499,7 @@ class StackFrame: class StackTrace: """StackTrace.""" - call_frames: list[Any | None] | None = None + call_frames: list[Any] = field(default_factory=list) @dataclass @@ -530,9 +530,9 @@ class AddPreloadScriptParameters: """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + arguments: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) sandbox: str | None = None @@ -547,7 +547,7 @@ class AddPreloadScriptResult: class DisownParameters: """DisownParameters.""" - handles: list[Any | None] | None = None + handles: list[Any] = field(default_factory=list) target: Any | None = None @@ -558,7 +558,7 @@ class CallFunctionParameters: function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = None + arguments: list[Any] = field(default_factory=list) result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -589,7 +589,7 @@ class GetRealmsParameters: class GetRealmsResult: """GetRealmsResult.""" - realms: list[Any | None] | None = None + realms: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 9b1daaae557fa..f1430cb6e59d3 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -26,7 +26,7 @@ class CapabilitiesRequest: """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = None + first_match: list[Any] = field(default_factory=list) @dataclass @@ -62,7 +62,7 @@ class ManualProxyConfiguration: proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = None + no_proxy: list[Any] = field(default_factory=list) @dataclass @@ -92,23 +92,23 @@ class SystemProxyConfiguration: class SubscribeParameters: """SubscribeParameters.""" - events: list[str | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + events: list[str] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class UnsubscribeByIDRequest: """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = None + subscriptions: list[Any] = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = None + events: list[str] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7e4c9c6dee459..833e9cdc74f2a 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -33,7 +33,7 @@ class GetCookiesParameters: class GetCookiesResult: """GetCookiesResult.""" - cookies: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) partition_key: Any | None = None From fccbe3c86e0691c650f82bb5ccda0a43445b4a71 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:16:41 +0000 Subject: [PATCH 12/67] fixing generating extra pass --- py/generate_bidi.py | 2 +- py/selenium/webdriver/common/bidi/log.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index f4915aa1ad123..a53ea96db7481 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -946,7 +946,7 @@ def clear_event_handlers(self) -> None: method_enhancements = enhancements.get(method_name_snake, {}) code += command.to_python_method(method_enhancements) code += "\n" - else: + elif not self.events and not enhancements.get("extra_methods", []): code += " pass\n" # Emit extra methods from enhancement manifest diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index c58018e8a947a..1f16849b8e03d 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -264,7 +264,6 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - pass def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. From d8bf641318712e355e527851c851d1dc545dc6e5 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:21:24 +0000 Subject: [PATCH 13/67] fix window tests --- py/private/bidi_enhancements_manifest.py | 37 ++++++++++++++++++++ py/selenium/webdriver/common/bidi/browser.py | 37 ++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index f06a4119625e6..e33a11d5f2b79 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -130,6 +130,43 @@ if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd)''', + ''' def set_client_window_state( + self, + client_window: Any | None = None, + state: Any | None = None, + ): + """Set the client window state. + + Args: + client_window: The client window ID to apply the state to. + state: The window state to set. Can be one of: + - A string: "fullscreen", "maximized", "minimized", "normal" + - A ClientWindowRectState object with width, height, x, y + - A dict representing the state + + Raises: + ValueError: If client_window is not provided or state is invalid. + """ + if client_window is None: + raise ValueError("client_window is required") + if state is None: + raise ValueError("state is required") + + # Serialize ClientWindowRectState if needed + state_param = state + if hasattr(state, '__dataclass_fields__'): + # It's a dataclass, convert to dict + state_param = { + k: v for k, v in state.__dict__.items() + if v is not None + } + + params = { + "clientWindow": client_window, + "state": state_param, + } + cmd = command_builder("browser.setClientWindowState", params) return self._conn.execute(cmd)''', ], }, diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 0618beb14ddef..7c1958fd435f0 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -341,3 +341,40 @@ def set_download_behavior( params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) return self._conn.execute(cmd) + def set_client_window_state( + self, + client_window: Any | None = None, + state: Any | None = None, + ): + """Set the client window state. + + Args: + client_window: The client window ID to apply the state to. + state: The window state to set. Can be one of: + - A string: "fullscreen", "maximized", "minimized", "normal" + - A ClientWindowRectState object with width, height, x, y + - A dict representing the state + + Raises: + ValueError: If client_window is not provided or state is invalid. + """ + if client_window is None: + raise ValueError("client_window is required") + if state is None: + raise ValueError("state is required") + + # Serialize ClientWindowRectState if needed + state_param = state + if hasattr(state, '__dataclass_fields__'): + # It's a dataclass, convert to dict + state_param = { + k: v for k, v in state.__dict__.items() + if v is not None + } + + params = { + "clientWindow": client_window, + "state": state_param, + } + cmd = command_builder("browser.setClientWindowState", params) + return self._conn.execute(cmd) From 00444ac4ed84fa51bcc42b118f0840c0706672b5 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:21:35 +0000 Subject: [PATCH 14/67] fix window tests --- py/private/bidi_enhancements_manifest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index e33a11d5f2b79..2b93f36f1a5dc 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -158,7 +158,7 @@ if hasattr(state, '__dataclass_fields__'): # It's a dataclass, convert to dict state_param = { - k: v for k, v in state.__dict__.items() + k: v for k, v in state.__dict__.items() if v is not None } From 534b5ea82c638f740fb0fd21360b4ab74e7402fc Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Mon, 23 Feb 2026 18:56:16 +0300 Subject: [PATCH 15/67] [dotnet] [bidi] Unregister cancelled commands (#17129) --- dotnet/src/webdriver/BiDi/Broker.cs | 35 ++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/dotnet/src/webdriver/BiDi/Broker.cs b/dotnet/src/webdriver/BiDi/Broker.cs index 5f5477f010276..8f01585ec04d9 100644 --- a/dotnet/src/webdriver/BiDi/Broker.cs +++ b/dotnet/src/webdriver/BiDi/Broker.cs @@ -73,12 +73,25 @@ public async Task ExecuteCommandAsync(TCommand comma var timeout = options?.Timeout ?? TimeSpan.FromSeconds(30); cts.CancelAfter(timeout); - cts.Token.Register(() => tcs.TrySetCanceled(cts.Token)); + var data = JsonSerializer.SerializeToUtf8Bytes(command, jsonCommandTypeInfo); var commandInfo = new CommandInfo(tcs, jsonResultTypeInfo); _pendingCommands[command.Id] = commandInfo; - var data = JsonSerializer.SerializeToUtf8Bytes(command, jsonCommandTypeInfo); - await _transport.SendAsync(data, cts.Token).ConfigureAwait(false); + using var ctsRegistration = cts.Token.Register(() => + { + tcs.TrySetCanceled(cts.Token); + _pendingCommands.TryRemove(command.Id, out _); + }); + + try + { + await _transport.SendAsync(data, cts.Token).ConfigureAwait(false); + } + catch + { + _pendingCommands.TryRemove(command.Id, out _); + throw; + } return (TResult)await tcs.Task.ConfigureAwait(false); } @@ -105,7 +118,7 @@ public async ValueTask DisposeAsync() GC.SuppressFinalize(this); } - private void ProcessReceivedMessage(byte[]? data) + private void ProcessReceivedMessage(byte[] data) { long? id = default; string? type = default; @@ -194,19 +207,22 @@ private void ProcessReceivedMessage(byte[]? data) } else { - throw new BiDiException($"The remote end responded with 'success' message type, but no pending command with id {id} was found."); + if (_logger.IsEnabled(LogEventLevel.Warn)) + { + _logger.Warn($"The remote end responded with 'success' message type, but no pending command with id {id} was found. Message content: {System.Text.Encoding.UTF8.GetString(data)}"); + } } break; case "event": - if (method is null) throw new BiDiException("The remote end responded with 'event' message type, but missed required 'method' property."); + if (method is null) throw new BiDiException($"The remote end responded with 'event' message type, but missed required 'method' property. Message content: {System.Text.Encoding.UTF8.GetString(data)}"); var paramsJsonData = new ReadOnlyMemory(data, (int)paramsStartIndex, (int)(paramsEndIndex - paramsStartIndex)); _eventDispatcher.EnqueueEvent(method, paramsJsonData, _bidi); break; case "error": - if (id is null) throw new BiDiException("The remote end responded with 'error' message type, but missed required 'id' property."); + if (id is null) throw new BiDiException($"The remote end responded with 'error' message type, but missed required 'id' property. Message content: {System.Text.Encoding.UTF8.GetString(data)}"); if (_pendingCommands.TryGetValue(id.Value, out var errorCommand)) { @@ -215,7 +231,10 @@ private void ProcessReceivedMessage(byte[]? data) } else { - throw new BiDiException($"The remote end responded with 'error' message type, but no pending command with id {id} was found."); + if (_logger.IsEnabled(LogEventLevel.Warn)) + { + _logger.Warn($"The remote end responded with 'error' message type, but no pending command with id {id} was found. Message content: {System.Text.Encoding.UTF8.GetString(data)}"); + } } break; From c57f45c33b9941368ec7b0fb6696c906f7cb02d4 Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Mon, 23 Feb 2026 20:30:31 +0300 Subject: [PATCH 16/67] [dotnet] [bidi] Properly handle websocket close handshake (#17132) --- dotnet/src/webdriver/BiDi/Broker.cs | 11 ++++- dotnet/src/webdriver/BiDi/ITransport.cs | 2 +- .../src/webdriver/BiDi/WebSocketTransport.cs | 43 ++++++++++++------- 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/dotnet/src/webdriver/BiDi/Broker.cs b/dotnet/src/webdriver/BiDi/Broker.cs index 8f01585ec04d9..162fc392fa920 100644 --- a/dotnet/src/webdriver/BiDi/Broker.cs +++ b/dotnet/src/webdriver/BiDi/Broker.cs @@ -113,7 +113,7 @@ public async ValueTask DisposeAsync() _receiveMessagesCancellationTokenSource.Dispose(); - _transport.Dispose(); + await _transport.DisposeAsync().ConfigureAwait(false); GC.SuppressFinalize(this); } @@ -269,6 +269,15 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) _logger.Error($"Unhandled error occurred while receiving remote messages: {ex}"); } + // Fail all pending commands, as the connection is likely broken if we failed to receive messages. + foreach (var id in _pendingCommands.Keys) + { + if (_pendingCommands.TryRemove(id, out var pendingCommand)) + { + pendingCommand.TaskCompletionSource.TrySetException(ex); + } + } + throw; } } diff --git a/dotnet/src/webdriver/BiDi/ITransport.cs b/dotnet/src/webdriver/BiDi/ITransport.cs index bdf33406b3936..f202535253c7b 100644 --- a/dotnet/src/webdriver/BiDi/ITransport.cs +++ b/dotnet/src/webdriver/BiDi/ITransport.cs @@ -19,7 +19,7 @@ namespace OpenQA.Selenium.BiDi; -interface ITransport : IDisposable +interface ITransport : IAsyncDisposable { Task ReceiveAsync(CancellationToken cancellationToken); diff --git a/dotnet/src/webdriver/BiDi/WebSocketTransport.cs b/dotnet/src/webdriver/BiDi/WebSocketTransport.cs index 5f9ba5333e8e1..fcab07b9a7770 100644 --- a/dotnet/src/webdriver/BiDi/WebSocketTransport.cs +++ b/dotnet/src/webdriver/BiDi/WebSocketTransport.cs @@ -24,7 +24,7 @@ namespace OpenQA.Selenium.BiDi; -sealed class WebSocketTransport(ClientWebSocket webSocket) : ITransport, IDisposable +sealed class WebSocketTransport(ClientWebSocket webSocket) : ITransport { private readonly static ILogger _logger = Internal.Logging.Log.GetLogger(); @@ -67,6 +67,14 @@ public async Task ReceiveAsync(CancellationToken cancellationToken) { result = await _webSocket.ReceiveAsync(segment, cancellationToken).ConfigureAwait(false); + if (result.MessageType == WebSocketMessageType.Close) + { + await _webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); + + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, + $"The remote end closed the WebSocket connection. Status: {result.CloseStatus}, Description: {result.CloseStatusDescription}"); + } + _sharedMemoryStream.Write(receiveBuffer, 0, result.Count); } while (!result.EndOfMessage); @@ -107,26 +115,31 @@ public async Task SendAsync(byte[] data, CancellationToken cancellationToken) private bool _disposed; - public void Dispose() + public async ValueTask DisposeAsync() { - Dispose(true); - GC.SuppressFinalize(this); - } + if (_disposed) return; - private void Dispose(bool disposing) - { - if (_disposed) + if (_webSocket.State == WebSocketState.Open) { - return; + try + { + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); + } + catch (Exception ex) + { + if (_logger.IsEnabled(LogEventLevel.Warn)) + { + _logger.Warn($"Error closing WebSocket gracefully: {ex.Message}"); + } + } } - if (disposing) - { - _webSocket.Dispose(); - _sharedMemoryStream.Dispose(); - _socketSendSemaphoreSlim.Dispose(); - } + _webSocket.Dispose(); + _sharedMemoryStream.Dispose(); + _socketSendSemaphoreSlim.Dispose(); _disposed = true; + + GC.SuppressFinalize(this); } } From e261b11fa64504660d51691fb438f151899fc585 Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Tue, 24 Feb 2026 15:31:18 +0200 Subject: [PATCH 17/67] [ruby] fix linter error in `./go authors` script (#17136) fix linter error in `./go authors` script Ruby linter complained about too long line. Had to split it. Hope the backslash works in all environments. --- Rakefile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Rakefile b/Rakefile index ef76e44ae6227..61abab2db19c0 100644 --- a/Rakefile +++ b/Rakefile @@ -96,7 +96,9 @@ task ios_driver: 'appium:build' desc 'Update AUTHORS file' task :authors do puts 'Updating AUTHORS file' - sh "(git log --use-mailmap --format='%aN <%aE>' ; cat .OLD_AUTHORS ; cat AUTHORS) | sort -uf > AUTHORS.tmp && mv AUTHORS.tmp AUTHORS" + sh "(git log --use-mailmap --format='%aN <%aE>' ; " \ + 'cat .OLD_AUTHORS ; cat AUTHORS) | ' \ + 'sort -uf > AUTHORS.tmp && mv AUTHORS.tmp AUTHORS' end # Example: `./go release_updates selenium-4.31.0 early-stable` From f04905132a3f66647903da21af0d6083d3f6f675 Mon Sep 17 00:00:00 2001 From: BckupMuthu <69098183+BckupMuthu@users.noreply.github.com> Date: Tue, 24 Feb 2026 20:14:44 +0530 Subject: [PATCH 18/67] [nodejs] Color Class for Javascript library (#16944) * Color Class for Javascript library --------- Co-authored-by: Corey Goldberg <1113081+cgoldberg@users.noreply.github.com> --- javascript/selenium-webdriver/index.js | 3 + javascript/selenium-webdriver/lib/color.js | 358 ++++++++++++++++++ .../selenium-webdriver/test/lib/color_test.js | 143 +++++++ 3 files changed, 504 insertions(+) create mode 100644 javascript/selenium-webdriver/lib/color.js create mode 100644 javascript/selenium-webdriver/test/lib/color_test.js diff --git a/javascript/selenium-webdriver/index.js b/javascript/selenium-webdriver/index.js index 962b30c655bec..32e7486dbaa9d 100644 --- a/javascript/selenium-webdriver/index.js +++ b/javascript/selenium-webdriver/index.js @@ -32,6 +32,7 @@ const firefox = require('./firefox') const ie = require('./ie') const input = require('./lib/input') const logging = require('./lib/logging') +const color = require('./lib/color') const promise = require('./lib/promise') const remote = require('./remote') const safari = require('./safari') @@ -790,6 +791,8 @@ exports.logging = logging exports.promise = promise exports.until = until exports.Select = select.Select +exports.Color = color.Color +exports.Colors = color.Colors exports.LogInspector = LogInspector exports.BrowsingContext = BrowsingContext exports.BrowsingContextInspector = BrowsingContextInspector diff --git a/javascript/selenium-webdriver/lib/color.js b/javascript/selenium-webdriver/lib/color.js new file mode 100644 index 0000000000000..3a9d71a86a129 --- /dev/null +++ b/javascript/selenium-webdriver/lib/color.js @@ -0,0 +1,358 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +'use strict' + +/** + * @fileoverview Color parsing and formatting utilities mirroring Selenium's Java Color. + */ + +class Color { + /** + * @param {number} red + * @param {number} green + * @param {number} blue + * @param {number} alpha + */ + constructor(red, green, blue, alpha = 1) { + this.red_ = Color.#clamp255(red) + this.green_ = Color.#clamp255(green) + this.blue_ = Color.#clamp255(blue) + this.alpha_ = Color.#clamp01(alpha) + } + + /** + * Guesses the input color format and returns a Color instance. + * @param {string} value + * @returns {Color} + */ + static fromString(value) { + const v = String(value) + for (const conv of [ + Color.#fromRgb, + Color.#fromRgbPct, + Color.#fromRgba, + Color.#fromRgbaPct, + Color.#fromHex6, + Color.#fromHex3, + Color.#fromHsl, + Color.#fromHsla, + Color.#fromNamed, + ]) { + const c = conv(v) + if (c) return c + } + throw new Error(`Did not know how to convert ${value} into color`) + } + + /** + * Sets opacity (alpha channel). + * @param {number} alpha + */ + setOpacity(alpha) { + this.alpha_ = Color.#clamp01(alpha) + } + + /** + * @returns {string} e.g. "rgb(255, 0, 0)" + */ + asRgb() { + return `rgb(${this.red_}, ${this.green_}, ${this.blue_})` + } + + /** + * @returns {string} e.g. "rgba(255, 0, 0, 1)" + */ + asRgba() { + let a + if (this.alpha_ === 1) { + a = '1' + } else if (this.alpha_ === 0) { + a = '0' + } else { + a = String(this.alpha_) + } + return `rgba(${this.red_}, ${this.green_}, ${this.blue_}, ${a})` + } + + /** + * @returns {string} e.g. "#ff0000" + */ + asHex() { + const toHex = (n) => n.toString(16).padStart(2, '0') + return `#${toHex(this.red_)}${toHex(this.green_)}${toHex(this.blue_)}` + } + + /** @override */ + toString() { + return `Color: ${this.asRgba()}` + } + + /** + * @param {*} other + * @returns {boolean} + */ + equals(other) { + return other instanceof Color && this.asRgba() === other.asRgba() + } + + // Converters + static #fromRgb(v) { + const m = /^\s*rgb\(\s*(\d{1,3})\s*,\s*(\d{1,3})\s*,\s*(\d{1,3})\s*\)\s*$/i.exec(v) + return m ? new Color(+m[1], +m[2], +m[3], 1) : null + } + + static #fromRgbPct(v) { + const m = + /^\s*rgb\(\s*(\d{1,3}|\d{1,2}\.\d+)%\s*,\s*(\d{1,3}|\d{1,2}\.\d+)%\s*,\s*(\d{1,3}|\d{1,2}\.\d+)%\s*\)\s*$/i.exec( + v, + ) + if (!m) return null + const pct = (i) => Math.floor((Math.min(100, Math.max(0, parseFloat(m[i]))) / 100) * 255) + return new Color(pct(1), pct(2), pct(3), 1) + } + + static #fromRgba(v) { + const m = /^\s*rgba\(\s*(\d{1,3})\s*,\s*(\d{1,3})\s*,\s*(\d{1,3})\s*,\s*(0|1|0\.\d+)\s*\)\s*$/i.exec(v) + return m ? new Color(+m[1], +m[2], +m[3], parseFloat(m[4])) : null + } + + static #fromRgbaPct(v) { + const m = + /^\s*rgba\(\s*(\d{1,3}|\d{1,2}\.\d+)%\s*,\s*(\d{1,3}|\d{1,2}\.\d+)%\s*,\s*(\d{1,3}|\d{1,2}\.\d+)%\s*,\s*(0|1|0\.\d+)\s*\)\s*$/i.exec( + v, + ) + if (!m) return null + const pct = (i) => Math.floor((Math.min(100, Math.max(0, parseFloat(m[i]))) / 100) * 255) + return new Color(pct(1), pct(2), pct(3), parseFloat(m[4])) + } + + static #fromHex6(v) { + const m = /^#([\da-f]{2})([\da-f]{2})([\da-f]{2})$/i.exec(v) + return m ? new Color(parseInt(m[1], 16), parseInt(m[2], 16), parseInt(m[3], 16), 1) : null + } + + static #fromHex3(v) { + const m = /^#([\da-f])([\da-f])([\da-f])$/i.exec(v) + return m ? new Color(parseInt(m[1] + m[1], 16), parseInt(m[2] + m[2], 16), parseInt(m[3] + m[3], 16), 1) : null + } + + static #fromHsl(v) { + const m = /^\s*hsl\(\s*(\d{1,3})\s*,\s*(\d{1,3})%\s*,\s*(\d{1,3})%\s*\)\s*$/i.exec(v) + return m ? Color.#hslToColor(+m[1], +m[2] / 100, +m[3] / 100, 1) : null + } + + static #fromHsla(v) { + const m = /^\s*hsla\(\s*(\d{1,3})\s*,\s*(\d{1,3})%\s*,\s*(\d{1,3})%\s*,\s*(0|1|0\.\d+)\s*\)\s*$/i.exec(v) + return m ? Color.#hslToColor(+m[1], +m[2] / 100, +m[3] / 100, parseFloat(m[4])) : null + } + + static #hslToColor(hDeg, s, l, a) { + const h = (((hDeg % 360) + 360) % 360) / 360 + if (s === 0) { + const v = Math.round(l * 255) + return new Color(v, v, v, a) + } + const luminocity2 = l < 0.5 ? l * (1 + s) : l + s - l * s + const luminocity1 = 2 * l - luminocity2 + const hueToRgb = (l1, l2, hue) => { + if (hue < 0) hue += 1 + if (hue > 1) hue -= 1 + if (hue < 1 / 6) return l1 + (l2 - l1) * 6 * hue + if (hue < 1 / 2) return l2 + if (hue < 2 / 3) return l1 + (l2 - l1) * (2 / 3 - hue) * 6 + return l1 + } + const r = Math.round(hueToRgb(luminocity1, luminocity2, h + 1 / 3) * 255) + const g = Math.round(hueToRgb(luminocity1, luminocity2, h) * 255) + const b = Math.round(hueToRgb(luminocity1, luminocity2, h - 1 / 3) * 255) + return new Color(r, g, b, a) + } + + static #fromNamed(v) { + const name = String(v).trim().toLowerCase() + const c = Colors[name] + return c ? new Color(c.red_, c.green_, c.blue_, c.alpha_) : null + } + + static #clamp255(n) { + return Math.max(0, Math.min(255, Math.round(n))) + } + + static #clamp01(n) { + return Math.max(0, Math.min(1, n)) + } +} + +// Basic colour keywords as defined by the W3C HTML/CSS spec. +// Keys are lowercase to match typical CSS usage. +const Colors = { + transparent: new Color(0, 0, 0, 0), + aliceblue: new Color(240, 248, 255, 1), + antiquewhite: new Color(250, 235, 215, 1), + aqua: new Color(0, 255, 255, 1), + aquamarine: new Color(127, 255, 212, 1), + azure: new Color(240, 255, 255, 1), + beige: new Color(245, 245, 220, 1), + bisque: new Color(255, 228, 196, 1), + black: new Color(0, 0, 0, 1), + blanchedalmond: new Color(255, 235, 205, 1), + blue: new Color(0, 0, 255, 1), + blueviolet: new Color(138, 43, 226, 1), + brown: new Color(165, 42, 42, 1), + burlywood: new Color(222, 184, 135, 1), + cadetblue: new Color(95, 158, 160, 1), + chartreuse: new Color(127, 255, 0, 1), + chocolate: new Color(210, 105, 30, 1), + coral: new Color(255, 127, 80, 1), + cornflowerblue: new Color(100, 149, 237, 1), + cornsilk: new Color(255, 248, 220, 1), + crimson: new Color(220, 20, 60, 1), + cyan: new Color(0, 255, 255, 1), + darkblue: new Color(0, 0, 139, 1), + darkcyan: new Color(0, 139, 139, 1), + darkgoldenrod: new Color(184, 134, 11, 1), + darkgray: new Color(169, 169, 169, 1), + darkgreen: new Color(0, 100, 0, 1), + darkgrey: new Color(169, 169, 169, 1), + darkkhaki: new Color(189, 183, 107, 1), + darkmagenta: new Color(139, 0, 139, 1), + darkolivegreen: new Color(85, 107, 47, 1), + darkorange: new Color(255, 140, 0, 1), + darkorchid: new Color(153, 50, 204, 1), + darkred: new Color(139, 0, 0, 1), + darksalmon: new Color(233, 150, 122, 1), + darkseagreen: new Color(143, 188, 143, 1), + darkslateblue: new Color(72, 61, 139, 1), + darkslategray: new Color(47, 79, 79, 1), + darkslategrey: new Color(47, 79, 79, 1), + darkturquoise: new Color(0, 206, 209, 1), + darkviolet: new Color(148, 0, 211, 1), + deeppink: new Color(255, 20, 147, 1), + deepskyblue: new Color(0, 191, 255, 1), + dimgray: new Color(105, 105, 105, 1), + dimgrey: new Color(105, 105, 105, 1), + dodgerblue: new Color(30, 144, 255, 1), + firebrick: new Color(178, 34, 34, 1), + floralwhite: new Color(255, 250, 240, 1), + forestgreen: new Color(34, 139, 34, 1), + fuchsia: new Color(255, 0, 255, 1), + gainsboro: new Color(220, 220, 220, 1), + ghostwhite: new Color(248, 248, 255, 1), + gold: new Color(255, 215, 0, 1), + goldenrod: new Color(218, 165, 32, 1), + gray: new Color(128, 128, 128, 1), + grey: new Color(128, 128, 128, 1), + green: new Color(0, 128, 0, 1), + greenyellow: new Color(173, 255, 47, 1), + honeydew: new Color(240, 255, 240, 1), + hotpink: new Color(255, 105, 180, 1), + indianred: new Color(205, 92, 92, 1), + indigo: new Color(75, 0, 130, 1), + ivory: new Color(255, 255, 240, 1), + khaki: new Color(240, 230, 140, 1), + lavender: new Color(230, 230, 250, 1), + lavenderblush: new Color(255, 240, 245, 1), + lawngreen: new Color(124, 252, 0, 1), + lemonchiffon: new Color(255, 250, 205, 1), + lightblue: new Color(173, 216, 230, 1), + lightcoral: new Color(240, 128, 128, 1), + lightcyan: new Color(224, 255, 255, 1), + lightgoldenrodyellow: new Color(250, 250, 210, 1), + lightgray: new Color(211, 211, 211, 1), + lightgreen: new Color(144, 238, 144, 1), + lightgrey: new Color(211, 211, 211, 1), + lightpink: new Color(255, 182, 193, 1), + lightsalmon: new Color(255, 160, 122, 1), + lightseagreen: new Color(32, 178, 170, 1), + lightskyblue: new Color(135, 206, 250, 1), + lightslategray: new Color(119, 136, 153, 1), + lightslategrey: new Color(119, 136, 153, 1), + lightsteelblue: new Color(176, 196, 222, 1), + lightyellow: new Color(255, 255, 224, 1), + lime: new Color(0, 255, 0, 1), + limegreen: new Color(50, 205, 50, 1), + linen: new Color(250, 240, 230, 1), + magenta: new Color(255, 0, 255, 1), + maroon: new Color(128, 0, 0, 1), + mediumaquamarine: new Color(102, 205, 170, 1), + mediumblue: new Color(0, 0, 205, 1), + mediumorchid: new Color(186, 85, 211, 1), + mediumpurple: new Color(147, 112, 219, 1), + mediumseagreen: new Color(60, 179, 113, 1), + mediumslateblue: new Color(123, 104, 238, 1), + mediumspringgreen: new Color(0, 250, 154, 1), + mediumturquoise: new Color(72, 209, 204, 1), + mediumvioletred: new Color(199, 21, 133, 1), + midnightblue: new Color(25, 25, 112, 1), + mintcream: new Color(245, 255, 250, 1), + mistyrose: new Color(255, 228, 225, 1), + moccasin: new Color(255, 228, 181, 1), + navajowhite: new Color(255, 222, 173, 1), + navy: new Color(0, 0, 128, 1), + oldlace: new Color(253, 245, 230, 1), + olive: new Color(128, 128, 0, 1), + olivedrab: new Color(107, 142, 35, 1), + orange: new Color(255, 165, 0, 1), + orangered: new Color(255, 69, 0, 1), + orchid: new Color(218, 112, 214, 1), + palegoldenrod: new Color(238, 232, 170, 1), + palegreen: new Color(152, 251, 152, 1), + paleturquoise: new Color(175, 238, 238, 1), + palevioletred: new Color(219, 112, 147, 1), + papayawhip: new Color(255, 239, 213, 1), + peachpuff: new Color(255, 218, 185, 1), + peru: new Color(205, 133, 63, 1), + pink: new Color(255, 192, 203, 1), + plum: new Color(221, 160, 221, 1), + powderblue: new Color(176, 224, 230, 1), + purple: new Color(128, 0, 128, 1), + rebeccapurple: new Color(102, 51, 153, 1), + red: new Color(255, 0, 0, 1), + rosybrown: new Color(188, 143, 143, 1), + royalblue: new Color(65, 105, 225, 1), + saddlebrown: new Color(139, 69, 19, 1), + salmon: new Color(250, 128, 114, 1), + sandybrown: new Color(244, 164, 96, 1), + seagreen: new Color(46, 139, 87, 1), + seashell: new Color(255, 245, 238, 1), + sienna: new Color(160, 82, 45, 1), + silver: new Color(192, 192, 192, 1), + skyblue: new Color(135, 206, 235, 1), + slateblue: new Color(106, 90, 205, 1), + slategray: new Color(112, 128, 144, 1), + slategrey: new Color(112, 128, 144, 1), + snow: new Color(255, 250, 250, 1), + springgreen: new Color(0, 255, 127, 1), + steelblue: new Color(70, 130, 180, 1), + tan: new Color(210, 180, 140, 1), + teal: new Color(0, 128, 128, 1), + thistle: new Color(216, 191, 216, 1), + tomato: new Color(255, 99, 71, 1), + turquoise: new Color(64, 224, 208, 1), + violet: new Color(238, 130, 238, 1), + wheat: new Color(245, 222, 179, 1), + white: new Color(255, 255, 255, 1), + whitesmoke: new Color(245, 245, 245, 1), + yellow: new Color(255, 255, 0, 1), + yellowgreen: new Color(154, 205, 50, 1), +} + +module.exports = { + Color, + Colors, +} diff --git a/javascript/selenium-webdriver/test/lib/color_test.js b/javascript/selenium-webdriver/test/lib/color_test.js new file mode 100644 index 0000000000000..6425d3e81112b --- /dev/null +++ b/javascript/selenium-webdriver/test/lib/color_test.js @@ -0,0 +1,143 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +'use strict' + +const assert = require('node:assert') +const { By, Color, Colors } = require('selenium-webdriver') +const { Pages, suite } = require('../../lib/test') + +suite(function (env) { + let driver + + before(async function () { + driver = await env.builder().build() + }) + + after(async function () { + await driver.quit() + }) + + describe('Color', function () { + describe('parsing', function () { + it('parses rgb()', function () { + const c = Color.fromString('rgb(255, 0, 0)') + assert.strictEqual(c.asHex(), '#ff0000') + assert.strictEqual(c.asRgb(), 'rgb(255, 0, 0)') + assert.strictEqual(c.asRgba(), 'rgba(255, 0, 0, 1)') + }) + + it('parses rgba() with alpha', function () { + const c = Color.fromString('rgba(0, 0, 255, 0.5)') + assert.strictEqual(c.asRgba(), 'rgba(0, 0, 255, 0.5)') + }) + + it('parses rgb% with truncation', function () { + const c = Color.fromString('rgb(50%, 50%, 50%)') + // Java impl truncates 127.5 -> 127 + assert.strictEqual(c.asRgb(), 'rgb(127, 127, 127)') + }) + + it('parses hex #rrggbb and #rgb', function () { + assert.strictEqual(Color.fromString('#ff0000').asRgb(), 'rgb(255, 0, 0)') + assert.strictEqual(Color.fromString('#0f0').asRgb(), 'rgb(0, 255, 0)') + }) + + it('parses hsl()', function () { + const c = Color.fromString('hsl(0, 100%, 50%)') + assert.strictEqual(c.asHex(), '#ff0000') + }) + + it('parses named colors', function () { + const c1 = Color.fromString('rebeccapurple') + assert.strictEqual(c1.asRgb(), 'rgb(102, 51, 153)') + const c2 = Color.fromString('transparent') + assert.strictEqual(c2.asRgba(), 'rgba(0, 0, 0, 0)') + const c3 = Color.fromString('gray') + const c4 = Color.fromString('grey') + assert.strictEqual(c3.asRgb(), c4.asRgb()) + assert.ok(Colors.gray instanceof Color) + }) + + it('equals compares normalized rgba string', function () { + const a = Color.fromString('rgba(255, 0, 0, 1)') + const b = Color.fromString('rgb(255, 0, 0)') + assert.ok(a.equals(b)) + }) + }) + + describe('integration with getCssValue()', function () { + before(async function () { + await driver.get(Pages.colorPage) + }) + + it('handles named color', async function () { + const css = await driver.findElement(By.id('namedColor')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asHex(), '#008000') // green + }) + + it('handles rgb()', async function () { + const css = await driver.findElement(By.id('rgb')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asHex(), '#008000') + }) + + it('handles rgb%()', async function () { + const css = await driver.findElement(By.id('rgbpct')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asRgb(), 'rgb(0, 128, 0)') + }) + + it('handles hex #rrggbb', async function () { + const css = await driver.findElement(By.id('hex')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asHex(), '#008000') + }) + + it('handles short hex #rgb', async function () { + const css = await driver.findElement(By.id('hexShort')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asHex(), '#eeeeee') + }) + + it('handles hsl()', async function () { + const css = await driver.findElement(By.id('hsl')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asHex(), '#008000') + }) + + it('handles rgba()', async function () { + const css = await driver.findElement(By.id('rgba')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asRgba(), 'rgba(0, 128, 0, 0.5)') + }) + + it('handles rgba%()', async function () { + const css = await driver.findElement(By.id('rgbapct')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asRgba(), 'rgba(0, 128, 0, 0.5)') + }) + + it('handles hsla()', async function () { + const css = await driver.findElement(By.id('hsla')).getCssValue('background-color') + const c = Color.fromString(css) + assert.strictEqual(c.asRgba(), 'rgba(0, 128, 0, 0.5)') + }) + }) + }) +}) From 73c4774da715ae4600893ebd7d9e623d2ecca0f1 Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Tue, 24 Feb 2026 21:25:50 +0200 Subject: [PATCH 19/67] [java] fix "or" condition (#17135) fix "or" condition We need to re-try in case of `NoSuchElementException`, not only `StaleElementReferenceException`. `NoSuchElementException` may be thrown by a page object because its @FindBy fields are lazy-initialized. Fixes #17091 --- .../support/ui/ExpectedConditions.java | 37 ++++++++++++++----- .../support/ui/ExpectedConditionsTest.java | 8 ++-- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/java/src/org/openqa/selenium/support/ui/ExpectedConditions.java b/java/src/org/openqa/selenium/support/ui/ExpectedConditions.java index b3ce508462f68..da62c59b8f1e8 100644 --- a/java/src/org/openqa/selenium/support/ui/ExpectedConditions.java +++ b/java/src/org/openqa/selenium/support/ui/ExpectedConditions.java @@ -399,7 +399,7 @@ public static ExpectedCondition textToBePresentInElement( return new ExpectedCondition<>() { private @Nullable String elementText; - private @Nullable StaleElementReferenceException error; + private @Nullable WebDriverException error; @Override public Boolean apply(WebDriver driver) { @@ -408,7 +408,7 @@ public Boolean apply(WebDriver driver) { try { elementText = element.getText(); return elementText.contains(text); - } catch (StaleElementReferenceException e) { + } catch (StaleElementReferenceException | NoSuchElementException e) { error = e; return false; } @@ -481,7 +481,7 @@ public static ExpectedCondition textToBePresentInElementValue( return new ExpectedCondition<>() { private @Nullable String actualValue; - private @Nullable StaleElementReferenceException error; + private @Nullable WebDriverException error; @Override public Boolean apply(WebDriver driver) { @@ -490,7 +490,7 @@ public Boolean apply(WebDriver driver) { try { actualValue = element.getAttribute("value"); return actualValue != null && actualValue.contains(expectedValue); - } catch (StaleElementReferenceException e) { + } catch (StaleElementReferenceException | NoSuchElementException e) { error = e; return false; } @@ -825,7 +825,7 @@ public Boolean apply(WebDriver ignored) { // Calling any method forces a staleness check element.isEnabled(); return false; - } catch (StaleElementReferenceException expected) { + } catch (StaleElementReferenceException | NoSuchElementException expected) { return true; } } @@ -855,7 +855,7 @@ public String toString() { public @Nullable T apply(WebDriver driver) { try { return condition.apply(driver); - } catch (StaleElementReferenceException e) { + } catch (StaleElementReferenceException | NoSuchElementException e) { return null; } } @@ -1677,15 +1677,21 @@ private static boolean isInvisible(final WebElement element) { */ public static ExpectedCondition or(final ExpectedCondition... conditions) { return new ExpectedCondition<>() { + @Nullable final Object[] results = new Object[conditions.length]; + @Override public Boolean apply(WebDriver driver) { - for (ExpectedCondition condition : conditions) { + for (int i = 0; i != conditions.length; ++i) { + ExpectedCondition condition = conditions[i]; + results[i] = null; try { Object result = condition.apply(driver); if (Boolean.TRUE.equals(result) || result != null && !(result instanceof Boolean)) { return true; } - } catch (StaleElementReferenceException ignore) { + results[i] = result; + } catch (RuntimeException e) { + results[i] = e; } } return false; @@ -1696,10 +1702,23 @@ public String toString() { StringBuilder message = new StringBuilder("at least one condition to be valid:").append(lineSeparator()); for (int i = 0; i < conditions.length; i++) { - message.append(i + 1).append(". ").append(conditions[i]).append(lineSeparator()); + message + .append(i + 1) + .append(". ") + .append(conditions[i]) + .append(resultAsString(results[i])) + .append(lineSeparator()); } return message.toString(); } + + private Object resultAsString(@Nullable Object result) { + return Boolean.FALSE.equals(result) + ? "" + : result instanceof WebDriverException + ? " (caused by: " + shortDescription((WebDriverException) result) + ")" + : " (actual: " + result + ")"; + } }; } diff --git a/java/test/org/openqa/selenium/support/ui/ExpectedConditionsTest.java b/java/test/org/openqa/selenium/support/ui/ExpectedConditionsTest.java index 5735dbf6dbbc9..1390ce6664732 100644 --- a/java/test/org/openqa/selenium/support/ui/ExpectedConditionsTest.java +++ b/java/test/org/openqa/selenium/support/ui/ExpectedConditionsTest.java @@ -966,8 +966,7 @@ void whenOneThrows() { .thenReturn("16pt") .thenReturn("17pt") .thenReturn("18pt"); - when(mockElement.getText()) - .thenThrow(new StaleElementReferenceException("Element disappeared")); + when(mockElement.getText()).thenThrow(new NoSuchElementException("Element disappeared")); assertThat( wait.until( @@ -981,7 +980,7 @@ void whenOneThrows() { void whenAllThrow() { String attributeName = "test"; when(mockElement.getAttribute(attributeName)) - .thenThrow(new StaleElementReferenceException("Disappeared 1")); + .thenThrow(new NoSuchElementException("Disappeared 1")); when(mockElement.getCssValue(attributeName)) .thenThrow(new StaleElementReferenceException("Disappeared 2")); when(mockElement.getText()).thenThrow(new StaleElementReferenceException("Disappeared 3")); @@ -999,7 +998,8 @@ void whenAllThrow() { + "1. element to have text \"test\", but..." + " org.openqa.selenium.StaleElementReferenceException: Disappeared 3." + lineSeparator() - + "2. attribute or CSS value \"test\"=\"test\". Current value: \"null\".") + + "2. attribute or CSS value \"test\"=\"test\". Current value: \"null\". (caused" + + " by: org.openqa.selenium.NoSuchElementException: Disappeared 1)") .hasMessageContaining("tried for 1.1 seconds with 250 milliseconds interval"); } } From 7efd537848e9fbe0405a66e76dfd731ba708c563 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Wed, 25 Feb 2026 01:54:37 +0100 Subject: [PATCH 20/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17134) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index 6c7c4b7fcd2f9..e7f65e87e49bd 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -11,8 +11,8 @@ def pin_browsers(): http_archive( name = "linux_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/147.0.4/linux-x86_64/en-US/firefox-147.0.4.tar.xz", - sha256 = "fdafe715d3b3ee406e306a27fd91f3e8bdee2fd35c1b3834df6c9134be6265ca", + url = "https://ftp.mozilla.org/pub/firefox/releases/148.0/linux-x86_64/en-US/firefox-148.0.tar.xz", + sha256 = "a3ea5907006baa19183d5f582f7781c1d0e22fd1605603c4a76fc14b7f55be23", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -33,8 +33,8 @@ js_library( dmg_archive( name = "mac_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/147.0.4/mac/en-US/Firefox%20147.0.4.dmg", - sha256 = "1b7d7e2f1f00fbc18a6890bd3a614d3501168a52ec35ef86a92d915a254bcfb4", + url = "https://ftp.mozilla.org/pub/firefox/releases/148.0/mac/en-US/Firefox%20148.0.dmg", + sha256 = "ec50d7ed2337441d92272617f510cd4c11e44d68699349d466d62211b767e0a2", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -199,8 +199,8 @@ js_library( http_archive( name = "linux_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/145.0.7632.77/linux64/chrome-linux64.zip", - sha256 = "e816d455d0f6a510541206ac7a52d6a1854d1c436843536206bc50e376c8f974", + url = "https://storage.googleapis.com/chrome-for-testing-public/145.0.7632.117/linux64/chrome-linux64.zip", + sha256 = "b373fbc7cc71922082febe41fe413e02bb6c4ae31b6b6626ddb2c764289ef15d", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -220,8 +220,8 @@ js_library( ) http_archive( name = "mac_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/145.0.7632.77/mac-arm64/chrome-mac-arm64.zip", - sha256 = "034936ee1c7cdbdd81cd1db9ae98d8416c3198617bd39c03f6d6a2cbdc73b7b9", + url = "https://storage.googleapis.com/chrome-for-testing-public/145.0.7632.117/mac-arm64/chrome-mac-arm64.zip", + sha256 = "cced377de50a37287f7f69337bbad5d9f12505bf6b03f7d40a627cc8061f7f59", strip_prefix = "chrome-mac-arm64", patch_cmds = [ "mv 'Google Chrome for Testing.app' Chrome.app", @@ -241,8 +241,8 @@ js_library( ) http_archive( name = "linux_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/145.0.7632.77/linux64/chromedriver-linux64.zip", - sha256 = "82f9e1128946f053867b8b939acf5892f1f037d2b9f730046f4b450820c45d47", + url = "https://storage.googleapis.com/chrome-for-testing-public/145.0.7632.117/linux64/chromedriver-linux64.zip", + sha256 = "7f397fa71614b2d97dcb06a87f3297c91725042803654d6c7fedcd10387a5df8", strip_prefix = "chromedriver-linux64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") @@ -259,8 +259,8 @@ js_library( http_archive( name = "mac_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/145.0.7632.77/mac-arm64/chromedriver-mac-arm64.zip", - sha256 = "9e1c34b1be0067aa4d8650c629d60c45b30a6c3af88eaadb629b4092fe54994c", + url = "https://storage.googleapis.com/chrome-for-testing-public/145.0.7632.117/mac-arm64/chromedriver-mac-arm64.zip", + sha256 = "25d09ebbca90628b480e95e0d1541818f8e7646d0b77fa81c6258fc13c0613d0", strip_prefix = "chromedriver-mac-arm64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") From dc1967113f8c333129641e0a32cc4324622929a3 Mon Sep 17 00:00:00 2001 From: JAYA DILEEP <114493600+seethinajayadileep@users.noreply.github.com> Date: Thu, 26 Feb 2026 03:34:23 +0530 Subject: [PATCH 21/67] [java] Improve screenshot error message (#17120) * Improve diagnostic message when screenshot file write fails Adds contextual information to WebDriverException thrown during temporary screenshot file creation or write failure. Preserves original IOException as cause. * Restore accidentally removed line * Apply formatting * Refactor screenshot temp file creation error handling --- java/src/org/openqa/selenium/OutputType.java | 25 ++++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/java/src/org/openqa/selenium/OutputType.java b/java/src/org/openqa/selenium/OutputType.java index 1d0e013e41266..3a3997cfd6cb1 100644 --- a/java/src/org/openqa/selenium/OutputType.java +++ b/java/src/org/openqa/selenium/OutputType.java @@ -83,14 +83,29 @@ public File convertFromPngBytes(byte[] data) { } private File save(byte[] data) { + Path tmpFilePath = createScreenshotFile(); try { - Path tmpFilePath = Files.createTempFile("screenshot", ".png"); - File tmpFile = tmpFilePath.toFile(); - tmpFile.deleteOnExit(); Files.write(tmpFilePath, data); - return tmpFile; } catch (IOException e) { - throw new WebDriverException(e); + throw new WebDriverException( + "Failed to create or write screenshot to temporary file: " + + tmpFilePath.toAbsolutePath().toString(), + e); + } + + File tmpFile = tmpFilePath.toFile(); + tmpFile.deleteOnExit(); + return tmpFile; + } + + private Path createScreenshotFile() { + try { + return Files.createTempFile("screenshot", ".png"); + } catch (IOException e) { + throw new WebDriverException( + "Failed to create or write screenshot to temporary file: " + + "temporary file could not be created", + e); } } From 319f327b079d4f0aa0ff6021be0aff41b6bafb81 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Thu, 26 Feb 2026 01:49:14 +0100 Subject: [PATCH 22/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17140) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index e7f65e87e49bd..d0d8bc7901567 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -50,8 +50,8 @@ js_library( http_archive( name = "linux_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/148.0b15/linux-x86_64/en-US/firefox-148.0b15.tar.xz", - sha256 = "23621cf9537fd8d52d3a4e83bba48984f705facd4c10819aa07ae2531e11e2e5", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b1/linux-x86_64/en-US/firefox-149.0b1.tar.xz", + sha256 = "f36c7db981d3098145c55b6cac5f9665066b9bb37b68531cc9dce59b72726c49", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -72,8 +72,8 @@ js_library( dmg_archive( name = "mac_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/148.0b15/mac/en-US/Firefox%20148.0b15.dmg", - sha256 = "0da5fee250eb13165dda25f9e29d238e51f9e0d56b9355b9c8746f6bc8d5c1fe", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b1/mac/en-US/Firefox%20149.0b1.dmg", + sha256 = "3d1f4abb063b47b392af528cf7501132fc405e2a9f1eccab71913b0fdf3e538c", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -277,8 +277,8 @@ js_library( http_archive( name = "linux_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/linux64/chrome-linux64.zip", - sha256 = "6c3241cf5eab6b5eaed9b0b741bae799377dea26985aed08cda51fb75433218e", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chrome-linux64.zip", + sha256 = "5b1961b081f0156a1923a9d9d1bfffdf00f82e8722152c35eb5eb742d63ceeb8", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -298,8 +298,8 @@ js_library( ) http_archive( name = "mac_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/mac-arm64/chrome-mac-arm64.zip", - sha256 = "b39fe2de33190da209845e5d21ea44c75d66d0f4c33c5a293d8b6a259d3c4029", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chrome-mac-arm64.zip", + sha256 = "207867110edc624316b18684065df4eb06b938a3fd9141790a726ab280e2640f", strip_prefix = "chrome-mac-arm64", patch_cmds = [ "mv 'Google Chrome for Testing.app' Chrome.app", @@ -319,8 +319,8 @@ js_library( ) http_archive( name = "linux_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/linux64/chromedriver-linux64.zip", - sha256 = "c6927758a816f0a2f5f10609b34f74080a8c0f08feaf177a68943d8d4aae3a72", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chromedriver-linux64.zip", + sha256 = "a8c7be8669829ed697759390c8c42b4bca3f884fd20980e078129f5282dabe1a", strip_prefix = "chromedriver-linux64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") @@ -337,8 +337,8 @@ js_library( http_archive( name = "mac_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/mac-arm64/chromedriver-mac-arm64.zip", - sha256 = "29c44a53be87fccea4a7887a7ed2b45b5812839e357e091c6a784ee17bb8da78", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chromedriver-mac-arm64.zip", + sha256 = "84c3717c0eeba663d0b8890a0fc06faa6fe158227876fc6954461730ccc81634", strip_prefix = "chromedriver-mac-arm64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") From b6d918fa9ec7422be0c4fb968ca04d7f3c47894c Mon Sep 17 00:00:00 2001 From: tim-burke-systemware Date: Wed, 25 Feb 2026 23:58:19 -0600 Subject: [PATCH 23/67] [java] Catch exception on .contentAsString with binary content (#17139) * Catch exeption on ..asString with binary content Wrap builder.append() in try-catch to handle UnsupportedOperationException. Fixes a bug when retrieving browser downloads from Selenium Grid using RemoteWebDriver builder syntax by ensuring HTTP message content logging does not fail on unsupported operations. Also updates getting the content using non-deprecated method. Fixes #17137 * Fix formatting and unused import * Simplify, just use toString() --------- Co-authored-by: Corey Goldberg <1113081+cgoldberg@users.noreply.github.com> --- .../remote/http/DumpHttpExchangeFilter.java | 2 +- .../router/RemoteWebDriverDownloadTest.java | 58 +++++++++++++------ 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/java/src/org/openqa/selenium/remote/http/DumpHttpExchangeFilter.java b/java/src/org/openqa/selenium/remote/http/DumpHttpExchangeFilter.java index c4b2c5f8c890e..0c3b591e8dc30 100644 --- a/java/src/org/openqa/selenium/remote/http/DumpHttpExchangeFilter.java +++ b/java/src/org/openqa/selenium/remote/http/DumpHttpExchangeFilter.java @@ -52,7 +52,7 @@ private void expandHeadersAndContent(StringBuilder builder, HttpMessage messa message.forEachHeader( (name, value) -> builder.append(" ").append(name).append(": ").append(value).append("\n")); builder.append("\n"); - builder.append(Contents.string(message)); + builder.append(message); } /** visible for testing only */ diff --git a/java/test/org/openqa/selenium/grid/router/RemoteWebDriverDownloadTest.java b/java/test/org/openqa/selenium/grid/router/RemoteWebDriverDownloadTest.java index ba3ce534e3b76..2f15652e6b62c 100644 --- a/java/test/org/openqa/selenium/grid/router/RemoteWebDriverDownloadTest.java +++ b/java/test/org/openqa/selenium/grid/router/RemoteWebDriverDownloadTest.java @@ -38,9 +38,9 @@ import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.openqa.selenium.By; import org.openqa.selenium.Capabilities; @@ -65,6 +65,11 @@ @Ignore(value = SAFARI, reason = "browser must support setting download location") class RemoteWebDriverDownloadTest extends JupiterTestBase { + enum DriverCreationMode { + CONSTRUCTOR, + BUILDER + } + private static final Set FILE_EXTENSIONS = Set.of(".txt", ".jpg"); private Server server; @@ -97,10 +102,11 @@ public void tearDownGrid() { tearDowns.parallelStream().forEach(Safely::safelyCall); } - @Test + @ParameterizedTest + @EnumSource(DriverCreationMode.class) @NoDriverBeforeTest - void canListDownloadedFiles() { - localDriver = createWebdriver(capabilities); + void canListDownloadedFiles(DriverCreationMode mode) { + localDriver = createWebdriver(capabilities, mode); localDriver.get(appServer.whereIs("downloads/download.html")); localDriver.findElement(By.id("file-1")).click(); @@ -121,9 +127,10 @@ void canListDownloadedFiles() { @ParameterizedTest @MethodSource("downloadableFiles") @NoDriverBeforeTest - void canDownloadFiles(By selector, String expectedFileName, String expectedFileContent) + void canDownloadFiles( + DriverCreationMode mode, By selector, String expectedFileName, String expectedFileContent) throws IOException { - localDriver = createWebdriver(capabilities); + localDriver = createWebdriver(capabilities, mode); localDriver.get(appServer.whereIs("downloads/download.html")); localDriver.findElement(selector).click(); @@ -142,16 +149,23 @@ void canDownloadFiles(By selector, String expectedFileName, String expectedFileC } static Stream downloadableFiles() { - return Stream.of( - Arguments.of(By.id("file-1"), "file_1.txt", "Hello, World!"), - Arguments.of( - By.id("file-3"), "file-with-space 0 & _ ' ~.txt", "Hello, filename with space!")); + return Stream.of(DriverCreationMode.values()) + .flatMap( + mode -> + Stream.of( + Arguments.of(mode, By.id("file-1"), "file_1.txt", "Hello, World!"), + Arguments.of( + mode, + By.id("file-3"), + "file-with-space 0 & _ ' ~.txt", + "Hello, filename with space!"))); } - @Test + @ParameterizedTest + @EnumSource(DriverCreationMode.class) @NoDriverBeforeTest - void testCanDeleteFiles() { - localDriver = createWebdriver(capabilities); + void testCanDeleteFiles(DriverCreationMode mode) { + localDriver = createWebdriver(capabilities, mode); localDriver.get(appServer.whereIs("downloads/download.html")); localDriver.findElement(By.id("file-1")).click(); waitForDownloadedFiles(localDriver, 1); @@ -162,16 +176,17 @@ void testCanDeleteFiles() { assertThat(afterDeleteNames).isEmpty(); } - @Test + @ParameterizedTest + @EnumSource(DriverCreationMode.class) @NoDriverBeforeTest - void errorsWhenCapabilityMissing() { + void errorsWhenCapabilityMissing(DriverCreationMode mode) { Browser browser = Browser.detect(); Capabilities caps = new PersistentCapabilities(Objects.requireNonNull(browser).getCapabilities()) .setCapability(ENABLE_DOWNLOADS, false); - localDriver = createWebdriver(caps); + localDriver = createWebdriver(caps, mode); assertThatThrownBy(() -> ((HasDownloads) localDriver).getDownloadedFiles()) .isInstanceOf(WebDriverException.class) .hasMessageStartingWith( @@ -184,8 +199,15 @@ void errorsWhenCapabilityMissing() { "You must enable downloads in order to work with downloadable files"); } - private WebDriver createWebdriver(Capabilities capabilities) { - return new Augmenter().augment(new RemoteWebDriver(server.getUrl(), capabilities)); + private WebDriver createWebdriver(Capabilities capabilities, DriverCreationMode mode) { + return switch (mode) { + case CONSTRUCTOR -> + new Augmenter().augment(new RemoteWebDriver(server.getUrl(), capabilities)); + case BUILDER -> + new Augmenter() + .augment( + RemoteWebDriver.builder().oneOf(capabilities).address(server.getUrl()).build()); + }; } /** ensure we hit no temporary file created by the browser while downloading */ From cbc0bbdd406b5440ee5058bf07ae8d6c526c888c Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Thu, 26 Feb 2026 18:40:54 +0300 Subject: [PATCH 24/67] [dotnet] [bidi] Wait until events are dispatched when unsubscribing (#17142) --- dotnet/src/webdriver/BiDi/EventDispatcher.cs | 154 ++++++++++++++++--- 1 file changed, 131 insertions(+), 23 deletions(-) diff --git a/dotnet/src/webdriver/BiDi/EventDispatcher.cs b/dotnet/src/webdriver/BiDi/EventDispatcher.cs index 1a4e93006821c..5a37ee58d9f7b 100644 --- a/dotnet/src/webdriver/BiDi/EventDispatcher.cs +++ b/dotnet/src/webdriver/BiDi/EventDispatcher.cs @@ -34,20 +34,18 @@ internal sealed class EventDispatcher : IAsyncDisposable private readonly ConcurrentDictionary _events = new(); - private readonly Channel _pendingEvents = Channel.CreateUnbounded(new() + private readonly Channel _pendingEvents = Channel.CreateUnbounded(new() { SingleReader = true, SingleWriter = true }); - private readonly Task _eventEmitterTask; - - private static readonly TaskFactory _myTaskFactory = new(CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskContinuationOptions.None, TaskScheduler.Default); + private readonly Task _processEventsTask; public EventDispatcher(Func sessionProvider) { _sessionProvider = sessionProvider; - _eventEmitterTask = _myTaskFactory.StartNew(ProcessEventsAwaiterAsync).Unwrap(); + _processEventsTask = Task.Run(ProcessEventsAsync); } public async Task SubscribeAsync(string eventName, EventHandler eventHandler, SubscriptionOptions? options, JsonTypeInfo jsonTypeInfo, CancellationToken cancellationToken) @@ -55,11 +53,19 @@ public async Task SubscribeAsync(string eventName, Eve { var registration = _events.GetOrAdd(eventName, _ => new EventRegistration(jsonTypeInfo)); - var subscribeResult = await _sessionProvider().SubscribeAsync([eventName], new() { Contexts = options?.Contexts, UserContexts = options?.UserContexts }, cancellationToken).ConfigureAwait(false); + registration.AddHandler(eventHandler); - registration.Handlers.Add(eventHandler); + try + { + var subscribeResult = await _sessionProvider().SubscribeAsync([eventName], new() { Contexts = options?.Contexts, UserContexts = options?.UserContexts }, cancellationToken).ConfigureAwait(false); - return new Subscription(subscribeResult.Subscription, this, eventHandler); + return new Subscription(subscribeResult.Subscription, this, eventHandler); + } + catch + { + registration.RemoveHandler(eventHandler); + throw; + } } public async ValueTask UnsubscribeAsync(Subscription subscription, CancellationToken cancellationToken) @@ -67,15 +73,34 @@ public async ValueTask UnsubscribeAsync(Subscription subscription, CancellationT if (_events.TryGetValue(subscription.EventHandler.EventName, out var registration)) { await _sessionProvider().UnsubscribeAsync([subscription.SubscriptionId], null, cancellationToken).ConfigureAwait(false); - registration.Handlers.Remove(subscription.EventHandler); + + // Wait until all pending events for this method are dispatched + try + { + await registration.DrainAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + registration.RemoveHandler(subscription.EventHandler); + } } } public void EnqueueEvent(string method, ReadOnlyMemory jsonUtf8Bytes, IBiDi bidi) { - if (_events.TryGetValue(method, out var registration) && registration.TypeInfo is not null) + if (_events.TryGetValue(method, out var registration)) { - _pendingEvents.Writer.TryWrite(new PendingEvent(method, jsonUtf8Bytes, bidi, registration.TypeInfo)); + if (_pendingEvents.Writer.TryWrite(new EventItem(jsonUtf8Bytes, bidi, registration))) + { + registration.IncrementEnqueued(); + } + else + { + if (_logger.IsEnabled(LogEventLevel.Warn)) + { + _logger.Warn($"Failed to enqueue BiDi event with method '{method}' for processing. Event will be ignored."); + } + } } else { @@ -86,34 +111,45 @@ public void EnqueueEvent(string method, ReadOnlyMemory jsonUtf8Bytes, IBiD } } - private async Task ProcessEventsAwaiterAsync() + private async Task ProcessEventsAsync() { var reader = _pendingEvents.Reader; + while (await reader.WaitToReadAsync().ConfigureAwait(false)) { - while (reader.TryRead(out var result)) + while (reader.TryRead(out var evt)) { try { - if (_events.TryGetValue(result.Method, out var registration)) - { - // Deserialize on background thread instead of network thread (single parse) - var eventArgs = (EventArgs)JsonSerializer.Deserialize(result.JsonUtf8Bytes.Span, result.TypeInfo)!; - eventArgs.BiDi = result.BiDi; + var eventArgs = (EventArgs)JsonSerializer.Deserialize(evt.JsonUtf8Bytes.Span, evt.Registration.TypeInfo)!; + eventArgs.BiDi = evt.BiDi; - foreach (var handler in registration.Handlers.ToArray()) // copy handlers avoiding modified collection while iterating + foreach (var handler in evt.Registration.GetHandlersSnapshot()) + { + try { await handler.InvokeAsync(eventArgs).ConfigureAwait(false); } + catch (Exception ex) + { + if (_logger.IsEnabled(LogEventLevel.Error)) + { + _logger.Error($"Unhandled error processing BiDi event handler: {ex}"); + } + } } } catch (Exception ex) { if (_logger.IsEnabled(LogEventLevel.Error)) { - _logger.Error($"Unhandled error processing BiDi event handler: {ex}"); + _logger.Error($"Unhandled error deserializing BiDi event: {ex}"); } } + finally + { + evt.Registration.IncrementProcessed(); + } } } } @@ -122,16 +158,88 @@ public async ValueTask DisposeAsync() { _pendingEvents.Writer.Complete(); - await _eventEmitterTask.ConfigureAwait(false); + await _processEventsTask.ConfigureAwait(false); GC.SuppressFinalize(this); } - private readonly record struct PendingEvent(string Method, ReadOnlyMemory JsonUtf8Bytes, IBiDi BiDi, JsonTypeInfo TypeInfo); + private sealed record EventItem(ReadOnlyMemory JsonUtf8Bytes, IBiDi BiDi, EventRegistration Registration); private sealed class EventRegistration(JsonTypeInfo typeInfo) { + private long _enqueueSeq; + private long _processedSeq; + private readonly object _drainLock = new(); + private readonly List _handlers = []; + private List<(long TargetSeq, TaskCompletionSource Tcs)>? _drainWaiters; + public JsonTypeInfo TypeInfo { get; } = typeInfo; - public List Handlers { get; } = []; + + public void AddHandler(EventHandler handler) + { + lock (_drainLock) _handlers.Add(handler); + } + + public void RemoveHandler(EventHandler handler) + { + lock (_drainLock) _handlers.Remove(handler); + } + + public EventHandler[] GetHandlersSnapshot() + { + lock (_drainLock) return [.. _handlers]; + } + + public void IncrementEnqueued() => Interlocked.Increment(ref _enqueueSeq); + + public void IncrementProcessed() + { + var processed = Interlocked.Increment(ref _processedSeq); + + lock (_drainLock) + { + if (_drainWaiters is null) return; + + for (var i = _drainWaiters.Count - 1; i >= 0; i--) + { + if (_drainWaiters[i].TargetSeq <= processed) + { + _drainWaiters[i].Tcs.TrySetResult(true); + _drainWaiters.RemoveAt(i); + } + } + + if (_drainWaiters.Count == 0) _drainWaiters = null; + } + } + + public Task DrainAsync(CancellationToken cancellationToken) + { + lock (_drainLock) + { + var target = Volatile.Read(ref _enqueueSeq); + if (Volatile.Read(ref _processedSeq) >= target) return Task.CompletedTask; + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _drainWaiters ??= []; + _drainWaiters.Add((target, tcs)); + + // Double-check: processing may have caught up between the read and adding the waiter + if (Volatile.Read(ref _processedSeq) >= target) + { + _drainWaiters.Remove((target, tcs)); + if (_drainWaiters.Count == 0) _drainWaiters = null; + return Task.CompletedTask; + } + + if (!cancellationToken.CanBeCanceled) return tcs.Task; + + return tcs.Task.ContinueWith( + static _ => { }, + cancellationToken, + TaskContinuationOptions.None, + TaskScheduler.Default); + } + } } } From 2407c89ed5410e423997d0bbfab56dd1ff63268d Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Thu, 26 Feb 2026 21:51:57 +0300 Subject: [PATCH 25/67] [dotnet] Any WebDriver can be disposed asynchronously (#17119) --- dotnet/src/support/BUILD.bazel | 5 ++ .../support/Events/EventFiringWebDriver.cs | 45 ++++++++---- .../src/webdriver/Chromium/ChromiumDriver.cs | 19 ++++- dotnet/src/webdriver/Firefox/FirefoxDriver.cs | 11 --- dotnet/src/webdriver/IWebDriver.cs | 2 +- dotnet/src/webdriver/WebDriver.cs | 72 +++++++++++++++---- dotnet/src/webdriver/assets/nuget/README.md | 2 +- dotnet/test/common/StubDriver.cs | 10 +++ dotnet/test/support/Events/BUILD.bazel | 1 + .../Selenium.WebDriver.Support.Tests.csproj | 2 +- 10 files changed, 127 insertions(+), 42 deletions(-) diff --git a/dotnet/src/support/BUILD.bazel b/dotnet/src/support/BUILD.bazel index b4251681c8452..19d8a5d896384 100644 --- a/dotnet/src/support/BUILD.bazel +++ b/dotnet/src/support/BUILD.bazel @@ -3,6 +3,7 @@ load( "csharp_library", "generated_assembly_info", "nuget_pack", + "nuget_package", ) load( "//dotnet:selenium-dotnet-version.bzl", @@ -44,6 +45,8 @@ csharp_library( ], deps = [ "//dotnet/src/webdriver:webdriver-netstandard2.0", + nuget_package("Microsoft.Bcl.AsyncInterfaces"), + nuget_package("System.Threading.Tasks.Extensions"), ], ) @@ -84,6 +87,8 @@ csharp_library( ], deps = [ "//dotnet/src/webdriver:webdriver-netstandard2.0-strongnamed", + nuget_package("Microsoft.Bcl.AsyncInterfaces"), + nuget_package("System.Threading.Tasks.Extensions"), ], ) diff --git a/dotnet/src/support/Events/EventFiringWebDriver.cs b/dotnet/src/support/Events/EventFiringWebDriver.cs index 70b69589f7500..9933064e8871f 100644 --- a/dotnet/src/support/Events/EventFiringWebDriver.cs +++ b/dotnet/src/support/Events/EventFiringWebDriver.cs @@ -405,6 +405,38 @@ public void Dispose() GC.SuppressFinalize(this); } + /// + /// Asynchronously disposes this instance. + /// + /// A task representing the asynchronous dispose operation. + public async ValueTask DisposeAsync() + { + await this.DisposeAsyncCore().ConfigureAwait(false); + this.Dispose(false); + GC.SuppressFinalize(this); + } + + /// + /// Stops the client from running. + /// + /// If , managed resources are disposed. + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + WrappedDriver.Dispose(); + } + } + + /// + /// Asynchronously performs the core dispose logic. + /// + /// A task representing the asynchronous dispose operation. + protected virtual async ValueTask DisposeAsyncCore() + { + await WrappedDriver.DisposeAsync().ConfigureAwait(false); + } + /// /// Executes JavaScript in the context of the currently selected frame or window. /// @@ -580,19 +612,6 @@ public Screenshot GetScreenshot() return screenshotDriver.GetScreenshot(); } - /// - /// Frees all managed and, optionally, unmanaged resources used by this instance. - /// - /// to dispose of only managed resources; - /// to dispose of managed and unmanaged resources. - protected virtual void Dispose(bool disposing) - { - if (disposing) - { - this.WrappedDriver.Dispose(); - } - } - /// /// Raises the event. /// diff --git a/dotnet/src/webdriver/Chromium/ChromiumDriver.cs b/dotnet/src/webdriver/Chromium/ChromiumDriver.cs index 8857f5229b254..71f864f771513 100644 --- a/dotnet/src/webdriver/Chromium/ChromiumDriver.cs +++ b/dotnet/src/webdriver/Chromium/ChromiumDriver.cs @@ -480,9 +480,9 @@ public void StopCasting(string deviceName) } /// - /// Stops the driver from running + /// Disposes of the resources used by the instance, including any active DevTools session. /// - /// if its in the process of disposing + /// Indicates whether the method is being called from a Dispose method (true) or from a finalizer (false). protected override void Dispose(bool disposing) { if (disposing) @@ -497,6 +497,21 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } + /// + /// Asynchronously disposes of the resources used by the instance, including any active DevTools session. + /// + /// A task representing the asynchronous dispose operation. + protected override async ValueTask DisposeAsyncCore() + { + if (this.devToolsSession != null) + { + this.devToolsSession.Dispose(); + this.devToolsSession = null; + } + + await base.DisposeAsyncCore().ConfigureAwait(false); + } + private static ICapabilities ConvertOptionsToCapabilities(ChromiumOptions options) { if (options == null) diff --git a/dotnet/src/webdriver/Firefox/FirefoxDriver.cs b/dotnet/src/webdriver/Firefox/FirefoxDriver.cs index 0573d2ed368e7..fd4d58c049eb5 100644 --- a/dotnet/src/webdriver/Firefox/FirefoxDriver.cs +++ b/dotnet/src/webdriver/Firefox/FirefoxDriver.cs @@ -414,17 +414,6 @@ protected virtual void PrepareEnvironment() // Does nothing, but provides a hook for subclasses to do "stuff" } - /// - /// Disposes of the FirefoxDriver and frees all resources. - /// - /// A value indicating whether the user initiated the - /// disposal of the object. Pass if the user is actively - /// disposing the object; otherwise . - protected override void Dispose(bool disposing) - { - base.Dispose(disposing); - } - private static ICapabilities ConvertOptionsToCapabilities(FirefoxOptions options) { if (options == null) diff --git a/dotnet/src/webdriver/IWebDriver.cs b/dotnet/src/webdriver/IWebDriver.cs index 7a92363acbaa3..9177ae81098a1 100644 --- a/dotnet/src/webdriver/IWebDriver.cs +++ b/dotnet/src/webdriver/IWebDriver.cs @@ -43,7 +43,7 @@ namespace OpenQA.Selenium; /// more fully featured browser when there is a requirement for one. /// /// -public interface IWebDriver : ISearchContext, IDisposable +public interface IWebDriver : ISearchContext, IDisposable, IAsyncDisposable { /// /// Gets or sets the URL the browser is currently displaying. diff --git a/dotnet/src/webdriver/WebDriver.cs b/dotnet/src/webdriver/WebDriver.cs index 78c76f08ebadb..0df8343313859 100644 --- a/dotnet/src/webdriver/WebDriver.cs +++ b/dotnet/src/webdriver/WebDriver.cs @@ -223,6 +223,17 @@ public void Dispose() GC.SuppressFinalize(this); } + /// + /// Asynchronously disposes the WebDriver Instance + /// + /// A task representing the asynchronous dispose operation. + public async ValueTask DisposeAsync() + { + await this.DisposeAsyncCore().ConfigureAwait(false); + this.Dispose(false); + GC.SuppressFinalize(this); + } + /// /// Executes JavaScript "asynchronously" in the context of the currently selected frame or window, /// executing the callback function specified as the last argument in the list of arguments. @@ -672,25 +683,60 @@ protected bool RegisterInternalDriverCommand(string commandName, [NotNullWhen(tr /// if its in the process of disposing protected virtual void Dispose(bool disposing) { - try + if (disposing) { if (this.SessionId is not null) { - this.Execute(DriverCommand.Quit, null); + try + { + + this.Execute(DriverCommand.Quit, null); + + } + catch (NotImplementedException) + { + } + catch (InvalidOperationException) + { + } + catch (WebDriverException) + { + } + finally + { + this.SessionId = null!; + } } + + this.CommandExecutor.Dispose(); } - catch (NotImplementedException) - { - } - catch (InvalidOperationException) - { - } - catch (WebDriverException) - { - } - finally + } + + /// + /// Asynchronously performs the core dispose logic. + /// + /// A task representing the asynchronous dispose operation. + protected virtual async ValueTask DisposeAsyncCore() + { + if (this.SessionId is not null) { - this.SessionId = null!; + try + { + await this.ExecuteAsync(DriverCommand.Quit, null).ConfigureAwait(false); + } + catch (NotImplementedException) + { + } + catch (InvalidOperationException) + { + } + catch (WebDriverException) + { + } + finally + { + this.SessionId = null!; + } } this.CommandExecutor.Dispose(); diff --git a/dotnet/src/webdriver/assets/nuget/README.md b/dotnet/src/webdriver/assets/nuget/README.md index e780243d3445b..e05cb38ab350b 100644 --- a/dotnet/src/webdriver/assets/nuget/README.md +++ b/dotnet/src/webdriver/assets/nuget/README.md @@ -6,7 +6,7 @@ Selenium is a set of different software tools each with a different approach to using OpenQA.Selenium.Chrome; using OpenQA.Selenium; -using var driver = new ChromeDriver(); +await using var driver = new ChromeDriver(); driver.Url = "https://www.google.com"; driver.FindElement(By.Name("q")).SendKeys("webdriver" + Keys.Return); diff --git a/dotnet/test/common/StubDriver.cs b/dotnet/test/common/StubDriver.cs index 9f14206c40196..64d6cfa821fa4 100644 --- a/dotnet/test/common/StubDriver.cs +++ b/dotnet/test/common/StubDriver.cs @@ -19,6 +19,7 @@ using System; using System.Collections.ObjectModel; +using System.Threading.Tasks; namespace OpenQA.Selenium; @@ -101,4 +102,13 @@ public void Dispose() } #endregion + + #region IAsyncDisposable Members + + public ValueTask DisposeAsync() + { + throw new NotImplementedException(); + } + + #endregion } diff --git a/dotnet/test/support/Events/BUILD.bazel b/dotnet/test/support/Events/BUILD.bazel index 060534a69a7bc..593314776f691 100644 --- a/dotnet/test/support/Events/BUILD.bazel +++ b/dotnet/test/support/Events/BUILD.bazel @@ -13,6 +13,7 @@ dotnet_nunit_test_suite( "//dotnet/src/support", "//dotnet/src/webdriver:webdriver-net8.0", "//dotnet/test/common:fixtures", + nuget_package("Microsoft.Bcl.AsyncInterfaces"), nuget_package("NUnit"), nuget_package("Moq"), ], diff --git a/dotnet/test/support/Selenium.WebDriver.Support.Tests.csproj b/dotnet/test/support/Selenium.WebDriver.Support.Tests.csproj index be6aaef57a009..036e08d4bfd77 100644 --- a/dotnet/test/support/Selenium.WebDriver.Support.Tests.csproj +++ b/dotnet/test/support/Selenium.WebDriver.Support.Tests.csproj @@ -6,6 +6,7 @@ + @@ -13,7 +14,6 @@ - From 8920500ef031972c9e9bbd46aaea7bebd54ca0b6 Mon Sep 17 00:00:00 2001 From: Viet Nguyen Duc Date: Fri, 27 Feb 2026 14:21:56 +0700 Subject: [PATCH 26/67] [grid] Router bypass WebSocket data path via transparent TCP tunnel (#17146) * [grid] Router bypass WebSocket data path via transparent TCP tunnel * [grid] Fix WebSocket proxy race on upstream close in ProxyNodeWebsockets --------- Signed-off-by: Viet Nguyen Duc --- common/repositories.bzl | 24 +- .../grid/TemplateGridServerCommand.java | 17 +- .../grid/node/ProxyNodeWebsockets.java | 83 ++- .../grid/router/httpd/RouterServer.java | 26 +- .../selenium/netty/server/NettyServer.java | 26 +- .../netty/server/SeleniumHttpInitializer.java | 17 + .../netty/server/TcpTunnelHandler.java | 72 ++ .../netty/server/TcpUpgradeTunnelHandler.java | 330 +++++++++ .../openqa/selenium/grid/router/BUILD.bazel | 4 + .../grid/router/TunnelWebsocketTest.java | 635 ++++++++++++++++++ 10 files changed, 1212 insertions(+), 22 deletions(-) create mode 100644 java/src/org/openqa/selenium/netty/server/TcpTunnelHandler.java create mode 100644 java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java create mode 100644 java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java diff --git a/common/repositories.bzl b/common/repositories.bzl index d0d8bc7901567..e7f65e87e49bd 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -50,8 +50,8 @@ js_library( http_archive( name = "linux_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b1/linux-x86_64/en-US/firefox-149.0b1.tar.xz", - sha256 = "f36c7db981d3098145c55b6cac5f9665066b9bb37b68531cc9dce59b72726c49", + url = "https://ftp.mozilla.org/pub/firefox/releases/148.0b15/linux-x86_64/en-US/firefox-148.0b15.tar.xz", + sha256 = "23621cf9537fd8d52d3a4e83bba48984f705facd4c10819aa07ae2531e11e2e5", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -72,8 +72,8 @@ js_library( dmg_archive( name = "mac_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b1/mac/en-US/Firefox%20149.0b1.dmg", - sha256 = "3d1f4abb063b47b392af528cf7501132fc405e2a9f1eccab71913b0fdf3e538c", + url = "https://ftp.mozilla.org/pub/firefox/releases/148.0b15/mac/en-US/Firefox%20148.0b15.dmg", + sha256 = "0da5fee250eb13165dda25f9e29d238e51f9e0d56b9355b9c8746f6bc8d5c1fe", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -277,8 +277,8 @@ js_library( http_archive( name = "linux_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chrome-linux64.zip", - sha256 = "5b1961b081f0156a1923a9d9d1bfffdf00f82e8722152c35eb5eb742d63ceeb8", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/linux64/chrome-linux64.zip", + sha256 = "6c3241cf5eab6b5eaed9b0b741bae799377dea26985aed08cda51fb75433218e", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -298,8 +298,8 @@ js_library( ) http_archive( name = "mac_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chrome-mac-arm64.zip", - sha256 = "207867110edc624316b18684065df4eb06b938a3fd9141790a726ab280e2640f", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/mac-arm64/chrome-mac-arm64.zip", + sha256 = "b39fe2de33190da209845e5d21ea44c75d66d0f4c33c5a293d8b6a259d3c4029", strip_prefix = "chrome-mac-arm64", patch_cmds = [ "mv 'Google Chrome for Testing.app' Chrome.app", @@ -319,8 +319,8 @@ js_library( ) http_archive( name = "linux_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chromedriver-linux64.zip", - sha256 = "a8c7be8669829ed697759390c8c42b4bca3f884fd20980e078129f5282dabe1a", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/linux64/chromedriver-linux64.zip", + sha256 = "c6927758a816f0a2f5f10609b34f74080a8c0f08feaf177a68943d8d4aae3a72", strip_prefix = "chromedriver-linux64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") @@ -337,8 +337,8 @@ js_library( http_archive( name = "mac_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chromedriver-mac-arm64.zip", - sha256 = "84c3717c0eeba663d0b8890a0fc06faa6fe158227876fc6954461730ccc81634", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/mac-arm64/chromedriver-mac-arm64.zip", + sha256 = "29c44a53be87fccea4a7887a7ed2b45b5812839e357e091c6a784ee17bb8da78", strip_prefix = "chromedriver-mac-arm64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") diff --git a/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java b/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java index 8924095ff0aad..0c8e82b4b692a 100644 --- a/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java +++ b/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java @@ -18,6 +18,7 @@ package org.openqa.selenium.grid; import java.io.Closeable; +import java.net.URI; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -47,7 +48,10 @@ public Server asServer(Config initialConfig) { Handlers handler = createHandlers(config); return new NettyServer( - new BaseServerOptions(config), handler.httpHandler, handler.websocketHandler) { + new BaseServerOptions(config), + handler.httpHandler, + handler.websocketHandler, + handler.tcpTunnelResolver) { @Override public void stop() { @@ -92,12 +96,23 @@ public abstract static class Handlers implements Closeable { public final BiFunction, Optional>> websocketHandler; + /** Optional resolver for direct TCP tunnel of WebSocket connections. May be null. */ + public final Function> tcpTunnelResolver; + public Handlers( HttpHandler http, BiFunction, Optional>> websocketHandler) { + this(http, websocketHandler, null); + } + + public Handlers( + HttpHandler http, + BiFunction, Optional>> websocketHandler, + Function> tcpTunnelResolver) { this.httpHandler = Require.nonNull("HTTP handler", http); this.websocketHandler = websocketHandler == null ? (str, sink) -> Optional.empty() : websocketHandler; + this.tcpTunnelResolver = tcpTunnelResolver; } @Override diff --git a/java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java b/java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java index 998811a3202b4..a2569e67e1ce2 100644 --- a/java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java +++ b/java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java @@ -251,6 +251,9 @@ private Consumer createWsEndPoint( LOG.info("Establishing connection to " + uri); AtomicBoolean connectionReleased = new AtomicBoolean(false); + // Set to true as soon as the browser signals it is closing so the send lambda can stop + // forwarding data frames without racing against the JDK WebSocket output stream being closed. + AtomicBoolean upstreamClosing = new AtomicBoolean(false); HttpClient client = clientFactory.createClient(ClientConfig.defaultConfig().baseUri(uri)); try { @@ -258,22 +261,66 @@ private Consumer createWsEndPoint( client.openSocket( new HttpRequest(GET, uri.toString()), new ForwardingListener( - node, downstream, sessionConsumer, sessionId, connectionReleased)); + node, + downstream, + sessionConsumer, + sessionId, + connectionReleased, + client, + upstreamClosing)); return (msg) -> { - try { - upstream.send(msg); - } finally { + // Fast path: once the browser has signalled close, there is no point sending further + // data frames — the JDK WebSocket output is already closing and the send would either + // be dropped or throw "Output closed". For the CloseMessage echo we skip the actual + // network write (the JDK stack handles the protocol-level echo internally when it fires + // onClose) and go straight to resource cleanup. + if (upstreamClosing.get()) { if (msg instanceof CloseMessage) { if (connectionReleased.compareAndSet(false, true)) { node.releaseConnection(sessionId); + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Failed to close client after upstream close for " + uri, e); + } } + } else { + LOG.log(Level.FINE, "Dropping in-flight data frame for closing session " + sessionId); + } + return; + } + + // Slow path: upstream is (was) open — attempt the send and catch the narrow race where + // the browser closes between the upstreamClosing check above and the actual write. + try { + upstream.send(msg); + } catch (Exception e) { + LOG.log( + Level.FINE, + "Could not forward message to browser WebSocket for session " + + sessionId + + " (connection likely closed concurrently)", + e); + if (connectionReleased.compareAndSet(false, true)) { + node.releaseConnection(sessionId); try { client.close(); - } catch (Exception e) { - LOG.log(Level.WARNING, "Failed to shutdown the client of " + uri, e); + } catch (Exception ce) { + LOG.log(Level.FINE, "Failed to close client after send error for " + uri, ce); } } + return; + } + if (msg instanceof CloseMessage) { + if (connectionReleased.compareAndSet(false, true)) { + node.releaseConnection(sessionId); + } + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to shutdown the client of " + uri, e); + } } }; } catch (Exception e) { @@ -289,18 +336,24 @@ private static class ForwardingListener implements WebSocket.Listener { private final Consumer sessionConsumer; private final SessionId sessionId; private final AtomicBoolean connectionReleased; + private final HttpClient client; + private final AtomicBoolean upstreamClosing; public ForwardingListener( Node node, Consumer downstream, Consumer sessionConsumer, SessionId sessionId, - AtomicBoolean connectionReleased) { + AtomicBoolean connectionReleased, + HttpClient client, + AtomicBoolean upstreamClosing) { this.node = node; this.downstream = Objects.requireNonNull(downstream); this.sessionConsumer = Objects.requireNonNull(sessionConsumer); this.sessionId = Objects.requireNonNull(sessionId); this.connectionReleased = Objects.requireNonNull(connectionReleased); + this.client = Objects.requireNonNull(client); + this.upstreamClosing = Objects.requireNonNull(upstreamClosing); } @Override @@ -311,9 +364,19 @@ public void onBinary(byte[] data) { @Override public void onClose(int code, String reason) { + // Signal the send lambda before forwarding the close downstream so that any data frames + // still queued in the Netty pipeline are discarded rather than attempted on a closing stream. + upstreamClosing.set(true); downstream.accept(new CloseMessage(code, reason)); if (connectionReleased.compareAndSet(false, true)) { node.releaseConnection(sessionId); + // Close the HttpClient eagerly so the connection slot is freed even if the client-side + // Close echo never arrives (e.g. the client dropped the TCP connection). + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Failed to close client on upstream WebSocket close", e); + } } } @@ -325,9 +388,15 @@ public void onText(CharSequence data) { @Override public void onError(Throwable cause) { + upstreamClosing.set(true); LOG.log(Level.WARNING, "Error proxying websocket command", cause); if (connectionReleased.compareAndSet(false, true)) { node.releaseConnection(sessionId); + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Failed to close client after WebSocket error", e); + } } } } diff --git a/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java b/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java index 6f78c69340a62..ba27c6aa76402 100644 --- a/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java +++ b/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java @@ -32,15 +32,19 @@ import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; +import java.net.URI; import java.net.URL; import java.time.Duration; import java.util.Collections; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Stream; import org.openqa.selenium.BuildInfo; +import org.openqa.selenium.NoSuchSessionException; import org.openqa.selenium.UsernameAndPassword; import org.openqa.selenium.cli.CliCommand; import org.openqa.selenium.grid.TemplateGridServerCommand; @@ -67,6 +71,8 @@ import org.openqa.selenium.grid.sessionqueue.remote.RemoteNewSessionQueue; import org.openqa.selenium.grid.web.GridUiRoute; import org.openqa.selenium.internal.Require; +import org.openqa.selenium.remote.HttpSessionId; +import org.openqa.selenium.remote.SessionId; import org.openqa.selenium.remote.http.ClientConfig; import org.openqa.selenium.remote.http.Contents; import org.openqa.selenium.remote.http.HttpClient; @@ -183,7 +189,25 @@ protected Handlers createHandlers(Config config) { // access to it. Routable routeWithLiveness = Route.combine(route, get("/readyz").to(() -> readinessCheck)); - return new Handlers(routeWithLiveness, new ProxyWebsocketsIntoGrid(clientFactory, sessions)) { + // Resolve a request URI to the Node URI for direct TCP tunnelling of WebSocket connections. + // Falls back to ProxyWebsocketsIntoGrid (the websocketHandler) when the session is not found. + Function> tcpTunnelResolver = + uri -> + HttpSessionId.getSessionId(uri) + .map(SessionId::new) + .flatMap( + id -> { + try { + return Optional.of(sessions.getUri(id)); + } catch (NoSuchSessionException e) { + return Optional.empty(); + } + }); + + return new Handlers( + routeWithLiveness, + new ProxyWebsocketsIntoGrid(clientFactory, sessions), + tcpTunnelResolver) { @Override public void close() { router.close(); diff --git a/java/src/org/openqa/selenium/netty/server/NettyServer.java b/java/src/org/openqa/selenium/netty/server/NettyServer.java index 4423129736d12..2331eecfb1883 100644 --- a/java/src/org/openqa/selenium/netty/server/NettyServer.java +++ b/java/src/org/openqa/selenium/netty/server/NettyServer.java @@ -36,11 +36,13 @@ import java.net.BindException; import java.net.InetSocketAddress; import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.security.cert.CertificateException; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; import javax.net.ssl.SSLException; import org.openqa.selenium.grid.server.BaseServerOptions; import org.openqa.selenium.grid.server.Server; @@ -62,6 +64,7 @@ public class NettyServer implements Server { private final BiFunction, Optional>> websocketHandler; private final SslContext sslCtx; private final boolean allowCors; + private final Function> tcpTunnelResolver; private Channel channel; @@ -73,9 +76,28 @@ public NettyServer( BaseServerOptions options, HttpHandler handler, BiFunction, Optional>> websocketHandler) { + this(options, handler, websocketHandler, null); + } + + /** + * Creates a {@link NettyServer} with an optional TCP-level tunnel resolver for WebSocket + * connections. When {@code tcpTunnelResolver} is non-null, WebSocket upgrade requests that + * contain a Selenium session ID are intercepted before the normal WebSocket handler: the Router + * opens a raw TCP connection to the resolved Node URI and bridges the two sockets directly, + * removing itself from the WebSocket data path entirely. + * + * @param tcpTunnelResolver maps a request URI to the target Node URI. Return {@link + * Optional#empty()} to fall through to the normal WebSocket handler. + */ + public NettyServer( + BaseServerOptions options, + HttpHandler handler, + BiFunction, Optional>> websocketHandler, + Function> tcpTunnelResolver) { Require.nonNull("Server options", options); Require.nonNull("Handler", handler); this.websocketHandler = Require.nonNull("Factory for websocket connections", websocketHandler); + this.tcpTunnelResolver = tcpTunnelResolver; InternalLoggerFactory.setDefaultFactory(JdkLoggerFactory.INSTANCE); @@ -155,7 +177,9 @@ public NettyServer start() { b.group(bossGroup, workerGroup) .channel(NioServerSocketChannel.class) .handler(new LoggingHandler(LogLevel.DEBUG)) - .childHandler(new SeleniumHttpInitializer(sslCtx, handler, websocketHandler, allowCors)); + .childHandler( + new SeleniumHttpInitializer( + sslCtx, handler, websocketHandler, allowCors, tcpTunnelResolver)); try { // Using a flag to avoid binding to the host, useful in environments like Docker, diff --git a/java/src/org/openqa/selenium/netty/server/SeleniumHttpInitializer.java b/java/src/org/openqa/selenium/netty/server/SeleniumHttpInitializer.java index 532b87e9f950e..19326a28ff104 100644 --- a/java/src/org/openqa/selenium/netty/server/SeleniumHttpInitializer.java +++ b/java/src/org/openqa/selenium/netty/server/SeleniumHttpInitializer.java @@ -25,9 +25,11 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.util.AttributeKey; +import java.net.URI; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; import org.openqa.selenium.internal.Require; import org.openqa.selenium.remote.http.HttpHandler; import org.openqa.selenium.remote.http.Message; @@ -40,16 +42,27 @@ class SeleniumHttpInitializer extends ChannelInitializer { private final BiFunction, Optional>> webSocketHandler; private final SslContext sslCtx; private final boolean allowCors; + private final Function> tcpTunnelResolver; SeleniumHttpInitializer( SslContext sslCtx, HttpHandler seleniumHandler, BiFunction, Optional>> webSocketHandler, boolean allowCors) { + this(sslCtx, seleniumHandler, webSocketHandler, allowCors, null); + } + + SeleniumHttpInitializer( + SslContext sslCtx, + HttpHandler seleniumHandler, + BiFunction, Optional>> webSocketHandler, + boolean allowCors, + Function> tcpTunnelResolver) { this.sslCtx = sslCtx; this.seleniumHandler = Require.nonNull("HTTP handler", seleniumHandler); this.webSocketHandler = Require.nonNull("WebSocket handler", webSocketHandler); this.allowCors = allowCors; + this.tcpTunnelResolver = tcpTunnelResolver; } @Override @@ -63,6 +76,10 @@ protected void initChannel(SocketChannel ch) { // Websocket magic ch.pipeline().addLast("ws-compression", new WebSocketServerCompressionHandler()); + // TCP tunnel intercepts WS upgrades before the normal WS handler when configured. + if (tcpTunnelResolver != null) { + ch.pipeline().addLast("tcp-tunnel", new TcpUpgradeTunnelHandler(tcpTunnelResolver)); + } ch.pipeline().addLast("ws-protocol", new WebSocketUpgradeHandler(KEY, webSocketHandler)); ch.pipeline().addLast("netty-to-se-messages", new MessageInboundConverter()); ch.pipeline().addLast("se-to-netty-messages", new MessageOutboundConverter()); diff --git a/java/src/org/openqa/selenium/netty/server/TcpTunnelHandler.java b/java/src/org/openqa/selenium/netty/server/TcpTunnelHandler.java new file mode 100644 index 0000000000000..b87b94fd9e95f --- /dev/null +++ b/java/src/org/openqa/selenium/netty/server/TcpTunnelHandler.java @@ -0,0 +1,72 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.netty.server; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Forwards every inbound {@link io.netty.buffer.ByteBuf} to a target {@link Channel}. Used on both + * ends of a transparent TCP tunnel once the WebSocket upgrade handshake has been proxied. + */ +class TcpTunnelHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = Logger.getLogger(TcpTunnelHandler.class.getName()); + + private final Channel target; + + TcpTunnelHandler(Channel target) { + this.target = target; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + target + .writeAndFlush(msg) + .addListener( + future -> { + if (!future.isSuccess()) { + LOG.log( + Level.WARNING, + "TCP tunnel write failed on " + + ctx.channel() + + " -> " + + target + + ", closing both channels", + future.cause()); + ctx.close(); + target.close(); + } + }); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + target.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + LOG.log(Level.WARNING, "TCP tunnel error, closing both channels", cause); + ctx.close(); + target.close(); + } +} diff --git a/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java b/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java new file mode 100644 index 0000000000000..985cdf247e6fe --- /dev/null +++ b/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java @@ -0,0 +1,330 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.netty.server; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.util.ReferenceCountUtil; +import java.net.URI; +import java.util.Optional; +import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLException; + +/** + * Netty handler placed in the server pipeline before {@link WebSocketUpgradeHandler}. When it sees + * an HTTP WebSocket upgrade request that carries a Selenium session ID, it resolves the Node URI + * and establishes a transparent TCP tunnel, removing the Router from the data path entirely. + * + *

If no Node URI is found for the session (or the request is not a WS upgrade), the request is + * passed to the next handler in the pipeline (falling through to {@link WebSocketUpgradeHandler}). + * + *

If the Node URI uses {@code https}, an SSL handler is added to the node-side channel so that + * the Router transparently terminates TLS with the client and re-establishes it with the Node. + * + *

If the TCP connect to the Node fails (e.g. the Node is unreachable in a Kubernetes + * port-forward topology), the original upgrade request is fired back through the pipeline so the + * normal {@link WebSocketUpgradeHandler} / {@code ProxyWebsocketsIntoGrid} path can handle it. + */ +class TcpUpgradeTunnelHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = Logger.getLogger(TcpUpgradeTunnelHandler.class.getName()); + + /** + * Lazily-initialised, process-wide SSL context used when connecting to HTTPS nodes. All node + * certificates are trusted because Grid nodes commonly use self-signed certificates for internal + * cluster communication. The external client↔Router TLS boundary is separate and unaffected. + */ + private static volatile SslContext clientSslContext; + + private final Function> nodeUriResolver; + + /** + * @param nodeUriResolver maps an HTTP request URI (e.g. {@code /session//bidi}) to the Node + * URI. Return {@link Optional#empty()} to fall through to the normal WS handler. + */ + TcpUpgradeTunnelHandler(Function> nodeUriResolver) { + this.nodeUriResolver = nodeUriResolver; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (!(msg instanceof HttpRequest)) { + ctx.fireChannelRead(msg); + return; + } + + HttpRequest req = (HttpRequest) msg; + + if (!isWebSocketUpgrade(req)) { + ctx.fireChannelRead(req); + return; + } + + String uri = req.uri(); + Optional maybeNodeUri = nodeUriResolver.apply(uri); + + if (maybeNodeUri.isEmpty()) { + ctx.fireChannelRead(req); + return; + } + + URI nodeUri = maybeNodeUri.get(); + Channel clientChannel = ctx.channel(); + + // Pause client reads while connecting so we don't lose or mis-process data. + clientChannel.config().setAutoRead(false); + + boolean useTls = "https".equalsIgnoreCase(nodeUri.getScheme()); + int port = nodeUri.getPort() != -1 ? nodeUri.getPort() : (useTls ? 443 : 80); + String host = nodeUri.getHost(); + + SslContext nodeSslCtx = null; + if (useTls) { + try { + nodeSslCtx = buildClientSslContext(); + } catch (SSLException e) { + LOG.log( + Level.WARNING, + "Failed to build SSL context for HTTPS node at " + + host + + ":" + + port + + ", falling back to WebSocket handler", + e); + clientChannel.config().setAutoRead(true); + ctx.fireChannelRead(req); + return; + } + } + final SslContext finalNodeSslCtx = nodeSslCtx; + + Bootstrap bootstrap = + new Bootstrap() + .group(clientChannel.eventLoop()) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + if (finalNodeSslCtx != null) { + // SSL handler must be first so the codec operates on plaintext. + ch.pipeline() + .addLast("ssl", finalNodeSslCtx.newHandler(ch.alloc(), host, port)); + } + ch.pipeline().addLast("http-codec", new HttpClientCodec()); + ch.pipeline() + .addLast( + "upgrade-handler", new NodeUpgradeResponseHandler(clientChannel, req)); + } + }); + + ChannelFuture connectFuture = bootstrap.connect(host, port); + connectFuture.addListener( + future -> { + if (!future.isSuccess()) { + // The Node is unreachable (wrong network, K8s port-forward topology, etc.). + // Re-enable reads and pass the request to the next handler so that + // ProxyWebsocketsIntoGrid can try to handle it via its own HTTP client. + LOG.log( + Level.WARNING, + "TCP tunnel connect failed for " + + host + + ":" + + port + + ", falling back to WebSocket handler", + future.cause()); + clientChannel.config().setAutoRead(true); + ctx.fireChannelRead(req); + } + // On success, NodeUpgradeResponseHandler.channelActive sends the request. + }); + } + + private static boolean isWebSocketUpgrade(HttpRequest req) { + return req.headers().containsValue(HttpHeaderNames.CONNECTION, "Upgrade", true) + && req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_VERSION); + } + + private static SslContext buildClientSslContext() throws SSLException { + if (clientSslContext == null) { + synchronized (TcpUpgradeTunnelHandler.class) { + if (clientSslContext == null) { + // InsecureTrustManagerFactory is appropriate here: Grid nodes commonly use self-signed + // certificates for intra-cluster communication, and the trust boundary that matters to + // end users is the client↔Router TLS connection, not this Router↔Node hop. + clientSslContext = + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + } + } + } + return clientSslContext; + } + + // --------------------------------------------------------------------------- + // Inner handler attached to the node-side channel + // --------------------------------------------------------------------------- + + private static final class NodeUpgradeResponseHandler extends ChannelInboundHandlerAdapter { + + private final Channel clientChannel; + private final HttpRequest upgradeRequest; + private boolean tunnelEstablished = false; + + NodeUpgradeResponseHandler(Channel clientChannel, HttpRequest upgradeRequest) { + this.clientChannel = clientChannel; + this.upgradeRequest = upgradeRequest; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + // Forward the original upgrade request to the Node. + DefaultHttpRequest nodeReq = + new DefaultHttpRequest( + upgradeRequest.protocolVersion(), upgradeRequest.method(), upgradeRequest.uri()); + nodeReq.headers().set(upgradeRequest.headers()); + ctx.writeAndFlush(nodeReq); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + try { + if (tunnelEstablished || !(msg instanceof HttpObject)) { + // Tunnel is live or not HTTP; any stale buffered data is discarded. + return; + } + + if (!(msg instanceof HttpResponse)) { + // LastHttpContent or other codec artefact before the 101 — skip. + return; + } + + HttpResponse resp = (HttpResponse) msg; + + if (resp.status().code() != 101) { + LOG.warning("Node rejected WebSocket upgrade: " + resp.status()); + ctx.close(); + clientChannel.close(); + return; + } + + tunnelEstablished = true; + Channel nodeChannel = ctx.channel(); + + // Build a proper Netty HTTP 101 response, copying all headers from the Node's response. + // Writing a DefaultFullHttpResponse goes through HttpResponseEncoder, which correctly + // encodes it, and the HttpServerKeepAliveHandler does not close the channel for 101. + DefaultFullHttpResponse clientResponse = + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.SWITCHING_PROTOCOLS, + Unpooled.EMPTY_BUFFER); + clientResponse.headers().set(resp.headers()); + + clientChannel + .writeAndFlush(clientResponse) + .addListener( + writeFuture -> { + if (!writeFuture.isSuccess()) { + LOG.log( + Level.WARNING, + "Failed to write 101 response to client", + writeFuture.cause()); + clientChannel.close(); + nodeChannel.close(); + return; + } + + // Rewire node channel: remove HTTP codec and this handler, add byte tunnel. + // The "ssl" handler (if present) is intentionally left in place — it + // transparently handles TLS framing for the raw byte stream. + nodeChannel.pipeline().remove("upgrade-handler"); + nodeChannel.pipeline().remove("http-codec"); + nodeChannel.pipeline().addLast("tunnel", new TcpTunnelHandler(clientChannel)); + + // Rewire client channel: replace the tcp-tunnel intercept handler with a raw + // byte tunnel, then strip remaining HTTP/WS handlers that are no longer needed. + ChannelPipeline cp = clientChannel.pipeline(); + cp.replace("tcp-tunnel", "tunnel", new TcpTunnelHandler(nodeChannel)); + for (String name : + new String[] { + "codec", + "keep-alive", + "chunked-write", + "ws-compression", + "ws-protocol", + "netty-to-se-messages", + "se-to-netty-messages", + "se-websocket-handler", + "se-request", + "se-response", + "se-handler" + }) { + if (cp.get(name) != null) { + cp.remove(name); + } + } + + // Re-enable reads on the client now that the tunnel is live. + clientChannel.config().setAutoRead(true); + }); + + } finally { + ReferenceCountUtil.release(msg); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + if (!tunnelEstablished) { + LOG.warning("Node channel closed before tunnel was established"); + clientChannel.close(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + LOG.log(Level.WARNING, "Error during node upgrade handshake", cause); + ctx.close(); + clientChannel.close(); + } + } +} diff --git a/java/test/org/openqa/selenium/grid/router/BUILD.bazel b/java/test/org/openqa/selenium/grid/router/BUILD.bazel index 4b8096bb9bd24..d8fea7a2f56c4 100644 --- a/java/test/org/openqa/selenium/grid/router/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/router/BUILD.bazel @@ -139,6 +139,10 @@ java_test_suite( "//java/src/org/openqa/selenium/firefox", "//java/src/org/openqa/selenium/grid", "//java/src/org/openqa/selenium/grid/config", + "//java/src/org/openqa/selenium/grid/distributor", + "//java/src/org/openqa/selenium/grid/distributor/local", + "//java/src/org/openqa/selenium/grid/distributor/selector", + "//java/src/org/openqa/selenium/grid/node/local", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", "//java/src/org/openqa/selenium/support", diff --git a/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java b/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java new file mode 100644 index 0000000000000..cc09ef5377027 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java @@ -0,0 +1,635 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.grid.router; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.openqa.selenium.remote.Dialect.W3C; +import static org.openqa.selenium.remote.http.HttpMethod.GET; + +import java.net.ServerSocket; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.openqa.selenium.ImmutableCapabilities; +import org.openqa.selenium.MutableCapabilities; +import org.openqa.selenium.NoSuchSessionException; +import org.openqa.selenium.SessionNotCreatedException; +import org.openqa.selenium.events.EventBus; +import org.openqa.selenium.events.local.GuavaEventBus; +import org.openqa.selenium.grid.config.MapConfig; +import org.openqa.selenium.grid.data.CreateSessionResponse; +import org.openqa.selenium.grid.data.DefaultSlotMatcher; +import org.openqa.selenium.grid.data.RequestId; +import org.openqa.selenium.grid.data.Session; +import org.openqa.selenium.grid.data.SessionRequest; +import org.openqa.selenium.grid.distributor.local.LocalDistributor; +import org.openqa.selenium.grid.distributor.selector.DefaultSlotSelector; +import org.openqa.selenium.grid.node.local.LocalNode; +import org.openqa.selenium.grid.security.Secret; +import org.openqa.selenium.grid.server.BaseServerOptions; +import org.openqa.selenium.grid.server.Server; +import org.openqa.selenium.grid.sessionmap.SessionMap; +import org.openqa.selenium.grid.sessionmap.local.LocalSessionMap; +import org.openqa.selenium.grid.sessionqueue.local.LocalNewSessionQueue; +import org.openqa.selenium.grid.testing.PassthroughHttpClient; +import org.openqa.selenium.grid.testing.TestSessionFactory; +import org.openqa.selenium.internal.Either; +import org.openqa.selenium.netty.server.NettyServer; +import org.openqa.selenium.remote.HttpSessionId; +import org.openqa.selenium.remote.SessionId; +import org.openqa.selenium.remote.http.BinaryMessage; +import org.openqa.selenium.remote.http.HttpClient; +import org.openqa.selenium.remote.http.HttpHandler; +import org.openqa.selenium.remote.http.HttpRequest; +import org.openqa.selenium.remote.http.HttpResponse; +import org.openqa.selenium.remote.http.TextMessage; +import org.openqa.selenium.remote.http.WebSocket; +import org.openqa.selenium.remote.tracing.DefaultTestTracer; +import org.openqa.selenium.remote.tracing.Tracer; +import org.openqa.selenium.support.ui.FluentWait; + +class TunnelWebsocketTest { + + private final HttpHandler nullHandler = req -> new HttpResponse(); + private final MapConfig emptyConfig = new MapConfig(Collections.emptyMap()); + + private Server tunnelServer; + private Server backendServer; + private SessionMap sessions; + + @BeforeEach + void setUp() { + Tracer tracer = DefaultTestTracer.createTracer(); + EventBus events = new GuavaEventBus(); + sessions = new LocalSessionMap(tracer, events); + } + + @AfterEach + void tearDown() { + if (tunnelServer != null) { + tunnelServer.stop(); + } + if (backendServer != null) { + backendServer.stop(); + } + } + + private Function> createResolver() { + return uri -> + HttpSessionId.getSessionId(uri) + .map(SessionId::new) + .flatMap( + id -> { + try { + return Optional.of(sessions.getUri(id)); + } catch (NoSuchSessionException e) { + return Optional.empty(); + } + }); + } + + private Server createEchoBackend( + String response, CountDownLatch receivedLatch, AtomicReference received) { + return new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + received.set(((TextMessage) msg).text()); + receivedLatch.countDown(); + if (!response.isEmpty()) { + sink.accept(new TextMessage(response)); + } + } + })) + .start(); + } + + @Test + void shouldForwardTextMessageToBackend() throws URISyntaxException, InterruptedException { + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + backendServer = createEchoBackend("", latch, received); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + socket.sendText("Hello tunnel"); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo("Hello tunnel"); + } + } + + @Test + void shouldForwardTextMessageFromBackendToClient() + throws URISyntaxException, InterruptedException { + backendServer = createEchoBackend("pong", new CountDownLatch(1), new AtomicReference<>()); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference reply = new AtomicReference<>(); + + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), + new WebSocket.Listener() { + @Override + public void onText(CharSequence data) { + reply.set(data.toString()); + latch.countDown(); + } + })) { + + socket.sendText("ping"); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(reply.get()).isEqualTo("pong"); + } + } + + @Test + void shouldForwardBinaryMessages() throws URISyntaxException, InterruptedException { + byte[] payload = new byte[] {1, 2, 3, 4}; + + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof BinaryMessage) { + received.set(((BinaryMessage) msg).data()); + latch.countDown(); + } + })) + .start(); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + socket.sendBinary(payload); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo(payload); + } + } + + @Test + void shouldFallBackToWebSocketHandlerWhenSessionNotFound() { + // No session in the map — tunnel resolver returns empty, falling through to the WS handler + // which also returns empty. WebSocketUpgradeHandler responds with 400 Bad Request. + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + SessionId unknownId = new SessionId(UUID.randomUUID()); + + boolean exceptionThrown = false; + try { + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + unknownId + "/bidi"), new WebSocket.Listener() {}); + } catch (Exception e) { + // Expected: connection is rejected (400) because the session is not in the map. + exceptionThrown = true; + } + assertThat(exceptionThrown).as("Expected openSocket to fail for unknown session").isTrue(); + } + + @Test + void shouldFallBackToWebSocketHandlerWhenNodeIsUnreachable() throws Exception { + // Allocate a port then immediately close the socket so nothing is listening on it. + int closedPort; + try (ServerSocket ss = new ServerSocket(0)) { + closedPort = ss.getLocalPort(); + } + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + new URI("http://127.0.0.1:" + closedPort), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + boolean exceptionThrown = false; + try { + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {}); + } catch (Exception e) { + // Expected: TCP connect fails, falls back to the WS handler (which returns empty) → 400. + // The important thing is a graceful rejection, not an abrupt channel close. + exceptionThrown = true; + } + assertThat(exceptionThrown) + .as("Expected openSocket to fail gracefully when node is unreachable") + .isTrue(); + } + + @Test + void shouldTunnelWebSocketThroughHttpsNode() throws URISyntaxException, InterruptedException { + // Start the backend with a self-signed certificate so its URL is https://. + // The tunnel handler detects the https scheme and adds a TLS handler on the node-side channel. + MapConfig httpsConfig = + new MapConfig(Map.of("server", Map.of("https-self-signed", true, "hostname", "localhost"))); + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + backendServer = + new NettyServer( + new BaseServerOptions(httpsConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + received.set(((TextMessage) msg).text()); + latch.countDown(); + } + })) + .start(); + + // backendServer.getUrl() is now https://localhost: + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + socket.sendText("secure-hello"); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo("secure-hello"); + } + } + + @Test + void shouldSupportMultipleMessagesOnSameConnection() + throws URISyntaxException, InterruptedException { + int messageCount = 5; + CountDownLatch latch = new CountDownLatch(messageCount); + AtomicReference count = new AtomicReference<>(0); + + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + count.updateAndGet(c -> c + 1); + latch.countDown(); + } + })) + .start(); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + for (int i = 0; i < messageCount; i++) { + socket.sendText("msg-" + i); + } + + assertThat(latch.await(10, SECONDS)).isTrue(); + assertThat(count.get()).isEqualTo(messageCount); + } + } + + /** + * Integration test that exercises the full Grid session lifecycle with BiDi enabled. + * + *

Flow: client requests {@code webSocketUrl: true} → LocalDistributor → LocalNode + * (TestSessionFactory returns {@code webSocketUrl} capability pointing to the Router) → + * LocalSessionMap registration → client reads {@code webSocketUrl} from capabilities → connects + * to Router BiDi WebSocket → TCP tunnel → stub backend server. + * + *

This mirrors what a real WebDriver BiDi client does: request {@code webSocketUrl: true}, + * receive a {@code webSocketUrl} in the response capabilities, and connect to it. + */ + @Test + void shouldTunnelBiDiThroughFullGridSessionLifecycle() + throws URISyntaxException, InterruptedException { + // Stub backend — simulates a Node's BiDi WebSocket endpoint. Echoes a fixed reply. + AtomicReference received = new AtomicReference<>(); + CountDownLatch receivedLatch = new CountDownLatch(1); + AtomicReference reply = new AtomicReference<>(); + CountDownLatch replyLatch = new CountDownLatch(1); + + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + received.set(((TextMessage) msg).text()); + receivedLatch.countDown(); + sink.accept(new TextMessage("bidi-ack")); + } + })) + .start(); + + URI backendUri = backendServer.getUrl().toURI(); + + // Wire up in-process Grid components — mirrors how Standalone sets up the session path. + Tracer tracer = DefaultTestTracer.createTracer(); + GuavaEventBus bus = new GuavaEventBus(); + Secret secret = new Secret("test"); + ImmutableCapabilities stereotype = new ImmutableCapabilities("browserName", "chrome"); + + LocalSessionMap gridSessions = new LocalSessionMap(tracer, bus); + LocalNewSessionQueue queue = + new LocalNewSessionQueue( + tracer, + new DefaultSlotMatcher(), + Duration.ofSeconds(2), + Duration.ofSeconds(5), + Duration.ofSeconds(1), + secret, + 5); + + // routerUrl is set after tunnelServer starts so the TestSessionFactory can embed the Router's + // WebSocket URL in the returned webSocketUrl capability (the real Grid does the same). + AtomicReference routerUrl = new AtomicReference<>(); + + // TestSessionFactory: session URI → backendServer (so the TCP tunnel connects there). + // The returned capabilities include webSocketUrl pointing to the Router's BiDi endpoint, + // which is what a real Node would return after the Router rewrites the capability. + LocalNode node = + LocalNode.builder(tracer, bus, backendUri, backendUri, secret) + .add( + stereotype, + new TestSessionFactory( + stereotype, + (id, caps) -> { + URL rUrl = routerUrl.get(); + MutableCapabilities returnedCaps = new MutableCapabilities(caps); + returnedCaps.setCapability( + "webSocketUrl", + "ws://" + + rUrl.getHost() + + ":" + + rUrl.getPort() + + "/session/" + + id + + "/bidi"); + return new Session(id, backendUri, stereotype, returnedCaps, Instant.now()); + })) + .build(); + + LocalDistributor distributor = + new LocalDistributor( + tracer, + bus, + new PassthroughHttpClient.Factory(node), + gridSessions, + queue, + new DefaultSlotSelector(), + secret, + Duration.ofMinutes(5), + false, + Duration.ofSeconds(5), + Runtime.getRuntime().availableProcessors(), + new DefaultSlotMatcher(), + Duration.ofSeconds(30)); + distributor.add(node); + + // Wait for node capacity, then start the Router so routerUrl is known before newSession(). + new FluentWait<>(distributor) + .withTimeout(Duration.ofSeconds(5)) + .pollingEvery(Duration.ofMillis(100)) + .until(d -> d.getStatus().hasCapacity()); + + HttpClient.Factory clientFactory = HttpClient.Factory.createDefault(); + Function> tcpTunnelResolver = + uri -> + HttpSessionId.getSessionId(uri) + .map(SessionId::new) + .flatMap( + id -> { + try { + return Optional.of(gridSessions.getUri(id)); + } catch (NoSuchSessionException e) { + return Optional.empty(); + } + }); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + new ProxyWebsocketsIntoGrid(clientFactory, gridSessions), + tcpTunnelResolver) + .start(); + routerUrl.set(tunnelServer.getUrl()); + + // Create a session with webSocketUrl: true — BiDi explicitly enabled by the client. + // LocalDistributor registers the session in gridSessions automatically. + SessionRequest sessionRequest = + new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "chrome", "webSocketUrl", true)), + Map.of(), + Map.of()); + Either result = + distributor.newSession(sessionRequest); + assertThat(result.isRight()).as("Session creation should succeed").isTrue(); + + // Read webSocketUrl from the returned capabilities — this is how a real client locates the + // BiDi endpoint, not by constructing the path manually. + // The capabilities system deserialises URL-like strings as URI objects, so avoid casting. + Object webSocketUrlCap = + result.right().getSession().getCapabilities().getCapability("webSocketUrl"); + assertThat(webSocketUrlCap).as("webSocketUrl capability must be present").isNotNull(); + + // Connect using the path from webSocketUrl (e.g. /session//bidi). + // The host/port points at the Router which uses the TCP tunnel to reach the backend. + String wsPath = new URI(webSocketUrlCap.toString()).getPath(); + + try (WebSocket socket = + clientFactory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, wsPath), + new WebSocket.Listener() { + @Override + public void onText(CharSequence data) { + reply.set(data.toString()); + replyLatch.countDown(); + } + })) { + + socket.sendText("{\"method\":\"session.new\"}"); + + // Verify client → backend direction. + assertThat(receivedLatch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo("{\"method\":\"session.new\"}"); + + // Verify backend → client direction (the echo reply). + assertThat(replyLatch.await(5, SECONDS)).isTrue(); + assertThat(reply.get()).isEqualTo("bidi-ack"); + } + + distributor.close(); + bus.close(); + } +} From 84655e2f05b96a7175b92c2a497c28750738043e Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:08:20 +0300 Subject: [PATCH 27/67] [dotnet] [bidi] Preserve configurable options pattern (#17144) --- dotnet/src/webdriver/BiDi/BiDi.cs | 7 ++- dotnet/src/webdriver/BiDi/BiDiOptions.cs | 24 ---------- .../src/webdriver/BiDi/BiDiOptionsBuilder.cs | 48 +++++++++++++++++++ .../webdriver/BiDi/WebDriver.Extensions.cs | 4 +- .../src/webdriver/BiDi/WebSocketTransport.cs | 4 +- 5 files changed, 58 insertions(+), 29 deletions(-) delete mode 100644 dotnet/src/webdriver/BiDi/BiDiOptions.cs create mode 100644 dotnet/src/webdriver/BiDi/BiDiOptionsBuilder.cs diff --git a/dotnet/src/webdriver/BiDi/BiDi.cs b/dotnet/src/webdriver/BiDi/BiDi.cs index b65b15078cbd1..e6933b933e660 100644 --- a/dotnet/src/webdriver/BiDi/BiDi.cs +++ b/dotnet/src/webdriver/BiDi/BiDi.cs @@ -52,9 +52,12 @@ private BiDi() { } public Emulation.IEmulationModule Emulation => AsModule(); - public static async Task ConnectAsync(string url, BiDiOptions? options = null, CancellationToken cancellationToken = default) + public static async Task ConnectAsync(string url, Action? configure = null, CancellationToken cancellationToken = default) { - var transport = await WebSocketTransport.ConnectAsync(new Uri(url), cancellationToken).ConfigureAwait(false); + BiDiOptionsBuilder builder = new(); + configure?.Invoke(builder); + + var transport = await builder.TransportFactory(new Uri(url), cancellationToken).ConfigureAwait(false); BiDi bidi = new(); diff --git a/dotnet/src/webdriver/BiDi/BiDiOptions.cs b/dotnet/src/webdriver/BiDi/BiDiOptions.cs deleted file mode 100644 index 7bcac32ca6c0d..0000000000000 --- a/dotnet/src/webdriver/BiDi/BiDiOptions.cs +++ /dev/null @@ -1,24 +0,0 @@ -// -// Licensed to the Software Freedom Conservancy (SFC) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The SFC licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -// - -namespace OpenQA.Selenium.BiDi; - -public sealed class BiDiOptions -{ -} diff --git a/dotnet/src/webdriver/BiDi/BiDiOptionsBuilder.cs b/dotnet/src/webdriver/BiDi/BiDiOptionsBuilder.cs new file mode 100644 index 0000000000000..96b17b516bb56 --- /dev/null +++ b/dotnet/src/webdriver/BiDi/BiDiOptionsBuilder.cs @@ -0,0 +1,48 @@ +// +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +using System.Net.WebSockets; + +namespace OpenQA.Selenium.BiDi; + +///

+/// Provides a fluent API for configuring BiDi connection options, +/// such as the underlying transport mechanism. +/// +public sealed class BiDiOptionsBuilder +{ + internal Func> TransportFactory { get; private set; } + = (uri, ct) => WebSocketTransport.ConnectAsync(uri, null, ct); + + /// + /// Configures the BiDi connection to use a WebSocket transport. + /// + /// + /// WebSocket is the default transport; calling this method is only necessary + /// when you need to customize the underlying + /// (e.g., to set headers, proxy, or certificates). + /// + /// An optional action to configure the before connecting. + /// The current instance for chaining. + public BiDiOptionsBuilder UseWebSocket(Action? configure = null) + { + TransportFactory = (uri, ct) => WebSocketTransport.ConnectAsync(uri, configure, ct); + return this; + } +} diff --git a/dotnet/src/webdriver/BiDi/WebDriver.Extensions.cs b/dotnet/src/webdriver/BiDi/WebDriver.Extensions.cs index 124d6bebd18b6..f3eed2d70a17d 100644 --- a/dotnet/src/webdriver/BiDi/WebDriver.Extensions.cs +++ b/dotnet/src/webdriver/BiDi/WebDriver.Extensions.cs @@ -21,7 +21,7 @@ namespace OpenQA.Selenium.BiDi; public static class WebDriverExtensions { - public static async Task AsBiDiAsync(this IWebDriver webDriver, BiDiOptions? options = null, CancellationToken cancellationToken = default) + public static async Task AsBiDiAsync(this IWebDriver webDriver, Action? configure = null, CancellationToken cancellationToken = default) { if (webDriver is null) throw new ArgumentNullException(nameof(webDriver)); @@ -34,7 +34,7 @@ public static async Task AsBiDiAsync(this IWebDriver webDriver, BiDiOptio if (webSocketUrl is null) throw new BiDiException("The driver is not compatible with bidirectional protocol or \"webSocketUrl\" not enabled in driver options."); - var bidi = await BiDi.ConnectAsync(webSocketUrl, options, cancellationToken).ConfigureAwait(false); + var bidi = await BiDi.ConnectAsync(webSocketUrl, configure, cancellationToken).ConfigureAwait(false); return bidi; } diff --git a/dotnet/src/webdriver/BiDi/WebSocketTransport.cs b/dotnet/src/webdriver/BiDi/WebSocketTransport.cs index fcab07b9a7770..6b0920500b6d3 100644 --- a/dotnet/src/webdriver/BiDi/WebSocketTransport.cs +++ b/dotnet/src/webdriver/BiDi/WebSocketTransport.cs @@ -32,12 +32,14 @@ sealed class WebSocketTransport(ClientWebSocket webSocket) : ITransport private readonly SemaphoreSlim _socketSendSemaphoreSlim = new(1, 1); private readonly MemoryStream _sharedMemoryStream = new(); - public static async Task ConnectAsync(Uri uri, CancellationToken cancellationToken) + public static async Task ConnectAsync(Uri uri, Action? configure, CancellationToken cancellationToken) { ClientWebSocket webSocket = new(); try { + configure?.Invoke(webSocket.Options); + await webSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); } catch (Exception) From 4fc491a07849c01cc631ce38d629e1494b94fc3d Mon Sep 17 00:00:00 2001 From: Alex Rodionov Date: Fri, 27 Feb 2026 07:06:42 -0800 Subject: [PATCH 28/67] [rb] Use portable Ruby (#16936) * [rb] Switch to portable Rubies * [java] Load `java_library` from repo * [java] Load `java_binary` from repo --------- Co-authored-by: Augustin Gottlieb <33221555+aguspe@users.noreply.github.com> --- .github/workflows/bazel.yml | 2 -- BUILD.bazel | 2 +- MODULE.bazel | 7 +++++-- common/remote-build/cc/BUILD | 1 + common/remote-build/cc/armeabi_cc_toolchain_config.bzl | 1 + common/remote-build/cc/cc_toolchain_config.bzl | 1 + cpp/linux-specific/BUILD.bazel | 2 ++ java/BUILD.bazel | 3 ++- java/private/BUILD.bazel | 2 +- java/private/common.bzl | 1 + java/private/dist_info.bzl | 1 + java/private/module.bzl | 3 +++ java/test/org/openqa/selenium/firefox/BUILD.bazel | 2 +- java/test/org/openqa/selenium/grid/router/BUILD.bazel | 2 +- java/test/org/openqa/selenium/netty/server/BUILD.bazel | 2 +- javascript/grid-ui/BUILD.bazel | 2 +- javascript/grid-ui/public/BUILD.bazel | 1 + javascript/private/test_suite.bzl | 5 ++--- rb/Gemfile.lock | 2 +- scripts/BUILD.bazel | 2 +- 20 files changed, 28 insertions(+), 16 deletions(-) diff --git a/.github/workflows/bazel.yml b/.github/workflows/bazel.yml index 9c99ee829177b..34021e6f6a82a 100644 --- a/.github/workflows/bazel.yml +++ b/.github/workflows/bazel.yml @@ -169,7 +169,6 @@ jobs: external-cache: | manifest: crates: rust/Cargo.Bazel.lock - rules_ruby++ruby+ruby: ${{ inputs.os == 'windows' && 'false' || 'rb/.ruby-version' }} "+pin_browsers_extension+linux_beta_chrome": false "+pin_browsers_extension+linux_beta_chromedriver": true "+pin_browsers_extension+linux_beta_firefox": false @@ -198,7 +197,6 @@ jobs: external-cache: | manifest: crates: rust/Cargo.Bazel.lock - rules_ruby++ruby+ruby: rb/.ruby-version repository-cache: true bazelrc: common --color=yes - name: Setup curl for Ubuntu diff --git a/BUILD.bazel b/BUILD.bazel index 63745fbe1d51e..94172c7b82a60 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -3,7 +3,7 @@ load("@buildifier_prebuilt//:rules.bzl", "buildifier") load("@npm//:defs.bzl", "npm_link_all_packages") load("//common:browsers.bzl", "chrome_data", "firefox_data") load("//java:browsers.bzl", "chrome_jvm_flags", "firefox_jvm_flags") -load("//java:defs.bzl", "artifact") +load("//java:defs.bzl", "artifact", "java_binary") exports_files([ "package.json", diff --git a/MODULE.bazel b/MODULE.bazel index acd0d3f5d26ea..901e492751c02 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -24,7 +24,7 @@ bazel_dep(name = "rules_nodejs", version = "6.3.2") bazel_dep(name = "rules_pkg", version = "1.0.1") bazel_dep(name = "rules_python", version = "1.8.3") bazel_dep(name = "rules_proto", version = "7.0.2") -bazel_dep(name = "rules_ruby", version = "0.19.0") +bazel_dep(name = "rules_ruby", version = "0.22.1") bazel_dep(name = "rules_rust", version = "0.67.0") # Until `rules_jvm_external` 6.8 ships @@ -259,7 +259,7 @@ ruby.toolchain( "curl", "libyaml", ], - ruby_build_version = "20260110", + portable_ruby = True, version_file = "//:rb/.ruby-version", ) ruby.bundle_fetch( @@ -358,6 +358,9 @@ ruby.bundle_fetch( }, gemfile = "//:rb/Gemfile", gemfile_lock = "//:rb/Gemfile.lock", + jar_checksums = { + "org.snakeyaml:snakeyaml-engine:2.10": "c99d9fd66c7c251d881a9cd95089b7c8044c29a1b02983d7036981bd4354ec37", + }, ) use_repo(ruby, "bundle", "ruby", "ruby_toolchains") diff --git a/common/remote-build/cc/BUILD b/common/remote-build/cc/BUILD index 88b9ad9cbcd6c..425b4acd189ff 100755 --- a/common/remote-build/cc/BUILD +++ b/common/remote-build/cc/BUILD @@ -14,6 +14,7 @@ # This becomes the BUILD file for @local_config_cc// under non-BSD unixes. +load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:defs.bzl", "cc_toolchain", "cc_toolchain_suite") load(":armeabi_cc_toolchain_config.bzl", "armeabi_cc_toolchain_config") load(":cc_toolchain_config.bzl", "cc_toolchain_config") diff --git a/common/remote-build/cc/armeabi_cc_toolchain_config.bzl b/common/remote-build/cc/armeabi_cc_toolchain_config.bzl index 72ef48ae6d6df..ae0527efe74bb 100755 --- a/common/remote-build/cc/armeabi_cc_toolchain_config.bzl +++ b/common/remote-build/cc/armeabi_cc_toolchain_config.bzl @@ -19,6 +19,7 @@ load( "feature", "tool_path", ) +load("@rules_cc//cc/common:cc_common.bzl", "cc_common") def _impl(ctx): toolchain_identifier = "stub_armeabi-v7a" diff --git a/common/remote-build/cc/cc_toolchain_config.bzl b/common/remote-build/cc/cc_toolchain_config.bzl index 4fd16d733098a..21b78e1ad76c5 100755 --- a/common/remote-build/cc/cc_toolchain_config.bzl +++ b/common/remote-build/cc/cc_toolchain_config.bzl @@ -28,6 +28,7 @@ load( "variable_with_value", "with_feature_set", ) +load("@rules_cc//cc/common:cc_common.bzl", "cc_common") def layering_check_features(compiler): if compiler != "clang": diff --git a/cpp/linux-specific/BUILD.bazel b/cpp/linux-specific/BUILD.bazel index 6fbb8031b2219..f70a1001ffade 100644 --- a/cpp/linux-specific/BUILD.bazel +++ b/cpp/linux-specific/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") + cc_binary( name = "noblur64", srcs = glob([ diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 71808daaa9c78..25d2f0c494038 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,6 +1,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@contrib_rules_jvm//java:defs.bzl", "spotbugs_binary", "spotbugs_config") -load(":defs.bzl", "artifact") +load("@rules_java//java:java_plugin.bzl", "java_plugin") +load(":defs.bzl", "artifact", "java_library") exports_files( srcs = [ diff --git a/java/private/BUILD.bazel b/java/private/BUILD.bazel index 750af87a97a5a..eaf0e211d6e78 100644 --- a/java/private/BUILD.bazel +++ b/java/private/BUILD.bazel @@ -1,4 +1,4 @@ -load("@rules_jvm_external//:defs.bzl", "artifact") +load("//java:defs.bzl", "artifact", "java_binary") exports_files( srcs = [ diff --git a/java/private/common.bzl b/java/private/common.bzl index 8f68b4a9b3fcf..1fb1415129762 100644 --- a/java/private/common.bzl +++ b/java/private/common.bzl @@ -1,3 +1,4 @@ +load("@rules_java//java/common:java_info.bzl", "JavaInfo") load("//java/private:module.bzl", "JavaModuleInfo") MavenInfo = provider( diff --git a/java/private/dist_info.bzl b/java/private/dist_info.bzl index 97c8b9bddbc97..09afe15922003 100644 --- a/java/private/dist_info.bzl +++ b/java/private/dist_info.bzl @@ -1,3 +1,4 @@ +load("@rules_java//java/common:java_info.bzl", "JavaInfo") load("//java/private:common.bzl", "MavenInfo", "explode_coordinates", "read_coordinates") load("//java/private:module.bzl", "JavaModuleInfo") diff --git a/java/private/module.bzl b/java/private/module.bzl index f0c9eedec22fa..313dca27e2d85 100644 --- a/java/private/module.bzl +++ b/java/private/module.bzl @@ -1,3 +1,6 @@ +load("@rules_java//java/common:java_common.bzl", "java_common") +load("@rules_java//java/common:java_info.bzl", "JavaInfo") + _GatheredModuleInfo = provider( fields = { "name": "Name of the module, may be `None`.", diff --git a/java/test/org/openqa/selenium/firefox/BUILD.bazel b/java/test/org/openqa/selenium/firefox/BUILD.bazel index ee7f980bd7c94..7d29fa850c18e 100644 --- a/java/test/org/openqa/selenium/firefox/BUILD.bazel +++ b/java/test/org/openqa/selenium/firefox/BUILD.bazel @@ -1,6 +1,6 @@ load("@rules_jvm_external//:defs.bzl", "artifact") load("//common:defs.bzl", "copy_file") -load("//java:defs.bzl", "JUNIT5_DEPS", "java_selenium_test_suite", "java_test_suite") +load("//java:defs.bzl", "JUNIT5_DEPS", "java_library", "java_selenium_test_suite", "java_test_suite") LARGE_TESTS = [ "ExtensionsTest.java", diff --git a/java/test/org/openqa/selenium/grid/router/BUILD.bazel b/java/test/org/openqa/selenium/grid/router/BUILD.bazel index d8fea7a2f56c4..33ca056590d9b 100644 --- a/java/test/org/openqa/selenium/grid/router/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/router/BUILD.bazel @@ -1,5 +1,5 @@ load("@rules_jvm_external//:defs.bzl", "artifact") -load("//java:defs.bzl", "BIDI_BROWSERS", "JUNIT5_DEPS", "SINGLE_BROWSER", "java_selenium_test_suite", "java_test_suite") +load("//java:defs.bzl", "BIDI_BROWSERS", "JUNIT5_DEPS", "SINGLE_BROWSER", "java_library", "java_selenium_test_suite", "java_test_suite") load("//java:version.bzl", "TOOLS_JAVA_VERSION") load("//java/src/org/openqa/selenium/devtools:versions.bzl", "CDP_DEPS") diff --git a/java/test/org/openqa/selenium/netty/server/BUILD.bazel b/java/test/org/openqa/selenium/netty/server/BUILD.bazel index c21993bc25b1c..29362519a0ebe 100644 --- a/java/test/org/openqa/selenium/netty/server/BUILD.bazel +++ b/java/test/org/openqa/selenium/netty/server/BUILD.bazel @@ -1,5 +1,5 @@ load("@rules_jvm_external//:defs.bzl", "artifact") -load("//java:defs.bzl", "JUNIT5_DEPS", "java_test_suite") +load("//java:defs.bzl", "JUNIT5_DEPS", "java_library", "java_test_suite") SMALL_TEST_SRCS = [ "RequestConverterTest.java", diff --git a/javascript/grid-ui/BUILD.bazel b/javascript/grid-ui/BUILD.bazel index ca3c37dd40584..289b89443d746 100644 --- a/javascript/grid-ui/BUILD.bazel +++ b/javascript/grid-ui/BUILD.bazel @@ -5,7 +5,7 @@ load("@aspect_rules_js//js:defs.bzl", "js_library") load("@aspect_rules_ts//ts:defs.bzl", "ts_project") load("@npm//:defs.bzl", "npm_link_all_packages") load("@rules_pkg//pkg:zip.bzl", "pkg_zip") -load("//java:defs.bzl", "merge_jars") +load("//java:defs.bzl", "java_import", "merge_jars") npm_link_all_packages(name = "node_modules") diff --git a/javascript/grid-ui/public/BUILD.bazel b/javascript/grid-ui/public/BUILD.bazel index b1635bc09dd94..7f68b55a76b48 100644 --- a/javascript/grid-ui/public/BUILD.bazel +++ b/javascript/grid-ui/public/BUILD.bazel @@ -1,4 +1,5 @@ load("@rules_pkg//pkg:zip.bzl", "pkg_zip") +load("//java:defs.bzl", "java_import") pkg_zip( name = "build-zip", diff --git a/javascript/private/test_suite.bzl b/javascript/private/test_suite.bzl index 77bbf3c2e46a2..78f5ade43f0d9 100644 --- a/javascript/private/test_suite.bzl +++ b/javascript/private/test_suite.bzl @@ -1,5 +1,5 @@ load("@rules_jvm_external//:defs.bzl", "artifact") -load("//java:defs.bzl", "selenium_test") +load("//java:defs.bzl", "java_binary", "selenium_test") def closure_test_suite(name, data = [], browsers = None): data = data + [ @@ -28,8 +28,7 @@ def closure_test_suite(name, data = [], browsers = None): kwargs["browsers"] = browsers selenium_test(**kwargs) - - native.java_binary( + java_binary( name = name + "_debug_server", main_class = "org.openqa.selenium.environment.webserver.NettyAppServer", data = data, diff --git a/rb/Gemfile.lock b/rb/Gemfile.lock index 237965ed5a7ae..890506e438335 100644 --- a/rb/Gemfile.lock +++ b/rb/Gemfile.lock @@ -222,4 +222,4 @@ DEPENDENCIES yard (~> 0.9.11, >= 0.9.36) BUNDLED WITH - 2.4.19 + 4.0.6 diff --git a/scripts/BUILD.bazel b/scripts/BUILD.bazel index 0a4514942b3e9..670d27ea0cd6a 100644 --- a/scripts/BUILD.bazel +++ b/scripts/BUILD.bazel @@ -1,6 +1,6 @@ load("@py_dev_requirements//:requirements.bzl", "requirement") load("@rules_python//python:defs.bzl", "py_binary") -load("//java:defs.bzl", "artifact") +load("//java:defs.bzl", "artifact", "java_binary") py_binary( name = "pinned_browsers", From 5d0b49cc871d446107e642539526808a9a025bb0 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Sat, 28 Feb 2026 01:45:24 +0100 Subject: [PATCH 29/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17145) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 42 ++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index e7f65e87e49bd..dc63f08014a9e 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -50,8 +50,8 @@ js_library( http_archive( name = "linux_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/148.0b15/linux-x86_64/en-US/firefox-148.0b15.tar.xz", - sha256 = "23621cf9537fd8d52d3a4e83bba48984f705facd4c10819aa07ae2531e11e2e5", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b2/linux-x86_64/en-US/firefox-149.0b2.tar.xz", + sha256 = "14057fe24b65ef64125d04bde3507f48e489464b9f22cda832d34db92dede817", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -72,8 +72,8 @@ js_library( dmg_archive( name = "mac_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/148.0b15/mac/en-US/Firefox%20148.0b15.dmg", - sha256 = "0da5fee250eb13165dda25f9e29d238e51f9e0d56b9355b9c8746f6bc8d5c1fe", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b2/mac/en-US/Firefox%20149.0b2.dmg", + sha256 = "cc697fd73992e677e7249be2a42d7e6e9fe1373841ecefe7009acfad113566fb", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -123,10 +123,10 @@ js_library( pkg_archive( name = "mac_edge", - url = "https://msedge.sf.dl.delivery.mp.microsoft.com/filestreamingservice/files/d4760825-3164-457a-9c5d-4d762678c105/MicrosoftEdge-145.0.3800.70.pkg", - sha256 = "9b4888e1c496127571d6ace8ab3e58cbbc5e75e162a8fc7a1004500e6bd83357", + url = "https://msedge.sf.dl.delivery.mp.microsoft.com/filestreamingservice/files/42fe1f19-ccfd-476b-835f-4dc020005cd9/MicrosoftEdge-145.0.3800.82.pkg", + sha256 = "d046111ab7cbbf56a78588afae17f21cc1343a55c9364268cbd98868f6428979", move = { - "MicrosoftEdge-145.0.3800.70.pkg/Payload/Microsoft Edge.app": "Edge.app", + "MicrosoftEdge-145.0.3800.82.pkg/Payload/Microsoft Edge.app": "Edge.app", }, build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") @@ -143,8 +143,8 @@ js_library( deb_archive( name = "linux_edge", - url = "https://packages.microsoft.com/repos/edge/pool/main/m/microsoft-edge-stable/microsoft-edge-stable_145.0.3800.70-1_amd64.deb", - sha256 = "814ca1f400f59ed9d96f688b47028bc72d0fc58d038acb7bde84fd002da91502", + url = "https://packages.microsoft.com/repos/edge/pool/main/m/microsoft-edge-stable/microsoft-edge-stable_145.0.3800.82-1_amd64.deb", + sha256 = "29e8d282072cf3c798e524cd1bd17b4ee79982f0679b17aca0d6ed50d8d2adf9", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -165,8 +165,8 @@ js_library( http_archive( name = "linux_edgedriver", - url = "https://msedgedriver.microsoft.com/145.0.3800.70/edgedriver_linux64.zip", - sha256 = "d8a3d393d6e5c246783260946fe8a91cfda2b2628438842bd3bb7f205b190d9f", + url = "https://msedgedriver.microsoft.com/145.0.3800.82/edgedriver_linux64.zip", + sha256 = "e6fa668cb036938d56a06519b19923ac8e547c3ea4c0d4fdf7a75076b3e1b31a", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -182,8 +182,8 @@ js_library( http_archive( name = "mac_edgedriver", - url = "https://msedgedriver.microsoft.com/145.0.3800.70/edgedriver_mac64_m1.zip", - sha256 = "71f65943c22e4aabce082531f83b1a3312d38c02f66e5fbb73edd57d03063363", + url = "https://msedgedriver.microsoft.com/145.0.3800.82/edgedriver_mac64_m1.zip", + sha256 = "178502258e17aef84051c5a317956ab3d3944e7329e3eee662105b17a3b60dc8", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -277,8 +277,8 @@ js_library( http_archive( name = "linux_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/linux64/chrome-linux64.zip", - sha256 = "6c3241cf5eab6b5eaed9b0b741bae799377dea26985aed08cda51fb75433218e", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chrome-linux64.zip", + sha256 = "5b1961b081f0156a1923a9d9d1bfffdf00f82e8722152c35eb5eb742d63ceeb8", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -298,8 +298,8 @@ js_library( ) http_archive( name = "mac_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/mac-arm64/chrome-mac-arm64.zip", - sha256 = "b39fe2de33190da209845e5d21ea44c75d66d0f4c33c5a293d8b6a259d3c4029", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chrome-mac-arm64.zip", + sha256 = "207867110edc624316b18684065df4eb06b938a3fd9141790a726ab280e2640f", strip_prefix = "chrome-mac-arm64", patch_cmds = [ "mv 'Google Chrome for Testing.app' Chrome.app", @@ -319,8 +319,8 @@ js_library( ) http_archive( name = "linux_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/linux64/chromedriver-linux64.zip", - sha256 = "c6927758a816f0a2f5f10609b34f74080a8c0f08feaf177a68943d8d4aae3a72", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chromedriver-linux64.zip", + sha256 = "a8c7be8669829ed697759390c8c42b4bca3f884fd20980e078129f5282dabe1a", strip_prefix = "chromedriver-linux64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") @@ -337,8 +337,8 @@ js_library( http_archive( name = "mac_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/mac-arm64/chromedriver-mac-arm64.zip", - sha256 = "29c44a53be87fccea4a7887a7ed2b45b5812839e357e091c6a784ee17bb8da78", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chromedriver-mac-arm64.zip", + sha256 = "84c3717c0eeba663d0b8890a0fc06faa6fe158227876fc6954461730ccc81634", strip_prefix = "chromedriver-mac-arm64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") From 77b1995dfa1d84b5935a206bcc9c48623a44f078 Mon Sep 17 00:00:00 2001 From: JAYA DILEEP <114493600+seethinajayadileep@users.noreply.github.com> Date: Mon, 2 Mar 2026 01:07:33 +0530 Subject: [PATCH 30/67] [java] Guard against NPE in Platform.extractFromSysProperty (#17151) * [java] Guard against NPE in Platform.extractFromSysProperty * [java] Improve null-safety and formatting in Platform.extractFromSysProperty * [java] Harden Platform.extractFromSysProperty against null inputs --- java/src/org/openqa/selenium/Platform.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/java/src/org/openqa/selenium/Platform.java b/java/src/org/openqa/selenium/Platform.java index 24d4a776166cf..9bef35c6d81ab 100644 --- a/java/src/org/openqa/selenium/Platform.java +++ b/java/src/org/openqa/selenium/Platform.java @@ -17,6 +17,8 @@ package org.openqa.selenium; +import static java.util.Objects.requireNonNullElse; + import java.util.Arrays; import java.util.Locale; import java.util.regex.Matcher; @@ -415,7 +417,7 @@ public static Platform getCurrent() { * @return the most likely platform based on given operating system name */ public static Platform extractFromSysProperty(String osName) { - return extractFromSysProperty(osName, System.getProperty("os.version")); + return extractFromSysProperty(osName, System.getProperty("os.version", "")); } /** @@ -428,7 +430,10 @@ public static Platform extractFromSysProperty(String osName) { * @return the most likely platform based on given operating system name and version */ public static Platform extractFromSysProperty(String osName, String osVersion) { + osName = requireNonNullElse(osName, ""); + osVersion = requireNonNullElse(osVersion, ""); osName = osName.toLowerCase(Locale.ENGLISH); + // os.name for android is linux if ("dalvik".equalsIgnoreCase(System.getProperty("java.vm.name"))) { return Platform.ANDROID; From c67296933bfb4d15c1568c976fdf4cb9fe022531 Mon Sep 17 00:00:00 2001 From: JAYA DILEEP <114493600+seethinajayadileep@users.noreply.github.com> Date: Mon, 2 Mar 2026 01:07:51 +0530 Subject: [PATCH 31/67] [java] Deduplicate Unicode PUA mappings in Keys; make OPTION an alias of ALT and deprecate FN (#17147) * Remove FN key and alias OPTION to ALT to eliminate duplicate Unicode mappings * Fix duplicate PUA mappings in Keys enum; deprecate FN and alias to RIGHT_CONTROL --- java/src/org/openqa/selenium/Keys.java | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/java/src/org/openqa/selenium/Keys.java b/java/src/org/openqa/selenium/Keys.java index 482da1cce469e..de72b50107590 100644 --- a/java/src/org/openqa/selenium/Keys.java +++ b/java/src/org/openqa/selenium/Keys.java @@ -28,8 +28,8 @@ * *

The codes follow conventions partially established by the W3C WebDriver specification and the * Selenium project. Some values (e.g., RIGHT_SHIFT, RIGHT_COMMAND) are used in ChromeDriver but are - * not currently part of the W3C spec. Others (e.g., OPTION, FN) are symbolic and reserved for - * possible future mapping. + * not currently part of the W3C specification. Others (e.g., OPTION) are symbolic aliases for + * existing keys. * *

For consistency across platforms and drivers, values should be verified before assuming native * support. @@ -116,9 +116,15 @@ public enum Keys implements CharSequence { RIGHT_ALT('\uE052'), RIGHT_COMMAND('\uE053'), - // Symbolic macOS keys not yet standardized - OPTION('\uE052'), - FN('\uE051'), // TODO: symbolic only; confirm or remove in future + // macOS-friendly alias (do NOT introduce new codes) + OPTION(Keys.ALT), + + /** + * @deprecated The FN key is not part of the W3C WebDriver specification and does not have a + * standardized Unicode mapping. Its behavior is not guaranteed across drivers/platforms. + */ + @Deprecated + FN(Keys.RIGHT_CONTROL), ZENKAKU_HANKAKU('\uE040'); From 5f7debadcc767bebdbf8529911d3f5c469c2e952 Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Sun, 1 Mar 2026 21:38:11 +0200 Subject: [PATCH 32/67] [java] remove `@Nullable` from return value for `ExpectedConditions` that never return null (#17149) mark `wait.until` return value as non-nullable ... and type of Function both nullable and non-nullable. Expression `wait.until(condition)` never returns null. It either returns an Object or true, or throws TimeoutException. While `Function isTrue` can return null. Fixes #17122 P.S. This is a crazy trick to satisfy Kotlin compiler: ```java @NonNull V until(Function isTrue); ``` Thanks to ``, Kotlin compiler accepts both nullable and non-nullable functions: * `(Function isTrue)` * `(Function isTrue)` --- .../support/ui/ExpectedConditions.java | 38 +++++++++---------- .../selenium/support/ui/FluentWait.java | 3 +- .../org/openqa/selenium/support/ui/Wait.java | 6 ++- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/java/src/org/openqa/selenium/support/ui/ExpectedConditions.java b/java/src/org/openqa/selenium/support/ui/ExpectedConditions.java index da62c59b8f1e8..6a231c715f80b 100644 --- a/java/src/org/openqa/selenium/support/ui/ExpectedConditions.java +++ b/java/src/org/openqa/selenium/support/ui/ExpectedConditions.java @@ -247,8 +247,8 @@ public String toString() { */ public static ExpectedCondition<@Nullable List> visibilityOfAllElementsLocatedBy( final By locator) { - return new ExpectedCondition<@Nullable List>() { - private int indexOfInvisibleElement; + return new ExpectedCondition<>() { + private int indexOfInvisibleElement = -1; private @Nullable WebElement invisibleElement; @Override @@ -304,8 +304,8 @@ public String toString() { */ public static ExpectedCondition<@Nullable List> visibilityOfAllElements( final List elements) { - return new ExpectedCondition<@Nullable List>() { - private int indexOfInvisibleElement; + return new ExpectedCondition<>() { + private int indexOfInvisibleElement = -1; private @Nullable WebElement invisibleElement; @Override @@ -345,7 +345,7 @@ public String toString() { * @return the (same) WebElement once it is visible */ public static ExpectedCondition<@Nullable WebElement> visibilityOf(final WebElement element) { - return new ExpectedCondition<@Nullable WebElement>() { + return new ExpectedCondition<>() { @Override public @Nullable WebElement apply(WebDriver driver) { return elementIfVisible(element); @@ -373,7 +373,7 @@ public String toString() { */ public static ExpectedCondition<@Nullable List> presenceOfAllElementsLocatedBy( final By locator) { - return new ExpectedCondition<@Nullable List>() { + return new ExpectedCondition<>() { @Override public @Nullable List apply(WebDriver driver) { List elements = driver.findElements(locator); @@ -596,7 +596,7 @@ public String toString() { */ public static ExpectedCondition<@Nullable WebDriver> frameToBeAvailableAndSwitchToIt( final By locator) { - return new ExpectedCondition<@Nullable WebDriver>() { + return new ExpectedCondition<>() { private @Nullable NotFoundException error; @Override @@ -660,7 +660,7 @@ public String toString() { */ public static ExpectedCondition<@Nullable WebDriver> frameToBeAvailableAndSwitchToIt( final WebElement frame) { - return new ExpectedCondition<@Nullable WebDriver>() { + return new ExpectedCondition<>() { private @Nullable NoSuchFrameException error; @Override @@ -850,7 +850,7 @@ public String toString() { * @return the result of the provided condition */ public static ExpectedCondition<@Nullable T> refreshed(final ExpectedCondition condition) { - return new ExpectedCondition<@Nullable T>() { + return new ExpectedCondition<>() { @Override public @Nullable T apply(WebDriver driver) { try { @@ -934,7 +934,7 @@ public String toString() { } public static ExpectedCondition<@Nullable Alert> alertIsPresent() { - return new ExpectedCondition<@Nullable Alert>() { + return new ExpectedCondition<>() { @Override public @Nullable Alert apply(WebDriver driver) { try { @@ -953,7 +953,7 @@ public String toString() { public static ExpectedCondition numberOfWindowsToBe(final int expectedNumberOfWindows) { return new ExpectedCondition<>() { - private int actualNumberOfWindows; + private int actualNumberOfWindows = -1; private @Nullable WebDriverException error; @Override @@ -1136,7 +1136,7 @@ public String toString() { */ public static ExpectedCondition<@Nullable List> numberOfElementsToBeMoreThan( final By locator, final Integer expectedNumber) { - return new ExpectedCondition<@Nullable List>() { + return new ExpectedCondition<>() { private Integer actualNumber = 0; @Override @@ -1165,7 +1165,7 @@ public String toString() { */ public static ExpectedCondition<@Nullable List> numberOfElementsToBeLessThan( final By locator, final Integer number) { - return new ExpectedCondition<@Nullable List>() { + return new ExpectedCondition<>() { private Integer currentNumber = 0; @Override @@ -1193,7 +1193,7 @@ public String toString() { */ public static ExpectedCondition<@Nullable List> numberOfElementsToBe( final By locator, final Integer expectedNumberOfElements) { - return new ExpectedCondition<@Nullable List>() { + return new ExpectedCondition<>() { private Integer actualNumberOfElements = -1; @Override @@ -1414,7 +1414,7 @@ private static Optional getAttributeOrCssValue(WebElement element, Strin */ public static ExpectedCondition<@Nullable List> visibilityOfNestedElementsLocatedBy( final By parent, final By childLocator) { - return new ExpectedCondition<@Nullable List>() { + return new ExpectedCondition<>() { private int indexOfInvisibleElement = -1; private @Nullable WebElement invisibleChild; @@ -1465,7 +1465,7 @@ public String toString() { */ public static ExpectedCondition<@Nullable List> visibilityOfNestedElementsLocatedBy( final WebElement element, final By childLocator) { - return new ExpectedCondition<@Nullable List>() { + return new ExpectedCondition<>() { private int indexOfInvisibleElement = -1; private @Nullable WebElement invisibleChild; @@ -1575,7 +1575,7 @@ public String toString() { */ public static ExpectedCondition<@Nullable List> presenceOfNestedElementsLocatedBy( final By parent, final By childLocator) { - return new ExpectedCondition<@Nullable List>() { + return new ExpectedCondition<>() { @Override public @Nullable List apply(WebDriver driver) { @@ -1610,7 +1610,7 @@ public static ExpectedCondition invisibilityOfAllElements(final WebElem public static ExpectedCondition invisibilityOfAllElements( final List elements) { return new ExpectedCondition<>() { - private int indexOfVisibleElement; + private int indexOfVisibleElement = -1; private @Nullable WebElement visibleElement; @Override @@ -1808,7 +1808,7 @@ public String toString() { * @return object once JavaScript executes without errors */ public static ExpectedCondition<@Nullable Object> jsReturnsValue(final String javaScript) { - return new ExpectedCondition<@Nullable Object>() { + return new ExpectedCondition<>() { private @Nullable WebDriverException error; @Override diff --git a/java/src/org/openqa/selenium/support/ui/FluentWait.java b/java/src/org/openqa/selenium/support/ui/FluentWait.java index 054d480438ab0..ddb30af6d2720 100644 --- a/java/src/org/openqa/selenium/support/ui/FluentWait.java +++ b/java/src/org/openqa/selenium/support/ui/FluentWait.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.function.Function; import java.util.function.Supplier; +import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; import org.openqa.selenium.TimeoutException; import org.openqa.selenium.WebDriverException; @@ -200,7 +201,7 @@ public FluentWait ignoring( * @throws TimeoutException If the timeout expires. */ @Override - public V until(Function isTrue) { + public @NonNull V until(Function isTrue) { Instant end = clock.instant().plus(timeout); Throwable lastException; diff --git a/java/src/org/openqa/selenium/support/ui/Wait.java b/java/src/org/openqa/selenium/support/ui/Wait.java index 11cd681b3d14c..35b8a5bf456e6 100644 --- a/java/src/org/openqa/selenium/support/ui/Wait.java +++ b/java/src/org/openqa/selenium/support/ui/Wait.java @@ -18,6 +18,8 @@ package org.openqa.selenium.support.ui; import java.util.function.Function; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; /** * A generic interface for waiting until a condition is true or not null. The condition may take a @@ -36,9 +38,9 @@ public interface Wait { * implementor may throw whatever is idiomatic for a given test infrastructure (e.g. JUnit4 would * throw {@link AssertionError}). * - * @param the return type of the method, which must not be Void + * @param the return type of the method, which must not be Void * @param isTrue the parameter to pass to the {@link ExpectedCondition} * @return truthy value from the isTrue condition */ - T until(Function isTrue); + @NonNull V until(Function isTrue); } From 97af034cc1015904eb8492d8b50f968bb6725e26 Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Sun, 1 Mar 2026 21:38:41 +0200 Subject: [PATCH 33/67] [java] Fluent setters in few classes like `PrintOptions` etc. (#17148) * make setter fluent: ChromiumNetworkConditions This style allows calling multiple setters in a row, thus making code shorter/readable. * make setter fluent: ErrorHandler This style allows calling multiple setters in a row, thus making code shorter/readable. * make setter fluent: LoggingOptions * make setter fluent: SeleniumManagerOutput * make setter fluent: PrintOptions * make setter fluent: Preferences * make setter fluent: FirefoxProfile * make setter fluent: DesiredCapabilities --- .../chromium/AddHasNetworkConditions.java | 34 ++++++------ .../chromium/ChromiumNetworkConditions.java | 16 ++++-- .../selenium/firefox/FirefoxOptions.java | 15 +++++- .../selenium/firefox/FirefoxProfile.java | 12 +++-- .../openqa/selenium/firefox/Preferences.java | 49 +++-------------- .../selenium/grid/log/LoggingOptions.java | 3 +- .../manager/SeleniumManagerOutput.java | 6 ++- .../openqa/selenium/print/PrintOptions.java | 24 ++++++--- .../selenium/remote/DesiredCapabilities.java | 12 +++-- .../openqa/selenium/remote/ErrorHandler.java | 3 +- .../org/openqa/selenium/PrintPageTest.java | 14 +++-- .../openqa/selenium/WebScriptExecuteTest.java | 5 +- .../browsingcontext/BrowsingContextTest.java | 4 +- .../chrome/ChromeDriverFunctionalTest.java | 14 +++-- .../edge/EdgeDriverFunctionalTest.java | 6 +-- .../firefox/FirefoxDriverPreferenceTest.java | 7 +-- .../selenium/firefox/FirefoxDriverTest.java | 2 +- .../selenium/firefox/FirefoxOptionsTest.java | 32 +++++++---- .../selenium/firefox/PreferencesTest.java | 18 ++++++- .../openqa/selenium/print/PageSizeTest.java | 29 +++------- .../selenium/print/PrintOptionsTest.java | 23 ++++---- .../remote/DesiredCapabilitiesTest.java | 53 ++++++++++++++++++- .../selenium/remote/ErrorHandlerTest.java | 9 +--- 23 files changed, 220 insertions(+), 170 deletions(-) diff --git a/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java b/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java index 07186b77e1f11..cb976cb2aa87d 100644 --- a/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java +++ b/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java @@ -17,7 +17,12 @@ package org.openqa.selenium.chromium; +import static java.util.Objects.requireNonNull; import static org.openqa.selenium.chromium.ChromiumDriver.IS_CHROMIUM_BROWSER; +import static org.openqa.selenium.chromium.ChromiumNetworkConditions.DOWNLOAD_THROUGHPUT; +import static org.openqa.selenium.chromium.ChromiumNetworkConditions.LATENCY; +import static org.openqa.selenium.chromium.ChromiumNetworkConditions.OFFLINE; +import static org.openqa.selenium.chromium.ChromiumNetworkConditions.UPLOAD_THROUGHPUT; import com.google.auto.service.AutoService; import java.time.Duration; @@ -73,19 +78,14 @@ public HasNetworkConditions getImplementation( public ChromiumNetworkConditions getNetworkConditions() { @SuppressWarnings("unchecked") Map result = - (Map) executeMethod.execute(GET_NETWORK_CONDITIONS, null); - ChromiumNetworkConditions networkConditions = new ChromiumNetworkConditions(); - networkConditions.setOffline( - (Boolean) result.getOrDefault(ChromiumNetworkConditions.OFFLINE, false)); - networkConditions.setLatency( - Duration.ofMillis((Long) result.getOrDefault(ChromiumNetworkConditions.LATENCY, 0))); - networkConditions.setDownloadThroughput( - ((Number) result.getOrDefault(ChromiumNetworkConditions.DOWNLOAD_THROUGHPUT, -1)) - .intValue()); - networkConditions.setUploadThroughput( - ((Number) result.getOrDefault(ChromiumNetworkConditions.UPLOAD_THROUGHPUT, -1)) - .intValue()); - return networkConditions; + (Map) + requireNonNull(executeMethod.execute(GET_NETWORK_CONDITIONS, null)); + return new ChromiumNetworkConditions() + .setOffline((Boolean) result.getOrDefault(OFFLINE, false)) + .setLatency(Duration.ofMillis((Long) result.getOrDefault(LATENCY, 0))) + .setDownloadThroughput( + ((Number) result.getOrDefault(DOWNLOAD_THROUGHPUT, -1)).intValue()) + .setUploadThroughput(((Number) result.getOrDefault(UPLOAD_THROUGHPUT, -1)).intValue()); } @Override @@ -94,13 +94,13 @@ public void setNetworkConditions(ChromiumNetworkConditions networkConditions) { Map conditions = Map.of( - ChromiumNetworkConditions.OFFLINE, + OFFLINE, networkConditions.getOffline(), - ChromiumNetworkConditions.LATENCY, + LATENCY, networkConditions.getLatency().toMillis(), - ChromiumNetworkConditions.DOWNLOAD_THROUGHPUT, + DOWNLOAD_THROUGHPUT, networkConditions.getDownloadThroughput(), - ChromiumNetworkConditions.UPLOAD_THROUGHPUT, + UPLOAD_THROUGHPUT, networkConditions.getUploadThroughput()); executeMethod.execute(SET_NETWORK_CONDITIONS, Map.of("network_conditions", conditions)); } diff --git a/java/src/org/openqa/selenium/chromium/ChromiumNetworkConditions.java b/java/src/org/openqa/selenium/chromium/ChromiumNetworkConditions.java index d7f81229177e3..b14185dd68440 100644 --- a/java/src/org/openqa/selenium/chromium/ChromiumNetworkConditions.java +++ b/java/src/org/openqa/selenium/chromium/ChromiumNetworkConditions.java @@ -31,6 +31,10 @@ public class ChromiumNetworkConditions { private int downloadThroughput = -1; private int uploadThroughput = -1; + public static ChromiumNetworkConditions withLatency(Duration latency) { + return new ChromiumNetworkConditions().setLatency(latency); + } + /** * @return whether network is simulated to be offline. */ @@ -43,8 +47,9 @@ public boolean getOffline() { * * @param offline when set to true, network is simulated to be offline. */ - public void setOffline(boolean offline) { + public ChromiumNetworkConditions setOffline(boolean offline) { this.offline = offline; + return this; } /** @@ -61,8 +66,9 @@ public Duration getLatency() { * * @param latency amount of latency, typically a Duration of milliseconds. */ - public void setLatency(Duration latency) { + public ChromiumNetworkConditions setLatency(Duration latency) { this.latency = latency; + return this; } /** @@ -79,8 +85,9 @@ public int getDownloadThroughput() { * * @param downloadThroughput throughput in kb/second */ - public void setDownloadThroughput(int downloadThroughput) { + public ChromiumNetworkConditions setDownloadThroughput(int downloadThroughput) { this.downloadThroughput = downloadThroughput; + return this; } /** @@ -97,7 +104,8 @@ public int getUploadThroughput() { * * @param uploadThroughput throughput in kb/second */ - public void setUploadThroughput(int uploadThroughput) { + public ChromiumNetworkConditions setUploadThroughput(int uploadThroughput) { this.uploadThroughput = uploadThroughput; + return this; } } diff --git a/java/src/org/openqa/selenium/firefox/FirefoxOptions.java b/java/src/org/openqa/selenium/firefox/FirefoxOptions.java index c4a817a4a905e..6950661874cc6 100644 --- a/java/src/org/openqa/selenium/firefox/FirefoxOptions.java +++ b/java/src/org/openqa/selenium/firefox/FirefoxOptions.java @@ -67,6 +67,11 @@ public FirefoxOptions() { addPreference("remote.active-protocols", 1); } + public FirefoxOptions(FirefoxProfile profile) { + this(); + setProfile(profile); + } + public FirefoxOptions(Capabilities source) { // We need to initialize all our own fields before calling. this(); @@ -176,7 +181,7 @@ public FirefoxProfile getProfile() { } } - public FirefoxOptions setProfile(FirefoxProfile profile) { + public final FirefoxOptions setProfile(FirefoxProfile profile) { Require.nonNull("Profile", profile); try { @@ -222,6 +227,14 @@ public FirefoxOptions addPreference(String key, Object value) { return setFirefoxOption(Keys.PREFS, Collections.unmodifiableMap(newPrefs)); } + Map prefs() { + return getOption(Keys.PREFS); + } + + String profile() { + return getOption(Keys.PROFILE); + } + public FirefoxOptions setLogLevel(FirefoxDriverLogLevel logLevel) { Require.nonNull("Log level", logLevel); return setFirefoxOption(Keys.LOG, logLevel.toJson()); diff --git a/java/src/org/openqa/selenium/firefox/FirefoxProfile.java b/java/src/org/openqa/selenium/firefox/FirefoxProfile.java index 4e92f1b267f14..f39b5e1fa33c9 100644 --- a/java/src/org/openqa/selenium/firefox/FirefoxProfile.java +++ b/java/src/org/openqa/selenium/firefox/FirefoxProfile.java @@ -177,8 +177,9 @@ private String deriveExtensionName(String originalName) { return name; } - public void setPreference(String key, Object value) { + public FirefoxProfile setPreference(String key, Object value) { additionalPrefs.setPreference(key, value); + return this; } protected Preferences getAdditionalPreferences() { @@ -256,8 +257,9 @@ public boolean shouldLoadNoFocusLib() { * * @param loadNoFocusLib Whether to always load the no focus library. */ - public void setAlwaysLoadNoFocusLib(boolean loadNoFocusLib) { + public FirefoxProfile setAlwaysLoadNoFocusLib(boolean loadNoFocusLib) { this.loadNoFocusLib = loadNoFocusLib; + return this; } /** @@ -266,8 +268,9 @@ public void setAlwaysLoadNoFocusLib(boolean loadNoFocusLib) { * * @param acceptUntrustedSsl Whether untrusted SSL certificates should be accepted. */ - public void setAcceptUntrustedCertificates(boolean acceptUntrustedSsl) { + public FirefoxProfile setAcceptUntrustedCertificates(boolean acceptUntrustedSsl) { this.acceptUntrustedCerts = acceptUntrustedSsl; + return this; } /** @@ -284,8 +287,9 @@ public void setAcceptUntrustedCertificates(boolean acceptUntrustedSsl) { * * @param untrustedIssuer whether to assume untrusted issuer or not. */ - public void setAssumeUntrustedCertificateIssuer(boolean untrustedIssuer) { + public FirefoxProfile setAssumeUntrustedCertificateIssuer(boolean untrustedIssuer) { this.untrustedCertIssuer = untrustedIssuer; + return this; } public void clean(File profileDir) { diff --git a/java/src/org/openqa/selenium/firefox/Preferences.java b/java/src/org/openqa/selenium/firefox/Preferences.java index f42b8e31455c0..51f272b323962 100644 --- a/java/src/org/openqa/selenium/firefox/Preferences.java +++ b/java/src/org/openqa/selenium/firefox/Preferences.java @@ -17,6 +17,7 @@ package org.openqa.selenium.firefox; +import static java.util.Collections.unmodifiableMap; import static org.openqa.selenium.json.Json.MAP_TYPE; import java.io.BufferedReader; @@ -32,19 +33,10 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import org.openqa.selenium.WebDriverException; -import org.openqa.selenium.internal.Require; import org.openqa.selenium.json.Json; class Preferences { - /** - * The maximum amount of time scripts should be permitted to run. The user may increase this - * timeout, but may not set it below the default value. - */ - private static final String MAX_SCRIPT_RUN_TIME_KEY = "dom.max_script_run_time"; - - private static final int DEFAULT_MAX_SCRIPT_RUN_TIME = 30; - /** * This pattern is used to parse preferences in user.js. It is intended to match all preference * lines in the format generated by Firefox; it won't necessarily match all possible lines that @@ -56,7 +48,6 @@ class Preferences { private static final Pattern PREFERENCE_PATTERN = Pattern.compile("user_pref\\(\"([^\"]+)\", (\"?.+?\"?)\\);"); - private final Map immutablePrefs = new HashMap<>(); private final Map allPrefs = new HashMap<>(); public Preferences() {} @@ -87,6 +78,10 @@ public Preferences(Reader defaults, Reader reader) { } } + Map asMap() { + return unmodifiableMap(allPrefs); + } + private void readUserPrefs(File userPrefs) { try (Reader reader = Files.newBufferedReader(userPrefs.toPath(), Charset.defaultCharset())) { readPreferences(reader); @@ -111,7 +106,6 @@ private void readDefaultPreferences(Reader defaultsReader) { value = ((Long) value).intValue(); } setPreference(key, value); - immutablePrefs.put(key, value); }); Map mutable = (Map) map.get("mutable"); @@ -122,7 +116,7 @@ private void readDefaultPreferences(Reader defaultsReader) { } } - public void setPreference(String key, Object value) { + public Preferences setPreference(String key, Object value) { if (value instanceof String) { if (isStringified((String) value)) { throw new IllegalArgumentException( @@ -134,6 +128,7 @@ public void setPreference(String key, Object value) { } else { allPrefs.put(key, value); } + return this; } private void readPreferences(Reader reader) throws IOException { @@ -197,34 +192,4 @@ private boolean isStringified(String value) { // the first character == " and the last character == " return value.startsWith("\"") && value.endsWith("\""); } - - private void checkPreference(String key, Object value) { - Require.nonNull("Key", key); - Require.nonNull("Value", value); - Require.stateCondition( - !immutablePrefs.containsKey(key) - || (immutablePrefs.containsKey(key) && value.equals(immutablePrefs.get(key))), - "Preference %s may not be overridden: frozen value=%s, requested value=%s", - key, - immutablePrefs.get(key), - value); - if (MAX_SCRIPT_RUN_TIME_KEY.equals(key)) { - int n; - if (value instanceof String) { - n = Integer.parseInt((String) value); - } else if (value instanceof Integer) { - n = (Integer) value; - } else { - throw new IllegalStateException( - String.format( - "%s value must be a number: %s", - MAX_SCRIPT_RUN_TIME_KEY, value.getClass().getName())); - } - Require.stateCondition( - n == 0 || n >= DEFAULT_MAX_SCRIPT_RUN_TIME, - "%s must be == 0 || >= %s", - MAX_SCRIPT_RUN_TIME_KEY, - DEFAULT_MAX_SCRIPT_RUN_TIME); - } - } } diff --git a/java/src/org/openqa/selenium/grid/log/LoggingOptions.java b/java/src/org/openqa/selenium/grid/log/LoggingOptions.java index 7151c63dec0fa..20cb5a770e68a 100644 --- a/java/src/org/openqa/selenium/grid/log/LoggingOptions.java +++ b/java/src/org/openqa/selenium/grid/log/LoggingOptions.java @@ -83,7 +83,7 @@ public String getLogEncoding() { return config.get(LOGGING_SECTION, "log-encoding").orElse(null); } - public void setLoggingLevel() { + public LoggingOptions setLoggingLevel() { String configLevel = config.get(LOGGING_SECTION, "log-level").orElse(DEFAULT_LOG_LEVEL); if (Debug.isDebugAll()) { System.err.println( @@ -105,6 +105,7 @@ public void setLoggingLevel() { + DEFAULT_LOG_LEVELS) .printStackTrace(); } + return this; } public Tracer getTracer() { diff --git a/java/src/org/openqa/selenium/manager/SeleniumManagerOutput.java b/java/src/org/openqa/selenium/manager/SeleniumManagerOutput.java index 3c408679819ba..80629dde4e676 100644 --- a/java/src/org/openqa/selenium/manager/SeleniumManagerOutput.java +++ b/java/src/org/openqa/selenium/manager/SeleniumManagerOutput.java @@ -32,16 +32,18 @@ public List getLogs() { return logs; } - public void setLogs(List logs) { + public SeleniumManagerOutput setLogs(List logs) { this.logs = logs; + return this; } public Result getResult() { return result; } - public void setResult(Result result) { + public SeleniumManagerOutput setResult(Result result) { this.result = result; + return this; } public static class Log { diff --git a/java/src/org/openqa/selenium/print/PrintOptions.java b/java/src/org/openqa/selenium/print/PrintOptions.java index 602c9c41f9a53..1e1b7a86da3d5 100644 --- a/java/src/org/openqa/selenium/print/PrintOptions.java +++ b/java/src/org/openqa/selenium/print/PrintOptions.java @@ -55,15 +55,16 @@ public Orientation getOrientation() { return this.orientation; } - public void setOrientation(Orientation orientation) { + public PrintOptions setOrientation(Orientation orientation) { this.orientation = Require.nonNull("orientation", orientation); + return this; } public String @Nullable [] getPageRanges() { return this.pageRanges; } - public void setPageRanges(String firstRange, String... ranges) { + public PrintOptions setPageRanges(String firstRange, String... ranges) { Require.nonNull("pageRanges", firstRange); this.pageRanges = new String[ranges.length + 1]; // Need to add all ranges and the initial range too. @@ -71,26 +72,30 @@ public void setPageRanges(String firstRange, String... ranges) { this.pageRanges[0] = firstRange; if (ranges.length > 0) System.arraycopy(ranges, 0, this.pageRanges, 1, ranges.length); + return this; } - public void setPageRanges(List ranges) { + public PrintOptions setPageRanges(List ranges) { this.pageRanges = new String[ranges.size()]; this.pageRanges = ranges.toArray(this.pageRanges); + return this; } - public void setBackground(boolean background) { + public PrintOptions setBackground(boolean background) { this.background = Require.nonNull("background", background); + return this; } public boolean getBackground() { return this.background; } - public void setScale(double scale) { + public PrintOptions setScale(double scale) { if (scale < 0.1 || scale > 2) { throw new IllegalArgumentException("Scale value should be between 0.1 and 2"); } this.scale = scale; + return this; } public double getScale() { @@ -101,16 +106,19 @@ public boolean getShrinkToFit() { return this.shrinkToFit; } - public void setShrinkToFit(boolean value) { + public PrintOptions setShrinkToFit(boolean value) { this.shrinkToFit = Require.nonNull("value", value); + return this; } - public void setPageSize(PageSize pageSize) { + public PrintOptions setPageSize(PageSize pageSize) { this.pageSize = Require.nonNull("pageSize", pageSize); + return this; } - public void setPageMargin(PageMargin margin) { + public PrintOptions setPageMargin(PageMargin margin) { this.pageMargin = Require.nonNull("margin", margin); + return this; } public PageSize getPageSize() { diff --git a/java/src/org/openqa/selenium/remote/DesiredCapabilities.java b/java/src/org/openqa/selenium/remote/DesiredCapabilities.java index 501eef6965fb3..518ceb80d8aae 100644 --- a/java/src/org/openqa/selenium/remote/DesiredCapabilities.java +++ b/java/src/org/openqa/selenium/remote/DesiredCapabilities.java @@ -55,16 +55,19 @@ public DesiredCapabilities(Capabilities... others) { } } - public void setBrowserName(String browserName) { + public DesiredCapabilities setBrowserName(String browserName) { setCapability(BROWSER_NAME, browserName); + return this; } - public void setVersion(String version) { + public DesiredCapabilities setVersion(String version) { setCapability(BROWSER_VERSION, version); + return this; } - public void setPlatform(Platform platform) { + public DesiredCapabilities setPlatform(Platform platform) { setCapability(PLATFORM_NAME, platform); + return this; } public boolean acceptInsecureCerts() { @@ -79,8 +82,9 @@ public boolean acceptInsecureCerts() { return true; } - public void setAcceptInsecureCerts(boolean acceptInsecureCerts) { + public DesiredCapabilities setAcceptInsecureCerts(boolean acceptInsecureCerts) { setCapability(ACCEPT_INSECURE_CERTS, acceptInsecureCerts); + return this; } /** diff --git a/java/src/org/openqa/selenium/remote/ErrorHandler.java b/java/src/org/openqa/selenium/remote/ErrorHandler.java index 59a52ddb389ca..63d658824d3b0 100644 --- a/java/src/org/openqa/selenium/remote/ErrorHandler.java +++ b/java/src/org/openqa/selenium/remote/ErrorHandler.java @@ -74,8 +74,9 @@ public boolean isIncludeServerErrors() { return includeServerErrors; } - public void setIncludeServerErrors(boolean includeServerErrors) { + public ErrorHandler setIncludeServerErrors(boolean includeServerErrors) { this.includeServerErrors = includeServerErrors; + return this; } @SuppressWarnings("unchecked") diff --git a/java/test/org/openqa/selenium/PrintPageTest.java b/java/test/org/openqa/selenium/PrintPageTest.java index e4db436a4db48..cad3ccce3b2b7 100644 --- a/java/test/org/openqa/selenium/PrintPageTest.java +++ b/java/test/org/openqa/selenium/PrintPageTest.java @@ -53,8 +53,7 @@ public void canPrintPage() { @Test @Ignore(value = CHROME) public void canPrintTwoPages() { - PrintOptions printOptions = new PrintOptions(); - printOptions.setPageRanges("1-2"); + PrintOptions printOptions = new PrintOptions().setPageRanges("1-2"); Pdf pdf = printer.print(printOptions); assertThat(pdf.getContent()).contains(MAGIC_STRING); @@ -64,12 +63,11 @@ public void canPrintTwoPages() { @Test @Ignore(value = CHROME) public void canPrintWithValidParams() { - PrintOptions printOptions = new PrintOptions(); - PageSize pageSize = new PageSize(); - - printOptions.setPageRanges("1-2"); - printOptions.setOrientation(PrintOptions.Orientation.LANDSCAPE); - printOptions.setPageSize(pageSize); + PrintOptions printOptions = + new PrintOptions() + .setPageRanges("1-2") + .setOrientation(PrintOptions.Orientation.LANDSCAPE) + .setPageSize(new PageSize()); Pdf pdf = printer.print(printOptions); assertThat(pdf.getContent()).contains(MAGIC_STRING); diff --git a/java/test/org/openqa/selenium/WebScriptExecuteTest.java b/java/test/org/openqa/selenium/WebScriptExecuteTest.java index bbce383072c02..7b37a29b98aa7 100644 --- a/java/test/org/openqa/selenium/WebScriptExecuteTest.java +++ b/java/test/org/openqa/selenium/WebScriptExecuteTest.java @@ -280,9 +280,6 @@ void canExecuteScriptWithMapArgument() { @Test void canExecuteScriptWithObjectArgument() { - - PrintOptions options = new PrintOptions(); - RemoteValue value = ((RemoteWebDriver) driver) .script() @@ -293,7 +290,7 @@ void canExecuteScriptWithObjectArgument() { + " Object.prototype.toString.call(arg));\n" + " return arg;\n" + " }}", - options); + new PrintOptions()); assertThat(value.getType()).isEqualTo("object"); diff --git a/java/test/org/openqa/selenium/bidi/browsingcontext/BrowsingContextTest.java b/java/test/org/openqa/selenium/bidi/browsingcontext/BrowsingContextTest.java index 2dd10dc897a2e..eff6f0f3c396e 100644 --- a/java/test/org/openqa/selenium/bidi/browsingcontext/BrowsingContextTest.java +++ b/java/test/org/openqa/selenium/bidi/browsingcontext/BrowsingContextTest.java @@ -523,11 +523,9 @@ void canSetViewportWithDevicePixelRatio() { @NeedsFreshDriver void canPrintPage() { BrowsingContext browsingContext = new BrowsingContext(driver, driver.getWindowHandle()); - driver.get(appServer.whereIs("formPage.html")); - PrintOptions printOptions = new PrintOptions(); - String printPage = browsingContext.print(printOptions); + String printPage = browsingContext.print(new PrintOptions()); assertThat(printPage).isNotEmpty(); // Comparing expected PDF is a hard problem. diff --git a/java/test/org/openqa/selenium/chrome/ChromeDriverFunctionalTest.java b/java/test/org/openqa/selenium/chrome/ChromeDriverFunctionalTest.java index c7bb9418a03df..b51057a2d4959 100644 --- a/java/test/org/openqa/selenium/chrome/ChromeDriverFunctionalTest.java +++ b/java/test/org/openqa/selenium/chrome/ChromeDriverFunctionalTest.java @@ -17,9 +17,11 @@ package org.openqa.selenium.chrome; +import static java.time.Duration.ofMillis; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assumptions.assumeThat; +import static org.openqa.selenium.chromium.ChromiumNetworkConditions.withLatency; import static org.openqa.selenium.testing.drivers.Browser.CHROME; import java.time.Duration; @@ -33,7 +35,6 @@ import org.openqa.selenium.SessionNotCreatedException; import org.openqa.selenium.WebDriver; import org.openqa.selenium.WebDriverException; -import org.openqa.selenium.chromium.ChromiumNetworkConditions; import org.openqa.selenium.chromium.HasCasting; import org.openqa.selenium.chromium.HasCdp; import org.openqa.selenium.chromium.HasNetworkConditions; @@ -64,10 +65,9 @@ public void builderGeneratesDefaultChromeOptions() { @NoDriverBeforeTest public void builderOverridesDefaultChromeOptions() { ChromeOptions options = (ChromeOptions) CHROME.getCapabilities(); - options.setImplicitWaitTimeout(Duration.ofMillis(1)); + options.setImplicitWaitTimeout(ofMillis(1)); localDriver = ChromeDriver.builder().oneOf(options).build(); - assertThat(localDriver.manage().timeouts().getImplicitWaitTimeout()) - .isEqualTo(Duration.ofMillis(1)); + assertThat(localDriver.manage().timeouts().getImplicitWaitTimeout()).isEqualTo(ofMillis(1)); } @Test @@ -171,11 +171,9 @@ public void canCastOnDesktop() throws InterruptedException { void canManageNetworkConditions() { HasNetworkConditions conditions = (HasNetworkConditions) driver; - ChromiumNetworkConditions networkConditions = new ChromiumNetworkConditions(); - networkConditions.setLatency(Duration.ofMillis(200)); + conditions.setNetworkConditions(withLatency(ofMillis(200))); - conditions.setNetworkConditions(networkConditions); - assertThat(conditions.getNetworkConditions().getLatency()).isEqualTo(Duration.ofMillis(200)); + assertThat(conditions.getNetworkConditions().getLatency()).isEqualTo(ofMillis(200)); conditions.deleteNetworkConditions(); diff --git a/java/test/org/openqa/selenium/edge/EdgeDriverFunctionalTest.java b/java/test/org/openqa/selenium/edge/EdgeDriverFunctionalTest.java index df8c87af2e490..97f61f2f980eb 100644 --- a/java/test/org/openqa/selenium/edge/EdgeDriverFunctionalTest.java +++ b/java/test/org/openqa/selenium/edge/EdgeDriverFunctionalTest.java @@ -21,6 +21,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assumptions.assumeThat; +import static org.openqa.selenium.chromium.ChromiumNetworkConditions.withLatency; import static org.openqa.selenium.testing.drivers.Browser.EDGE; import java.time.Duration; @@ -33,7 +34,6 @@ import org.openqa.selenium.SessionNotCreatedException; import org.openqa.selenium.WebDriver; import org.openqa.selenium.WebDriverException; -import org.openqa.selenium.chromium.ChromiumNetworkConditions; import org.openqa.selenium.chromium.HasCasting; import org.openqa.selenium.chromium.HasCdp; import org.openqa.selenium.chromium.HasNetworkConditions; @@ -167,10 +167,8 @@ void canCast() throws InterruptedException { void canManageNetworkConditions() { HasNetworkConditions conditions = (HasNetworkConditions) driver; - ChromiumNetworkConditions networkConditions = new ChromiumNetworkConditions(); - networkConditions.setLatency(Duration.ofMillis(200)); + conditions.setNetworkConditions(withLatency(Duration.ofMillis(200))); - conditions.setNetworkConditions(networkConditions); assertThat(conditions.getNetworkConditions().getLatency()).isEqualTo(Duration.ofMillis(200)); conditions.deleteNetworkConditions(); diff --git a/java/test/org/openqa/selenium/firefox/FirefoxDriverPreferenceTest.java b/java/test/org/openqa/selenium/firefox/FirefoxDriverPreferenceTest.java index 323c2cc0b0eae..710389c5a961b 100644 --- a/java/test/org/openqa/selenium/firefox/FirefoxDriverPreferenceTest.java +++ b/java/test/org/openqa/selenium/firefox/FirefoxDriverPreferenceTest.java @@ -39,9 +39,10 @@ private FirefoxOptions getDefaultOptions() { @Test @NoDriverBeforeTest public void canStartDriverWithSpecifiedProfile() { - FirefoxProfile profile = new FirefoxProfile(); - profile.setPreference("browser.startup.page", 1); - profile.setPreference("browser.startup.homepage", pages.xhtmlTestPage); + FirefoxProfile profile = + new FirefoxProfile() + .setPreference("browser.startup.page", 1) + .setPreference("browser.startup.homepage", pages.xhtmlTestPage); localDriver = new WebDriverBuilder().get(getDefaultOptions().setProfile(profile)); diff --git a/java/test/org/openqa/selenium/firefox/FirefoxDriverTest.java b/java/test/org/openqa/selenium/firefox/FirefoxDriverTest.java index 335540082f1aa..832ce869bd828 100644 --- a/java/test/org/openqa/selenium/firefox/FirefoxDriverTest.java +++ b/java/test/org/openqa/selenium/firefox/FirefoxDriverTest.java @@ -141,7 +141,7 @@ public void shouldBeAbleToStartANamedProfile() { FirefoxProfile profile = new ProfilesIni().getProfile("default"); assumeTrue(profile != null); - localDriver = new WebDriverBuilder().get(new FirefoxOptions().setProfile(profile)); + localDriver = new WebDriverBuilder().get(new FirefoxOptions(profile)); } @Test diff --git a/java/test/org/openqa/selenium/firefox/FirefoxOptionsTest.java b/java/test/org/openqa/selenium/firefox/FirefoxOptionsTest.java index 2cc683e50cc10..473b218ebba21 100644 --- a/java/test/org/openqa/selenium/firefox/FirefoxOptionsTest.java +++ b/java/test/org/openqa/selenium/firefox/FirefoxOptionsTest.java @@ -55,6 +55,26 @@ @Tag("UnitTests") class FirefoxOptionsTest { + @Test + void defaultConstructor() { + FirefoxOptions options = new FirefoxOptions(); + + assertThat(options.getBrowserName()).isEqualTo("firefox"); + assertThat(options.getCapability(ACCEPT_INSECURE_CERTS)).isEqualTo(true); + assertThat(options.prefs()).containsExactly(Map.entry("remote.active-protocols", 1)); + assertThat(options.profile()).isNull(); + } + + @Test + void constructorWithProfile() { + FirefoxOptions options = new FirefoxOptions(new FirefoxProfile().setPreference("foo", "bar")); + + assertThat(options.getBrowserName()).isEqualTo("firefox"); + assertThat(options.getCapability(ACCEPT_INSECURE_CERTS)).isEqualTo(true); + assertThat(options.prefs()).containsExactly(Map.entry("remote.active-protocols", 1)); + assertThat(options.profile()).isBase64(); + } + @Test void canInitFirefoxOptionsWithCapabilities() { FirefoxOptions options = @@ -138,8 +158,7 @@ void shouldGetStringPreferencesFromGetProfile() { String key = "browser.startup.homepage"; String value = "about:robots"; - FirefoxProfile profile = new FirefoxProfile(); - profile.setPreference(key, value); + FirefoxProfile profile = new FirefoxProfile().setPreference(key, value); FirefoxOptions options = new FirefoxOptions(); options.setProfile(profile); @@ -311,11 +330,7 @@ void mergingOptionsWithMutableCapabilities() { String key = "browser.startup.homepage"; String value = "about:robots"; - FirefoxProfile profile = new FirefoxProfile(); - profile.setPreference(key, value); - - options.setProfile(profile); - + options.setProfile(new FirefoxProfile().setPreference(key, value)); options.setLogLevel(DEBUG); File binary = TestUtilities.createTmpFile("binary"); @@ -383,8 +398,7 @@ void mergingOptionsWithOptionsAsMutableCapabilities() throws IOException { String key = "browser.startup.homepage"; String value = "about:robots"; - FirefoxProfile profile = new FirefoxProfile(); - profile.setPreference(key, value); + FirefoxProfile profile = new FirefoxProfile().setPreference(key, value); File binary = TestUtilities.createTmpFile("binary"); diff --git a/java/test/org/openqa/selenium/firefox/PreferencesTest.java b/java/test/org/openqa/selenium/firefox/PreferencesTest.java index 55e835b05677e..5cdbae1a520c9 100644 --- a/java/test/org/openqa/selenium/firefox/PreferencesTest.java +++ b/java/test/org/openqa/selenium/firefox/PreferencesTest.java @@ -17,10 +17,12 @@ package org.openqa.selenium.firefox; +import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import java.io.Reader; import java.io.StringReader; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -120,7 +122,21 @@ void canOverrideAFrozenPreferenceWithTheFrozenValue() { preferences.setPreference("frozen.pref", true); - assertThat(preferences.getPreference("frozen.pref")).isEqualTo(true); + assertThat(preferences.asMap()).containsExactly(Map.entry("frozen.pref", true)); + } + + @Test + void canUseFluentSetter() { + Preferences preferences = + new Preferences() + .setPreference("one", true) + .setPreference("two", 2) + .setPreference("three", "3.0") + .setPreference("four", asList(1, 2, 3)); + + assertThat(preferences.asMap()) + .containsExactlyInAnyOrderEntriesOf( + Map.of("one", true, "two", 2, "three", "3.0", "four", asList(1, 2, 3))); } private boolean canSet(Preferences pref, String value) { diff --git a/java/test/org/openqa/selenium/print/PageSizeTest.java b/java/test/org/openqa/selenium/print/PageSizeTest.java index ec8240f73c309..822665dea299f 100644 --- a/java/test/org/openqa/selenium/print/PageSizeTest.java +++ b/java/test/org/openqa/selenium/print/PageSizeTest.java @@ -19,20 +19,12 @@ import static org.assertj.core.api.Assertions.assertThat; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @Tag("UnitTests") class PageSizeTest { - private PrintOptions printOptions; - - @BeforeEach - void setUp() { - printOptions = new PrintOptions(); - } - @Test void setsDefaultHeightWidth() { PageSize pageSize = new PageSize(); @@ -42,30 +34,25 @@ void setsDefaultHeightWidth() { @Test void verifiesPageSizeA4() { - - printOptions.setPageSize(PageSize.ISO_A4); - assertThat(printOptions.getPageSize().getHeight()).isEqualTo(29.7); - assertThat(printOptions.getPageSize().getWidth()).isEqualTo(21.0); + assertThat(PageSize.ISO_A4.getHeight()).isEqualTo(29.7); + assertThat(PageSize.ISO_A4.getWidth()).isEqualTo(21.0); } @Test void verifiesPageSizeLegal() { - printOptions.setPageSize(PageSize.US_LEGAL); - assertThat(printOptions.getPageSize().getHeight()).isEqualTo(35.56); - assertThat(printOptions.getPageSize().getWidth()).isEqualTo(21.59); + assertThat(PageSize.US_LEGAL.getHeight()).isEqualTo(35.56); + assertThat(PageSize.US_LEGAL.getWidth()).isEqualTo(21.59); } @Test void verifiesPageSizeLetter() { - printOptions.setPageSize(PageSize.US_LETTER); - assertThat(printOptions.getPageSize().getHeight()).isEqualTo(27.94); - assertThat(printOptions.getPageSize().getWidth()).isEqualTo(21.59); + assertThat(PageSize.US_LETTER.getHeight()).isEqualTo(27.94); + assertThat(PageSize.US_LETTER.getWidth()).isEqualTo(21.59); } @Test void verifiesPageSizeTabloid() { - printOptions.setPageSize(PageSize.ANSI_TABLOID); - assertThat(printOptions.getPageSize().getHeight()).isEqualTo(43.18); - assertThat(printOptions.getPageSize().getWidth()).isEqualTo(27.94); + assertThat(PageSize.ANSI_TABLOID.getHeight()).isEqualTo(43.18); + assertThat(PageSize.ANSI_TABLOID.getWidth()).isEqualTo(27.94); } } diff --git a/java/test/org/openqa/selenium/print/PrintOptionsTest.java b/java/test/org/openqa/selenium/print/PrintOptionsTest.java index beb0753b64eca..d7313b1f2c7a3 100644 --- a/java/test/org/openqa/selenium/print/PrintOptionsTest.java +++ b/java/test/org/openqa/selenium/print/PrintOptionsTest.java @@ -39,11 +39,8 @@ void setsDefaultValues() { @Test void setsValuesAsPassed() { - PrintOptions printOptions = new PrintOptions(); - - printOptions.setBackground(true); - printOptions.setScale(1.5); - printOptions.setShrinkToFit(false); + PrintOptions printOptions = + new PrintOptions().setBackground(true).setScale(1.5).setShrinkToFit(false); assertThat(printOptions.getScale()).isEqualTo(1.5); assertThat(printOptions.getBackground()).isTrue(); @@ -66,14 +63,12 @@ void toMapContainsProperKey() { Map map = printOptions.toMap(); assertThat(map).hasSize(7); - assertThat(map).containsKey("page"); - assertThat(map).containsKey("orientation"); - assertThat(map).containsKey("scale"); - assertThat(map).containsKey("shrinkToFit"); - assertThat(map).containsKey("background"); - assertThat(map).containsKey("pageRanges"); - assertThat(map).containsKey("margin"); - assertThat(map.get("margin")).asInstanceOf(MAP).hasSize(4); - assertThat(map.get("page")).asInstanceOf(MAP).hasSize(2); + assertThat(map) + .containsOnlyKeys( + "page", "orientation", "scale", "shrinkToFit", "background", "pageRanges", "margin"); + assertThat(map.get("margin")) + .asInstanceOf(MAP) + .containsOnlyKeys("top", "left", "bottom", "right"); + assertThat(map.get("page")).asInstanceOf(MAP).containsOnlyKeys("width", "height"); } } diff --git a/java/test/org/openqa/selenium/remote/DesiredCapabilitiesTest.java b/java/test/org/openqa/selenium/remote/DesiredCapabilitiesTest.java index 5f3187b85689c..45b805393fdcf 100644 --- a/java/test/org/openqa/selenium/remote/DesiredCapabilitiesTest.java +++ b/java/test/org/openqa/selenium/remote/DesiredCapabilitiesTest.java @@ -18,6 +18,12 @@ package org.openqa.selenium.remote; import static org.assertj.core.api.Assertions.assertThat; +import static org.openqa.selenium.Platform.ANDROID; +import static org.openqa.selenium.Platform.WINDOWS; +import static org.openqa.selenium.remote.CapabilityType.ACCEPT_INSECURE_CERTS; +import static org.openqa.selenium.remote.CapabilityType.BROWSER_NAME; +import static org.openqa.selenium.remote.CapabilityType.BROWSER_VERSION; +import static org.openqa.selenium.remote.CapabilityType.PLATFORM_NAME; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -31,10 +37,53 @@ @Tag("UnitTests") class DesiredCapabilitiesTest { + @Test + void defaultConstructor() { + DesiredCapabilities capabilities = new DesiredCapabilities(); + assertThat(capabilities.asMap()).isEmpty(); + } + + @Test + void constructorWithMap() { + DesiredCapabilities capabilities = new DesiredCapabilities(Map.of("foo", 8)); + assertThat(capabilities.asMap()).containsExactly(Map.entry("foo", 8)); + } + + @Test + void constructorWithBrowserVersion() { + DesiredCapabilities capabilities = new DesiredCapabilities("firefox", "2.0", WINDOWS); + assertThat(capabilities.asMap()) + .containsExactlyInAnyOrderEntriesOf( + Map.of( + BROWSER_VERSION, "2.0", + BROWSER_NAME, "firefox", + PLATFORM_NAME, WINDOWS)); + } + + @Test + void fluentSetters() { + DesiredCapabilities capabilities = + new DesiredCapabilities() + .setBrowserName("edge") + .setVersion("3.0") + .setPlatform(ANDROID) + .setAcceptInsecureCerts(true); + assertThat(capabilities.asMap()) + .containsExactlyInAnyOrderEntriesOf( + Map.of( + BROWSER_VERSION, + "3.0", + BROWSER_NAME, + "edge", + PLATFORM_NAME, + ANDROID, + ACCEPT_INSECURE_CERTS, + true)); + } + @Test void testAddingTheSameCapabilityToAMapTwiceShouldResultInOneEntry() { - Map> capabilitiesToDriver = - new ConcurrentHashMap<>(); + Map> capabilitiesToDriver = new ConcurrentHashMap<>(); capabilitiesToDriver.put(new FirefoxOptions(), WebDriver.class); capabilitiesToDriver.put(new FirefoxOptions(), RemoteWebDriver.class); diff --git a/java/test/org/openqa/selenium/remote/ErrorHandlerTest.java b/java/test/org/openqa/selenium/remote/ErrorHandlerTest.java index 4d375b5cc7583..7c51385f6118b 100644 --- a/java/test/org/openqa/selenium/remote/ErrorHandlerTest.java +++ b/java/test/org/openqa/selenium/remote/ErrorHandlerTest.java @@ -23,7 +23,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.openqa.selenium.InvalidCookieDomainException; @@ -49,7 +48,7 @@ @Tag("UnitTests") @SuppressWarnings("removal") class ErrorHandlerTest { - private ErrorHandler handler; + private final ErrorHandler handler = new ErrorHandler().setIncludeServerErrors(true); private static void assertStackTracesEqual( StackTraceElement[] expected, StackTraceElement[] actual) { @@ -68,12 +67,6 @@ private static Map toMap(Object o) { return new Json().toType(rawJson, Map.class); } - @BeforeEach - public void setUp() { - handler = new ErrorHandler(); - handler.setIncludeServerErrors(true); - } - @Test void testShouldNotThrowIfResponseWasASuccess() { handler.throwIfResponseFailed(createResponse(ErrorCodes.SUCCESS), 100); From 70dfe39bdea97f984c79b6c3d6c6c3c74502da01 Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:46:51 +0300 Subject: [PATCH 34/67] [dotnet] [bidi] Add disposed guard (#17161) --- dotnet/src/webdriver/BiDi/BiDi.cs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dotnet/src/webdriver/BiDi/BiDi.cs b/dotnet/src/webdriver/BiDi/BiDi.cs index e6933b933e660..f008d32a9394f 100644 --- a/dotnet/src/webdriver/BiDi/BiDi.cs +++ b/dotnet/src/webdriver/BiDi/BiDi.cs @@ -27,6 +27,7 @@ namespace OpenQA.Selenium.BiDi; public sealed class BiDi : IBiDi { private readonly ConcurrentDictionary _modules = new(); + private bool _disposed; private Broker Broker { get; set; } = null!; @@ -83,6 +84,13 @@ public Task EndAsync(EndOptions? options = null, CancellationToken ca public async ValueTask DisposeAsync() { + if (_disposed) + { + return; + } + + _disposed = true; + await Broker.DisposeAsync().ConfigureAwait(false); GC.SuppressFinalize(this); } From 90ab1d1c78048eacc3c69a259bf4c4d2fc67b4a9 Mon Sep 17 00:00:00 2001 From: JAYA DILEEP <114493600+seethinajayadileep@users.noreply.github.com> Date: Tue, 3 Mar 2026 02:21:35 +0530 Subject: [PATCH 35/67] [java] Enhance ScriptKey.toString() and mask script content in UnpinnedScriptKey (#17159) * Add toString() to ScriptKey for improved debugging readability * Avoid logging raw script content in UnpinnedScriptKey.toString() * [java] Add unit tests for ScriptKey and UnpinnedScriptKey toString() behavior * [java] Refine UnpinnedScriptKey.toString() and update corresponding tests --------- Co-authored-by: Corey Goldberg <1113081+cgoldberg@users.noreply.github.com> --- java/src/org/openqa/selenium/ScriptKey.java | 5 +++ .../openqa/selenium/UnpinnedScriptKey.java | 13 ++++++ .../org/openqa/selenium/ScriptKeyTest.java | 33 ++++++++++++++ .../selenium/UnpinnedScriptKeyTest.java | 45 +++++++++++++++++++ 4 files changed, 96 insertions(+) create mode 100644 java/test/org/openqa/selenium/ScriptKeyTest.java create mode 100644 java/test/org/openqa/selenium/UnpinnedScriptKeyTest.java diff --git a/java/src/org/openqa/selenium/ScriptKey.java b/java/src/org/openqa/selenium/ScriptKey.java index e5f61c226a6c3..8bcf53ac6ae04 100644 --- a/java/src/org/openqa/selenium/ScriptKey.java +++ b/java/src/org/openqa/selenium/ScriptKey.java @@ -33,6 +33,11 @@ public String getIdentifier() { return identifier; } + @Override + public String toString() { + return identifier; + } + @Override public boolean equals(@Nullable Object o) { if (!(o instanceof ScriptKey)) { diff --git a/java/src/org/openqa/selenium/UnpinnedScriptKey.java b/java/src/org/openqa/selenium/UnpinnedScriptKey.java index b61f443d5f5b8..f8eac4303963c 100644 --- a/java/src/org/openqa/selenium/UnpinnedScriptKey.java +++ b/java/src/org/openqa/selenium/UnpinnedScriptKey.java @@ -112,4 +112,17 @@ public boolean equals(@Nullable Object o) { public int hashCode() { return Objects.hash(super.hashCode(), script); } + + @Override + public String toString() { + // Avoid dumping raw JavaScript into logs: in UnpinnedScriptKey the identifier is the script. + return "UnpinnedScriptKey{" + + "scriptHash=" + + script.hashCode() + + ", scriptId=" + + Objects.toString(scriptId, "unset") + + ", length=" + + script.length() + + "}"; + } } diff --git a/java/test/org/openqa/selenium/ScriptKeyTest.java b/java/test/org/openqa/selenium/ScriptKeyTest.java new file mode 100644 index 0000000000000..1de043da69f03 --- /dev/null +++ b/java/test/org/openqa/selenium/ScriptKeyTest.java @@ -0,0 +1,33 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.openqa.selenium; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class ScriptKeyTest { + @Test + void hasToStringEqualToIdentifier() { + assertThat(new ScriptKey("xxx")).hasToString("xxx"); + } + + @Test + void hasToStringWorksForEmptyIdentifier() { + assertThat(new ScriptKey("")).hasToString(""); + } +} diff --git a/java/test/org/openqa/selenium/UnpinnedScriptKeyTest.java b/java/test/org/openqa/selenium/UnpinnedScriptKeyTest.java new file mode 100644 index 0000000000000..3f7bfc8d2545a --- /dev/null +++ b/java/test/org/openqa/selenium/UnpinnedScriptKeyTest.java @@ -0,0 +1,45 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class UnpinnedScriptKeyTest { + + @Test + void toStringDoesNotExposeRawScript() { + String script = "return 'SECRET_TOKEN';"; + UnpinnedScriptKey key = new UnpinnedScriptKey(script); + + assertThat(key.toString()).doesNotContain(script); + } + + @Test + void toStringContainsScriptIdWhenPresent() { + UnpinnedScriptKey key = new UnpinnedScriptKey("return 1;"); + key.setScriptId("script-99"); + + String value = key.toString(); + + assertThat(value).contains("UnpinnedScriptKey{"); + assertThat(value).contains("scriptId=script-99"); + assertThat(value).contains("length="); + } +} From 14642bd0de19b678dee4cda60fae35f74dc54de1 Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Mon, 2 Mar 2026 23:16:48 +0200 Subject: [PATCH 36/67] [java] specify nullability in packages `org.openqa.selenium.chrom*` (#17152) --- .../src/org/openqa/selenium/Capabilities.java | 12 +++ .../selenium/chrome/ChromeDriverInfo.java | 3 - .../openqa/selenium/chrome/package-info.java | 21 +++++ .../selenium/chromium/AddHasCasting.java | 8 +- .../openqa/selenium/chromium/AddHasCdp.java | 7 +- .../chromium/AddHasNetworkConditions.java | 5 +- .../selenium/chromium/ChromiumDriver.java | 13 +-- .../chromium/ChromiumDriverLogLevel.java | 2 +- .../selenium/chromium/ChromiumOptions.java | 84 +++++++++---------- .../openqa/selenium/chromium/HasCasting.java | 2 +- .../openqa/selenium/remote/ExecuteMethod.java | 8 +- .../selenium/remote/FedCmDialogImpl.java | 34 ++++---- .../selenium/remote/LocalExecuteMethod.java | 4 +- .../selenium/remote/RemoteExecuteMethod.java | 5 +- .../selenium/remote/RemoteWebDriver.java | 3 +- 15 files changed, 120 insertions(+), 91 deletions(-) create mode 100644 java/src/org/openqa/selenium/chrome/package-info.java diff --git a/java/src/org/openqa/selenium/Capabilities.java b/java/src/org/openqa/selenium/Capabilities.java index a34bb740b1536..36527a618e55b 100644 --- a/java/src/org/openqa/selenium/Capabilities.java +++ b/java/src/org/openqa/selenium/Capabilities.java @@ -17,6 +17,8 @@ package org.openqa.selenium; +import static java.util.Objects.requireNonNull; + import java.io.Serializable; import java.util.Collections; import java.util.Map; @@ -70,6 +72,16 @@ default String getBrowserVersion() { */ @Nullable Object getCapability(String capabilityName); + @SuppressWarnings("unchecked") + default @Nullable T get(String capabilityName) { + return (T) getCapability(capabilityName); + } + + default T required(String capabilityName) { + return requireNonNull( + get(capabilityName), () -> "Capability " + capabilityName + " is not set"); + } + /** * @param capabilityName The capability to check. * @return Whether the value is not null and not false. diff --git a/java/src/org/openqa/selenium/chrome/ChromeDriverInfo.java b/java/src/org/openqa/selenium/chrome/ChromeDriverInfo.java index 0310ba5cccf31..27c9a823d4a32 100644 --- a/java/src/org/openqa/selenium/chrome/ChromeDriverInfo.java +++ b/java/src/org/openqa/selenium/chrome/ChromeDriverInfo.java @@ -21,7 +21,6 @@ import com.google.auto.service.AutoService; import java.util.Optional; -import java.util.logging.Logger; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.SessionNotCreatedException; @@ -34,8 +33,6 @@ @AutoService(WebDriverInfo.class) public class ChromeDriverInfo extends ChromiumDriverInfo { - private static final Logger LOG = Logger.getLogger(ChromeDriverInfo.class.getName()); - @Override public String getDisplayName() { return "Chrome"; diff --git a/java/src/org/openqa/selenium/chrome/package-info.java b/java/src/org/openqa/selenium/chrome/package-info.java new file mode 100644 index 0000000000000..21abcff6c80d8 --- /dev/null +++ b/java/src/org/openqa/selenium/chrome/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.chrome; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/chromium/AddHasCasting.java b/java/src/org/openqa/selenium/chromium/AddHasCasting.java index aeb57693110e9..b52d55e50f0dc 100644 --- a/java/src/org/openqa/selenium/chromium/AddHasCasting.java +++ b/java/src/org/openqa/selenium/chromium/AddHasCasting.java @@ -17,6 +17,9 @@ package org.openqa.selenium.chromium; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNullElse; + import java.util.List; import java.util.Map; import java.util.function.Predicate; @@ -51,10 +54,9 @@ public Class getDescribedInterface() { @Override public HasCasting getImplementation(Capabilities capabilities, ExecuteMethod executeMethod) { return new HasCasting() { - @SuppressWarnings("unchecked") @Override public List> getCastSinks() { - return (List>) executeMethod.execute(GET_CAST_SINKS, null); + return requireNonNullElse(executeMethod.execute(GET_CAST_SINKS, null), emptyList()); } @Override @@ -80,7 +82,7 @@ public void startTabMirroring(String deviceName) { @Override public String getCastIssueMessage() { - return executeMethod.execute(GET_CAST_ISSUE_MESSAGE, null).toString(); + return executeMethod.executeRequired(GET_CAST_ISSUE_MESSAGE, null).toString(); } @Override diff --git a/java/src/org/openqa/selenium/chromium/AddHasCdp.java b/java/src/org/openqa/selenium/chromium/AddHasCdp.java index 3b6604b54130d..e35998388ffb1 100644 --- a/java/src/org/openqa/selenium/chromium/AddHasCdp.java +++ b/java/src/org/openqa/selenium/chromium/AddHasCdp.java @@ -51,11 +51,8 @@ public HasCdp getImplementation(Capabilities capabilities, ExecuteMethod execute Require.nonNull("Command name", commandName); Require.nonNull("Parameters", parameters); - Map toReturn = - (Map) - executeMethod.execute(EXECUTE_CDP, Map.of("cmd", commandName, "params", parameters)); - - return Map.copyOf(toReturn); + return executeMethod.executeRequired( + EXECUTE_CDP, Map.of("cmd", commandName, "params", parameters)); }; } } diff --git a/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java b/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java index cb976cb2aa87d..3a045c37d8170 100644 --- a/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java +++ b/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java @@ -17,7 +17,6 @@ package org.openqa.selenium.chromium; -import static java.util.Objects.requireNonNull; import static org.openqa.selenium.chromium.ChromiumDriver.IS_CHROMIUM_BROWSER; import static org.openqa.selenium.chromium.ChromiumNetworkConditions.DOWNLOAD_THROUGHPUT; import static org.openqa.selenium.chromium.ChromiumNetworkConditions.LATENCY; @@ -77,9 +76,7 @@ public HasNetworkConditions getImplementation( @Override public ChromiumNetworkConditions getNetworkConditions() { @SuppressWarnings("unchecked") - Map result = - (Map) - requireNonNull(executeMethod.execute(GET_NETWORK_CONDITIONS, null)); + Map result = executeMethod.executeRequired(GET_NETWORK_CONDITIONS, null); return new ChromiumNetworkConditions() .setOffline((Boolean) result.getOrDefault(OFFLINE, false)) .setLatency(Duration.ofMillis((Long) result.getOrDefault(LATENCY, 0))) diff --git a/java/src/org/openqa/selenium/chromium/ChromiumDriver.java b/java/src/org/openqa/selenium/chromium/ChromiumDriver.java index e5cd3cb48b710..3a3ec3a1bdbc1 100644 --- a/java/src/org/openqa/selenium/chromium/ChromiumDriver.java +++ b/java/src/org/openqa/selenium/chromium/ChromiumDriver.java @@ -17,6 +17,7 @@ package org.openqa.selenium.chromium; +import static java.util.Collections.emptyList; import static org.openqa.selenium.remote.Browser.CHROME; import static org.openqa.selenium.remote.Browser.EDGE; import static org.openqa.selenium.remote.Browser.OPERA; @@ -32,7 +33,6 @@ import java.util.function.Supplier; import java.util.logging.Level; import java.util.logging.Logger; -import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; import org.openqa.selenium.BuildInfo; import org.openqa.selenium.Capabilities; @@ -69,7 +69,6 @@ * A {@link WebDriver} implementation that controls a Chromium browser running on the local machine. * It is used as the base class for Chromium-based browser drivers (Chrome, Edge). */ -@NullMarked public class ChromiumDriver extends RemoteWebDriver implements HasAuthentication, HasBiDi, @@ -250,8 +249,9 @@ public void unpin(ScriptKey key) { "Page.removeScriptToEvaluateOnNewDocument", Map.of("identifier", key.getIdentifier()))); } + @Nullable @Override - public Object executeScript(ScriptKey key, Object... args) { + public Object executeScript(ScriptKey key, @Nullable Object... args) { int hashCode = getScriptId(key); String scriptToUse = @@ -342,7 +342,7 @@ public Optional maybeGetBiDi() { @Override public List> getCastSinks() { if (this.casting == null) { - return List.of(); + return emptyList(); } return casting.getCastSinks(); @@ -415,9 +415,4 @@ public void setNetworkConditions(ChromiumNetworkConditions networkConditions) { public void deleteNetworkConditions() { networkConditions.deleteNetworkConditions(); } - - @Override - public void quit() { - super.quit(); - } } diff --git a/java/src/org/openqa/selenium/chromium/ChromiumDriverLogLevel.java b/java/src/org/openqa/selenium/chromium/ChromiumDriverLogLevel.java index 0a8dd108fe4c6..fc5d2ec615a99 100644 --- a/java/src/org/openqa/selenium/chromium/ChromiumDriverLogLevel.java +++ b/java/src/org/openqa/selenium/chromium/ChromiumDriverLogLevel.java @@ -52,7 +52,7 @@ public String toString() { } @Nullable - public static ChromiumDriverLogLevel fromString(String text) { + public static ChromiumDriverLogLevel fromString(@Nullable String text) { if (text != null) { for (ChromiumDriverLogLevel b : ChromiumDriverLogLevel.values()) { if (text.equalsIgnoreCase(b.toString())) { diff --git a/java/src/org/openqa/selenium/chromium/ChromiumOptions.java b/java/src/org/openqa/selenium/chromium/ChromiumOptions.java index ac780d57cd4fa..0aa1b30f81d20 100644 --- a/java/src/org/openqa/selenium/chromium/ChromiumOptions.java +++ b/java/src/org/openqa/selenium/chromium/ChromiumOptions.java @@ -17,9 +17,8 @@ package org.openqa.selenium.chromium; -import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableMap; -import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toUnmodifiableList; import java.io.File; import java.io.IOException; @@ -30,7 +29,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.TreeMap; import java.util.stream.Stream; @@ -64,12 +62,12 @@ public class ChromiumOptions> extends AbstractDriverOptions> { - private String binary; + private @Nullable String binary; private final List args = new ArrayList<>(); private final List extensionFiles = new ArrayList<>(); private final List extensions = new ArrayList<>(); private final Map experimentalOptions = new HashMap<>(); - private Map androidOptions = new HashMap<>(); + private final Map androidOptions = new HashMap<>(); private final String capabilityName; @@ -87,7 +85,7 @@ public ChromiumOptions(String capabilityType, String browserType, String capabil */ public T setBinary(File path) { binary = Require.nonNull("Path to the chrome executable", path).getPath(); - return (T) this; + return self(); } /** @@ -99,7 +97,7 @@ public T setBinary(File path) { */ public T setBinary(String path) { binary = Require.nonNull("Path to the chrome executable", path); - return (T) this; + return self(); } /** @@ -108,7 +106,7 @@ public T setBinary(String path) { */ public T addArguments(String... arguments) { addArguments(List.of(arguments)); - return (T) this; + return self(); } /** @@ -127,7 +125,7 @@ public T addArguments(String... arguments) { */ public T addArguments(List arguments) { args.addAll(arguments); - return (T) this; + return self(); } /** @@ -136,7 +134,7 @@ public T addArguments(List arguments) { */ public T addExtensions(File... paths) { addExtensions(List.of(paths)); - return (T) this; + return self(); } /** @@ -148,7 +146,7 @@ public T addExtensions(File... paths) { public T addExtensions(List paths) { paths.forEach(path -> Require.argument("Extension", path.toPath()).isFile()); extensionFiles.addAll(paths); - return (T) this; + return self(); } /** @@ -157,7 +155,7 @@ public T addExtensions(List paths) { */ public T addEncodedExtensions(String... encoded) { addEncodedExtensions(List.of(encoded)); - return (T) this; + return self(); } /** @@ -171,12 +169,12 @@ public T addEncodedExtensions(List encoded) { Require.nonNull("Encoded extension", extension); } extensions.addAll(encoded); - return (T) this; + return self(); } public T enableBiDi() { setCapability("webSocketUrl", true); - return (T) this; + return self(); } /** @@ -188,7 +186,7 @@ public T enableBiDi() { */ public T setExperimentalOption(String name, Object value) { experimentalOptions.put(Require.nonNull("Option name", name), value); - return (T) this; + return self(); } public T setAndroidPackage(String androidPackage) { @@ -222,9 +220,12 @@ public T setAndroidProcess(String processName) { private T setAndroidCapability(String name, Object value) { Require.nonNull("Name", name); Require.nonNull("Value", value); - Map newOptions = new TreeMap<>(androidOptions); - newOptions.put(name, value); - androidOptions = Collections.unmodifiableMap(newOptions); + androidOptions.put(name, value); + return self(); + } + + @SuppressWarnings("unchecked") + protected T self() { return (T) this; } @@ -247,30 +248,28 @@ protected Object getExtraCapability(String capabilityName) { options.put("binary", binary); } - options.put("args", unmodifiableList(new ArrayList<>(args))); - - options.put( - "extensions", - unmodifiableList( - Stream.concat( - extensionFiles.stream() - .map( - file -> { - try { - return Base64.getEncoder() - .encodeToString(Files.readAllBytes(file.toPath())); - } catch (IOException e) { - throw new SessionNotCreatedException(e.getMessage(), e); - } - }), - extensions.stream()) - .collect(toList()))); - + options.put("args", List.copyOf(args)); + options.put("extensions", extensionsArgument()); options.putAll(androidOptions); return unmodifiableMap(options); } + private List extensionsArgument() { + return Stream.concat(extensionFiles.stream().map(this::fileContentBase64), extensions.stream()) + .collect(toUnmodifiableList()); + } + + private String fileContentBase64(File file) { + try { + byte[] fileContent = Files.readAllBytes(file.toPath()); + return Base64.getEncoder().encodeToString(fileContent); + } catch (IOException e) { + throw new SessionNotCreatedException( + "Failed to read extension file " + file.getAbsolutePath(), e); + } + } + protected void mergeInPlace(Capabilities capabilities) { Require.nonNull("Capabilities to merge", capabilities); @@ -280,7 +279,7 @@ protected void mergeInPlace(Capabilities capabilities) { } if (name.equals("args") && capabilities.getCapability(name) != null) { - List arguments = (List) (capabilities.getCapability(("args"))); + List arguments = capabilities.required("args"); arguments.forEach( arg -> { if (!args.contains(arg)) { @@ -290,7 +289,7 @@ protected void mergeInPlace(Capabilities capabilities) { } if (name.equals("extensions") && capabilities.getCapability(name) != null) { - List extensionList = (List) (capabilities.getCapability(("extensions"))); + List extensionList = capabilities.required("extensions"); extensionList.forEach( extension -> { if (!extensions.contains(extension)) { @@ -323,12 +322,9 @@ protected void mergeInPlace(Capabilities capabilities) { addExtensions(options.extensionFiles); addEncodedExtensions(options.extensions); - Optional.ofNullable(options.binary).ifPresent(this::setBinary); - + if (options.binary != null) setBinary(options.binary); options.experimentalOptions.forEach(this::setExperimentalOption); - - Optional.ofNullable(options.androidOptions) - .ifPresent(opts -> opts.forEach(this::setAndroidCapability)); + options.androidOptions.forEach(this::setAndroidCapability); } } diff --git a/java/src/org/openqa/selenium/chromium/HasCasting.java b/java/src/org/openqa/selenium/chromium/HasCasting.java index a9f3ca5b95b69..513122e1847b5 100644 --- a/java/src/org/openqa/selenium/chromium/HasCasting.java +++ b/java/src/org/openqa/selenium/chromium/HasCasting.java @@ -28,7 +28,7 @@ public interface HasCasting { /** * Returns the list of cast sinks (Cast devices) available to the Chrome media router. * - * @return array of ID / Name pairs of available cast sink targets + * @return list of ID / Name pairs of available cast sink targets */ List> getCastSinks(); diff --git a/java/src/org/openqa/selenium/remote/ExecuteMethod.java b/java/src/org/openqa/selenium/remote/ExecuteMethod.java index de1caaefe8ca8..e348e5daf4604 100644 --- a/java/src/org/openqa/selenium/remote/ExecuteMethod.java +++ b/java/src/org/openqa/selenium/remote/ExecuteMethod.java @@ -17,6 +17,8 @@ package org.openqa.selenium.remote; +import static java.util.Objects.requireNonNull; + import java.util.Map; import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; @@ -35,5 +37,9 @@ public interface ExecuteMethod { * @param parameters The parameters to execute that command with * @return The result of {@link Response#getValue()}. */ - @Nullable Object execute(String commandName, @Nullable Map parameters); + @Nullable T execute(String commandName, @Nullable Map parameters); + + default T executeRequired(String commandName, @Nullable Map parameters) { + return requireNonNull(execute(commandName, parameters)); + } } diff --git a/java/src/org/openqa/selenium/remote/FedCmDialogImpl.java b/java/src/org/openqa/selenium/remote/FedCmDialogImpl.java index 59e832f84e367..ed50cfab338fe 100644 --- a/java/src/org/openqa/selenium/remote/FedCmDialogImpl.java +++ b/java/src/org/openqa/selenium/remote/FedCmDialogImpl.java @@ -17,12 +17,15 @@ package org.openqa.selenium.remote; -import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.federatedcredentialmanagement.FederatedCredentialManagementAccount; import org.openqa.selenium.federatedcredentialmanagement.FederatedCredentialManagementDialog; +@NullMarked class FedCmDialogImpl implements FederatedCredentialManagementDialog { private final ExecuteMethod executeMethod; @@ -40,9 +43,10 @@ public void selectAccount(int index) { executeMethod.execute(DriverCommand.SELECT_ACCOUNT, Map.of("accountIndex", index)); } + @Nullable @Override public String getDialogType() { - return (String) executeMethod.execute(DriverCommand.GET_FEDCM_DIALOG_TYPE, null); + return executeMethod.execute(DriverCommand.GET_FEDCM_DIALOG_TYPE, null); } @Override @@ -51,29 +55,27 @@ public void clickDialog() { DriverCommand.CLICK_DIALOG, Map.of("dialogButton", "ConfirmIdpLoginContinue")); } + @Nullable @Override public String getTitle() { - Map result = - (Map) executeMethod.execute(DriverCommand.GET_FEDCM_TITLE, null); - return (String) result.getOrDefault("title", null); + Map result = executeMethod.executeRequired(DriverCommand.GET_FEDCM_TITLE, null); + return result.get("title"); } + @Nullable @Override public String getSubtitle() { - Map result = - (Map) executeMethod.execute(DriverCommand.GET_FEDCM_TITLE, null); - return (String) result.getOrDefault("subtitle", null); + Map result = executeMethod.executeRequired(DriverCommand.GET_FEDCM_TITLE, null); + return result.get("subtitle"); } @Override public List getAccounts() { - List> list = - (List>) executeMethod.execute(DriverCommand.GET_ACCOUNTS, null); - ArrayList accounts = - new ArrayList(); - for (Map map : list) { - accounts.add(new FederatedCredentialManagementAccount(map)); - } - return accounts; + List> accounts = + executeMethod.executeRequired(DriverCommand.GET_ACCOUNTS, null); + + return accounts.stream() + .map(map -> new FederatedCredentialManagementAccount(map)) + .collect(Collectors.toList()); } } diff --git a/java/src/org/openqa/selenium/remote/LocalExecuteMethod.java b/java/src/org/openqa/selenium/remote/LocalExecuteMethod.java index 944b9365d8169..0d2b4e4cb5c36 100644 --- a/java/src/org/openqa/selenium/remote/LocalExecuteMethod.java +++ b/java/src/org/openqa/selenium/remote/LocalExecuteMethod.java @@ -18,13 +18,15 @@ package org.openqa.selenium.remote; import java.util.Map; +import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; import org.openqa.selenium.WebDriverException; +@NullMarked class LocalExecuteMethod implements ExecuteMethod { @Nullable @Override - public Object execute(String commandName, @Nullable Map parameters) { + public T execute(String commandName, @Nullable Map parameters) { throw new WebDriverException("Cannot execute remote command: " + commandName); } } diff --git a/java/src/org/openqa/selenium/remote/RemoteExecuteMethod.java b/java/src/org/openqa/selenium/remote/RemoteExecuteMethod.java index 9be7a20bfd2be..8f28193c5718e 100644 --- a/java/src/org/openqa/selenium/remote/RemoteExecuteMethod.java +++ b/java/src/org/openqa/selenium/remote/RemoteExecuteMethod.java @@ -31,8 +31,9 @@ public RemoteExecuteMethod(RemoteWebDriver driver) { this.driver = Require.nonNull("Remote WebDriver", driver); } + @SuppressWarnings("unchecked") @Override - public @Nullable Object execute(String commandName, @Nullable Map parameters) { + public @Nullable T execute(String commandName, @Nullable Map parameters) { Response response; if (parameters == null || parameters.isEmpty()) { @@ -41,7 +42,7 @@ public RemoteExecuteMethod(RemoteWebDriver driver) { response = driver.execute(commandName, parameters); } - return response.getValue(); + return (T) response.getValue(); } @Override diff --git a/java/src/org/openqa/selenium/remote/RemoteWebDriver.java b/java/src/org/openqa/selenium/remote/RemoteWebDriver.java index c419961c9b5d1..0ed6e13909865 100644 --- a/java/src/org/openqa/selenium/remote/RemoteWebDriver.java +++ b/java/src/org/openqa/selenium/remote/RemoteWebDriver.java @@ -47,6 +47,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.AcceptedW3CCapabilityKeys; import org.openqa.selenium.Alert; import org.openqa.selenium.Beta; @@ -528,7 +529,7 @@ public String getWindowHandle() { } @Override - public Object executeScript(String script, Object... args) { + public @Nullable Object executeScript(@NonNull String script, @Nullable Object... args) { List convertedArgs = Stream.of(args).map(new WebElementToJsonConverter()).collect(Collectors.toList()); From d6a46459f74ac57df59071fa421efb9a8d02a7c5 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Tue, 3 Mar 2026 01:52:03 +0100 Subject: [PATCH 37/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17164) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index dc63f08014a9e..7d82fd5b86cf1 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -50,8 +50,8 @@ js_library( http_archive( name = "linux_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b2/linux-x86_64/en-US/firefox-149.0b2.tar.xz", - sha256 = "14057fe24b65ef64125d04bde3507f48e489464b9f22cda832d34db92dede817", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b3/linux-x86_64/en-US/firefox-149.0b3.tar.xz", + sha256 = "9c69d673378aebcbc4bc5f8dd4becc31cf53d9f604de13f4aa4d3a6dded08475", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -72,8 +72,8 @@ js_library( dmg_archive( name = "mac_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b2/mac/en-US/Firefox%20149.0b2.dmg", - sha256 = "cc697fd73992e677e7249be2a42d7e6e9fe1373841ecefe7009acfad113566fb", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b3/mac/en-US/Firefox%20149.0b3.dmg", + sha256 = "7872c9b4c98d8bfe818fd34e4f14bc55bf797c5b55842623b2ba28c6ac102226", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) From 395fda0fc9e4d7a4109779b6957971a494c0c4e9 Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Tue, 3 Mar 2026 12:59:35 +0200 Subject: [PATCH 38/67] Nullability for browsers (#17167) * specify nullability in package `org.openqa.selenium.edge` * move null-check for `args` and `env` to superclass * specify nullability in package `org.openqa.selenium.firefox` * specify nullability in package `org.openqa.selenium.ie` * specify nullability in package `org.openqa.selenium.safari` * specify nullability in package `org.openqa.selenium.federatedcredentialmanagement` * specify nullability in package `org.openqa.selenium.print` --- .../selenium/chrome/ChromeDriverService.java | 9 +--- .../openqa/selenium/edge/EdgeDriverInfo.java | 3 -- .../selenium/edge/EdgeDriverService.java | 2 +- .../openqa/selenium/edge/package-info.java | 50 +++++++++++++++++++ .../FederatedCredentialManagementAccount.java | 22 ++++---- .../FederatedCredentialManagementDialog.java | 2 - .../HasFederatedCredentialManagement.java | 2 - .../package-info.java | 50 +++++++++++++++++++ .../selenium/firefox/AddHasContext.java | 3 +- .../selenium/firefox/AddHasExtensions.java | 7 ++- .../selenium/firefox/FileExtension.java | 20 +++++--- .../firefox/FirefoxCommandContext.java | 12 ++--- .../selenium/firefox/FirefoxDriver.java | 4 +- .../firefox/FirefoxDriverLogLevel.java | 5 +- .../firefox/FirefoxDriverService.java | 2 +- .../selenium/firefox/FirefoxOptions.java | 13 ++++- .../selenium/firefox/FirefoxProfile.java | 15 +++--- .../selenium/firefox/GeckoDriverInfo.java | 3 -- .../openqa/selenium/firefox/Preferences.java | 3 +- .../openqa/selenium/firefox/ProfilesIni.java | 9 +++- .../openqa/selenium/firefox/package-info.java | 21 ++++++++ .../selenium/ie/ElementScrollBehavior.java | 3 ++ .../ie/InternetExplorerDriverInfo.java | 3 -- .../ie/InternetExplorerDriverService.java | 9 +--- .../selenium/ie/InternetExplorerOptions.java | 2 + .../org/openqa/selenium/ie/package-info.java | 50 +++++++++++++++++++ .../org/openqa/selenium/print/PageMargin.java | 2 - .../org/openqa/selenium/print/PageSize.java | 7 +-- .../openqa/selenium/print/PrintOptions.java | 2 - .../openqa/selenium/print/package-info.java | 50 +++++++++++++++++++ .../remote/service/DriverService.java | 5 +- .../selenium/safari/AddHasPermissions.java | 18 +++---- .../selenium/safari/SafariDriverInfo.java | 3 -- .../selenium/safari/SafariDriverService.java | 2 +- .../openqa/selenium/safari/SafariOptions.java | 4 +- .../SafariTechPreviewDriverService.java | 5 +- .../openqa/selenium/safari/package-info.java | 50 +++++++++++++++++++ 37 files changed, 364 insertions(+), 108 deletions(-) create mode 100644 java/src/org/openqa/selenium/edge/package-info.java create mode 100644 java/src/org/openqa/selenium/federatedcredentialmanagement/package-info.java create mode 100644 java/src/org/openqa/selenium/firefox/package-info.java create mode 100644 java/src/org/openqa/selenium/ie/package-info.java create mode 100644 java/src/org/openqa/selenium/print/package-info.java create mode 100644 java/src/org/openqa/selenium/safari/package-info.java diff --git a/java/src/org/openqa/selenium/chrome/ChromeDriverService.java b/java/src/org/openqa/selenium/chrome/ChromeDriverService.java index 921ad928ea5fc..fa2bee0167044 100644 --- a/java/src/org/openqa/selenium/chrome/ChromeDriverService.java +++ b/java/src/org/openqa/selenium/chrome/ChromeDriverService.java @@ -18,7 +18,6 @@ package org.openqa.selenium.chrome; import static java.util.Collections.unmodifiableList; -import static java.util.Collections.unmodifiableMap; import static org.openqa.selenium.remote.Browser.CHROME; import com.google.auto.service.AutoService; @@ -27,7 +26,6 @@ import java.io.OutputStream; import java.time.Duration; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -110,12 +108,7 @@ public ChromeDriverService( @Nullable List args, @Nullable Map environment) throws IOException { - super( - executable, - port, - timeout, - unmodifiableList(new ArrayList<>(args)), - unmodifiableMap(new HashMap<>(environment))); + super(executable, port, timeout, args, environment); } public String getDriverName() { diff --git a/java/src/org/openqa/selenium/edge/EdgeDriverInfo.java b/java/src/org/openqa/selenium/edge/EdgeDriverInfo.java index 2435c1fae3dbb..23070cee03f44 100644 --- a/java/src/org/openqa/selenium/edge/EdgeDriverInfo.java +++ b/java/src/org/openqa/selenium/edge/EdgeDriverInfo.java @@ -21,7 +21,6 @@ import com.google.auto.service.AutoService; import java.util.Optional; -import java.util.logging.Logger; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.SessionNotCreatedException; @@ -34,8 +33,6 @@ @AutoService(WebDriverInfo.class) public class EdgeDriverInfo extends ChromiumDriverInfo { - private static final Logger LOG = Logger.getLogger(EdgeDriverInfo.class.getName()); - @Override public String getDisplayName() { return "Edge"; diff --git a/java/src/org/openqa/selenium/edge/EdgeDriverService.java b/java/src/org/openqa/selenium/edge/EdgeDriverService.java index 6f1ae54147861..2ba4459ebb519 100644 --- a/java/src/org/openqa/selenium/edge/EdgeDriverService.java +++ b/java/src/org/openqa/selenium/edge/EdgeDriverService.java @@ -104,7 +104,7 @@ public EdgeDriverService( @Nullable List args, @Nullable Map environment) throws IOException { - super(executable, port, timeout, List.copyOf(args), Map.copyOf(environment)); + super(executable, port, timeout, args, environment); } public String getDriverName() { diff --git a/java/src/org/openqa/selenium/edge/package-info.java b/java/src/org/openqa/selenium/edge/package-info.java new file mode 100644 index 0000000000000..0ecb0d05d4dd3 --- /dev/null +++ b/java/src/org/openqa/selenium/edge/package-info.java @@ -0,0 +1,50 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/** + * Mechanisms to configure and run selenium via the command line. There are two key classes {@link + * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. + * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, + * for which there are strongly-typed role-specific classes that use a {@code Config}, such as + * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. + * + *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, + * the process for building the set of flags to use is: + * + *

    + *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} + * and {@link org.openqa.selenium.grid.config.ConfigFlags} + *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link + * org.openqa.selenium.grid.config.HasRoles} where {@link + * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link + * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. + *
  3. Finally all flags returned by {@link + * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. + *
+ * + *

The flags are then used by JCommander to parse the command arguments. Once that's done, the + * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of + * the flag objects with system properties and environment variables. This implies that each flag + * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. + * + *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's + * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. + */ +@NullMarked +package org.openqa.selenium.edge; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/federatedcredentialmanagement/FederatedCredentialManagementAccount.java b/java/src/org/openqa/selenium/federatedcredentialmanagement/FederatedCredentialManagementAccount.java index ff9248a2eb7f5..d14f0c855ee35 100644 --- a/java/src/org/openqa/selenium/federatedcredentialmanagement/FederatedCredentialManagementAccount.java +++ b/java/src/org/openqa/selenium/federatedcredentialmanagement/FederatedCredentialManagementAccount.java @@ -18,7 +18,6 @@ package org.openqa.selenium.federatedcredentialmanagement; import java.util.Map; -import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; /** @@ -29,7 +28,6 @@ * @see * https://w3c-fedid.github.io/FedCM/#webdriver-accountlist */ -@NullMarked public class FederatedCredentialManagementAccount { private final @Nullable String accountId; private final @Nullable String email; @@ -57,16 +55,16 @@ public class FederatedCredentialManagementAccount { public static final String LOGIN_STATE_SIGNIN = "SignIn"; public static final String LOGIN_STATE_SIGNUP = "SignUp"; - public FederatedCredentialManagementAccount(Map dict) { - accountId = (String) dict.getOrDefault("accountId", null); - email = (String) dict.getOrDefault("email", null); - name = (String) dict.getOrDefault("name", null); - givenName = (String) dict.getOrDefault("givenName", null); - pictureUrl = (String) dict.getOrDefault("pictureUrl", null); - idpConfigUrl = (String) dict.getOrDefault("idpConfigUrl", null); - loginState = (String) dict.getOrDefault("loginState", null); - termsOfServiceUrl = (String) dict.getOrDefault("termsOfServiceUrl", null); - privacyPolicyUrl = (String) dict.getOrDefault("privacyPolicyUrl", null); + public FederatedCredentialManagementAccount(Map dict) { + accountId = dict.getOrDefault("accountId", null); + email = dict.getOrDefault("email", null); + name = dict.getOrDefault("name", null); + givenName = dict.getOrDefault("givenName", null); + pictureUrl = dict.getOrDefault("pictureUrl", null); + idpConfigUrl = dict.getOrDefault("idpConfigUrl", null); + loginState = dict.getOrDefault("loginState", null); + termsOfServiceUrl = dict.getOrDefault("termsOfServiceUrl", null); + privacyPolicyUrl = dict.getOrDefault("privacyPolicyUrl", null); } public @Nullable String getAccountid() { diff --git a/java/src/org/openqa/selenium/federatedcredentialmanagement/FederatedCredentialManagementDialog.java b/java/src/org/openqa/selenium/federatedcredentialmanagement/FederatedCredentialManagementDialog.java index 3c12d92e0c6c0..4882141cc0177 100644 --- a/java/src/org/openqa/selenium/federatedcredentialmanagement/FederatedCredentialManagementDialog.java +++ b/java/src/org/openqa/selenium/federatedcredentialmanagement/FederatedCredentialManagementDialog.java @@ -18,7 +18,6 @@ package org.openqa.selenium.federatedcredentialmanagement; import java.util.List; -import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; /** @@ -26,7 +25,6 @@ * * @see https://w3c-fedid.github.io/FedCM/ */ -@NullMarked public interface FederatedCredentialManagementDialog { String DIALOG_TYPE_ACCOUNT_LIST = "AccountChooser"; diff --git a/java/src/org/openqa/selenium/federatedcredentialmanagement/HasFederatedCredentialManagement.java b/java/src/org/openqa/selenium/federatedcredentialmanagement/HasFederatedCredentialManagement.java index 501a52d6e05e7..2026f66617758 100644 --- a/java/src/org/openqa/selenium/federatedcredentialmanagement/HasFederatedCredentialManagement.java +++ b/java/src/org/openqa/selenium/federatedcredentialmanagement/HasFederatedCredentialManagement.java @@ -17,13 +17,11 @@ package org.openqa.selenium.federatedcredentialmanagement; -import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; import org.openqa.selenium.Beta; /** Used by classes to indicate that they can interact with FedCM dialogs. */ @Beta -@NullMarked public interface HasFederatedCredentialManagement { /** * Disables the promise rejection delay. diff --git a/java/src/org/openqa/selenium/federatedcredentialmanagement/package-info.java b/java/src/org/openqa/selenium/federatedcredentialmanagement/package-info.java new file mode 100644 index 0000000000000..359e361a886f5 --- /dev/null +++ b/java/src/org/openqa/selenium/federatedcredentialmanagement/package-info.java @@ -0,0 +1,50 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/** + * Mechanisms to configure and run selenium via the command line. There are two key classes {@link + * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. + * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, + * for which there are strongly-typed role-specific classes that use a {@code Config}, such as + * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. + * + *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, + * the process for building the set of flags to use is: + * + *

    + *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} + * and {@link org.openqa.selenium.grid.config.ConfigFlags} + *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link + * org.openqa.selenium.grid.config.HasRoles} where {@link + * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link + * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. + *
  3. Finally all flags returned by {@link + * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. + *
+ * + *

The flags are then used by JCommander to parse the command arguments. Once that's done, the + * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of + * the flag objects with system properties and environment variables. This implies that each flag + * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. + * + *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's + * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. + */ +@NullMarked +package org.openqa.selenium.federatedcredentialmanagement; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/firefox/AddHasContext.java b/java/src/org/openqa/selenium/firefox/AddHasContext.java index 787892239fc20..4369ab6cda1d2 100644 --- a/java/src/org/openqa/selenium/firefox/AddHasContext.java +++ b/java/src/org/openqa/selenium/firefox/AddHasContext.java @@ -69,8 +69,7 @@ public void setContext(FirefoxCommandContext context) { @Override public FirefoxCommandContext getContext() { - String context = (String) executeMethod.execute(GET_CONTEXT, null); - + String context = executeMethod.executeRequired(GET_CONTEXT, null); return FirefoxCommandContext.fromString(context); } }; diff --git a/java/src/org/openqa/selenium/firefox/AddHasExtensions.java b/java/src/org/openqa/selenium/firefox/AddHasExtensions.java index d1b21880bbdb6..671199bd2e494 100644 --- a/java/src/org/openqa/selenium/firefox/AddHasExtensions.java +++ b/java/src/org/openqa/selenium/firefox/AddHasExtensions.java @@ -95,9 +95,8 @@ public String installExtension(Path path, Boolean temporary) { throw new InvalidArgumentException(path + " is an invalid path", e); } - return (String) - executeMethod.execute( - INSTALL_EXTENSION, Map.of("addon", encoded, "temporary", temporary)); + return executeMethod.executeRequired( + INSTALL_EXTENSION, Map.of("addon", encoded, "temporary", temporary)); } private byte[] zipDirectory(Path path) throws IOException { @@ -105,7 +104,7 @@ private byte[] zipDirectory(Path path) throws IOException { try (ZipOutputStream zos = new ZipOutputStream(baos)) { Files.walkFileTree( path, - new SimpleFileVisitor() { + new SimpleFileVisitor<>() { @Override public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { diff --git a/java/src/org/openqa/selenium/firefox/FileExtension.java b/java/src/org/openqa/selenium/firefox/FileExtension.java index 7f7253c22df85..db9e755f00527 100644 --- a/java/src/org/openqa/selenium/firefox/FileExtension.java +++ b/java/src/org/openqa/selenium/firefox/FileExtension.java @@ -18,6 +18,7 @@ package org.openqa.selenium.firefox; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; import static org.openqa.selenium.json.Json.MAP_TYPE; import java.io.BufferedInputStream; @@ -131,7 +132,7 @@ private String readIdFromManifestJson(File root) { JsonInput json = new Json().newInput(reader)) { String addOnId = null; - Map manifestObject = json.read(MAP_TYPE); + Map manifestObject = json.readNonNull(MAP_TYPE); if (manifestObject.get("applications") instanceof Map) { Map applicationObj = (Map) manifestObject.get("applications"); if (applicationObj.get("gecko") instanceof Map) { @@ -143,10 +144,14 @@ private String readIdFromManifestJson(File root) { } if (addOnId == null || addOnId.isEmpty()) { - addOnId = - ((String) manifestObject.get("name")).replaceAll(" ", "") - + "@" - + manifestObject.get("version"); + String name = + (String) + requireNonNull( + manifestObject.get("name"), + () -> + "Manifest should contain either 'applications' or 'name', but received: " + + manifestObject); + addOnId = name.replaceAll(" ", "") + "@" + manifestObject.get("version"); } return addOnId; @@ -158,9 +163,8 @@ private String readIdFromManifestJson(File root) { } private String readIdFromInstallRdf(File root) { + File installRdf = new File(root, "install.rdf"); try { - File installRdf = new File(root, "install.rdf"); - DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); factory.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true); factory.setNamespaceAware(true); @@ -214,7 +218,7 @@ public Iterator getPrefixes(String uri) { } return id; } catch (Exception e) { - throw new WebDriverException(e); + throw new WebDriverException("Failed to read id from " + installRdf.getAbsolutePath(), e); } } } diff --git a/java/src/org/openqa/selenium/firefox/FirefoxCommandContext.java b/java/src/org/openqa/selenium/firefox/FirefoxCommandContext.java index 7c683788938a7..07e5208d0c978 100644 --- a/java/src/org/openqa/selenium/firefox/FirefoxCommandContext.java +++ b/java/src/org/openqa/selenium/firefox/FirefoxCommandContext.java @@ -32,17 +32,15 @@ public enum FirefoxCommandContext { @Override public String toString() { - return String.valueOf(text); + return text; } public static FirefoxCommandContext fromString(String text) { - if (text != null) { - for (FirefoxCommandContext b : FirefoxCommandContext.values()) { - if (text.equalsIgnoreCase(b.text)) { - return b; - } + for (FirefoxCommandContext b : FirefoxCommandContext.values()) { + if (text.equalsIgnoreCase(b.text)) { + return b; } } - return null; + throw new IllegalArgumentException("Unknown Firefox context: " + text); } } diff --git a/java/src/org/openqa/selenium/firefox/FirefoxDriver.java b/java/src/org/openqa/selenium/firefox/FirefoxDriver.java index a84f836f374b8..2546a8556e971 100644 --- a/java/src/org/openqa/selenium/firefox/FirefoxDriver.java +++ b/java/src/org/openqa/selenium/firefox/FirefoxDriver.java @@ -27,7 +27,6 @@ import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.Stream; -import org.jspecify.annotations.NonNull; import org.openqa.selenium.Beta; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; @@ -169,6 +168,8 @@ public static RemoteWebDriverBuilder builder() { /** Check capabilities and proxy if it is set */ private static Capabilities checkCapabilitiesAndProxy(Capabilities capabilities) { + // TODO I think we can remove this null check + //noinspection ConstantValue if (capabilities == null) { return new ImmutableCapabilities(); } @@ -184,7 +185,6 @@ private static Capabilities checkCapabilitiesAndProxy(Capabilities capabilities) return caps; } - @NonNull @Override public Capabilities getCapabilities() { return capabilities; diff --git a/java/src/org/openqa/selenium/firefox/FirefoxDriverLogLevel.java b/java/src/org/openqa/selenium/firefox/FirefoxDriverLogLevel.java index b46e762214b84..e7366fc74f9d2 100644 --- a/java/src/org/openqa/selenium/firefox/FirefoxDriverLogLevel.java +++ b/java/src/org/openqa/selenium/firefox/FirefoxDriverLogLevel.java @@ -21,6 +21,7 @@ import java.util.Locale; import java.util.Map; import java.util.logging.Level; +import org.jspecify.annotations.Nullable; /** * Log levels defined by GeckoDriver @@ -51,7 +52,8 @@ public String toString() { return super.toString().toLowerCase(Locale.ENGLISH); } - public static FirefoxDriverLogLevel fromString(String text) { + @Nullable + public static FirefoxDriverLogLevel fromString(@Nullable String text) { if (text != null) { for (FirefoxDriverLogLevel b : FirefoxDriverLogLevel.values()) { if (text.equalsIgnoreCase(b.toString())) { @@ -70,6 +72,7 @@ Map toJson() { return Collections.singletonMap("level", toString()); } + @Nullable static FirefoxDriverLogLevel fromJson(Map json) { return fromString(json.get("level")); } diff --git a/java/src/org/openqa/selenium/firefox/FirefoxDriverService.java b/java/src/org/openqa/selenium/firefox/FirefoxDriverService.java index 65e9ed83be01d..69f7626fe8244 100644 --- a/java/src/org/openqa/selenium/firefox/FirefoxDriverService.java +++ b/java/src/org/openqa/selenium/firefox/FirefoxDriverService.java @@ -35,7 +35,7 @@ public abstract class FirefoxDriverService extends DriverService { * @param environment The environment for the launched server. * @throws IOException If an I/O error occurs. */ - public FirefoxDriverService( + protected FirefoxDriverService( @Nullable File executable, int port, @Nullable Duration timeout, diff --git a/java/src/org/openqa/selenium/firefox/FirefoxOptions.java b/java/src/org/openqa/selenium/firefox/FirefoxOptions.java index 6950661874cc6..641d6253fd34a 100644 --- a/java/src/org/openqa/selenium/firefox/FirefoxOptions.java +++ b/java/src/org/openqa/selenium/firefox/FirefoxOptions.java @@ -227,11 +227,11 @@ public FirefoxOptions addPreference(String key, Object value) { return setFirefoxOption(Keys.PREFS, Collections.unmodifiableMap(newPrefs)); } - Map prefs() { + @Nullable Map prefs() { return getOption(Keys.PREFS); } - String profile() { + @Nullable String profile() { return getOption(Keys.PROFILE); } @@ -293,6 +293,7 @@ protected Set getExtraCapabilityNames() { return Collections.unmodifiableSet(names); } + @Nullable @Override protected Object getExtraCapability(String capabilityName) { Require.nonNull("Capability name", capabilityName); @@ -440,6 +441,7 @@ private enum Keys { @Override public void amend(Map sourceOptions, Map toAmend) {} + @Nullable @Override public Object mirror(Map first, Map second) { return null; @@ -465,6 +467,7 @@ public void amend(Map sourceOptions, Map toAmend toAmend.put(key(), Collections.unmodifiableList(new ArrayList<>(newArgs))); } + @Nullable @Override public Object mirror(Map first, Map second) { Object rawFirst = first.getOrDefault(key(), new ArrayList<>()); @@ -494,6 +497,7 @@ public void amend(Map sourceOptions, Map toAmend } } + @Nullable @Override public Object mirror(Map first, Map second) { Object value = second.get(key()); @@ -527,6 +531,7 @@ public void amend(Map sourceOptions, Map toAmend toAmend.put(key(), Collections.unmodifiableMap(collected)); } + @Nullable @Override public Object mirror(Map first, Map second) { Object rawFirst = first.getOrDefault(key(), new TreeMap<>()); @@ -559,6 +564,7 @@ public void amend(Map sourceOptions, Map toAmend toAmend.put(key(), o); } + @Nullable @Override public Object mirror(Map first, Map second) { Object value = second.get(key()); @@ -590,6 +596,7 @@ public void amend(Map sourceOptions, Map toAmend toAmend.put(key(), Collections.unmodifiableMap(collected)); } + @Nullable @Override public Object mirror(Map first, Map second) { Object rawFirst = first.getOrDefault(key(), new TreeMap<>()); @@ -627,6 +634,7 @@ public void amend(Map sourceOptions, Map toAmend toAmend.put(key(), o); } + @Nullable @Override public Object mirror(Map first, Map second) { Object value = second.get(key()); @@ -657,6 +665,7 @@ public String key() { public abstract void amend(Map sourceOptions, Map toAmend); + @Nullable public abstract Object mirror(Map first, Map second); } } diff --git a/java/src/org/openqa/selenium/firefox/FirefoxProfile.java b/java/src/org/openqa/selenium/firefox/FirefoxProfile.java index f39b5e1fa33c9..023bf8571cf40 100644 --- a/java/src/org/openqa/selenium/firefox/FirefoxProfile.java +++ b/java/src/org/openqa/selenium/firefox/FirefoxProfile.java @@ -26,6 +26,7 @@ import java.nio.charset.Charset; import java.util.HashMap; import java.util.Map; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.io.FileHandler; import org.openqa.selenium.io.TemporaryFilesystem; @@ -38,7 +39,7 @@ public class FirefoxProfile { private static final String ASSUME_UNTRUSTED_ISSUER_PREF = "webdriver_assume_untrusted_issuer"; private final Preferences additionalPrefs; private final Map extensions = new HashMap<>(); - private final File model; + private @Nullable final File model; private boolean loadNoFocusLib; private boolean acceptUntrustedCerts; private boolean untrustedCertIssuer; @@ -54,7 +55,7 @@ public FirefoxProfile() { * * @param profileDir The profile directory to use as a model. */ - public FirefoxProfile(File profileDir) { + public FirefoxProfile(@Nullable File profileDir) { additionalPrefs = new Preferences(); model = profileDir; verifyModel(model); @@ -124,7 +125,7 @@ public boolean getBooleanPreference(String key, boolean defaultValue) { return defaultValue; } - private void verifyModel(File model) { + private void verifyModel(@Nullable File model) { if (model == null) { return; } @@ -292,8 +293,10 @@ public FirefoxProfile setAssumeUntrustedCertificateIssuer(boolean untrustedIssue return this; } - public void clean(File profileDir) { - TemporaryFilesystem.getDefaultTmpFS().deleteTempDir(profileDir); + public void clean(@Nullable File profileDir) { + if (profileDir != null) { + TemporaryFilesystem.getDefaultTmpFS().deleteTempDir(profileDir); + } } String toJson() throws IOException { @@ -336,7 +339,7 @@ public File layoutOnDisk() { } } - protected void copyModel(File sourceDir, File profileDir) throws IOException { + protected void copyModel(@Nullable File sourceDir, File profileDir) throws IOException { if (sourceDir == null || !sourceDir.exists()) { return; } diff --git a/java/src/org/openqa/selenium/firefox/GeckoDriverInfo.java b/java/src/org/openqa/selenium/firefox/GeckoDriverInfo.java index dfb371978a18b..da4503b97ea01 100644 --- a/java/src/org/openqa/selenium/firefox/GeckoDriverInfo.java +++ b/java/src/org/openqa/selenium/firefox/GeckoDriverInfo.java @@ -22,7 +22,6 @@ import com.google.auto.service.AutoService; import java.util.Optional; -import java.util.logging.Logger; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.SessionNotCreatedException; @@ -33,8 +32,6 @@ @AutoService(WebDriverInfo.class) public class GeckoDriverInfo implements WebDriverInfo { - private static final Logger LOG = Logger.getLogger(GeckoDriverInfo.class.getName()); - @Override public String getDisplayName() { return "Firefox"; diff --git a/java/src/org/openqa/selenium/firefox/Preferences.java b/java/src/org/openqa/selenium/firefox/Preferences.java index 51f272b323962..59b182573804e 100644 --- a/java/src/org/openqa/selenium/firefox/Preferences.java +++ b/java/src/org/openqa/selenium/firefox/Preferences.java @@ -32,6 +32,7 @@ import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.json.Json; @@ -183,7 +184,7 @@ private Object preferenceAsValue(String toConvert) { } } - Object getPreference(String key) { + @Nullable Object getPreference(String key) { return allPrefs.get(key); } diff --git a/java/src/org/openqa/selenium/firefox/ProfilesIni.java b/java/src/org/openqa/selenium/firefox/ProfilesIni.java index 8e22caa5107e1..535e12921ea20 100644 --- a/java/src/org/openqa/selenium/firefox/ProfilesIni.java +++ b/java/src/org/openqa/selenium/firefox/ProfilesIni.java @@ -28,6 +28,7 @@ import java.text.MessageFormat; import java.util.HashMap; import java.util.Map; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Platform; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.io.FileHandler; @@ -41,7 +42,7 @@ public ProfilesIni() { profiles = readProfiles(appData); } - protected Map readProfiles(File appData) { + protected Map readProfiles(@Nullable File appData) { Map toReturn = new HashMap<>(); File profilesIni = new File(appData, "profiles.ini"); @@ -95,13 +96,16 @@ protected Map readProfiles(File appData) { return toReturn; } - protected File newProfile(String name, File appData, String path, boolean isRelative) { + @Nullable + protected File newProfile( + @Nullable String name, @Nullable File appData, @Nullable String path, boolean isRelative) { if (name != null && path != null) { return isRelative ? new File(appData, path) : new File(path); } return null; } + @Nullable public FirefoxProfile getProfile(String profileName) { File profileDir = profiles.get(profileName); if (profileDir == null) { @@ -127,6 +131,7 @@ public FirefoxProfile getProfile(String profileName) { return new FirefoxProfile(tempDir); } + @Nullable protected File locateAppDataDirectory(Platform os) { File appData; if (os.is(WINDOWS)) { diff --git a/java/src/org/openqa/selenium/firefox/package-info.java b/java/src/org/openqa/selenium/firefox/package-info.java new file mode 100644 index 0000000000000..45dd7c8f0a7dc --- /dev/null +++ b/java/src/org/openqa/selenium/firefox/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.firefox; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/ie/ElementScrollBehavior.java b/java/src/org/openqa/selenium/ie/ElementScrollBehavior.java index c816170757213..c97cc1de12d20 100644 --- a/java/src/org/openqa/selenium/ie/ElementScrollBehavior.java +++ b/java/src/org/openqa/selenium/ie/ElementScrollBehavior.java @@ -17,6 +17,8 @@ package org.openqa.selenium.ie; +import org.jspecify.annotations.Nullable; + public enum ElementScrollBehavior { TOP(0), BOTTOM(1), @@ -33,6 +35,7 @@ public String toString() { return String.valueOf(value); } + @Nullable public static ElementScrollBehavior fromString(String text) { for (ElementScrollBehavior b : ElementScrollBehavior.values()) { if (text.equalsIgnoreCase(b.toString())) { diff --git a/java/src/org/openqa/selenium/ie/InternetExplorerDriverInfo.java b/java/src/org/openqa/selenium/ie/InternetExplorerDriverInfo.java index 4ac3ae75f856e..7713ba8ff91ff 100644 --- a/java/src/org/openqa/selenium/ie/InternetExplorerDriverInfo.java +++ b/java/src/org/openqa/selenium/ie/InternetExplorerDriverInfo.java @@ -22,7 +22,6 @@ import com.google.auto.service.AutoService; import java.util.Optional; -import java.util.logging.Logger; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.Platform; @@ -34,8 +33,6 @@ @AutoService(WebDriverInfo.class) public class InternetExplorerDriverInfo implements WebDriverInfo { - private static final Logger LOG = Logger.getLogger(InternetExplorerDriverInfo.class.getName()); - @Override public String getDisplayName() { return "Internet Explorer"; diff --git a/java/src/org/openqa/selenium/ie/InternetExplorerDriverService.java b/java/src/org/openqa/selenium/ie/InternetExplorerDriverService.java index 3cb43892c6e85..f02cfa5dd43aa 100644 --- a/java/src/org/openqa/selenium/ie/InternetExplorerDriverService.java +++ b/java/src/org/openqa/selenium/ie/InternetExplorerDriverService.java @@ -18,7 +18,6 @@ package org.openqa.selenium.ie; import static java.util.Collections.unmodifiableList; -import static java.util.Collections.unmodifiableMap; import static org.openqa.selenium.remote.Browser.IE; import com.google.auto.service.AutoService; @@ -26,7 +25,6 @@ import java.io.IOException; import java.time.Duration; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -85,12 +83,7 @@ public InternetExplorerDriverService( @Nullable List args, @Nullable Map environment) throws IOException { - super( - executable, - port, - timeout, - unmodifiableList(new ArrayList<>(args)), - unmodifiableMap(new HashMap<>(environment))); + super(executable, port, timeout, args, environment); } public String getDriverName() { diff --git a/java/src/org/openqa/selenium/ie/InternetExplorerOptions.java b/java/src/org/openqa/selenium/ie/InternetExplorerOptions.java index 84ceef9965d60..6f1c5ccc1c381 100644 --- a/java/src/org/openqa/selenium/ie/InternetExplorerOptions.java +++ b/java/src/org/openqa/selenium/ie/InternetExplorerOptions.java @@ -42,6 +42,7 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.internal.Require; import org.openqa.selenium.remote.AbstractDriverOptions; @@ -276,6 +277,7 @@ protected Set getExtraCapabilityNames() { return Collections.emptySet(); } + @Nullable @Override protected Object getExtraCapability(String capabilityName) { Require.nonNull("Capability name", capabilityName); diff --git a/java/src/org/openqa/selenium/ie/package-info.java b/java/src/org/openqa/selenium/ie/package-info.java new file mode 100644 index 0000000000000..e2d33d56dc710 --- /dev/null +++ b/java/src/org/openqa/selenium/ie/package-info.java @@ -0,0 +1,50 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/** + * Mechanisms to configure and run selenium via the command line. There are two key classes {@link + * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. + * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, + * for which there are strongly-typed role-specific classes that use a {@code Config}, such as + * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. + * + *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, + * the process for building the set of flags to use is: + * + *

    + *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} + * and {@link org.openqa.selenium.grid.config.ConfigFlags} + *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link + * org.openqa.selenium.grid.config.HasRoles} where {@link + * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link + * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. + *
  3. Finally all flags returned by {@link + * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. + *
+ * + *

The flags are then used by JCommander to parse the command arguments. Once that's done, the + * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of + * the flag objects with system properties and environment variables. This implies that each flag + * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. + * + *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's + * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. + */ +@NullMarked +package org.openqa.selenium.ie; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/print/PageMargin.java b/java/src/org/openqa/selenium/print/PageMargin.java index 6fe27e6063e5a..6c18bf3588a7d 100644 --- a/java/src/org/openqa/selenium/print/PageMargin.java +++ b/java/src/org/openqa/selenium/print/PageMargin.java @@ -18,9 +18,7 @@ package org.openqa.selenium.print; import java.util.Map; -import org.jspecify.annotations.NullMarked; -@NullMarked public class PageMargin { private final double top; private final double bottom; diff --git a/java/src/org/openqa/selenium/print/PageSize.java b/java/src/org/openqa/selenium/print/PageSize.java index a673fda24ec45..ff8f3ab2489b1 100644 --- a/java/src/org/openqa/selenium/print/PageSize.java +++ b/java/src/org/openqa/selenium/print/PageSize.java @@ -19,9 +19,8 @@ import java.util.HashMap; import java.util.Map; -import org.jspecify.annotations.NullMarked; +import org.openqa.selenium.internal.Require; -@NullMarked public class PageSize { private final double height; private final double width; @@ -52,9 +51,7 @@ public double getWidth() { } public static PageSize setPageSize(PageSize pageSize) { - if (pageSize == null) { - throw new IllegalArgumentException("Page size cannot be null"); - } + Require.nonNull("Page size", pageSize); return new PageSize(pageSize.getHeight(), pageSize.getWidth()); } diff --git a/java/src/org/openqa/selenium/print/PrintOptions.java b/java/src/org/openqa/selenium/print/PrintOptions.java index 1e1b7a86da3d5..019e051979f89 100644 --- a/java/src/org/openqa/selenium/print/PrintOptions.java +++ b/java/src/org/openqa/selenium/print/PrintOptions.java @@ -20,11 +20,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; -@NullMarked public class PrintOptions { public enum Orientation { diff --git a/java/src/org/openqa/selenium/print/package-info.java b/java/src/org/openqa/selenium/print/package-info.java new file mode 100644 index 0000000000000..48ef4b6fe9f52 --- /dev/null +++ b/java/src/org/openqa/selenium/print/package-info.java @@ -0,0 +1,50 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/** + * Mechanisms to configure and run selenium via the command line. There are two key classes {@link + * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. + * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, + * for which there are strongly-typed role-specific classes that use a {@code Config}, such as + * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. + * + *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, + * the process for building the set of flags to use is: + * + *

    + *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} + * and {@link org.openqa.selenium.grid.config.ConfigFlags} + *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link + * org.openqa.selenium.grid.config.HasRoles} where {@link + * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link + * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. + *
  3. Finally all flags returned by {@link + * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. + *
+ * + *

The flags are then used by JCommander to parse the command arguments. Once that's done, the + * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of + * the flag objects with system properties and environment variables. This implies that each flag + * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. + * + *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's + * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. + */ +@NullMarked +package org.openqa.selenium.print; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/remote/service/DriverService.java b/java/src/org/openqa/selenium/remote/service/DriverService.java index 11aacd3677441..3276498d3a315 100644 --- a/java/src/org/openqa/selenium/remote/service/DriverService.java +++ b/java/src/org/openqa/selenium/remote/service/DriverService.java @@ -17,6 +17,7 @@ package org.openqa.selenium.remote.service; +import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.concurrent.TimeUnit.SECONDS; import static org.openqa.selenium.concurrent.ExecutorServices.shutdownGracefully; @@ -117,8 +118,8 @@ protected DriverService( this.executable = executable.getCanonicalPath(); } this.timeout = timeout; - this.args = args; - this.environment = environment; + this.args = args == null ? emptyList() : List.copyOf(args); + this.environment = environment == null ? emptyMap() : Map.copyOf(environment); this.url = getUrl(port); } diff --git a/java/src/org/openqa/selenium/safari/AddHasPermissions.java b/java/src/org/openqa/selenium/safari/AddHasPermissions.java index 96e49bdb54084..6aa4572387a89 100644 --- a/java/src/org/openqa/selenium/safari/AddHasPermissions.java +++ b/java/src/org/openqa/selenium/safari/AddHasPermissions.java @@ -67,21 +67,15 @@ public void setPermissions(String permission, boolean value) { @Override public Map getPermissions() { - Object resultObject = executeMethod.execute(GET_PERMISSIONS, null); + Map resultMap = executeMethod.executeRequired(GET_PERMISSIONS, null); - if (resultObject instanceof Map) { - Map resultMap = (Map) resultObject; - Map permissionMap = new HashMap<>(); - for (Map.Entry entry : resultMap.entrySet()) { - if (entry.getKey() instanceof String && entry.getValue() instanceof Boolean) { - permissionMap.put((String) entry.getKey(), (Boolean) entry.getValue()); - } + Map permissionMap = new HashMap<>(); + for (Map.Entry entry : resultMap.entrySet()) { + if (entry.getKey() instanceof String && entry.getValue() instanceof Boolean) { + permissionMap.put((String) entry.getKey(), (Boolean) entry.getValue()); } - return permissionMap; - } else { - throw new IllegalStateException( - "Unexpected result type: " + resultObject.getClass().getName()); } + return permissionMap; } }; } diff --git a/java/src/org/openqa/selenium/safari/SafariDriverInfo.java b/java/src/org/openqa/selenium/safari/SafariDriverInfo.java index 3f9967db274d4..bfd408716a3fb 100644 --- a/java/src/org/openqa/selenium/safari/SafariDriverInfo.java +++ b/java/src/org/openqa/selenium/safari/SafariDriverInfo.java @@ -22,7 +22,6 @@ import com.google.auto.service.AutoService; import java.util.Optional; -import java.util.logging.Logger; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.Platform; @@ -35,8 +34,6 @@ @AutoService(WebDriverInfo.class) public class SafariDriverInfo implements WebDriverInfo { - private static final Logger LOG = Logger.getLogger(SafariDriverInfo.class.getName()); - @Override public String getDisplayName() { return "Safari"; diff --git a/java/src/org/openqa/selenium/safari/SafariDriverService.java b/java/src/org/openqa/selenium/safari/SafariDriverService.java index 19c6e29480cc4..52d3d7bc00f96 100644 --- a/java/src/org/openqa/selenium/safari/SafariDriverService.java +++ b/java/src/org/openqa/selenium/safari/SafariDriverService.java @@ -68,7 +68,7 @@ public SafariDriverService( @Nullable List args, @Nullable Map environment) throws IOException { - super(executable, port, timeout, List.copyOf(args), Map.copyOf(environment)); + super(executable, port, timeout, args, environment); } public String getDriverName() { diff --git a/java/src/org/openqa/selenium/safari/SafariOptions.java b/java/src/org/openqa/selenium/safari/SafariOptions.java index b9bda31f9259b..ab390e6b99c25 100644 --- a/java/src/org/openqa/selenium/safari/SafariOptions.java +++ b/java/src/org/openqa/selenium/safari/SafariOptions.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.Set; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.internal.Require; @@ -105,7 +106,7 @@ public SafariOptions setAutomaticInspection(boolean automaticInspection) { } public boolean getAutomaticProfiling() { - return Boolean.TRUE.equals(is(Option.AUTOMATIC_PROFILING)); + return is(Option.AUTOMATIC_PROFILING); } /** @@ -144,6 +145,7 @@ protected Set getExtraCapabilityNames() { return Collections.emptySet(); } + @Nullable @Override protected Object getExtraCapability(String capabilityName) { return null; diff --git a/java/src/org/openqa/selenium/safari/SafariTechPreviewDriverService.java b/java/src/org/openqa/selenium/safari/SafariTechPreviewDriverService.java index 01f15fc063df2..d6d03d8f172b9 100644 --- a/java/src/org/openqa/selenium/safari/SafariTechPreviewDriverService.java +++ b/java/src/org/openqa/selenium/safari/SafariTechPreviewDriverService.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.net.PortProber; @@ -66,7 +67,7 @@ public SafariTechPreviewDriverService( List args, Map environment) throws IOException { - super(executable, port, timeout, List.copyOf(args), Map.copyOf(environment)); + super(executable, port, timeout, args, environment); } public String getDriverName() { @@ -114,7 +115,7 @@ public static class Builder extends DriverService.Builder< SafariTechPreviewDriverService, SafariTechPreviewDriverService.Builder> { - private Boolean diagnose; + @Nullable private Boolean diagnose; @Override public int score(Capabilities capabilities) { diff --git a/java/src/org/openqa/selenium/safari/package-info.java b/java/src/org/openqa/selenium/safari/package-info.java new file mode 100644 index 0000000000000..6684283ee6752 --- /dev/null +++ b/java/src/org/openqa/selenium/safari/package-info.java @@ -0,0 +1,50 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/** + * Mechanisms to configure and run selenium via the command line. There are two key classes {@link + * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. + * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, + * for which there are strongly-typed role-specific classes that use a {@code Config}, such as + * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. + * + *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, + * the process for building the set of flags to use is: + * + *

    + *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} + * and {@link org.openqa.selenium.grid.config.ConfigFlags} + *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link + * org.openqa.selenium.grid.config.HasRoles} where {@link + * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link + * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. + *
  3. Finally all flags returned by {@link + * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. + *
+ * + *

The flags are then used by JCommander to parse the command arguments. Once that's done, the + * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of + * the flag objects with system properties and environment variables. This implies that each flag + * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. + * + *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's + * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. + */ +@NullMarked +package org.openqa.selenium.safari; + +import org.jspecify.annotations.NullMarked; From 6e653b5998b79ba11a23fe472d9b4e93835ae0c6 Mon Sep 17 00:00:00 2001 From: pinterior Date: Tue, 3 Mar 2026 23:33:15 +0900 Subject: [PATCH 39/67] [py] Use Self as return type of __enter__ in remote.WebDriver (#17170) --- py/selenium/webdriver/remote/webdriver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 979415bd7a6e1..ed05f8c1585f8 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -33,6 +33,8 @@ from importlib import import_module from typing import Any, cast +from typing_extensions import Self + from selenium.common.exceptions import ( InvalidArgumentException, JavascriptException, @@ -290,7 +292,7 @@ def __init__( def __repr__(self) -> str: return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' - def __enter__(self) -> "WebDriver": + def __enter__(self) -> Self: return self def __exit__( From d493f293f8f0e6dcc5767c1d19d7e7f1664e12b8 Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Wed, 4 Mar 2026 01:04:32 +0300 Subject: [PATCH 40/67] [dotnet] [bidi] Convert RemoteValue to well-known types (#17027) --- .../src/webdriver/BiDi/Script/RemoteValue.cs | 112 +++++-- .../BiDi/Script/RemoteValueConversionTests.cs | 289 ++++++++++++++++++ 2 files changed, 371 insertions(+), 30 deletions(-) create mode 100644 dotnet/test/common/BiDi/Script/RemoteValueConversionTests.cs diff --git a/dotnet/src/webdriver/BiDi/Script/RemoteValue.cs b/dotnet/src/webdriver/BiDi/Script/RemoteValue.cs index 7fcd16c8e1b1e..9f7698575506a 100644 --- a/dotnet/src/webdriver/BiDi/Script/RemoteValue.cs +++ b/dotnet/src/webdriver/BiDi/Script/RemoteValue.cs @@ -54,49 +54,101 @@ namespace OpenQA.Selenium.BiDi.Script; [JsonConverter(typeof(RemoteValueConverter))] public abstract record RemoteValue { - public static implicit operator double(RemoteValue remoteValue) => (double)((NumberRemoteValue)remoteValue).Value; + public static implicit operator bool(RemoteValue remoteValue) => remoteValue.ConvertTo(); + public static implicit operator double(RemoteValue remoteValue) => remoteValue.ConvertTo(); + public static implicit operator float(RemoteValue remoteValue) => remoteValue.ConvertTo(); + public static implicit operator int(RemoteValue remoteValue) => remoteValue.ConvertTo(); + public static implicit operator long(RemoteValue remoteValue) => remoteValue.ConvertTo(); + public static implicit operator string?(RemoteValue remoteValue) => remoteValue.ConvertTo(); - public static implicit operator int(RemoteValue remoteValue) => (int)(double)remoteValue; - public static implicit operator long(RemoteValue remoteValue) => (long)(double)remoteValue; - - public static implicit operator string?(RemoteValue remoteValue) - { - return remoteValue switch + public TResult? ConvertTo() + => (this, typeof(TResult)) switch { - StringRemoteValue stringValue => stringValue.Value, - NullRemoteValue => null, - _ => throw new InvalidCastException($"Cannot convert {remoteValue} to string") + (_, Type t) when t.IsAssignableFrom(GetType()) + => (TResult)(object)this, + (BooleanRemoteValue b, Type t) when t == typeof(bool) + => (TResult)(object)b.Value, + (NullRemoteValue, Type t) when !t.IsValueType || Nullable.GetUnderlyingType(t) is not null + => default, + (NumberRemoteValue n, Type t) when t == typeof(byte) + => (TResult)(object)Convert.ToByte(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(sbyte) + => (TResult)(object)Convert.ToSByte(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(short) + => (TResult)(object)Convert.ToInt16(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(ushort) + => (TResult)(object)Convert.ToUInt16(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(int) + => (TResult)(object)Convert.ToInt32(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(uint) + => (TResult)(object)Convert.ToUInt32(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(long) + => (TResult)(object)Convert.ToInt64(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(ulong) + => (TResult)(object)Convert.ToUInt64(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(double) + => (TResult)(object)n.Value, + (NumberRemoteValue n, Type t) when t == typeof(float) + => (TResult)(object)Convert.ToSingle(n.Value), + (NumberRemoteValue n, Type t) when t == typeof(decimal) + => (TResult)(object)Convert.ToDecimal(n.Value), + (StringRemoteValue s, Type t) when t == typeof(string) + => (TResult)(object)s.Value, + (ArrayRemoteValue a, Type t) when t.IsArray + => ConvertRemoteValuesToArray(a.Value, t.GetElementType()!), + (ArrayRemoteValue a, Type t) when t.IsGenericType && t.IsAssignableFrom(typeof(List<>).MakeGenericType(t.GetGenericArguments()[0])) + => ConvertRemoteValuesToGenericList(a.Value, typeof(List<>).MakeGenericType(t.GetGenericArguments()[0])), + + (_, Type t) when Nullable.GetUnderlyingType(t) is { } underlying + => ConvertToNullable(underlying), + + _ => throw new InvalidCastException($"Cannot convert {GetType().Name} to {typeof(TResult).FullName}") }; - } - // TODO: extend types - public TResult? ConvertTo() + private TResult ConvertToNullable(Type underlyingType) { - var type = typeof(TResult); + var convertMethod = typeof(RemoteValue).GetMethod(nameof(ConvertTo))!.MakeGenericMethod(underlyingType); + var value = convertMethod.Invoke(this, null); + return (TResult)value!; + } - if (typeof(RemoteValue).IsAssignableFrom(type)) // handle native derived types - { - return (TResult)(this as object); - } - if (type == typeof(bool)) - { - return (TResult)(Convert.ToBoolean(((BooleanRemoteValue)this).Value) as object); - } - if (type == typeof(int)) + private static TResult ConvertRemoteValuesToArray(IEnumerable? remoteValues, Type elementType) + { + if (remoteValues is null) { - return (TResult)(Convert.ToInt32(((NumberRemoteValue)this).Value) as object); + return (TResult)(object)Array.CreateInstance(elementType, 0); } - else if (type == typeof(string)) + + var convertMethod = typeof(RemoteValue).GetMethod(nameof(ConvertTo))!.MakeGenericMethod(elementType); + var items = remoteValues.ToList(); + var array = Array.CreateInstance(elementType, items.Count); + + for (int i = 0; i < items.Count; i++) { - return (TResult)(((StringRemoteValue)this).Value as object); + var convertedItem = convertMethod.Invoke(items[i], null); + array.SetValue(convertedItem, i); } - else if (type is object) + + return (TResult)(object)array; + } + + private static TResult ConvertRemoteValuesToGenericList(IEnumerable? remoteValues, Type listType) + { + var elementType = listType.GetGenericArguments()[0]; + var list = (System.Collections.IList)Activator.CreateInstance(listType)!; + + if (remoteValues is not null) { - // :) - return (TResult)new object(); + var convertMethod = typeof(RemoteValue).GetMethod(nameof(ConvertTo))!.MakeGenericMethod(elementType); + + foreach (var item in remoteValues) + { + var convertedItem = convertMethod.Invoke(item, null); + list.Add(convertedItem); + } } - throw new BiDiException("Cannot convert ....."); + return (TResult)list; } } diff --git a/dotnet/test/common/BiDi/Script/RemoteValueConversionTests.cs b/dotnet/test/common/BiDi/Script/RemoteValueConversionTests.cs new file mode 100644 index 0000000000000..b453272371a51 --- /dev/null +++ b/dotnet/test/common/BiDi/Script/RemoteValueConversionTests.cs @@ -0,0 +1,289 @@ +// +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +using System.Collections.Generic; +using NUnit.Framework; + +namespace OpenQA.Selenium.BiDi.Script; + +internal class RemoteValueConversionTests +{ + [Test] + public void CanConvertToNullable() + { + NullRemoteValue arg = new(); + + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + AssertValue(arg.ConvertTo()); + + static void AssertValue(T value) + { + Assert.That(value, Is.Null); + } + } + + [Test] + public void CanConvertToBool() + { + BooleanRemoteValue arg = new(true); + + AssertValue(arg.ConvertTo()); + AssertValue((bool)arg); + + static void AssertValue(bool value) + { + Assert.That(value, Is.True); + } + } + + [Test] + public void CanConvertToByte() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + + static void AssertValue(byte value) + { + Assert.That(value, Is.EqualTo((byte)6)); + } + } + + [Test] + public void CanConvertToSByte() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + + static void AssertValue(sbyte value) + { + Assert.That(value, Is.EqualTo((sbyte)6)); + } + } + + [Test] + public void CanConvertToInt16() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + AssertValue((short)arg); + + static void AssertValue(short value) + { + Assert.That(value, Is.EqualTo((short)6)); + } + } + + [Test] + public void CanConvertToUInt16() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + AssertValue((ushort)arg); + + static void AssertValue(ushort value) + { + Assert.That(value, Is.EqualTo((ushort)6)); + } + } + + [Test] + public void CanConvertToInt32() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + AssertValue((int)arg); + + static void AssertValue(int value) + { + Assert.That(value, Is.EqualTo(6)); + } + } + + [Test] + public void CanConvertToUInt32() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + AssertValue((uint)arg); + + static void AssertValue(uint value) + { + Assert.That(value, Is.EqualTo(6U)); + } + } + + [Test] + public void CanConvertToInt64() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + AssertValue((long)arg); + + static void AssertValue(long value) + { + Assert.That(value, Is.EqualTo(6L)); + } + } + + [Test] + public void CanConvertToUInt64() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + + static void AssertValue(ulong value) + { + Assert.That(value, Is.EqualTo(6UL)); + } + } + + [Test] + public void CanConvertToDouble() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + AssertValue((double)arg); + + static void AssertValue(double value) + { + Assert.That(value, Is.EqualTo(5.9d)); + } + } + + [Test] + public void CanConvertToFloat() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + AssertValue((float)arg); + + static void AssertValue(float value) + { + Assert.That(value, Is.EqualTo(5.9f)); + } + } + + [Test] + public void CanConvertToDecimal() + { + NumberRemoteValue arg = new(5.9); + + AssertValue(arg.ConvertTo()); + + static void AssertValue(decimal value) + { + Assert.That(value, Is.EqualTo(5.9m)); + } + } + + [Test] + public void CanConvertToString() + { + StringRemoteValue arg = new("abc"); + + AssertValue(arg.ConvertTo()); + AssertValue((string)arg); + + static void AssertValue(string value) + { + Assert.That(value, Is.EqualTo("abc")); + } + } + + [Test] + public void CanConvertToArray() + { + ArrayRemoteValue arg = new() { Value = [new NumberRemoteValue(1), new NumberRemoteValue(2)] }; + + AssertValue(arg.ConvertTo()); + + static void AssertValue(int[] value) + { + Assert.That(value, Is.EqualTo(new int[] { 1, 2 })); + } + } + + [Test] + public void CanConvertToEmptyArray() + { + ArrayRemoteValue arg = new(); + + AssertValue(arg.ConvertTo()); + + static void AssertValue(int[] value) + { + Assert.That(value, Is.Empty); + } + } + + [Test] + public void CanConvertToEnumerableOf() + { + ArrayRemoteValue arg = new() { Value = [new NumberRemoteValue(1), new NumberRemoteValue(2)] }; + + AssertValue(arg.ConvertTo>()); + + AssertValue(arg.ConvertTo>()); + AssertValue(arg.ConvertTo>()); + AssertValue(arg.ConvertTo>()); + AssertValue(arg.ConvertTo>()); + AssertValue(arg.ConvertTo>()); + + static void AssertValue(IEnumerable value) + { + Assert.That(value, Is.EqualTo(new List { 1, 2 })); + } + } + + [Test] + public void CanConvertToEmptyEnumerableOf() + { + ArrayRemoteValue arg = new(); + + AssertValue(arg.ConvertTo>()); + + AssertValue(arg.ConvertTo>()); + + static void AssertValue(IEnumerable value) + { + Assert.That(value, Is.Empty); + } + } +} From 5bfeeee0afd0c6a35a032c11ba7b512f532b7b94 Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:09:15 +0300 Subject: [PATCH 41/67] [dotnet] [bidi] Cache BrowsingContext/UserContext per session (#17172) --- .../BiDi/Json/Converters/BrowserUserContextConverter.cs | 8 +++++++- .../BiDi/Json/Converters/BrowsingContextConverter.cs | 8 ++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/dotnet/src/webdriver/BiDi/Json/Converters/BrowserUserContextConverter.cs b/dotnet/src/webdriver/BiDi/Json/Converters/BrowserUserContextConverter.cs index 47e1a2cc4ee1f..e1e939ba11196 100644 --- a/dotnet/src/webdriver/BiDi/Json/Converters/BrowserUserContextConverter.cs +++ b/dotnet/src/webdriver/BiDi/Json/Converters/BrowserUserContextConverter.cs @@ -17,6 +17,8 @@ // under the License. // +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; using OpenQA.Selenium.BiDi.Browser; @@ -25,11 +27,15 @@ namespace OpenQA.Selenium.BiDi.Json.Converters; internal class BrowserUserContextConverter(IBiDi bidi) : JsonConverter { + private static readonly ConditionalWeakTable> s_cache = new(); + public override UserContext? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { var id = reader.GetString(); - return new UserContext(id!) { BiDi = bidi }; + var sessionCache = s_cache.GetValue(bidi, _ => new ConcurrentDictionary()); + + return sessionCache.GetOrAdd(id!, key => new UserContext(bidi, key)); } public override void Write(Utf8JsonWriter writer, UserContext value, JsonSerializerOptions options) diff --git a/dotnet/src/webdriver/BiDi/Json/Converters/BrowsingContextConverter.cs b/dotnet/src/webdriver/BiDi/Json/Converters/BrowsingContextConverter.cs index e91923d49fa01..86d7c77c7f074 100644 --- a/dotnet/src/webdriver/BiDi/Json/Converters/BrowsingContextConverter.cs +++ b/dotnet/src/webdriver/BiDi/Json/Converters/BrowsingContextConverter.cs @@ -17,6 +17,8 @@ // under the License. // +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; @@ -24,13 +26,15 @@ namespace OpenQA.Selenium.BiDi.Json.Converters; internal class BrowsingContextConverter(IBiDi bidi) : JsonConverter { - private readonly IBiDi _bidi = bidi; + private static readonly ConditionalWeakTable> s_cache = new(); public override BrowsingContext.BrowsingContext? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { var id = reader.GetString(); - return new BrowsingContext.BrowsingContext(id!) { BiDi = _bidi }; + var sessionCache = s_cache.GetValue(bidi, _ => new ConcurrentDictionary()); + + return sessionCache.GetOrAdd(id!, key => new BrowsingContext.BrowsingContext(bidi, key)); } public override void Write(Utf8JsonWriter writer, BrowsingContext.BrowsingContext value, JsonSerializerOptions options) From fe8cc56c7b23544f93484f25fdf6f4048459c2c2 Mon Sep 17 00:00:00 2001 From: JAYA DILEEP <114493600+seethinajayadileep@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:39:26 +0530 Subject: [PATCH 42/67] [java] Keys: enforce CharSequence contract in charAt() (#17166) * Enforce bounds check in Keys.charAt() to comply with CharSequence contract * update charAt test to validate bounds --- java/src/org/openqa/selenium/Keys.java | 6 +++--- java/test/org/openqa/selenium/KeysTest.java | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/java/src/org/openqa/selenium/Keys.java b/java/src/org/openqa/selenium/Keys.java index de72b50107590..922178aab5f42 100644 --- a/java/src/org/openqa/selenium/Keys.java +++ b/java/src/org/openqa/selenium/Keys.java @@ -146,10 +146,10 @@ public int getCodePoint() { @Override public char charAt(int index) { - if (index == 0) { - return keyCode; + if (index != 0) { + throw new IndexOutOfBoundsException("Index: " + index + ", Length: 1"); } - return 0; + return keyCode; } @Override diff --git a/java/test/org/openqa/selenium/KeysTest.java b/java/test/org/openqa/selenium/KeysTest.java index 91e1064fb5ea5..a162b1d9c307c 100644 --- a/java/test/org/openqa/selenium/KeysTest.java +++ b/java/test/org/openqa/selenium/KeysTest.java @@ -35,8 +35,10 @@ void charAtPosition0ReturnsKeyCode() { } @Test - void charAtOtherPositionReturnsZero() { - assertThat(Keys.LEFT.charAt(10)).isEqualTo((char) 0); + void charAtOtherPositionThrows() { + + assertThatExceptionOfType(IndexOutOfBoundsException.class) + .isThrownBy(() -> Keys.LEFT.charAt(10)); } @Test From 580b98a2dde9667fdf30180e3ec0f66a1737c05c Mon Sep 17 00:00:00 2001 From: Swastik Baranwal Date: Wed, 4 Mar 2026 16:41:51 +0530 Subject: [PATCH 43/67] [java][BiDi] implement `speculation` module (#17130) * [java][BiDi] implement `speculation` module --- .../openqa/selenium/bidi/module/BUILD.bazel | 1 + .../bidi/module/SpeculationInspector.java | 76 +++++ .../selenium/bidi/speculation/BUILD.bazel | 24 ++ .../PrefetchStatusUpdatedParameters.java | 54 ++++ .../bidi/speculation/PreloadingStatus.java | 45 +++ .../bidi/speculation/Speculation.java | 28 ++ .../bidi/speculation/package-info.java | 21 ++ .../selenium/bidi/speculation/BUILD.bazel | 31 ++ .../speculation/SpeculationInspectorTest.java | 273 ++++++++++++++++++ 9 files changed, 553 insertions(+) create mode 100644 java/src/org/openqa/selenium/bidi/module/SpeculationInspector.java create mode 100644 java/src/org/openqa/selenium/bidi/speculation/BUILD.bazel create mode 100644 java/src/org/openqa/selenium/bidi/speculation/PrefetchStatusUpdatedParameters.java create mode 100644 java/src/org/openqa/selenium/bidi/speculation/PreloadingStatus.java create mode 100644 java/src/org/openqa/selenium/bidi/speculation/Speculation.java create mode 100644 java/src/org/openqa/selenium/bidi/speculation/package-info.java create mode 100644 java/test/org/openqa/selenium/bidi/speculation/BUILD.bazel create mode 100644 java/test/org/openqa/selenium/bidi/speculation/SpeculationInspectorTest.java diff --git a/java/src/org/openqa/selenium/bidi/module/BUILD.bazel b/java/src/org/openqa/selenium/bidi/module/BUILD.bazel index 652d563957a45..9dfa606989a4a 100644 --- a/java/src/org/openqa/selenium/bidi/module/BUILD.bazel +++ b/java/src/org/openqa/selenium/bidi/module/BUILD.bazel @@ -25,6 +25,7 @@ java_library( "//java/src/org/openqa/selenium/bidi/network", "//java/src/org/openqa/selenium/bidi/permissions", "//java/src/org/openqa/selenium/bidi/script", + "//java/src/org/openqa/selenium/bidi/speculation", "//java/src/org/openqa/selenium/bidi/storage", "//java/src/org/openqa/selenium/bidi/webextension", "//java/src/org/openqa/selenium/json", diff --git a/java/src/org/openqa/selenium/bidi/module/SpeculationInspector.java b/java/src/org/openqa/selenium/bidi/module/SpeculationInspector.java new file mode 100644 index 0000000000000..30b0ea059bde0 --- /dev/null +++ b/java/src/org/openqa/selenium/bidi/module/SpeculationInspector.java @@ -0,0 +1,76 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.bidi.module; + +import static java.util.Collections.emptySet; + +import java.util.Collections; +import java.util.Set; +import java.util.function.Consumer; +import org.openqa.selenium.WebDriver; +import org.openqa.selenium.bidi.BiDi; +import org.openqa.selenium.bidi.Event; +import org.openqa.selenium.bidi.HasBiDi; +import org.openqa.selenium.bidi.speculation.PrefetchStatusUpdatedParameters; +import org.openqa.selenium.bidi.speculation.Speculation; +import org.openqa.selenium.internal.Require; + +public class SpeculationInspector implements AutoCloseable { + private final Event prefetchStatusUpdatedEvent; + private final Set browsingContextIds; + + private final BiDi bidi; + + public SpeculationInspector(WebDriver driver) { + this(emptySet(), driver); + } + + public SpeculationInspector(String browsingContextId, WebDriver driver) { + this(Collections.singleton(Require.nonNull("Browsing context id", browsingContextId)), driver); + } + + public SpeculationInspector(Set browsingContextIds, WebDriver driver) { + Require.nonNull("WebDriver", driver); + Require.nonNull("Browsing context id list", browsingContextIds); + + if (!(driver instanceof HasBiDi)) { + throw new IllegalArgumentException("WebDriver instance must support BiDi protocol"); + } + + this.bidi = ((HasBiDi) driver).getBiDi(); + this.browsingContextIds = browsingContextIds; + this.prefetchStatusUpdatedEvent = Speculation.prefetchStatusUpdated(); + } + + public long onPrefetchStatusUpdated(Consumer consumer) { + if (browsingContextIds.isEmpty()) { + return this.bidi.addListener(this.prefetchStatusUpdatedEvent, consumer); + } else { + return this.bidi.addListener(browsingContextIds, this.prefetchStatusUpdatedEvent, consumer); + } + } + + public void removeListener(long subscriptionId) { + this.bidi.removeListener(subscriptionId); + } + + @Override + public void close() { + this.bidi.clearListener(Speculation.prefetchStatusUpdated()); + } +} diff --git a/java/src/org/openqa/selenium/bidi/speculation/BUILD.bazel b/java/src/org/openqa/selenium/bidi/speculation/BUILD.bazel new file mode 100644 index 0000000000000..2b94ce85bafff --- /dev/null +++ b/java/src/org/openqa/selenium/bidi/speculation/BUILD.bazel @@ -0,0 +1,24 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") +load("//java:defs.bzl", "java_library") + +java_library( + name = "speculation", + srcs = glob( + [ + "*.java", + ], + ), + visibility = [ + "//java/src/org/openqa/selenium/bidi:__subpackages__", + "//java/src/org/openqa/selenium/remote:__pkg__", + "//java/test/org/openqa/selenium/bidi:__subpackages__", + "//java/test/org/openqa/selenium/grid:__subpackages__", + ], + deps = [ + "//java/src/org/openqa/selenium:core", + "//java/src/org/openqa/selenium/bidi", + "//java/src/org/openqa/selenium/json", + "//java/src/org/openqa/selenium/remote/http", + artifact("org.jspecify:jspecify"), + ], +) diff --git a/java/src/org/openqa/selenium/bidi/speculation/PrefetchStatusUpdatedParameters.java b/java/src/org/openqa/selenium/bidi/speculation/PrefetchStatusUpdatedParameters.java new file mode 100644 index 0000000000000..357bead447868 --- /dev/null +++ b/java/src/org/openqa/selenium/bidi/speculation/PrefetchStatusUpdatedParameters.java @@ -0,0 +1,54 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.bidi.speculation; + +import java.util.Map; + +public class PrefetchStatusUpdatedParameters { + + private final String context; + private final String url; + private final PreloadingStatus status; + + private PrefetchStatusUpdatedParameters(String context, String url, PreloadingStatus status) { + this.context = context; + this.url = url; + this.status = status; + } + + public static PrefetchStatusUpdatedParameters fromJson(Map params) { + String context = (String) params.get("context"); + String url = (String) params.get("url"); + String statusStr = (String) params.get("status"); + PreloadingStatus status = PreloadingStatus.fromString(statusStr); + + return new PrefetchStatusUpdatedParameters(context, url, status); + } + + public String getContext() { + return context; + } + + public String getUrl() { + return url; + } + + public PreloadingStatus getStatus() { + return status; + } +} diff --git a/java/src/org/openqa/selenium/bidi/speculation/PreloadingStatus.java b/java/src/org/openqa/selenium/bidi/speculation/PreloadingStatus.java new file mode 100644 index 0000000000000..334ed7f58b476 --- /dev/null +++ b/java/src/org/openqa/selenium/bidi/speculation/PreloadingStatus.java @@ -0,0 +1,45 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.bidi.speculation; + +public enum PreloadingStatus { + PENDING("pending"), + READY("ready"), + SUCCESS("success"), + FAILURE("failure"); + + private final String status; + + PreloadingStatus(String status) { + this.status = status; + } + + @Override + public String toString() { + return status; + } + + public static PreloadingStatus fromString(String status) { + for (PreloadingStatus s : PreloadingStatus.values()) { + if (s.status.equalsIgnoreCase(status)) { + return s; + } + } + throw new IllegalArgumentException("Unknown preloading status: " + status); + } +} diff --git a/java/src/org/openqa/selenium/bidi/speculation/Speculation.java b/java/src/org/openqa/selenium/bidi/speculation/Speculation.java new file mode 100644 index 0000000000000..b4b79e4049940 --- /dev/null +++ b/java/src/org/openqa/selenium/bidi/speculation/Speculation.java @@ -0,0 +1,28 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.bidi.speculation; + +import org.openqa.selenium.bidi.Event; + +public class Speculation { + public static Event prefetchStatusUpdated() { + return new Event<>( + "speculation.prefetchStatusUpdated", + params -> PrefetchStatusUpdatedParameters.fromJson(params)); + } +} diff --git a/java/src/org/openqa/selenium/bidi/speculation/package-info.java b/java/src/org/openqa/selenium/bidi/speculation/package-info.java new file mode 100644 index 0000000000000..a613fd23f8b90 --- /dev/null +++ b/java/src/org/openqa/selenium/bidi/speculation/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.bidi.speculation; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/bidi/speculation/BUILD.bazel b/java/test/org/openqa/selenium/bidi/speculation/BUILD.bazel new file mode 100644 index 0000000000000..5d579439ef43a --- /dev/null +++ b/java/test/org/openqa/selenium/bidi/speculation/BUILD.bazel @@ -0,0 +1,31 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") +load("//java:defs.bzl", "BIDI_BROWSERS", "JUNIT5_DEPS", "java_selenium_test_suite") + +java_selenium_test_suite( + name = "large-tests", + size = "large", + srcs = glob(["*Test.java"]), + browsers = BIDI_BROWSERS, + tags = [ + "selenium-remote", + ], + deps = [ + "//java/src/org/openqa/selenium/bidi", + "//java/src/org/openqa/selenium/bidi/browsingcontext", + "//java/src/org/openqa/selenium/bidi/log", + "//java/src/org/openqa/selenium/bidi/module", + "//java/src/org/openqa/selenium/bidi/script", + "//java/src/org/openqa/selenium/bidi/speculation", + "//java/src/org/openqa/selenium/firefox", + "//java/src/org/openqa/selenium/grid/security", + "//java/src/org/openqa/selenium/json", + "//java/src/org/openqa/selenium/remote", + "//java/src/org/openqa/selenium/support", + "//java/test/org/openqa/selenium/environment", + "//java/test/org/openqa/selenium/testing:annotations", + "//java/test/org/openqa/selenium/testing:test-base", + "//java/test/org/openqa/selenium/testing/drivers", + artifact("org.junit.jupiter:junit-jupiter-api"), + artifact("org.assertj:assertj-core"), + ] + JUNIT5_DEPS, +) diff --git a/java/test/org/openqa/selenium/bidi/speculation/SpeculationInspectorTest.java b/java/test/org/openqa/selenium/bidi/speculation/SpeculationInspectorTest.java new file mode 100644 index 0000000000000..459a3e058bbe3 --- /dev/null +++ b/java/test/org/openqa/selenium/bidi/speculation/SpeculationInspectorTest.java @@ -0,0 +1,273 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.bidi.speculation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.openqa.selenium.testing.drivers.Browser.FIREFOX; +import static org.openqa.selenium.testing.drivers.Browser.SAFARI; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.openqa.selenium.bidi.module.Script; +import org.openqa.selenium.bidi.module.SpeculationInspector; +import org.openqa.selenium.testing.JupiterTestBase; +import org.openqa.selenium.testing.NeedsFreshDriver; +import org.openqa.selenium.testing.NotYetImplemented; + +class SpeculationInspectorTest extends JupiterTestBase { + + private Script script; + private SpeculationInspector speculationInspector; + + @BeforeEach + public void setUp() { + script = new Script(driver); + speculationInspector = new SpeculationInspector(driver); + } + + @AfterEach + public void cleanUp() { + if (speculationInspector != null) { + speculationInspector.close(); + } + if (script != null) { + script.close(); + } + } + + void addSpeculationRulesAndLink(String rules, String href, String linkText, String linkId) { + String functionDeclaration = + String.format( + "() => {" + + "const script = document.createElement('script');" + + "script.type = 'speculationrules';" + + "script.textContent = `%s`;" + + "document.head.appendChild(script);" + + "const link = document.createElement('a');" + + "link.href = '%s';" + + "link.textContent = '%s';" + + "link.id = '%s';" + + "document.body.appendChild(link);" + + "}", + rules, href, linkText, linkId); + + script.callFunctionInBrowsingContext( + driver.getWindowHandle(), + functionDeclaration, + false, + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + + @Test + @NeedsFreshDriver + @NotYetImplemented(FIREFOX) + @NotYetImplemented(SAFARI) + void canListenToPrefetchStatusUpdatedWithPendingAndReadyEvents() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + List events = new ArrayList<>(); + + speculationInspector.onPrefetchStatusUpdated( + event -> { + events.add(event); + latch.countDown(); + }); + + String testUrl = appServer.whereIs("/common/blank.html"); + driver.get(testUrl); + + String prefetchTarget = appServer.whereIs("/common/dummy.xml"); + String speculationRules = + String.format( + "{\"prefetch\": [{\"source\": \"list\", \"urls\": [\"%s\"]}]}", prefetchTarget); + + addSpeculationRulesAndLink(speculationRules, prefetchTarget, "Test Link", "prefetch-page"); + + // Wait for at least one prefetch event + latch.await(5, TimeUnit.SECONDS); + + // Verify we got at least one event + assertThat(events).hasSizeGreaterThanOrEqualTo(1); + + PrefetchStatusUpdatedParameters firstEvent = events.get(0); + assertThat(firstEvent.getUrl()).isEqualTo(prefetchTarget); + assertThat(firstEvent.getContext()).isEqualTo(driver.getWindowHandle()); + assertThat(firstEvent.getStatus()).isNotNull(); + } + + @Test + @NeedsFreshDriver + @NotYetImplemented(FIREFOX) + @NotYetImplemented(SAFARI) + void canListenToPrefetchStatusUpdatedWithNavigationAndSuccess() + throws ExecutionException, InterruptedException, TimeoutException { + CountDownLatch latch = new CountDownLatch(1); + List events = new ArrayList<>(); + + speculationInspector.onPrefetchStatusUpdated( + event -> { + events.add(event); + latch.countDown(); + }); + + String testUrl = appServer.whereIs("/common/blank.html"); + driver.get(testUrl); + + String prefetchTarget = appServer.whereIs("/common/dummy.xml"); + String speculationRules = + String.format( + "{\"prefetch\": [{\"source\": \"list\", \"urls\": [\"%s\"]}]}", prefetchTarget); + + addSpeculationRulesAndLink(speculationRules, prefetchTarget, "Test Link", "prefetch-page"); + + // Wait for prefetch event + latch.await(5, TimeUnit.SECONDS); + + assertThat(events).hasSizeGreaterThanOrEqualTo(1); + + // Verify first event + assertThat(events.get(0).getUrl()).isEqualTo(prefetchTarget); + assertThat(events.get(0).getContext()).isEqualTo(driver.getWindowHandle()); + + // If prefetch succeeded (status is READY), proceed with success test; otherwise skip + if (events.stream().noneMatch(e -> e.getStatus() == PreloadingStatus.READY)) { + // Prefetch didn't succeed, likely due to Chrome's restrictions + return; + } + + // Set up for success event + CompletableFuture successFuture = new CompletableFuture<>(); + speculationInspector.onPrefetchStatusUpdated( + event -> { + if (event.getStatus() == PreloadingStatus.SUCCESS) { + successFuture.complete(event); + } + }); + + // Navigate to the prefetched page by clicking the link + script.callFunctionInBrowsingContext( + driver.getWindowHandle(), + "() => { const link = document.getElementById('prefetch-page'); if (link) { link.click(); }" + + " }", + false, + Optional.empty(), + Optional.empty(), + Optional.empty()); + + // Wait for success event + PrefetchStatusUpdatedParameters successEvent = successFuture.get(5, TimeUnit.SECONDS); + assertThat(successEvent.getUrl()).isEqualTo(prefetchTarget); + assertThat(successEvent.getStatus()).isEqualTo(PreloadingStatus.SUCCESS); + assertThat(successEvent.getContext()).isEqualTo(driver.getWindowHandle()); + } + + @Test + @NeedsFreshDriver + @NotYetImplemented(FIREFOX) + @NotYetImplemented(SAFARI) + void canListenToPrefetchStatusUpdatedWithFailureEvents() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + List events = new ArrayList<>(); + + speculationInspector.onPrefetchStatusUpdated( + event -> { + events.add(event); + latch.countDown(); + }); + + String testUrl = appServer.whereIs("/common/blank.html"); + driver.get(testUrl); + + // Use a non-existent path that will return 404 + String failedTarget = appServer.whereIs("/nonexistent/path/that/will/404.xml"); + String speculationRules = + String.format("{\"prefetch\": [{\"source\": \"list\", \"urls\": [\"%s\"]}]}", failedTarget); + + addSpeculationRulesAndLink(speculationRules, failedTarget, "Test Link", "prefetch-page"); + + // Wait for event + latch.await(5, TimeUnit.SECONDS); + + // Verify we got at least one event + assertThat(events).hasSizeGreaterThanOrEqualTo(1); + + PrefetchStatusUpdatedParameters firstEvent = events.get(0); + assertThat(firstEvent.getUrl()).isEqualTo(failedTarget); + assertThat(firstEvent.getContext()).isEqualTo(driver.getWindowHandle()); + // Verify status is either PENDING or FAILURE + assertThat(firstEvent.getStatus()).isIn(PreloadingStatus.PENDING, PreloadingStatus.FAILURE); + } + + @Test + @NeedsFreshDriver + @NotYetImplemented(FIREFOX) + @NotYetImplemented(SAFARI) + void canUnsubscribeFromPrefetchStatusUpdated() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + List events = new ArrayList<>(); + + long subscriptionId = + speculationInspector.onPrefetchStatusUpdated( + event -> { + events.add(event); + latch.countDown(); + }); + + String testUrl = appServer.whereIs("/common/blank.html"); + driver.get(testUrl); + + String prefetchTarget = appServer.whereIs("/common/dummy.xml"); + String speculationRules = + String.format( + "{\"prefetch\": [{\"source\": \"list\", \"urls\": [\"%s\"]}]}", prefetchTarget); + + addSpeculationRulesAndLink(speculationRules, prefetchTarget, "Test Link", "prefetch-page"); + + // Wait for events to be emitted + latch.await(5, TimeUnit.SECONDS); + assertThat(events).hasSizeGreaterThanOrEqualTo(1); + + // Unsubscribe + speculationInspector.removeListener(subscriptionId); + + // Clear events and reload + events.clear(); + driver.get(testUrl); + + String prefetchTarget2 = appServer.whereIs("/common/square.png"); + String speculationRules2 = + String.format( + "{\"prefetch\": [{\"source\": \"list\", \"urls\": [\"%s\"]}]}", prefetchTarget2); + + addSpeculationRulesAndLink( + speculationRules2, prefetchTarget2, "Test Link 2", "prefetch-page-2"); + + // Verify no events are emitted after unsubscribing + assertThat(events).isEmpty(); + } +} From 0a348c40c861af31262012bada59e861e3681a93 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Thu, 5 Mar 2026 01:51:32 +0100 Subject: [PATCH 44/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17176) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index 7d82fd5b86cf1..d659c8af5d1a5 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -50,8 +50,8 @@ js_library( http_archive( name = "linux_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b3/linux-x86_64/en-US/firefox-149.0b3.tar.xz", - sha256 = "9c69d673378aebcbc4bc5f8dd4becc31cf53d9f604de13f4aa4d3a6dded08475", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b4/linux-x86_64/en-US/firefox-149.0b4.tar.xz", + sha256 = "57539691f99f49124487846e1ed30b2c3c1e5413a68a4838644d9f35fcd22ec6", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -72,8 +72,8 @@ js_library( dmg_archive( name = "mac_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b3/mac/en-US/Firefox%20149.0b3.dmg", - sha256 = "7872c9b4c98d8bfe818fd34e4f14bc55bf797c5b55842623b2ba28c6ac102226", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b4/mac/en-US/Firefox%20149.0b4.dmg", + sha256 = "69b90c44f1757f0ba0799e82767e38bac8cad4762de76113a14237190f9c4212", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) From eae2652c3f9ffbb6d3036f0eaf1ee2583b1bfede Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Thu, 5 Mar 2026 09:25:02 +0200 Subject: [PATCH 45/67] [java] specify nullability in `org.openqa.selenium.grid.*` packages (#17173) * remove occasionally copy-pasted javadoc * specify nullability in package `org.openqa.selenium.grid` --- .../openqa/selenium/MutableCapabilities.java | 2 +- .../selenium/SessionNotCreatedException.java | 6 +- java/src/org/openqa/selenium/cli/BUILD.bazel | 1 + .../org/openqa/selenium/cli/package-info.java | 3 + .../openqa/selenium/edge/package-info.java | 29 ------ .../package-info.java | 29 ------ java/src/org/openqa/selenium/grid/BUILD.bazel | 2 + .../grid/TemplateGridServerCommand.java | 17 ++-- .../selenium/grid/commands/InfoCommand.java | 2 - .../selenium/grid/commands/package-info.java | 21 ++++ .../selenium/grid/component/BUILD.bazel | 2 + .../selenium/grid/component/package-info.java | 21 ++++ .../selenium/grid/config/AnnotatedConfig.java | 4 +- .../openqa/selenium/grid/config/BUILD.bazel | 1 + .../grid/config/ConcatenatingConfig.java | 3 +- .../openqa/selenium/grid/config/Config.java | 3 +- .../selenium/grid/config/ConfigException.java | 4 +- .../selenium/grid/config/MemoizedConfig.java | 5 +- .../selenium/grid/config/package-info.java | 21 ++++ .../org/openqa/selenium/grid/data/BUILD.bazel | 2 + .../selenium/grid/data/CapabilityCount.java | 1 + .../grid/data/CreateSessionRequest.java | 1 + .../grid/data/CreateSessionResponse.java | 1 + .../selenium/grid/data/DistributorStatus.java | 1 + .../grid/data/NewSessionErrorResponse.java | 1 + .../grid/data/NewSessionResponse.java | 1 + .../org/openqa/selenium/grid/data/NodeId.java | 1 + .../openqa/selenium/grid/data/NodeStatus.java | 3 +- .../openqa/selenium/grid/data/RequestId.java | 1 + .../openqa/selenium/grid/data/Session.java | 3 +- .../selenium/grid/data/SessionClosedData.java | 25 +++-- .../grid/data/SessionCreatedData.java | 1 + .../selenium/grid/data/SessionEventData.java | 16 ++-- .../grid/data/SessionRemovalInfo.java | 5 +- .../selenium/grid/data/SessionRequest.java | 1 + .../grid/data/SessionRequestCapability.java | 8 +- .../org/openqa/selenium/grid/data/Slot.java | 11 ++- .../org/openqa/selenium/grid/data/SlotId.java | 1 + .../selenium/grid/data/package-info.java | 21 ++++ .../selenium/grid/distributor/AddNode.java | 4 +- .../selenium/grid/distributor/BUILD.bazel | 1 + .../selenium/grid/distributor/GridModel.java | 3 +- .../grid/distributor/NodeRegistry.java | 5 +- .../grid/distributor/config/BUILD.bazel | 1 + .../grid/distributor/config/package-info.java | 21 ++++ .../grid/distributor/httpd/BUILD.bazel | 1 + .../grid/distributor/httpd/package-info.java | 21 ++++ .../grid/distributor/local/BUILD.bazel | 1 + .../distributor/local/LocalDistributor.java | 3 + .../distributor/local/LocalGridModel.java | 8 +- .../distributor/local/LocalNodeRegistry.java | 4 +- .../grid/distributor/local/package-info.java | 21 ++++ .../grid/distributor/package-info.java | 3 + .../grid/distributor/remote/BUILD.bazel | 2 + .../grid/distributor/remote/package-info.java | 21 ++++ .../grid/distributor/selector/BUILD.bazel | 1 + .../distributor/selector/package-info.java | 21 ++++ .../openqa/selenium/grid/graphql/BUILD.bazel | 1 + .../selenium/grid/graphql/SessionData.java | 2 + .../openqa/selenium/grid/graphql/Types.java | 8 +- .../selenium/grid/graphql/package-info.java | 21 ++++ .../openqa/selenium/grid/jmx/JMXHelper.java | 4 +- .../org/openqa/selenium/grid/jmx/MBean.java | 4 +- .../selenium/grid/jmx/package-info.java | 21 ++++ .../org/openqa/selenium/grid/log/BUILD.bazel | 1 + .../selenium/grid/log/FlushingHandler.java | 3 +- .../selenium/grid/log/LoggingOptions.java | 4 +- .../selenium/grid/log/package-info.java | 21 ++++ .../org/openqa/selenium/grid/node/Node.java | 3 + .../selenium/grid/node/config/BUILD.bazel | 1 + .../config/DriverServiceSessionFactory.java | 2 + .../grid/node/config/NodeOptions.java | 2 +- .../config/SessionCapabilitiesMutator.java | 9 +- .../grid/node/config/package-info.java | 21 ++++ .../selenium/grid/node/docker/BUILD.bazel | 1 + .../grid/node/docker/DockerOptions.java | 4 + .../grid/node/docker/DockerSession.java | 5 +- .../node/docker/DockerSessionFactory.java | 18 ++-- .../grid/node/docker/package-info.java | 21 ++++ .../selenium/grid/node/httpd/BUILD.bazel | 1 + .../selenium/grid/node/httpd/NodeServer.java | 5 +- .../grid/node/httpd/package-info.java | 21 ++++ .../openqa/selenium/grid/node/k8s/BUILD.bazel | 1 + .../selenium/grid/node/k8s/OneShotNode.java | 12 ++- .../selenium/grid/node/k8s/package-info.java | 21 ++++ .../selenium/grid/node/kubernetes/BUILD.bazel | 1 + .../node/kubernetes/InheritedPodSpec.java | 95 +++++++++++-------- .../node/kubernetes/KubernetesOptions.java | 17 +++- .../node/kubernetes/KubernetesSession.java | 13 +-- .../kubernetes/KubernetesSessionFactory.java | 58 ++++++----- .../grid/node/kubernetes/package-info.java | 21 ++++ .../selenium/grid/node/local/BUILD.bazel | 1 + .../selenium/grid/node/local/LocalNode.java | 8 +- .../selenium/grid/node/local/SessionSlot.java | 6 +- .../grid/node/local/package-info.java | 21 ++++ .../selenium/grid/node/package-info.java | 21 ++++ .../selenium/grid/node/relay/BUILD.bazel | 1 + .../grid/node/relay/RelayOptions.java | 5 +- .../grid/node/relay/RelaySessionFactory.java | 13 +-- .../grid/node/relay/package-info.java | 21 ++++ .../selenium/grid/node/remote/BUILD.bazel | 1 + .../selenium/grid/node/remote/RemoteNode.java | 4 +- .../grid/node/remote/package-info.java | 21 ++++ .../openqa/selenium/grid/package-info.java | 6 +- .../openqa/selenium/grid/router/BUILD.bazel | 1 + .../selenium/grid/router/httpd/BUILD.bazel | 1 + .../grid/router/httpd/package-info.java | 21 ++++ .../selenium/grid/router/package-info.java | 21 ++++ .../openqa/selenium/grid/security/BUILD.bazel | 1 + .../security/BasicAuthenticationFilter.java | 3 +- .../selenium/grid/security/SecretOptions.java | 2 + .../selenium/grid/security/package-info.java | 21 ++++ .../openqa/selenium/grid/server/BUILD.bazel | 1 + .../grid/server/BaseServerOptions.java | 3 +- .../selenium/grid/server/EventBusOptions.java | 3 +- .../selenium/grid/server/package-info.java | 21 ++++ .../openqa/selenium/grid/session/BUILD.bazel | 1 + .../selenium/grid/session/package-info.java | 21 ++++ .../selenium/grid/sessionmap/BUILD.bazel | 1 + .../grid/sessionmap/config/BUILD.bazel | 1 + .../grid/sessionmap/config/package-info.java | 21 ++++ .../grid/sessionmap/httpd/BUILD.bazel | 1 + .../grid/sessionmap/httpd/package-info.java | 21 ++++ .../sessionmap/jdbc/JdbcBackedSessionMap.java | 13 +-- .../grid/sessionmap/jdbc/package-info.java | 21 ++++ .../grid/sessionmap/local/BUILD.bazel | 1 + .../sessionmap/local/LocalSessionMap.java | 3 + .../grid/sessionmap/local/package-info.java | 21 ++++ .../grid/sessionmap/package-info.java | 21 ++++ .../grid/sessionmap/redis/BUILD.bazel | 1 + .../redis/RedisBackedSessionMap.java | 8 +- .../grid/sessionmap/redis/package-info.java | 21 ++++ .../grid/sessionmap/remote/BUILD.bazel | 2 + .../sessionmap/remote/RemoteSessionMap.java | 5 +- .../grid/sessionmap/remote/package-info.java | 21 ++++ .../selenium/grid/sessionqueue/BUILD.bazel | 1 + .../grid/sessionqueue/config/BUILD.bazel | 1 + .../sessionqueue/config/package-info.java | 21 ++++ .../grid/sessionqueue/httpd/BUILD.bazel | 1 + .../grid/sessionqueue/httpd/package-info.java | 21 ++++ .../grid/sessionqueue/local/BUILD.bazel | 1 + .../local/LocalNewSessionQueue.java | 4 +- .../grid/sessionqueue/local/package-info.java | 21 ++++ .../grid/sessionqueue/package-info.java | 21 ++++ .../grid/sessionqueue/remote/BUILD.bazel | 1 + .../sessionqueue/remote/package-info.java | 21 ++++ .../org/openqa/selenium/grid/web/BUILD.bazel | 1 + .../selenium/grid/web/MergedResource.java | 3 +- .../openqa/selenium/grid/web/NoHandler.java | 3 +- .../org/openqa/selenium/grid/web/Values.java | 2 +- .../selenium/grid/web/package-info.java | 21 ++++ .../selenium/ie/InternetExplorerOptions.java | 8 +- .../org/openqa/selenium/ie/package-info.java | 29 ------ .../org/openqa/selenium/internal/Debug.java | 3 +- .../org/openqa/selenium/manager/BUILD.bazel | 1 + .../selenium/manager/SeleniumManager.java | 9 +- .../manager/SeleniumManagerOutput.java | 22 +++-- .../openqa/selenium/manager/package-info.java | 21 ++++ .../org/openqa/selenium/net/UrlChecker.java | 5 +- .../openqa/selenium/print/package-info.java | 29 ------ .../remote/AbstractDriverOptions.java | 29 ++++-- .../openqa/selenium/remote/RemoteLogs.java | 24 ++--- .../openqa/selenium/safari/package-info.java | 29 ------ .../openqa/selenium/grid/config/BUILD.bazel | 1 + .../selenium/grid/config/package-info.java | 21 ++++ .../org/openqa/selenium/grid/data/BUILD.bazel | 1 + .../grid/data/DefaultSlotMatcherTest.java | 1 + .../selenium/grid/data/package-info.java | 21 ++++ .../grid/distributor/AddingNodesTest.java | 5 +- .../selenium/grid/distributor/BUILD.bazel | 1 + .../grid/distributor/local/BUILD.bazel | 1 + .../local/LocalNodeRegistryTest.java | 1 + .../grid/distributor/local/package-info.java | 21 ++++ .../grid/distributor/package-info.java | 21 ++++ .../grid/distributor/selector/BUILD.bazel | 1 + .../distributor/selector/package-info.java | 21 ++++ .../openqa/selenium/grid/graphql/BUILD.bazel | 1 + .../selenium/grid/graphql/package-info.java | 21 ++++ .../openqa/selenium/grid/gridui/BUILD.bazel | 1 + .../selenium/grid/gridui/package-info.java | 21 ++++ .../org/openqa/selenium/grid/node/BUILD.bazel | 1 + .../grid/node/ProxyNodeWebsocketsTest.java | 8 +- .../selenium/grid/node/config/BUILD.bazel | 2 + .../SessionCapabilitiesMutatorTest.java | 54 ++++++----- .../grid/node/config/package-info.java | 21 ++++ .../selenium/grid/node/data/BUILD.bazel | 2 + .../selenium/grid/node/data/package-info.java | 21 ++++ .../selenium/grid/node/docker/BUILD.bazel | 1 + .../grid/node/docker/package-info.java | 21 ++++ .../selenium/grid/node/kubernetes/BUILD.bazel | 1 + .../grid/node/kubernetes/package-info.java | 21 ++++ .../selenium/grid/node/local/BUILD.bazel | 1 + .../grid/node/local/package-info.java | 21 ++++ .../selenium/grid/node/package-info.java | 21 ++++ .../selenium/grid/node/relay/BUILD.bazel | 1 + .../grid/node/relay/package-info.java | 21 ++++ .../openqa/selenium/grid/router/BUILD.bazel | 5 + .../grid/router/RemoteWebDriverBiDiTest.java | 1 + .../grid/router/SessionCleanUpTest.java | 1 + .../selenium/grid/router/package-info.java | 21 ++++ .../openqa/selenium/grid/security/BUILD.bazel | 1 + .../selenium/grid/security/package-info.java | 21 ++++ .../openqa/selenium/grid/server/BUILD.bazel | 1 + .../selenium/grid/server/package-info.java | 21 ++++ .../selenium/remote/RemoteLogsTest.java | 2 +- 205 files changed, 1689 insertions(+), 416 deletions(-) create mode 100644 java/src/org/openqa/selenium/grid/commands/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/component/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/config/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/data/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/distributor/config/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/distributor/httpd/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/distributor/local/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/distributor/remote/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/distributor/selector/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/graphql/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/jmx/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/log/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/config/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/docker/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/httpd/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/k8s/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/kubernetes/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/local/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/relay/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/node/remote/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/router/httpd/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/router/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/security/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/server/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/session/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionmap/config/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionmap/httpd/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionmap/jdbc/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionmap/local/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionmap/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionmap/redis/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionmap/remote/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionqueue/config/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionqueue/httpd/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionqueue/local/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionqueue/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/sessionqueue/remote/package-info.java create mode 100644 java/src/org/openqa/selenium/grid/web/package-info.java create mode 100644 java/src/org/openqa/selenium/manager/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/config/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/data/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/distributor/local/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/distributor/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/distributor/selector/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/graphql/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/gridui/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/node/config/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/node/data/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/node/docker/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/node/kubernetes/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/node/local/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/node/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/node/relay/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/router/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/security/package-info.java create mode 100644 java/test/org/openqa/selenium/grid/server/package-info.java diff --git a/java/src/org/openqa/selenium/MutableCapabilities.java b/java/src/org/openqa/selenium/MutableCapabilities.java index 5d9081eeb66d7..37d68ffceb868 100644 --- a/java/src/org/openqa/selenium/MutableCapabilities.java +++ b/java/src/org/openqa/selenium/MutableCapabilities.java @@ -74,7 +74,7 @@ public void setCapability(String capabilityName, boolean value) { setCapability(capabilityName, (Object) value); } - public void setCapability(String capabilityName, String value) { + public void setCapability(String capabilityName, @Nullable String value) { setCapability(capabilityName, (Object) value); } diff --git a/java/src/org/openqa/selenium/SessionNotCreatedException.java b/java/src/org/openqa/selenium/SessionNotCreatedException.java index 30bb9d6c2ab21..47e397dab835f 100644 --- a/java/src/org/openqa/selenium/SessionNotCreatedException.java +++ b/java/src/org/openqa/selenium/SessionNotCreatedException.java @@ -17,16 +17,18 @@ package org.openqa.selenium; +import org.jspecify.annotations.Nullable; + /** Indicates that a session could not be created. */ public class SessionNotCreatedException extends WebDriverException { - public SessionNotCreatedException(String msg) { + public SessionNotCreatedException(@Nullable String msg) { super( "Could not start a new session. " + msg + (msg != null && msg.contains("Host info") ? "" : " \n" + getHostInformation())); } - public SessionNotCreatedException(String msg, Throwable cause) { + public SessionNotCreatedException(@Nullable String msg, @Nullable Throwable cause) { super( "Could not start a new session. " + msg diff --git a/java/src/org/openqa/selenium/cli/BUILD.bazel b/java/src/org/openqa/selenium/cli/BUILD.bazel index 663b7e81ac101..1a621c3b5000e 100644 --- a/java/src/org/openqa/selenium/cli/BUILD.bazel +++ b/java/src/org/openqa/selenium/cli/BUILD.bazel @@ -9,5 +9,6 @@ java_library( deps = [ "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/grid/config", + "@maven//:org_jspecify_jspecify", ], ) diff --git a/java/src/org/openqa/selenium/cli/package-info.java b/java/src/org/openqa/selenium/cli/package-info.java index 2e3f8d91a4c4a..68be944df3d78 100644 --- a/java/src/org/openqa/selenium/cli/package-info.java +++ b/java/src/org/openqa/selenium/cli/package-info.java @@ -44,4 +44,7 @@ *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. */ +@NullMarked package org.openqa.selenium.cli; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/edge/package-info.java b/java/src/org/openqa/selenium/edge/package-info.java index 0ecb0d05d4dd3..5f1ac3fda669a 100644 --- a/java/src/org/openqa/selenium/edge/package-info.java +++ b/java/src/org/openqa/selenium/edge/package-info.java @@ -15,35 +15,6 @@ // specific language governing permissions and limitations // under the License. -/** - * Mechanisms to configure and run selenium via the command line. There are two key classes {@link - * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. - * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, - * for which there are strongly-typed role-specific classes that use a {@code Config}, such as - * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. - * - *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, - * the process for building the set of flags to use is: - * - *

    - *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} - * and {@link org.openqa.selenium.grid.config.ConfigFlags} - *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link - * org.openqa.selenium.grid.config.HasRoles} where {@link - * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link - * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. - *
  3. Finally all flags returned by {@link - * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. - *
- * - *

The flags are then used by JCommander to parse the command arguments. Once that's done, the - * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of - * the flag objects with system properties and environment variables. This implies that each flag - * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. - * - *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's - * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. - */ @NullMarked package org.openqa.selenium.edge; diff --git a/java/src/org/openqa/selenium/federatedcredentialmanagement/package-info.java b/java/src/org/openqa/selenium/federatedcredentialmanagement/package-info.java index 359e361a886f5..f2a1ac6f33779 100644 --- a/java/src/org/openqa/selenium/federatedcredentialmanagement/package-info.java +++ b/java/src/org/openqa/selenium/federatedcredentialmanagement/package-info.java @@ -15,35 +15,6 @@ // specific language governing permissions and limitations // under the License. -/** - * Mechanisms to configure and run selenium via the command line. There are two key classes {@link - * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. - * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, - * for which there are strongly-typed role-specific classes that use a {@code Config}, such as - * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. - * - *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, - * the process for building the set of flags to use is: - * - *

    - *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} - * and {@link org.openqa.selenium.grid.config.ConfigFlags} - *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link - * org.openqa.selenium.grid.config.HasRoles} where {@link - * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link - * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. - *
  3. Finally all flags returned by {@link - * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. - *
- * - *

The flags are then used by JCommander to parse the command arguments. Once that's done, the - * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of - * the flag objects with system properties and environment variables. This implies that each flag - * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. - * - *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's - * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. - */ @NullMarked package org.openqa.selenium.federatedcredentialmanagement; diff --git a/java/src/org/openqa/selenium/grid/BUILD.bazel b/java/src/org/openqa/selenium/grid/BUILD.bazel index 2b4e66a98c449..dcc5b7f0e2c82 100644 --- a/java/src/org/openqa/selenium/grid/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/BUILD.bazel @@ -89,6 +89,7 @@ java_library( "//java/src/org/openqa/selenium/remote/http", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) @@ -171,6 +172,7 @@ java_export( ":base-command", "//java/src/org/openqa/selenium/cli", "//java/src/org/openqa/selenium/grid/config", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java b/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java index 0c8e82b4b692a..963b445d6435b 100644 --- a/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java +++ b/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java @@ -26,6 +26,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.grid.config.CompoundConfig; import org.openqa.selenium.grid.config.Config; import org.openqa.selenium.grid.config.MemoizedConfig; @@ -91,24 +92,26 @@ protected static Routable baseRoute(String prefix, Route route) { protected abstract Handlers createHandlers(Config config); - public abstract static class Handlers implements Closeable { + protected abstract static class Handlers implements Closeable { public final HttpHandler httpHandler; public final BiFunction, Optional>> websocketHandler; /** Optional resolver for direct TCP tunnel of WebSocket connections. May be null. */ - public final Function> tcpTunnelResolver; + public final @Nullable Function> tcpTunnelResolver; - public Handlers( + protected Handlers( HttpHandler http, - BiFunction, Optional>> websocketHandler) { + @Nullable BiFunction, Optional>> + websocketHandler) { this(http, websocketHandler, null); } - public Handlers( + protected Handlers( HttpHandler http, - BiFunction, Optional>> websocketHandler, - Function> tcpTunnelResolver) { + @Nullable BiFunction, Optional>> + websocketHandler, + @Nullable Function> tcpTunnelResolver) { this.httpHandler = Require.nonNull("HTTP handler", http); this.websocketHandler = websocketHandler == null ? (str, sink) -> Optional.empty() : websocketHandler; diff --git a/java/src/org/openqa/selenium/grid/commands/InfoCommand.java b/java/src/org/openqa/selenium/grid/commands/InfoCommand.java index 41db14a988552..3fcbf02e2df51 100644 --- a/java/src/org/openqa/selenium/grid/commands/InfoCommand.java +++ b/java/src/org/openqa/selenium/grid/commands/InfoCommand.java @@ -34,7 +34,6 @@ import java.util.Collections; import java.util.Set; import java.util.regex.Pattern; -import org.jspecify.annotations.NonNull; import org.openqa.selenium.cli.CliCommand; import org.openqa.selenium.cli.WrappedPrintWriter; import org.openqa.selenium.grid.config.Role; @@ -169,7 +168,6 @@ private String readContent(String path) throws IOException { return formattedText.toString(); } - @NonNull private String unformattedText(String path) throws IOException { try (InputStream in = getClass().getClassLoader().getResourceAsStream(path)) { requireNonNull(in, () -> "Resource is not found in classpath: " + path); diff --git a/java/src/org/openqa/selenium/grid/commands/package-info.java b/java/src/org/openqa/selenium/grid/commands/package-info.java new file mode 100644 index 0000000000000..d57f8ca1ce6b9 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/commands/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.commands; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/component/BUILD.bazel b/java/src/org/openqa/selenium/grid/component/BUILD.bazel index e4692451e6921..8160be880f20f 100644 --- a/java/src/org/openqa/selenium/grid/component/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/component/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") load("//java:defs.bzl", "java_library") java_library( @@ -10,5 +11,6 @@ java_library( ], deps = [ "//java/src/org/openqa/selenium:core", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/component/package-info.java b/java/src/org/openqa/selenium/grid/component/package-info.java new file mode 100644 index 0000000000000..eaa2be1630be9 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/component/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.component; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/config/AnnotatedConfig.java b/java/src/org/openqa/selenium/grid/config/AnnotatedConfig.java index f963f87436b7f..7476063aeab3d 100644 --- a/java/src/org/openqa/selenium/grid/config/AnnotatedConfig.java +++ b/java/src/org/openqa/selenium/grid/config/AnnotatedConfig.java @@ -35,6 +35,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; /** @@ -108,7 +109,8 @@ public AnnotatedConfig(Object obj, Set cliArgs, boolean includeCliArgs) this.config = values; } - private String getSingleValue(Object value) { + @Nullable + private String getSingleValue(@Nullable Object value) { if (value == null) { return null; } diff --git a/java/src/org/openqa/selenium/grid/config/BUILD.bazel b/java/src/org/openqa/selenium/grid/config/BUILD.bazel index e2be2a4b14de7..92344f484eb73 100644 --- a/java/src/org/openqa/selenium/grid/config/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/config/BUILD.bazel @@ -22,5 +22,6 @@ java_library( artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), artifact("org.tomlj:tomlj"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/config/ConcatenatingConfig.java b/java/src/org/openqa/selenium/grid/config/ConcatenatingConfig.java index 56e7b6e72bbd2..6f9aa1691624f 100644 --- a/java/src/org/openqa/selenium/grid/config/ConcatenatingConfig.java +++ b/java/src/org/openqa/selenium/grid/config/ConcatenatingConfig.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; public class ConcatenatingConfig implements Config { @@ -34,7 +35,7 @@ public class ConcatenatingConfig implements Config { private final char separator; private final Map values; - public ConcatenatingConfig(String prefix, char separator, Map values) { + public ConcatenatingConfig(@Nullable String prefix, char separator, Map values) { this.prefix = prefix == null || prefix.isEmpty() ? "" : (prefix + separator); this.separator = separator; diff --git a/java/src/org/openqa/selenium/grid/config/Config.java b/java/src/org/openqa/selenium/grid/config/Config.java index c4eb7e2263519..16424bd5a7aa1 100644 --- a/java/src/org/openqa/selenium/grid/config/Config.java +++ b/java/src/org/openqa/selenium/grid/config/Config.java @@ -24,6 +24,7 @@ import java.util.Set; import java.util.logging.Logger; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.json.Json; public interface Config { @@ -34,7 +35,7 @@ public interface Config { Optional> getAll(String section, String option); - default Optional get(String section, String option) { + default Optional<@Nullable String> get(String section, String option) { return getAll(section, option).map(items -> items.isEmpty() ? null : items.get(0)); } diff --git a/java/src/org/openqa/selenium/grid/config/ConfigException.java b/java/src/org/openqa/selenium/grid/config/ConfigException.java index cf4a0bcd818a7..e396ca26ae671 100644 --- a/java/src/org/openqa/selenium/grid/config/ConfigException.java +++ b/java/src/org/openqa/selenium/grid/config/ConfigException.java @@ -17,13 +17,15 @@ package org.openqa.selenium.grid.config; +import org.jspecify.annotations.Nullable; + public class ConfigException extends RuntimeException { public ConfigException(String message, Object... args) { super(String.format(message, args)); } - public ConfigException(Throwable cause) { + public ConfigException(@Nullable Throwable cause) { super(cause); } } diff --git a/java/src/org/openqa/selenium/grid/config/MemoizedConfig.java b/java/src/org/openqa/selenium/grid/config/MemoizedConfig.java index a111b5af68759..34c8a34826c3a 100644 --- a/java/src/org/openqa/selenium/grid/config/MemoizedConfig.java +++ b/java/src/org/openqa/selenium/grid/config/MemoizedConfig.java @@ -24,6 +24,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; public class MemoizedConfig implements Config { @@ -61,7 +62,7 @@ public Optional> getAll(String section, String option) { } @Override - public Optional get(String section, String option) { + public Optional<@Nullable String> get(String section, String option) { Require.nonNull("Section name", section); Require.nonNull("Option", option); @@ -94,7 +95,7 @@ public X getClass(String section, String option, Class typeOfX, String de Require.nonNull("Type to load", typeOfX); Require.nonNull("Default class name", defaultClassName); - AtomicReference thrown = new AtomicReference<>(); + AtomicReference<@Nullable Exception> thrown = new AtomicReference<>(); Object value = seenClasses.computeIfAbsent( new Key(section, option, typeOfX.toGenericString(), defaultClassName), diff --git a/java/src/org/openqa/selenium/grid/config/package-info.java b/java/src/org/openqa/selenium/grid/config/package-info.java new file mode 100644 index 0000000000000..f1393763f3149 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/config/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.config; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/data/BUILD.bazel b/java/src/org/openqa/selenium/grid/data/BUILD.bazel index 761bf51d92992..175cf061f84eb 100644 --- a/java/src/org/openqa/selenium/grid/data/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/data/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") load("//java:defs.bzl", "java_library") java_library( @@ -15,5 +16,6 @@ java_library( "//java/src/org/openqa/selenium/grid/security", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/data/CapabilityCount.java b/java/src/org/openqa/selenium/grid/data/CapabilityCount.java index f04af7680fc2c..4e2e877b8831a 100644 --- a/java/src/org/openqa/selenium/grid/data/CapabilityCount.java +++ b/java/src/org/openqa/selenium/grid/data/CapabilityCount.java @@ -60,6 +60,7 @@ private Object toJson() { UNORDERED)); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static CapabilityCount fromJson(JsonInput input) { Map toReturn = new HashMap<>(); diff --git a/java/src/org/openqa/selenium/grid/data/CreateSessionRequest.java b/java/src/org/openqa/selenium/grid/data/CreateSessionRequest.java index e89ba10f81886..3c811f91af481 100644 --- a/java/src/org/openqa/selenium/grid/data/CreateSessionRequest.java +++ b/java/src/org/openqa/selenium/grid/data/CreateSessionRequest.java @@ -56,6 +56,7 @@ public Map getMetadata() { return metadata; } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static CreateSessionRequest fromJson(JsonInput input) { Set downstreamDialects = null; Capabilities capabilities = null; diff --git a/java/src/org/openqa/selenium/grid/data/CreateSessionResponse.java b/java/src/org/openqa/selenium/grid/data/CreateSessionResponse.java index 10a0e24c9f585..0cf73a231d77a 100644 --- a/java/src/org/openqa/selenium/grid/data/CreateSessionResponse.java +++ b/java/src/org/openqa/selenium/grid/data/CreateSessionResponse.java @@ -52,6 +52,7 @@ private Map toJson() { return unmodifiableMap(toReturn); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static CreateSessionResponse fromJson(JsonInput input) { Session session = null; byte[] downstreamResponse = null; diff --git a/java/src/org/openqa/selenium/grid/data/DistributorStatus.java b/java/src/org/openqa/selenium/grid/data/DistributorStatus.java index d989ac3c842d5..208d5d6bceba2 100644 --- a/java/src/org/openqa/selenium/grid/data/DistributorStatus.java +++ b/java/src/org/openqa/selenium/grid/data/DistributorStatus.java @@ -52,6 +52,7 @@ private Map toJson() { return Collections.singletonMap("nodes", getNodes()); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static DistributorStatus fromJson(JsonInput input) { Set nodes = null; diff --git a/java/src/org/openqa/selenium/grid/data/NewSessionErrorResponse.java b/java/src/org/openqa/selenium/grid/data/NewSessionErrorResponse.java index 1cfba61c8e111..a7e9dc363c750 100644 --- a/java/src/org/openqa/selenium/grid/data/NewSessionErrorResponse.java +++ b/java/src/org/openqa/selenium/grid/data/NewSessionErrorResponse.java @@ -49,6 +49,7 @@ private Map toJson() { return unmodifiableMap(toReturn); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static NewSessionErrorResponse fromJson(JsonInput input) { String message = null; RequestId requestId = null; diff --git a/java/src/org/openqa/selenium/grid/data/NewSessionResponse.java b/java/src/org/openqa/selenium/grid/data/NewSessionResponse.java index c0a20a443f97a..5e7d7599909c1 100644 --- a/java/src/org/openqa/selenium/grid/data/NewSessionResponse.java +++ b/java/src/org/openqa/selenium/grid/data/NewSessionResponse.java @@ -60,6 +60,7 @@ private Map toJson() { return unmodifiableMap(toReturn); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static NewSessionResponse fromJson(JsonInput input) { RequestId requestId = null; Session session = null; diff --git a/java/src/org/openqa/selenium/grid/data/NodeId.java b/java/src/org/openqa/selenium/grid/data/NodeId.java index cae18693b8fa8..f5393d356a86a 100644 --- a/java/src/org/openqa/selenium/grid/data/NodeId.java +++ b/java/src/org/openqa/selenium/grid/data/NodeId.java @@ -64,6 +64,7 @@ private Object toJson() { return uuid; } + @SuppressWarnings({"unused"}) private static NodeId fromJson(UUID id) { return new NodeId(id); } diff --git a/java/src/org/openqa/selenium/grid/data/NodeStatus.java b/java/src/org/openqa/selenium/grid/data/NodeStatus.java index 010ed40b81281..aabf309d09a99 100644 --- a/java/src/org/openqa/selenium/grid/data/NodeStatus.java +++ b/java/src/org/openqa/selenium/grid/data/NodeStatus.java @@ -69,7 +69,8 @@ public NodeStatus( this.osInfo = Require.nonNull("Node host OS info", osInfo); } - public static NodeStatus fromJson(JsonInput input) { + @SuppressWarnings({"unused", "DataFlowIssue"}) + private static NodeStatus fromJson(JsonInput input) { NodeId nodeId = null; URI externalUri = null; int maxSessions = 0; diff --git a/java/src/org/openqa/selenium/grid/data/RequestId.java b/java/src/org/openqa/selenium/grid/data/RequestId.java index 1b7ddd3fb3af6..7bdae7ca2c7c5 100644 --- a/java/src/org/openqa/selenium/grid/data/RequestId.java +++ b/java/src/org/openqa/selenium/grid/data/RequestId.java @@ -57,6 +57,7 @@ private Object toJson() { return uuid; } + @SuppressWarnings({"unused"}) private static RequestId fromJson(UUID id) { return new RequestId(id); } diff --git a/java/src/org/openqa/selenium/grid/data/Session.java b/java/src/org/openqa/selenium/grid/data/Session.java index 061cbfb6b212e..7d51e6023f0a3 100644 --- a/java/src/org/openqa/selenium/grid/data/Session.java +++ b/java/src/org/openqa/selenium/grid/data/Session.java @@ -90,6 +90,7 @@ private Map toJson() { return unmodifiableMap(toReturn); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static Session fromJson(JsonInput input) { SessionId id = null; URI uri = null; @@ -101,7 +102,7 @@ private static Session fromJson(JsonInput input) { while (input.hasNext()) { switch (input.nextName()) { case "capabilities": - caps = ImmutableCapabilities.copyOf(input.read(Capabilities.class)); + caps = ImmutableCapabilities.copyOf(input.readNonNull(Capabilities.class)); break; case "sessionId": diff --git a/java/src/org/openqa/selenium/grid/data/SessionClosedData.java b/java/src/org/openqa/selenium/grid/data/SessionClosedData.java index 4fdfac5f10ebc..1dff56c87b150 100644 --- a/java/src/org/openqa/selenium/grid/data/SessionClosedData.java +++ b/java/src/org/openqa/selenium/grid/data/SessionClosedData.java @@ -22,6 +22,7 @@ import java.time.Instant; import java.util.Map; import java.util.TreeMap; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.internal.Require; import org.openqa.selenium.json.JsonException; @@ -37,10 +38,10 @@ public class SessionClosedData { private final SessionId sessionId; private final SessionClosedReason reason; - private final NodeId nodeId; - private final URI nodeUri; - private final Capabilities capabilities; - private final Instant startTime; + private final @Nullable NodeId nodeId; + private final @Nullable URI nodeUri; + private final @Nullable Capabilities capabilities; + private final @Nullable Instant startTime; private final Instant endTime; /** Backward compatible constructor for existing code. */ @@ -52,10 +53,10 @@ public SessionClosedData(SessionId sessionId, SessionClosedReason reason) { public SessionClosedData( SessionId sessionId, SessionClosedReason reason, - NodeId nodeId, - URI nodeUri, - Capabilities capabilities, - Instant startTime, + @Nullable NodeId nodeId, + @Nullable URI nodeUri, + @Nullable Capabilities capabilities, + @Nullable Instant startTime, Instant endTime) { this.sessionId = Require.nonNull("Session ID", sessionId); this.reason = Require.nonNull("Reason", reason); @@ -74,18 +75,22 @@ public SessionClosedReason getReason() { return reason; } + @Nullable public NodeId getNodeId() { return nodeId; } + @Nullable public URI getNodeUri() { return nodeUri; } + @Nullable public Capabilities getCapabilities() { return capabilities; } + @Nullable public Instant getStartTime() { return startTime; } @@ -99,8 +104,9 @@ public Instant getEndTime() { * * @return the session duration, or null if start time was not recorded */ + @Nullable public Duration getSessionDuration() { - if (startTime == null || endTime == null) { + if (startTime == null) { return null; } return Duration.between(startTime, endTime); @@ -132,6 +138,7 @@ private Map toJson() { return result; } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static SessionClosedData fromJson(JsonInput input) { SessionId sessionId = null; SessionClosedReason reason = null; diff --git a/java/src/org/openqa/selenium/grid/data/SessionCreatedData.java b/java/src/org/openqa/selenium/grid/data/SessionCreatedData.java index e3900efb8eea3..82b6d7630053d 100644 --- a/java/src/org/openqa/selenium/grid/data/SessionCreatedData.java +++ b/java/src/org/openqa/selenium/grid/data/SessionCreatedData.java @@ -104,6 +104,7 @@ private Map toJson() { return toReturn; } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static SessionCreatedData fromJson(JsonInput input) { SessionId sessionId = null; NodeId nodeId = null; diff --git a/java/src/org/openqa/selenium/grid/data/SessionEventData.java b/java/src/org/openqa/selenium/grid/data/SessionEventData.java index a2a7977e19142..ad3d8cb9bc131 100644 --- a/java/src/org/openqa/selenium/grid/data/SessionEventData.java +++ b/java/src/org/openqa/selenium/grid/data/SessionEventData.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.Map; import java.util.TreeMap; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; import org.openqa.selenium.json.JsonInput; import org.openqa.selenium.remote.SessionId; @@ -75,18 +76,18 @@ public class SessionEventData { private final SessionId sessionId; private final String eventType; - private final NodeId nodeId; - private final URI nodeUri; + private final @Nullable NodeId nodeId; + private final @Nullable URI nodeUri; private final Instant timestamp; private final Map payload; public SessionEventData( SessionId sessionId, String eventType, - NodeId nodeId, - URI nodeUri, + @Nullable NodeId nodeId, + @Nullable URI nodeUri, Instant timestamp, - Map payload) { + @Nullable Map payload) { this.sessionId = Require.nonNull("Session ID", sessionId); this.eventType = Require.nonNull("Event type", eventType); if (!eventType.matches("^[a-zA-Z][a-zA-Z0-9:._-]*$")) { @@ -133,10 +134,12 @@ public String getEventType() { return eventType; } + @Nullable public NodeId getNodeId() { return nodeId; } + @Nullable public URI getNodeUri() { return nodeUri; } @@ -165,6 +168,7 @@ public Object get(String key) { * @param key the key to look up * @return the string value, or null if not present or not a string */ + @Nullable public String getString(String key) { Object value = payload.get(key); return value instanceof String ? (String) value : null; @@ -212,7 +216,7 @@ private Map toJson() { return result; } - @SuppressWarnings("unchecked") + @SuppressWarnings({"unused", "DataFlowIssue"}) private static SessionEventData fromJson(JsonInput input) { SessionId sessionId = null; String eventType = null; diff --git a/java/src/org/openqa/selenium/grid/data/SessionRemovalInfo.java b/java/src/org/openqa/selenium/grid/data/SessionRemovalInfo.java index 6cbe86e8af98b..353771724db06 100644 --- a/java/src/org/openqa/selenium/grid/data/SessionRemovalInfo.java +++ b/java/src/org/openqa/selenium/grid/data/SessionRemovalInfo.java @@ -20,13 +20,14 @@ import java.net.URI; import java.time.Duration; import java.time.Instant; +import org.jspecify.annotations.Nullable; public class SessionRemovalInfo { private final Instant removedAt; private final String reason; - private final URI nodeUri; + private final @Nullable URI nodeUri; - public SessionRemovalInfo(String reason, URI nodeUri) { + public SessionRemovalInfo(String reason, @Nullable URI nodeUri) { this.removedAt = Instant.now(); this.reason = reason; this.nodeUri = nodeUri; diff --git a/java/src/org/openqa/selenium/grid/data/SessionRequest.java b/java/src/org/openqa/selenium/grid/data/SessionRequest.java index aa299de2a8129..3cd4f5f7572e5 100644 --- a/java/src/org/openqa/selenium/grid/data/SessionRequest.java +++ b/java/src/org/openqa/selenium/grid/data/SessionRequest.java @@ -164,6 +164,7 @@ private Map toJson() { return unmodifiableMap(toReturn); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static SessionRequest fromJson(JsonInput input) { RequestId id = null; Instant enqueued = null; diff --git a/java/src/org/openqa/selenium/grid/data/SessionRequestCapability.java b/java/src/org/openqa/selenium/grid/data/SessionRequestCapability.java index 7091e57cd2b61..c86c7a3199896 100644 --- a/java/src/org/openqa/selenium/grid/data/SessionRequestCapability.java +++ b/java/src/org/openqa/selenium/grid/data/SessionRequestCapability.java @@ -17,11 +17,9 @@ package org.openqa.selenium.grid.data; -import static java.util.Collections.unmodifiableMap; import static java.util.Collections.unmodifiableSet; import java.lang.reflect.Type; -import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Objects; @@ -77,12 +75,10 @@ public int hashCode() { } private Map toJson() { - Map toReturn = new HashMap<>(); - toReturn.put("requestId", requestId); - toReturn.put("capabilities", desiredCapabilities); - return unmodifiableMap(toReturn); + return Map.of("requestId", requestId, "capabilities", desiredCapabilities); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static SessionRequestCapability fromJson(JsonInput input) { RequestId id = null; Set capabilities = null; diff --git a/java/src/org/openqa/selenium/grid/data/Slot.java b/java/src/org/openqa/selenium/grid/data/Slot.java index 938b8debb048e..007c109de6a94 100644 --- a/java/src/org/openqa/selenium/grid/data/Slot.java +++ b/java/src/org/openqa/selenium/grid/data/Slot.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.TreeMap; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.internal.Require; @@ -33,16 +34,17 @@ public class Slot implements Serializable { private final SlotId id; private final Capabilities stereotype; - private final Session session; + private final @Nullable Session session; private final Instant lastStarted; - public Slot(SlotId id, Capabilities stereotype, Instant lastStarted, Session session) { + public Slot(SlotId id, Capabilities stereotype, Instant lastStarted, @Nullable Session session) { this.id = Require.nonNull("Slot ID", id); this.stereotype = ImmutableCapabilities.copyOf(Require.nonNull("Stereotype", stereotype)); this.lastStarted = Require.nonNull("Last started", lastStarted); this.session = session; } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static Slot fromJson(JsonInput input) { SlotId id = null; Capabilities stereotype = null; @@ -79,8 +81,8 @@ private static Slot fromJson(JsonInput input) { return new Slot(id, stereotype, lastStarted, session); } - private Map toJson() { - Map toReturn = new TreeMap<>(); + private Map toJson() { + Map toReturn = new TreeMap<>(); toReturn.put("id", getId()); toReturn.put("lastStarted", getLastStarted()); toReturn.put("session", getSession()); @@ -100,6 +102,7 @@ public Instant getLastStarted() { return lastStarted; } + @Nullable public Session getSession() { return session; } diff --git a/java/src/org/openqa/selenium/grid/data/SlotId.java b/java/src/org/openqa/selenium/grid/data/SlotId.java index 8a126e4071a01..20d4bbb9ce7ab 100644 --- a/java/src/org/openqa/selenium/grid/data/SlotId.java +++ b/java/src/org/openqa/selenium/grid/data/SlotId.java @@ -71,6 +71,7 @@ private Object toJson() { return unmodifiableMap(toReturn); } + @SuppressWarnings({"unused", "DataFlowIssue"}) private static SlotId fromJson(JsonInput input) { NodeId nodeId = null; UUID id = null; diff --git a/java/src/org/openqa/selenium/grid/data/package-info.java b/java/src/org/openqa/selenium/grid/data/package-info.java new file mode 100644 index 0000000000000..578785a87ff4c --- /dev/null +++ b/java/src/org/openqa/selenium/grid/data/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.data; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/distributor/AddNode.java b/java/src/org/openqa/selenium/grid/distributor/AddNode.java index 904faaab6f1ec..4ca2e6d05528c 100644 --- a/java/src/org/openqa/selenium/grid/distributor/AddNode.java +++ b/java/src/org/openqa/selenium/grid/distributor/AddNode.java @@ -17,8 +17,6 @@ package org.openqa.selenium.grid.distributor; -import static org.openqa.selenium.remote.http.Contents.string; - import java.util.stream.Collectors; import org.openqa.selenium.grid.data.NodeStatus; import org.openqa.selenium.grid.data.Slot; @@ -56,7 +54,7 @@ class AddNode implements HttpHandler { @Override public HttpResponse execute(HttpRequest req) { - NodeStatus status = json.toType(string(req), NodeStatus.class); + NodeStatus status = json.toType(req.contentAsString(), NodeStatus.class); Node node = new RemoteNode( diff --git a/java/src/org/openqa/selenium/grid/distributor/BUILD.bazel b/java/src/org/openqa/selenium/grid/distributor/BUILD.bazel index ad61b54f26067..1e3fa9cdd4ec2 100644 --- a/java/src/org/openqa/selenium/grid/distributor/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/distributor/BUILD.bazel @@ -29,5 +29,6 @@ java_library( "//java/src/org/openqa/selenium/grid/sessionmap/remote", "//java/src/org/openqa/selenium/status", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/distributor/GridModel.java b/java/src/org/openqa/selenium/grid/distributor/GridModel.java index b975a16082d94..382b030d68311 100644 --- a/java/src/org/openqa/selenium/grid/distributor/GridModel.java +++ b/java/src/org/openqa/selenium/grid/distributor/GridModel.java @@ -18,6 +18,7 @@ package org.openqa.selenium.grid.distributor; import java.util.Set; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.grid.data.Availability; import org.openqa.selenium.grid.data.NodeId; import org.openqa.selenium.grid.data.NodeStatus; @@ -101,7 +102,7 @@ public abstract class GridModel { * @param slotId The ID of the slot to update * @param session The session to associate with the slot, or null to clear */ - public abstract void setSession(SlotId slotId, Session session); + public abstract void setSession(SlotId slotId, @Nullable Session session); /** * Updates the health check count for a node based on its availability. diff --git a/java/src/org/openqa/selenium/grid/distributor/NodeRegistry.java b/java/src/org/openqa/selenium/grid/distributor/NodeRegistry.java index ec3ae18b7ee51..d56a71363b7a9 100644 --- a/java/src/org/openqa/selenium/grid/distributor/NodeRegistry.java +++ b/java/src/org/openqa/selenium/grid/distributor/NodeRegistry.java @@ -20,6 +20,7 @@ import java.io.Closeable; import java.net.URI; import java.util.Set; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.grid.data.Availability; import org.openqa.selenium.grid.data.DistributorStatus; import org.openqa.selenium.grid.data.NodeId; @@ -137,7 +138,7 @@ public interface NodeRegistry extends HasReadyState, Closeable { * @param slotId The slot ID. * @param session The session to associate with the slot, or null to clear. */ - void setSession(SlotId slotId, Session session); + void setSession(SlotId slotId, @Nullable Session session); /** Get the number of active slots. */ int getActiveSlots(); @@ -151,5 +152,5 @@ public interface NodeRegistry extends HasReadyState, Closeable { * @param uri The node URI to look up. * @return The node if found, null otherwise. */ - Node getNode(URI uri); + @Nullable Node getNode(URI uri); } diff --git a/java/src/org/openqa/selenium/grid/distributor/config/BUILD.bazel b/java/src/org/openqa/selenium/grid/distributor/config/BUILD.bazel index f75ae55ab7720..9170f39c8f6c2 100644 --- a/java/src/org/openqa/selenium/grid/distributor/config/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/distributor/config/BUILD.bazel @@ -15,5 +15,6 @@ java_library( "//java/src/org/openqa/selenium/grid/distributor/selector", "//java/src/org/openqa/selenium/grid/server", artifact("com.beust:jcommander"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/distributor/config/package-info.java b/java/src/org/openqa/selenium/grid/distributor/config/package-info.java new file mode 100644 index 0000000000000..7e9ae3048559a --- /dev/null +++ b/java/src/org/openqa/selenium/grid/distributor/config/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.distributor.config; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/distributor/httpd/BUILD.bazel b/java/src/org/openqa/selenium/grid/distributor/httpd/BUILD.bazel index d538459e0b681..1d1b1c0934b0d 100644 --- a/java/src/org/openqa/selenium/grid/distributor/httpd/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/distributor/httpd/BUILD.bazel @@ -23,5 +23,6 @@ java_library( "//java/src/org/openqa/selenium/netty/server", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/distributor/httpd/package-info.java b/java/src/org/openqa/selenium/grid/distributor/httpd/package-info.java new file mode 100644 index 0000000000000..67ddea44d69dd --- /dev/null +++ b/java/src/org/openqa/selenium/grid/distributor/httpd/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.distributor.httpd; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/distributor/local/BUILD.bazel b/java/src/org/openqa/selenium/grid/distributor/local/BUILD.bazel index 23552ec528410..c1af8f44e9aa1 100644 --- a/java/src/org/openqa/selenium/grid/distributor/local/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/distributor/local/BUILD.bazel @@ -31,5 +31,6 @@ java_library( "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java b/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java index 5a46acdfd93b1..c888c5a0b836e 100644 --- a/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java +++ b/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java @@ -50,6 +50,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Beta; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; @@ -448,6 +449,7 @@ private CreateSessionResponse startSession( return result.right(); } + @Nullable private SlotId reserveSlot(RequestId requestId, Capabilities caps) { // Use read lock for slot selection to allow concurrent reads // This reduces contention compared to using write lock for the entire operation @@ -679,6 +681,7 @@ private void handleNewSessionRequest(SessionRequest sessionRequest) { } } + @Nullable protected Node getNodeFromURI(URI uri) { Lock readLock = this.lock.readLock(); readLock.lock(); diff --git a/java/src/org/openqa/selenium/grid/distributor/local/LocalGridModel.java b/java/src/org/openqa/selenium/grid/distributor/local/LocalGridModel.java index cf4985b299b0a..15d271ae8f706 100644 --- a/java/src/org/openqa/selenium/grid/distributor/local/LocalGridModel.java +++ b/java/src/org/openqa/selenium/grid/distributor/local/LocalGridModel.java @@ -34,6 +34,7 @@ import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.events.EventBus; import org.openqa.selenium.grid.config.Config; import org.openqa.selenium.grid.data.Availability; @@ -169,7 +170,7 @@ public void refresh(NodeStatus status) { if (node.getNodeId().equals(status.getNodeId())) { iterator.remove(); - // if the node was marked as "down", keep it down until a healthcheck passes: + // if the node was marked as "down", keep it down until a health check passes: // just because the node can hit the event bus doesn't mean it's reachable if (node.getAvailability() == DOWN) { nodes.add(rewrite(status, DOWN)); @@ -375,6 +376,7 @@ public Set getSnapshot() { } } + @Nullable private NodeStatus getNode(NodeId id) { Require.nonNull("Node ID", id); @@ -401,7 +403,7 @@ private NodeStatus rewrite(NodeStatus status, Availability availability) { } @Override - public void release(SessionId id) { + public void release(@Nullable SessionId id) { if (id == null) { return; } @@ -430,7 +432,7 @@ public void release(SessionId id) { } @Override - public void setSession(SlotId slotId, Session session) { + public void setSession(SlotId slotId, @Nullable Session session) { Require.nonNull("Slot ID", slotId); Lock writeLock = lock.writeLock(); diff --git a/java/src/org/openqa/selenium/grid/distributor/local/LocalNodeRegistry.java b/java/src/org/openqa/selenium/grid/distributor/local/LocalNodeRegistry.java index 77612fae8aaea..bcd21fc1675d0 100644 --- a/java/src/org/openqa/selenium/grid/distributor/local/LocalNodeRegistry.java +++ b/java/src/org/openqa/selenium/grid/distributor/local/LocalNodeRegistry.java @@ -46,6 +46,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.HealthCheckFailedException; import org.openqa.selenium.concurrent.GuardedRunnable; import org.openqa.selenium.events.EventBus; @@ -519,7 +520,7 @@ public boolean reserve(SlotId slotId) { } @Override - public void setSession(SlotId slotId, Session session) { + public void setSession(SlotId slotId, @Nullable Session session) { Lock writeLock = lock.writeLock(); writeLock.lock(); try { @@ -565,6 +566,7 @@ public int getIdleSlots() { * @param uri The URI of the node to find * @return The node if found, null otherwise */ + @Nullable public Node getNode(URI uri) { Lock readLock = this.lock.readLock(); readLock.lock(); diff --git a/java/src/org/openqa/selenium/grid/distributor/local/package-info.java b/java/src/org/openqa/selenium/grid/distributor/local/package-info.java new file mode 100644 index 0000000000000..29ed00ad32e30 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/distributor/local/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.distributor.local; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/distributor/package-info.java b/java/src/org/openqa/selenium/grid/distributor/package-info.java index 10f446636bfd2..afd36d731e800 100644 --- a/java/src/org/openqa/selenium/grid/distributor/package-info.java +++ b/java/src/org/openqa/selenium/grid/distributor/package-info.java @@ -27,4 +27,7 @@ * that dialects match, or that a converter of some sort is added. The Node may be the part of the * system responsible for adding this converter. */ +@NullMarked package org.openqa.selenium.grid.distributor; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/distributor/remote/BUILD.bazel b/java/src/org/openqa/selenium/grid/distributor/remote/BUILD.bazel index 53ebc5ce3f192..949a8d397b31f 100644 --- a/java/src/org/openqa/selenium/grid/distributor/remote/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/distributor/remote/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") load("//java:defs.bzl", "java_library") java_library( @@ -17,5 +18,6 @@ java_library( "//java/src/org/openqa/selenium/grid/sessionmap", "//java/src/org/openqa/selenium/grid/web", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/distributor/remote/package-info.java b/java/src/org/openqa/selenium/grid/distributor/remote/package-info.java new file mode 100644 index 0000000000000..7a6aaf0fa411c --- /dev/null +++ b/java/src/org/openqa/selenium/grid/distributor/remote/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.distributor.remote; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/distributor/selector/BUILD.bazel b/java/src/org/openqa/selenium/grid/distributor/selector/BUILD.bazel index 867ab5082ec90..a1ade3769a079 100644 --- a/java/src/org/openqa/selenium/grid/distributor/selector/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/distributor/selector/BUILD.bazel @@ -13,5 +13,6 @@ java_library( "//java/src/org/openqa/selenium/grid/config", "//java/src/org/openqa/selenium/grid/data", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/distributor/selector/package-info.java b/java/src/org/openqa/selenium/grid/distributor/selector/package-info.java new file mode 100644 index 0000000000000..0344710f634c2 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/distributor/selector/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.distributor.selector; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/graphql/BUILD.bazel b/java/src/org/openqa/selenium/grid/graphql/BUILD.bazel index a9a9c0a7b5953..102f20a1987fe 100644 --- a/java/src/org/openqa/selenium/grid/graphql/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/graphql/BUILD.bazel @@ -22,5 +22,6 @@ java_library( artifact("com.google.guava:guava"), artifact("com.graphql-java:graphql-java"), artifact("com.github.ben-manes.caffeine:caffeine"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/graphql/SessionData.java b/java/src/org/openqa/selenium/grid/graphql/SessionData.java index 94396fbd2f942..92bc288d58aaa 100644 --- a/java/src/org/openqa/selenium/grid/graphql/SessionData.java +++ b/java/src/org/openqa/selenium/grid/graphql/SessionData.java @@ -20,6 +20,7 @@ import graphql.schema.DataFetcher; import graphql.schema.DataFetchingEnvironment; import java.util.Set; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.grid.data.NodeStatus; import org.openqa.selenium.grid.data.Slot; import org.openqa.selenium.grid.distributor.Distributor; @@ -62,6 +63,7 @@ public org.openqa.selenium.grid.graphql.Session get(DataFetchingEnvironment envi } } + @Nullable private SessionInSlot findSession(String sessionId, Set nodeStatuses) { for (NodeStatus status : nodeStatuses) { for (Slot slot : status.getSlots()) { diff --git a/java/src/org/openqa/selenium/grid/graphql/Types.java b/java/src/org/openqa/selenium/grid/graphql/Types.java index 5bb2a21fa2b1a..da1857d3e0113 100644 --- a/java/src/org/openqa/selenium/grid/graphql/Types.java +++ b/java/src/org/openqa/selenium/grid/graphql/Types.java @@ -27,6 +27,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.net.URL; +import org.jspecify.annotations.Nullable; class Types { @@ -53,10 +54,6 @@ public String serialize(Object o) throws CoercingSerializeException { @Override public URI parseValue(Object input) throws CoercingParseValueException { - if (input == null) { - return null; - } - if (input instanceof URI) { return (URI) input; } @@ -110,8 +107,9 @@ public String serialize(Object o) throws CoercingSerializeException { throw new CoercingSerializeException("Unable to coerce " + o); } + @Nullable @Override - public URL parseValue(Object input) throws CoercingParseValueException { + public URL parseValue(@Nullable Object input) throws CoercingParseValueException { if (input == null) { return null; } diff --git a/java/src/org/openqa/selenium/grid/graphql/package-info.java b/java/src/org/openqa/selenium/grid/graphql/package-info.java new file mode 100644 index 0000000000000..df07cd66175e3 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/graphql/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.graphql; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/jmx/JMXHelper.java b/java/src/org/openqa/selenium/grid/jmx/JMXHelper.java index 54b8644eec577..8dd73bd8ecd37 100644 --- a/java/src/org/openqa/selenium/grid/jmx/JMXHelper.java +++ b/java/src/org/openqa/selenium/grid/jmx/JMXHelper.java @@ -22,10 +22,8 @@ import javax.management.InstanceAlreadyExistsException; import javax.management.MBeanServer; import javax.management.ObjectName; -import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; -@NullMarked public class JMXHelper { private static final Logger LOG = Logger.getLogger(JMXHelper.class.getName()); @@ -44,7 +42,7 @@ public class JMXHelper { } } - public void unregister(ObjectName objectName) { + public void unregister(@Nullable ObjectName objectName) { if (objectName != null) { MBeanServer mbs = ManagementFactory.getPlatformMBeanServer(); try { diff --git a/java/src/org/openqa/selenium/grid/jmx/MBean.java b/java/src/org/openqa/selenium/grid/jmx/MBean.java index dce77f1da0746..3a7c7511e9a33 100644 --- a/java/src/org/openqa/selenium/grid/jmx/MBean.java +++ b/java/src/org/openqa/selenium/grid/jmx/MBean.java @@ -99,7 +99,7 @@ MBeanOperationInfo getMBeanOperationInfo() { String name = bean.getClass().getName(); String description = mBean.description(); collectAttributeInfo(bean); - MBeanAttributeInfo[] attributes = + @Nullable MBeanAttributeInfo[] attributes = attributeMap.values().stream() .map(AttributeInfo::getMBeanAttributeInfo) .toArray(MBeanAttributeInfo[]::new); @@ -254,7 +254,7 @@ public void setAttribute(Attribute attribute) { } @Override - public AttributeList getAttributes(String[] attributes) { + public AttributeList getAttributes(String @Nullable [] attributes) { AttributeList resultList = new AttributeList(); // if attributeNames is empty, return an empty result list diff --git a/java/src/org/openqa/selenium/grid/jmx/package-info.java b/java/src/org/openqa/selenium/grid/jmx/package-info.java new file mode 100644 index 0000000000000..51159a8e5004d --- /dev/null +++ b/java/src/org/openqa/selenium/grid/jmx/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.jmx; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/log/BUILD.bazel b/java/src/org/openqa/selenium/grid/log/BUILD.bazel index a8d998a81e209..68c796f96396b 100644 --- a/java/src/org/openqa/selenium/grid/log/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/log/BUILD.bazel @@ -14,5 +14,6 @@ java_library( "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", artifact("com.beust:jcommander"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/log/FlushingHandler.java b/java/src/org/openqa/selenium/grid/log/FlushingHandler.java index 8adda6023719c..b556231dcbbd6 100644 --- a/java/src/org/openqa/selenium/grid/log/FlushingHandler.java +++ b/java/src/org/openqa/selenium/grid/log/FlushingHandler.java @@ -21,10 +21,11 @@ import java.util.Objects; import java.util.logging.LogRecord; import java.util.logging.StreamHandler; +import org.jspecify.annotations.Nullable; class FlushingHandler extends StreamHandler { - private OutputStream out; + @Nullable private OutputStream out; FlushingHandler(OutputStream out) { setOutputStream(out); diff --git a/java/src/org/openqa/selenium/grid/log/LoggingOptions.java b/java/src/org/openqa/selenium/grid/log/LoggingOptions.java index 20cb5a770e68a..b812bac69f186 100644 --- a/java/src/org/openqa/selenium/grid/log/LoggingOptions.java +++ b/java/src/org/openqa/selenium/grid/log/LoggingOptions.java @@ -30,6 +30,7 @@ import java.util.logging.Level; import java.util.logging.LogManager; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.grid.config.Config; import org.openqa.selenium.grid.config.ConfigException; import org.openqa.selenium.internal.Debug; @@ -79,6 +80,7 @@ public boolean isUsingPlainLogs() { return config.getBool(LOGGING_SECTION, "plain-logs").orElse(DEFAULT_PLAIN_LOGS); } + @Nullable public String getLogEncoding() { return config.get(LOGGING_SECTION, "log-encoding").orElse(null); } @@ -169,7 +171,7 @@ public void configureLogging() { } } - private void configureLogEncoding(Logger logger, String encoding, Handler handler) { + private void configureLogEncoding(Logger logger, @Nullable String encoding, Handler handler) { String message; try { if (encoding != null) { diff --git a/java/src/org/openqa/selenium/grid/log/package-info.java b/java/src/org/openqa/selenium/grid/log/package-info.java new file mode 100644 index 0000000000000..fd4c74b6d998a --- /dev/null +++ b/java/src/org/openqa/selenium/grid/log/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.log; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/Node.java b/java/src/org/openqa/selenium/grid/node/Node.java index 1d933e33c6c6a..8171a0f753d75 100644 --- a/java/src/org/openqa/selenium/grid/node/Node.java +++ b/java/src/org/openqa/selenium/grid/node/Node.java @@ -35,6 +35,7 @@ import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.BuildInfo; import org.openqa.selenium.Capabilities; import org.openqa.selenium.NoSuchSessionException; @@ -255,8 +256,10 @@ public TemporaryFilesystem getDownloadsFilesystem(SessionId id) throws IOExcepti throw new UnsupportedOperationException(); } + @Nullable public abstract HttpResponse uploadFile(HttpRequest req, SessionId id); + @Nullable public abstract HttpResponse downloadFile(HttpRequest req, SessionId id); /** diff --git a/java/src/org/openqa/selenium/grid/node/config/BUILD.bazel b/java/src/org/openqa/selenium/grid/node/config/BUILD.bazel index d71932bada408..6c747672982fe 100644 --- a/java/src/org/openqa/selenium/grid/node/config/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/node/config/BUILD.bazel @@ -19,5 +19,6 @@ java_library( "//java/src/org/openqa/selenium/remote", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/node/config/DriverServiceSessionFactory.java b/java/src/org/openqa/selenium/grid/node/config/DriverServiceSessionFactory.java index 3a1b8aa09f2de..14a74962fdc6d 100644 --- a/java/src/org/openqa/selenium/grid/node/config/DriverServiceSessionFactory.java +++ b/java/src/org/openqa/selenium/grid/node/config/DriverServiceSessionFactory.java @@ -35,6 +35,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Stream; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.MutableCapabilities; @@ -324,6 +325,7 @@ private Capabilities setInitialCapabilityValue(Capabilities caps, String key, Ob return new PersistentCapabilities(caps).setCapability(key, value); } + @Nullable private String getHost() { try { return new NetworkUtils().getNonLoopbackAddressOfThisMachine(); diff --git a/java/src/org/openqa/selenium/grid/node/config/NodeOptions.java b/java/src/org/openqa/selenium/grid/node/config/NodeOptions.java index 23c4616039dff..be9cf1f76d08c 100644 --- a/java/src/org/openqa/selenium/grid/node/config/NodeOptions.java +++ b/java/src/org/openqa/selenium/grid/node/config/NodeOptions.java @@ -803,7 +803,7 @@ private void report(Map.Entry> entry) private String unquote(String input) { int len = input.length(); if ((input.charAt(0) == '"') && (input.charAt(len - 1) == '"')) { - return new Json().newInput(new StringReader(input)).read(Json.OBJECT_TYPE); + return new Json().newInput(new StringReader(input)).readNonNull(Json.OBJECT_TYPE); } return input; } diff --git a/java/src/org/openqa/selenium/grid/node/config/SessionCapabilitiesMutator.java b/java/src/org/openqa/selenium/grid/node/config/SessionCapabilitiesMutator.java index 1d3431d779ab4..89fcf5c58b017 100644 --- a/java/src/org/openqa/selenium/grid/node/config/SessionCapabilitiesMutator.java +++ b/java/src/org/openqa/selenium/grid/node/config/SessionCapabilitiesMutator.java @@ -27,6 +27,7 @@ import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.PersistentCapabilities; +import org.openqa.selenium.internal.Require; public class SessionCapabilitiesMutator implements Function { @@ -49,11 +50,13 @@ public Capabilities apply(Capabilities capabilities) { return capabilities; } - if (slotStereotype.getCapability(SE_VNC_ENABLED) != null) { + Object vncEnabled = slotStereotype.getCapability(SE_VNC_ENABLED); + if (vncEnabled != null) { + Object vncPort = slotStereotype.getCapability(SE_NO_VNC_PORT); capabilities = new PersistentCapabilities(capabilities) - .setCapability(SE_VNC_ENABLED, slotStereotype.getCapability(SE_VNC_ENABLED)) - .setCapability(SE_NO_VNC_PORT, slotStereotype.getCapability(SE_NO_VNC_PORT)); + .setCapability(SE_VNC_ENABLED, vncEnabled) + .setCapability(SE_NO_VNC_PORT, Require.nonNull(SE_NO_VNC_PORT, vncPort)); } String browserName = capabilities.getBrowserName().toLowerCase(Locale.ENGLISH); diff --git a/java/src/org/openqa/selenium/grid/node/config/package-info.java b/java/src/org/openqa/selenium/grid/node/config/package-info.java new file mode 100644 index 0000000000000..51df08ff26d59 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/config/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.config; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/docker/BUILD.bazel b/java/src/org/openqa/selenium/grid/node/docker/BUILD.bazel index 6cc6de13f6695..a4e1467776f4c 100644 --- a/java/src/org/openqa/selenium/grid/node/docker/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/node/docker/BUILD.bazel @@ -22,5 +22,6 @@ java_library( "//java/src/org/openqa/selenium/support", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/node/docker/DockerOptions.java b/java/src/org/openqa/selenium/grid/node/docker/DockerOptions.java index d95363d8e9502..ce422300835a2 100644 --- a/java/src/org/openqa/selenium/grid/node/docker/DockerOptions.java +++ b/java/src/org/openqa/selenium/grid/node/docker/DockerOptions.java @@ -37,6 +37,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.Platform; import org.openqa.selenium.docker.ContainerId; @@ -115,6 +116,7 @@ private Duration getServerStartTimeout() { config.getInt(DOCKER_SECTION, "server-start-timeout").orElse(DEFAULT_SERVER_START_TIMEOUT)); } + @Nullable private String getApiVersion() { return config.get(DOCKER_SECTION, "api-version").orElse(null); } @@ -238,6 +240,7 @@ protected List getDevicesMapping() { return deviceMapping; } + @Nullable private Image getVideoImage(Docker docker) { String videoImage = config.get(DOCKER_SECTION, "video-image").orElse(DEFAULT_VIDEO_IMAGE); if (videoImage.equalsIgnoreCase("false")) { @@ -280,6 +283,7 @@ private Map getGroupingLabels(Optional info) { .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } + @Nullable @SuppressWarnings("OptionalUsedAsFieldOrParameterType") private DockerAssetsPath getAssetsPath(Optional info) { if (info.isPresent()) { diff --git a/java/src/org/openqa/selenium/grid/node/docker/DockerSession.java b/java/src/org/openqa/selenium/grid/node/docker/DockerSession.java index 6608e7cb7bdca..464be03e2ab5d 100644 --- a/java/src/org/openqa/selenium/grid/node/docker/DockerSession.java +++ b/java/src/org/openqa/selenium/grid/node/docker/DockerSession.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.docker.Container; import org.openqa.selenium.grid.node.DefaultActiveSession; @@ -38,12 +39,12 @@ public class DockerSession extends DefaultActiveSession { private static final Logger LOG = Logger.getLogger(DockerSession.class.getName()); private final Container container; - private final Container videoContainer; + private final @Nullable Container videoContainer; private final DockerAssetsPath assetsPath; DockerSession( Container container, - Container videoContainer, + @Nullable Container videoContainer, Tracer tracer, HttpClient client, SessionId id, diff --git a/java/src/org/openqa/selenium/grid/node/docker/DockerSessionFactory.java b/java/src/org/openqa/selenium/grid/node/docker/DockerSessionFactory.java index 9aef964df9246..189275b166240 100644 --- a/java/src/org/openqa/selenium/grid/node/docker/DockerSessionFactory.java +++ b/java/src/org/openqa/selenium/grid/node/docker/DockerSessionFactory.java @@ -20,7 +20,6 @@ import static java.util.Optional.ofNullable; import static org.openqa.selenium.docker.ContainerConfig.image; import static org.openqa.selenium.remote.Dialect.W3C; -import static org.openqa.selenium.remote.http.Contents.string; import static org.openqa.selenium.remote.http.HttpMethod.GET; import static org.openqa.selenium.remote.tracing.Tags.EXCEPTION; @@ -44,6 +43,7 @@ import java.util.function.Predicate; import java.util.logging.Level; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.Dimension; import org.openqa.selenium.ImmutableCapabilities; @@ -98,8 +98,8 @@ public class DockerSessionFactory implements SessionFactory { private final Image browserImage; private final Capabilities stereotype; private final List devices; - private final Image videoImage; - private final DockerAssetsPath assetsPath; + private final @Nullable Image videoImage; + private final @Nullable DockerAssetsPath assetsPath; private final String networkName; private final boolean runningInDocker; private final Predicate predicate; @@ -117,8 +117,8 @@ public DockerSessionFactory( Image browserImage, Capabilities stereotype, List devices, - Image videoImage, - DockerAssetsPath assetsPath, + @Nullable Image videoImage, + @Nullable DockerAssetsPath assetsPath, String networkName, boolean runningInDocker, Predicate predicate, @@ -162,7 +162,7 @@ public Either apply(CreateSessionRequest sess // Generate unique identifier for consistent naming between browser and recorder containers // Using browserName-timestamp-UUID to avoid conflicts in concurrent session creation String browserName = sessionRequest.getDesiredCapabilities().getBrowserName(); - if (browserName != null && !browserName.isEmpty()) { + if (!browserName.isEmpty()) { browserName = browserName.toLowerCase(); } else { browserName = "unknown"; @@ -374,6 +374,7 @@ private void setCapsToEnvVars( timeZone.ifPresent(zone -> envVars.put("TZ", zone.getID())); } + @Nullable private Container startVideoContainer( Capabilities sessionCapabilities, String browserContainerIp, @@ -438,6 +439,7 @@ private Map getVideoContainerEnvVars( return envVars; } + @Nullable private String getVideoFileName(Capabilities sessionRequestCapabilities, String capabilityName) { Optional testName = ofNullable(sessionRequestCapabilities.getCapability(capabilityName)); @@ -454,6 +456,7 @@ private String getVideoFileName(Capabilities sessionRequestCapabilities, String return null; } + @Nullable private TimeZone getTimeZone(Capabilities sessionRequestCapabilities) { Optional timeZone = ofNullable(sessionRequestCapabilities.getCapability("se:timeZone")); if (timeZone.isPresent()) { @@ -469,6 +472,7 @@ private TimeZone getTimeZone(Capabilities sessionRequestCapabilities) { return null; } + @Nullable private Dimension getScreenResolution(Capabilities sessionRequestCapabilities) { Optional screenResolution = ofNullable(sessionRequestCapabilities.getCapability("se:screenResolution")); @@ -521,7 +525,7 @@ private void waitForServerToStart(HttpClient client, Duration duration) { wait.until( obj -> { HttpResponse response = client.execute(new HttpRequest(GET, "/status")); - LOG.fine(string(response)); + LOG.fine(response::contentAsString); if (401 == response.getStatus()) { LOG.warning( "Server requires basic authentication. " diff --git a/java/src/org/openqa/selenium/grid/node/docker/package-info.java b/java/src/org/openqa/selenium/grid/node/docker/package-info.java new file mode 100644 index 0000000000000..b2e9f6b6b6b1b --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/docker/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.docker; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/httpd/BUILD.bazel b/java/src/org/openqa/selenium/grid/node/httpd/BUILD.bazel index 605d9a02092d0..04bcc32571c75 100644 --- a/java/src/org/openqa/selenium/grid/node/httpd/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/node/httpd/BUILD.bazel @@ -24,5 +24,6 @@ java_library( "//java/src/org/openqa/selenium/remote", artifact("com.google.guava:guava"), artifact("dev.failsafe:failsafe"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/node/httpd/NodeServer.java b/java/src/org/openqa/selenium/grid/node/httpd/NodeServer.java index dd93d481b4f7d..a69044b822575 100644 --- a/java/src/org/openqa/selenium/grid/node/httpd/NodeServer.java +++ b/java/src/org/openqa/selenium/grid/node/httpd/NodeServer.java @@ -38,6 +38,7 @@ import java.util.concurrent.Executors; import java.util.logging.Level; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.BuildInfo; import org.openqa.selenium.cli.CliCommand; import org.openqa.selenium.events.EventBus; @@ -72,8 +73,8 @@ public class NodeServer extends TemplateGridServerCommand { private static final Logger LOG = Logger.getLogger(NodeServer.class.getName()); - private Node node; - private EventBus bus; + private @Nullable Node node; + private @Nullable EventBus bus; private final Thread shutdownHook = new Thread(() -> bus.fire(new NodeRemovedEvent(node.getStatus()))); diff --git a/java/src/org/openqa/selenium/grid/node/httpd/package-info.java b/java/src/org/openqa/selenium/grid/node/httpd/package-info.java new file mode 100644 index 0000000000000..e1c8b8d14cf0b --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/httpd/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.httpd; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/k8s/BUILD.bazel b/java/src/org/openqa/selenium/grid/node/k8s/BUILD.bazel index 68b5f9cd50cc0..e78639b6c9eef 100644 --- a/java/src/org/openqa/selenium/grid/node/k8s/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/node/k8s/BUILD.bazel @@ -14,5 +14,6 @@ java_library( "//java/src/org/openqa/selenium/remote", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/node/k8s/OneShotNode.java b/java/src/org/openqa/selenium/grid/node/k8s/OneShotNode.java index 1281083a0970c..a0ebc6848d71d 100644 --- a/java/src/org/openqa/selenium/grid/node/k8s/OneShotNode.java +++ b/java/src/org/openqa/selenium/grid/node/k8s/OneShotNode.java @@ -39,6 +39,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.StreamSupport; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.NoSuchSessionException; @@ -106,10 +107,10 @@ public class OneShotNode extends Node { private final UUID slotId = UUID.randomUUID(); private final int connectionLimitPerSession; private final AtomicInteger connectionCounter = new AtomicInteger(); - private RemoteWebDriver driver; - private SessionId sessionId; - private HttpClient client; - private Capabilities capabilities; + private @Nullable RemoteWebDriver driver; + private @Nullable SessionId sessionId; + private @Nullable HttpClient client; + private @Nullable Capabilities capabilities; private Instant sessionStart = Instant.EPOCH; private OneShotNode( @@ -264,6 +265,7 @@ private HttpClient extractHttpClient(RemoteWebDriver driver) { } } + @Nullable private Field findClientField(Class clazz) { try { return clazz.getDeclaredField("client"); @@ -334,11 +336,13 @@ public Session getSession(SessionId id) throws NoSuchSessionException { return new Session(sessionId, getUri(), stereotype, capabilities, sessionStart); } + @Nullable @Override public HttpResponse uploadFile(HttpRequest req, SessionId id) { return null; } + @Nullable @Override public HttpResponse downloadFile(HttpRequest req, SessionId id) { return null; diff --git a/java/src/org/openqa/selenium/grid/node/k8s/package-info.java b/java/src/org/openqa/selenium/grid/node/k8s/package-info.java new file mode 100644 index 0000000000000..645e3b9df98d8 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/k8s/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.k8s; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/kubernetes/BUILD.bazel b/java/src/org/openqa/selenium/grid/node/kubernetes/BUILD.bazel index 54641b1b7924e..bef0936b08149 100644 --- a/java/src/org/openqa/selenium/grid/node/kubernetes/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/node/kubernetes/BUILD.bazel @@ -38,6 +38,7 @@ java_export( artifact("io.fabric8:kubernetes-client-api"), artifact("io.fabric8:kubernetes-model-batch"), artifact("io.fabric8:kubernetes-model-core"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/node/kubernetes/InheritedPodSpec.java b/java/src/org/openqa/selenium/grid/node/kubernetes/InheritedPodSpec.java index 827a10b0e35cc..ebd5e07694f5e 100644 --- a/java/src/org/openqa/selenium/grid/node/kubernetes/InheritedPodSpec.java +++ b/java/src/org/openqa/selenium/grid/node/kubernetes/InheritedPodSpec.java @@ -26,43 +26,44 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import org.jspecify.annotations.Nullable; public class InheritedPodSpec { private final List tolerations; - private final Affinity affinity; + private final @Nullable Affinity affinity; private final List imagePullSecrets; - private final String dnsPolicy; - private final PodDNSConfig dnsConfig; - private final PodSecurityContext securityContext; - private final String priorityClassName; + private final @Nullable String dnsPolicy; + private final @Nullable PodDNSConfig dnsConfig; + private final @Nullable PodSecurityContext securityContext; + private final @Nullable String priorityClassName; private final Map nodeSelector; - private final String serviceAccountName; + private final @Nullable String serviceAccountName; private final Map labels; private final Map annotations; - private final String imagePullPolicy; + private final @Nullable String imagePullPolicy; private final Map resourceRequests; private final Map resourceLimits; - private final String assetsClaimName; - private final String nodePodName; - private final String nodePodUid; + private final @Nullable String assetsClaimName; + private final @Nullable String nodePodName; + private final @Nullable String nodePodUid; public InheritedPodSpec( - List tolerations, - Affinity affinity, - List imagePullSecrets, - String dnsPolicy, - PodDNSConfig dnsConfig, - PodSecurityContext securityContext, - String priorityClassName, - Map nodeSelector, - String serviceAccountName, - Map labels, - Map annotations, - String imagePullPolicy, - Map resourceRequests, - Map resourceLimits, - String assetsClaimName) { + @Nullable List tolerations, + @Nullable Affinity affinity, + @Nullable List imagePullSecrets, + @Nullable String dnsPolicy, + @Nullable PodDNSConfig dnsConfig, + @Nullable PodSecurityContext securityContext, + @Nullable String priorityClassName, + @Nullable Map nodeSelector, + @Nullable String serviceAccountName, + @Nullable Map labels, + @Nullable Map annotations, + @Nullable String imagePullPolicy, + @Nullable Map resourceRequests, + @Nullable Map resourceLimits, + @Nullable String assetsClaimName) { this( tolerations, affinity, @@ -84,23 +85,23 @@ public InheritedPodSpec( } public InheritedPodSpec( - List tolerations, - Affinity affinity, - List imagePullSecrets, - String dnsPolicy, - PodDNSConfig dnsConfig, - PodSecurityContext securityContext, - String priorityClassName, - Map nodeSelector, - String serviceAccountName, - Map labels, - Map annotations, - String imagePullPolicy, - Map resourceRequests, - Map resourceLimits, - String assetsClaimName, - String nodePodName, - String nodePodUid) { + @Nullable List tolerations, + @Nullable Affinity affinity, + @Nullable List imagePullSecrets, + @Nullable String dnsPolicy, + @Nullable PodDNSConfig dnsConfig, + @Nullable PodSecurityContext securityContext, + @Nullable String priorityClassName, + @Nullable Map nodeSelector, + @Nullable String serviceAccountName, + @Nullable Map labels, + @Nullable Map annotations, + @Nullable String imagePullPolicy, + @Nullable Map resourceRequests, + @Nullable Map resourceLimits, + @Nullable String assetsClaimName, + @Nullable String nodePodName, + @Nullable String nodePodUid) { this.tolerations = tolerations != null ? List.copyOf(tolerations) : List.of(); this.affinity = affinity; this.imagePullSecrets = imagePullSecrets != null ? List.copyOf(imagePullSecrets) : List.of(); @@ -150,6 +151,7 @@ public List getTolerations() { return tolerations; } + @Nullable public Affinity getAffinity() { return affinity; } @@ -158,18 +160,22 @@ public List getImagePullSecrets() { return imagePullSecrets; } + @Nullable public String getDnsPolicy() { return dnsPolicy; } + @Nullable public PodDNSConfig getDnsConfig() { return dnsConfig; } + @Nullable public PodSecurityContext getSecurityContext() { return securityContext; } + @Nullable public String getPriorityClassName() { return priorityClassName; } @@ -178,6 +184,7 @@ public Map getNodeSelector() { return nodeSelector; } + @Nullable public String getServiceAccountName() { return serviceAccountName; } @@ -190,6 +197,7 @@ public Map getAnnotations() { return annotations; } + @Nullable public String getImagePullPolicy() { return imagePullPolicy; } @@ -202,6 +210,7 @@ public Map getResourceLimits() { return resourceLimits; } + @Nullable public String getAssetsClaimName() { return assetsClaimName; } @@ -213,10 +222,12 @@ public boolean hasNodePodOwnerReference() { && !nodePodUid.isEmpty(); } + @Nullable public String getNodePodName() { return nodePodName; } + @Nullable public String getNodePodUid() { return nodePodUid; } diff --git a/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesOptions.java b/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesOptions.java index cb25d7f57792e..de032983ec704 100644 --- a/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesOptions.java +++ b/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesOptions.java @@ -41,6 +41,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.grid.config.ConfigException; import org.openqa.selenium.grid.node.SessionFactory; @@ -320,7 +321,7 @@ String getNamespace() { return getNamespace(null); } - String getNamespace(KubernetesClient kubeClient) { + String getNamespace(@Nullable KubernetesClient kubeClient) { // Priority: config → client auto-detected namespace → "default" // The fabric8 KubernetesClient.getNamespace() already reads from kubeconfig, // in-cluster service account namespace file, and KUBERNETES_NAMESPACE env var. @@ -366,6 +367,7 @@ Map getNodeSelector() { return parseKeyValueMap(config.get(K8S_SECTION, "node-selector").orElse(null)); } + @Nullable private String getVideoImage() { String image = config.get(K8S_SECTION, "video-image").orElse(DEFAULT_VIDEO_IMAGE); if (image.equalsIgnoreCase("false")) { @@ -374,6 +376,7 @@ private String getVideoImage() { return image; } + @Nullable private String getAssetsPath() { return config.get(K8S_SECTION, "assets-path").orElse(null); } @@ -388,7 +391,10 @@ boolean isRunningInKubernetes() { } InheritedPodSpec inspectNodePod( - KubernetesClient kubeClient, String namespace, String labelInheritPrefix, String assetsPath) { + KubernetesClient kubeClient, + String namespace, + String labelInheritPrefix, + @Nullable String assetsPath) { if (!isRunningInKubernetes()) { LOG.fine("Not running in Kubernetes; skipping Node Pod inspection"); return InheritedPodSpec.empty(); @@ -496,7 +502,8 @@ InheritedPodSpec inspectNodePod( } } - static Map filterByPrefix(Map map, String prefix) { + static Map filterByPrefix( + @Nullable Map map, @Nullable String prefix) { if (map == null || map.isEmpty()) { return Collections.emptyMap(); } @@ -508,7 +515,7 @@ static Map filterByPrefix(Map map, String prefix .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); } - static Map parseResourceMap(String resourceString) { + static Map parseResourceMap(@Nullable String resourceString) { if (resourceString == null || resourceString.trim().isEmpty()) { return Collections.emptyMap(); } @@ -522,7 +529,7 @@ static Map parseResourceMap(String resourceString) { return resources; } - static Map parseKeyValueMap(String mapString) { + static Map parseKeyValueMap(@Nullable String mapString) { if (mapString == null || mapString.trim().isEmpty()) { return Collections.emptyMap(); } diff --git a/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesSession.java b/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesSession.java index 8337889882940..b466892577412 100644 --- a/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesSession.java +++ b/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesSession.java @@ -32,6 +32,7 @@ import java.time.Instant; import java.util.logging.Level; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.grid.node.DefaultActiveSession; import org.openqa.selenium.internal.Require; @@ -51,20 +52,20 @@ public class KubernetesSession extends DefaultActiveSession { private final String namespace; private final KubernetesClient kubeClient; private final String podName; - private final String assetsPath; - private final String videoFileName; + private final @Nullable String assetsPath; + private final @Nullable String videoFileName; private final long terminationGracePeriodSeconds; - private final LocalPortForward portForward; + private final @Nullable LocalPortForward portForward; KubernetesSession( String jobName, String namespace, KubernetesClient kubeClient, String podName, - String assetsPath, - String videoFileName, + @Nullable String assetsPath, + @Nullable String videoFileName, long terminationGracePeriodSeconds, - LocalPortForward portForward, + @Nullable LocalPortForward portForward, Tracer tracer, HttpClient client, SessionId id, diff --git a/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesSessionFactory.java b/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesSessionFactory.java index 6523bac2e1ec2..b918eff88bb9c 100644 --- a/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesSessionFactory.java +++ b/java/src/org/openqa/selenium/grid/node/kubernetes/KubernetesSessionFactory.java @@ -19,7 +19,6 @@ import static java.util.Optional.ofNullable; import static org.openqa.selenium.remote.Dialect.W3C; -import static org.openqa.selenium.remote.http.Contents.string; import static org.openqa.selenium.remote.http.HttpMethod.GET; import static org.openqa.selenium.remote.tracing.Tags.EXCEPTION; @@ -81,6 +80,7 @@ import java.util.logging.Logger; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.Dimension; import org.openqa.selenium.ImmutableCapabilities; @@ -134,19 +134,19 @@ public class KubernetesSessionFactory implements SessionFactory { private final String namespace; private final String browserImage; private final Capabilities stereotype; - private final String imagePullPolicy; - private final String serviceAccount; + private final @Nullable String imagePullPolicy; + private final @Nullable String serviceAccount; private final Map resourceRequests; private final Map resourceLimits; private final Map nodeSelector; - private final String videoImage; - private final String assetsPath; + private final @Nullable String videoImage; + private final @Nullable String assetsPath; private final InheritedPodSpec inheritedPodSpec; - private final Job jobTemplate; + private final @Nullable Job jobTemplate; private final long terminationGracePeriodSeconds; private final boolean usePortForwarding; private final Predicate predicate; - private final OwnerReference nodePodOwnerReference; + private final @Nullable OwnerReference nodePodOwnerReference; public KubernetesSessionFactory( Tracer tracer, @@ -158,12 +158,12 @@ public KubernetesSessionFactory( String browserImage, Capabilities stereotype, String imagePullPolicy, - String serviceAccount, + @Nullable String serviceAccount, Map resourceRequests, Map resourceLimits, Map nodeSelector, - String videoImage, - String assetsPath, + @Nullable String videoImage, + @Nullable String assetsPath, InheritedPodSpec inheritedPodSpec, long terminationGracePeriodSeconds, boolean usePortForwarding, @@ -202,8 +202,8 @@ public KubernetesSessionFactory( String browserImage, Capabilities stereotype, Job jobTemplate, - String videoImage, - String assetsPath, + @Nullable String videoImage, + @Nullable String assetsPath, long terminationGracePeriodSeconds, boolean usePortForwarding, Predicate predicate) { @@ -235,8 +235,8 @@ public KubernetesSessionFactory( String browserImage, Capabilities stereotype, Job jobTemplate, - String videoImage, - String assetsPath, + @Nullable String videoImage, + @Nullable String assetsPath, InheritedPodSpec inheritedPodSpec, long terminationGracePeriodSeconds, boolean usePortForwarding, @@ -319,7 +319,7 @@ public Either apply(CreateSessionRequest sess LOG.info("Starting K8s session for " + sessionRequest.getDesiredCapabilities()); String browserName = sessionRequest.getDesiredCapabilities().getBrowserName(); - if (browserName == null || browserName.isEmpty()) { + if (browserName.isEmpty()) { browserName = "unknown"; } else { browserName = browserName.toLowerCase(); @@ -539,7 +539,9 @@ private String generateJobName(String browserName, long timestamp, String unique return name; } - private static OwnerReference createNodePodOwnerReference(InheritedPodSpec inheritedPodSpec) { + @Nullable + private static OwnerReference createNodePodOwnerReference( + @Nullable InheritedPodSpec inheritedPodSpec) { if (inheritedPodSpec == null || !inheritedPodSpec.hasNodePodOwnerReference()) { return null; } @@ -571,7 +573,7 @@ Job buildJobSpec(String jobName, Capabilities sessionCapabilities) { labels.put("app", "selenium-session"); labels.put("se/job-name", jobName); String browser = sessionCapabilities.getBrowserName(); - if (browser != null && !browser.isEmpty()) { + if (!browser.isEmpty()) { labels.put("se/browser", browser.toLowerCase()); } @@ -861,6 +863,7 @@ private Container buildVideoContainer(String jobName, Capabilities sessionCapabi .build(); } + @Nullable private String getVideoFileName(Capabilities sessionRequestCapabilities, String capabilityName) { String trimRegex = getVideoFileNameTrimRegex(); Optional testName = @@ -916,7 +919,7 @@ Job buildJobSpecFromTemplate(String jobName, Capabilities sessionCapabilities) { labels.put("app", "selenium-session"); labels.put("se/job-name", jobName); String browser = sessionCapabilities.getBrowserName(); - if (browser != null && !browser.isEmpty()) { + if (!browser.isEmpty()) { labels.put("se/browser", browser.toLowerCase()); } @@ -990,7 +993,8 @@ && recordVideoForSession(sessionCapabilities)) { return job; } - static Container findContainerByName(List containers, String name) { + @Nullable + static Container findContainerByName(@Nullable List containers, String name) { if (containers == null) { return null; } @@ -1059,7 +1063,8 @@ static void ensurePort(Container container, int port) { } } - static void ensureVolumeMount(Container container, String volumeName, String mountPath) { + static void ensureVolumeMount( + @Nullable Container container, String volumeName, String mountPath) { if (container == null) { return; } @@ -1147,7 +1152,7 @@ private String[] doWaitForPodRunning(String jobName) { .inNamespace(namespace) .withLabel("job-name", jobName) .watch( - new Watcher() { + new Watcher<>() { @Override public void eventReceived(Action action, Pod pod) { if (action == Action.DELETED) { @@ -1166,7 +1171,7 @@ public void eventReceived(Action action, Pod pod) { } @Override - public void onClose(WatcherException cause) { + public void onClose(@Nullable WatcherException cause) { if (!future.isDone() && cause != null) { future.completeExceptionally( new SessionNotCreatedException( @@ -1283,11 +1288,12 @@ private boolean isPodReady(Pod pod) { return true; } + @Nullable private String resolvePodIp(Pod pod) { if (pod.getStatus() == null) { return null; } - List podIps = pod.getStatus().getPodIPs(); + List<@Nullable PodIP> podIps = pod.getStatus().getPodIPs(); if (podIps != null) { for (PodIP podIp : podIps) { if (podIp != null && podIp.getIp() != null && !podIp.getIp().isBlank()) { @@ -1309,7 +1315,7 @@ private void waitForServerToStart(HttpClient client, Duration duration) { wait.until( obj -> { HttpResponse response = client.execute(new HttpRequest(GET, "/status")); - LOG.fine(string(response)); + LOG.fine(response::contentAsString); if (401 == response.getStatus()) { LOG.warning( "Server requires basic authentication. " @@ -1339,7 +1345,7 @@ private Capabilities addForwardCdpEndpoint( .setCapability("se:forwardCdp", forwardCdpPath); } - private void closePortForward(LocalPortForward portForward) { + private void closePortForward(@Nullable LocalPortForward portForward) { if (portForward != null) { try { portForward.close(); @@ -1387,6 +1393,7 @@ private URL getUrl(String host, int port) { } } + @Nullable private TimeZone getTimeZone(Capabilities sessionRequestCapabilities) { Optional timeZone = ofNullable(sessionRequestCapabilities.getCapability("se:timeZone")); if (timeZone.isPresent()) { @@ -1408,6 +1415,7 @@ private boolean recordVideoForSession(Capabilities sessionRequestCapabilities) { return recordVideo.isPresent() && Boolean.parseBoolean(recordVideo.get().toString()); } + @Nullable private Dimension getScreenResolution(Capabilities sessionRequestCapabilities) { Optional screenResolution = ofNullable(sessionRequestCapabilities.getCapability("se:screenResolution")); diff --git a/java/src/org/openqa/selenium/grid/node/kubernetes/package-info.java b/java/src/org/openqa/selenium/grid/node/kubernetes/package-info.java new file mode 100644 index 0000000000000..6f37afbf08e4c --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/kubernetes/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.kubernetes; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/local/BUILD.bazel b/java/src/org/openqa/selenium/grid/node/local/BUILD.bazel index 3a29f82adb381..d33fea9fc67d9 100644 --- a/java/src/org/openqa/selenium/grid/node/local/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/node/local/BUILD.bazel @@ -29,5 +29,6 @@ java_library( "//java/src/org/openqa/selenium/remote", artifact("com.google.guava:guava"), artifact("com.github.ben-manes.caffeine:caffeine"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/node/local/LocalNode.java b/java/src/org/openqa/selenium/grid/node/local/LocalNode.java index b1d9fde09ad53..c2aaeaa1a0bb9 100644 --- a/java/src/org/openqa/selenium/grid/node/local/LocalNode.java +++ b/java/src/org/openqa/selenium/grid/node/local/LocalNode.java @@ -76,6 +76,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.MutableCapabilities; @@ -171,7 +172,7 @@ protected LocalNode( EventBus bus, URI uri, URI gridUri, - HealthCheck healthCheck, + @Nullable HealthCheck healthCheck, int maxSessionCount, int drainAfterSessionCount, boolean cdpEnabled, @@ -337,7 +338,8 @@ public void close() { shutdown.run(); } - private void stopTimedOutSession(SessionId id, SessionSlot slot, RemovalCause cause) { + private void stopTimedOutSession( + @Nullable SessionId id, @Nullable SessionSlot slot, RemovalCause cause) { try (Span span = tracer.getCurrentContext().createSpan("node.stop_session")) { AttributeMap attributeMap = tracer.createAttributeMap(); attributeMap.put(AttributeKey.LOGGER_CLASS.getKey(), getClass().getName()); @@ -1013,7 +1015,7 @@ private HttpResponse deleteDownloadedFile(File downloadsDirectory) { for (File file : files) { FileHandler.delete(file); } - Map toReturn = new HashMap<>(); + Map toReturn = new HashMap<>(); toReturn.put("value", null); return new HttpResponse().setContent(asJson(toReturn)); } diff --git a/java/src/org/openqa/selenium/grid/node/local/SessionSlot.java b/java/src/org/openqa/selenium/grid/node/local/SessionSlot.java index fee8b5789101e..cc615d46dcb38 100644 --- a/java/src/org/openqa/selenium/grid/node/local/SessionSlot.java +++ b/java/src/org/openqa/selenium/grid/node/local/SessionSlot.java @@ -29,6 +29,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.StreamSupport; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.NoSuchSessionException; @@ -67,7 +68,7 @@ public class SessionSlot private final boolean supportingBiDi; private final AtomicLong connectionCounter; // volatile ensures memory visibility across threads when session is set after reservation - private volatile ActiveSession currentSession; + private volatile @Nullable ActiveSession currentSession; public SessionSlot(EventBus bus, Capabilities stereotype, SessionFactory factory) { this.bus = Require.nonNull("Event bus", bus); @@ -118,6 +119,7 @@ public boolean isAvailable() { return !reserved.get(); } + @Nullable public ActiveSession getSession() { if (isAvailable()) { throw new NoSuchSessionException("Session is not running"); @@ -143,7 +145,7 @@ public void stop(SessionClosedReason reason) { * @param nodeUri the URI of the node where the session was running (may be null for backward * compatibility) */ - public void stop(SessionClosedReason reason, NodeId nodeId, URI nodeUri) { + public void stop(SessionClosedReason reason, @Nullable NodeId nodeId, @Nullable URI nodeUri) { if (isAvailable()) { return; } diff --git a/java/src/org/openqa/selenium/grid/node/local/package-info.java b/java/src/org/openqa/selenium/grid/node/local/package-info.java new file mode 100644 index 0000000000000..f47f0602b0030 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/local/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.local; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/package-info.java b/java/src/org/openqa/selenium/grid/node/package-info.java new file mode 100644 index 0000000000000..f858c033d8a46 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/relay/BUILD.bazel b/java/src/org/openqa/selenium/grid/node/relay/BUILD.bazel index a9e354cd1ef94..0ed95a851837a 100644 --- a/java/src/org/openqa/selenium/grid/node/relay/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/node/relay/BUILD.bazel @@ -17,5 +17,6 @@ java_library( "//java/src/org/openqa/selenium/remote", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/node/relay/RelayOptions.java b/java/src/org/openqa/selenium/grid/node/relay/RelayOptions.java index 3342a585c2b7e..ac7902438c966 100644 --- a/java/src/org/openqa/selenium/grid/node/relay/RelayOptions.java +++ b/java/src/org/openqa/selenium/grid/node/relay/RelayOptions.java @@ -17,7 +17,6 @@ package org.openqa.selenium.grid.node.relay; -import static org.openqa.selenium.remote.http.Contents.string; import static org.openqa.selenium.remote.http.HttpMethod.GET; import java.net.URI; @@ -30,6 +29,7 @@ import java.util.Map; import java.util.Optional; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.grid.config.Config; @@ -87,6 +87,7 @@ public URI getServiceUri() { } } + @Nullable public URI getServiceStatusUri() { try { if (config.get(RELAY_SECTION, "status-endpoint").isEmpty()) { @@ -134,7 +135,7 @@ private boolean isServiceUp(HttpClient client) { } try { HttpResponse response = client.execute(new HttpRequest(GET, serviceStatusUri.toString())); - LOG.fine(string(response)); + LOG.fine(response::contentAsString); return 200 == response.getStatus(); } catch (Exception e) { throw new ConfigException("Unable to reach the service at " + getServiceUri(), e); diff --git a/java/src/org/openqa/selenium/grid/node/relay/RelaySessionFactory.java b/java/src/org/openqa/selenium/grid/node/relay/RelaySessionFactory.java index 493b368956802..8dd43d5a90e39 100644 --- a/java/src/org/openqa/selenium/grid/node/relay/RelaySessionFactory.java +++ b/java/src/org/openqa/selenium/grid/node/relay/RelaySessionFactory.java @@ -39,6 +39,7 @@ import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.MutableCapabilities; @@ -59,7 +60,6 @@ import org.openqa.selenium.remote.Response; import org.openqa.selenium.remote.SessionId; import org.openqa.selenium.remote.http.ClientConfig; -import org.openqa.selenium.remote.http.Contents; import org.openqa.selenium.remote.http.HttpClient; import org.openqa.selenium.remote.http.HttpMethod; import org.openqa.selenium.remote.http.HttpRequest; @@ -77,7 +77,7 @@ public class RelaySessionFactory implements SessionFactory { private final HttpClient.Factory clientFactory; private final Duration sessionTimeout; private final URL serviceUrl; - private final URL serviceStatusUrl; + private final @Nullable URL serviceStatusUrl; private final String serviceProtocolVersion; private final Capabilities stereotype; @@ -86,13 +86,13 @@ public RelaySessionFactory( HttpClient.Factory clientFactory, Duration sessionTimeout, URI serviceUri, - URI serviceStatusUri, + @Nullable URI serviceStatusUri, String serviceProtocolVersion, Capabilities stereotype) { this.tracer = Require.nonNull("Tracer", tracer); this.clientFactory = Require.nonNull("HTTP client", clientFactory); this.sessionTimeout = Require.nonNull("Session timeout", sessionTimeout); - this.serviceUrl = createUrlFromUri(Require.nonNull("Service URL", serviceUri)); + this.serviceUrl = Require.nonNull("Service URL", createUrlFromUri(serviceUri)); this.serviceStatusUrl = createUrlFromUri(serviceStatusUri); this.serviceProtocolVersion = Require.nonNull("Service protocol version", serviceProtocolVersion); @@ -244,7 +244,7 @@ public boolean isServiceUp() { try (HttpClient client = clientFactory.createClient(clientConfig)) { HttpResponse response = client.execute(new HttpRequest(HttpMethod.GET, serviceStatusUrl.toString())); - LOG.log(Debug.getDebugLogLevel(), () -> Contents.string(response)); + LOG.log(Debug.getDebugLogLevel(), response::contentAsString); return response.getStatus() == 200; } catch (Exception e) { LOG.log( @@ -257,7 +257,8 @@ public boolean isServiceUp() { return false; } - private URL createUrlFromUri(URI uri) { + @Nullable + private URL createUrlFromUri(@Nullable URI uri) { if (uri == null) { return null; } diff --git a/java/src/org/openqa/selenium/grid/node/relay/package-info.java b/java/src/org/openqa/selenium/grid/node/relay/package-info.java new file mode 100644 index 0000000000000..d0133b3427817 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/relay/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.relay; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/node/remote/BUILD.bazel b/java/src/org/openqa/selenium/grid/node/remote/BUILD.bazel index 400855ed6a686..c2801d30992c5 100644 --- a/java/src/org/openqa/selenium/grid/node/remote/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/node/remote/BUILD.bazel @@ -18,5 +18,6 @@ java_library( "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/node/remote/RemoteNode.java b/java/src/org/openqa/selenium/grid/node/remote/RemoteNode.java index 0d0832e658b22..de8f1235c4ccd 100644 --- a/java/src/org/openqa/selenium/grid/node/remote/RemoteNode.java +++ b/java/src/org/openqa/selenium/grid/node/remote/RemoteNode.java @@ -243,7 +243,7 @@ public Session getSession(SessionId id) throws NoSuchSessionException { HttpResponse res = client.with(addSecret).execute(req); - return Values.get(res, Session.class); + return Require.nonNull("Session", Values.get(res, Session.class)); } @Override @@ -299,7 +299,7 @@ private NodeStatus getStatus(HttpHandler handler) { while (in.hasNext()) { if ("node".equals(in.nextName())) { - return in.read(NodeStatus.class); + return in.readNonNull(NodeStatus.class); } else { in.skipValue(); } diff --git a/java/src/org/openqa/selenium/grid/node/remote/package-info.java b/java/src/org/openqa/selenium/grid/node/remote/package-info.java new file mode 100644 index 0000000000000..f026ca4da35cb --- /dev/null +++ b/java/src/org/openqa/selenium/grid/node/remote/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.remote; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/package-info.java b/java/src/org/openqa/selenium/grid/package-info.java index 38cfdad52888b..35fc148c138f2 100644 --- a/java/src/org/openqa/selenium/grid/package-info.java +++ b/java/src/org/openqa/selenium/grid/package-info.java @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -package org.openqa.selenium.grid; - /** * The Selenium Grid is composed of a number of moving pieces, all of which are designed to be used * either locally or across an HTTP boundary. @@ -37,3 +35,7 @@ * which the {@code Node} is running. Conversely, when the session comes to an end, the {@code Node} * is responsible for ensuring that the session is removed from the {@code SessionMap}. */ +@NullMarked +package org.openqa.selenium.grid; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/router/BUILD.bazel b/java/src/org/openqa/selenium/grid/router/BUILD.bazel index 70065bd707e4e..02368b044c94c 100644 --- a/java/src/org/openqa/selenium/grid/router/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/router/BUILD.bazel @@ -24,5 +24,6 @@ java_library( "//java/src/org/openqa/selenium/remote", "//java/src/org/openqa/selenium/status", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/router/httpd/BUILD.bazel b/java/src/org/openqa/selenium/grid/router/httpd/BUILD.bazel index 782f0b52af005..9f44ba247ca57 100644 --- a/java/src/org/openqa/selenium/grid/router/httpd/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/router/httpd/BUILD.bazel @@ -33,5 +33,6 @@ java_library( "//java/src/org/openqa/selenium/remote", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/router/httpd/package-info.java b/java/src/org/openqa/selenium/grid/router/httpd/package-info.java new file mode 100644 index 0000000000000..fdd71e1265265 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/router/httpd/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.router.httpd; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/router/package-info.java b/java/src/org/openqa/selenium/grid/router/package-info.java new file mode 100644 index 0000000000000..1ed3dadc439f8 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/router/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.router; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/security/BUILD.bazel b/java/src/org/openqa/selenium/grid/security/BUILD.bazel index ae959d8c1d8c4..0fc8246210bed 100644 --- a/java/src/org/openqa/selenium/grid/security/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/security/BUILD.bazel @@ -17,5 +17,6 @@ java_library( "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote/http", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/security/BasicAuthenticationFilter.java b/java/src/org/openqa/selenium/grid/security/BasicAuthenticationFilter.java index 90a8c5e858e0b..9dadae5a623fd 100644 --- a/java/src/org/openqa/selenium/grid/security/BasicAuthenticationFilter.java +++ b/java/src/org/openqa/selenium/grid/security/BasicAuthenticationFilter.java @@ -22,6 +22,7 @@ import java.net.HttpURLConnection; import java.util.Base64; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; import org.openqa.selenium.remote.http.Filter; import org.openqa.selenium.remote.http.HttpHandler; @@ -55,7 +56,7 @@ public HttpHandler apply(HttpHandler next) { }; } - private boolean isAuthorized(String auth) { + private boolean isAuthorized(@Nullable String auth) { if (auth != null) { final int index = auth.indexOf(' ') + 1; diff --git a/java/src/org/openqa/selenium/grid/security/SecretOptions.java b/java/src/org/openqa/selenium/grid/security/SecretOptions.java index bc8bb36c92ce0..250ab0e08391d 100644 --- a/java/src/org/openqa/selenium/grid/security/SecretOptions.java +++ b/java/src/org/openqa/selenium/grid/security/SecretOptions.java @@ -24,6 +24,7 @@ import java.nio.file.Files; import java.util.Arrays; import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.UsernameAndPassword; import org.openqa.selenium.grid.config.Config; import org.openqa.selenium.grid.config.ConfigException; @@ -58,6 +59,7 @@ public Secret getRegistrationSecret() { .orElse(new Secret(secret)); } + @Nullable public UsernameAndPassword getServerAuthentication() { Optional username = config.get(ROUTER_SECTION, "username"); Optional password = config.get(ROUTER_SECTION, "password"); diff --git a/java/src/org/openqa/selenium/grid/security/package-info.java b/java/src/org/openqa/selenium/grid/security/package-info.java new file mode 100644 index 0000000000000..29b62cab27d5a --- /dev/null +++ b/java/src/org/openqa/selenium/grid/security/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.security; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/server/BUILD.bazel b/java/src/org/openqa/selenium/grid/server/BUILD.bazel index bb5d132803c3c..cdcadab917182 100644 --- a/java/src/org/openqa/selenium/grid/server/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/server/BUILD.bazel @@ -23,5 +23,6 @@ java_library( "//java/src/org/openqa/selenium/remote", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/server/BaseServerOptions.java b/java/src/org/openqa/selenium/grid/server/BaseServerOptions.java index 73347a593c5d0..c0ab6164e29f1 100644 --- a/java/src/org/openqa/selenium/grid/server/BaseServerOptions.java +++ b/java/src/org/openqa/selenium/grid/server/BaseServerOptions.java @@ -22,6 +22,7 @@ import java.net.URISyntaxException; import java.util.Optional; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.grid.config.Config; import org.openqa.selenium.grid.config.ConfigException; @@ -48,7 +49,7 @@ public BaseServerOptions(Config config) { new JMXHelper().register(this); } - public Optional getHostname() { + public Optional<@Nullable String> getHostname() { return config.get(SERVER_SECTION, "host"); } diff --git a/java/src/org/openqa/selenium/grid/server/EventBusOptions.java b/java/src/org/openqa/selenium/grid/server/EventBusOptions.java index a78e833186465..dd3d5dd86ec75 100644 --- a/java/src/org/openqa/selenium/grid/server/EventBusOptions.java +++ b/java/src/org/openqa/selenium/grid/server/EventBusOptions.java @@ -18,6 +18,7 @@ package org.openqa.selenium.grid.server; import java.time.Duration; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.events.EventBus; import org.openqa.selenium.grid.config.Config; import org.openqa.selenium.internal.Require; @@ -28,7 +29,7 @@ public class EventBusOptions { private static final String DEFAULT_CLASS = "org.openqa.selenium.events.zeromq.ZeroMqEventBus"; private static final int DEFAULT_HEARTBEAT_PERIOD = 60; private final Config config; - private volatile EventBus bus; + private volatile @Nullable EventBus bus; public EventBusOptions(Config config) { this.config = Require.nonNull("Config", config); diff --git a/java/src/org/openqa/selenium/grid/server/package-info.java b/java/src/org/openqa/selenium/grid/server/package-info.java new file mode 100644 index 0000000000000..adecd3fa97fd3 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/server/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.server; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/session/BUILD.bazel b/java/src/org/openqa/selenium/grid/session/BUILD.bazel index 4195249e38602..77100fb69baa5 100644 --- a/java/src/org/openqa/selenium/grid/session/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/session/BUILD.bazel @@ -16,5 +16,6 @@ java_library( "//java/src/org/openqa/selenium/grid/web", "//java/src/org/openqa/selenium/remote", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/session/package-info.java b/java/src/org/openqa/selenium/grid/session/package-info.java new file mode 100644 index 0000000000000..c5aca77dd7376 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/session/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.session; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionmap/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionmap/BUILD.bazel index 0c8d812e8c185..8b0b62764775e 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionmap/BUILD.bazel @@ -19,5 +19,6 @@ java_library( "//java/src/org/openqa/selenium/remote/http", "//java/src/org/openqa/selenium/status", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionmap/config/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionmap/config/BUILD.bazel index b91fb3f78062e..052be637fdab7 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/config/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionmap/config/BUILD.bazel @@ -13,5 +13,6 @@ java_library( "//java/src/org/openqa/selenium/grid/server", "//java/src/org/openqa/selenium/grid/sessionmap", artifact("com.beust:jcommander"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionmap/config/package-info.java b/java/src/org/openqa/selenium/grid/sessionmap/config/package-info.java new file mode 100644 index 0000000000000..cc2417130c9da --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionmap/config/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionmap.config; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionmap/httpd/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionmap/httpd/BUILD.bazel index 61bbd869ee58d..a7df6e94aca57 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/httpd/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionmap/httpd/BUILD.bazel @@ -21,5 +21,6 @@ java_library( "//java/src/org/openqa/selenium/netty/server", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionmap/httpd/package-info.java b/java/src/org/openqa/selenium/grid/sessionmap/httpd/package-info.java new file mode 100644 index 0000000000000..36fff457cdb04 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionmap/httpd/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionmap.httpd; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionmap/jdbc/JdbcBackedSessionMap.java b/java/src/org/openqa/selenium/grid/sessionmap/jdbc/JdbcBackedSessionMap.java index 9746000c27f1c..aba579061bac7 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/jdbc/JdbcBackedSessionMap.java +++ b/java/src/org/openqa/selenium/grid/sessionmap/jdbc/JdbcBackedSessionMap.java @@ -32,6 +32,7 @@ import java.sql.SQLException; import java.time.Instant; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.NoSuchSessionException; @@ -69,22 +70,21 @@ public class JdbcBackedSessionMap extends SessionMap implements Closeable { private static final String DATABASE_USER = AttributeKey.DATABASE_USER.getKey(); private static final String DATABASE_CONNECTION_STRING = AttributeKey.DATABASE_CONNECTION_STRING.getKey(); - private static String jdbcUser; - private static String jdbcUrl; - private final EventBus bus; + private static @Nullable String jdbcUser; + private static @Nullable String jdbcUrl; private final Connection connection; public JdbcBackedSessionMap(Tracer tracer, Connection jdbcConnection, EventBus bus) { super(tracer); Require.nonNull("JDBC Connection Object", jdbcConnection); - this.bus = Require.nonNull("Event bus", bus); + Require.nonNull("Event bus", bus); this.connection = jdbcConnection; // Listen to SessionClosedEvent and extract the sessionId - this.bus.addListener(SessionClosedEvent.sessionListener(this::remove)); + bus.addListener(SessionClosedEvent.sessionListener(this::remove)); - this.bus.addListener( + bus.addListener( NodeRemovedEvent.listener( nodeStatus -> nodeStatus.getSlots().stream() @@ -272,6 +272,7 @@ public Session get(SessionId id) throws NoSuchSessionException { } span.addEvent("Retrieved session from the database", attributeMap); + //noinspection DataFlowIssue return new Session(id, uri, stereotype, caps, start); } catch (SQLException e) { span.setAttribute("error", true); diff --git a/java/src/org/openqa/selenium/grid/sessionmap/jdbc/package-info.java b/java/src/org/openqa/selenium/grid/sessionmap/jdbc/package-info.java new file mode 100644 index 0000000000000..f8b28572af4e4 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionmap/jdbc/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionmap.jdbc; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionmap/local/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionmap/local/BUILD.bazel index 94b7c1589708c..414ec99776a03 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/local/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionmap/local/BUILD.bazel @@ -18,5 +18,6 @@ java_library( "//java/src/org/openqa/selenium/grid/sessionmap", "//java/src/org/openqa/selenium/remote", artifact("com.github.ben-manes.caffeine:caffeine"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java b/java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java index 159c6336c4d75..b47d286c4f782 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java +++ b/java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java @@ -33,6 +33,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.NoSuchSessionException; import org.openqa.selenium.events.EventBus; import org.openqa.selenium.grid.config.Config; @@ -213,6 +214,7 @@ private static class IndexedSessionMap { private final ConcurrentMap> sessionsByUri = new ConcurrentHashMap<>(); private final Object coordinationLock = new Object(); + @Nullable public Session get(SessionId id) { return sessions.get(id); } @@ -232,6 +234,7 @@ public void put(SessionId id, Session session) { } } + @Nullable public Session remove(SessionId id) { synchronized (coordinationLock) { Session removed = sessions.remove(id); diff --git a/java/src/org/openqa/selenium/grid/sessionmap/local/package-info.java b/java/src/org/openqa/selenium/grid/sessionmap/local/package-info.java new file mode 100644 index 0000000000000..b1e9ea1bc027f --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionmap/local/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionmap.local; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionmap/package-info.java b/java/src/org/openqa/selenium/grid/sessionmap/package-info.java new file mode 100644 index 0000000000000..025b77cc59fb7 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionmap/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionmap; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionmap/redis/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionmap/redis/BUILD.bazel index d29484ba06b29..5b9275d4eb11c 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/redis/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionmap/redis/BUILD.bazel @@ -24,5 +24,6 @@ java_export( "//java/src/org/openqa/selenium/remote", artifact("com.google.guava:guava"), artifact("io.lettuce:lettuce-core"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionmap/redis/RedisBackedSessionMap.java b/java/src/org/openqa/selenium/grid/sessionmap/redis/RedisBackedSessionMap.java index f21f5fabb1d48..e989b91c4df30 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/redis/RedisBackedSessionMap.java +++ b/java/src/org/openqa/selenium/grid/sessionmap/redis/RedisBackedSessionMap.java @@ -64,19 +64,18 @@ public class RedisBackedSessionMap extends SessionMap { private static final String DATABASE_SYSTEM = AttributeKey.DATABASE_SYSTEM.getKey(); private static final String DATABASE_OPERATION = AttributeKey.DATABASE_OPERATION.getKey(); private final GridRedisClient connection; - private final EventBus bus; private final URI serverUri; public RedisBackedSessionMap(Tracer tracer, URI serverUri, EventBus bus) { super(tracer); Require.nonNull("Redis Server Uri", serverUri); - this.bus = Require.nonNull("Event bus", bus); + Require.nonNull("Event bus", bus); this.connection = new GridRedisClient(serverUri); this.serverUri = serverUri; - this.bus.addListener(SessionClosedEvent.sessionListener(this::remove)); + bus.addListener(SessionClosedEvent.sessionListener(this::remove)); - this.bus.addListener( + bus.addListener( NodeRemovedEvent.listener( nodeStatus -> nodeStatus.getSlots().stream() @@ -195,6 +194,7 @@ public Session get(SessionId id) throws NoSuchSessionException { CAPABILITIES_EVENT.accept(attributeMap, caps); span.addEvent("Retrieved session from the database", attributeMap); + //noinspection DataFlowIssue return new Session(id, uri, stereotype, caps, start); } } diff --git a/java/src/org/openqa/selenium/grid/sessionmap/redis/package-info.java b/java/src/org/openqa/selenium/grid/sessionmap/redis/package-info.java new file mode 100644 index 0000000000000..2fdbf65aadb04 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionmap/redis/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionmap.redis; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionmap/remote/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionmap/remote/BUILD.bazel index 1d250a65cbb16..59a80b1d3999a 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/remote/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionmap/remote/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") load("//java:defs.bzl", "java_library") java_library( @@ -16,5 +17,6 @@ java_library( "//java/src/org/openqa/selenium/grid/sessionmap/config", "//java/src/org/openqa/selenium/grid/web", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionmap/remote/RemoteSessionMap.java b/java/src/org/openqa/selenium/grid/sessionmap/remote/RemoteSessionMap.java index 5d49e2b6588c8..10383bb96c349 100644 --- a/java/src/org/openqa/selenium/grid/sessionmap/remote/RemoteSessionMap.java +++ b/java/src/org/openqa/selenium/grid/sessionmap/remote/RemoteSessionMap.java @@ -17,6 +17,7 @@ package org.openqa.selenium.grid.sessionmap.remote; +import static java.util.Objects.requireNonNull; import static org.openqa.selenium.remote.http.Contents.asJson; import static org.openqa.selenium.remote.http.HttpMethod.DELETE; import static org.openqa.selenium.remote.http.HttpMethod.GET; @@ -26,6 +27,7 @@ import java.lang.reflect.Type; import java.net.MalformedURLException; import java.net.URI; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.NoSuchSessionException; import org.openqa.selenium.grid.config.Config; import org.openqa.selenium.grid.data.Session; @@ -79,7 +81,7 @@ public boolean add(Session session) { HttpRequest request = new HttpRequest(POST, "/se/grid/session"); request.setContent(asJson(session)); - return makeRequest(request, Boolean.class); + return requireNonNull(makeRequest(request, Boolean.class)); } @Override @@ -111,6 +113,7 @@ public void remove(SessionId id) { makeRequest(new HttpRequest(DELETE, "/se/grid/session/" + id), Void.class); } + @Nullable private T makeRequest(HttpRequest request, Type typeOfT) { HttpTracing.inject(tracer, tracer.getCurrentContext(), request); diff --git a/java/src/org/openqa/selenium/grid/sessionmap/remote/package-info.java b/java/src/org/openqa/selenium/grid/sessionmap/remote/package-info.java new file mode 100644 index 0000000000000..54737da86682b --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionmap/remote/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionmap.remote; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionqueue/BUILD.bazel index 5f8c1b2f43e8f..c271ee4a19a72 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionqueue/BUILD.bazel @@ -23,5 +23,6 @@ java_library( "//java/src/org/openqa/selenium/remote/http", "//java/src/org/openqa/selenium/status", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/config/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionqueue/config/BUILD.bazel index 1e1d4f5482a0c..084b96412213d 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/config/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionqueue/config/BUILD.bazel @@ -15,5 +15,6 @@ java_library( "//java/src/org/openqa/selenium/grid/server", "//java/src/org/openqa/selenium/grid/sessionqueue", artifact("com.beust:jcommander"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/config/package-info.java b/java/src/org/openqa/selenium/grid/sessionqueue/config/package-info.java new file mode 100644 index 0000000000000..d73cbede10e00 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionqueue/config/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionqueue.config; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/httpd/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionqueue/httpd/BUILD.bazel index 91cb97059421e..55c443be94687 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/httpd/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionqueue/httpd/BUILD.bazel @@ -22,5 +22,6 @@ java_library( "//java/src/org/openqa/selenium/netty/server", artifact("com.beust:jcommander"), artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/httpd/package-info.java b/java/src/org/openqa/selenium/grid/sessionqueue/httpd/package-info.java new file mode 100644 index 0000000000000..c3bd786e10df0 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionqueue/httpd/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionqueue.httpd; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/local/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionqueue/local/BUILD.bazel index 601027e3142ff..ed8a8e69bbfdc 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/local/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionqueue/local/BUILD.bazel @@ -23,5 +23,6 @@ java_library( "//java/src/org/openqa/selenium/grid/sessionqueue/config", "//java/src/org/openqa/selenium/remote", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java b/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java index 8995553652080..75a46cf8db112 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java +++ b/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java @@ -42,6 +42,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.function.Predicate; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Capabilities; import org.openqa.selenium.SessionNotCreatedException; import org.openqa.selenium.concurrent.GuardedRunnable; @@ -110,7 +111,8 @@ public class LocalNewSessionQueue extends NewSessionQueue implements Closeable { thread.setName(NAME); return thread; }); - private final MBean jmxBean; + + @Nullable private final MBean jmxBean; public LocalNewSessionQueue( Tracer tracer, diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/local/package-info.java b/java/src/org/openqa/selenium/grid/sessionqueue/local/package-info.java new file mode 100644 index 0000000000000..708feadfdf3ad --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionqueue/local/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionqueue.local; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/package-info.java b/java/src/org/openqa/selenium/grid/sessionqueue/package-info.java new file mode 100644 index 0000000000000..48b760e8acc54 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionqueue/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionqueue; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/remote/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionqueue/remote/BUILD.bazel index 27ddf5111011e..85ead987f69c2 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/remote/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/sessionqueue/remote/BUILD.bazel @@ -21,5 +21,6 @@ java_library( "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/remote/package-info.java b/java/src/org/openqa/selenium/grid/sessionqueue/remote/package-info.java new file mode 100644 index 0000000000000..723c8fecc9db8 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionqueue/remote/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.sessionqueue.remote; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/grid/web/BUILD.bazel b/java/src/org/openqa/selenium/grid/web/BUILD.bazel index 6eccbb98ecbf4..b7893bca93ea9 100644 --- a/java/src/org/openqa/selenium/grid/web/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/web/BUILD.bazel @@ -16,5 +16,6 @@ java_library( "//java/src/org/openqa/selenium/remote", "//java/src/org/openqa/selenium/remote/http", artifact("com.google.guava:guava"), + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/web/MergedResource.java b/java/src/org/openqa/selenium/grid/web/MergedResource.java index 15bc14912b693..fc43898b4d8d3 100644 --- a/java/src/org/openqa/selenium/grid/web/MergedResource.java +++ b/java/src/org/openqa/selenium/grid/web/MergedResource.java @@ -20,6 +20,7 @@ import java.util.HashSet; import java.util.Optional; import java.util.Set; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; public class MergedResource implements Resource { @@ -31,7 +32,7 @@ public MergedResource(Resource base) { this(base, null); } - private MergedResource(Resource base, Resource next) { + private MergedResource(Resource base, @Nullable Resource next) { this.base = Require.nonNull("Base resource", base); this.next = Optional.ofNullable(next); } diff --git a/java/src/org/openqa/selenium/grid/web/NoHandler.java b/java/src/org/openqa/selenium/grid/web/NoHandler.java index f8e18fa6fa881..68790ecda4e2c 100644 --- a/java/src/org/openqa/selenium/grid/web/NoHandler.java +++ b/java/src/org/openqa/selenium/grid/web/NoHandler.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; import org.openqa.selenium.json.Json; import org.openqa.selenium.remote.http.HttpHandler; @@ -43,7 +44,7 @@ public NoHandler(Json json) { @Override public HttpResponse execute(HttpRequest req) throws UncheckedIOException { // We're not using ImmutableMap for the outer map because it disallows null values. - Map responseMap = new HashMap<>(); + Map responseMap = new HashMap<>(); responseMap.put("sessionId", null); responseMap.put( "value", diff --git a/java/src/org/openqa/selenium/grid/web/Values.java b/java/src/org/openqa/selenium/grid/web/Values.java index a567bd80a81d4..009de95a68c1d 100644 --- a/java/src/org/openqa/selenium/grid/web/Values.java +++ b/java/src/org/openqa/selenium/grid/web/Values.java @@ -61,7 +61,7 @@ public static T get(HttpResponse response, Type typeOfT) { } } - throw new IllegalStateException("Unable to locate value: " + string(response)); + throw new IllegalStateException("Unable to locate value: " + response.contentAsString()); } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/java/src/org/openqa/selenium/grid/web/package-info.java b/java/src/org/openqa/selenium/grid/web/package-info.java new file mode 100644 index 0000000000000..b2760746de7da --- /dev/null +++ b/java/src/org/openqa/selenium/grid/web/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.web; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/ie/InternetExplorerOptions.java b/java/src/org/openqa/selenium/ie/InternetExplorerOptions.java index 6f1c5ccc1c381..585bf376cedc3 100644 --- a/java/src/org/openqa/selenium/ie/InternetExplorerOptions.java +++ b/java/src/org/openqa/selenium/ie/InternetExplorerOptions.java @@ -89,7 +89,7 @@ public class InternetExplorerOptions extends AbstractDriverOptions ieOptions = new HashMap<>(); + private final Map ieOptions = new HashMap<>(); public InternetExplorerOptions() { setCapability(BROWSER_NAME, IE.browserName()); @@ -233,7 +233,7 @@ private InternetExplorerOptions amend(String optionName, Object value) { } @Override - public void setCapability(String key, Object value) { + public void setCapability(String key, @Nullable Object value) { if (IE_SWITCHES.equals(key)) { if (value instanceof List) { value = ((List) value).stream().map(Object::toString).collect(Collectors.joining(" ")); @@ -249,9 +249,9 @@ public void setCapability(String key, Object value) { if (IE_OPTIONS.equals(key)) { ieOptions.clear(); - Map streamFrom; + Map streamFrom; if (value instanceof Map) { - streamFrom = (Map) value; + streamFrom = (Map) value; } else if (value instanceof Capabilities) { streamFrom = ((Capabilities) value).asMap(); } else { diff --git a/java/src/org/openqa/selenium/ie/package-info.java b/java/src/org/openqa/selenium/ie/package-info.java index e2d33d56dc710..e341052be7f8a 100644 --- a/java/src/org/openqa/selenium/ie/package-info.java +++ b/java/src/org/openqa/selenium/ie/package-info.java @@ -15,35 +15,6 @@ // specific language governing permissions and limitations // under the License. -/** - * Mechanisms to configure and run selenium via the command line. There are two key classes {@link - * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. - * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, - * for which there are strongly-typed role-specific classes that use a {@code Config}, such as - * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. - * - *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, - * the process for building the set of flags to use is: - * - *

    - *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} - * and {@link org.openqa.selenium.grid.config.ConfigFlags} - *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link - * org.openqa.selenium.grid.config.HasRoles} where {@link - * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link - * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. - *
  3. Finally all flags returned by {@link - * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. - *
- * - *

The flags are then used by JCommander to parse the command arguments. Once that's done, the - * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of - * the flag objects with system properties and environment variables. This implies that each flag - * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. - * - *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's - * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. - */ @NullMarked package org.openqa.selenium.ie; diff --git a/java/src/org/openqa/selenium/internal/Debug.java b/java/src/org/openqa/selenium/internal/Debug.java index 5c694416f6890..c7786ccb7ff8d 100644 --- a/java/src/org/openqa/selenium/internal/Debug.java +++ b/java/src/org/openqa/selenium/internal/Debug.java @@ -29,7 +29,6 @@ public class Debug { private static final boolean IS_DEBUG; private static final AtomicBoolean DEBUG_WARNING_LOGGED = new AtomicBoolean(false); private static boolean loggerConfigured = false; - private static Logger seleniumLogger; static { IS_DEBUG = @@ -64,7 +63,7 @@ public static void configureLogger() { return; } - seleniumLogger = Logger.getLogger("org.openqa.selenium"); + Logger seleniumLogger = Logger.getLogger("org.openqa.selenium"); seleniumLogger.setLevel(Level.FINE); StreamHandler handler = new StreamHandler(System.err, new SimpleFormatter()); diff --git a/java/src/org/openqa/selenium/manager/BUILD.bazel b/java/src/org/openqa/selenium/manager/BUILD.bazel index d66cf5aa885ef..f0c3d396516e2 100644 --- a/java/src/org/openqa/selenium/manager/BUILD.bazel +++ b/java/src/org/openqa/selenium/manager/BUILD.bazel @@ -25,6 +25,7 @@ java_export( "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/os", + "@maven//:org_jspecify_jspecify", ], ) diff --git a/java/src/org/openqa/selenium/manager/SeleniumManager.java b/java/src/org/openqa/selenium/manager/SeleniumManager.java index bed9e56e08dbe..e042b500c281f 100644 --- a/java/src/org/openqa/selenium/manager/SeleniumManager.java +++ b/java/src/org/openqa/selenium/manager/SeleniumManager.java @@ -36,6 +36,7 @@ import java.util.Properties; import java.util.logging.Level; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Beta; import org.openqa.selenium.BuildInfo; import org.openqa.selenium.Platform; @@ -70,9 +71,11 @@ public class SeleniumManager { private static final String EXE = ".exe"; private static final String SE_ENV_PREFIX = "SE_"; - private static volatile SeleniumManager manager; - private final String managerPath = System.getenv("SE_MANAGER_PATH"); - private Path binary = managerPath == null ? null : Paths.get(managerPath); + @Nullable private static volatile SeleniumManager manager; + + @Nullable private final String managerPath = System.getenv("SE_MANAGER_PATH"); + + @Nullable private Path binary = managerPath == null ? null : Paths.get(managerPath); private final String seleniumManagerVersion; private boolean binaryInTemporalFolder = false; diff --git a/java/src/org/openqa/selenium/manager/SeleniumManagerOutput.java b/java/src/org/openqa/selenium/manager/SeleniumManagerOutput.java index 80629dde4e676..74413ec5b0d66 100644 --- a/java/src/org/openqa/selenium/manager/SeleniumManagerOutput.java +++ b/java/src/org/openqa/selenium/manager/SeleniumManagerOutput.java @@ -20,14 +20,16 @@ import java.util.Locale; import java.util.Objects; import java.util.logging.Level; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.internal.Require; import org.openqa.selenium.json.JsonInput; public class SeleniumManagerOutput { - private List logs; - private Result result; + private @Nullable List logs; + private @Nullable Result result; + @Nullable public List getLogs() { return logs; } @@ -37,6 +39,7 @@ public SeleniumManagerOutput setLogs(List logs) { return this; } + @Nullable public Result getResult() { return result; } @@ -115,15 +118,19 @@ private static Log fromJson(JsonInput input) { public static class Result { private final int code; - private final String message; - private final String driverPath; - private final String browserPath; + private final @Nullable String message; + private final @Nullable String driverPath; + private final @Nullable String browserPath; public Result(String driverPath) { this(0, null, driverPath, null); } - public Result(int code, String message, String driverPath, String browserPath) { + public Result( + int code, + @Nullable String message, + @Nullable String driverPath, + @Nullable String browserPath) { this.code = code; this.message = message; this.driverPath = driverPath; @@ -134,14 +141,17 @@ public int getCode() { return code; } + @Nullable public String getMessage() { return message; } + @Nullable public String getDriverPath() { return driverPath; } + @Nullable public String getBrowserPath() { return browserPath; } diff --git a/java/src/org/openqa/selenium/manager/package-info.java b/java/src/org/openqa/selenium/manager/package-info.java new file mode 100644 index 0000000000000..48c3543b54a28 --- /dev/null +++ b/java/src/org/openqa/selenium/manager/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.manager; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/net/UrlChecker.java b/java/src/org/openqa/selenium/net/UrlChecker.java index 1392d3dc204b2..44437f2a463c8 100644 --- a/java/src/org/openqa/selenium/net/UrlChecker.java +++ b/java/src/org/openqa/selenium/net/UrlChecker.java @@ -31,6 +31,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.io.Read; /** Polls a URL until a HTTP 200 response is received. */ @@ -59,7 +60,7 @@ public void waitUntilAvailable(long timeout, TimeUnit unit, final URL... urls) long start = System.currentTimeMillis(); LOG.fine("Waiting for " + Arrays.toString(urls)); try { - Future callback = + Future<@Nullable Void> callback = EXECUTOR.submit( () -> { HttpURLConnection connection = null; @@ -113,7 +114,7 @@ public void waitUntilUnavailable(long timeout, TimeUnit unit, final URL url) long start = System.currentTimeMillis(); LOG.fine("Waiting for " + url); try { - Future callback = + Future<@Nullable Void> callback = EXECUTOR.submit( () -> { HttpURLConnection connection = null; diff --git a/java/src/org/openqa/selenium/print/package-info.java b/java/src/org/openqa/selenium/print/package-info.java index 48ef4b6fe9f52..a1b118b15459b 100644 --- a/java/src/org/openqa/selenium/print/package-info.java +++ b/java/src/org/openqa/selenium/print/package-info.java @@ -15,35 +15,6 @@ // specific language governing permissions and limitations // under the License. -/** - * Mechanisms to configure and run selenium via the command line. There are two key classes {@link - * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. - * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, - * for which there are strongly-typed role-specific classes that use a {@code Config}, such as - * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. - * - *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, - * the process for building the set of flags to use is: - * - *

    - *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} - * and {@link org.openqa.selenium.grid.config.ConfigFlags} - *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link - * org.openqa.selenium.grid.config.HasRoles} where {@link - * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link - * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. - *
  3. Finally all flags returned by {@link - * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. - *
- * - *

The flags are then used by JCommander to parse the command arguments. Once that's done, the - * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of - * the flag objects with system properties and environment variables. This implies that each flag - * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. - * - *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's - * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. - */ @NullMarked package org.openqa.selenium.print; diff --git a/java/src/org/openqa/selenium/remote/AbstractDriverOptions.java b/java/src/org/openqa/selenium/remote/AbstractDriverOptions.java index 60c97d9befdeb..9017c3d4d9f75 100644 --- a/java/src/org/openqa/selenium/remote/AbstractDriverOptions.java +++ b/java/src/org/openqa/selenium/remote/AbstractDriverOptions.java @@ -34,6 +34,8 @@ import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.MutableCapabilities; import org.openqa.selenium.PageLoadStrategy; import org.openqa.selenium.Proxy; @@ -44,12 +46,12 @@ public abstract class AbstractDriverOptions extends MutableCapabilities { public DO setBrowserVersion(String browserVersion) { setCapability(BROWSER_VERSION, Require.nonNull("Browser version", browserVersion)); - return (DO) this; + return self(); } public DO setPlatformName(String platformName) { setCapability(PLATFORM_NAME, Require.nonNull("Platform Name", platformName)); - return (DO) this; + return self(); } public DO setImplicitWaitTimeout(Duration timeout) { @@ -57,7 +59,7 @@ public DO setImplicitWaitTimeout(Duration timeout) { timeouts.put("implicit", timeout.toMillis()); setCapability(TIMEOUTS, Collections.unmodifiableMap(timeouts)); - return (DO) this; + return self(); } public DO setPageLoadTimeout(Duration timeout) { @@ -65,7 +67,7 @@ public DO setPageLoadTimeout(Duration timeout) { timeouts.put("pageLoad", timeout.toMillis()); setCapability(TIMEOUTS, Collections.unmodifiableMap(timeouts)); - return (DO) this; + return self(); } public DO setScriptTimeout(Duration timeout) { @@ -73,37 +75,43 @@ public DO setScriptTimeout(Duration timeout) { timeouts.put("script", timeout.toMillis()); setCapability(TIMEOUTS, Collections.unmodifiableMap(timeouts)); - return (DO) this; + return self(); } public DO setPageLoadStrategy(PageLoadStrategy strategy) { setCapability(PAGE_LOAD_STRATEGY, Require.nonNull("Page load strategy", strategy)); - return (DO) this; + return self(); } public DO setUnhandledPromptBehaviour(UnexpectedAlertBehaviour behaviour) { setCapability( UNHANDLED_PROMPT_BEHAVIOUR, Require.nonNull("Unhandled prompt behavior", behaviour)); - return (DO) this; + return self(); } public DO setAcceptInsecureCerts(boolean acceptInsecureCerts) { setCapability(ACCEPT_INSECURE_CERTS, acceptInsecureCerts); - return (DO) this; + return self(); } public DO setStrictFileInteractability(boolean strictFileInteractability) { setCapability(STRICT_FILE_INTERACTABILITY, strictFileInteractability); - return (DO) this; + return self(); } public DO setProxy(Proxy proxy) { setCapability(PROXY, Require.nonNull("Proxy", proxy)); - return (DO) this; + return self(); } public DO setEnableDownloads(boolean enableDownloads) { setCapability(ENABLE_DOWNLOADS, enableDownloads); + return self(); + } + + @NonNull + @SuppressWarnings("unchecked") + private DO self() { return (DO) this; } @@ -126,6 +134,7 @@ public Object getCapability(String capabilityName) { return super.getCapability(capabilityName); } + @Nullable protected abstract Object getExtraCapability(String capabilityName); @Override diff --git a/java/src/org/openqa/selenium/remote/RemoteLogs.java b/java/src/org/openqa/selenium/remote/RemoteLogs.java index 75607fcd9a711..7d0f230d6b324 100644 --- a/java/src/org/openqa/selenium/remote/RemoteLogs.java +++ b/java/src/org/openqa/selenium/remote/RemoteLogs.java @@ -25,6 +25,7 @@ import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; import org.openqa.selenium.Beta; import org.openqa.selenium.UnsupportedCommandException; @@ -111,18 +112,19 @@ private LogEntries getRemoteEntries(String logType) { } @SuppressWarnings("unchecked") List> rawList = (List>) raw; - List remoteEntries = new ArrayList<>(rawList.size()); + List remoteEntries = + rawList.stream().map(this::createLogEntry).collect(Collectors.toList()); - for (Map obj : rawList) { - remoteEntries.add( - new LogEntry( - LogLevelMapping.toLevel((String) obj.get(LEVEL)), - (Long) obj.get(TIMESTAMP), - (String) obj.get(MESSAGE))); - } return new LogEntries(remoteEntries); } + private LogEntry createLogEntry(Map obj) { + return new LogEntry( + LogLevelMapping.toLevel((String) obj.get(LEVEL)), + (Long) obj.get(TIMESTAMP), + (String) obj.get(MESSAGE)); + } + /** * @deprecated logging is not in the W3C WebDriver spec and LocalLogs are no longer supported. */ @@ -146,11 +148,9 @@ private Set getAvailableLocalLogs() { } @Override - @SuppressWarnings("deprecation") public Set getAvailableLogTypes() { - Object raw = executeMethod.execute(DriverCommand.GET_AVAILABLE_LOG_TYPES, null); - @SuppressWarnings("unchecked") - List rawList = (List) raw; + List rawList = + executeMethod.executeRequired(DriverCommand.GET_AVAILABLE_LOG_TYPES, null); Set builder = new LinkedHashSet<>(); builder.addAll(rawList); builder.addAll(getAvailableLocalLogs()); diff --git a/java/src/org/openqa/selenium/safari/package-info.java b/java/src/org/openqa/selenium/safari/package-info.java index 6684283ee6752..f679c88a7dc10 100644 --- a/java/src/org/openqa/selenium/safari/package-info.java +++ b/java/src/org/openqa/selenium/safari/package-info.java @@ -15,35 +15,6 @@ // specific language governing permissions and limitations // under the License. -/** - * Mechanisms to configure and run selenium via the command line. There are two key classes {@link - * org.openqa.selenium.cli.CliCommand} and {@link org.openqa.selenium.grid.config.HasRoles}. - * Ultimately, these are used to build a {@link org.openqa.selenium.grid.config.Config} instance, - * for which there are strongly-typed role-specific classes that use a {@code Config}, such as - * {@link org.openqa.selenium.grid.node.docker.DockerOptions}. - * - *

Assuming your {@code CliCommand} extends {@link org.openqa.selenium.grid.TemplateGridCommand}, - * the process for building the set of flags to use is: - * - *

    - *
  1. The default flags are added (these are {@link org.openqa.selenium.grid.server.HelpFlags} - * and {@link org.openqa.selenium.grid.config.ConfigFlags} - *
  2. {@link java.util.ServiceLoader} is used to find all implementations of {@link - * org.openqa.selenium.grid.config.HasRoles} where {@link - * org.openqa.selenium.grid.config.HasRoles#getRoles()} is contained within {@link - * org.openqa.selenium.cli.CliCommand#getConfigurableRoles()}. - *
  3. Finally all flags returned by {@link - * org.openqa.selenium.grid.TemplateGridCommand#getFlagObjects()} are added. - *
- * - *

The flags are then used by JCommander to parse the command arguments. Once that's done, the - * raw flags are converted to a {@link org.openqa.selenium.grid.config.Config} by combining all of - * the flag objects with system properties and environment variables. This implies that each flag - * object has annotated each field with {@link org.openqa.selenium.grid.config.ConfigValue}. - * - *

Ultimately, this means that flag objects have all (most?) fields annotated with JCommander's - * {@link com.beust.jcommander.Parameter} annotation as well as {@code ConfigValue}. - */ @NullMarked package org.openqa.selenium.safari; diff --git a/java/test/org/openqa/selenium/grid/config/BUILD.bazel b/java/test/org/openqa/selenium/grid/config/BUILD.bazel index 9575e42d240fb..a9584208da731 100644 --- a/java/test/org/openqa/selenium/grid/config/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/config/BUILD.bazel @@ -13,5 +13,6 @@ java_test_suite( artifact("com.google.guava:guava"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/config/package-info.java b/java/test/org/openqa/selenium/grid/config/package-info.java new file mode 100644 index 0000000000000..f1393763f3149 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/config/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.config; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/data/BUILD.bazel b/java/test/org/openqa/selenium/grid/data/BUILD.bazel index 8473abbb9e4e8..ac313a9987ae9 100644 --- a/java/test/org/openqa/selenium/grid/data/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/data/BUILD.bazel @@ -14,5 +14,6 @@ java_test_suite( artifact("com.google.guava:guava"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/data/DefaultSlotMatcherTest.java b/java/test/org/openqa/selenium/grid/data/DefaultSlotMatcherTest.java index 7b5fb5562707a..99ca4a4c722ef 100644 --- a/java/test/org/openqa/selenium/grid/data/DefaultSlotMatcherTest.java +++ b/java/test/org/openqa/selenium/grid/data/DefaultSlotMatcherTest.java @@ -26,6 +26,7 @@ import org.openqa.selenium.Platform; import org.openqa.selenium.remote.CapabilityType; +@SuppressWarnings("EqualsWithItself") class DefaultSlotMatcherTest { private final DefaultSlotMatcher slotMatcher = new DefaultSlotMatcher(); diff --git a/java/test/org/openqa/selenium/grid/data/package-info.java b/java/test/org/openqa/selenium/grid/data/package-info.java new file mode 100644 index 0000000000000..578785a87ff4c --- /dev/null +++ b/java/test/org/openqa/selenium/grid/data/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.data; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/distributor/AddingNodesTest.java b/java/test/org/openqa/selenium/grid/distributor/AddingNodesTest.java index 67f88b7565410..685cc258952dd 100644 --- a/java/test/org/openqa/selenium/grid/distributor/AddingNodesTest.java +++ b/java/test/org/openqa/selenium/grid/distributor/AddingNodesTest.java @@ -35,6 +35,7 @@ import java.util.Set; import java.util.UUID; import java.util.function.Function; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.openqa.selenium.Capabilities; @@ -83,7 +84,7 @@ class AddingNodesTest { private static final Secret registrationSecret = new Secret("caerphilly"); private static final int newSessionThreadPoolSize = Runtime.getRuntime().availableProcessors(); - private Distributor distributor; + private @Nullable Distributor distributor; private Tracer tracer; private EventBus bus; private Wait wait; @@ -384,7 +385,7 @@ static class CustomNode extends Node { private final EventBus bus; private final Function factory; - private Session running; + private @Nullable Session running; protected CustomNode( EventBus bus, diff --git a/java/test/org/openqa/selenium/grid/distributor/BUILD.bazel b/java/test/org/openqa/selenium/grid/distributor/BUILD.bazel index 822b2bde76599..1e4728ddf8183 100644 --- a/java/test/org/openqa/selenium/grid/distributor/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/distributor/BUILD.bazel @@ -37,6 +37,7 @@ java_selenium_test_suite( artifact("org.assertj:assertj-core"), "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/remote/http", + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/distributor/local/BUILD.bazel b/java/test/org/openqa/selenium/grid/distributor/local/BUILD.bazel index 5cef7dac794cf..d43ed7e843143 100644 --- a/java/test/org/openqa/selenium/grid/distributor/local/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/distributor/local/BUILD.bazel @@ -32,5 +32,6 @@ java_test_suite( artifact("io.opentelemetry:opentelemetry-api"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/distributor/local/LocalNodeRegistryTest.java b/java/test/org/openqa/selenium/grid/distributor/local/LocalNodeRegistryTest.java index 6db9833de7d4b..35e479d9160a1 100644 --- a/java/test/org/openqa/selenium/grid/distributor/local/LocalNodeRegistryTest.java +++ b/java/test/org/openqa/selenium/grid/distributor/local/LocalNodeRegistryTest.java @@ -86,6 +86,7 @@ void setUp() { } @AfterEach + @SuppressWarnings("ConstantValue") void tearDown() { if (registry != null) { registry.close(); diff --git a/java/test/org/openqa/selenium/grid/distributor/local/package-info.java b/java/test/org/openqa/selenium/grid/distributor/local/package-info.java new file mode 100644 index 0000000000000..29ed00ad32e30 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/distributor/local/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.distributor.local; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/distributor/package-info.java b/java/test/org/openqa/selenium/grid/distributor/package-info.java new file mode 100644 index 0000000000000..6422c2c745391 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/distributor/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.distributor; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/distributor/selector/BUILD.bazel b/java/test/org/openqa/selenium/grid/distributor/selector/BUILD.bazel index 6e87972283521..2ec64dbfc47a9 100644 --- a/java/test/org/openqa/selenium/grid/distributor/selector/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/distributor/selector/BUILD.bazel @@ -20,5 +20,6 @@ java_test_suite( artifact("com.google.guava:guava"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/distributor/selector/package-info.java b/java/test/org/openqa/selenium/grid/distributor/selector/package-info.java new file mode 100644 index 0000000000000..0344710f634c2 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/distributor/selector/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.distributor.selector; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/graphql/BUILD.bazel b/java/test/org/openqa/selenium/grid/graphql/BUILD.bazel index 083c1fcf84545..583baf073a2a1 100644 --- a/java/test/org/openqa/selenium/grid/graphql/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/graphql/BUILD.bazel @@ -33,5 +33,6 @@ java_test_suite( artifact("com.google.guava:guava"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/graphql/package-info.java b/java/test/org/openqa/selenium/grid/graphql/package-info.java new file mode 100644 index 0000000000000..df07cd66175e3 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/graphql/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.graphql; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/gridui/BUILD.bazel b/java/test/org/openqa/selenium/grid/gridui/BUILD.bazel index 9a6fa557a37de..ecd9c10e98b39 100644 --- a/java/test/org/openqa/selenium/grid/gridui/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/gridui/BUILD.bazel @@ -24,5 +24,6 @@ java_selenium_test_suite( artifact("com.google.guava:guava"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/gridui/package-info.java b/java/test/org/openqa/selenium/grid/gridui/package-info.java new file mode 100644 index 0000000000000..bcfc99636f7e5 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/gridui/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.gridui; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/node/BUILD.bazel b/java/test/org/openqa/selenium/grid/node/BUILD.bazel index 370858f09eeae..1d3ab581d282c 100644 --- a/java/test/org/openqa/selenium/grid/node/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/node/BUILD.bazel @@ -29,5 +29,6 @@ java_test_suite( artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), artifact("org.mockito:mockito-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/node/ProxyNodeWebsocketsTest.java b/java/test/org/openqa/selenium/grid/node/ProxyNodeWebsocketsTest.java index 70c2d17a18853..ac4fe45bfa1b9 100644 --- a/java/test/org/openqa/selenium/grid/node/ProxyNodeWebsocketsTest.java +++ b/java/test/org/openqa/selenium/grid/node/ProxyNodeWebsocketsTest.java @@ -26,6 +26,7 @@ import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; import org.openqa.selenium.Capabilities; import org.openqa.selenium.ImmutableCapabilities; @@ -366,13 +367,13 @@ public boolean isReady() { private static class CountingStubNode extends StubNode { private final SessionId ownedSession; - private final Session session; + private final @Nullable Session session; private final AtomicInteger acquireCount; private final AtomicInteger releaseCount; CountingStubNode( SessionId ownedSession, - Session session, + @Nullable Session session, AtomicInteger acquireCount, AtomicInteger releaseCount) { super(new NodeId(UUID.randomUUID()), URI.create("http://localhost:5555")); @@ -400,6 +401,9 @@ public void releaseConnection(SessionId id) { @Override public Session getSession(SessionId id) { + if (session == null) { + throw new UnsupportedOperationException(); + } return session; } } diff --git a/java/test/org/openqa/selenium/grid/node/config/BUILD.bazel b/java/test/org/openqa/selenium/grid/node/config/BUILD.bazel index 95a52fb23cecd..9e078cb73af33 100644 --- a/java/test/org/openqa/selenium/grid/node/config/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/node/config/BUILD.bazel @@ -41,6 +41,7 @@ java_test_suite( artifact("org.assertj:assertj-core"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.mockito:mockito-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) @@ -77,5 +78,6 @@ java_test_suite( artifact("org.assertj:assertj-core"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.mockito:mockito-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/node/config/SessionCapabilitiesMutatorTest.java b/java/test/org/openqa/selenium/grid/node/config/SessionCapabilitiesMutatorTest.java index 527860fc38d0a..b1e3a808b296f 100644 --- a/java/test/org/openqa/selenium/grid/node/config/SessionCapabilitiesMutatorTest.java +++ b/java/test/org/openqa/selenium/grid/node/config/SessionCapabilitiesMutatorTest.java @@ -32,24 +32,20 @@ import org.openqa.selenium.ImmutableCapabilities; public class SessionCapabilitiesMutatorTest { - - private SessionCapabilitiesMutator sessionCapabilitiesMutator; - private Capabilities stereotype; - private Capabilities capabilities; - @Test void shouldMergeStereotypeWithoutOptionsWithCapsWithOptions() { - stereotype = + Capabilities stereotype = new ImmutableCapabilities( "browserName", "chrome", "unhandledPromptBehavior", "accept"); - sessionCapabilitiesMutator = new SessionCapabilitiesMutator(stereotype); + SessionCapabilitiesMutator sessionCapabilitiesMutator = + new SessionCapabilitiesMutator(stereotype); Map chromeOptions = new HashMap<>(); chromeOptions.put("args", List.of("incognito", "window-size=500,500")); - capabilities = + Capabilities capabilities = new ImmutableCapabilities( "browserName", "chrome", "goog:chromeOptions", chromeOptions, @@ -74,15 +70,16 @@ void shouldMergeStereotypeWithOptionsWithCapsWithoutOptions() { Map chromeOptions = new HashMap<>(); chromeOptions.put("args", List.of("incognito", "window-size=500,500")); - stereotype = + Capabilities stereotype = new ImmutableCapabilities( "browserName", "chrome", "goog:chromeOptions", chromeOptions, "unhandledPromptBehavior", "accept"); - sessionCapabilitiesMutator = new SessionCapabilitiesMutator(stereotype); + SessionCapabilitiesMutator sessionCapabilitiesMutator = + new SessionCapabilitiesMutator(stereotype); - capabilities = + Capabilities capabilities = new ImmutableCapabilities( "browserName", "chrome", "pageLoadStrategy", "normal"); @@ -113,10 +110,11 @@ void shouldMergeChromeSpecificOptionsFromStereotypeAndCaps() { stereotypeOptions.put("opt1", "val1"); stereotypeOptions.put("opt2", "val4"); - stereotype = + Capabilities stereotype = new ImmutableCapabilities("browserName", "chrome", "goog:chromeOptions", stereotypeOptions); - sessionCapabilitiesMutator = new SessionCapabilitiesMutator(stereotype); + SessionCapabilitiesMutator sessionCapabilitiesMutator = + new SessionCapabilitiesMutator(stereotype); Map capabilityOptions = new HashMap<>(); capabilityOptions.put("args", List.of("incognito", "--headless")); @@ -125,7 +123,7 @@ void shouldMergeChromeSpecificOptionsFromStereotypeAndCaps() { capabilityOptions.put("opt2", "val2"); capabilityOptions.put("opt3", "val3"); - capabilities = + Capabilities capabilities = new ImmutableCapabilities("browserName", "chrome", "goog:chromeOptions", capabilityOptions); Map modifiedCapabilities = @@ -171,11 +169,12 @@ void shouldMergeEdgeSpecificOptionsFromStereotypeAndCaps() { stereotypeOptions.put("opt1", "val1"); stereotypeOptions.put("opt2", "val4"); - stereotype = + Capabilities stereotype = new ImmutableCapabilities( "browserName", "microsoftedge", "ms:edgeOptions", stereotypeOptions); - sessionCapabilitiesMutator = new SessionCapabilitiesMutator(stereotype); + SessionCapabilitiesMutator sessionCapabilitiesMutator = + new SessionCapabilitiesMutator(stereotype); Map capabilityOptions = new HashMap<>(); capabilityOptions.put("args", List.of("incognito", "--headless")); @@ -184,7 +183,7 @@ void shouldMergeEdgeSpecificOptionsFromStereotypeAndCaps() { capabilityOptions.put("opt2", "val2"); capabilityOptions.put("opt3", "val3"); - capabilities = + Capabilities capabilities = new ImmutableCapabilities( "browserName", "microsoftedge", "ms:edgeOptions", capabilityOptions); @@ -237,11 +236,12 @@ void shouldMergeFirefoxSpecificOptionsFromStereotypeAndCaps() { stereotypeOptions.put("profile", "profile-string"); stereotypeOptions.put("androidDeviceSerial", "emulator-5556"); - stereotype = + Capabilities stereotype = new ImmutableCapabilities( "browserName", "firefox", "moz:firefoxOptions", stereotypeOptions); - sessionCapabilitiesMutator = new SessionCapabilitiesMutator(stereotype); + SessionCapabilitiesMutator sessionCapabilitiesMutator = + new SessionCapabilitiesMutator(stereotype); Map capabilityOptions = new HashMap<>(); capabilityOptions.put("args", Collections.singletonList("-headless")); @@ -260,7 +260,7 @@ void shouldMergeFirefoxSpecificOptionsFromStereotypeAndCaps() { capabilityOptions.put("binary", "/path/to/caps/binary"); capabilityOptions.put("androidPackage", "com.android.chrome"); - capabilities = + Capabilities capabilities = new ImmutableCapabilities( "browserName", "firefox", "moz:firefoxOptions", capabilityOptions); @@ -321,15 +321,16 @@ void shouldMergeFirefoxSpecificOptionsFromStereotypeAndCaps() { @Test void shouldMergeTopLevelStereotypeAndCaps() { - stereotype = + Capabilities stereotype = new ImmutableCapabilities( "browserName", "chrome", "unhandledPromptBehavior", "accept", "pageLoadStrategy", "eager"); - sessionCapabilitiesMutator = new SessionCapabilitiesMutator(stereotype); + SessionCapabilitiesMutator sessionCapabilitiesMutator = + new SessionCapabilitiesMutator(stereotype); - capabilities = + Capabilities capabilities = new ImmutableCapabilities( "browserName", "chrome", "pageLoadStrategy", "normal"); @@ -344,11 +345,12 @@ void shouldMergeTopLevelStereotypeAndCaps() { @Test void shouldAllowUnknownBrowserNames() { - stereotype = new ImmutableCapabilities("browserName", "safari"); + Capabilities stereotype = new ImmutableCapabilities("browserName", "safari"); - sessionCapabilitiesMutator = new SessionCapabilitiesMutator(stereotype); + SessionCapabilitiesMutator sessionCapabilitiesMutator = + new SessionCapabilitiesMutator(stereotype); - capabilities = new ImmutableCapabilities("browserName", "safari"); + Capabilities capabilities = new ImmutableCapabilities("browserName", "safari"); Map modifiedCapabilities = sessionCapabilitiesMutator.apply(capabilities).asMap(); diff --git a/java/test/org/openqa/selenium/grid/node/config/package-info.java b/java/test/org/openqa/selenium/grid/node/config/package-info.java new file mode 100644 index 0000000000000..51df08ff26d59 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/node/config/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.config; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/node/data/BUILD.bazel b/java/test/org/openqa/selenium/grid/node/data/BUILD.bazel index ae130811e9ee7..c38143bf3639f 100644 --- a/java/test/org/openqa/selenium/grid/node/data/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/node/data/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") load("//java:defs.bzl", "java_library") java_library( @@ -13,5 +14,6 @@ java_library( "//java/src/org/openqa/selenium/grid/security", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/test/org/openqa/selenium/grid/node/data/package-info.java b/java/test/org/openqa/selenium/grid/node/data/package-info.java new file mode 100644 index 0000000000000..660d3c87372f7 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/node/data/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.data; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/node/docker/BUILD.bazel b/java/test/org/openqa/selenium/grid/node/docker/BUILD.bazel index 9704b17e181ae..f83a1e10e8086 100644 --- a/java/test/org/openqa/selenium/grid/node/docker/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/node/docker/BUILD.bazel @@ -19,5 +19,6 @@ java_test_suite( artifact("org.junit.jupiter:junit-jupiter-params"), artifact("org.assertj:assertj-core"), artifact("org.mockito:mockito-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/node/docker/package-info.java b/java/test/org/openqa/selenium/grid/node/docker/package-info.java new file mode 100644 index 0000000000000..b2e9f6b6b6b1b --- /dev/null +++ b/java/test/org/openqa/selenium/grid/node/docker/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.docker; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/node/kubernetes/BUILD.bazel b/java/test/org/openqa/selenium/grid/node/kubernetes/BUILD.bazel index d7deafd35677b..3cabfda65c9e6 100644 --- a/java/test/org/openqa/selenium/grid/node/kubernetes/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/node/kubernetes/BUILD.bazel @@ -19,5 +19,6 @@ java_test_suite( artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), artifact("org.mockito:mockito-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/node/kubernetes/package-info.java b/java/test/org/openqa/selenium/grid/node/kubernetes/package-info.java new file mode 100644 index 0000000000000..6f37afbf08e4c --- /dev/null +++ b/java/test/org/openqa/selenium/grid/node/kubernetes/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.kubernetes; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/node/local/BUILD.bazel b/java/test/org/openqa/selenium/grid/node/local/BUILD.bazel index 687a49de94817..98a0f03219f32 100644 --- a/java/test/org/openqa/selenium/grid/node/local/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/node/local/BUILD.bazel @@ -20,5 +20,6 @@ java_test_suite( artifact("io.opentelemetry:opentelemetry-api"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/node/local/package-info.java b/java/test/org/openqa/selenium/grid/node/local/package-info.java new file mode 100644 index 0000000000000..f47f0602b0030 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/node/local/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.local; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/node/package-info.java b/java/test/org/openqa/selenium/grid/node/package-info.java new file mode 100644 index 0000000000000..f858c033d8a46 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/node/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/node/relay/BUILD.bazel b/java/test/org/openqa/selenium/grid/node/relay/BUILD.bazel index 274b85c1c6869..f9a1525fc7fef 100644 --- a/java/test/org/openqa/selenium/grid/node/relay/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/node/relay/BUILD.bazel @@ -17,5 +17,6 @@ java_test_suite( artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), artifact("org.mockito:mockito-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/node/relay/package-info.java b/java/test/org/openqa/selenium/grid/node/relay/package-info.java new file mode 100644 index 0000000000000..d0133b3427817 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/node/relay/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.node.relay; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/router/BUILD.bazel b/java/test/org/openqa/selenium/grid/router/BUILD.bazel index 33ca056590d9b..3dfb296f47f56 100644 --- a/java/test/org/openqa/selenium/grid/router/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/router/BUILD.bazel @@ -38,6 +38,7 @@ java_library( "//java/src/org/openqa/selenium/remote/http", "//java/src/org/openqa/selenium/support", "//java/test/org/openqa/selenium/testing:test-base", + artifact("org.jspecify:jspecify"), ], ) @@ -67,6 +68,7 @@ java_selenium_test_suite( artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.junit.jupiter:junit-jupiter-params"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + CDP_DEPS + JUNIT5_DEPS, ) @@ -98,6 +100,7 @@ java_selenium_test_suite( artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.junit.jupiter:junit-jupiter-params"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + CDP_DEPS + JUNIT5_DEPS, ) @@ -120,6 +123,7 @@ java_selenium_test_suite( "//java/test/org/openqa/selenium/testing:test-base", artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + CDP_DEPS + JUNIT5_DEPS, ) @@ -156,5 +160,6 @@ java_test_suite( artifact("org.junit.jupiter:junit-jupiter-params"), artifact("org.assertj:assertj-core"), artifact("org.zeromq:jeromq"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/router/RemoteWebDriverBiDiTest.java b/java/test/org/openqa/selenium/grid/router/RemoteWebDriverBiDiTest.java index fe507998dfbab..8b19e34e3ce03 100644 --- a/java/test/org/openqa/selenium/grid/router/RemoteWebDriverBiDiTest.java +++ b/java/test/org/openqa/selenium/grid/router/RemoteWebDriverBiDiTest.java @@ -80,6 +80,7 @@ void setup() { } @AfterEach + @SuppressWarnings("ConstantValue") void tearDownDeployment() { if (localDriver != null) { localDriver.quit(); diff --git a/java/test/org/openqa/selenium/grid/router/SessionCleanUpTest.java b/java/test/org/openqa/selenium/grid/router/SessionCleanUpTest.java index 623e9656959d9..3eaf57162f4fc 100644 --- a/java/test/org/openqa/selenium/grid/router/SessionCleanUpTest.java +++ b/java/test/org/openqa/selenium/grid/router/SessionCleanUpTest.java @@ -119,6 +119,7 @@ public void setup() { } @AfterEach + @SuppressWarnings("ConstantValue") public void stopServer() { if (server != null) { server.stop(); diff --git a/java/test/org/openqa/selenium/grid/router/package-info.java b/java/test/org/openqa/selenium/grid/router/package-info.java new file mode 100644 index 0000000000000..1ed3dadc439f8 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/router/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.router; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/security/BUILD.bazel b/java/test/org/openqa/selenium/grid/security/BUILD.bazel index e0c6b83987074..7e75330e3d60a 100644 --- a/java/test/org/openqa/selenium/grid/security/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/security/BUILD.bazel @@ -9,5 +9,6 @@ java_test_suite( "//java/src/org/openqa/selenium/remote/http", artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/security/package-info.java b/java/test/org/openqa/selenium/grid/security/package-info.java new file mode 100644 index 0000000000000..29b62cab27d5a --- /dev/null +++ b/java/test/org/openqa/selenium/grid/security/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.security; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/grid/server/BUILD.bazel b/java/test/org/openqa/selenium/grid/server/BUILD.bazel index 2cf09abe9c814..fc7965c8ddded 100644 --- a/java/test/org/openqa/selenium/grid/server/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/server/BUILD.bazel @@ -21,5 +21,6 @@ java_test_suite( artifact("com.google.guava:guava"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/grid/server/package-info.java b/java/test/org/openqa/selenium/grid/server/package-info.java new file mode 100644 index 0000000000000..adecd3fa97fd3 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/server/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.grid.server; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/remote/RemoteLogsTest.java b/java/test/org/openqa/selenium/remote/RemoteLogsTest.java index 49201de14c0aa..e9044760a0e93 100644 --- a/java/test/org/openqa/selenium/remote/RemoteLogsTest.java +++ b/java/test/org/openqa/selenium/remote/RemoteLogsTest.java @@ -133,7 +133,7 @@ void throwsOnBogusRemoteLogsResponse() { @Test void canGetAvailableLogTypes() { List remoteAvailableLogTypes = List.of(LogType.PROFILER, LogType.SERVER); - when(executeMethod.execute(DriverCommand.GET_AVAILABLE_LOG_TYPES, null)) + when(executeMethod.executeRequired(DriverCommand.GET_AVAILABLE_LOG_TYPES, null)) .thenReturn(remoteAvailableLogTypes); Set localAvailableLogTypes = Set.of(LogType.PROFILER, LogType.CLIENT); From 45911b8eeff4a6d7b97ed6aeea987a9af4b0f52e Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:14:15 +0300 Subject: [PATCH 46/67] [dotnet] [bidi] Revert... Wait until events are dispatched (#17178) Revert "[dotnet] [bidi] Wait until events are dispatched when unsubscribing (#17142)" This reverts commit 2f7127244336d9a585874ff6bc4a8fa336c64414. --- dotnet/src/webdriver/BiDi/EventDispatcher.cs | 154 +++---------------- 1 file changed, 23 insertions(+), 131 deletions(-) diff --git a/dotnet/src/webdriver/BiDi/EventDispatcher.cs b/dotnet/src/webdriver/BiDi/EventDispatcher.cs index 5a37ee58d9f7b..1a4e93006821c 100644 --- a/dotnet/src/webdriver/BiDi/EventDispatcher.cs +++ b/dotnet/src/webdriver/BiDi/EventDispatcher.cs @@ -34,18 +34,20 @@ internal sealed class EventDispatcher : IAsyncDisposable private readonly ConcurrentDictionary _events = new(); - private readonly Channel _pendingEvents = Channel.CreateUnbounded(new() + private readonly Channel _pendingEvents = Channel.CreateUnbounded(new() { SingleReader = true, SingleWriter = true }); - private readonly Task _processEventsTask; + private readonly Task _eventEmitterTask; + + private static readonly TaskFactory _myTaskFactory = new(CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskContinuationOptions.None, TaskScheduler.Default); public EventDispatcher(Func sessionProvider) { _sessionProvider = sessionProvider; - _processEventsTask = Task.Run(ProcessEventsAsync); + _eventEmitterTask = _myTaskFactory.StartNew(ProcessEventsAwaiterAsync).Unwrap(); } public async Task SubscribeAsync(string eventName, EventHandler eventHandler, SubscriptionOptions? options, JsonTypeInfo jsonTypeInfo, CancellationToken cancellationToken) @@ -53,19 +55,11 @@ public async Task SubscribeAsync(string eventName, Eve { var registration = _events.GetOrAdd(eventName, _ => new EventRegistration(jsonTypeInfo)); - registration.AddHandler(eventHandler); + var subscribeResult = await _sessionProvider().SubscribeAsync([eventName], new() { Contexts = options?.Contexts, UserContexts = options?.UserContexts }, cancellationToken).ConfigureAwait(false); - try - { - var subscribeResult = await _sessionProvider().SubscribeAsync([eventName], new() { Contexts = options?.Contexts, UserContexts = options?.UserContexts }, cancellationToken).ConfigureAwait(false); + registration.Handlers.Add(eventHandler); - return new Subscription(subscribeResult.Subscription, this, eventHandler); - } - catch - { - registration.RemoveHandler(eventHandler); - throw; - } + return new Subscription(subscribeResult.Subscription, this, eventHandler); } public async ValueTask UnsubscribeAsync(Subscription subscription, CancellationToken cancellationToken) @@ -73,34 +67,15 @@ public async ValueTask UnsubscribeAsync(Subscription subscription, CancellationT if (_events.TryGetValue(subscription.EventHandler.EventName, out var registration)) { await _sessionProvider().UnsubscribeAsync([subscription.SubscriptionId], null, cancellationToken).ConfigureAwait(false); - - // Wait until all pending events for this method are dispatched - try - { - await registration.DrainAsync(cancellationToken).ConfigureAwait(false); - } - finally - { - registration.RemoveHandler(subscription.EventHandler); - } + registration.Handlers.Remove(subscription.EventHandler); } } public void EnqueueEvent(string method, ReadOnlyMemory jsonUtf8Bytes, IBiDi bidi) { - if (_events.TryGetValue(method, out var registration)) + if (_events.TryGetValue(method, out var registration) && registration.TypeInfo is not null) { - if (_pendingEvents.Writer.TryWrite(new EventItem(jsonUtf8Bytes, bidi, registration))) - { - registration.IncrementEnqueued(); - } - else - { - if (_logger.IsEnabled(LogEventLevel.Warn)) - { - _logger.Warn($"Failed to enqueue BiDi event with method '{method}' for processing. Event will be ignored."); - } - } + _pendingEvents.Writer.TryWrite(new PendingEvent(method, jsonUtf8Bytes, bidi, registration.TypeInfo)); } else { @@ -111,45 +86,34 @@ public void EnqueueEvent(string method, ReadOnlyMemory jsonUtf8Bytes, IBiD } } - private async Task ProcessEventsAsync() + private async Task ProcessEventsAwaiterAsync() { var reader = _pendingEvents.Reader; - while (await reader.WaitToReadAsync().ConfigureAwait(false)) { - while (reader.TryRead(out var evt)) + while (reader.TryRead(out var result)) { try { - var eventArgs = (EventArgs)JsonSerializer.Deserialize(evt.JsonUtf8Bytes.Span, evt.Registration.TypeInfo)!; - eventArgs.BiDi = evt.BiDi; - - foreach (var handler in evt.Registration.GetHandlersSnapshot()) + if (_events.TryGetValue(result.Method, out var registration)) { - try + // Deserialize on background thread instead of network thread (single parse) + var eventArgs = (EventArgs)JsonSerializer.Deserialize(result.JsonUtf8Bytes.Span, result.TypeInfo)!; + eventArgs.BiDi = result.BiDi; + + foreach (var handler in registration.Handlers.ToArray()) // copy handlers avoiding modified collection while iterating { await handler.InvokeAsync(eventArgs).ConfigureAwait(false); } - catch (Exception ex) - { - if (_logger.IsEnabled(LogEventLevel.Error)) - { - _logger.Error($"Unhandled error processing BiDi event handler: {ex}"); - } - } } } catch (Exception ex) { if (_logger.IsEnabled(LogEventLevel.Error)) { - _logger.Error($"Unhandled error deserializing BiDi event: {ex}"); + _logger.Error($"Unhandled error processing BiDi event handler: {ex}"); } } - finally - { - evt.Registration.IncrementProcessed(); - } } } } @@ -158,88 +122,16 @@ public async ValueTask DisposeAsync() { _pendingEvents.Writer.Complete(); - await _processEventsTask.ConfigureAwait(false); + await _eventEmitterTask.ConfigureAwait(false); GC.SuppressFinalize(this); } - private sealed record EventItem(ReadOnlyMemory JsonUtf8Bytes, IBiDi BiDi, EventRegistration Registration); + private readonly record struct PendingEvent(string Method, ReadOnlyMemory JsonUtf8Bytes, IBiDi BiDi, JsonTypeInfo TypeInfo); private sealed class EventRegistration(JsonTypeInfo typeInfo) { - private long _enqueueSeq; - private long _processedSeq; - private readonly object _drainLock = new(); - private readonly List _handlers = []; - private List<(long TargetSeq, TaskCompletionSource Tcs)>? _drainWaiters; - public JsonTypeInfo TypeInfo { get; } = typeInfo; - - public void AddHandler(EventHandler handler) - { - lock (_drainLock) _handlers.Add(handler); - } - - public void RemoveHandler(EventHandler handler) - { - lock (_drainLock) _handlers.Remove(handler); - } - - public EventHandler[] GetHandlersSnapshot() - { - lock (_drainLock) return [.. _handlers]; - } - - public void IncrementEnqueued() => Interlocked.Increment(ref _enqueueSeq); - - public void IncrementProcessed() - { - var processed = Interlocked.Increment(ref _processedSeq); - - lock (_drainLock) - { - if (_drainWaiters is null) return; - - for (var i = _drainWaiters.Count - 1; i >= 0; i--) - { - if (_drainWaiters[i].TargetSeq <= processed) - { - _drainWaiters[i].Tcs.TrySetResult(true); - _drainWaiters.RemoveAt(i); - } - } - - if (_drainWaiters.Count == 0) _drainWaiters = null; - } - } - - public Task DrainAsync(CancellationToken cancellationToken) - { - lock (_drainLock) - { - var target = Volatile.Read(ref _enqueueSeq); - if (Volatile.Read(ref _processedSeq) >= target) return Task.CompletedTask; - - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _drainWaiters ??= []; - _drainWaiters.Add((target, tcs)); - - // Double-check: processing may have caught up between the read and adding the waiter - if (Volatile.Read(ref _processedSeq) >= target) - { - _drainWaiters.Remove((target, tcs)); - if (_drainWaiters.Count == 0) _drainWaiters = null; - return Task.CompletedTask; - } - - if (!cancellationToken.CanBeCanceled) return tcs.Task; - - return tcs.Task.ContinueWith( - static _ => { }, - cancellationToken, - TaskContinuationOptions.None, - TaskScheduler.Default); - } - } + public List Handlers { get; } = []; } } From fed48a3db891eb81580c53fcc0d8628835309593 Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Thu, 5 Mar 2026 19:44:57 +0300 Subject: [PATCH 47/67] [dotnet] Use props file for legacy sdk (#17180) --- dotnet/src/webdriver/BUILD.bazel | 20 +++++++++---------- .../Selenium.WebDriver.StrongNamed.nuspec | 4 ++-- .../src/webdriver/Selenium.WebDriver.csproj | 10 +++++----- .../src/webdriver/Selenium.WebDriver.nuspec | 4 ++-- dotnet/src/webdriver/assets/BUILD.bazel | 4 ++-- ...river.targets => Selenium.WebDriver.props} | 0 ...river.targets => Selenium.WebDriver.props} | 2 +- 7 files changed, 22 insertions(+), 22 deletions(-) rename dotnet/src/webdriver/assets/nuget/build/{Selenium.WebDriver.targets => Selenium.WebDriver.props} (100%) rename dotnet/src/webdriver/assets/nuget/buildTransitive/{Selenium.WebDriver.targets => Selenium.WebDriver.props} (89%) diff --git a/dotnet/src/webdriver/BUILD.bazel b/dotnet/src/webdriver/BUILD.bazel index 9fd2588951e96..79e2547b1feca 100644 --- a/dotnet/src/webdriver/BUILD.bazel +++ b/dotnet/src/webdriver/BUILD.bazel @@ -228,15 +228,15 @@ copy_file( ) copy_file( - name = "assets-nuget-build-targets", - src = "//dotnet/src/webdriver/assets:nuget/build/Selenium.WebDriver.targets", - out = "Selenium.WebDriver.targets", + name = "assets-nuget-build-props", + src = "//dotnet/src/webdriver/assets:nuget/build/Selenium.WebDriver.props", + out = "Selenium.WebDriver.props", ) copy_file( - name = "assets-nuget-buildtransitive-targets", - src = "//dotnet/src/webdriver/assets:nuget/buildTransitive/Selenium.WebDriver.targets", - out = "transitiveSelenium.WebDriver.targets", + name = "assets-nuget-buildtransitive-props", + src = "//dotnet/src/webdriver/assets:nuget/buildTransitive/Selenium.WebDriver.props", + out = "transitiveSelenium.WebDriver.props", ) copy_file( @@ -268,8 +268,8 @@ nuget_pack( "//common/manager:selenium-manager-macos": "manager/macos/selenium-manager", "//common/manager:selenium-manager-windows": "manager/windows/selenium-manager.exe", ":assets-nuget-readme": "README.md", - ":assets-nuget-build-targets": "build/Selenium.WebDriver.targets", - ":assets-nuget-buildtransitive-targets": "buildTransitive/Selenium.WebDriver.targets", + ":assets-nuget-build-props": "build/Selenium.WebDriver.props", + ":assets-nuget-buildtransitive-props": "buildTransitive/Selenium.WebDriver.props", }, id = "Selenium.WebDriver", libs = { @@ -295,8 +295,8 @@ nuget_pack( "//common/manager:selenium-manager-macos": "manager/macos/selenium-manager", "//common/manager:selenium-manager-windows": "manager/windows/selenium-manager.exe", ":assets-nuget-readme": "README.md", - ":assets-nuget-build-targets": "build/Selenium.WebDriver.StrongNamed.targets", - ":assets-nuget-buildtransitive-targets": "buildTransitive/Selenium.WebDriver.StrongNamed.targets", + ":assets-nuget-build-props": "build/Selenium.WebDriver.StrongNamed.props", + ":assets-nuget-buildtransitive-props": "buildTransitive/Selenium.WebDriver.StrongNamed.props", }, id = "Selenium.WebDriver.StrongNamed", libs = { diff --git a/dotnet/src/webdriver/Selenium.WebDriver.StrongNamed.nuspec b/dotnet/src/webdriver/Selenium.WebDriver.StrongNamed.nuspec index 70da075ee5ec2..c137f5650e644 100644 --- a/dotnet/src/webdriver/Selenium.WebDriver.StrongNamed.nuspec +++ b/dotnet/src/webdriver/Selenium.WebDriver.StrongNamed.nuspec @@ -54,8 +54,8 @@ - - + + diff --git a/dotnet/src/webdriver/Selenium.WebDriver.csproj b/dotnet/src/webdriver/Selenium.WebDriver.csproj index 9fdcb02d24169..a9d7914b0db3f 100644 --- a/dotnet/src/webdriver/Selenium.WebDriver.csproj +++ b/dotnet/src/webdriver/Selenium.WebDriver.csproj @@ -61,11 +61,6 @@ - - - - - @@ -85,6 +80,11 @@ + + + + + diff --git a/dotnet/src/webdriver/Selenium.WebDriver.nuspec b/dotnet/src/webdriver/Selenium.WebDriver.nuspec index 6e45a4fea9caa..bcd160d4678be 100644 --- a/dotnet/src/webdriver/Selenium.WebDriver.nuspec +++ b/dotnet/src/webdriver/Selenium.WebDriver.nuspec @@ -54,8 +54,8 @@ - - + + diff --git a/dotnet/src/webdriver/assets/BUILD.bazel b/dotnet/src/webdriver/assets/BUILD.bazel index a72d6f97da643..e8f0b7e6994da 100644 --- a/dotnet/src/webdriver/assets/BUILD.bazel +++ b/dotnet/src/webdriver/assets/BUILD.bazel @@ -1,5 +1,5 @@ exports_files([ "nuget/README.md", - "nuget/build/Selenium.WebDriver.targets", - "nuget/buildTransitive/Selenium.WebDriver.targets", + "nuget/build/Selenium.WebDriver.props", + "nuget/buildTransitive/Selenium.WebDriver.props", ]) diff --git a/dotnet/src/webdriver/assets/nuget/build/Selenium.WebDriver.targets b/dotnet/src/webdriver/assets/nuget/build/Selenium.WebDriver.props similarity index 100% rename from dotnet/src/webdriver/assets/nuget/build/Selenium.WebDriver.targets rename to dotnet/src/webdriver/assets/nuget/build/Selenium.WebDriver.props diff --git a/dotnet/src/webdriver/assets/nuget/buildTransitive/Selenium.WebDriver.targets b/dotnet/src/webdriver/assets/nuget/buildTransitive/Selenium.WebDriver.props similarity index 89% rename from dotnet/src/webdriver/assets/nuget/buildTransitive/Selenium.WebDriver.targets rename to dotnet/src/webdriver/assets/nuget/buildTransitive/Selenium.WebDriver.props index 877a684d233d1..21f87ccc264a2 100644 --- a/dotnet/src/webdriver/assets/nuget/buildTransitive/Selenium.WebDriver.targets +++ b/dotnet/src/webdriver/assets/nuget/buildTransitive/Selenium.WebDriver.props @@ -1,4 +1,4 @@ - + From e3393ed203652a7dac176d4dff978943de89e955 Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Thu, 5 Mar 2026 19:56:22 +0300 Subject: [PATCH 48/67] [bidi] Convert RemoteValue to IDictionary (#17181) --- .../src/webdriver/BiDi/Script/RemoteValue.cs | 30 ++++++++ .../BiDi/Script/RemoteValueConversionTests.cs | 72 +++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/dotnet/src/webdriver/BiDi/Script/RemoteValue.cs b/dotnet/src/webdriver/BiDi/Script/RemoteValue.cs index 9f7698575506a..a567009a5895a 100644 --- a/dotnet/src/webdriver/BiDi/Script/RemoteValue.cs +++ b/dotnet/src/webdriver/BiDi/Script/RemoteValue.cs @@ -98,6 +98,10 @@ public abstract record RemoteValue => ConvertRemoteValuesToArray(a.Value, t.GetElementType()!), (ArrayRemoteValue a, Type t) when t.IsGenericType && t.IsAssignableFrom(typeof(List<>).MakeGenericType(t.GetGenericArguments()[0])) => ConvertRemoteValuesToGenericList(a.Value, typeof(List<>).MakeGenericType(t.GetGenericArguments()[0])), + (MapRemoteValue m, Type t) when t.IsGenericType && t.GetGenericArguments().Length == 2 && t.IsAssignableFrom(typeof(Dictionary<,>).MakeGenericType(t.GetGenericArguments())) + => ConvertRemoteValuesToDictionary(m.Value, typeof(Dictionary<,>).MakeGenericType(t.GetGenericArguments())), + (ObjectRemoteValue o, Type t) when t.IsGenericType && t.GetGenericArguments().Length == 2 && t.IsAssignableFrom(typeof(Dictionary<,>).MakeGenericType(t.GetGenericArguments())) + => ConvertRemoteValuesToDictionary(o.Value, typeof(Dictionary<,>).MakeGenericType(t.GetGenericArguments())), (_, Type t) when Nullable.GetUnderlyingType(t) is { } underlying => ConvertToNullable(underlying), @@ -150,6 +154,32 @@ private static TResult ConvertRemoteValuesToGenericList(IEnumerable(IReadOnlyList>? remoteValues, Type dictionaryType) + { + var typeArgs = dictionaryType.GetGenericArguments(); + var dict = (System.Collections.IDictionary)Activator.CreateInstance(dictionaryType)!; + + if (remoteValues is not null) + { + var convertKeyMethod = typeof(RemoteValue).GetMethod(nameof(ConvertTo))!.MakeGenericMethod(typeArgs[0]); + var convertValueMethod = typeof(RemoteValue).GetMethod(nameof(ConvertTo))!.MakeGenericMethod(typeArgs[1]); + + foreach (var pair in remoteValues) + { + if (pair.Count != 2) + { + throw new FormatException($"Expected a pair of RemoteValues for dictionary entry, but got {pair.Count} values."); + } + + var convertedKey = convertKeyMethod.Invoke(pair[0], null)!; + var convertedValue = convertValueMethod.Invoke(pair[1], null); + dict.Add(convertedKey, convertedValue); + } + } + + return (TResult)dict; + } } public abstract record PrimitiveProtocolRemoteValue : RemoteValue; diff --git a/dotnet/test/common/BiDi/Script/RemoteValueConversionTests.cs b/dotnet/test/common/BiDi/Script/RemoteValueConversionTests.cs index b453272371a51..ddc04be9af2f9 100644 --- a/dotnet/test/common/BiDi/Script/RemoteValueConversionTests.cs +++ b/dotnet/test/common/BiDi/Script/RemoteValueConversionTests.cs @@ -286,4 +286,76 @@ static void AssertValue(IEnumerable value) Assert.That(value, Is.Empty); } } + + [Test] + public void CanConvertMapRemoteValueToDictionary() + { + MapRemoteValue arg = new() + { + Value = + [ + [new StringRemoteValue("key1"), new NumberRemoteValue(1)], + [new StringRemoteValue("key2"), new NumberRemoteValue(2)], + ] + }; + + AssertValue(arg.ConvertTo>()); + AssertValue(arg.ConvertTo>()); + + static void AssertValue(IDictionary value) + { + Assert.That(value, Has.Count.EqualTo(2)); + Assert.That(value["key1"], Is.EqualTo(1)); + Assert.That(value["key2"], Is.EqualTo(2)); + } + } + + [Test] + public void CanConvertEmptyMapRemoteValueToDictionary() + { + MapRemoteValue arg = new(); + + AssertValue(arg.ConvertTo>()); + + static void AssertValue(IDictionary value) + { + Assert.That(value, Is.Empty); + } + } + + [Test] + public void CanConvertObjectRemoteValueToDictionary() + { + ObjectRemoteValue arg = new() + { + Value = + [ + [new StringRemoteValue("a"), new BooleanRemoteValue(true)], + [new StringRemoteValue("b"), new BooleanRemoteValue(false)], + ] + }; + + AssertValue(arg.ConvertTo>()); + AssertValue(arg.ConvertTo>()); + + static void AssertValue(IDictionary value) + { + Assert.That(value, Has.Count.EqualTo(2)); + Assert.That(value["a"], Is.True); + Assert.That(value["b"], Is.False); + } + } + + [Test] + public void CanConvertEmptyObjectRemoteValueToDictionary() + { + ObjectRemoteValue arg = new(); + + AssertValue(arg.ConvertTo>()); + + static void AssertValue(IDictionary value) + { + Assert.That(value, Is.Empty); + } + } } From ada4964adbac5023cefe95017a6ab8b449d06430 Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Fri, 6 Mar 2026 11:09:58 +0200 Subject: [PATCH 49/67] [java] make the signature change in `ExecuteMethod` backward compatible (#17183) * make the signature change in `ExecuteMethod` backward compatible ... to avoid too many changes e.g. in Appium project (its has own implementation `AppiumExecutionMethod`). In PR #17152, I changed signature of `ExecuteMethod.execute()` - and later realized that it breaks backward compatibility in Appium. * rename the additional helper methods `ExecuteMethod` * rename the additional helper methods `ExecuteMethod` --- .../selenium/chromium/AddHasCasting.java | 5 ++- .../openqa/selenium/chromium/AddHasCdp.java | 3 +- .../chromium/AddHasNetworkConditions.java | 3 +- .../selenium/firefox/AddHasContext.java | 2 +- .../selenium/firefox/AddHasExtensions.java | 2 +- .../openqa/selenium/remote/ExecuteMethod.java | 33 +++++++++++++++++-- .../selenium/remote/FedCmDialogImpl.java | 9 +++-- .../selenium/remote/LocalExecuteMethod.java | 2 +- .../selenium/remote/RemoteExecuteMethod.java | 5 ++- .../openqa/selenium/remote/RemoteLogs.java | 3 +- .../selenium/safari/AddHasPermissions.java | 2 +- .../selenium/remote/RemoteLogsTest.java | 2 +- 12 files changed, 46 insertions(+), 25 deletions(-) diff --git a/java/src/org/openqa/selenium/chromium/AddHasCasting.java b/java/src/org/openqa/selenium/chromium/AddHasCasting.java index b52d55e50f0dc..7cb789ecf9b8a 100644 --- a/java/src/org/openqa/selenium/chromium/AddHasCasting.java +++ b/java/src/org/openqa/selenium/chromium/AddHasCasting.java @@ -18,7 +18,6 @@ package org.openqa.selenium.chromium; import static java.util.Collections.emptyList; -import static java.util.Objects.requireNonNullElse; import java.util.List; import java.util.Map; @@ -56,7 +55,7 @@ public HasCasting getImplementation(Capabilities capabilities, ExecuteMethod exe return new HasCasting() { @Override public List> getCastSinks() { - return requireNonNullElse(executeMethod.execute(GET_CAST_SINKS, null), emptyList()); + return executeMethod.execute(GET_CAST_SINKS, null, emptyList()); } @Override @@ -82,7 +81,7 @@ public void startTabMirroring(String deviceName) { @Override public String getCastIssueMessage() { - return executeMethod.executeRequired(GET_CAST_ISSUE_MESSAGE, null).toString(); + return executeMethod.execute(GET_CAST_ISSUE_MESSAGE).toString(); } @Override diff --git a/java/src/org/openqa/selenium/chromium/AddHasCdp.java b/java/src/org/openqa/selenium/chromium/AddHasCdp.java index e35998388ffb1..08e8225a1ec1b 100644 --- a/java/src/org/openqa/selenium/chromium/AddHasCdp.java +++ b/java/src/org/openqa/selenium/chromium/AddHasCdp.java @@ -51,8 +51,7 @@ public HasCdp getImplementation(Capabilities capabilities, ExecuteMethod execute Require.nonNull("Command name", commandName); Require.nonNull("Parameters", parameters); - return executeMethod.executeRequired( - EXECUTE_CDP, Map.of("cmd", commandName, "params", parameters)); + return executeMethod.executeAs(EXECUTE_CDP, Map.of("cmd", commandName, "params", parameters)); }; } } diff --git a/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java b/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java index 3a045c37d8170..d7774d040297e 100644 --- a/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java +++ b/java/src/org/openqa/selenium/chromium/AddHasNetworkConditions.java @@ -75,8 +75,7 @@ public HasNetworkConditions getImplementation( return new HasNetworkConditions() { @Override public ChromiumNetworkConditions getNetworkConditions() { - @SuppressWarnings("unchecked") - Map result = executeMethod.executeRequired(GET_NETWORK_CONDITIONS, null); + Map result = executeMethod.execute(GET_NETWORK_CONDITIONS); return new ChromiumNetworkConditions() .setOffline((Boolean) result.getOrDefault(OFFLINE, false)) .setLatency(Duration.ofMillis((Long) result.getOrDefault(LATENCY, 0))) diff --git a/java/src/org/openqa/selenium/firefox/AddHasContext.java b/java/src/org/openqa/selenium/firefox/AddHasContext.java index 4369ab6cda1d2..bddb8de9e0058 100644 --- a/java/src/org/openqa/selenium/firefox/AddHasContext.java +++ b/java/src/org/openqa/selenium/firefox/AddHasContext.java @@ -69,7 +69,7 @@ public void setContext(FirefoxCommandContext context) { @Override public FirefoxCommandContext getContext() { - String context = executeMethod.executeRequired(GET_CONTEXT, null); + String context = executeMethod.execute(GET_CONTEXT); return FirefoxCommandContext.fromString(context); } }; diff --git a/java/src/org/openqa/selenium/firefox/AddHasExtensions.java b/java/src/org/openqa/selenium/firefox/AddHasExtensions.java index 671199bd2e494..32726c6375bb6 100644 --- a/java/src/org/openqa/selenium/firefox/AddHasExtensions.java +++ b/java/src/org/openqa/selenium/firefox/AddHasExtensions.java @@ -95,7 +95,7 @@ public String installExtension(Path path, Boolean temporary) { throw new InvalidArgumentException(path + " is an invalid path", e); } - return executeMethod.executeRequired( + return executeMethod.executeAs( INSTALL_EXTENSION, Map.of("addon", encoded, "temporary", temporary)); } diff --git a/java/src/org/openqa/selenium/remote/ExecuteMethod.java b/java/src/org/openqa/selenium/remote/ExecuteMethod.java index e348e5daf4604..b65238c0d4262 100644 --- a/java/src/org/openqa/selenium/remote/ExecuteMethod.java +++ b/java/src/org/openqa/selenium/remote/ExecuteMethod.java @@ -18,6 +18,7 @@ package org.openqa.selenium.remote; import static java.util.Objects.requireNonNull; +import static java.util.Objects.requireNonNullElse; import java.util.Map; import org.jspecify.annotations.NullMarked; @@ -37,9 +38,35 @@ public interface ExecuteMethod { * @param parameters The parameters to execute that command with * @return The result of {@link Response#getValue()}. */ - @Nullable T execute(String commandName, @Nullable Map parameters); + @Nullable Object execute(String commandName, @Nullable Map parameters); - default T executeRequired(String commandName, @Nullable Map parameters) { - return requireNonNull(execute(commandName, parameters)); + /** + * Execute the given command and return the default value if the command return null. + * + * @return non-nullable value of type T. + */ + @SuppressWarnings("unchecked") + default T execute(String commandName, @Nullable Map parameters, T defaultValue) { + return (T) requireNonNullElse(execute(commandName, parameters), defaultValue); + } + + /** + * Execute the given command and cast the returned value to T. + * + * @return non-nullable value of type T. + */ + @SuppressWarnings("unchecked") + default T executeAs(String commandName, @Nullable Map parameters) { + return (T) requireNonNull(execute(commandName, parameters)); + } + + /** + * Execute the given command without parameters and cast the returned value to T. + * + * @return non-nullable value of type T. + */ + @SuppressWarnings("unchecked") + default T execute(String commandName) { + return (T) requireNonNull(execute(commandName, null)); } } diff --git a/java/src/org/openqa/selenium/remote/FedCmDialogImpl.java b/java/src/org/openqa/selenium/remote/FedCmDialogImpl.java index ed50cfab338fe..19d0ee4295986 100644 --- a/java/src/org/openqa/selenium/remote/FedCmDialogImpl.java +++ b/java/src/org/openqa/selenium/remote/FedCmDialogImpl.java @@ -46,7 +46,7 @@ public void selectAccount(int index) { @Nullable @Override public String getDialogType() { - return executeMethod.execute(DriverCommand.GET_FEDCM_DIALOG_TYPE, null); + return (String) executeMethod.execute(DriverCommand.GET_FEDCM_DIALOG_TYPE, null); } @Override @@ -58,21 +58,20 @@ public void clickDialog() { @Nullable @Override public String getTitle() { - Map result = executeMethod.executeRequired(DriverCommand.GET_FEDCM_TITLE, null); + Map result = executeMethod.execute(DriverCommand.GET_FEDCM_TITLE); return result.get("title"); } @Nullable @Override public String getSubtitle() { - Map result = executeMethod.executeRequired(DriverCommand.GET_FEDCM_TITLE, null); + Map result = executeMethod.execute(DriverCommand.GET_FEDCM_TITLE); return result.get("subtitle"); } @Override public List getAccounts() { - List> accounts = - executeMethod.executeRequired(DriverCommand.GET_ACCOUNTS, null); + List> accounts = executeMethod.execute(DriverCommand.GET_ACCOUNTS); return accounts.stream() .map(map -> new FederatedCredentialManagementAccount(map)) diff --git a/java/src/org/openqa/selenium/remote/LocalExecuteMethod.java b/java/src/org/openqa/selenium/remote/LocalExecuteMethod.java index 0d2b4e4cb5c36..84a6093dc8928 100644 --- a/java/src/org/openqa/selenium/remote/LocalExecuteMethod.java +++ b/java/src/org/openqa/selenium/remote/LocalExecuteMethod.java @@ -26,7 +26,7 @@ class LocalExecuteMethod implements ExecuteMethod { @Nullable @Override - public T execute(String commandName, @Nullable Map parameters) { + public Object execute(String commandName, @Nullable Map parameters) { throw new WebDriverException("Cannot execute remote command: " + commandName); } } diff --git a/java/src/org/openqa/selenium/remote/RemoteExecuteMethod.java b/java/src/org/openqa/selenium/remote/RemoteExecuteMethod.java index 8f28193c5718e..9be7a20bfd2be 100644 --- a/java/src/org/openqa/selenium/remote/RemoteExecuteMethod.java +++ b/java/src/org/openqa/selenium/remote/RemoteExecuteMethod.java @@ -31,9 +31,8 @@ public RemoteExecuteMethod(RemoteWebDriver driver) { this.driver = Require.nonNull("Remote WebDriver", driver); } - @SuppressWarnings("unchecked") @Override - public @Nullable T execute(String commandName, @Nullable Map parameters) { + public @Nullable Object execute(String commandName, @Nullable Map parameters) { Response response; if (parameters == null || parameters.isEmpty()) { @@ -42,7 +41,7 @@ public RemoteExecuteMethod(RemoteWebDriver driver) { response = driver.execute(commandName, parameters); } - return (T) response.getValue(); + return response.getValue(); } @Override diff --git a/java/src/org/openqa/selenium/remote/RemoteLogs.java b/java/src/org/openqa/selenium/remote/RemoteLogs.java index 7d0f230d6b324..267558e892e14 100644 --- a/java/src/org/openqa/selenium/remote/RemoteLogs.java +++ b/java/src/org/openqa/selenium/remote/RemoteLogs.java @@ -149,8 +149,7 @@ private Set getAvailableLocalLogs() { @Override public Set getAvailableLogTypes() { - List rawList = - executeMethod.executeRequired(DriverCommand.GET_AVAILABLE_LOG_TYPES, null); + List rawList = executeMethod.execute(DriverCommand.GET_AVAILABLE_LOG_TYPES); Set builder = new LinkedHashSet<>(); builder.addAll(rawList); builder.addAll(getAvailableLocalLogs()); diff --git a/java/src/org/openqa/selenium/safari/AddHasPermissions.java b/java/src/org/openqa/selenium/safari/AddHasPermissions.java index 6aa4572387a89..f9edf3496eeea 100644 --- a/java/src/org/openqa/selenium/safari/AddHasPermissions.java +++ b/java/src/org/openqa/selenium/safari/AddHasPermissions.java @@ -67,7 +67,7 @@ public void setPermissions(String permission, boolean value) { @Override public Map getPermissions() { - Map resultMap = executeMethod.executeRequired(GET_PERMISSIONS, null); + Map resultMap = executeMethod.execute(GET_PERMISSIONS); Map permissionMap = new HashMap<>(); for (Map.Entry entry : resultMap.entrySet()) { diff --git a/java/test/org/openqa/selenium/remote/RemoteLogsTest.java b/java/test/org/openqa/selenium/remote/RemoteLogsTest.java index e9044760a0e93..39631bbabd5c0 100644 --- a/java/test/org/openqa/selenium/remote/RemoteLogsTest.java +++ b/java/test/org/openqa/selenium/remote/RemoteLogsTest.java @@ -133,7 +133,7 @@ void throwsOnBogusRemoteLogsResponse() { @Test void canGetAvailableLogTypes() { List remoteAvailableLogTypes = List.of(LogType.PROFILER, LogType.SERVER); - when(executeMethod.executeRequired(DriverCommand.GET_AVAILABLE_LOG_TYPES, null)) + when(executeMethod.execute(DriverCommand.GET_AVAILABLE_LOG_TYPES)) .thenReturn(remoteAvailableLogTypes); Set localAvailableLogTypes = Set.of(LogType.PROFILER, LogType.CLIENT); From 1708745f9ef0bf020422fa5d73c2b58eb7a20148 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Sat, 7 Mar 2026 01:48:10 +0100 Subject: [PATCH 50/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17182) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index d659c8af5d1a5..398a0ae838826 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -50,8 +50,8 @@ js_library( http_archive( name = "linux_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b4/linux-x86_64/en-US/firefox-149.0b4.tar.xz", - sha256 = "57539691f99f49124487846e1ed30b2c3c1e5413a68a4838644d9f35fcd22ec6", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b5/linux-x86_64/en-US/firefox-149.0b5.tar.xz", + sha256 = "1c9c3c4955c542815549f3c2737a54c4adba74cc949e07441fede36d02f1a536", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -72,8 +72,8 @@ js_library( dmg_archive( name = "mac_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b4/mac/en-US/Firefox%20149.0b4.dmg", - sha256 = "69b90c44f1757f0ba0799e82767e38bac8cad4762de76113a14237190f9c4212", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b5/mac/en-US/Firefox%20149.0b5.dmg", + sha256 = "7bca909999df7ce1a6d6b8a1ebcf121f9f85a2f20dc0205d70362b9de813df51", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -123,10 +123,10 @@ js_library( pkg_archive( name = "mac_edge", - url = "https://msedge.sf.dl.delivery.mp.microsoft.com/filestreamingservice/files/42fe1f19-ccfd-476b-835f-4dc020005cd9/MicrosoftEdge-145.0.3800.82.pkg", - sha256 = "d046111ab7cbbf56a78588afae17f21cc1343a55c9364268cbd98868f6428979", + url = "https://msedge.sf.dl.delivery.mp.microsoft.com/filestreamingservice/files/99e9a53e-4987-4ced-b408-2832d8a3e50a/MicrosoftEdge-145.0.3800.97.pkg", + sha256 = "462c1ee28df2815fe2902f32e9846029b7c57d207affb973be4d76c8db1e9eab", move = { - "MicrosoftEdge-145.0.3800.82.pkg/Payload/Microsoft Edge.app": "Edge.app", + "MicrosoftEdge-145.0.3800.97.pkg/Payload/Microsoft Edge.app": "Edge.app", }, build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") @@ -143,8 +143,8 @@ js_library( deb_archive( name = "linux_edge", - url = "https://packages.microsoft.com/repos/edge/pool/main/m/microsoft-edge-stable/microsoft-edge-stable_145.0.3800.82-1_amd64.deb", - sha256 = "29e8d282072cf3c798e524cd1bd17b4ee79982f0679b17aca0d6ed50d8d2adf9", + url = "https://packages.microsoft.com/repos/edge/pool/main/m/microsoft-edge-stable/microsoft-edge-stable_145.0.3800.97-1_amd64.deb", + sha256 = "66287f30e884d40a6d6413ad47957b211779d09568933cd741bd55ea950ad152", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -277,8 +277,8 @@ js_library( http_archive( name = "linux_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chrome-linux64.zip", - sha256 = "5b1961b081f0156a1923a9d9d1bfffdf00f82e8722152c35eb5eb742d63ceeb8", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.66/linux64/chrome-linux64.zip", + sha256 = "85361fe4a850e6fd68487afd277b48c5579e7f2e08f803783c26f9465793e743", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -298,8 +298,8 @@ js_library( ) http_archive( name = "mac_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chrome-mac-arm64.zip", - sha256 = "207867110edc624316b18684065df4eb06b938a3fd9141790a726ab280e2640f", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.66/mac-arm64/chrome-mac-arm64.zip", + sha256 = "427685bb078198854606f25f7b0b35f52ae4e933fd482632882187c14341b031", strip_prefix = "chrome-mac-arm64", patch_cmds = [ "mv 'Google Chrome for Testing.app' Chrome.app", @@ -319,8 +319,8 @@ js_library( ) http_archive( name = "linux_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chromedriver-linux64.zip", - sha256 = "a8c7be8669829ed697759390c8c42b4bca3f884fd20980e078129f5282dabe1a", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.66/linux64/chromedriver-linux64.zip", + sha256 = "90fa704ff45f6512273fc8a9b8da3344665922f6555716dde341817264e77388", strip_prefix = "chromedriver-linux64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") @@ -337,8 +337,8 @@ js_library( http_archive( name = "mac_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chromedriver-mac-arm64.zip", - sha256 = "84c3717c0eeba663d0b8890a0fc06faa6fe158227876fc6954461730ccc81634", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.66/mac-arm64/chromedriver-mac-arm64.zip", + sha256 = "36a89f18770f7422bba363bc21e2ed6da7db2dd479d816da307ff338e01a6612", strip_prefix = "chromedriver-mac-arm64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") From 4227c4e9b6b6d72387203124dc55c3c25c8b1f48 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Sun, 8 Mar 2026 01:52:05 +0100 Subject: [PATCH 51/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17189) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index 398a0ae838826..b625b4d4659f1 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -165,8 +165,8 @@ js_library( http_archive( name = "linux_edgedriver", - url = "https://msedgedriver.microsoft.com/145.0.3800.82/edgedriver_linux64.zip", - sha256 = "e6fa668cb036938d56a06519b19923ac8e547c3ea4c0d4fdf7a75076b3e1b31a", + url = "https://msedgedriver.microsoft.com/145.0.3800.97/edgedriver_linux64.zip", + sha256 = "fc29a432716ac44d98cda4cb28fb8ee14a87340f95456c3fcea329dc34d6cd1a", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) From e8f653d12febd4a6019d35220ebd2874bdf2f5ee Mon Sep 17 00:00:00 2001 From: Andrei Solntsev Date: Sun, 8 Mar 2026 18:01:08 +0200 Subject: [PATCH 52/67] [java] specify nullability in package `org.openqa.selenium.devtools` (#17185) specify nullability in package `org.openqa.selenium.devtools` Partially implements #14291 --- .../selenium/devtools/CdpClientGenerator.java | 25 +++++--- .../org/openqa/selenium/devtools/Command.java | 6 +- .../openqa/selenium/devtools/Connection.java | 11 ++-- .../selenium/devtools/ConverterFunctions.java | 2 +- .../openqa/selenium/devtools/DevTools.java | 15 +++-- .../selenium/devtools/DevToolsException.java | 12 ++-- .../devtools/RequestFailedException.java | 2 - .../devtools/events/CdpEventTypes.java | 3 +- .../devtools/events/package-info.java | 21 +++++++ .../selenium/devtools/idealized/Network.java | 11 ++-- .../idealized/browser/model/package-info.java | 21 +++++++ .../idealized/log/model/package-info.java | 21 +++++++ .../devtools/idealized/log/package-info.java | 21 +++++++ .../devtools/idealized/package-info.java | 21 +++++++ .../idealized/runtime/model/RemoteObject.java | 7 ++- .../idealized/runtime/model/package-info.java | 21 +++++++ .../idealized/target/model/TargetInfo.java | 12 ++-- .../idealized/target/model/package-info.java | 21 +++++++ .../idealized/target/package-info.java | 21 +++++++ .../selenium/devtools/noop/package-info.java | 21 +++++++ .../openqa/selenium/devtools/v143/BUILD.bazel | 3 + .../selenium/devtools/v143/package-info.java | 21 +++++++ .../openqa/selenium/devtools/v144/BUILD.bazel | 3 + .../selenium/devtools/v144/package-info.java | 21 +++++++ .../openqa/selenium/devtools/v145/BUILD.bazel | 3 + .../selenium/devtools/v145/package-info.java | 21 +++++++ .../org/openqa/selenium/devtools/BUILD.bazel | 3 + .../runtime/model/RemoteObjectTest.java | 57 +++++++++++++++++++ .../idealized/runtime/model/package-info.java | 21 +++++++ 29 files changed, 404 insertions(+), 44 deletions(-) create mode 100644 java/src/org/openqa/selenium/devtools/events/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/idealized/browser/model/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/idealized/log/model/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/idealized/log/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/idealized/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/idealized/runtime/model/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/idealized/target/model/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/idealized/target/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/noop/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/v143/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/v144/package-info.java create mode 100644 java/src/org/openqa/selenium/devtools/v145/package-info.java create mode 100644 java/test/org/openqa/selenium/devtools/idealized/runtime/model/RemoteObjectTest.java create mode 100644 java/test/org/openqa/selenium/devtools/idealized/runtime/model/package-info.java diff --git a/java/src/org/openqa/selenium/devtools/CdpClientGenerator.java b/java/src/org/openqa/selenium/devtools/CdpClientGenerator.java index f38869992733a..1d34dd66a2816 100644 --- a/java/src/org/openqa/selenium/devtools/CdpClientGenerator.java +++ b/java/src/org/openqa/selenium/devtools/CdpClientGenerator.java @@ -17,6 +17,7 @@ package org.openqa.selenium.devtools; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.file.FileVisitResult.CONTINUE; import static java.util.Collections.unmodifiableMap; import static java.util.stream.Collectors.joining; @@ -176,8 +177,11 @@ public void parse(T target, Map json) { } private static class BaseSpec { + // it seems `name` is always filled from JSON + @SuppressWarnings({"NotNullFieldNotInitialized", "InstanceVariableMayNotBeInitialized"}) protected String name; - protected String description; + + protected @Nullable String description; protected boolean experimental; protected boolean deprecated; } @@ -343,7 +347,7 @@ private void dumpMainClass(Path target) { ensureFileDoesNotExists(commandFile); try { - Files.write(commandFile, unit.toString().getBytes()); + Files.write(commandFile, unit.toString().getBytes(UTF_8)); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -426,7 +430,7 @@ public void dumpTo(Path target) { ensureFileDoesNotExists(eventFile); try { - Files.write(eventFile, unit.toString().getBytes()); + Files.write(eventFile, unit.toString().getBytes(UTF_8)); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -540,7 +544,7 @@ public void dumpTo(Path target) { ensureFileDoesNotExists(typeFile); try { - Files.write(typeFile, unit.toString().getBytes()); + Files.write(typeFile, unit.toString().getBytes(UTF_8)); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -766,7 +770,7 @@ private interface IType { String getJavaDefaultValue(); - TypeDeclaration toTypeDeclaration(); + @Nullable TypeDeclaration toTypeDeclaration(); String getMapper(); } @@ -1240,7 +1244,10 @@ public String getMapper() { } private static class ArrayType implements IType { + // it seems `name` is always filled from JSON + @SuppressWarnings({"NotNullFieldNotInitialized", "InstanceVariableMayNotBeInitialized"}) private IType itemType; + private final String name; public ArrayType(String name) { @@ -1402,7 +1409,10 @@ private static void ensureFileDoesNotExists(Path file) { } } - private static String capitalize(String text) { + private static String capitalize(@Nullable String text) { + if (text == null) { + return ""; + } return text.substring(0, 1).toUpperCase() + text.substring(1); } @@ -1415,9 +1425,6 @@ private static String toJavaConstant(String text) { } private static String sanitizeJavadoc(String description) { - if (description == null) { - return null; - } // Escape */ sequences which would prematurely close the JavaDoc comment return description.replace("*/", "*/"); } diff --git a/java/src/org/openqa/selenium/devtools/Command.java b/java/src/org/openqa/selenium/devtools/Command.java index adcf87c0389e7..8aeca8f09e368 100644 --- a/java/src/org/openqa/selenium/devtools/Command.java +++ b/java/src/org/openqa/selenium/devtools/Command.java @@ -35,7 +35,8 @@ public Command(String method, Map params) { } public Command(String method, Map params, Type typeOfX) { - this(method, params, input -> input.read(Require.nonNull("Type to convert to", typeOfX))); + this( + method, params, input -> input.readNonNull(Require.nonNull("Type to convert to", typeOfX))); } public Command(String method, Map params, Function mapper) { @@ -73,7 +74,10 @@ public boolean getSendsResponse() { /** * Some CDP commands do not appear to send responses, and so are really hard to deal with. Work * around that by flagging those commands. + * + * @deprecated Not needed. All CDP commands return something, at least empty map. */ + @Deprecated public Command doesNotSendResponse() { return new Command<>(method, params, mapper, false); } diff --git a/java/src/org/openqa/selenium/devtools/Connection.java b/java/src/org/openqa/selenium/devtools/Connection.java index 7bdfb064e9d32..90194266957ab 100644 --- a/java/src/org/openqa/selenium/devtools/Connection.java +++ b/java/src/org/openqa/selenium/devtools/Connection.java @@ -47,6 +47,7 @@ import java.util.function.Consumer; import java.util.logging.Level; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.devtools.idealized.target.model.SessionID; import org.openqa.selenium.internal.Either; @@ -90,7 +91,6 @@ public Connection(HttpClient client, String url) { } public Connection(HttpClient client, String url, ClientConfig clientConfig) { - ; this.client = Require.nonNull("HTTP client", client); this.wsConfig = wsClientConfig(clientConfig, url); this.socket = this.client.openSocket(new HttpRequest(GET, wsConfig.baseUri()), new Listener()); @@ -117,7 +117,7 @@ private static ClientConfig wsClientConfig(ClientConfig clientConfig, String uri } } - private static class NamedConsumer implements Consumer { + private static final class NamedConsumer implements Consumer { private final String name; private final Consumer delegate; @@ -142,7 +142,7 @@ public String toString() { } } - public CompletableFuture send(SessionID sessionId, Command command) { + public CompletableFuture send(@Nullable SessionID sessionId, Command command) { long id = NEXT_ID.getAndIncrement(); CompletableFuture result = new CompletableFuture<>(); @@ -185,13 +185,16 @@ public CompletableFuture send(SessionID sessionId, Command command) { socket.sendText(json); if (!command.getSendsResponse()) { + // As far as I see, it never happens. + // All DevTools commands return something - at least empty map. + //noinspection DataFlowIssue result.complete(null); } return result; } - public X sendAndWait(SessionID sessionId, Command command, Duration timeout) { + public X sendAndWait(@Nullable SessionID sessionId, Command command, Duration timeout) { try { CompletableFuture future = send(sessionId, command); return future.get(timeout.toMillis(), MILLISECONDS); diff --git a/java/src/org/openqa/selenium/devtools/ConverterFunctions.java b/java/src/org/openqa/selenium/devtools/ConverterFunctions.java index 054b2010b5a1b..579da57223100 100644 --- a/java/src/org/openqa/selenium/devtools/ConverterFunctions.java +++ b/java/src/org/openqa/selenium/devtools/ConverterFunctions.java @@ -38,7 +38,7 @@ public class ConverterFunctions { Require.nonNull("Read callback", read); return input -> { - @Nullable X value = null; + X value = null; input.beginObject(); while (input.hasNext()) { diff --git a/java/src/org/openqa/selenium/devtools/DevTools.java b/java/src/org/openqa/selenium/devtools/DevTools.java index c29fb7bc7a314..110d8530ad35b 100644 --- a/java/src/org/openqa/selenium/devtools/DevTools.java +++ b/java/src/org/openqa/selenium/devtools/DevTools.java @@ -32,7 +32,6 @@ import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; -import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; import org.openqa.selenium.WebDriver; import org.openqa.selenium.devtools.idealized.Domains; @@ -47,8 +46,8 @@ public class DevTools implements Closeable { private final Domains protocol; private final Duration timeout = Duration.ofSeconds(30); private final Connection connection; - private volatile String windowHandle; - private volatile SessionID cdpSession; + private volatile @Nullable String windowHandle; + private volatile @Nullable SessionID cdpSession; public DevTools(Function protocol, Connection connection) { this.connection = Require.nonNull("WebSocket connection", connection); @@ -66,7 +65,8 @@ public void close() { } public void disconnectSession() { - if (cdpSession != null) { + SessionID id = cdpSession; + if (id != null) { try { // ensure network interception does cancel the wait for responses getDomains().network().disable(); @@ -75,7 +75,6 @@ public void disconnectSession() { LOG.log(Level.WARNING, "Exception while disabling network", e); } - SessionID id = cdpSession; cdpSession = null; windowHandle = null; @@ -168,7 +167,7 @@ public void createSession(@Nullable String windowHandle) { attachToWindow(windowHandle); } - private void attachToWindow(String windowHandle) { + private void attachToWindow(@Nullable String windowHandle) { TargetID targetId = findTarget(windowHandle); attachToTarget(targetId); this.windowHandle = windowHandle; @@ -210,8 +209,7 @@ private void attachToTarget(TargetID targetId) { } } - @NonNull - private TargetID findTarget(String windowHandle) { + private TargetID findTarget(@Nullable String windowHandle) { // Figure out the targets. List infos = connection.sendAndWait(cdpSession, getDomains().target().getTargets(), timeout); @@ -231,6 +229,7 @@ private Throwable unwrapCause(ExecutionException e) { return e.getCause() != null ? e.getCause() : e; } + @Nullable public SessionID getCdpSession() { return cdpSession; } diff --git a/java/src/org/openqa/selenium/devtools/DevToolsException.java b/java/src/org/openqa/selenium/devtools/DevToolsException.java index ba58cb9c06d7a..eedb38062f6ec 100644 --- a/java/src/org/openqa/selenium/devtools/DevToolsException.java +++ b/java/src/org/openqa/selenium/devtools/DevToolsException.java @@ -17,22 +17,22 @@ package org.openqa.selenium.devtools; -import org.jspecify.annotations.NullMarked; +import static java.util.Objects.requireNonNullElseGet; + import org.jspecify.annotations.Nullable; import org.openqa.selenium.WebDriverException; -@NullMarked public class DevToolsException extends WebDriverException { - public DevToolsException(@Nullable Throwable cause) { - this(cause.getMessage(), cause); + public DevToolsException(Throwable cause) { + this(requireNonNullElseGet(cause.getMessage(), cause::toString), cause); } - public DevToolsException(@Nullable String message) { + public DevToolsException(String message) { this(message, null); } - public DevToolsException(@Nullable String message, @Nullable Throwable cause) { + public DevToolsException(String message, @Nullable Throwable cause) { super(message, cause); addInfo(WebDriverException.DRIVER_INFO, "DevTools Connection"); } diff --git a/java/src/org/openqa/selenium/devtools/RequestFailedException.java b/java/src/org/openqa/selenium/devtools/RequestFailedException.java index 322d610b28951..6ffef0c946fc3 100644 --- a/java/src/org/openqa/selenium/devtools/RequestFailedException.java +++ b/java/src/org/openqa/selenium/devtools/RequestFailedException.java @@ -17,7 +17,6 @@ package org.openqa.selenium.devtools; -import org.jspecify.annotations.NullMarked; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.remote.http.Filter; import org.openqa.selenium.remote.http.HttpHandler; @@ -27,5 +26,4 @@ * browser fails to send a HTTP request. It can be caught in a {@link Filter} to handle the error * by, for example, returning a custom HTTP response. */ -@NullMarked public class RequestFailedException extends WebDriverException {} diff --git a/java/src/org/openqa/selenium/devtools/events/CdpEventTypes.java b/java/src/org/openqa/selenium/devtools/events/CdpEventTypes.java index 372c3d06687c1..a5afd71fc5fe3 100644 --- a/java/src/org/openqa/selenium/devtools/events/CdpEventTypes.java +++ b/java/src/org/openqa/selenium/devtools/events/CdpEventTypes.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.By; import org.openqa.selenium.JavascriptExecutor; import org.openqa.selenium.WebDriver; @@ -62,7 +63,7 @@ public void initializeListener(WebDriver webDriver) { }; } - public static EventType domMutation(Consumer handler) { + public static EventType domMutation(Consumer<@Nullable DomMutationEvent> handler) { Require.nonNull("Handler", handler); String script = Read.resourceAsString("/org/openqa/selenium/devtools/mutation-listener.js"); diff --git a/java/src/org/openqa/selenium/devtools/events/package-info.java b/java/src/org/openqa/selenium/devtools/events/package-info.java new file mode 100644 index 0000000000000..d5cef07efde14 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/events/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.events; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/idealized/Network.java b/java/src/org/openqa/selenium/devtools/idealized/Network.java index 949811311a6be..2eed867917208 100644 --- a/java/src/org/openqa/selenium/devtools/idealized/Network.java +++ b/java/src/org/openqa/selenium/devtools/idealized/Network.java @@ -19,6 +19,7 @@ import static java.net.HttpURLConnection.HTTP_OK; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; import static java.util.logging.Level.WARNING; import java.net.URI; @@ -37,6 +38,7 @@ import java.util.function.Predicate; import java.util.function.Supplier; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.openqa.selenium.Credentials; import org.openqa.selenium.TimeoutException; import org.openqa.selenium.UsernameAndPassword; @@ -151,7 +153,6 @@ public void addAuthHandler( prepareToInterceptTraffic(); } - @SuppressWarnings("SuspiciousMethodCalls") public void resetNetworkFilter() { filter = defaultFilter; } @@ -216,7 +217,7 @@ public void prepareToInterceptTraffic() { Either message = createSeMessages(pausedRequest); if (message.isRight()) { - HttpResponse res = message.right(); + HttpResponse res = requireNonNull(message.right()); CompletableFuture future = pendingResponses.remove(id); if (future == null) { @@ -310,9 +311,9 @@ protected HttpMethod convertFromCdpHttpMethod(String method) { protected HttpResponse createHttpResponse( Optional statusCode, - String body, - Boolean bodyIsBase64Encoded, - List> headers) { + @Nullable String body, + @Nullable Boolean bodyIsBase64Encoded, + List> headers) { Contents.Supplier content; if (body == null) { diff --git a/java/src/org/openqa/selenium/devtools/idealized/browser/model/package-info.java b/java/src/org/openqa/selenium/devtools/idealized/browser/model/package-info.java new file mode 100644 index 0000000000000..6988e29b3ba26 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/idealized/browser/model/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.idealized.browser.model; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/idealized/log/model/package-info.java b/java/src/org/openqa/selenium/devtools/idealized/log/model/package-info.java new file mode 100644 index 0000000000000..a12de4c826664 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/idealized/log/model/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.idealized.log.model; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/idealized/log/package-info.java b/java/src/org/openqa/selenium/devtools/idealized/log/package-info.java new file mode 100644 index 0000000000000..eb8d488256630 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/idealized/log/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.idealized.log; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/idealized/package-info.java b/java/src/org/openqa/selenium/devtools/idealized/package-info.java new file mode 100644 index 0000000000000..4dcf6a26f93db --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/idealized/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.idealized; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/idealized/runtime/model/RemoteObject.java b/java/src/org/openqa/selenium/devtools/idealized/runtime/model/RemoteObject.java index d4f7722263de3..2dbb9630d1ac8 100644 --- a/java/src/org/openqa/selenium/devtools/idealized/runtime/model/RemoteObject.java +++ b/java/src/org/openqa/selenium/devtools/idealized/runtime/model/RemoteObject.java @@ -17,12 +17,14 @@ package org.openqa.selenium.devtools.idealized.runtime.model; +import org.jspecify.annotations.Nullable; + public class RemoteObject { private final String type; - private final Object value; + private final @Nullable Object value; - public RemoteObject(String type, Object value) { + public RemoteObject(String type, @Nullable Object value) { this.type = type; this.value = value; } @@ -38,6 +40,7 @@ public String getType() { return type; } + @Nullable public Object getValue() { return value; } diff --git a/java/src/org/openqa/selenium/devtools/idealized/runtime/model/package-info.java b/java/src/org/openqa/selenium/devtools/idealized/runtime/model/package-info.java new file mode 100644 index 0000000000000..90e68397e2e71 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/idealized/runtime/model/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.idealized.runtime.model; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/idealized/target/model/TargetInfo.java b/java/src/org/openqa/selenium/devtools/idealized/target/model/TargetInfo.java index 6451b73b1b861..d3a54ef04961a 100644 --- a/java/src/org/openqa/selenium/devtools/idealized/target/model/TargetInfo.java +++ b/java/src/org/openqa/selenium/devtools/idealized/target/model/TargetInfo.java @@ -20,6 +20,7 @@ import java.util.Optional; import org.openqa.selenium.Beta; import org.openqa.selenium.devtools.idealized.browser.model.BrowserContextID; +import org.openqa.selenium.internal.Require; import org.openqa.selenium.json.JsonInput; public class TargetInfo { @@ -42,11 +43,11 @@ public TargetInfo( Boolean attached, Optional openerId, Optional browserContextId) { - this.targetId = java.util.Objects.requireNonNull(targetId, "targetId is required"); - this.type = java.util.Objects.requireNonNull(type, "type is required"); - this.title = java.util.Objects.requireNonNull(title, "title is required"); - this.url = java.util.Objects.requireNonNull(url, "url is required"); - this.attached = java.util.Objects.requireNonNull(attached, "attached is required"); + this.targetId = Require.nonNull("targetId", targetId); + this.type = Require.nonNull("type", type); + this.title = Require.nonNull("title", title); + this.url = Require.nonNull("url", url); + this.attached = Require.nonNull("attached", attached); this.openerId = openerId; this.browserContextId = browserContextId; } @@ -82,6 +83,7 @@ public Optional getBrowserContextId() { return browserContextId; } + @SuppressWarnings("DataFlowIssue") private static TargetInfo fromJson(JsonInput input) { TargetID targetId = null; String type = null; diff --git a/java/src/org/openqa/selenium/devtools/idealized/target/model/package-info.java b/java/src/org/openqa/selenium/devtools/idealized/target/model/package-info.java new file mode 100644 index 0000000000000..de674415b30b4 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/idealized/target/model/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.idealized.target.model; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/idealized/target/package-info.java b/java/src/org/openqa/selenium/devtools/idealized/target/package-info.java new file mode 100644 index 0000000000000..00401925522f0 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/idealized/target/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.idealized.target; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/noop/package-info.java b/java/src/org/openqa/selenium/devtools/noop/package-info.java new file mode 100644 index 0000000000000..d59564ae18c2e --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/noop/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.noop; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/v143/BUILD.bazel b/java/src/org/openqa/selenium/devtools/v143/BUILD.bazel index c16d20745a71e..819cfbb8ae8e4 100644 --- a/java/src/org/openqa/selenium/devtools/v143/BUILD.bazel +++ b/java/src/org/openqa/selenium/devtools/v143/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") load("//common:defs.bzl", "copy_file") load("//java:defs.bzl", "java_export", "java_library") load("//java:version.bzl", "SE_VERSION") @@ -27,6 +28,7 @@ java_export( "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) @@ -42,6 +44,7 @@ java_library( "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/devtools/v143/package-info.java b/java/src/org/openqa/selenium/devtools/v143/package-info.java new file mode 100644 index 0000000000000..ad7c75636057c --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/v143/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.v143; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/v144/BUILD.bazel b/java/src/org/openqa/selenium/devtools/v144/BUILD.bazel index 885b9016ea619..19226eabfe7ca 100644 --- a/java/src/org/openqa/selenium/devtools/v144/BUILD.bazel +++ b/java/src/org/openqa/selenium/devtools/v144/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") load("//common:defs.bzl", "copy_file") load("//java:defs.bzl", "java_export", "java_library") load("//java:version.bzl", "SE_VERSION") @@ -27,6 +28,7 @@ java_export( "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) @@ -42,6 +44,7 @@ java_library( "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/devtools/v144/package-info.java b/java/src/org/openqa/selenium/devtools/v144/package-info.java new file mode 100644 index 0000000000000..2b3fd2eb579b6 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/v144/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.v144; + +import org.jspecify.annotations.NullMarked; diff --git a/java/src/org/openqa/selenium/devtools/v145/BUILD.bazel b/java/src/org/openqa/selenium/devtools/v145/BUILD.bazel index 09605c4cd51ba..de0c2b7ccd6e6 100644 --- a/java/src/org/openqa/selenium/devtools/v145/BUILD.bazel +++ b/java/src/org/openqa/selenium/devtools/v145/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") load("//common:defs.bzl", "copy_file") load("//java:defs.bzl", "java_export", "java_library") load("//java:version.bzl", "SE_VERSION") @@ -27,6 +28,7 @@ java_export( "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) @@ -42,6 +44,7 @@ java_library( "//java/src/org/openqa/selenium:core", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", + artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/devtools/v145/package-info.java b/java/src/org/openqa/selenium/devtools/v145/package-info.java new file mode 100644 index 0000000000000..9af117579c531 --- /dev/null +++ b/java/src/org/openqa/selenium/devtools/v145/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.v145; + +import org.jspecify.annotations.NullMarked; diff --git a/java/test/org/openqa/selenium/devtools/BUILD.bazel b/java/test/org/openqa/selenium/devtools/BUILD.bazel index 9664a5f059000..7c0b8f68509b9 100644 --- a/java/test/org/openqa/selenium/devtools/BUILD.bazel +++ b/java/test/org/openqa/selenium/devtools/BUILD.bazel @@ -15,6 +15,7 @@ java_test_suite( "//java/src/org/openqa/selenium/remote", artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) @@ -46,6 +47,7 @@ java_selenium_test_suite( artifact("com.google.guava:guava"), artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) @@ -67,5 +69,6 @@ java_library( "//java/test/org/openqa/selenium/testing/drivers", artifact("org.junit.jupiter:junit-jupiter-api"), artifact("org.assertj:assertj-core"), + artifact("org.jspecify:jspecify"), ] + JUNIT5_DEPS, ) diff --git a/java/test/org/openqa/selenium/devtools/idealized/runtime/model/RemoteObjectTest.java b/java/test/org/openqa/selenium/devtools/idealized/runtime/model/RemoteObjectTest.java new file mode 100644 index 0000000000000..951fb6130116b --- /dev/null +++ b/java/test/org/openqa/selenium/devtools/idealized/runtime/model/RemoteObjectTest.java @@ -0,0 +1,57 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.devtools.idealized.runtime.model; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.junit.jupiter.api.Test; + +class RemoteObjectTest { + @Test + void toString_embracesStringValueInQuotes() { + assertThat(new RemoteObject("foo", "bar")) + .hasToString( + """ + "bar" + """ + .trim()); + } + + @Test + void toString_escapesQuotesInStringValue() { + assertThat(new RemoteObject("foo", "bar\"baz")) + .hasToString( + """ + "bar\\"baz" + """ + .trim()); + } + + @Test + void toString_withNonStringValues() { + assertThat(new RemoteObject("foo", 42)).hasToString("42"); + assertThat(new RemoteObject("foo", false)).hasToString("false"); + assertThat(new RemoteObject("foo", List.of("a", "b", "c"))).hasToString("[a, b, c]"); + } + + @Test + void toString_withNullValue() { + assertThat(new RemoteObject("foo", null)).hasToString("null"); + } +} diff --git a/java/test/org/openqa/selenium/devtools/idealized/runtime/model/package-info.java b/java/test/org/openqa/selenium/devtools/idealized/runtime/model/package-info.java new file mode 100644 index 0000000000000..90e68397e2e71 --- /dev/null +++ b/java/test/org/openqa/selenium/devtools/idealized/runtime/model/package-info.java @@ -0,0 +1,21 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +@NullMarked +package org.openqa.selenium.devtools.idealized.runtime.model; + +import org.jspecify.annotations.NullMarked; From a15c6cce9faf8e870c5a2cc20610d5b9bbaa4eb7 Mon Sep 17 00:00:00 2001 From: Corey Goldberg <1113081+cgoldberg@users.noreply.github.com> Date: Sun, 8 Mar 2026 14:53:22 -0400 Subject: [PATCH 53/67] [py] Add type stubs for lazy imported classes and modules (#17165) --- py/BUILD.bazel | 16 ++- py/pyproject.toml | 1 + py/selenium/webdriver/__init__.pyi | 131 ++++++++++++++++++ py/selenium/webdriver/chrome/__init__.pyi | 26 ++++ .../webdriver/chrome/remote_connection.py | 2 +- py/selenium/webdriver/edge/__init__.pyi | 26 ++++ .../webdriver/edge/remote_connection.py | 2 +- py/selenium/webdriver/firefox/__init__.pyi | 26 ++++ py/selenium/webdriver/ie/__init__.pyi | 26 ++++ py/selenium/webdriver/safari/__init__.pyi | 26 ++++ py/selenium/webdriver/webkitgtk/__init__.pyi | 26 ++++ py/selenium/webdriver/wpewebkit/__init__.pyi | 26 ++++ 12 files changed, 328 insertions(+), 6 deletions(-) create mode 100644 py/selenium/webdriver/__init__.pyi create mode 100644 py/selenium/webdriver/chrome/__init__.pyi create mode 100644 py/selenium/webdriver/edge/__init__.pyi create mode 100644 py/selenium/webdriver/firefox/__init__.pyi create mode 100644 py/selenium/webdriver/ie/__init__.pyi create mode 100644 py/selenium/webdriver/safari/__init__.pyi create mode 100644 py/selenium/webdriver/webkitgtk/__init__.pyi create mode 100644 py/selenium/webdriver/wpewebkit/__init__.pyi diff --git a/py/BUILD.bazel b/py/BUILD.bazel index 4d11a3063e66a..b560f297b7537 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -234,6 +234,7 @@ py_library( "selenium/__init__.py", "selenium/webdriver/__init__.py", ] + glob(["selenium/common/**/*.py"]), + data = ["selenium/webdriver/__init__.pyi"], imports = ["."], visibility = ["//visibility:public"], ) @@ -278,9 +279,7 @@ py_library( name = "common", srcs = glob( ["selenium/webdriver/common/**/*.py"], - exclude = [ - "selenium/webdriver/common/bidi/**", - ], + exclude = ["selenium/webdriver/common/bidi/**"], ), data = [ ":manager-linux", @@ -326,6 +325,7 @@ py_library( py_library( name = "chrome", srcs = glob(["selenium/webdriver/chrome/**/*.py"]), + data = ["selenium/webdriver/chrome/__init__.pyi"], imports = ["."], visibility = ["//visibility:public"], deps = [":chromium"], @@ -335,6 +335,7 @@ py_library( py_library( name = "edge", srcs = glob(["selenium/webdriver/edge/**/*.py"]), + data = ["selenium/webdriver/edge/__init__.pyi"], imports = ["."], visibility = ["//visibility:public"], deps = [":chromium"], @@ -344,7 +345,10 @@ py_library( py_library( name = "firefox", srcs = glob(["selenium/webdriver/firefox/**/*.py"]), - data = [":firefox-driver-prefs"], + data = [ + "selenium/webdriver/firefox/__init__.pyi", + ":firefox-driver-prefs", + ], imports = ["."], visibility = ["//visibility:public"], deps = [ @@ -357,6 +361,7 @@ py_library( py_library( name = "safari", srcs = glob(["selenium/webdriver/safari/**/*.py"]), + data = ["selenium/webdriver/safari/__init__.pyi"], imports = ["."], visibility = ["//visibility:public"], deps = [ @@ -369,6 +374,7 @@ py_library( py_library( name = "ie", srcs = glob(["selenium/webdriver/ie/**/*.py"]), + data = ["selenium/webdriver/ie/__init__.pyi"], imports = ["."], visibility = ["//visibility:public"], deps = [ @@ -381,6 +387,7 @@ py_library( py_library( name = "webkitgtk", srcs = glob(["selenium/webdriver/webkitgtk/**/*.py"]), + data = ["selenium/webdriver/webkitgtk/__init__.pyi"], imports = ["."], visibility = ["//visibility:public"], deps = [ @@ -393,6 +400,7 @@ py_library( py_library( name = "wpewebkit", srcs = glob(["selenium/webdriver/wpewebkit/**/*.py"]), + data = ["selenium/webdriver/wpewebkit/__init__.pyi"], imports = ["."], visibility = ["//visibility:public"], deps = [ diff --git a/py/pyproject.toml b/py/pyproject.toml index 8be36470023e9..840869b1d4286 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -67,6 +67,7 @@ binding = "Exec" [tool.setuptools.package-data] "*" = [ "*.py", + "*.pyi", "*.rst", "*.json", "*.xpi", diff --git a/py/selenium/webdriver/__init__.pyi b/py/selenium/webdriver/__init__.pyi new file mode 100644 index 0000000000000..c03fcc67ac1ba --- /dev/null +++ b/py/selenium/webdriver/__init__.pyi @@ -0,0 +1,131 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type stub with lazy import mapping from __init__.py. + +This stub file is necessary for type checkers and IDEs to automatically have +visibility into lazy modules since they are not imported immediately at runtime. +""" + +# ruff: noqa: I001 + +# Expose runtime version +__version__: str + +# Chrome +from selenium.webdriver.chrome.webdriver import WebDriver as Chrome +from selenium.webdriver.chrome.options import Options as ChromeOptions +from selenium.webdriver.chrome.service import Service as ChromeService + +# Edge +from selenium.webdriver.edge.webdriver import WebDriver as Edge +from selenium.webdriver.edge.webdriver import WebDriver as ChromiumEdge +from selenium.webdriver.edge.options import Options as EdgeOptions +from selenium.webdriver.edge.service import Service as EdgeService + +# Firefox +from selenium.webdriver.firefox.webdriver import WebDriver as Firefox +from selenium.webdriver.firefox.options import Options as FirefoxOptions +from selenium.webdriver.firefox.service import Service as FirefoxService +from selenium.webdriver.firefox.firefox_profile import FirefoxProfile + +# IE +from selenium.webdriver.ie.webdriver import WebDriver as Ie +from selenium.webdriver.ie.options import Options as IeOptions +from selenium.webdriver.ie.service import Service as IeService + +# Safari +from selenium.webdriver.safari.webdriver import WebDriver as Safari +from selenium.webdriver.safari.options import Options as SafariOptions +from selenium.webdriver.safari.service import Service as SafariService + +# Remote +from selenium.webdriver.remote.webdriver import WebDriver as Remote + +# WebKitGTK +from selenium.webdriver.webkitgtk.webdriver import WebDriver as WebKitGTK +from selenium.webdriver.webkitgtk.options import Options as WebKitGTKOptions +from selenium.webdriver.webkitgtk.service import Service as WebKitGTKService + +# WPEWebKit +from selenium.webdriver.wpewebkit.webdriver import WebDriver as WPEWebKit +from selenium.webdriver.wpewebkit.options import Options as WPEWebKitOptions +from selenium.webdriver.wpewebkit.service import Service as WPEWebKitService + +# Common utilities +from selenium.webdriver.common.action_chains import ActionChains +from selenium.webdriver.common.desired_capabilities import DesiredCapabilities +from selenium.webdriver.common.keys import Keys +from selenium.webdriver.common.proxy import Proxy + +# Submodules +from . import chrome +from . import chromium +from . import common +from . import edge +from . import firefox +from . import ie +from . import remote +from . import safari +from . import support +from . import webkitgtk +from . import wpewebkit + +# Exposed names +__all__ = [ + # Classes + "ActionChains", + "Chrome", + "ChromeOptions", + "ChromeService", + "ChromiumEdge", + "DesiredCapabilities", + "Edge", + "EdgeOptions", + "EdgeService", + "Firefox", + "FirefoxOptions", + "FirefoxProfile", + "FirefoxService", + "Ie", + "IeOptions", + "IeService", + "Keys", + "Proxy", + "Remote", + "Safari", + "SafariOptions", + "SafariService", + "WPEWebKit", + "WPEWebKitOptions", + "WPEWebKitService", + "WebKitGTK", + "WebKitGTKOptions", + "WebKitGTKService", + # Submodules + "chrome", + "chromium", + "common", + "edge", + "firefox", + "ie", + "remote", + "safari", + "support", + "webkitgtk", + "wpewebkit", +] diff --git a/py/selenium/webdriver/chrome/__init__.pyi b/py/selenium/webdriver/chrome/__init__.pyi new file mode 100644 index 0000000000000..44d8a67ea67e2 --- /dev/null +++ b/py/selenium/webdriver/chrome/__init__.pyi @@ -0,0 +1,26 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type stub with lazy import mapping from __init__.py. + +This stub file is necessary for type checkers and IDEs to automatically have +visibility into lazy modules since they are not imported immediately at runtime. +""" + +from . import options, remote_connection, service, webdriver + +__all__ = ["options", "remote_connection", "service", "webdriver"] diff --git a/py/selenium/webdriver/chrome/remote_connection.py b/py/selenium/webdriver/chrome/remote_connection.py index 134a13d06180e..93c684abf8f86 100644 --- a/py/selenium/webdriver/chrome/remote_connection.py +++ b/py/selenium/webdriver/chrome/remote_connection.py @@ -16,8 +16,8 @@ # under the License. -from selenium.webdriver import DesiredCapabilities from selenium.webdriver.chromium.remote_connection import ChromiumRemoteConnection +from selenium.webdriver.common.desired_capabilities import DesiredCapabilities from selenium.webdriver.remote.client_config import ClientConfig diff --git a/py/selenium/webdriver/edge/__init__.pyi b/py/selenium/webdriver/edge/__init__.pyi new file mode 100644 index 0000000000000..44d8a67ea67e2 --- /dev/null +++ b/py/selenium/webdriver/edge/__init__.pyi @@ -0,0 +1,26 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type stub with lazy import mapping from __init__.py. + +This stub file is necessary for type checkers and IDEs to automatically have +visibility into lazy modules since they are not imported immediately at runtime. +""" + +from . import options, remote_connection, service, webdriver + +__all__ = ["options", "remote_connection", "service", "webdriver"] diff --git a/py/selenium/webdriver/edge/remote_connection.py b/py/selenium/webdriver/edge/remote_connection.py index 5aa4f99cd304f..dc859e145665b 100644 --- a/py/selenium/webdriver/edge/remote_connection.py +++ b/py/selenium/webdriver/edge/remote_connection.py @@ -16,8 +16,8 @@ # under the License. -from selenium.webdriver import DesiredCapabilities from selenium.webdriver.chromium.remote_connection import ChromiumRemoteConnection +from selenium.webdriver.common.desired_capabilities import DesiredCapabilities from selenium.webdriver.remote.client_config import ClientConfig diff --git a/py/selenium/webdriver/firefox/__init__.pyi b/py/selenium/webdriver/firefox/__init__.pyi new file mode 100644 index 0000000000000..e025d5fc409d1 --- /dev/null +++ b/py/selenium/webdriver/firefox/__init__.pyi @@ -0,0 +1,26 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type stub with lazy import mapping from __init__.py. + +This stub file is necessary for type checkers and IDEs to automatically have +visibility into lazy modules since they are not imported immediately at runtime. +""" + +from . import firefox_profile, options, remote_connection, service, webdriver + +__all__ = ["firefox_profile", "options", "remote_connection", "service", "webdriver"] diff --git a/py/selenium/webdriver/ie/__init__.pyi b/py/selenium/webdriver/ie/__init__.pyi new file mode 100644 index 0000000000000..b3479da1ad72e --- /dev/null +++ b/py/selenium/webdriver/ie/__init__.pyi @@ -0,0 +1,26 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type stub with lazy import mapping from __init__.py. + +This stub file is necessary for type checkers and IDEs to automatically have +visibility into lazy modules since they are not imported immediately at runtime. +""" + +from . import options, service, webdriver + +__all__ = ["options", "service", "webdriver"] diff --git a/py/selenium/webdriver/safari/__init__.pyi b/py/selenium/webdriver/safari/__init__.pyi new file mode 100644 index 0000000000000..519f3700f14a0 --- /dev/null +++ b/py/selenium/webdriver/safari/__init__.pyi @@ -0,0 +1,26 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type stub with lazy import mapping from __init__.py. + +This stub file is necessary for type checkers and IDEs to automatically have +visibility into lazy modules since they are not imported immediately at runtime. +""" + +from . import options, permissions, remote_connection, service, webdriver + +__all__ = ["options", "permissions", "remote_connection", "service", "webdriver"] diff --git a/py/selenium/webdriver/webkitgtk/__init__.pyi b/py/selenium/webdriver/webkitgtk/__init__.pyi new file mode 100644 index 0000000000000..b3479da1ad72e --- /dev/null +++ b/py/selenium/webdriver/webkitgtk/__init__.pyi @@ -0,0 +1,26 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type stub with lazy import mapping from __init__.py. + +This stub file is necessary for type checkers and IDEs to automatically have +visibility into lazy modules since they are not imported immediately at runtime. +""" + +from . import options, service, webdriver + +__all__ = ["options", "service", "webdriver"] diff --git a/py/selenium/webdriver/wpewebkit/__init__.pyi b/py/selenium/webdriver/wpewebkit/__init__.pyi new file mode 100644 index 0000000000000..b3479da1ad72e --- /dev/null +++ b/py/selenium/webdriver/wpewebkit/__init__.pyi @@ -0,0 +1,26 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type stub with lazy import mapping from __init__.py. + +This stub file is necessary for type checkers and IDEs to automatically have +visibility into lazy modules since they are not imported immediately at runtime. +""" + +from . import options, service, webdriver + +__all__ = ["options", "service", "webdriver"] From 54918ec60f252e18d89cda0c4af8b35a4464c740 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Mon, 9 Mar 2026 01:51:47 +0100 Subject: [PATCH 54/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17192) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index b625b4d4659f1..82e94da12eae5 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -182,8 +182,8 @@ js_library( http_archive( name = "mac_edgedriver", - url = "https://msedgedriver.microsoft.com/145.0.3800.82/edgedriver_mac64_m1.zip", - sha256 = "178502258e17aef84051c5a317956ab3d3944e7329e3eee662105b17a3b60dc8", + url = "https://msedgedriver.microsoft.com/145.0.3800.97/edgedriver_mac64_m1.zip", + sha256 = "8e811839e69209e947f63da484e5813f1eca3eb9af673d780eab9c448b77f386", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) From 520e14b03398342e3d60b7665d833428358c698a Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Mon, 9 Mar 2026 10:45:59 +0300 Subject: [PATCH 55/67] [dotnet] Apply selenium theme for docs (#17190) --- dotnet/docs/docfx.json | 3 ++- dotnet/docs/index.md | 4 ++-- dotnet/docs/templates/selenium/public/main.css | 8 ++++++++ 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 dotnet/docs/templates/selenium/public/main.css diff --git a/dotnet/docs/docfx.json b/dotnet/docs/docfx.json index 1ae440cb281a2..ab9897e1ab965 100644 --- a/dotnet/docs/docfx.json +++ b/dotnet/docs/docfx.json @@ -48,7 +48,8 @@ "output": "../../build/docs/api/dotnet", "template": [ "default", - "modern" + "modern", + "templates/selenium" ], "globalMetadata": { "_appName": "Selenium .NET API", diff --git a/dotnet/docs/index.md b/dotnet/docs/index.md index 36682c63021de..3a78bc472ce55 100644 --- a/dotnet/docs/index.md +++ b/dotnet/docs/index.md @@ -5,5 +5,5 @@ layout: landingPage # Welcome to the Selenium .NET API Docs ## Modules -- [Selenium.WebDriver](webdriver/OpenQA.Selenium.html) -- [Selenium.Support](support/OpenQA.Selenium.Support.html) \ No newline at end of file +- [Selenium.WebDriver](webdriver/OpenQA.Selenium.yml) +- [Selenium.Support](support/OpenQA.Selenium.Support.yml) diff --git a/dotnet/docs/templates/selenium/public/main.css b/dotnet/docs/templates/selenium/public/main.css new file mode 100644 index 0000000000000..6daec3186d6b5 --- /dev/null +++ b/dotnet/docs/templates/selenium/public/main.css @@ -0,0 +1,8 @@ +:root { + --bs-primary: #43b02a; + --bs-primary-rgb: 67, 176, 42; + --bs-link-color: #43b02a; + --bs-link-color-rgb: 67, 176, 42; + --bs-link-hover-color: #369022; + --bs-link-hover-color-rgb: 54, 144, 34; +} From a07008e65f3aa7acee5097f1eb45bbb418f889a9 Mon Sep 17 00:00:00 2001 From: Selenium CI Bot Date: Tue, 10 Mar 2026 01:47:06 +0100 Subject: [PATCH 56/67] [dotnet][rb][java][js][py] Automated Browser Version Update (#17195) Update pinned browser versions Co-authored-by: Selenium CI Bot --- common/repositories.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/repositories.bzl b/common/repositories.bzl index 82e94da12eae5..dac7b4c479002 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -50,8 +50,8 @@ js_library( http_archive( name = "linux_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b5/linux-x86_64/en-US/firefox-149.0b5.tar.xz", - sha256 = "1c9c3c4955c542815549f3c2737a54c4adba74cc949e07441fede36d02f1a536", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b6/linux-x86_64/en-US/firefox-149.0b6.tar.xz", + sha256 = "b7f14d7b33678a55cd9c27de0f81b586b3d1dab04ddd5a6d1584f32780f67597", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -72,8 +72,8 @@ js_library( dmg_archive( name = "mac_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b5/mac/en-US/Firefox%20149.0b5.dmg", - sha256 = "7bca909999df7ce1a6d6b8a1ebcf121f9f85a2f20dc0205d70362b9de813df51", + url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b6/mac/en-US/Firefox%20149.0b6.dmg", + sha256 = "be690a8bb38861766c72adc7d44c4296a153ba02979f92ed8e494fec846f8b5d", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) From 02d393474633813267e07647e7d923761a72c68e Mon Sep 17 00:00:00 2001 From: David Burns Date: Tue, 10 Mar 2026 18:05:39 +0000 Subject: [PATCH 57/67] [py] Bidi py tests expansion (#17193) --- .../webdriver/common/bidi_errors_tests.py | 149 ++++ .../webdriver/common/bidi_input_tests.py | 554 ++++++++++++++- .../common/bidi_integration_tests.py | 266 +++++++ .../webdriver/common/bidi_log_tests.py | 161 +++++ .../webdriver/common/bidi_script_tests.py | 670 ++++++++++++++++-- .../webdriver/common/bidi_storage_tests.py | 431 ++++++++++- .../common/bidi_webextension_tests.py | 307 +++++++- 7 files changed, 2437 insertions(+), 101 deletions(-) create mode 100644 py/test/selenium/webdriver/common/bidi_errors_tests.py create mode 100644 py/test/selenium/webdriver/common/bidi_integration_tests.py create mode 100644 py/test/selenium/webdriver/common/bidi_log_tests.py diff --git a/py/test/selenium/webdriver/common/bidi_errors_tests.py b/py/test/selenium/webdriver/common/bidi_errors_tests.py new file mode 100644 index 0000000000000..2d826b280ca0b --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_errors_tests.py @@ -0,0 +1,149 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from selenium.common.exceptions import WebDriverException +from selenium.webdriver.common.by import By + + +def test_invalid_browsing_context_id(driver): + """Test that invalid browsing context ID raises an error.""" + with pytest.raises(WebDriverException): + driver.browsing_context.close("invalid-context-id") + + +def test_invalid_navigation_url(driver): + """Test that navigation with invalid context should fail.""" + with pytest.raises(WebDriverException): + # Invalid context ID should fail + driver.browsing_context.navigate("invalid-context-id", "about:blank") + + +def test_invalid_geolocation_coordinates(driver): + """Test that invalid geolocation coordinates raise an error.""" + from selenium.webdriver.common.bidi.emulation import GeolocationCoordinates + + with pytest.raises((WebDriverException, ValueError, TypeError)): + # Invalid latitude (> 90) + coords = GeolocationCoordinates(latitude=999, longitude=180, accuracy=10) + driver.emulation.set_geolocation_override(coordinates=coords) + + +def test_invalid_timezone(driver): + """Test that invalid timezone string raises an error.""" + with pytest.raises((WebDriverException, ValueError)): + driver.emulation.set_timezone_override("Invalid/Timezone") + + +def test_invalid_set_cookie(driver, pages): + """Test that setting cookie with None raises an error.""" + pages.load("blank.html") + + with pytest.raises((WebDriverException, TypeError, AttributeError)): + driver.storage.set_cookie(None) + + +def test_remove_nonexistent_context(driver): + """Test that removing non-existent context raises an error.""" + with pytest.raises(WebDriverException): + driver.browser.remove_user_context("non-existent-context-id") + + +def test_invalid_perform_actions_missing_context(driver, pages): + """Test that perform_actions without context raises an error.""" + pages.load("blank.html") + + with pytest.raises(TypeError): + # Missing required 'context' parameter + driver.input.perform_actions(actions=[]) + + +def test_error_recovery_after_invalid_navigation(driver): + """Test that driver can recover after failed navigation.""" + # Try an invalid navigation with bad context + with pytest.raises(WebDriverException): + driver.browsing_context.navigate("invalid-context", "about:blank") + + # Driver should still be functional + driver.get("about:blank") + assert driver.find_element(By.TAG_NAME, "body") is not None + + +def test_multiple_error_conditions(driver, pages): + """Test handling multiple error conditions in sequence.""" + pages.load("blank.html") + + # First error + with pytest.raises(WebDriverException): + driver.browser.remove_user_context("invalid") + + # Driver should still work + assert driver.find_element(By.TAG_NAME, "body") is not None + + # Second error + with pytest.raises((WebDriverException, ValueError)): + driver.emulation.set_timezone_override("Invalid") + + # Driver still functional + driver.get("about:blank") + + +class TestBidiErrorHandling: + """Test class for error handling in BiDi operations.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test in this class.""" + pages.load("blank.html") + + def test_error_on_invalid_context_operations(self, driver): + """Test error handling with invalid context operations.""" + # Try to close non-existent context + with pytest.raises(WebDriverException): + driver.browsing_context.close("nonexistent") + + def test_error_recovery_sequence(self, driver): + """Test that driver recovers properly from errors.""" + # First operation fails + with pytest.raises(WebDriverException): + driver.browser.remove_user_context("bad-id") + + # Recovery test + element = driver.find_element(By.TAG_NAME, "body") + assert element is not None + + def test_consecutive_errors(self, driver): + """Test handling consecutive errors.""" + errors_caught = 0 + + # First error + try: + driver.browser.remove_user_context("id1") + except WebDriverException: + errors_caught += 1 + + # Second error + try: + driver.browser.remove_user_context("id2") + except WebDriverException: + errors_caught += 1 + + assert errors_caught == 2 + + # Driver should still work + driver.get("about:blank") diff --git a/py/test/selenium/webdriver/common/bidi_input_tests.py b/py/test/selenium/webdriver/common/bidi_input_tests.py index ecbe0bddd4f73..9929a01117924 100644 --- a/py/test/selenium/webdriver/common/bidi_input_tests.py +++ b/py/test/selenium/webdriver/common/bidi_input_tests.py @@ -27,6 +27,7 @@ KeyDownAction, KeySourceActions, KeyUpAction, + NoneSourceActions, Origin, PauseAction, PointerCommonProperties, @@ -73,7 +74,9 @@ def test_basic_key_input(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions]) - WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "hello") + WebDriverWait(driver, 5).until( + lambda d: input_element.get_attribute("value") == "hello" + ) assert input_element.get_attribute("value") == "hello" @@ -97,7 +100,9 @@ def test_key_input_with_pause(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions]) - WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "ab") + WebDriverWait(driver, 5).until( + lambda d: input_element.get_attribute("value") == "ab" + ) assert input_element.get_attribute("value") == "ab" @@ -170,7 +175,13 @@ def test_pointer_with_common_properties(driver, pages): # Create pointer properties properties = PointerCommonProperties( - width=2, height=2, pressure=0.5, tangential_pressure=0.0, twist=45, altitude_angle=0.5, azimuth_angle=1.0 + width=2, + height=2, + pressure=0.5, + tangential_pressure=0.0, + twist=45, + altitude_angle=0.5, + azimuth_angle=1.0, ) pointer_actions = PointerSourceActions( @@ -196,7 +207,12 @@ def test_wheel_scroll(driver, pages): # Scroll down wheel_actions = WheelSourceActions( - id="wheel", actions=[WheelScrollAction(x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT)] + id="wheel", + actions=[ + WheelScrollAction( + x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT + ) + ], ) driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) @@ -247,9 +263,13 @@ def test_combined_input_actions(driver, pages): ], ) - driver.input.perform_actions(driver.current_window_handle, [pointer_actions, key_actions]) + driver.input.perform_actions( + driver.current_window_handle, [pointer_actions, key_actions] + ) - WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "test") + WebDriverWait(driver, 5).until( + lambda d: input_element.get_attribute("value") == "test" + ) assert input_element.get_attribute("value") == "test" @@ -261,7 +281,9 @@ def test_set_files(driver, pages): assert upload_element.get_attribute("value") == "" # Create a temporary file - with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as temp_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as temp_file: temp_file.write("test content") temp_file_path = temp_file.name @@ -271,7 +293,9 @@ def test_set_files(driver, pages): element_ref = {"sharedId": element_id} # Set files using BiDi - driver.input.set_files(driver.current_window_handle, element_ref, [temp_file_path]) + driver.input.set_files( + driver.current_window_handle, element_ref, [temp_file_path] + ) # Verify file was set value = upload_element.get_attribute("value") @@ -346,7 +370,9 @@ def test_release_actions(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions2]) # Should be able to type normally - WebDriverWait(driver, 5).until(lambda d: "b" in input_element.get_attribute("value")) + WebDriverWait(driver, 5).until( + lambda d: "b" in input_element.get_attribute("value") + ) @pytest.mark.parametrize("multiple", [True, False]) @@ -362,7 +388,9 @@ def file_dialog_handler(file_dialog_info): handler_id = driver.input.add_file_dialog_handler(file_dialog_handler) assert handler_id is not None - driver.get(f"data:text/html,") + driver.get( + f"data:text/html," + ) # Use script.evaluate to trigger the file dialog with user activation driver.script._evaluate( @@ -413,3 +441,509 @@ def file_dialog_handler(file_dialog_info): # Wait to ensure no events are captured time.sleep(1) assert len(file_dialog_events) == 0 + + +# Edge Cases and Additional Tests + + +def test_perform_actions_with_none_source(driver, pages): + """Test performing NoneSourceActions (pause only).""" + pages.load("single_text_input.html") + + # Create none actions (pause only - no actual input) + none_actions = NoneSourceActions( + id="none_id", + actions=[ + PauseAction(duration=100), + PauseAction(duration=50), + ], + ) + + # Should execute without error + driver.input.perform_actions(driver.current_window_handle, [none_actions]) + + # Verify input field is still empty + input_element = driver.find_element(By.ID, "textInput") + assert input_element.get_attribute("value") == "" + + +def test_perform_actions_rapid_key_sequence(driver, pages): + """Test rapid key input sequence without pause between keys.""" + pages.load("single_text_input.html") + + input_element = driver.find_element(By.ID, "textInput") + + # Create rapid key sequence + key_actions = KeySourceActions( + id="keyboard", + actions=[ + KeyDownAction(value="a"), + KeyUpAction(value="a"), + KeyDownAction(value="b"), + KeyUpAction(value="b"), + KeyDownAction(value="c"), + KeyUpAction(value="c"), + KeyDownAction(value="d"), + KeyUpAction(value="d"), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [key_actions]) + + WebDriverWait(driver, 5).until( + lambda d: input_element.get_attribute("value") == "abcd" + ) + assert input_element.get_attribute("value") == "abcd" + + +def test_perform_actions_multiple_pointer_buttons(driver, pages): + """Test pointer actions with different button values.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + location = button.location + size = button.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Test with button 0 (left click) + pointer_actions_left = PointerSourceActions( + id="mouse_left", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PointerMoveAction(x=x, y=y), + PointerDownAction(button=0), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions_left]) + + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_perform_actions_pointer_touch_type(driver, pages): + """Test pointer actions with touch pointer type.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + location = button.location + size = button.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Create touch actions + touch_actions = PointerSourceActions( + id="touch", + parameters=PointerParameters(pointer_type=PointerType.TOUCH), + actions=[ + PointerMoveAction(x=x, y=y), + PointerDownAction(button=0), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [touch_actions]) + + # Touch should work similar to mouse click + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_perform_actions_pointer_pen_type(driver, pages): + """Test pointer actions with pen pointer type.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + location = button.location + size = button.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Create pen actions + pen_actions = PointerSourceActions( + id="pen", + parameters=PointerParameters(pointer_type=PointerType.PEN), + actions=[ + PointerMoveAction(x=x, y=y), + PointerDownAction(button=0), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pen_actions]) + + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_perform_actions_pointer_move_with_duration(driver, pages): + """Test pointer move action with duration parameter.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + location = button.location + size = button.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Start point (off the button) + start_x = x - 100 + start_y = y - 100 + + # Create pointer actions with duration on move + pointer_actions = PointerSourceActions( + id="mouse", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PointerMoveAction(x=start_x, y=start_y), + PointerMoveAction(x=x, y=y, duration=500), # Slow move + PointerDownAction(button=0), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions]) + + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_wheel_scroll_negative_delta(driver, pages): + """Test wheel scroll with negative delta values (up/left).""" + pages.load("scroll3.html") + + # First scroll down + wheel_actions_down = WheelSourceActions( + id="wheel_down", + actions=[ + WheelScrollAction( + x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT + ) + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [wheel_actions_down]) + + scroll_y_down = driver.execute_script("return window.pageYOffset;") + assert scroll_y_down > 0 + + # Then scroll back up (negative delta) + wheel_actions_up = WheelSourceActions( + id="wheel_up", + actions=[ + WheelScrollAction( + x=100, y=100, delta_x=0, delta_y=-50, origin=Origin.VIEWPORT + ) + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [wheel_actions_up]) + + scroll_y_up = driver.execute_script("return window.pageYOffset;") + assert scroll_y_up < scroll_y_down + + +def test_wheel_scroll_with_duration(driver, pages): + """Test wheel scroll action with duration parameter.""" + pages.load("scroll3.html") + + wheel_actions = WheelSourceActions( + id="wheel", + actions=[ + WheelScrollAction( + x=100, + y=100, + delta_x=0, + delta_y=100, + duration=500, + origin=Origin.VIEWPORT, + ) + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) + + scroll_y = driver.execute_script("return window.pageYOffset;") + assert scroll_y == 100 + + +def test_wheel_scroll_horizontal(driver, pages): + """Test wheel scroll with horizontal movement.""" + pages.load("scroll3.html") + + # Scroll horizontally + wheel_actions = WheelSourceActions( + id="wheel", + actions=[ + WheelScrollAction( + x=100, y=100, delta_x=50, delta_y=0, origin=Origin.VIEWPORT + ) + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) + + # Check horizontal scroll occurred + scroll_x = driver.execute_script("return window.pageXOffset;") + assert scroll_x >= 0 + + +def test_key_input_special_characters(driver, pages): + """Test keyboard input with special characters.""" + pages.load("single_text_input.html") + + input_element = driver.find_element(By.ID, "textInput") + + # Create keyboard actions for special characters + key_actions = KeySourceActions( + id="keyboard", + actions=[ + KeyDownAction(value="!"), + KeyUpAction(value="!"), + KeyDownAction(value="@"), + KeyUpAction(value="@"), + KeyDownAction(value="#"), + KeyUpAction(value="#"), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [key_actions]) + + WebDriverWait(driver, 5).until( + lambda d: "!" in input_element.get_attribute("value") + ) + + +def test_set_files_empty_file_list(driver, pages): + """Test setting an empty file list on a file input element.""" + pages.load("formPage.html") + + upload_element = driver.find_element(By.ID, "upload") + + # Get element reference for BiDi + element_id = upload_element.id + element_ref = {"sharedId": element_id} + + # Set empty file list + driver.input.set_files(driver.current_window_handle, element_ref, []) + + # Value should be empty + value = upload_element.get_attribute("value") + assert value == "" + + +def test_set_files_with_absolute_path(driver): + """Test setting a file using absolute file path.""" + driver.get("data:text/html,") + + upload_element = driver.find_element(By.ID, "upload") + + # Create a temporary file + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as temp_file: + temp_file.write("test file content") + temp_file_path = temp_file.name + + try: + # Get element reference + element_id = upload_element.id + element_ref = {"sharedId": element_id} + + # Set file using absolute path + driver.input.set_files( + driver.current_window_handle, element_ref, [temp_file_path] + ) + + value = upload_element.get_attribute("value") + assert os.path.basename(temp_file_path) in value + + finally: + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + +def test_release_actions_clears_pointer_state(driver, pages): + """Test that release_actions properly clears pointer state.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + location = button.location + size = button.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Press pointer button but don't release + pointer_actions = PointerSourceActions( + id="mouse", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PointerMoveAction(x=x, y=y), + PointerDownAction(button=0), + # Not releasing button + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions]) + + # Release all actions + driver.input.release_actions(driver.current_window_handle) + + # Now move and try clicking again - should work normally + pointer_actions2 = PointerSourceActions( + id="mouse", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PointerMoveAction(x=x, y=y), + PointerDownAction(button=0), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions2]) + + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_multiple_file_dialog_handlers(driver): + """Test registering multiple file dialog handlers.""" + handlers_triggered = [] + + def handler_1(file_dialog_info): + handlers_triggered.append(1) + + def handler_2(file_dialog_info): + handlers_triggered.append(2) + + # Register two handlers + handler_id_1 = driver.input.add_file_dialog_handler(handler_1) + handler_id_2 = driver.input.add_file_dialog_handler(handler_2) + + assert handler_id_1 is not None + assert handler_id_2 is not None + assert handler_id_1 != handler_id_2 + + # Clean up + driver.input.remove_file_dialog_handler(handler_id_1) + driver.input.remove_file_dialog_handler(handler_id_2) + + +def test_pointer_common_properties_pressure_values(driver, pages): + """Test pointer actions with various pressure values.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + location = button.location + size = button.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Test with different pressure values + properties = PointerCommonProperties( + width=2, + height=2, + pressure=0.75, # High pressure + tangential_pressure=0.25, + twist=90, + altitude_angle=0.7, + azimuth_angle=1.5, + ) + + pointer_actions = PointerSourceActions( + id="mouse", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PointerMoveAction(x=x, y=y, properties=properties), + PointerDownAction(button=0, properties=properties), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions]) + + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_combined_keyboard_and_wheel_actions(driver, pages): + """Test combining keyboard and wheel scroll actions.""" + pages.load("scroll3.html") + + # Combine keyboard and wheel actions + key_actions = KeySourceActions( + id="keyboard", + actions=[PauseAction(duration=0)], # Sync with wheel + ) + + wheel_actions = WheelSourceActions( + id="wheel", + actions=[ + PauseAction(duration=0), # Sync with keyboard + WheelScrollAction( + x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT + ), + ], + ) + + driver.input.perform_actions( + driver.current_window_handle, [key_actions, wheel_actions] + ) + + scroll_y = driver.execute_script("return window.pageYOffset;") + assert scroll_y == 100 + + +def test_key_input_with_value_attribute(driver, pages): + """Test KeyDownAction and KeyUpAction use value attribute correctly.""" + pages.load("single_text_input.html") + + input_element = driver.find_element(By.ID, "textInput") + + # Use explicit value attribute in actions + key_actions = KeySourceActions( + id="keyboard", + actions=[ + KeyDownAction(value="x"), + KeyUpAction(value="x"), + KeyDownAction(value="y"), + KeyUpAction(value="y"), + KeyDownAction(value="z"), + KeyUpAction(value="z"), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [key_actions]) + + WebDriverWait(driver, 5).until( + lambda d: input_element.get_attribute("value") == "xyz" + ) + assert input_element.get_attribute("value") == "xyz" + + +def test_wheel_scroll_with_element_origin(driver, pages): + """Test wheel scroll with element origin instead of viewport.""" + pages.load("scroll3.html") + + # Get a reference to a scrollable element (body) + body_element = driver.find_element(By.TAG_NAME, "body") + element_id = body_element.id + element_ref = {"sharedId": element_id} + element_origin = ElementOrigin(element_ref) + + # Scroll with element origin + wheel_actions = WheelSourceActions( + id="wheel", + actions=[ + WheelScrollAction( + x=100, y=100, delta_x=0, delta_y=100, origin=element_origin + ) + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) + + scroll_y = driver.execute_script("return window.pageYOffset;") + assert scroll_y >= 0 diff --git a/py/test/selenium/webdriver/common/bidi_integration_tests.py b/py/test/selenium/webdriver/common/bidi_integration_tests.py new file mode 100644 index 0000000000000..85323f49b3ccf --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_integration_tests.py @@ -0,0 +1,266 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from selenium.webdriver.common.by import By +from selenium.webdriver.common.window import WindowTypes +from selenium.webdriver.support.ui import WebDriverWait + + +class TestBidiNetworkWithCookies: + """Test integration of network and storage modules.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test in this class.""" + pages.load("blank.html") + yield + # Cleanup: delete all cookies to prevent bleed-through + driver.delete_all_cookies() + + def test_cookies_interaction(self, driver, pages): + """Test that cookies work with network operations.""" + pages.load("blank.html") + + # Set a cookie + driver.add_cookie({"name": "test_cookie", "value": "test_value"}) + + # Verify cookie is set + cookies = driver.get_cookies() + assert len(cookies) > 0 + assert any(c.get("name") == "test_cookie" for c in cookies) + + def test_cookie_modification(self, driver, pages): + """Test that modifying cookies works properly.""" + pages.load("blank.html") + + # Add first cookie + driver.add_cookie({"name": "cookie1", "value": "value1"}) + + cookies_before = driver.get_cookies() + initial_count = len(cookies_before) + + # Add second cookie + driver.add_cookie({"name": "cookie2", "value": "value2"}) + + cookies_after = driver.get_cookies() + assert len(cookies_after) > initial_count + + +class TestBidiScriptWithNavigation: + """Test integration of script execution and navigation.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test in this class.""" + driver.delete_all_cookies() + pages.load("blank.html") + yield + # Cleanup: delete all cookies to prevent bleed-through + driver.delete_all_cookies() + + def test_script_execution_after_navigation(self, driver, pages): + """Test script execution after page navigation.""" + # First page + pages.load("blank.html") + driver.execute_script("window.page1_loaded = true;") + + # Navigate to different page + pages.load("blank.html") + + # Previous page variable should not exist + result = driver.execute_script("return window.page1_loaded;") + assert result is None + + # New variable should work + driver.execute_script("window.page2_loaded = true;") + result = driver.execute_script("return window.page2_loaded;") + assert result is True + + def test_global_variable_lifecycle(self, driver, pages): + """Test global variable lifecycle across operations.""" + pages.load("blank.html") + + # Set a global variable + driver.execute_script("window.test_var = {data: 'value'};") + + # Verify it exists + result = driver.execute_script("return window.test_var.data;") + assert result == "value" + + # Navigate away + driver.get("about:blank") + + # Variable should not exist anymore + result = driver.execute_script("return typeof window.test_var;") + assert result == "undefined" + + +class TestBidiEmulationWithNavigation: + """Test integration of emulation and navigation.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test in this class.""" + pages.load("blank.html") + yield + # Cleanup: delete all cookies to prevent bleed-through + driver.delete_all_cookies() + + def test_basic_navigation(self, driver, pages): + """Test basic navigation.""" + pages.load("blank.html") + assert driver.find_element(By.TAG_NAME, "body") is not None + + +class TestBidiContextManagement: + """Test integration of context creation and management.""" + + def test_create_and_close_context(self, driver): + """Test creating and closing a user context.""" + new_context = driver.browser.create_user_context() + + try: + assert new_context is not None + finally: + driver.browser.remove_user_context(new_context) + + def test_multiple_contexts_creation(self, driver): + """Test creating multiple contexts.""" + context1 = driver.browser.create_user_context() + context2 = driver.browser.create_user_context() + + try: + assert context1 is not None + assert context2 is not None + assert context1 != context2 + finally: + driver.browser.remove_user_context(context1) + driver.browser.remove_user_context(context2) + + +class TestBidiEventHandlers: + """Test integration of event handlers.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test in this class.""" + pages.load("blank.html") + yield + # Cleanup: delete all cookies to prevent bleed-through + driver.delete_all_cookies() + + def test_multiple_console_handlers(self, driver): + """Test multiple console message handlers.""" + messages1 = [] + messages2 = [] + + handler1 = driver.script.add_console_message_handler(messages1.append) + handler2 = driver.script.add_console_message_handler(messages2.append) + + try: + driver.execute_script("console.log('test message');") + WebDriverWait(driver, 5).until( + lambda _: len(messages1) > 0 and len(messages2) > 0 + ) + + assert len(messages1) > 0 + assert len(messages2) > 0 + finally: + driver.script.remove_console_message_handler(handler1) + driver.script.remove_console_message_handler(handler2) + + +class TestBidiStorageOperations: + """Test storage operations.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test in this class.""" + driver.delete_all_cookies() + pages.load("blank.html") + yield + # Cleanup: delete all cookies to prevent bleed-through + driver.delete_all_cookies() + + def test_cookie_operations(self, driver, pages): + """Test basic cookie operations.""" + pages.load("blank.html") + + # Set cookie + driver.add_cookie({"name": "test", "value": "data"}) + + # Get cookies + cookies = driver.get_cookies() + assert any(c.get("name") == "test" for c in cookies) + + # Delete cookie + driver.delete_cookie("test") + + # Verify deletion + cookies_after = driver.get_cookies() + assert not any(c.get("name") == "test" for c in cookies_after) + + def test_cookie_attributes(self, driver, pages): + """Test cookie with various attributes.""" + pages.load("blank.html") + + driver.add_cookie( + {"name": "attr_cookie", "value": "test_value", "path": "/", "secure": False} + ) + + cookies = driver.get_cookies() + cookie = next((c for c in cookies if c.get("name") == "attr_cookie"), None) + + assert cookie is not None + assert cookie.get("value") == "test_value" + + +class TestBidiBrowsingContexts: + """Test browsing context operations.""" + + @pytest.fixture(autouse=True) + def setup(self, driver): + """Setup for each test in this class.""" + driver.delete_all_cookies() + yield + # Cleanup: delete all cookies to prevent bleed-through + driver.delete_all_cookies() + + def test_create_new_window(self, driver): + """Test creating a new window context.""" + # Create new tab + new_context = driver.browsing_context.create(type=WindowTypes.TAB) + + try: + assert new_context is not None + finally: + driver.browsing_context.close(new_context) + + def test_navigation_in_context(self, driver, pages): + """Test navigation in a specific context.""" + pages.load("blank.html") + + # Navigate using the BiDi API with the current context + driver.browsing_context.navigate( + context=driver.current_window_handle, url=pages.url("blank.html") + ) + + # Verify page loaded + element = driver.find_element(By.TAG_NAME, "body") + assert element is not None diff --git a/py/test/selenium/webdriver/common/bidi_log_tests.py b/py/test/selenium/webdriver/common/bidi_log_tests.py new file mode 100644 index 0000000000000..fbbd3a8166b2d --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_log_tests.py @@ -0,0 +1,161 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import pytest + +from selenium.webdriver.support.ui import WebDriverWait + + +def test_log_module_initialized(driver): + """Test that the log module is initialized properly.""" + assert driver.script is not None + + +class TestBidiLogging: + """Test class for BiDi logging functionality.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test in this class.""" + pages.load("blank.html") + + def test_console_log_message(self, driver): + """Test capturing console.log messages.""" + log_entries = [] + + def callback(log_entry): + log_entries.append(log_entry) + + handler_id = driver.script.add_console_message_handler(callback) + + try: + driver.execute_script("console.log('test message');") + WebDriverWait(driver, 5).until(lambda _: log_entries) + + assert len(log_entries) > 0 + finally: + driver.script.remove_console_message_handler(handler_id) + + def test_console_multiple_messages(self, driver): + """Test capturing multiple console messages.""" + log_entries = [] + + handler_id = driver.script.add_console_message_handler(log_entries.append) + + try: + driver.execute_script( + """ + console.log('message 1'); + console.log('message 2'); + console.log('message 3'); + """ + ) + + WebDriverWait(driver, 5).until(lambda _: len(log_entries) >= 3) + + assert len(log_entries) >= 3 + finally: + driver.script.remove_console_message_handler(handler_id) + + def test_add_and_remove_handler(self, driver): + """Test adding and removing log handlers.""" + log_entries1 = [] + log_entries2 = [] + + handler_id1 = driver.script.add_console_message_handler(log_entries1.append) + handler_id2 = driver.script.add_console_message_handler(log_entries2.append) + + try: + driver.execute_script("console.log('first message');") + WebDriverWait(driver, 5).until( + lambda _: len(log_entries1) > 0 and len(log_entries2) > 0 + ) + + assert len(log_entries1) > 0 + assert len(log_entries2) > 0 + + # Remove first handler + driver.script.remove_console_message_handler(handler_id1) + + initial_count1 = len(log_entries1) + initial_count2 = len(log_entries2) + + # Trigger another message + driver.execute_script("console.log('second message');") + WebDriverWait(driver, 5).until(lambda _: len(log_entries2) > initial_count2) + + # First handler should not receive new messages + assert len(log_entries1) == initial_count1 + assert len(log_entries2) > initial_count2 + finally: + driver.script.remove_console_message_handler(handler_id2) + + def test_handler_receives_all_levels(self, driver): + """Test that a single handler can receive all log levels.""" + log_levels = [] + + def callback(entry): + log_levels.append(entry) + + handler_id = driver.script.add_console_message_handler(callback) + + try: + driver.execute_script( + """ + console.log('log'); + console.warn('warn'); + console.error('error'); + console.debug('debug'); + console.info('info'); + """ + ) + + WebDriverWait(driver, 5).until(lambda _: len(log_levels) >= 5) + + assert len(log_levels) >= 5 + finally: + driver.script.remove_console_message_handler(handler_id) + + def test_log_with_multiple_arguments(self, driver): + """Test console.log with multiple arguments.""" + log_entries = [] + + handler_id = driver.script.add_console_message_handler(log_entries.append) + + try: + driver.execute_script("console.log('arg1', 'arg2', 'arg3');") + WebDriverWait(driver, 5).until(lambda _: log_entries) + + assert len(log_entries) > 0 + finally: + driver.script.remove_console_message_handler(handler_id) + + def test_log_entry_attributes(self, driver): + """Test log entry has expected attributes.""" + log_entries = [] + + handler_id = driver.script.add_console_message_handler(log_entries.append) + + try: + driver.execute_script("console.log('test');") + WebDriverWait(driver, 5).until(lambda _: log_entries) + + assert len(log_entries) > 0 + assert hasattr(log_entries[0], "text") or hasattr(log_entries[0], "args") + finally: + driver.script.remove_console_message_handler(handler_id) diff --git a/py/test/selenium/webdriver/common/bidi_script_tests.py b/py/test/selenium/webdriver/common/bidi_script_tests.py index 35f8e455573be..5d5fc6ef9b780 100644 --- a/py/test/selenium/webdriver/common/bidi_script_tests.py +++ b/py/test/selenium/webdriver/common/bidi_script_tests.py @@ -17,6 +17,7 @@ import pytest +from selenium.common.exceptions import WebDriverException from selenium.webdriver.common.bidi.log import LogLevel from selenium.webdriver.common.bidi.script import RealmType, ResultOwnership from selenium.webdriver.common.by import By @@ -41,18 +42,21 @@ def test_logs_console_messages(driver, pages): pages.load("bidi/logEntryAdded.html") log_entries = [] - driver.script.add_console_message_handler(log_entries.append) + handler_id = driver.script.add_console_message_handler(log_entries.append) - driver.find_element(By.ID, "jsException").click() - driver.find_element(By.ID, "consoleLog").click() + try: + driver.find_element(By.ID, "jsException").click() + driver.find_element(By.ID, "consoleLog").click() - WebDriverWait(driver, 5).until(lambda _: log_entries) + WebDriverWait(driver, 5).until(lambda _: log_entries) - log_entry = log_entries[0] - assert log_entry.level == LogLevel.INFO - assert log_entry.method == "log" - assert log_entry.text == "Hello, world!" - assert log_entry.type_ == "console" + log_entry = log_entries[0] + assert log_entry.level == LogLevel.INFO + assert log_entry.method == "log" + assert log_entry.text == "Hello, world!" + assert log_entry.type_ == "console" + finally: + driver.script.remove_console_message_handler(handler_id) def test_logs_console_errors(driver, pages): @@ -63,34 +67,41 @@ def log_error(entry): if entry.level == LogLevel.ERROR: log_entries.append(entry) - driver.script.add_console_message_handler(log_error) + handler_id = driver.script.add_console_message_handler(log_error) - driver.find_element(By.ID, "consoleLog").click() - driver.find_element(By.ID, "consoleError").click() + try: + driver.find_element(By.ID, "consoleLog").click() + driver.find_element(By.ID, "consoleError").click() - WebDriverWait(driver, 5).until(lambda _: log_entries) + WebDriverWait(driver, 5).until(lambda _: log_entries) - assert len(log_entries) == 1 + assert len(log_entries) == 1 - log_entry = log_entries[0] - assert log_entry.level == LogLevel.ERROR - assert log_entry.method == "error" - assert log_entry.text == "I am console error" - assert log_entry.type_ == "console" + log_entry = log_entries[0] + assert log_entry.level == LogLevel.ERROR + assert log_entry.method == "error" + assert log_entry.text == "I am console error" + assert log_entry.type_ == "console" + finally: + driver.script.remove_console_message_handler(handler_id) def test_logs_multiple_console_messages(driver, pages): pages.load("bidi/logEntryAdded.html") log_entries = [] - driver.script.add_console_message_handler(log_entries.append) - driver.script.add_console_message_handler(log_entries.append) + handler_id1 = driver.script.add_console_message_handler(log_entries.append) + handler_id2 = driver.script.add_console_message_handler(log_entries.append) - driver.find_element(By.ID, "jsException").click() - driver.find_element(By.ID, "consoleLog").click() + try: + driver.find_element(By.ID, "jsException").click() + driver.find_element(By.ID, "consoleLog").click() - WebDriverWait(driver, 5).until(lambda _: len(log_entries) > 1) - assert len(log_entries) == 2 + WebDriverWait(driver, 5).until(lambda _: len(log_entries) > 1) + assert len(log_entries) == 2 + finally: + driver.script.remove_console_message_handler(handler_id1) + driver.script.remove_console_message_handler(handler_id2) def test_removes_console_message_handler(driver, pages): @@ -99,32 +110,41 @@ def test_removes_console_message_handler(driver, pages): log_entries1 = [] log_entries2 = [] - id = driver.script.add_console_message_handler(log_entries1.append) - driver.script.add_console_message_handler(log_entries2.append) + id1 = driver.script.add_console_message_handler(log_entries1.append) + id2 = driver.script.add_console_message_handler(log_entries2.append) - driver.find_element(By.ID, "consoleLog").click() - WebDriverWait(driver, 5).until(lambda _: len(log_entries1) and len(log_entries2)) + try: + driver.find_element(By.ID, "consoleLog").click() + WebDriverWait(driver, 5).until( + lambda _: len(log_entries1) and len(log_entries2) + ) - driver.script.remove_console_message_handler(id) - driver.find_element(By.ID, "consoleLog").click() + driver.script.remove_console_message_handler(id1) + driver.find_element(By.ID, "consoleLog").click() - WebDriverWait(driver, 5).until(lambda _: len(log_entries2) == 2) - assert len(log_entries1) == 1 + WebDriverWait(driver, 5).until(lambda _: len(log_entries2) == 2) + assert len(log_entries1) == 1 + finally: + driver.script.remove_console_message_handler(id1) + driver.script.remove_console_message_handler(id2) def test_javascript_error_messages(driver, pages): pages.load("bidi/logEntryAdded.html") log_entries = [] - driver.script.add_javascript_error_handler(log_entries.append) + handler_id = driver.script.add_javascript_error_handler(log_entries.append) - driver.find_element(By.ID, "jsException").click() - WebDriverWait(driver, 5).until(lambda _: log_entries) + try: + driver.find_element(By.ID, "jsException").click() + WebDriverWait(driver, 5).until(lambda _: log_entries) - log_entry = log_entries[0] - assert log_entry.text == "Error: Not working" - assert log_entry.level == LogLevel.ERROR - assert log_entry.type_ == "javascript" + log_entry = log_entries[0] + assert log_entry.text == "Error: Not working" + assert log_entry.level == LogLevel.ERROR + assert log_entry.type_ == "javascript" + finally: + driver.script.remove_javascript_error_handler(handler_id) def test_removes_javascript_message_handler(driver, pages): @@ -133,17 +153,23 @@ def test_removes_javascript_message_handler(driver, pages): log_entries1 = [] log_entries2 = [] - id = driver.script.add_javascript_error_handler(log_entries1.append) - driver.script.add_javascript_error_handler(log_entries2.append) + id1 = driver.script.add_javascript_error_handler(log_entries1.append) + id2 = driver.script.add_javascript_error_handler(log_entries2.append) - driver.find_element(By.ID, "jsException").click() - WebDriverWait(driver, 5).until(lambda _: len(log_entries1) and len(log_entries2)) + try: + driver.find_element(By.ID, "jsException").click() + WebDriverWait(driver, 5).until( + lambda _: len(log_entries1) and len(log_entries2) + ) - driver.script.remove_javascript_error_handler(id) - driver.find_element(By.ID, "jsException").click() + driver.script.remove_javascript_error_handler(id1) + driver.find_element(By.ID, "jsException").click() - WebDriverWait(driver, 5).until(lambda _: len(log_entries2) == 2) - assert len(log_entries1) == 1 + WebDriverWait(driver, 5).until(lambda _: len(log_entries2) == 2) + assert len(log_entries1) == 1 + finally: + driver.script.remove_javascript_error_handler(id1) + driver.script.remove_javascript_error_handler(id2) def test_add_preload_script(driver, pages): @@ -159,7 +185,9 @@ def test_add_preload_script(driver, pages): # Check if the preload script was executed result = driver.script._evaluate( - "window.preloadExecuted", {"context": driver.current_window_handle}, await_promise=False + "window.preloadExecuted", + {"context": driver.current_window_handle}, + await_promise=False, ) assert result.result["value"] is True @@ -168,15 +196,21 @@ def test_add_preload_script_with_arguments(driver, pages): """Test adding a preload script with channel arguments.""" function_declaration = "(channelFunc) => { channelFunc('test_value'); window.preloadValue = 'received'; }" - arguments = [{"type": "channel", "value": {"channel": "test-channel", "ownership": "root"}}] + arguments = [ + {"type": "channel", "value": {"channel": "test-channel", "ownership": "root"}} + ] - script_id = driver.script._add_preload_script(function_declaration, arguments=arguments) + script_id = driver.script._add_preload_script( + function_declaration, arguments=arguments + ) assert script_id is not None pages.load("blank.html") result = driver.script._evaluate( - "window.preloadValue", {"context": driver.current_window_handle}, await_promise=False + "window.preloadValue", + {"context": driver.current_window_handle}, + await_promise=False, ) assert result.result["value"] == "received" @@ -186,13 +220,17 @@ def test_add_preload_script_with_contexts(driver, pages): function_declaration = "() => { window.contextSpecific = true; }" contexts = [driver.current_window_handle] - script_id = driver.script._add_preload_script(function_declaration, contexts=contexts) + script_id = driver.script._add_preload_script( + function_declaration, contexts=contexts + ) assert script_id is not None pages.load("blank.html") result = driver.script._evaluate( - "window.contextSpecific", {"context": driver.current_window_handle}, await_promise=False + "window.contextSpecific", + {"context": driver.current_window_handle}, + await_promise=False, ) assert result.result["value"] is True @@ -200,36 +238,50 @@ def test_add_preload_script_with_contexts(driver, pages): def test_add_preload_script_with_user_contexts(driver, pages): """Test adding a preload script with user contexts.""" function_declaration = "() => { window.contextSpecific = true; }" + original_handle = driver.current_window_handle user_context = driver.browser.create_user_context() context1 = driver.browsing_context.create(type="window", user_context=user_context) driver.switch_to.window(context1) - user_contexts = [user_context] + try: + user_contexts = [user_context] - script_id = driver.script._add_preload_script(function_declaration, user_contexts=user_contexts) - assert script_id is not None + script_id = driver.script._add_preload_script( + function_declaration, user_contexts=user_contexts + ) + assert script_id is not None - pages.load("blank.html") + pages.load("blank.html") - result = driver.script._evaluate( - "window.contextSpecific", {"context": driver.current_window_handle}, await_promise=False - ) - assert result.result["value"] is True + result = driver.script._evaluate( + "window.contextSpecific", + {"context": driver.current_window_handle}, + await_promise=False, + ) + assert result.result["value"] is True + finally: + driver.switch_to.window(original_handle) + driver.browsing_context.close(context1) + driver.browser.remove_user_context(user_context) def test_add_preload_script_with_sandbox(driver, pages): """Test adding a preload script with sandbox.""" function_declaration = "() => { window.sandboxScript = true; }" - script_id = driver.script._add_preload_script(function_declaration, sandbox="test-sandbox") + script_id = driver.script._add_preload_script( + function_declaration, sandbox="test-sandbox" + ) assert script_id is not None pages.load("blank.html") # calling evaluate without sandbox should return undefined result = driver.script._evaluate( - "window.sandboxScript", {"context": driver.current_window_handle}, await_promise=False + "window.sandboxScript", + {"context": driver.current_window_handle}, + await_promise=False, ) assert result.result["type"] == "undefined" @@ -246,8 +298,12 @@ def test_add_preload_script_invalid_arguments(driver): """Test that providing both contexts and user_contexts raises an error.""" function_declaration = "() => {}" - with pytest.raises(ValueError, match="Cannot specify both contexts and user_contexts"): - driver.script._add_preload_script(function_declaration, contexts=["context1"], user_contexts=["user1"]) + with pytest.raises( + ValueError, match="Cannot specify both contexts and user_contexts" + ): + driver.script._add_preload_script( + function_declaration, contexts=["context1"], user_contexts=["user1"] + ) def test_remove_preload_script(driver, pages): @@ -262,7 +318,9 @@ def test_remove_preload_script(driver, pages): # The script should not have executed result = driver.script._evaluate( - "typeof window.removableScript", {"context": driver.current_window_handle}, await_promise=False + "typeof window.removableScript", + {"context": driver.current_window_handle}, + await_promise=False, ) assert result.result["value"] == "undefined" @@ -271,7 +329,9 @@ def test_evaluate_expression(driver, pages): """Test evaluating a simple expression.""" pages.load("blank.html") - result = driver.script._evaluate("1 + 2", {"context": driver.current_window_handle}, await_promise=False) + result = driver.script._evaluate( + "1 + 2", {"context": driver.current_window_handle}, await_promise=False + ) assert result.realm is not None assert result.result["type"] == "number" @@ -284,7 +344,9 @@ def test_evaluate_with_await_promise(driver, pages): pages.load("blank.html") result = driver.script._evaluate( - "Promise.resolve(42)", {"context": driver.current_window_handle}, await_promise=True + "Promise.resolve(42)", + {"context": driver.current_window_handle}, + await_promise=True, ) assert result.result["type"] == "number" @@ -296,7 +358,9 @@ def test_evaluate_with_exception(driver, pages): pages.load("blank.html") result = driver.script._evaluate( - "throw new Error('Test error')", {"context": driver.current_window_handle}, await_promise=False + "throw new Error('Test error')", + {"context": driver.current_window_handle}, + await_promise=False, ) assert result.exception_details is not None @@ -334,7 +398,11 @@ def test_evaluate_with_serialization_options(driver, pages): """Test evaluating with serialization options.""" pages.load("shadowRootPage.html") - serialization_options = {"maxDomDepth": 2, "maxObjectDepth": 2, "includeShadowTree": "all"} + serialization_options = { + "maxDomDepth": 2, + "maxObjectDepth": 2, + "includeShadowTree": "all", + } result = driver.script._evaluate( "document.body", @@ -386,7 +454,9 @@ def test_call_function_with_this(driver, pages): # First set up an object driver.script._evaluate( - "window.testObj = { value: 10 }", {"context": driver.current_window_handle}, await_promise=False + "window.testObj = { value: 10 }", + {"context": driver.current_window_handle}, + await_promise=False, ) result = driver.script._call_function( @@ -419,7 +489,11 @@ def test_call_function_with_serialization_options(driver, pages): """Test calling a function with serialization options.""" pages.load("shadowRootPage.html") - serialization_options = {"maxDomDepth": 2, "maxObjectDepth": 2, "includeShadowTree": "all"} + serialization_options = { + "maxDomDepth": 2, + "maxObjectDepth": 2, + "includeShadowTree": "all", + } result = driver.script._call_function( "() => document.body", @@ -455,7 +529,9 @@ def test_call_function_with_await_promise(driver, pages): pages.load("blank.html") result = driver.script._call_function( - "() => Promise.resolve('async result')", await_promise=True, target={"context": driver.current_window_handle} + "() => Promise.resolve('async result')", + await_promise=True, + target={"context": driver.current_window_handle}, ) assert result.result["type"] == "string" @@ -534,7 +610,10 @@ def test_disown_handles(driver, pages): # Create an object with root ownership (this will return a handle) result = driver.script._evaluate( - "({foo: 'bar'})", target={"context": driver.current_window_handle}, await_promise=False, result_ownership="root" + "({foo: 'bar'})", + target={"context": driver.current_window_handle}, + await_promise=False, + result_ownership="root", ) handle = result.result["handle"] @@ -551,7 +630,9 @@ def test_disown_handles(driver, pages): assert result_before.result["value"] == "bar" # Disown the handle - driver.script._disown(handles=[handle], target={"context": driver.current_window_handle}) + driver.script._disown( + handles=[handle], target={"context": driver.current_window_handle} + ) # Try using the disowned handle (this should fail) with pytest.raises(Exception): @@ -814,8 +895,6 @@ def test_execute_script_with_exception(driver, pages): """Test executing script that throws an exception.""" pages.load("blank.html") - from selenium.common.exceptions import WebDriverException - with pytest.raises(WebDriverException) as exc_info: driver.script.execute( """() => { @@ -870,3 +949,438 @@ def test_execute_script_with_nested_objects(driver, pages): assert value_dict["userName"] == "John" assert value_dict["userAge"] == 30 assert value_dict["hobbyCount"] == 2 + + +class TestBidiScriptExecution: + """Test script execution via execute_script.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test.""" + pages.load("blank.html") + + def test_execute_script_returns_string(self, driver): + """Test executing script that returns string.""" + result = driver.execute_script("return 'hello';") + assert result == "hello" + + def test_execute_script_returns_number(self, driver): + """Test executing script that returns number.""" + result = driver.execute_script("return 42;") + assert result == 42 + + def test_execute_script_returns_boolean(self, driver): + """Test executing script that returns boolean.""" + result = driver.execute_script("return true;") + assert result is True + + def test_execute_script_returns_null(self, driver): + """Test executing script that returns null.""" + result = driver.execute_script("return null;") + assert result is None + + def test_execute_script_returns_object(self, driver): + """Test executing script that returns object.""" + result = driver.execute_script("return {x: 1, y: 2};") + assert isinstance(result, dict) + assert result["x"] == 1 + + def test_execute_script_returns_array(self, driver): + """Test executing script that returns array.""" + result = driver.execute_script("return [1, 2, 3, 4, 5];") + assert isinstance(result, list) + assert len(result) == 5 + + def test_execute_script_dom_query(self, driver, pages): + """Test executing script that queries DOM.""" + pages.load("formPage.html") + result = driver.execute_script( + "return document.querySelectorAll('input').length;" + ) + assert result > 0 + + def test_execute_script_with_arguments(self, driver): + """Test executing script with arguments.""" + result = driver.execute_script("return arguments[0] * arguments[1];", 3, 5) + assert result == 15 + + +class TestBidiScriptGlobalState: + """Test script execution with global state management.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test.""" + pages.load("blank.html") + + def test_global_state_persistence(self, driver): + """Test that global state persists across script calls.""" + driver.execute_script("window.testVar = 42;") + result = driver.execute_script("return window.testVar;") + assert result == 42 + + def test_multiple_global_variables(self, driver): + """Test managing multiple global variables.""" + driver.execute_script( + """ + window.var1 = 'first'; + window.var2 = 'second'; + window.var3 = 'third'; + """ + ) + + result = driver.execute_script( + """ + return { + v1: window.var1, + v2: window.var2, + v3: window.var3 + }; + """ + ) + + assert result["v1"] == "first" + assert result["v2"] == "second" + assert result["v3"] == "third" + + def test_function_definition_in_global_scope(self, driver): + """Test defining functions in global scope.""" + driver.execute_script( + """ + window.multiply = function(a, b) { + return a * b; + }; + """ + ) + + result = driver.execute_script("return window.multiply(3, 7);") + assert result == 21 + + def test_complex_object_in_global_scope(self, driver): + """Test storing complex objects globally.""" + driver.execute_script( + """ + window.data = { + users: [ + {name: 'Alice', age: 30}, + {name: 'Bob', age: 25} + ], + metadata: { + version: '1.0', + timestamp: Date.now() + } + }; + """ + ) + + result = driver.execute_script("return window.data.users.length;") + assert result == 2 + + +class TestBidiScriptPreloadScripts: + """Test preload script lifecycle and edge cases.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test.""" + pages.load("blank.html") + + def test_multiple_preload_scripts(self, driver, pages): + """Test adding multiple preload scripts.""" + id1 = driver.script._add_preload_script("() => { window.test1 = 'loaded'; }") + id2 = driver.script._add_preload_script("() => { window.test2 = 'loaded'; }") + + try: + pages.load("blank.html") + + result1 = driver.script._evaluate( + "window.test1", + {"context": driver.current_window_handle}, + await_promise=False, + ) + result2 = driver.script._evaluate( + "window.test2", + {"context": driver.current_window_handle}, + await_promise=False, + ) + + assert result1.result["value"] == "loaded" + assert result2.result["value"] == "loaded" + finally: + driver.script._remove_preload_script(script_id=id1) + driver.script._remove_preload_script(script_id=id2) + + def test_preload_script_with_function(self, driver, pages): + """Test preload script defining functions.""" + script_id = driver.script._add_preload_script( + "() => { window.customFunc = (x) => x * 2; }" + ) + + try: + pages.load("blank.html") + result = driver.script._evaluate( + "window.customFunc(5)", + {"context": driver.current_window_handle}, + await_promise=False, + ) + assert result.result["value"] == 10 + finally: + driver.script._remove_preload_script(script_id=script_id) + + def test_preload_script_removal_prevents_execution(self, driver, pages): + """Test that removing preload script prevents its execution.""" + script_id = driver.script._add_preload_script( + "() => { window.shouldNotExist = true; }" + ) + driver.script._remove_preload_script(script_id=script_id) + + pages.load("blank.html") + result = driver.script._evaluate( + "typeof window.shouldNotExist", + {"context": driver.current_window_handle}, + await_promise=False, + ) + assert result.result["value"] == "undefined" + + def test_preload_script_with_dom_manipulation(self, driver, pages): + """Test preload script that manipulates DOM.""" + script_id = driver.script._add_preload_script( + """ + () => { + document.addEventListener('DOMContentLoaded', function() { + var div = document.createElement('div'); + div.id = 'injected-element'; + div.textContent = 'injected'; + document.body.appendChild(div); + }); + } + """ + ) + + try: + pages.load("blank.html") + element = driver.find_element(By.ID, "injected-element") + assert element is not None + assert element.text == "injected" + finally: + driver.script._remove_preload_script(script_id=script_id) + + +class TestBidiScriptContextManagement: + """Test script execution across browsing contexts.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test.""" + pages.load("blank.html") + + def test_script_executes_in_current_context(self, driver): + """Test that scripts execute in the current browsing context.""" + # Set variable in current context + driver.execute_script("window.contextVar = 'main';") + + # Verify it's accessible + result = driver.execute_script("return window.contextVar;") + assert result == "main" + + def test_multiple_navigations_maintain_context(self, driver, pages): + """Test script context changes with navigation.""" + # Load first page + pages.load("blank.html") + driver.execute_script("window.page = 'blank';") + + # Load second page - context should reset + pages.load("formPage.html") + result = driver.execute_script("return window.page;") + assert result is None + + # Set new value + driver.execute_script("window.page = 'form';") + result = driver.execute_script("return window.page;") + assert result == "form" + + def test_script_can_access_dom_elements(self, driver, pages): + """Test that scripts can access and manipulate DOM.""" + pages.load("formPage.html") + + # Find element count + result = driver.execute_script( + """ + return document.querySelectorAll('input[type="text"]').length; + """ + ) + assert result > 0 + + def test_script_context_with_console_handler(self, driver, pages): + """Test script execution with console message handler active.""" + log_entries = [] + handler_id = driver.script.add_console_message_handler(log_entries.append) + + try: + pages.load("bidi/logEntryAdded.html") + driver.execute_script("console.log('test message');") + + # Give some time for handler to capture + WebDriverWait(driver, 3).until(lambda _: log_entries) + assert len(log_entries) > 0 + finally: + driver.script.remove_console_message_handler(handler_id) + + def test_script_error_handler_active(self, driver, pages): + """Test script execution with error handler active.""" + errors = [] + handler_id = driver.script.add_javascript_error_handler(errors.append) + + try: + pages.load("bidi/logEntryAdded.html") + # Click element that triggers JS error + driver.find_element(By.ID, "jsException").click() + + # Give time for error handler to capture + WebDriverWait(driver, 5).until(lambda _: errors) + assert len(errors) > 0 + finally: + driver.script.remove_javascript_error_handler(handler_id) + + +class TestBidiScriptComplexOperations: + """Test complex script operations and edge cases.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test.""" + pages.load("blank.html") + + def test_execute_script_with_timeout(self, driver): + """Test script execution within time constraints.""" + # Execute script that completes quickly + result = driver.execute_script( + """ + return new Promise((resolve) => { + setTimeout(() => resolve('completed'), 10); + }); + """ + ) + # Note: synchronous execute_script may not wait for promises + # This just tests that the method handles the call + assert result is not None + + def test_execute_script_with_dom_creation(self, driver): + """Test script that creates and manipulates DOM.""" + driver.execute_script( + """ + const div = document.createElement('div'); + div.id = 'created-element'; + div.textContent = 'Created by script'; + document.body.appendChild(div); + """ + ) + + # Verify element was created + result = driver.execute_script( + """ + const elem = document.getElementById('created-element'); + return elem ? elem.textContent : null; + """ + ) + assert result == "Created by script" + + def test_execute_script_with_nested_objects(self, driver): + """Test script that returns deeply nested objects.""" + result = driver.execute_script( + """ + return { + level1: { + level2: { + level3: { + value: 'deep' + } + } + } + }; + """ + ) + + assert result["level1"]["level2"]["level3"]["value"] == "deep" + + def test_execute_script_with_exception_handling(self, driver): + """Test script that handles exceptions internally.""" + result = driver.execute_script( + """ + try { + throw new Error('test error'); + } catch (e) { + return 'error caught: ' + e.message; + } + """ + ) + assert "error caught" in result + + +class TestBidiScriptErrorHandling: + """Test script error and logging scenarios.""" + + @pytest.fixture(autouse=True) + def setup(self, driver, pages): + """Setup for each test.""" + pages.load("blank.html") + + def test_script_error_handler_captures_errors(self, driver, pages): + """Test that error handler can capture script errors.""" + errors = [] + + def error_handler(entry): + errors.append(entry) + + handler_id = driver.script.add_javascript_error_handler(error_handler) + + try: + pages.load("bidi/logEntryAdded.html") + driver.find_element(By.ID, "jsException").click() + + WebDriverWait(driver, 5).until(lambda _: errors) + assert len(errors) > 0 + finally: + driver.script.remove_javascript_error_handler(handler_id) + + def test_multiple_error_handlers(self, driver, pages): + """Test multiple error handlers can be registered.""" + errors1 = [] + errors2 = [] + + handler_id1 = driver.script.add_javascript_error_handler(errors1.append) + handler_id2 = driver.script.add_javascript_error_handler(errors2.append) + + try: + pages.load("bidi/logEntryAdded.html") + driver.find_element(By.ID, "jsException").click() + + # Both handlers should receive events when error occurs + WebDriverWait(driver, 5).until( + lambda _: len(errors1) > 0 and len(errors2) > 0 + ) + assert len(errors1) > 0 + assert len(errors2) > 0 + finally: + driver.script.remove_javascript_error_handler(handler_id1) + driver.script.remove_javascript_error_handler(handler_id2) + + def test_console_message_with_logging(self, driver, pages): + """Test console message handler with actual logging.""" + log_entries = [] + handler_id = driver.script.add_console_message_handler(log_entries.append) + + try: + pages.load("bidi/logEntryAdded.html") + driver.find_element(By.ID, "consoleLog").click() + + WebDriverWait(driver, 5).until(lambda _: log_entries) + assert len(log_entries) > 0 + finally: + driver.script.remove_console_message_handler(handler_id) + + def test_execute_script_syntax_error(self, driver): + """Test executing script with syntax errors.""" + # This should raise an exception + with pytest.raises(Exception): + driver.execute_script("{{invalid syntax}}") diff --git a/py/test/selenium/webdriver/common/bidi_storage_tests.py b/py/test/selenium/webdriver/common/bidi_storage_tests.py index 01fb375f7c39a..157df78dd3cd9 100644 --- a/py/test/selenium/webdriver/common/bidi_storage_tests.py +++ b/py/test/selenium/webdriver/common/bidi_storage_tests.py @@ -98,7 +98,9 @@ def test_get_cookie_by_name(self, driver, pages, webserver): driver.add_cookie({"name": key, "value": value}) # Test - cookie_filter = CookieFilter(name=key, value=BytesValue(BytesValue.TYPE_STRING, "set")) + cookie_filter = CookieFilter( + name=key, value=BytesValue(BytesValue.TYPE_STRING, "set") + ) result = driver.storage.get_cookies(filter=cookie_filter) @@ -120,14 +122,18 @@ def test_get_cookie_in_default_user_context(self, driver, pages, webserver): driver.add_cookie({"name": key, "value": value}) # Test - cookie_filter = CookieFilter(name=key, value=BytesValue(BytesValue.TYPE_STRING, "set")) + cookie_filter = CookieFilter( + name=key, value=BytesValue(BytesValue.TYPE_STRING, "set") + ) driver.switch_to.new_window(WindowTypes.WINDOW) descriptor = BrowsingContextPartitionDescriptor(driver.current_window_handle) params = cookie_filter - result_after_switching_context = driver.storage.get_cookies(filter=params, partition=descriptor) + result_after_switching_context = driver.storage.get_cookies( + filter=params, partition=descriptor + ) assert len(result_after_switching_context.cookies) > 0 assert result_after_switching_context.cookies[0].value.value == value @@ -158,15 +164,21 @@ def test_get_cookie_in_a_user_context(self, driver, pages, webserver): descriptor = StorageKeyPartitionDescriptor(user_context=user_context) - parameters = PartialCookie(key, BytesValue(BytesValue.TYPE_STRING, value), webserver.host) + parameters = PartialCookie( + key, BytesValue(BytesValue.TYPE_STRING, value), webserver.host + ) driver.storage.set_cookie(cookie=parameters, partition=descriptor) # Test - cookie_filter = CookieFilter(name=key, value=BytesValue(BytesValue.TYPE_STRING, "set")) + cookie_filter = CookieFilter( + name=key, value=BytesValue(BytesValue.TYPE_STRING, "set") + ) # Create a new window with the user context - new_window = driver.browsing_context.create(type=WindowTypes.TAB, user_context=user_context) + new_window = driver.browsing_context.create( + type=WindowTypes.TAB, user_context=user_context + ) driver.switch_to.window(new_window) @@ -181,9 +193,13 @@ def test_get_cookie_in_a_user_context(self, driver, pages, webserver): driver.switch_to.window(window_handle) - browsing_context_partition_descriptor = BrowsingContextPartitionDescriptor(window_handle) + browsing_context_partition_descriptor = BrowsingContextPartitionDescriptor( + window_handle + ) - result1 = driver.storage.get_cookies(filter=cookie_filter, partition=browsing_context_partition_descriptor) + result1 = driver.storage.get_cookies( + filter=cookie_filter, partition=browsing_context_partition_descriptor + ) assert len(result1.cookies) == 0 @@ -198,7 +214,9 @@ def test_add_cookie(self, driver, pages, webserver): key = generate_unique_key() value = "foo" - parameters = PartialCookie(key, BytesValue(BytesValue.TYPE_STRING, value), webserver.host) + parameters = PartialCookie( + key, BytesValue(BytesValue.TYPE_STRING, value), webserver.host + ) assert_cookie_is_not_present_with_name(driver, key) # Test @@ -223,7 +241,14 @@ def test_add_and_get_cookie(self, driver, pages, webserver): path = "/simpleTest.html" cookie = PartialCookie( - "fish", value, domain, path=path, http_only=True, secure=False, same_site=SameSite.LAX, expiry=expiry + "fish", + value, + domain, + path=path, + http_only=True, + secure=False, + same_site=SameSite.LAX, + expiry=expiry, ) # Test @@ -336,10 +361,18 @@ def test_add_cookies_with_different_paths(self, driver, pages, webserver): assert_no_cookies_are_present(driver) cookie1 = PartialCookie( - "fish", BytesValue(BytesValue.TYPE_STRING, "cod"), webserver.host, path="/simpleTest.html" + "fish", + BytesValue(BytesValue.TYPE_STRING, "cod"), + webserver.host, + path="/simpleTest.html", ) - cookie2 = PartialCookie("planet", BytesValue(BytesValue.TYPE_STRING, "earth"), webserver.host, path="/") + cookie2 = PartialCookie( + "planet", + BytesValue(BytesValue.TYPE_STRING, "earth"), + webserver.host, + path="/", + ) # Test driver.storage.set_cookie(cookie=cookie1) @@ -353,3 +386,377 @@ def test_add_cookies_with_different_paths(self, driver, pages, webserver): driver.get(pages.url("formPage.html")) assert_cookie_is_not_present_with_name(driver, "fish") + + def test_delete_cookies_by_name_filter(self, driver, pages, webserver): + """Test deleting cookies with specific name filter.""" + assert_no_cookies_are_present(driver) + + key1 = generate_unique_key() + key2 = generate_unique_key() + key3 = generate_unique_key() + + driver.add_cookie({"name": key1, "value": "value1"}) + driver.add_cookie({"name": key2, "value": "value2"}) + driver.add_cookie({"name": key3, "value": "value3"}) + + # Delete only key1 + driver.storage.delete_cookies(filter=CookieFilter(name=key1)) + + # Verify + assert_cookie_is_not_present_with_name(driver, key1) + assert_cookie_is_present_with_name(driver, key2) + assert_cookie_is_present_with_name(driver, key3) + + def test_delete_cookies_multiple_filters(self, driver, pages, webserver): + """Test deleting cookies with multiple filter criteria.""" + assert_no_cookies_are_present(driver) + + key = "multi_filter_delete_test" + value = BytesValue(BytesValue.TYPE_STRING, "test_value") + + # Create two cookies with same name but different http_only attributes + # This ensures the http_only filter actually affects which cookies are deleted + cookie1 = PartialCookie(key, value, webserver.host, http_only=True) + cookie2 = PartialCookie(key, value, webserver.host, http_only=False) + + driver.storage.set_cookie(cookie=cookie1) + driver.storage.set_cookie(cookie=cookie2) + + # Delete only http_only cookies - the http_only filter should actually matter here + driver.storage.delete_cookies(filter=CookieFilter(name=key, http_only=True)) + + # Verify - only the http_only=True cookie should be deleted + result = driver.storage.get_cookies(filter=CookieFilter(name=key)) + + # Should have one cookie remaining (the http_only=False one) + assert len(result.cookies) == 1 + assert result.cookies[0].http_only is False + + def test_delete_cookies_empty_filter(self, driver, pages, webserver): + """Test deleting with empty filter deletes all cookies.""" + assert_no_cookies_are_present(driver) + + # Add multiple cookies + for i in range(3): + driver.add_cookie({"name": f"cookie_{i}", "value": f"value_{i}"}) + + assert_some_cookies_are_present(driver) + + # Delete with empty filter + driver.storage.delete_cookies(filter=CookieFilter()) + + # Verify all deleted + assert_no_cookies_are_present(driver) + + def test_set_cookie_with_http_only_attribute(self, driver, pages, webserver): + """Test setting a cookie with http_only attribute.""" + assert_no_cookies_are_present(driver) + + key = "http_only_cookie" + value = BytesValue(BytesValue.TYPE_STRING, "protected") + + cookie = PartialCookie(key, value, webserver.host, http_only=True) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify + cookie_filter = CookieFilter(name=key, http_only=True) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].http_only is True + + def test_set_cookie_with_secure_attribute(self, driver, pages, webserver): + """Test setting a cookie with secure attribute.""" + assert_no_cookies_are_present(driver) + + key = "secure_cookie" + value = BytesValue(BytesValue.TYPE_STRING, "encrypted") + + cookie = PartialCookie(key, value, webserver.host, secure=True) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify + cookie_filter = CookieFilter(name=key, secure=True) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].secure is True + + def test_set_cookie_with_same_site_strict(self, driver, pages, webserver): + """Test setting a cookie with SameSite=Strict.""" + assert_no_cookies_are_present(driver) + + key = "samesite_strict" + value = BytesValue(BytesValue.TYPE_STRING, "strict") + + cookie = PartialCookie(key, value, webserver.host, same_site=SameSite.STRICT) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify + cookie_filter = CookieFilter(name=key, same_site=SameSite.STRICT) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].same_site == SameSite.STRICT + + def test_set_cookie_with_same_site_lax(self, driver, pages, webserver): + """Test setting a cookie with SameSite=Lax.""" + assert_no_cookies_are_present(driver) + + key = "samesite_lax" + value = BytesValue(BytesValue.TYPE_STRING, "lax") + + cookie = PartialCookie(key, value, webserver.host, same_site=SameSite.LAX) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify + cookie_filter = CookieFilter(name=key, same_site=SameSite.LAX) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].same_site == SameSite.LAX + + def test_set_cookie_with_same_site_none(self, driver, pages, webserver): + """Test setting a cookie with SameSite=None (requires Secure).""" + assert_no_cookies_are_present(driver) + + key = "samesite_none" + value = BytesValue(BytesValue.TYPE_STRING, "none") + + # SameSite=None typically requires secure=True + cookie = PartialCookie( + key, value, webserver.host, same_site=SameSite.NONE, secure=True + ) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify + cookie_filter = CookieFilter(name=key, same_site=SameSite.NONE) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].same_site == SameSite.NONE + + def test_set_cookie_with_path_and_domain(self, driver, pages, webserver): + """Test setting a cookie with specific path and domain.""" + assert_no_cookies_are_present(driver) + + key = "path_domain_cookie" + value = BytesValue(BytesValue.TYPE_STRING, "scoped") + path = "/simpleTest.html" + + cookie = PartialCookie(key, value, webserver.host, path=path) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify + cookie_filter = CookieFilter(name=key, path=path) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].path == path + assert result.cookies[0].domain == webserver.host + + def test_set_cookie_with_future_expiry(self, driver, pages, webserver): + """Test setting a cookie with a future expiry date.""" + assert_no_cookies_are_present(driver) + + key = "future_expiry_cookie" + value = BytesValue(BytesValue.TYPE_STRING, "future") + + # Set expiry to 1 hour from now + future_expiry = int(time.time() + 3600) + + cookie = PartialCookie(key, value, webserver.host, expiry=future_expiry) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify + cookie_filter = CookieFilter(name=key) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].expiry == future_expiry + + def test_set_cookie_with_string_value(self, driver, pages, webserver): + """Test setting a cookie with string value (standard format).""" + assert_no_cookies_are_present(driver) + + key = "string_value_cookie" + value = BytesValue(BytesValue.TYPE_STRING, "hello") + + cookie = PartialCookie(key, value, webserver.host) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify + cookie_filter = CookieFilter(name=key) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].value.value == "hello" + + def test_get_cookies_filter_by_domain(self, driver, pages, webserver): + """Test getting cookies filtered by domain.""" + assert_no_cookies_are_present(driver) + + key = generate_unique_key() + value = BytesValue(BytesValue.TYPE_STRING, "domain_test") + + cookie = PartialCookie(key, value, webserver.host) + driver.storage.set_cookie(cookie=cookie) + + # Filter by domain + cookie_filter = CookieFilter(domain=webserver.host) + result = driver.storage.get_cookies(filter=cookie_filter) + + # Should find the cookie + cookie_names = [c.name for c in result.cookies] + assert key in cookie_names + + def test_get_cookies_filter_by_path(self, driver, pages, webserver): + """Test getting cookies filtered by path.""" + assert_no_cookies_are_present(driver) + + key1 = generate_unique_key() + key2 = generate_unique_key() + value = BytesValue(BytesValue.TYPE_STRING, "path_test") + + # Cookie with specific path + cookie1 = PartialCookie(key1, value, webserver.host, path="/simpleTest.html") + # Cookie with root path + cookie2 = PartialCookie(key2, value, webserver.host, path="/") + + driver.storage.set_cookie(cookie=cookie1) + driver.storage.set_cookie(cookie=cookie2) + + # Filter by specific path + cookie_filter = CookieFilter(path="/simpleTest.html") + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert all(c.path == "/simpleTest.html" for c in result.cookies) + + def test_multiple_cookies_same_name_different_paths(self, driver, pages, webserver): + """Test setting multiple cookies with same name but different paths.""" + assert_no_cookies_are_present(driver) + + key = "multi_path_cookie" + value = BytesValue(BytesValue.TYPE_STRING, "test") + + # Create cookies with same name but different paths + cookie1 = PartialCookie(key, value, webserver.host, path="/") + cookie2 = PartialCookie(key, value, webserver.host, path="/simpleTest.html") + + driver.storage.set_cookie(cookie=cookie1) + driver.storage.set_cookie(cookie=cookie2) + + # Both should exist + cookie_filter = CookieFilter(name=key) + result = driver.storage.get_cookies(filter=cookie_filter) + + # Should find at least 2 cookies with this name (different paths) + assert len(result.cookies) >= 2 + + def test_delete_cookie_by_path(self, driver, pages, webserver): + """Test deleting cookies filtered by path.""" + assert_no_cookies_are_present(driver) + + key1 = generate_unique_key() + key2 = generate_unique_key() + value = BytesValue(BytesValue.TYPE_STRING, "delete_test") + + cookie1 = PartialCookie(key1, value, webserver.host, path="/simpleTest.html") + cookie2 = PartialCookie(key2, value, webserver.host, path="/") + + driver.storage.set_cookie(cookie=cookie1) + driver.storage.set_cookie(cookie=cookie2) + + # Delete only cookies with specific path + driver.storage.delete_cookies(filter=CookieFilter(path="/simpleTest.html")) + + # Verify path-specific cookie is deleted, root path cookie remains + result = driver.storage.get_cookies(filter=CookieFilter()) + cookie_names = [c.name for c in result.cookies] + + assert key1 not in cookie_names or all( + c.path != "/simpleTest.html" for c in result.cookies if c.name == key1 + ) + + def test_cookie_expiry_timestamp(self, driver, pages, webserver): + """Test that cookie expiry is stored correctly as timestamp.""" + assert_no_cookies_are_present(driver) + + key = "expiry_test" + value = BytesValue(BytesValue.TYPE_STRING, "expires") + + # Set expiry to specific time + expiry_time = int(time.time() + 7200) # 2 hours from now + + cookie = PartialCookie(key, value, webserver.host, expiry=expiry_time) + + driver.storage.set_cookie(cookie=cookie) + + # Get and verify + cookie_filter = CookieFilter(name=key) + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + assert result.cookies[0].expiry == expiry_time + + def test_cookie_combined_attributes(self, driver, pages, webserver): + """Test setting and getting cookie with multiple attributes combined.""" + assert_no_cookies_are_present(driver) + + key = "combined_attrs" + value = BytesValue(BytesValue.TYPE_STRING, "all_features") + path = "/simpleTest.html" + expiry = int(time.time() + 3600) + + cookie = PartialCookie( + key, + value, + webserver.host, + path=path, + http_only=True, + secure=True, + same_site=SameSite.LAX, + expiry=expiry, + ) + + # Test + driver.storage.set_cookie(cookie=cookie) + + # Verify with matching filter + cookie_filter = CookieFilter( + name=key, + path=path, + http_only=True, + secure=True, + same_site=SameSite.LAX, + expiry=expiry, + ) + + result = driver.storage.get_cookies(filter=cookie_filter) + + assert len(result.cookies) > 0 + cookie_result = result.cookies[0] + assert cookie_result.name == key + assert cookie_result.value.value == value.value + assert cookie_result.path == path + assert cookie_result.http_only is True + assert cookie_result.secure is True + assert cookie_result.same_site == SameSite.LAX + assert cookie_result.expiry == expiry diff --git a/py/test/selenium/webdriver/common/bidi_webextension_tests.py b/py/test/selenium/webdriver/common/bidi_webextension_tests.py index 7bea9f71e7e16..93c6c5a1d5528 100644 --- a/py/test/selenium/webdriver/common/bidi_webextension_tests.py +++ b/py/test/selenium/webdriver/common/bidi_webextension_tests.py @@ -179,4 +179,309 @@ def test_install_with_extension_id_uninstall(self, chromium_driver): ext_info = chromium_driver.webextension.install(path=path) extension_id = ext_info.get("extension") # Uninstall using the extension ID - uninstall_extension_and_verify_extension_uninstalled(chromium_driver, extension_id) + uninstall_extension_and_verify_extension_uninstalled( + chromium_driver, extension_id + ) + + +# Additional edge case tests for better WPT coverage + + +class TestFirefoxWebExtensionEdgeCases: + """Firefox WebExtension edge case tests.""" + + @pytest.mark.xfail_chrome + @pytest.mark.xfail_edge + def test_uninstall_extension_by_id_string(self, driver, pages): + """Test uninstalling extension using extension ID as string.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = install_extension(driver, path=path) + extension_id_string = ext_info.get("extension") + + # Uninstall using ID string directly + driver.webextension.uninstall(extension_id_string) + + # Verify uninstall was successful + driver.browsing_context.reload(driver.current_window_handle) + assert len(driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 + + @pytest.mark.xfail_chrome + @pytest.mark.xfail_edge + def test_uninstall_extension_by_result_dict(self, driver, pages): + """Test uninstalling extension using result dictionary from install.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = install_extension(driver, path=path) + + # Uninstall using result dict + driver.webextension.uninstall(ext_info) + + # Verify uninstall was successful + driver.browsing_context.reload(driver.current_window_handle) + assert len(driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 + + @pytest.mark.xfail_chrome + @pytest.mark.xfail_edge + def test_install_returns_extension_id(self, driver, pages): + """Test that install returns proper extension ID in result.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = install_extension(driver, path=path) + + # Verify result structure + assert "extension" in ext_info + assert isinstance(ext_info.get("extension"), str) + assert len(ext_info.get("extension", "")) > 0 + assert ext_info.get("extension") == EXTENSION_ID + + # Cleanup + driver.webextension.uninstall(ext_info) + + @pytest.mark.xfail_chrome + @pytest.mark.xfail_edge + def test_extension_content_script_injection(self, driver, pages): + """Test that extension content scripts are properly injected.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = install_extension(driver, path=path) + + # Load page and verify content script injection + pages.load("blank.html") + + # Element should be injected by extension + injected_element = WebDriverWait(driver, timeout=5).until( + lambda dr: dr.find_element(By.ID, "webextensions-selenium-example") + ) + + assert injected_element is not None + assert ( + "Content injected by webextensions-selenium-example" + in injected_element.text + ) + + # Cleanup + driver.webextension.uninstall(ext_info) + + @pytest.mark.xfail_chrome + @pytest.mark.xfail_edge + def test_uninstall_removes_content_scripts(self, driver, pages): + """Test that uninstalling extension removes content scripts.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = install_extension(driver, path=path) + + # Verify injection works + pages.load("blank.html") + WebDriverWait(driver, timeout=5).until( + lambda dr: dr.find_element(By.ID, "webextensions-selenium-example") + ) + + # Uninstall + driver.webextension.uninstall(ext_info) + + # Reload page and verify injection is gone + driver.browsing_context.reload(driver.current_window_handle) + assert len(driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 + + @pytest.mark.xfail_chrome + @pytest.mark.xfail_edge + def test_install_from_archive_returns_extension_id(self, driver, pages): + """Test that archive install returns proper extension ID.""" + archive_path = os.path.join(EXTENSIONS, EXTENSION_ARCHIVE_PATH) + ext_info = install_extension(driver, archive_path=archive_path) + + # Verify result structure + assert "extension" in ext_info + assert isinstance(ext_info.get("extension"), str) + assert len(ext_info.get("extension", "")) > 0 + + # Cleanup + driver.webextension.uninstall(ext_info) + + @pytest.mark.xfail_chrome + @pytest.mark.xfail_edge + def test_multiple_installations_and_uninstalls(self, driver, pages): + """Test installing and uninstalling extension multiple times.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + + # Install/uninstall cycle 1 + ext_info_1 = install_extension(driver, path=path) + verify_extension_injection(driver, pages) + driver.webextension.uninstall(ext_info_1) + driver.browsing_context.reload(driver.current_window_handle) + assert len(driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 + + # Install/uninstall cycle 2 + ext_info_2 = install_extension(driver, path=path) + verify_extension_injection(driver, pages) + driver.webextension.uninstall(ext_info_2) + driver.browsing_context.reload(driver.current_window_handle) + assert len(driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 + + +class TestChromiumWebExtensionEdgeCases: + """Chrome/Edge WebExtension edge case tests.""" + + @pytest.mark.xfail_firefox + @pytest.fixture + def pages_chromium(self, webserver, chromium_driver): + class Pages: + def load(self, name): + chromium_driver.get(webserver.where_is(name, localhost=False)) + + return Pages() + + @pytest.mark.xfail_firefox + @pytest.fixture + def chromium_driver(self, chromium_options, request): + """Create a Chrome/Edge driver with webextension support enabled.""" + driver_option = request.config.option.drivers[0].lower() + + if driver_option == "chrome": + browser_class = webdriver.Chrome + browser_service = webdriver.ChromeService + elif driver_option == "edge": + browser_class = webdriver.Edge + browser_service = webdriver.EdgeService + + temp_dir = tempfile.mkdtemp(prefix="chromium-profile-") + + chromium_options.enable_bidi = True + chromium_options.enable_webextensions = True + chromium_options.add_argument(f"--user-data-dir={temp_dir}") + chromium_options.add_argument("--no-sandbox") + chromium_options.add_argument("--disable-dev-shm-usage") + + binary = request.config.option.binary + if binary: + chromium_options.binary_location = binary + + executable = request.config.option.executable + if executable: + service = browser_service(executable_path=executable) + else: + service = browser_service() + + chromium_driver = browser_class(options=chromium_options, service=service) + + yield chromium_driver + chromium_driver.quit() + + # delete the temp directory + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + @pytest.mark.xfail_firefox + def test_uninstall_extension_by_id_string(self, chromium_driver, pages_chromium): + """Test uninstalling extension using extension ID as string.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = chromium_driver.webextension.install(path=path) + extension_id_string = ext_info.get("extension") + + # Uninstall using ID string directly + chromium_driver.webextension.uninstall(extension_id_string) + + # Verify uninstall was successful + chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) + assert ( + len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) + == 0 + ) + + @pytest.mark.xfail_firefox + def test_uninstall_extension_by_result_dict(self, chromium_driver, pages_chromium): + """Test uninstalling extension using result dictionary from install.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = chromium_driver.webextension.install(path=path) + + # Uninstall using result dict + chromium_driver.webextension.uninstall(ext_info) + + # Verify uninstall was successful + chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) + assert ( + len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) + == 0 + ) + + @pytest.mark.xfail_firefox + def test_install_returns_extension_id(self, chromium_driver, pages_chromium): + """Test that install returns proper extension ID in result.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = chromium_driver.webextension.install(path=path) + + # Verify result structure + assert "extension" in ext_info + assert isinstance(ext_info.get("extension"), str) + assert len(ext_info.get("extension", "")) > 0 + + # Cleanup + chromium_driver.webextension.uninstall(ext_info) + + @pytest.mark.xfail_firefox + def test_extension_content_script_injection(self, chromium_driver, pages_chromium): + """Test that extension content scripts are properly injected.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = chromium_driver.webextension.install(path=path) + + # Load page and verify content script injection + pages_chromium.load("blank.html") + + # Element should be injected by extension + injected_element = WebDriverWait(chromium_driver, timeout=5).until( + lambda dr: dr.find_element(By.ID, "webextensions-selenium-example") + ) + + assert injected_element is not None + assert ( + "Content injected by webextensions-selenium-example" + in injected_element.text + ) + + # Cleanup + chromium_driver.webextension.uninstall(ext_info) + + @pytest.mark.xfail_firefox + def test_uninstall_removes_content_scripts(self, chromium_driver, pages_chromium): + """Test that uninstalling extension removes content scripts.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + ext_info = chromium_driver.webextension.install(path=path) + + # Verify injection works + pages_chromium.load("blank.html") + WebDriverWait(chromium_driver, timeout=5).until( + lambda dr: dr.find_element(By.ID, "webextensions-selenium-example") + ) + + # Uninstall + chromium_driver.webextension.uninstall(ext_info) + + # Reload page and verify injection is gone + chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) + assert ( + len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) + == 0 + ) + + @pytest.mark.xfail_firefox + def test_multiple_installations_and_uninstalls( + self, chromium_driver, pages_chromium + ): + """Test installing and uninstalling extension multiple times.""" + path = os.path.join(EXTENSIONS, EXTENSION_PATH) + + # Install/uninstall cycle 1 + ext_info_1 = chromium_driver.webextension.install(path=path) + verify_extension_injection(chromium_driver, pages_chromium) + chromium_driver.webextension.uninstall(ext_info_1) + chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) + assert ( + len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) + == 0 + ) + + # Install/uninstall cycle 2 + ext_info_2 = chromium_driver.webextension.install(path=path) + verify_extension_injection(chromium_driver, pages_chromium) + chromium_driver.webextension.uninstall(ext_info_2) + chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) + assert ( + len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) + == 0 + ) From e0b80c807cf85aba4f3b040bfdc2f18763076d7a Mon Sep 17 00:00:00 2001 From: Nikolay Borisenko <22616990+nvborisenko@users.noreply.github.com> Date: Wed, 11 Mar 2026 00:48:13 +0300 Subject: [PATCH 58/67] [dotnet] [bidi] Simplified how background tasks are started (#17198) --- dotnet/src/webdriver/BiDi/Broker.cs | 4 +--- dotnet/src/webdriver/BiDi/EventDispatcher.cs | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/dotnet/src/webdriver/BiDi/Broker.cs b/dotnet/src/webdriver/BiDi/Broker.cs index 162fc392fa920..c1f02b398b50c 100644 --- a/dotnet/src/webdriver/BiDi/Broker.cs +++ b/dotnet/src/webdriver/BiDi/Broker.cs @@ -37,8 +37,6 @@ internal sealed class Broker : IAsyncDisposable private long _currentCommandId; - private static readonly TaskFactory _myTaskFactory = new(CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskContinuationOptions.None, TaskScheduler.Default); - private readonly Task _receivingMessageTask; private readonly CancellationTokenSource _receiveMessagesCancellationTokenSource; @@ -49,7 +47,7 @@ public Broker(ITransport transport, IBiDi bidi, Func sessionProv _eventDispatcher = new EventDispatcher(sessionProvider); _receiveMessagesCancellationTokenSource = new CancellationTokenSource(); - _receivingMessageTask = _myTaskFactory.StartNew(async () => await ReceiveMessagesAsync(_receiveMessagesCancellationTokenSource.Token), TaskCreationOptions.LongRunning).Unwrap(); + _receivingMessageTask = Task.Run(() => ReceiveMessagesAsync(_receiveMessagesCancellationTokenSource.Token)); } public Task SubscribeAsync(string eventName, EventHandler eventHandler, SubscriptionOptions? options, JsonTypeInfo jsonTypeInfo, CancellationToken cancellationToken) diff --git a/dotnet/src/webdriver/BiDi/EventDispatcher.cs b/dotnet/src/webdriver/BiDi/EventDispatcher.cs index 1a4e93006821c..1b1bfe964dc4a 100644 --- a/dotnet/src/webdriver/BiDi/EventDispatcher.cs +++ b/dotnet/src/webdriver/BiDi/EventDispatcher.cs @@ -42,12 +42,10 @@ internal sealed class EventDispatcher : IAsyncDisposable private readonly Task _eventEmitterTask; - private static readonly TaskFactory _myTaskFactory = new(CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskContinuationOptions.None, TaskScheduler.Default); - public EventDispatcher(Func sessionProvider) { _sessionProvider = sessionProvider; - _eventEmitterTask = _myTaskFactory.StartNew(ProcessEventsAwaiterAsync).Unwrap(); + _eventEmitterTask = Task.Run(ProcessEventsAwaiterAsync); } public async Task SubscribeAsync(string eventName, EventHandler eventHandler, SubscriptionOptions? options, JsonTypeInfo jsonTypeInfo, CancellationToken cancellationToken) From 41a323f48bcccd84679872eaf7fd849bd959962a Mon Sep 17 00:00:00 2001 From: Viet Nguyen Duc Date: Wed, 11 Mar 2026 11:59:56 +0700 Subject: [PATCH 59/67] [grid] Router WebSocket handle dropped close frames, idle disconnects, high-latency proxying (#17197) Signed-off-by: Viet Nguyen Duc --- .../openqa/selenium/grid/router/BUILD.bazel | 3 + .../grid/router/ProxyWebsocketsIntoGrid.java | 235 +++++++++++++--- .../grid/router/httpd/RouterFlags.java | 12 + .../grid/router/httpd/RouterOptions.java | 4 + .../grid/router/httpd/RouterServer.java | 26 +- .../selenium/netty/server/NettyServer.java | 5 + .../netty/server/PostUpgradeHook.java | 39 +++ .../netty/server/TcpUpgradeTunnelHandler.java | 65 +++++ .../netty/server/WebSocketFrameProxy.java | 166 ++++++++++++ .../server/WebSocketKeepAliveHandler.java | 69 +++++ .../netty/server/WebSocketUpgradeHandler.java | 25 +- .../grid/router/TunnelWebsocketTest.java | 252 ++++++++++++++++++ 12 files changed, 854 insertions(+), 47 deletions(-) create mode 100644 java/src/org/openqa/selenium/netty/server/PostUpgradeHook.java create mode 100644 java/src/org/openqa/selenium/netty/server/WebSocketFrameProxy.java create mode 100644 java/src/org/openqa/selenium/netty/server/WebSocketKeepAliveHandler.java diff --git a/java/src/org/openqa/selenium/grid/router/BUILD.bazel b/java/src/org/openqa/selenium/grid/router/BUILD.bazel index 02368b044c94c..d3d7ca5acf2ab 100644 --- a/java/src/org/openqa/selenium/grid/router/BUILD.bazel +++ b/java/src/org/openqa/selenium/grid/router/BUILD.bazel @@ -21,9 +21,12 @@ java_library( "//java/src/org/openqa/selenium/grid/sessionqueue", "//java/src/org/openqa/selenium/grid/web", "//java/src/org/openqa/selenium/json", + "//java/src/org/openqa/selenium/netty/server", "//java/src/org/openqa/selenium/remote", "//java/src/org/openqa/selenium/status", artifact("com.google.guava:guava"), + artifact("io.netty:netty-codec-http"), + artifact("io.netty:netty-transport"), artifact("org.jspecify:jspecify"), ], ) diff --git a/java/src/org/openqa/selenium/grid/router/ProxyWebsocketsIntoGrid.java b/java/src/org/openqa/selenium/grid/router/ProxyWebsocketsIntoGrid.java index 3cdf7784f02e1..9fb1d4a91e7cc 100644 --- a/java/src/org/openqa/selenium/grid/router/ProxyWebsocketsIntoGrid.java +++ b/java/src/org/openqa/selenium/grid/router/ProxyWebsocketsIntoGrid.java @@ -19,9 +19,16 @@ import static org.openqa.selenium.remote.http.HttpMethod.GET; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import java.net.URI; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.logging.Level; @@ -29,6 +36,8 @@ import org.openqa.selenium.NoSuchSessionException; import org.openqa.selenium.grid.sessionmap.SessionMap; import org.openqa.selenium.internal.Require; +import org.openqa.selenium.netty.server.PostUpgradeHook; +import org.openqa.selenium.netty.server.WebSocketFrameProxy; import org.openqa.selenium.remote.HttpSessionId; import org.openqa.selenium.remote.SessionId; import org.openqa.selenium.remote.http.BinaryMessage; @@ -40,6 +49,16 @@ import org.openqa.selenium.remote.http.TextMessage; import org.openqa.selenium.remote.http.WebSocket; +/** + * Proxies WebSocket connections from the Router to a Grid node. + * + *

After the Netty-side upgrade handshake completes ({@link PostUpgradeHook}), the pipeline is + * simplified: {@code MessageInboundConverter}, {@code MessageOutboundConverter}, and {@code + * WebSocketMessageHandler} are replaced by a {@link WebSocketFrameProxy} that forwards {@link + * io.netty.handler.codec.http.websocketx.WebSocketFrame} objects directly to the node-side {@link + * WebSocket}. This eliminates one intermediate object allocation and one executor-task submission + * per frame in each direction. + */ public class ProxyWebsocketsIntoGrid implements BiFunction, Optional>> { @@ -63,65 +82,211 @@ public Optional> apply(String uri, Consumer downstrea return Optional.empty(); } + URI sessionUri; try { - URI sessionUri = sessions.getUri(sessionId.get()); + sessionUri = sessions.getUri(sessionId.get()); + } catch (NoSuchSessionException e) { + LOG.warning("Attempt to connect to non-existent session: " + uri); + return Optional.empty(); + } + + AtomicBoolean upstreamClosing = new AtomicBoolean(false); + // Holds the client channel once onUpgradeComplete fires; used by DirectForwardingListener + // to write frames without going through MessageOutboundConverter. + AtomicReference clientChannelRef = new AtomicReference<>(); + + HttpClient client = + clientFactory.createClient(ClientConfig.defaultConfig().baseUri(sessionUri)); + try { + WebSocket upstream = + client.openSocket( + new HttpRequest(GET, uri), + new DirectForwardingListener(downstream, clientChannelRef, upstreamClosing, client)); + + return Optional.of( + new FrameProxyConsumer(upstream, client, clientChannelRef, upstreamClosing)); + + } catch (Exception e) { + LOG.log(Level.WARNING, "Connecting to upstream websocket failed", e); + client.close(); + return Optional.empty(); + } + } + + // --------------------------------------------------------------------------- + // Consumer returned to WebSocketUpgradeHandler — also implements PostUpgradeHook + // --------------------------------------------------------------------------- + + private static class FrameProxyConsumer implements Consumer, PostUpgradeHook { + + private final WebSocket upstream; + private final HttpClient client; + private final AtomicReference clientChannelRef; + private final AtomicBoolean upstreamClosing; + + FrameProxyConsumer( + WebSocket upstream, + HttpClient client, + AtomicReference clientChannelRef, + AtomicBoolean upstreamClosing) { + this.upstream = upstream; + this.client = client; + this.clientChannelRef = clientChannelRef; + this.upstreamClosing = upstreamClosing; + } + + /** + * Called by {@link org.openqa.selenium.netty.server.WebSocketUpgradeHandler} on the Netty IO + * thread after the client-side handshake completes. Install {@link WebSocketFrameProxy} and + * strip the three {@code Message}-layer handlers so subsequent data frames never pass through + * the full Selenium handler chain. + */ + @Override + public void onUpgradeComplete(ChannelHandlerContext ctx) { + Channel ch = ctx.channel(); + clientChannelRef.set(ch); + + WebSocketFrameProxy proxy = new WebSocketFrameProxy(upstream, upstreamClosing); + ChannelPipeline pipeline = ctx.pipeline(); + + // Insert the frame proxy just before the inbound Message converter so it intercepts + // WebSocketFrame objects first, then remove the three now-redundant handlers. + pipeline.addBefore("netty-to-se-messages", "frame-proxy", proxy); + removeIfPresent(pipeline, "netty-to-se-messages"); // MessageInboundConverter + removeIfPresent(pipeline, "se-to-netty-messages"); // MessageOutboundConverter + removeIfPresent(pipeline, "se-websocket-handler"); // WebSocketMessageHandler + } + + private static void removeIfPresent(ChannelPipeline pipeline, String name) { + if (pipeline.get(name) != null) { + pipeline.remove(name); + } + } + + /** + * After pipeline rewiring this consumer is only called for {@link CloseMessage} (fired by + * {@link org.openqa.selenium.netty.server.WebSocketUpgradeHandler} on the close handshake). + * Data frames are handled directly by {@link WebSocketFrameProxy}. + */ + @Override + public void accept(Message msg) { + if (upstreamClosing.get()) { + if (msg instanceof CloseMessage) { + closeClient(); + } + return; + } - HttpClient client = - clientFactory.createClient(ClientConfig.defaultConfig().baseUri(sessionUri)); try { - WebSocket upstream = - client.openSocket(new HttpRequest(GET, uri), new ForwardingListener(downstream)); - - return Optional.of( - (msg) -> { - try { - upstream.send(msg); - } finally { - if (msg instanceof CloseMessage) { - try { - client.close(); - } catch (Exception e) { - LOG.log(Level.WARNING, "Failed to shutdown the client of " + sessionUri, e); - } - } - } - }); + upstream.send(msg); } catch (Exception e) { - LOG.log(Level.WARNING, "Connecting to upstream websocket failed", e); + LOG.log( + Level.FINE, + "Could not forward message to node WebSocket (connection likely closed)", + e); + closeClient(); + return; + } + + if (msg instanceof CloseMessage) { + closeClient(); + } + } + + private void closeClient() { + try { client.close(); - return Optional.empty(); + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to close upstream client", e); } - } catch (NoSuchSessionException e) { - LOG.warning("Attempt to connect to non-existent session: " + uri); - return Optional.empty(); } } - private static class ForwardingListener implements WebSocket.Listener { - private final Consumer downstream; + // --------------------------------------------------------------------------- + // Listener for node → client messages (fast path via direct frame writes) + // --------------------------------------------------------------------------- + + /** + * Writes node-side messages directly to the client {@link Channel} as Netty WebSocket frames, + * bypassing {@code MessageOutboundConverter}. Falls back to the {@code downstream} consumer + * before the client channel reference is set (i.e. before {@link + * PostUpgradeHook#onUpgradeComplete} fires, which is rare). + */ + private static class DirectForwardingListener implements WebSocket.Listener { - public ForwardingListener(Consumer downstream) { - this.downstream = Objects.requireNonNull(downstream); + private final Consumer fallbackDownstream; + private final AtomicReference clientChannelRef; + private final AtomicBoolean upstreamClosing; + private final HttpClient client; + + DirectForwardingListener( + Consumer fallbackDownstream, + AtomicReference clientChannelRef, + AtomicBoolean upstreamClosing, + HttpClient client) { + this.fallbackDownstream = Objects.requireNonNull(fallbackDownstream); + this.clientChannelRef = Objects.requireNonNull(clientChannelRef); + this.upstreamClosing = Objects.requireNonNull(upstreamClosing); + this.client = Objects.requireNonNull(client); } @Override - public void onBinary(byte[] data) { - downstream.accept(new BinaryMessage(data)); + public void onText(CharSequence data) { + Channel ch = clientChannelRef.get(); + if (ch != null) { + // Fast path: write TextWebSocketFrame directly, skipping MessageOutboundConverter. + WebSocketFrameProxy.writeTextFrame(ch, data); + } else { + fallbackDownstream.accept(new TextMessage(data)); + } } @Override - public void onClose(int code, String reason) { - downstream.accept(new CloseMessage(code, reason)); + public void onBinary(byte[] data) { + Channel ch = clientChannelRef.get(); + if (ch != null) { + // Fast path: write BinaryWebSocketFrame directly, skipping MessageOutboundConverter. + WebSocketFrameProxy.writeBinaryFrame(ch, data); + } else { + fallbackDownstream.accept(new BinaryMessage(data)); + } } @Override - public void onText(CharSequence data) { - downstream.accept(new TextMessage(data)); + public void onClose(int code, String reason) { + upstreamClosing.set(true); + // After onUpgradeComplete the pipeline no longer contains MessageOutboundConverter, so + // writing a CloseMessage object via fallbackDownstream would fail to encode. Write the + // Netty frame directly once the client channel reference is available. + Channel ch = clientChannelRef.get(); + if (ch != null && ch.isActive()) { + ch.writeAndFlush(new CloseWebSocketFrame(code, reason)); + } else { + fallbackDownstream.accept(new CloseMessage(code, reason)); + } + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Failed to close client on upstream WebSocket close", e); + } } @Override public void onError(Throwable cause) { + upstreamClosing.set(true); LOG.log(Level.WARNING, "Error proxying websocket command", cause); + // Close the client channel so Playwright/BiDi clients see a clean disconnect rather than + // hanging until the next keepalive ping fires. + Channel ch = clientChannelRef.get(); + if (ch != null && ch.isActive()) { + ch.writeAndFlush(new CloseWebSocketFrame(1011, "upstream error")) + .addListener(ChannelFutureListener.CLOSE); + } + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Failed to close client after WebSocket error", e); + } } } } diff --git a/java/src/org/openqa/selenium/grid/router/httpd/RouterFlags.java b/java/src/org/openqa/selenium/grid/router/httpd/RouterFlags.java index 3960c59d00d31..16e40c16bc849 100644 --- a/java/src/org/openqa/selenium/grid/router/httpd/RouterFlags.java +++ b/java/src/org/openqa/selenium/grid/router/httpd/RouterFlags.java @@ -74,6 +74,18 @@ public class RouterFlags implements HasRoles { @ConfigValue(section = ROUTER_SECTION, name = "disable-ui", example = "true") public boolean disableUi = false; + @Parameter( + names = {"--tcp-tunnel"}, + arity = 1, + description = + "Enable the transparent TCP tunnel for WebSocket connections (BiDi, CDP). " + + "When disabled, all WebSocket traffic is routed through ProxyWebsocketsIntoGrid " + + "instead of the direct byte-bridge. Disable for benchmarking the proxy path " + + "or in network topologies where the Router cannot open direct TCP connections " + + "to Nodes (e.g. Kubernetes port-forward setups).") + @ConfigValue(section = ROUTER_SECTION, name = "tcp-tunnel", example = "false") + public boolean tcpTunnel = true; + @Override public Set getRoles() { return Collections.singleton(ROUTER_ROLE); diff --git a/java/src/org/openqa/selenium/grid/router/httpd/RouterOptions.java b/java/src/org/openqa/selenium/grid/router/httpd/RouterOptions.java index 47d2e902a107a..7c2a5e6bd43a3 100644 --- a/java/src/org/openqa/selenium/grid/router/httpd/RouterOptions.java +++ b/java/src/org/openqa/selenium/grid/router/httpd/RouterOptions.java @@ -51,4 +51,8 @@ public String subPath() { public boolean disableUi() { return config.get(ROUTER_SECTION, "disable-ui").map(Boolean::parseBoolean).orElse(false); } + + public boolean tcpTunnel() { + return config.get(ROUTER_SECTION, "tcp-tunnel").map(Boolean::parseBoolean).orElse(true); + } } diff --git a/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java b/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java index ba27c6aa76402..b23620a27bfb7 100644 --- a/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java +++ b/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java @@ -191,18 +191,22 @@ protected Handlers createHandlers(Config config) { // Resolve a request URI to the Node URI for direct TCP tunnelling of WebSocket connections. // Falls back to ProxyWebsocketsIntoGrid (the websocketHandler) when the session is not found. + // Passing null disables the tunnel entirely (--tcp-tunnel false), forcing all WebSocket traffic + // through ProxyWebsocketsIntoGrid — useful for benchmarking or restricted network topologies. Function> tcpTunnelResolver = - uri -> - HttpSessionId.getSessionId(uri) - .map(SessionId::new) - .flatMap( - id -> { - try { - return Optional.of(sessions.getUri(id)); - } catch (NoSuchSessionException e) { - return Optional.empty(); - } - }); + routerOptions.tcpTunnel() + ? uri -> + HttpSessionId.getSessionId(uri) + .map(SessionId::new) + .flatMap( + id -> { + try { + return Optional.of(sessions.getUri(id)); + } catch (NoSuchSessionException e) { + return Optional.empty(); + } + }) + : null; return new Handlers( routeWithLiveness, diff --git a/java/src/org/openqa/selenium/netty/server/NettyServer.java b/java/src/org/openqa/selenium/netty/server/NettyServer.java index 2331eecfb1883..4748403c39f86 100644 --- a/java/src/org/openqa/selenium/netty/server/NettyServer.java +++ b/java/src/org/openqa/selenium/netty/server/NettyServer.java @@ -19,6 +19,7 @@ import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; +import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.MultiThreadIoEventLoopGroup; import io.netty.channel.nio.NioIoHandler; @@ -177,6 +178,10 @@ public NettyServer start() { b.group(bossGroup, workerGroup) .channel(NioServerSocketChannel.class) .handler(new LoggingHandler(LogLevel.DEBUG)) + // OS-level TCP keepalive: kernel probes stale connections that the app cannot detect. + .childOption(ChannelOption.SO_KEEPALIVE, true) + // Disable Nagle: flush small frames (BiDi, CDP) immediately without buffering. + .childOption(ChannelOption.TCP_NODELAY, true) .childHandler( new SeleniumHttpInitializer( sslCtx, handler, websocketHandler, allowCors, tcpTunnelResolver)); diff --git a/java/src/org/openqa/selenium/netty/server/PostUpgradeHook.java b/java/src/org/openqa/selenium/netty/server/PostUpgradeHook.java new file mode 100644 index 0000000000000..074d8375706d8 --- /dev/null +++ b/java/src/org/openqa/selenium/netty/server/PostUpgradeHook.java @@ -0,0 +1,39 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.netty.server; + +import io.netty.channel.ChannelHandlerContext; + +/** + * Optional interface that a {@link java.util.function.Consumer} returned by a WebSocket handler + * factory may implement to receive a callback after the Netty-side WebSocket handshake has + * completed. + * + *

Implementing this hook allows the handler to rewire the Netty pipeline (e.g. replace the + * {@code Message}-layer handlers with a lighter-weight frame-level forwarder) once the channel is + * fully in WebSocket mode. + */ +public interface PostUpgradeHook { + + /** + * Called on the Netty IO thread immediately after the {@link + * io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker} handshake future succeeds. + * The channel is now in WebSocket mode; callers may freely modify {@code ctx.pipeline()}. + */ + void onUpgradeComplete(ChannelHandlerContext ctx); +} diff --git a/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java b/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java index 985cdf247e6fe..55b427c3ded3c 100644 --- a/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java +++ b/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java @@ -24,6 +24,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; @@ -39,9 +40,12 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.timeout.IdleStateEvent; +import io.netty.handler.timeout.IdleStateHandler; import io.netty.util.ReferenceCountUtil; import java.net.URI; import java.util.Optional; +import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; @@ -139,6 +143,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception new Bootstrap() .group(clientChannel.eventLoop()) .channel(NioSocketChannel.class) + // Mirror the server-side socket options so both legs of the tunnel behave + // the same: SO_KEEPALIVE lets the OS probe stale connections, TCP_NODELAY + // flushes small CDP/BiDi frames without Nagle buffering. + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.TCP_NODELAY, true) .handler( new ChannelInitializer() { @Override @@ -303,6 +312,29 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { } } + // Install read-idle detection on both tunnel channels. The tunnel carries raw + // WebSocket bytes (CDP / BiDi); application-level pings from the client (e.g. + // Playwright's 30 s pings) flow through and reset the timer naturally. If no + // bytes arrive for IDLE_TIMEOUT_SECONDS the upstream LB has silently dropped + // the TCP connection — close both ends so the session slot is freed promptly. + int idleSeconds = WebSocketKeepAliveHandler.PING_INTERVAL_SECONDS * 4; + nodeChannel + .pipeline() + .addBefore( + "tunnel", + "node-idle", + new IdleStateHandler(idleSeconds, 0, 0, TimeUnit.SECONDS)); + nodeChannel + .pipeline() + .addAfter( + "node-idle", "node-idle-close", new IdleCloseHandler(clientChannel)); + cp.addBefore( + "tunnel", + "client-idle", + new IdleStateHandler(idleSeconds, 0, 0, TimeUnit.SECONDS)); + cp.addAfter( + "client-idle", "client-idle-close", new IdleCloseHandler(nodeChannel)); + // Re-enable reads on the client now that the tunnel is live. clientChannel.config().setAutoRead(true); }); @@ -327,4 +359,37 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { clientChannel.close(); } } + + // --------------------------------------------------------------------------- + // Idle-close handler shared by both legs of the tunnel + // --------------------------------------------------------------------------- + + /** + * Closes both tunnel channels when no bytes have been received on this channel for the configured + * read-idle window. This cleans up sessions where the intermediate load balancer silently dropped + * the TCP connection without sending a FIN or RST (common with AWS ALB, k8s ingress-nginx at + * their default 60 s idle timeout). + */ + private static final class IdleCloseHandler extends ChannelInboundHandlerAdapter { + + private final Channel peer; + + IdleCloseHandler(Channel peer) { + this.peer = peer; + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof IdleStateEvent) { + LOG.log( + Level.FINE, + "TCP tunnel read-idle timeout on {0}, closing both channels", + ctx.channel()); + ctx.close(); + peer.close(); + return; + } + super.userEventTriggered(ctx, evt); + } + } } diff --git a/java/src/org/openqa/selenium/netty/server/WebSocketFrameProxy.java b/java/src/org/openqa/selenium/netty/server/WebSocketFrameProxy.java new file mode 100644 index 0000000000000..aad98b13e7523 --- /dev/null +++ b/java/src/org/openqa/selenium/netty/server/WebSocketFrameProxy.java @@ -0,0 +1,166 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.netty.server; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.openqa.selenium.remote.http.BinaryMessage; +import org.openqa.selenium.remote.http.TextMessage; +import org.openqa.selenium.remote.http.WebSocket; + +/** + * Installed in the client-side Netty pipeline by {@link + * org.openqa.selenium.grid.router.ProxyWebsocketsIntoGrid} after the WebSocket upgrade handshake + * completes on both sides. It replaces the {@link MessageInboundConverter} → {@link + * WebSocketMessageHandler} chain by forwarding {@link WebSocketFrame} objects directly to the + * node-side {@link WebSocket}, avoiding one intermediate {@code Message} allocation and one + * executor-task submission per frame. + * + *

The reverse direction (node → client) is handled by {@code DirectForwardingListener} inside + * {@code ProxyWebsocketsIntoGrid}, which writes {@link TextWebSocketFrame}/{@link + * BinaryWebSocketFrame} directly to the client {@link Channel}, bypassing {@link + * MessageOutboundConverter}. + * + *

Close frames are intentionally NOT handled here — they continue to flow through {@link + * WebSocketUpgradeHandler} which calls the registered {@code Consumer} with a {@link + * org.openqa.selenium.remote.http.CloseMessage} and runs the Netty-level close handshake. + * + *

This handler is not {@code @ChannelHandler.Sharable}: each connection gets its own instance so + * that the fragmentation accumulators are per-connection. + */ +public class WebSocketFrameProxy extends SimpleChannelInboundHandler { + + private static final Logger LOG = Logger.getLogger(WebSocketFrameProxy.class.getName()); + + private final WebSocket upstream; + private final AtomicBoolean upstreamClosing; + + // State for reassembling fragmented messages (mirrors MessageInboundConverter). + private enum Continuation { + Text, + Binary, + None + } + + private Continuation next = Continuation.None; + private final StringBuilder textBuffer = new StringBuilder(); + private final ByteArrayOutputStream binaryBuffer = new ByteArrayOutputStream(); + + public WebSocketFrameProxy(WebSocket upstream, AtomicBoolean upstreamClosing) { + super(true); // autoRelease: SimpleChannelInboundHandler releases each frame after read + this.upstream = upstream; + this.upstreamClosing = upstreamClosing; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) { + if (upstreamClosing.get()) { + LOG.log(Level.FINE, "Dropping data frame: upstream WebSocket is closing"); + return; + } + + try { + forwardFrame(frame); + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to forward WebSocket frame to node", e); + ctx.fireExceptionCaught(e); + } + } + + private void forwardFrame(WebSocketFrame frame) { + if (frame instanceof TextWebSocketFrame) { + TextWebSocketFrame text = (TextWebSocketFrame) frame; + if (text.isFinalFragment()) { + upstream.send(new TextMessage(text.text())); + } else { + next = Continuation.Text; + textBuffer.append(text.text()); + } + + } else if (frame instanceof BinaryWebSocketFrame) { + BinaryWebSocketFrame binary = (BinaryWebSocketFrame) frame; + if (binary.isFinalFragment()) { + upstream.send(new BinaryMessage(binary.content().nioBuffer())); + } else { + next = Continuation.Binary; + try { + binary.content().readBytes(binaryBuffer, binary.content().readableBytes()); + } catch (IOException e) { + throw new UncheckedIOException("failed to read binary frame", e); + } + } + + } else if (frame instanceof ContinuationWebSocketFrame) { + ContinuationWebSocketFrame cont = (ContinuationWebSocketFrame) frame; + switch (next) { + case Text: + textBuffer.append(cont.text()); + if (cont.isFinalFragment()) { + upstream.send(new TextMessage(textBuffer.toString())); + textBuffer.setLength(0); + next = Continuation.None; + } + break; + case Binary: + try { + cont.content().readBytes(binaryBuffer, cont.content().readableBytes()); + } catch (IOException e) { + throw new UncheckedIOException("failed to read continuation frame", e); + } + if (cont.isFinalFragment()) { + upstream.send(new BinaryMessage(binaryBuffer.toByteArray())); + binaryBuffer.reset(); + next = Continuation.None; + } + break; + default: + // CloseWebSocketFrame continuation or unknown — ignore. + break; + } + } + // CloseWebSocketFrame: handled by WebSocketUpgradeHandler → Consumer. + } + + /** + * Called by the node-side {@code ForwardingListener} to write a text frame directly to the client + * channel, bypassing {@link MessageOutboundConverter}. + */ + public static void writeTextFrame(Channel clientChannel, CharSequence text) { + clientChannel.writeAndFlush(new TextWebSocketFrame(text.toString())); + } + + /** + * Called by the node-side {@code ForwardingListener} to write a binary frame directly to the + * client channel, bypassing {@link MessageOutboundConverter}. + */ + public static void writeBinaryFrame(Channel clientChannel, byte[] data) { + clientChannel.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(data))); + } +} diff --git a/java/src/org/openqa/selenium/netty/server/WebSocketKeepAliveHandler.java b/java/src/org/openqa/selenium/netty/server/WebSocketKeepAliveHandler.java new file mode 100644 index 0000000000000..ca0215a8091b0 --- /dev/null +++ b/java/src/org/openqa/selenium/netty/server/WebSocketKeepAliveHandler.java @@ -0,0 +1,69 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.netty.server; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Sends a WebSocket {@link PingWebSocketFrame} whenever no data has been written to the channel for + * {@link #PING_INTERVAL_SECONDS} seconds. This prevents cloud load balancers (AWS ALB default: 60 + * s, GCP default: 600 s, k8s ingress-nginx default: 60 s) and NAT gateways from silently dropping + * idle TCP connections mid-session. + * + *

Must be placed in the pipeline immediately after an {@link + * io.netty.handler.timeout.IdleStateHandler} that is configured with a writer-idle timeout. + * Installed by {@link WebSocketUpgradeHandler} after the WebSocket handshake completes so that it + * only activates for WebSocket connections (never for plain HTTP request/response pairs). + * + *

Incoming {@link io.netty.handler.codec.http.websocketx.PongWebSocketFrame} responses are + * handled by {@link WebSocketUpgradeHandler} (which releases them). No additional handling is + * needed here. + */ +class WebSocketKeepAliveHandler extends ChannelInboundHandlerAdapter { + + static final int PING_INTERVAL_SECONDS = 30; + + private static final Logger LOG = Logger.getLogger(WebSocketKeepAliveHandler.class.getName()); + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof IdleStateEvent && ((IdleStateEvent) evt).state() == IdleState.WRITER_IDLE) { + LOG.log(Level.FINE, "Sending WebSocket ping keepalive on {0}", ctx.channel()); + ctx.writeAndFlush(new PingWebSocketFrame()) + .addListener( + future -> { + if (!future.isSuccess()) { + // Channel is gone; close it so the session slot is released. + LOG.log( + Level.FINE, + "WebSocket ping failed on " + ctx.channel() + ", closing channel", + future.cause()); + ctx.close(); + } + }); + return; + } + super.userEventTriggered(ctx, evt); + } +} diff --git a/java/src/org/openqa/selenium/netty/server/WebSocketUpgradeHandler.java b/java/src/org/openqa/selenium/netty/server/WebSocketUpgradeHandler.java index d4917da5a3859..9a1547950da15 100644 --- a/java/src/org/openqa/selenium/netty/server/WebSocketUpgradeHandler.java +++ b/java/src/org/openqa/selenium/netty/server/WebSocketUpgradeHandler.java @@ -30,6 +30,7 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; @@ -43,10 +44,12 @@ import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory; +import io.netty.handler.timeout.IdleStateHandler; import io.netty.util.AttributeKey; import java.util.Arrays; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.logging.Level; @@ -157,7 +160,27 @@ private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req) { if (!future.isSuccess()) { ctx.fireExceptionCaught(future.cause()); } else { - ctx.channel().attr(key).setIfAbsent(maybeHandler.get()); + Consumer handler = maybeHandler.get(); + ctx.channel().attr(key).setIfAbsent(handler); + + // Install application-level keepalive for all WebSocket connections. + // Cloud LBs (AWS ALB: 60 s, k8s ingress-nginx: 60 s) silently drop idle + // TCP connections; OS-level SO_KEEPALIVE alone is not enough because + // most LBs ignore TCP keepalive probes. A WS ping every 30 s resets + // the LB's idle timer at the application level. + ChannelPipeline pipeline = ctx.pipeline(); + pipeline.addAfter( + "ws-protocol", + "ws-idle", + new IdleStateHandler( + 0, WebSocketKeepAliveHandler.PING_INTERVAL_SECONDS, 0, TimeUnit.SECONDS)); + pipeline.addAfter("ws-idle", "ws-keepalive", new WebSocketKeepAliveHandler()); + + // Allow the handler to rewire the pipeline now that the channel + // is fully in WebSocket mode (HTTP codec no longer active). + if (handler instanceof PostUpgradeHook) { + ((PostUpgradeHook) handler).onUpgradeComplete(ctx); + } } }); } diff --git a/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java b/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java index cc09ef5377027..625f78704c165 100644 --- a/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java +++ b/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java @@ -67,6 +67,7 @@ import org.openqa.selenium.remote.HttpSessionId; import org.openqa.selenium.remote.SessionId; import org.openqa.selenium.remote.http.BinaryMessage; +import org.openqa.selenium.remote.http.CloseMessage; import org.openqa.selenium.remote.http.HttpClient; import org.openqa.selenium.remote.http.HttpHandler; import org.openqa.selenium.remote.http.HttpRequest; @@ -632,4 +633,255 @@ public void onText(CharSequence data) { distributor.close(); bus.close(); } + + // --------------------------------------------------------------------------- + // ProxyWebsocketsIntoGrid + WebSocketFrameProxy (fallback path) tests + // + // These tests deliberately omit the TCP tunnel resolver so TcpUpgradeTunnelHandler + // is NOT installed. Every WebSocket upgrade goes through ProxyWebsocketsIntoGrid + // which, after the Netty handshake, rewires the pipeline via WebSocketFrameProxy. + // --------------------------------------------------------------------------- + + private Server createProxyRouter() { + // 3-arg constructor: no tcpTunnelResolver → no TcpUpgradeTunnelHandler in the pipeline. + return new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + new ProxyWebsocketsIntoGrid(HttpClient.Factory.createDefault(), sessions)) + .start(); + } + + @Test + void proxyPath_shouldForwardTextMessageToBackend() + throws URISyntaxException, InterruptedException { + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + backendServer = createEchoBackend("", latch, received); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = createProxyRouter(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + socket.sendText("proxy-hello"); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo("proxy-hello"); + } + } + + @Test + void proxyPath_shouldForwardReplyFromBackendToClient() + throws URISyntaxException, InterruptedException { + backendServer = createEchoBackend("proxy-pong", new CountDownLatch(1), new AtomicReference<>()); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = createProxyRouter(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference reply = new AtomicReference<>(); + + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), + new WebSocket.Listener() { + @Override + public void onText(CharSequence data) { + reply.set(data.toString()); + latch.countDown(); + } + })) { + + socket.sendText("proxy-ping"); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(reply.get()).isEqualTo("proxy-pong"); + } + } + + @Test + void proxyPath_shouldForwardBinaryMessages() throws URISyntaxException, InterruptedException { + byte[] payload = new byte[] {10, 20, 30, 40}; + + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof BinaryMessage) { + received.set(((BinaryMessage) msg).data()); + latch.countDown(); + } + })) + .start(); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = createProxyRouter(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + socket.sendBinary(payload); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo(payload); + } + } + + @Test + void proxyPath_shouldSupportMultipleMessagesAndBidirectionalFlow() + throws URISyntaxException, InterruptedException { + // Backend echoes every text message back with a ">" prefix to distinguish direction. + int messageCount = 5; + CountDownLatch backendLatch = new CountDownLatch(messageCount); + CountDownLatch clientLatch = new CountDownLatch(messageCount); + + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + backendLatch.countDown(); + sink.accept(new TextMessage(">" + ((TextMessage) msg).text())); + } + })) + .start(); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = createProxyRouter(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), + new WebSocket.Listener() { + @Override + public void onText(CharSequence data) { + clientLatch.countDown(); + } + })) { + + for (int i = 0; i < messageCount; i++) { + socket.sendText("msg-" + i); + } + + assertThat(backendLatch.await(10, SECONDS)).as("backend received all messages").isTrue(); + assertThat(clientLatch.await(10, SECONDS)).as("client received all replies").isTrue(); + } + } + + /** + * Regression test for node-initiated close in proxy path. + * + *

After the Netty pipeline is rewired by {@code WebSocketFrameProxy} (i.e. {@code + * MessageOutboundConverter} is removed), a close message sent by the backend must still reach the + * client as a proper WebSocket close frame — not silently dropped. + */ + @Test + void proxyPath_shouldRelayNodeInitiatedClose() throws URISyntaxException, InterruptedException { + CountDownLatch closeLatch = new CountDownLatch(1); + AtomicReference closeCode = new AtomicReference<>(); + + // Backend sends one text message, then immediately closes the connection. + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + sink.accept(new CloseMessage(1000, "done")); + } + })) + .start(); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = createProxyRouter(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), + new WebSocket.Listener() { + @Override + public void onClose(int code, String reason) { + closeCode.set(code); + closeLatch.countDown(); + } + })) { + + socket.sendText("trigger-close"); + + assertThat(closeLatch.await(5, SECONDS)) + .as("client should receive close frame from node") + .isTrue(); + assertThat(closeCode.get()).isEqualTo(1000); + } + } } From 10296a097d17b3b3177a46c58c24fb26c16f5585 Mon Sep 17 00:00:00 2001 From: Viet Nguyen Duc Date: Wed, 11 Mar 2026 12:22:23 +0700 Subject: [PATCH 60/67] [build] Fix Lint Format CI (#17202) Signed-off-by: Viet Nguyen Duc --- .../webdriver/common/bidi_input_tests.py | 90 +++++-------------- .../common/bidi_integration_tests.py | 12 +-- .../webdriver/common/bidi_log_tests.py | 4 +- .../webdriver/common/bidi_script_tests.py | 60 ++++--------- .../webdriver/common/bidi_storage_tests.py | 44 +++------ .../common/bidi_webextension_tests.py | 47 +++------- 6 files changed, 60 insertions(+), 197 deletions(-) diff --git a/py/test/selenium/webdriver/common/bidi_input_tests.py b/py/test/selenium/webdriver/common/bidi_input_tests.py index 9929a01117924..c415cc2e69b49 100644 --- a/py/test/selenium/webdriver/common/bidi_input_tests.py +++ b/py/test/selenium/webdriver/common/bidi_input_tests.py @@ -74,9 +74,7 @@ def test_basic_key_input(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions]) - WebDriverWait(driver, 5).until( - lambda d: input_element.get_attribute("value") == "hello" - ) + WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "hello") assert input_element.get_attribute("value") == "hello" @@ -100,9 +98,7 @@ def test_key_input_with_pause(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions]) - WebDriverWait(driver, 5).until( - lambda d: input_element.get_attribute("value") == "ab" - ) + WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "ab") assert input_element.get_attribute("value") == "ab" @@ -208,11 +204,7 @@ def test_wheel_scroll(driver, pages): # Scroll down wheel_actions = WheelSourceActions( id="wheel", - actions=[ - WheelScrollAction( - x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT - ) - ], + actions=[WheelScrollAction(x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT)], ) driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) @@ -263,13 +255,9 @@ def test_combined_input_actions(driver, pages): ], ) - driver.input.perform_actions( - driver.current_window_handle, [pointer_actions, key_actions] - ) + driver.input.perform_actions(driver.current_window_handle, [pointer_actions, key_actions]) - WebDriverWait(driver, 5).until( - lambda d: input_element.get_attribute("value") == "test" - ) + WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "test") assert input_element.get_attribute("value") == "test" @@ -281,9 +269,7 @@ def test_set_files(driver, pages): assert upload_element.get_attribute("value") == "" # Create a temporary file - with tempfile.NamedTemporaryFile( - mode="w", suffix=".txt", delete=False - ) as temp_file: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as temp_file: temp_file.write("test content") temp_file_path = temp_file.name @@ -293,9 +279,7 @@ def test_set_files(driver, pages): element_ref = {"sharedId": element_id} # Set files using BiDi - driver.input.set_files( - driver.current_window_handle, element_ref, [temp_file_path] - ) + driver.input.set_files(driver.current_window_handle, element_ref, [temp_file_path]) # Verify file was set value = upload_element.get_attribute("value") @@ -370,9 +354,7 @@ def test_release_actions(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions2]) # Should be able to type normally - WebDriverWait(driver, 5).until( - lambda d: "b" in input_element.get_attribute("value") - ) + WebDriverWait(driver, 5).until(lambda d: "b" in input_element.get_attribute("value")) @pytest.mark.parametrize("multiple", [True, False]) @@ -388,9 +370,7 @@ def file_dialog_handler(file_dialog_info): handler_id = driver.input.add_file_dialog_handler(file_dialog_handler) assert handler_id is not None - driver.get( - f"data:text/html," - ) + driver.get(f"data:text/html,") # Use script.evaluate to trigger the file dialog with user activation driver.script._evaluate( @@ -490,9 +470,7 @@ def test_perform_actions_rapid_key_sequence(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions]) - WebDriverWait(driver, 5).until( - lambda d: input_element.get_attribute("value") == "abcd" - ) + WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "abcd") assert input_element.get_attribute("value") == "abcd" @@ -617,11 +595,7 @@ def test_wheel_scroll_negative_delta(driver, pages): # First scroll down wheel_actions_down = WheelSourceActions( id="wheel_down", - actions=[ - WheelScrollAction( - x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT - ) - ], + actions=[WheelScrollAction(x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT)], ) driver.input.perform_actions(driver.current_window_handle, [wheel_actions_down]) @@ -632,11 +606,7 @@ def test_wheel_scroll_negative_delta(driver, pages): # Then scroll back up (negative delta) wheel_actions_up = WheelSourceActions( id="wheel_up", - actions=[ - WheelScrollAction( - x=100, y=100, delta_x=0, delta_y=-50, origin=Origin.VIEWPORT - ) - ], + actions=[WheelScrollAction(x=100, y=100, delta_x=0, delta_y=-50, origin=Origin.VIEWPORT)], ) driver.input.perform_actions(driver.current_window_handle, [wheel_actions_up]) @@ -676,11 +646,7 @@ def test_wheel_scroll_horizontal(driver, pages): # Scroll horizontally wheel_actions = WheelSourceActions( id="wheel", - actions=[ - WheelScrollAction( - x=100, y=100, delta_x=50, delta_y=0, origin=Origin.VIEWPORT - ) - ], + actions=[WheelScrollAction(x=100, y=100, delta_x=50, delta_y=0, origin=Origin.VIEWPORT)], ) driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) @@ -711,9 +677,7 @@ def test_key_input_special_characters(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions]) - WebDriverWait(driver, 5).until( - lambda d: "!" in input_element.get_attribute("value") - ) + WebDriverWait(driver, 5).until(lambda d: "!" in input_element.get_attribute("value")) def test_set_files_empty_file_list(driver, pages): @@ -741,9 +705,7 @@ def test_set_files_with_absolute_path(driver): upload_element = driver.find_element(By.ID, "upload") # Create a temporary file - with tempfile.NamedTemporaryFile( - mode="w", suffix=".txt", delete=False - ) as temp_file: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as temp_file: temp_file.write("test file content") temp_file_path = temp_file.name @@ -753,9 +715,7 @@ def test_set_files_with_absolute_path(driver): element_ref = {"sharedId": element_id} # Set file using absolute path - driver.input.set_files( - driver.current_window_handle, element_ref, [temp_file_path] - ) + driver.input.set_files(driver.current_window_handle, element_ref, [temp_file_path]) value = upload_element.get_attribute("value") assert os.path.basename(temp_file_path) in value @@ -882,15 +842,11 @@ def test_combined_keyboard_and_wheel_actions(driver, pages): id="wheel", actions=[ PauseAction(duration=0), # Sync with keyboard - WheelScrollAction( - x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT - ), + WheelScrollAction(x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT), ], ) - driver.input.perform_actions( - driver.current_window_handle, [key_actions, wheel_actions] - ) + driver.input.perform_actions(driver.current_window_handle, [key_actions, wheel_actions]) scroll_y = driver.execute_script("return window.pageYOffset;") assert scroll_y == 100 @@ -917,9 +873,7 @@ def test_key_input_with_value_attribute(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions]) - WebDriverWait(driver, 5).until( - lambda d: input_element.get_attribute("value") == "xyz" - ) + WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "xyz") assert input_element.get_attribute("value") == "xyz" @@ -936,11 +890,7 @@ def test_wheel_scroll_with_element_origin(driver, pages): # Scroll with element origin wheel_actions = WheelSourceActions( id="wheel", - actions=[ - WheelScrollAction( - x=100, y=100, delta_x=0, delta_y=100, origin=element_origin - ) - ], + actions=[WheelScrollAction(x=100, y=100, delta_x=0, delta_y=100, origin=element_origin)], ) driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) diff --git a/py/test/selenium/webdriver/common/bidi_integration_tests.py b/py/test/selenium/webdriver/common/bidi_integration_tests.py index 85323f49b3ccf..18098d2e7cb2a 100644 --- a/py/test/selenium/webdriver/common/bidi_integration_tests.py +++ b/py/test/selenium/webdriver/common/bidi_integration_tests.py @@ -175,9 +175,7 @@ def test_multiple_console_handlers(self, driver): try: driver.execute_script("console.log('test message');") - WebDriverWait(driver, 5).until( - lambda _: len(messages1) > 0 and len(messages2) > 0 - ) + WebDriverWait(driver, 5).until(lambda _: len(messages1) > 0 and len(messages2) > 0) assert len(messages1) > 0 assert len(messages2) > 0 @@ -220,9 +218,7 @@ def test_cookie_attributes(self, driver, pages): """Test cookie with various attributes.""" pages.load("blank.html") - driver.add_cookie( - {"name": "attr_cookie", "value": "test_value", "path": "/", "secure": False} - ) + driver.add_cookie({"name": "attr_cookie", "value": "test_value", "path": "/", "secure": False}) cookies = driver.get_cookies() cookie = next((c for c in cookies if c.get("name") == "attr_cookie"), None) @@ -257,9 +253,7 @@ def test_navigation_in_context(self, driver, pages): pages.load("blank.html") # Navigate using the BiDi API with the current context - driver.browsing_context.navigate( - context=driver.current_window_handle, url=pages.url("blank.html") - ) + driver.browsing_context.navigate(context=driver.current_window_handle, url=pages.url("blank.html")) # Verify page loaded element = driver.find_element(By.TAG_NAME, "body") diff --git a/py/test/selenium/webdriver/common/bidi_log_tests.py b/py/test/selenium/webdriver/common/bidi_log_tests.py index fbbd3a8166b2d..540100098ec16 100644 --- a/py/test/selenium/webdriver/common/bidi_log_tests.py +++ b/py/test/selenium/webdriver/common/bidi_log_tests.py @@ -82,9 +82,7 @@ def test_add_and_remove_handler(self, driver): try: driver.execute_script("console.log('first message');") - WebDriverWait(driver, 5).until( - lambda _: len(log_entries1) > 0 and len(log_entries2) > 0 - ) + WebDriverWait(driver, 5).until(lambda _: len(log_entries1) > 0 and len(log_entries2) > 0) assert len(log_entries1) > 0 assert len(log_entries2) > 0 diff --git a/py/test/selenium/webdriver/common/bidi_script_tests.py b/py/test/selenium/webdriver/common/bidi_script_tests.py index 5d5fc6ef9b780..d7ddf6d5d6882 100644 --- a/py/test/selenium/webdriver/common/bidi_script_tests.py +++ b/py/test/selenium/webdriver/common/bidi_script_tests.py @@ -115,9 +115,7 @@ def test_removes_console_message_handler(driver, pages): try: driver.find_element(By.ID, "consoleLog").click() - WebDriverWait(driver, 5).until( - lambda _: len(log_entries1) and len(log_entries2) - ) + WebDriverWait(driver, 5).until(lambda _: len(log_entries1) and len(log_entries2)) driver.script.remove_console_message_handler(id1) driver.find_element(By.ID, "consoleLog").click() @@ -158,9 +156,7 @@ def test_removes_javascript_message_handler(driver, pages): try: driver.find_element(By.ID, "jsException").click() - WebDriverWait(driver, 5).until( - lambda _: len(log_entries1) and len(log_entries2) - ) + WebDriverWait(driver, 5).until(lambda _: len(log_entries1) and len(log_entries2)) driver.script.remove_javascript_error_handler(id1) driver.find_element(By.ID, "jsException").click() @@ -196,13 +192,9 @@ def test_add_preload_script_with_arguments(driver, pages): """Test adding a preload script with channel arguments.""" function_declaration = "(channelFunc) => { channelFunc('test_value'); window.preloadValue = 'received'; }" - arguments = [ - {"type": "channel", "value": {"channel": "test-channel", "ownership": "root"}} - ] + arguments = [{"type": "channel", "value": {"channel": "test-channel", "ownership": "root"}}] - script_id = driver.script._add_preload_script( - function_declaration, arguments=arguments - ) + script_id = driver.script._add_preload_script(function_declaration, arguments=arguments) assert script_id is not None pages.load("blank.html") @@ -220,9 +212,7 @@ def test_add_preload_script_with_contexts(driver, pages): function_declaration = "() => { window.contextSpecific = true; }" contexts = [driver.current_window_handle] - script_id = driver.script._add_preload_script( - function_declaration, contexts=contexts - ) + script_id = driver.script._add_preload_script(function_declaration, contexts=contexts) assert script_id is not None pages.load("blank.html") @@ -247,9 +237,7 @@ def test_add_preload_script_with_user_contexts(driver, pages): try: user_contexts = [user_context] - script_id = driver.script._add_preload_script( - function_declaration, user_contexts=user_contexts - ) + script_id = driver.script._add_preload_script(function_declaration, user_contexts=user_contexts) assert script_id is not None pages.load("blank.html") @@ -270,9 +258,7 @@ def test_add_preload_script_with_sandbox(driver, pages): """Test adding a preload script with sandbox.""" function_declaration = "() => { window.sandboxScript = true; }" - script_id = driver.script._add_preload_script( - function_declaration, sandbox="test-sandbox" - ) + script_id = driver.script._add_preload_script(function_declaration, sandbox="test-sandbox") assert script_id is not None pages.load("blank.html") @@ -298,12 +284,8 @@ def test_add_preload_script_invalid_arguments(driver): """Test that providing both contexts and user_contexts raises an error.""" function_declaration = "() => {}" - with pytest.raises( - ValueError, match="Cannot specify both contexts and user_contexts" - ): - driver.script._add_preload_script( - function_declaration, contexts=["context1"], user_contexts=["user1"] - ) + with pytest.raises(ValueError, match="Cannot specify both contexts and user_contexts"): + driver.script._add_preload_script(function_declaration, contexts=["context1"], user_contexts=["user1"]) def test_remove_preload_script(driver, pages): @@ -329,9 +311,7 @@ def test_evaluate_expression(driver, pages): """Test evaluating a simple expression.""" pages.load("blank.html") - result = driver.script._evaluate( - "1 + 2", {"context": driver.current_window_handle}, await_promise=False - ) + result = driver.script._evaluate("1 + 2", {"context": driver.current_window_handle}, await_promise=False) assert result.realm is not None assert result.result["type"] == "number" @@ -630,9 +610,7 @@ def test_disown_handles(driver, pages): assert result_before.result["value"] == "bar" # Disown the handle - driver.script._disown( - handles=[handle], target={"context": driver.current_window_handle} - ) + driver.script._disown(handles=[handle], target={"context": driver.current_window_handle}) # Try using the disowned handle (this should fail) with pytest.raises(Exception): @@ -994,9 +972,7 @@ def test_execute_script_returns_array(self, driver): def test_execute_script_dom_query(self, driver, pages): """Test executing script that queries DOM.""" pages.load("formPage.html") - result = driver.execute_script( - "return document.querySelectorAll('input').length;" - ) + result = driver.execute_script("return document.querySelectorAll('input').length;") assert result > 0 def test_execute_script_with_arguments(self, driver): @@ -1112,9 +1088,7 @@ def test_multiple_preload_scripts(self, driver, pages): def test_preload_script_with_function(self, driver, pages): """Test preload script defining functions.""" - script_id = driver.script._add_preload_script( - "() => { window.customFunc = (x) => x * 2; }" - ) + script_id = driver.script._add_preload_script("() => { window.customFunc = (x) => x * 2; }") try: pages.load("blank.html") @@ -1129,9 +1103,7 @@ def test_preload_script_with_function(self, driver, pages): def test_preload_script_removal_prevents_execution(self, driver, pages): """Test that removing preload script prevents its execution.""" - script_id = driver.script._add_preload_script( - "() => { window.shouldNotExist = true; }" - ) + script_id = driver.script._add_preload_script("() => { window.shouldNotExist = true; }") driver.script._remove_preload_script(script_id=script_id) pages.load("blank.html") @@ -1356,9 +1328,7 @@ def test_multiple_error_handlers(self, driver, pages): driver.find_element(By.ID, "jsException").click() # Both handlers should receive events when error occurs - WebDriverWait(driver, 5).until( - lambda _: len(errors1) > 0 and len(errors2) > 0 - ) + WebDriverWait(driver, 5).until(lambda _: len(errors1) > 0 and len(errors2) > 0) assert len(errors1) > 0 assert len(errors2) > 0 finally: diff --git a/py/test/selenium/webdriver/common/bidi_storage_tests.py b/py/test/selenium/webdriver/common/bidi_storage_tests.py index 157df78dd3cd9..10d96789fcf3a 100644 --- a/py/test/selenium/webdriver/common/bidi_storage_tests.py +++ b/py/test/selenium/webdriver/common/bidi_storage_tests.py @@ -98,9 +98,7 @@ def test_get_cookie_by_name(self, driver, pages, webserver): driver.add_cookie({"name": key, "value": value}) # Test - cookie_filter = CookieFilter( - name=key, value=BytesValue(BytesValue.TYPE_STRING, "set") - ) + cookie_filter = CookieFilter(name=key, value=BytesValue(BytesValue.TYPE_STRING, "set")) result = driver.storage.get_cookies(filter=cookie_filter) @@ -122,18 +120,14 @@ def test_get_cookie_in_default_user_context(self, driver, pages, webserver): driver.add_cookie({"name": key, "value": value}) # Test - cookie_filter = CookieFilter( - name=key, value=BytesValue(BytesValue.TYPE_STRING, "set") - ) + cookie_filter = CookieFilter(name=key, value=BytesValue(BytesValue.TYPE_STRING, "set")) driver.switch_to.new_window(WindowTypes.WINDOW) descriptor = BrowsingContextPartitionDescriptor(driver.current_window_handle) params = cookie_filter - result_after_switching_context = driver.storage.get_cookies( - filter=params, partition=descriptor - ) + result_after_switching_context = driver.storage.get_cookies(filter=params, partition=descriptor) assert len(result_after_switching_context.cookies) > 0 assert result_after_switching_context.cookies[0].value.value == value @@ -164,21 +158,15 @@ def test_get_cookie_in_a_user_context(self, driver, pages, webserver): descriptor = StorageKeyPartitionDescriptor(user_context=user_context) - parameters = PartialCookie( - key, BytesValue(BytesValue.TYPE_STRING, value), webserver.host - ) + parameters = PartialCookie(key, BytesValue(BytesValue.TYPE_STRING, value), webserver.host) driver.storage.set_cookie(cookie=parameters, partition=descriptor) # Test - cookie_filter = CookieFilter( - name=key, value=BytesValue(BytesValue.TYPE_STRING, "set") - ) + cookie_filter = CookieFilter(name=key, value=BytesValue(BytesValue.TYPE_STRING, "set")) # Create a new window with the user context - new_window = driver.browsing_context.create( - type=WindowTypes.TAB, user_context=user_context - ) + new_window = driver.browsing_context.create(type=WindowTypes.TAB, user_context=user_context) driver.switch_to.window(new_window) @@ -193,13 +181,9 @@ def test_get_cookie_in_a_user_context(self, driver, pages, webserver): driver.switch_to.window(window_handle) - browsing_context_partition_descriptor = BrowsingContextPartitionDescriptor( - window_handle - ) + browsing_context_partition_descriptor = BrowsingContextPartitionDescriptor(window_handle) - result1 = driver.storage.get_cookies( - filter=cookie_filter, partition=browsing_context_partition_descriptor - ) + result1 = driver.storage.get_cookies(filter=cookie_filter, partition=browsing_context_partition_descriptor) assert len(result1.cookies) == 0 @@ -214,9 +198,7 @@ def test_add_cookie(self, driver, pages, webserver): key = generate_unique_key() value = "foo" - parameters = PartialCookie( - key, BytesValue(BytesValue.TYPE_STRING, value), webserver.host - ) + parameters = PartialCookie(key, BytesValue(BytesValue.TYPE_STRING, value), webserver.host) assert_cookie_is_not_present_with_name(driver, key) # Test @@ -532,9 +514,7 @@ def test_set_cookie_with_same_site_none(self, driver, pages, webserver): value = BytesValue(BytesValue.TYPE_STRING, "none") # SameSite=None typically requires secure=True - cookie = PartialCookie( - key, value, webserver.host, same_site=SameSite.NONE, secure=True - ) + cookie = PartialCookie(key, value, webserver.host, same_site=SameSite.NONE, secure=True) # Test driver.storage.set_cookie(cookie=cookie) @@ -691,9 +671,7 @@ def test_delete_cookie_by_path(self, driver, pages, webserver): result = driver.storage.get_cookies(filter=CookieFilter()) cookie_names = [c.name for c in result.cookies] - assert key1 not in cookie_names or all( - c.path != "/simpleTest.html" for c in result.cookies if c.name == key1 - ) + assert key1 not in cookie_names or all(c.path != "/simpleTest.html" for c in result.cookies if c.name == key1) def test_cookie_expiry_timestamp(self, driver, pages, webserver): """Test that cookie expiry is stored correctly as timestamp.""" diff --git a/py/test/selenium/webdriver/common/bidi_webextension_tests.py b/py/test/selenium/webdriver/common/bidi_webextension_tests.py index 93c6c5a1d5528..ef27c24481ad0 100644 --- a/py/test/selenium/webdriver/common/bidi_webextension_tests.py +++ b/py/test/selenium/webdriver/common/bidi_webextension_tests.py @@ -179,9 +179,7 @@ def test_install_with_extension_id_uninstall(self, chromium_driver): ext_info = chromium_driver.webextension.install(path=path) extension_id = ext_info.get("extension") # Uninstall using the extension ID - uninstall_extension_and_verify_extension_uninstalled( - chromium_driver, extension_id - ) + uninstall_extension_and_verify_extension_uninstalled(chromium_driver, extension_id) # Additional edge case tests for better WPT coverage @@ -251,10 +249,7 @@ def test_extension_content_script_injection(self, driver, pages): ) assert injected_element is not None - assert ( - "Content injected by webextensions-selenium-example" - in injected_element.text - ) + assert "Content injected by webextensions-selenium-example" in injected_element.text # Cleanup driver.webextension.uninstall(ext_info) @@ -268,9 +263,7 @@ def test_uninstall_removes_content_scripts(self, driver, pages): # Verify injection works pages.load("blank.html") - WebDriverWait(driver, timeout=5).until( - lambda dr: dr.find_element(By.ID, "webextensions-selenium-example") - ) + WebDriverWait(driver, timeout=5).until(lambda dr: dr.find_element(By.ID, "webextensions-selenium-example")) # Uninstall driver.webextension.uninstall(ext_info) @@ -379,10 +372,7 @@ def test_uninstall_extension_by_id_string(self, chromium_driver, pages_chromium) # Verify uninstall was successful chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) - assert ( - len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) - == 0 - ) + assert len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 @pytest.mark.xfail_firefox def test_uninstall_extension_by_result_dict(self, chromium_driver, pages_chromium): @@ -395,10 +385,7 @@ def test_uninstall_extension_by_result_dict(self, chromium_driver, pages_chromiu # Verify uninstall was successful chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) - assert ( - len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) - == 0 - ) + assert len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 @pytest.mark.xfail_firefox def test_install_returns_extension_id(self, chromium_driver, pages_chromium): @@ -429,10 +416,7 @@ def test_extension_content_script_injection(self, chromium_driver, pages_chromiu ) assert injected_element is not None - assert ( - "Content injected by webextensions-selenium-example" - in injected_element.text - ) + assert "Content injected by webextensions-selenium-example" in injected_element.text # Cleanup chromium_driver.webextension.uninstall(ext_info) @@ -454,15 +438,10 @@ def test_uninstall_removes_content_scripts(self, chromium_driver, pages_chromium # Reload page and verify injection is gone chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) - assert ( - len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) - == 0 - ) + assert len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 @pytest.mark.xfail_firefox - def test_multiple_installations_and_uninstalls( - self, chromium_driver, pages_chromium - ): + def test_multiple_installations_and_uninstalls(self, chromium_driver, pages_chromium): """Test installing and uninstalling extension multiple times.""" path = os.path.join(EXTENSIONS, EXTENSION_PATH) @@ -471,17 +450,11 @@ def test_multiple_installations_and_uninstalls( verify_extension_injection(chromium_driver, pages_chromium) chromium_driver.webextension.uninstall(ext_info_1) chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) - assert ( - len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) - == 0 - ) + assert len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 # Install/uninstall cycle 2 ext_info_2 = chromium_driver.webextension.install(path=path) verify_extension_injection(chromium_driver, pages_chromium) chromium_driver.webextension.uninstall(ext_info_2) chromium_driver.browsing_context.reload(chromium_driver.current_window_handle) - assert ( - len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) - == 0 - ) + assert len(chromium_driver.find_elements(By.ID, "webextensions-selenium-example")) == 0 From dd6597f35513d5c34f87a30ac5815c69df378ae6 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 11:27:05 +0000 Subject: [PATCH 61/67] correct checks on method arguments --- py/generate_bidi.py | 151 +++++++++++------- py/selenium/webdriver/common/bidi/browser.py | 9 +- .../webdriver/common/bidi/browsing_context.py | 36 +++++ .../webdriver/common/bidi/emulation.py | 30 ++++ py/selenium/webdriver/common/bidi/input.py | 15 ++ py/selenium/webdriver/common/bidi/network.py | 69 ++++++++ py/selenium/webdriver/common/bidi/script.py | 32 ++++ py/selenium/webdriver/common/bidi/session.py | 6 + py/selenium/webdriver/common/bidi/storage.py | 3 + 9 files changed, 292 insertions(+), 59 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index a53ea96db7481..78a7603b929c0 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,12 +18,11 @@ import logging import re import sys -from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import dedent, indent as tw_indent -from typing import Any, Dict, List, Optional, Set, Tuple +from textwrap import indent as tw_indent +from typing import Any __version__ = "1.0.0" @@ -53,7 +52,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: +def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -139,11 +138,12 @@ class CddlCommand: module: str name: str - params: Dict[str, str] = field(default_factory=dict) - result: Optional[str] = None + params: dict[str, str] = field(default_factory=dict) + required_params: set[str] = field(default_factory=set) + result: str | None = None description: str = "" - def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python method code for this command. Args: @@ -178,7 +178,17 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str body = f" def {method_name}({param_list}):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' - # Add validation if specified + # Add automatic validation for required parameters + # (This is used unless there's no required_params, in which case all params are optional) + if self.required_params: + for param_name, snake_param in param_names: + if param_name in self.required_params: + method_snake = self._camel_to_snake(self.name) + body += f" if {snake_param} is None:\n" + body += f' raise TypeError("{method_snake}() missing required argument: {snake_param!r}")\n' + body += "\n" + + # Add validation if specified in enhancements (for additional business logic validation) if "validate" in enhancements: validate_func = enhancements["validate"] # Build parameter list for validation function @@ -264,45 +274,45 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f' item.get("{extract_property}")\n' - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += f" return extracted\n" - body += f" return result\n" + body += " return extracted\n" + body += " return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -351,10 +361,10 @@ class CddlTypeDefinition: module: str name: str - fields: Dict[str, str] = field(default_factory=dict) + fields: dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python dataclass code for this type. Args: @@ -366,7 +376,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> # Generate class name from type name (keep it as-is, don't split on underscores) class_name = self.name - code = f"@dataclass\n" + code = "@dataclass\n" code += f"class {class_name}:\n" code += f' """{self.description or self.name}."""\n\n' @@ -460,7 +470,7 @@ class CddlEnum: module: str name: str - values: List[str] = field(default_factory=list) + values: list[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -537,10 +547,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: List[CddlCommand] = field(default_factory=list) - types: List[CddlTypeDefinition] = field(default_factory=list) - enums: List[CddlEnum] = field(default_factory=list) - events: List[CddlEvent] = field(default_factory=list) + commands: list[CddlCommand] = field(default_factory=list) + types: list[CddlTypeDefinition] = field(default_factory=list) + enums: list[CddlEnum] = field(default_factory=list) + events: list[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -555,7 +565,7 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python code for this module. Args: @@ -1007,9 +1017,9 @@ def clear_event_handlers(self) -> None: code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined - code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" # Use globals() to look up types dynamically to handle missing types gracefully - code += f"_globals = globals()\n" + code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" for event_def in self.events: # Convert method name to user-friendly event name @@ -1037,9 +1047,9 @@ def __init__(self, cddl_path: str): """Initialize parser with CDDL file path.""" self.cddl_path = Path(cddl_path) self.content = "" - self.modules: Dict[str, CddlModule] = {} - self.definitions: Dict[str, str] = {} - self.event_names: Set[str] = set() # Names of definitions that are events + self.modules: dict[str, CddlModule] = {} + self.definitions: dict[str, str] = {} + self.event_names: set[str] = set() # Names of definitions that are events self._read_file() def _read_file(self) -> None: @@ -1047,12 +1057,12 @@ def _read_file(self) -> None: if not self.cddl_path.exists(): raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - with open(self.cddl_path, "r", encoding="utf-8") as f: + with open(self.cddl_path, encoding="utf-8") as f: self.content = f.read() logger.info(f"Loaded CDDL file: {self.cddl_path}") - def parse(self) -> Dict[str, CddlModule]: + def parse(self) -> dict[str, CddlModule]: """Parse CDDL content and return modules.""" # Remove comments content = self._remove_comments(self.content) @@ -1201,7 +1211,7 @@ def _is_enum_definition(self, definition: str) -> bool: # Pattern: "something" / "something_else" return " / " in clean_def and '"' in clean_def - def _extract_enum_values(self, enum_definition: str) -> List[str]: + def _extract_enum_values(self, enum_definition: str) -> list[str]: """Extract individual values from an enum definition. Enums are defined as: "value1" / "value2" / "value3" @@ -1251,7 +1261,7 @@ def _normalize_cddl_type(field_type: str) -> str: result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) return result.strip() - def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + def _extract_type_fields(self, type_definition: str) -> dict[str, str]: """Extract fields from a type definition block.""" fields = {} @@ -1361,14 +1371,17 @@ def _extract_commands(self) -> None: if module_name not in self.modules: self.modules[module_name] = CddlModule(name=module_name) - # Extract parameters - params = self._extract_parameters(params_type) + # Extract parameters and required parameters + params, required_params = self._extract_parameters_and_required( + params_type + ) # Create command cmd = CddlCommand( module=module_name, name=command_name, params=params, + required_params=required_params, description=f"Execute {method}", ) @@ -1378,24 +1391,36 @@ def _extract_commands(self) -> None: ) def _extract_parameters( - self, params_type: str, _seen: Optional[Set[str]] = None - ) -> Dict[str, str]: + self, params_type: str, _seen: set[str] | None = None + ) -> dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), merging all fields from each alternative as optional parameters. """ + params, _ = self._extract_parameters_and_required(params_type, _seen) + return params + + def _extract_parameters_and_required( + self, params_type: str, _seen: set[str] | None = None + ) -> tuple[dict[str, str], set[str]]: + """Extract parameters and track which are required from a parameter type definition. + + Returns: + Tuple of (params dict, required_params set) + """ params = {} + required = set() if _seen is None: _seen = set() if params_type in _seen: - return params + return params, required _seen.add(params_type) if params_type not in self.definitions: logger.debug(f"Parameter type not found: {params_type}") - return params + return params, required definition = self.definitions[params_type] @@ -1409,10 +1434,15 @@ def _extract_parameters( alternatives = [a.strip() for a in stripped.split("/") if a.strip()] all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) if all_named: + # For union types, collect parameters from all alternatives + # but treat them as optional since the caller only needs to pass one alternative for alt_type in alternatives: - alt_params = self._extract_parameters(alt_type, _seen) + alt_params, _ = self._extract_parameters_and_required( + alt_type, _seen + ) params.update(alt_params) - return params + # Note: We intentionally DON'T add to required, since these are union alternatives + return params, required # Remove the outer curly braces and split by comma # Then parse each line for key: type patterns @@ -1429,6 +1459,9 @@ def _extract_parameters( continue # Match pattern: [?] name: type + # Check if parameter has optional marker (?) + is_optional = line.startswith("?") + # Using a simple pattern that handles optional prefix match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) if not match: @@ -1443,11 +1476,13 @@ def _extract_parameters( # Skip lines that are part of nested definitions if "{" not in normalized_type and "(" not in normalized_type: params[param_name] = normalized_type + if not is_optional: + required.add(param_name) logger.debug( - f"Extracted param {param_name}: {normalized_type} from {params_type}" + f"Extracted param {param_name}: {normalized_type} (required={not is_optional}) from {params_type}" ) - return params + return params, required def module_name_to_class_name(module_name: str) -> str: @@ -1492,7 +1527,7 @@ def module_name_to_filename(module_name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: +def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: """Generate __init__.py file for the module.""" init_path = output_path / "__init__.py" @@ -1507,7 +1542,7 @@ def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> Non filename = module_name_to_filename(module_name) code += f"from .{filename} import {class_name}\n" - code += f"\n__all__ = [\n" + code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) code += f' "{class_name}",\n' @@ -1729,7 +1764,7 @@ def main( cddl_file: str, output_dir: str, spec_version: str = "1.0", - enhancements_manifest: Optional[str] = None, + enhancements_manifest: str | None = None, ) -> None: """Main entry point. diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 7c1958fd435f0..c4017265ac757 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -275,6 +275,9 @@ def get_user_contexts(self): def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" + if user_context is None: + raise TypeError("remove_user_context() missing required argument: 'user_context'") + params = { "userContext": user_context, } @@ -285,6 +288,9 @@ def remove_user_context(self, user_context: Any | None = None): def set_client_window_state(self, client_window: Any | None = None): """Execute browser.setClientWindowState.""" + if client_window is None: + raise TypeError("set_client_window_state() missing required argument: 'client_window'") + params = { "clientWindow": client_window, } @@ -295,6 +301,7 @@ def set_client_window_state(self, client_window: Any | None = None): def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) download_behavior = None @@ -368,7 +375,7 @@ def set_client_window_state( if hasattr(state, '__dataclass_fields__'): # It's a dataclass, convert to dict state_param = { - k: v for k, v in state.__dict__.items() + k: v for k, v in state.__dict__.items() if v is not None } diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index d17829709c0c3..775bcdb8f9dbb 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -622,6 +622,9 @@ def __init__(self, conn) -> None: def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" + if context is None: + raise TypeError("activate() missing required argument: 'context'") + params = { "context": context, } @@ -632,6 +635,9 @@ def activate(self, context: Any | None = None): def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): """Execute browsingContext.captureScreenshot.""" + if context is None: + raise TypeError("capture_screenshot() missing required argument: 'context'") + params = { "context": context, "format": format, @@ -648,6 +654,9 @@ def capture_screenshot(self, context: str | None = None, format: Any | None = No def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" + if context is None: + raise TypeError("close() missing required argument: 'context'") + params = { "context": context, "promptUnload": prompt_unload, @@ -659,6 +668,9 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): """Execute browsingContext.create.""" + if type is None: + raise TypeError("create() missing required argument: 'type'") + params = { "type": type, "referenceContext": reference_context, @@ -701,6 +713,9 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" + if context is None: + raise TypeError("handle_user_prompt() missing required argument: 'context'") + params = { "context": context, "accept": accept, @@ -713,6 +728,11 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): """Execute browsingContext.locateNodes.""" + if context is None: + raise TypeError("locate_nodes() missing required argument: 'context'") + if locator is None: + raise TypeError("locate_nodes() missing required argument: 'locator'") + params = { "context": context, "locator": locator, @@ -730,6 +750,11 @@ def locate_nodes(self, context: str | None = None, locator: Any | None = None, s def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" + if context is None: + raise TypeError("navigate() missing required argument: 'context'") + if url is None: + raise TypeError("navigate() missing required argument: 'url'") + params = { "context": context, "url": url, @@ -742,6 +767,9 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): """Execute browsingContext.print.""" + if context is None: + raise TypeError("print() missing required argument: 'context'") + params = { "context": context, "background": background, @@ -760,6 +788,9 @@ def print(self, context: Any | None = None, background: bool | None = None, marg def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" + if context is None: + raise TypeError("reload() missing required argument: 'context'") + params = { "context": context, "ignoreCache": ignore_cache, @@ -785,6 +816,11 @@ def set_viewport(self, context: str | None = None, viewport: Any | None = None, def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" + if context is None: + raise TypeError("traverse_history() missing required argument: 'context'") + if delta is None: + raise TypeError("traverse_history() missing required argument: 'delta'") + params = { "context": context, "delta": delta, diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 7edb7a9dacd06..8428c233682b8 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -193,6 +193,9 @@ def __init__(self, conn) -> None: def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setForcedColorsModeThemeOverride.""" + if theme is None: + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") + params = { "theme": theme, "contexts": contexts, @@ -216,6 +219,9 @@ def set_geolocation_override(self, contexts: List[Any] | None = None, user_conte def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setLocaleOverride.""" + if locale is None: + raise TypeError("set_locale_override() missing required argument: 'locale'") + params = { "locale": locale, "contexts": contexts, @@ -228,6 +234,9 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setNetworkConditions.""" + if network_conditions is None: + raise TypeError("set_network_conditions() missing required argument: 'network_conditions'") + params = { "networkConditions": network_conditions, "contexts": contexts, @@ -240,6 +249,9 @@ def set_network_conditions(self, network_conditions: Any | None = None, contexts def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenSettingsOverride.""" + if screen_area is None: + raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") + params = { "screenArea": screen_area, "contexts": contexts, @@ -252,6 +264,9 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenOrientationOverride.""" + if screen_orientation is None: + raise TypeError("set_screen_orientation_override() missing required argument: 'screen_orientation'") + params = { "screenOrientation": screen_orientation, "contexts": contexts, @@ -264,6 +279,9 @@ def set_screen_orientation_override(self, screen_orientation: Any | None = None, def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setUserAgentOverride.""" + if user_agent is None: + raise TypeError("set_user_agent_override() missing required argument: 'user_agent'") + params = { "userAgent": user_agent, "contexts": contexts, @@ -276,6 +294,9 @@ def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setViewportMetaOverride.""" + if viewport_meta is None: + raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") + params = { "viewportMeta": viewport_meta, "contexts": contexts, @@ -288,6 +309,9 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScriptingEnabled.""" + if enabled is None: + raise TypeError("set_scripting_enabled() missing required argument: 'enabled'") + params = { "enabled": enabled, "contexts": contexts, @@ -300,6 +324,9 @@ def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScrollbarTypeOverride.""" + if scrollbar_type is None: + raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") + params = { "scrollbarType": scrollbar_type, "contexts": contexts, @@ -312,6 +339,9 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTimezoneOverride.""" + if timezone is None: + raise TypeError("set_timezone_override() missing required argument: 'timezone'") + params = { "timezone": timezone, "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index a294bde307b89..2a19d8072781a 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -370,6 +370,11 @@ def __init__(self, conn) -> None: def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" + if context is None: + raise TypeError("perform_actions() missing required argument: 'context'") + if actions is None: + raise TypeError("perform_actions() missing required argument: 'actions'") + params = { "context": context, "actions": actions, @@ -381,6 +386,9 @@ def perform_actions(self, context: Any | None = None, actions: List[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" + if context is None: + raise TypeError("release_actions() missing required argument: 'context'") + params = { "context": context, } @@ -391,6 +399,13 @@ def release_actions(self, context: Any | None = None): def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" + if context is None: + raise TypeError("set_files() missing required argument: 'context'") + if element is None: + raise TypeError("set_files() missing required argument: 'element'") + if files is None: + raise TypeError("set_files() missing required argument: 'files'") + params = { "context": context, "element": element, diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index af079f421546c..1f6b0471f2414 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -565,6 +565,11 @@ def __init__(self, conn) -> None: def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.addDataCollector.""" + if data_types is None: + raise TypeError("add_data_collector() missing required argument: 'data_types'") + if max_encoded_data_size is None: + raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") + params = { "dataTypes": data_types, "maxEncodedDataSize": max_encoded_data_size, @@ -579,6 +584,9 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): """Execute network.addIntercept.""" + if phases is None: + raise TypeError("add_intercept() missing required argument: 'phases'") + params = { "phases": phases, "contexts": contexts, @@ -591,6 +599,9 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): """Execute network.continueRequest.""" + if request is None: + raise TypeError("continue_request() missing required argument: 'request'") + params = { "request": request, "body": body, @@ -606,6 +617,9 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.continueResponse.""" + if request is None: + raise TypeError("continue_response() missing required argument: 'request'") + params = { "request": request, "cookies": cookies, @@ -621,6 +635,9 @@ def continue_response(self, request: Any | None = None, cookies: List[Any] | Non def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" + if request is None: + raise TypeError("continue_with_auth() missing required argument: 'request'") + params = { "request": request, } @@ -631,6 +648,13 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" + if data_type is None: + raise TypeError("disown_data() missing required argument: 'data_type'") + if collector is None: + raise TypeError("disown_data() missing required argument: 'collector'") + if request is None: + raise TypeError("disown_data() missing required argument: 'request'") + params = { "dataType": data_type, "collector": collector, @@ -643,6 +667,9 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" + if request is None: + raise TypeError("fail_request() missing required argument: 'request'") + params = { "request": request, } @@ -653,6 +680,11 @@ def fail_request(self, request: Any | None = None): def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): """Execute network.getData.""" + if data_type is None: + raise TypeError("get_data() missing required argument: 'data_type'") + if request is None: + raise TypeError("get_data() missing required argument: 'request'") + params = { "dataType": data_type, "collector": collector, @@ -666,6 +698,9 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.provideResponse.""" + if request is None: + raise TypeError("provide_response() missing required argument: 'request'") + params = { "request": request, "body": body, @@ -681,6 +716,9 @@ def provide_response(self, request: Any | None = None, body: Any | None = None, def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" + if collector is None: + raise TypeError("remove_data_collector() missing required argument: 'collector'") + params = { "collector": collector, } @@ -691,6 +729,9 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" + if intercept is None: + raise TypeError("remove_intercept() missing required argument: 'intercept'") + params = { "intercept": intercept, } @@ -701,6 +742,9 @@ def remove_intercept(self, intercept: Any | None = None): def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" + if cache_behavior is None: + raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") + params = { "cacheBehavior": cache_behavior, "contexts": contexts, @@ -712,6 +756,9 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.setExtraHeaders.""" + if headers is None: + raise TypeError("set_extra_headers() missing required argument: 'headers'") + params = { "headers": headers, "contexts": contexts, @@ -724,6 +771,11 @@ def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" + if method is None: + raise TypeError("before_request_sent() missing required argument: 'method'") + if params is None: + raise TypeError("before_request_sent() missing required argument: 'params'") + params = { "initiator": initiator, "method": method, @@ -736,6 +788,13 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" + if error_text is None: + raise TypeError("fetch_error() missing required argument: 'error_text'") + if method is None: + raise TypeError("fetch_error() missing required argument: 'method'") + if params is None: + raise TypeError("fetch_error() missing required argument: 'params'") + params = { "errorText": error_text, "method": method, @@ -748,6 +807,13 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" + if response is None: + raise TypeError("response_completed() missing required argument: 'response'") + if method is None: + raise TypeError("response_completed() missing required argument: 'method'") + if params is None: + raise TypeError("response_completed() missing required argument: 'params'") + params = { "response": response, "method": method, @@ -760,6 +826,9 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" + if response is None: + raise TypeError("response_started() missing required argument: 'response'") + params = { "response": response, } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 492d1fe431680..0f59c400a38c2 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -785,6 +785,9 @@ def __init__(self, conn, driver=None) -> None: def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): """Execute script.addPreloadScript.""" + if function_declaration is None: + raise TypeError("add_preload_script() missing required argument: 'function_declaration'") + params = { "functionDeclaration": function_declaration, "arguments": arguments, @@ -799,6 +802,11 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" + if handles is None: + raise TypeError("disown() missing required argument: 'handles'") + if target is None: + raise TypeError("disown() missing required argument: 'target'") + params = { "handles": handles, "target": target, @@ -810,6 +818,13 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): """Execute script.callFunction.""" + if function_declaration is None: + raise TypeError("call_function() missing required argument: 'function_declaration'") + if await_promise is None: + raise TypeError("call_function() missing required argument: 'await_promise'") + if target is None: + raise TypeError("call_function() missing required argument: 'target'") + params = { "functionDeclaration": function_declaration, "awaitPromise": await_promise, @@ -827,6 +842,13 @@ def call_function(self, function_declaration: Any | None = None, await_promise: def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): """Execute script.evaluate.""" + if expression is None: + raise TypeError("evaluate() missing required argument: 'expression'") + if target is None: + raise TypeError("evaluate() missing required argument: 'target'") + if await_promise is None: + raise TypeError("evaluate() missing required argument: 'await_promise'") + params = { "expression": expression, "target": target, @@ -853,6 +875,9 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" + if script is None: + raise TypeError("remove_preload_script() missing required argument: 'script'") + params = { "script": script, } @@ -863,6 +888,13 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" + if channel is None: + raise TypeError("message() missing required argument: 'channel'") + if data is None: + raise TypeError("message() missing required argument: 'data'") + if source is None: + raise TypeError("message() missing required argument: 'source'") + params = { "channel": channel, "data": data, diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index f1430cb6e59d3..374375a62f2ec 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -194,6 +194,9 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" + if capabilities is None: + raise TypeError("new() missing required argument: 'capabilities'") + params = { "capabilities": capabilities, } @@ -213,6 +216,9 @@ def end(self): def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute session.subscribe.""" + if events is None: + raise TypeError("subscribe() missing required argument: 'events'") + params = { "events": events, "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 833e9cdc74f2a..8742dc61ebccf 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -248,6 +248,9 @@ def get_cookies(self, filter: Any | None = None, partition: Any | None = None): def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): """Execute storage.setCookie.""" + if cookie is None: + raise TypeError("set_cookie() missing required argument: 'cookie'") + params = { "cookie": cookie, "partition": partition, From 9c813d67bd826775c1ced6741e086fc338316edc Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 12:32:27 +0000 Subject: [PATCH 62/67] improve generation so we don't need to run ruffs over it --- py/generate_bidi.py | 81 +++++++-- py/private/bidi_enhancements_manifest.py | 14 +- py/selenium/webdriver/common/bidi/browser.py | 47 +---- .../webdriver/common/bidi/browsing_context.py | 167 ++++++++++++------ .../webdriver/common/bidi/emulation.py | 131 ++++---------- py/selenium/webdriver/common/bidi/input.py | 18 +- py/selenium/webdriver/common/bidi/log.py | 6 +- py/selenium/webdriver/common/bidi/network.py | 119 +++++++++---- py/selenium/webdriver/common/bidi/script.py | 69 ++++++-- py/selenium/webdriver/common/bidi/session.py | 11 +- py/selenium/webdriver/common/bidi/storage.py | 36 ---- 11 files changed, 384 insertions(+), 315 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 78a7603b929c0..affd0a63a750c 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -170,22 +170,34 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: param_strs.append(f"{snake_param}: {python_type} | None = None") if param_strs: - param_list = "self, " + ", ".join(param_strs) + # Check if full signature would exceed line length limit (120 chars) + single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" + if len(single_line_signature) > 120: + # Format parameters on multiple lines + body = f" def {method_name}(\n" + body += " self,\n" + for i, param_str in enumerate(param_strs): + if i < len(param_strs) - 1: + body += f" {param_str},\n" + else: + body += f" {param_str},\n" + body += " ):\n" + else: + param_list = "self, " + ", ".join(param_strs) + body = f" def {method_name}({param_list}):\n" else: - param_list = "self" - - # Build method body - body = f" def {method_name}({param_list}):\n" + body = f" def {method_name}(self):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add automatic validation for required parameters # (This is used unless there's no required_params, in which case all params are optional) if self.required_params: + method_snake = self._camel_to_snake(self.name) for param_name, snake_param in param_names: if param_name in self.required_params: - method_snake = self._camel_to_snake(self.name) body += f" if {snake_param} is None:\n" - body += f' raise TypeError("{method_snake}() missing required argument: {snake_param!r}")\n' + msg = f"{method_snake}() missing required argument:" + body += f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' body += "\n" # Add validation if specified in enhancements (for additional business logic validation) @@ -247,7 +259,6 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform - override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -667,8 +678,20 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Generate enums first + # Generate enums first (excluding those in exclude_types) + exclude_types = set(enhancements.get("exclude_types", [])) + + # Also exclude any types that have extra_dataclasses overrides + # Extract class names from extra_dataclasses strings + for extra_cls in enhancements.get("extra_dataclasses", []): + # Match "class ClassName:" pattern + match = re.search(r"class\s+(\w+)\s*:", extra_cls) + if match: + exclude_types.add(match.group(1)) + for enum_def in self.enums: + if enum_def.name in exclude_types: + continue code += enum_def.to_python_class() code += "\n\n" @@ -677,7 +700,6 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - exclude_types = set(enhancements.get("exclude_types", [])) for type_def in self.types: if type_def.name in exclude_types: continue @@ -946,6 +968,16 @@ def clear_event_handlers(self) -> None: # Generate command methods exclude_methods = enhancements.get("exclude_methods", []) + + # Automatically exclude methods that are defined in extra_methods + # to prevent generating duplicates + if "extra_methods" in enhancements: + for extra_method in enhancements["extra_methods"]: + # Extract method name from "def method_name(" + match = re.search(r"def\s+(\w+)\s*\(", extra_method) + if match: + exclude_methods = list(exclude_methods) + [match.group(1)] + if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1026,9 +1058,26 @@ def clear_event_handlers(self) -> None: method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - # The event class is the event name (e.g., ContextCreated) - # Try to get it from globals, default to dict if not found - code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + # Try to get event class from globals, default to dict if not found + getter = f'_globals.get("{event_def.name}", dict)' + condition = f'_globals.get("{event_def.name}")' + event_class = f'{getter} if {condition} else dict' + + # Build the entry line and check if it exceeds 120 chars + single_line = ( + f' "{event_name}": ' + f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' + ) + + if len(single_line) > 120: + # Break into multiple lines + code += f' "{event_name}": EventConfig(\n' + code += f' "{event_name}",\n' + code += f' "{event_def.method}",\n' + code += f' {event_class},\n' + code += ' ),\n' + else: + code += single_line + '\n' # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] @@ -1126,9 +1175,6 @@ def _extract_event_names(self) -> None: ... ) """ - # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. - event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") - for def_name, def_content in self.definitions.items(): # Check if this looks like an event union (name ends with "Event") and # contains a module-qualified reference like "module.EventName". @@ -1479,7 +1525,8 @@ def _extract_parameters_and_required( if not is_optional: required.add(param_name) logger.debug( - f"Extracted param {param_name}: {normalized_type} (required={not is_optional}) from {params_type}" + f"Extracted param {param_name}: {normalized_type} " + f"(required={not is_optional}) from {params_type}" ) return params, required diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 2b93f36f1a5dc..40647157f8535 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -252,19 +252,7 @@ def from_json(cls, params: dict) -> "DownloadEndParams": ) return cls(download_params=dp)''', ], - # Non-CDDL download events (Chromium-specific, not in the BiDi spec) - "extra_events": [ - { - "event_key": "download_will_begin", - "bidi_event": "browsingContext.downloadWillBegin", - "event_class": "DownloadWillBeginParams", - }, - { - "event_key": "download_end", - "bidi_event": "browsingContext.downloadEnd", - "event_class": "DownloadEndParams", - }, - ], + # Download events are now in the CDDL spec, so no extra_events needed }, "log": { # Make LogLevel an alias for Level so existing code using LogLevel works diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index c4017265ac757..77ae8f0696281 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -61,14 +61,6 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") -class ClientWindowNamedState: - """ClientWindowNamedState.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - - @dataclass class ClientWindowInfo: """ClientWindowInfo.""" @@ -212,7 +204,12 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): + def create_user_context( + self, + accept_insecure_certs: bool | None = None, + proxy: Any | None = None, + unhandled_prompt_behavior: Any | None = None, + ): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -276,7 +273,7 @@ def get_user_contexts(self): def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" if user_context is None: - raise TypeError("remove_user_context() missing required argument: 'user_context'") + raise TypeError("remove_user_context() missing required argument: {{snake_param!r}}") params = { "userContext": user_context, @@ -286,36 +283,6 @@ def remove_user_context(self, user_context: Any | None = None): result = self._conn.execute(cmd) return result - def set_client_window_state(self, client_window: Any | None = None): - """Execute browser.setClientWindowState.""" - if client_window is None: - raise TypeError("set_client_window_state() missing required argument: 'client_window'") - - params = { - "clientWindow": client_window, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setClientWindowState", params) - result = self._conn.execute(cmd) - return result - - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): - """Execute browser.setDownloadBehavior.""" - - validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) - - download_behavior = None - download_behavior = transform_download_params(allowed, destination_folder) - - params = { - "downloadBehavior": download_behavior, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setDownloadBehavior", params) - result = self._conn.execute(cmd) - return result - def set_download_behavior( self, allowed: bool | None = None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 775bcdb8f9dbb..3f877b06b00ab 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -328,20 +328,6 @@ class HistoryUpdatedParameters: url: str | None = None -@dataclass -class DownloadWillBeginParams: - """DownloadWillBeginParams.""" - - suggested_filename: str | None = None - - -@dataclass -class DownloadCanceledParams: - """DownloadCanceledParams.""" - - status: str = field(default="canceled", init=False) - - @dataclass class UserPromptClosedParameters: """UserPromptClosedParameters.""" @@ -421,8 +407,6 @@ def from_json(cls, params: dict) -> "DownloadEndParams": "navigation_failed": "browsingContext.navigationFailed", "user_prompt_closed": "browsingContext.userPromptClosed", "user_prompt_opened": "browsingContext.userPromptOpened", - "download_will_begin": "browsingContext.downloadWillBegin", - "download_end": "browsingContext.downloadEnd", } def _deserialize_info_list(items: list) -> list | None: @@ -623,7 +607,7 @@ def __init__(self, conn) -> None: def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" if context is None: - raise TypeError("activate() missing required argument: 'context'") + raise TypeError("activate() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -633,10 +617,16 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): + def capture_screenshot( + self, + context: str | None = None, + format: Any | None = None, + clip: Any | None = None, + origin: str | None = None, + ): """Execute browsingContext.captureScreenshot.""" if context is None: - raise TypeError("capture_screenshot() missing required argument: 'context'") + raise TypeError("capture_screenshot() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -655,7 +645,7 @@ def capture_screenshot(self, context: str | None = None, format: Any | None = No def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" if context is None: - raise TypeError("close() missing required argument: 'context'") + raise TypeError("close() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -666,10 +656,16 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): + def create( + self, + type: Any | None = None, + reference_context: Any | None = None, + background: bool | None = None, + user_context: Any | None = None, + ): """Execute browsingContext.create.""" if type is None: - raise TypeError("create() missing required argument: 'type'") + raise TypeError("create() missing required argument: {{snake_param!r}}") params = { "type": type, @@ -714,7 +710,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" if context is None: - raise TypeError("handle_user_prompt() missing required argument: 'context'") + raise TypeError("handle_user_prompt() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -726,12 +722,19 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): + def locate_nodes( + self, + context: str | None = None, + locator: Any | None = None, + serialization_options: Any | None = None, + start_nodes: Any | None = None, + max_node_count: int | None = None, + ): """Execute browsingContext.locateNodes.""" if context is None: - raise TypeError("locate_nodes() missing required argument: 'context'") + raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") if locator is None: - raise TypeError("locate_nodes() missing required argument: 'locator'") + raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -751,9 +754,9 @@ def locate_nodes(self, context: str | None = None, locator: Any | None = None, s def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" if context is None: - raise TypeError("navigate() missing required argument: 'context'") + raise TypeError("navigate() missing required argument: {{snake_param!r}}") if url is None: - raise TypeError("navigate() missing required argument: 'url'") + raise TypeError("navigate() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -765,10 +768,18 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): + def print( + self, + context: Any | None = None, + background: bool | None = None, + margin: Any | None = None, + page: Any | None = None, + scale: Any | None = None, + shrink_to_fit: bool | None = None, + ): """Execute browsingContext.print.""" if context is None: - raise TypeError("print() missing required argument: 'context'") + raise TypeError("print() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -789,7 +800,7 @@ def print(self, context: Any | None = None, background: bool | None = None, marg def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" if context is None: - raise TypeError("reload() missing required argument: 'context'") + raise TypeError("reload() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -801,7 +812,13 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + def set_viewport( + self, + context: str | None = None, + viewport: Any | None = None, + user_contexts: Any | None = None, + device_pixel_ratio: Any | None = None, + ): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -817,9 +834,9 @@ def set_viewport(self, context: str | None = None, viewport: Any | None = None, def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" if context is None: - raise TypeError("traverse_history() missing required argument: 'context'") + raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") if delta is None: - raise TypeError("traverse_history() missing required argument: 'delta'") + raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -904,20 +921,70 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), - "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), - "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), - "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), - "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), - "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), - "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), - "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), - "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), - "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), - "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), - "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), - "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), - "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), - "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), - "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), + "context_created": EventConfig( + "context_created", + "browsingContext.contextCreated", + _globals.get("ContextCreated", dict) if _globals.get("ContextCreated") else dict, + ), + "context_destroyed": EventConfig( + "context_destroyed", + "browsingContext.contextDestroyed", + _globals.get("ContextDestroyed", dict) if _globals.get("ContextDestroyed") else dict, + ), + "navigation_started": EventConfig( + "navigation_started", + "browsingContext.navigationStarted", + _globals.get("NavigationStarted", dict) if _globals.get("NavigationStarted") else dict, + ), + "fragment_navigated": EventConfig( + "fragment_navigated", + "browsingContext.fragmentNavigated", + _globals.get("FragmentNavigated", dict) if _globals.get("FragmentNavigated") else dict, + ), + "history_updated": EventConfig( + "history_updated", + "browsingContext.historyUpdated", + _globals.get("HistoryUpdated", dict) if _globals.get("HistoryUpdated") else dict, + ), + "dom_content_loaded": EventConfig( + "dom_content_loaded", + "browsingContext.domContentLoaded", + _globals.get("DomContentLoaded", dict) if _globals.get("DomContentLoaded") else dict, + ), + "load": EventConfig("load", "browsingContext.load", _globals.get("Load", dict) if _globals.get("Load") else dict), + "download_will_begin": EventConfig( + "download_will_begin", + "browsingContext.downloadWillBegin", + _globals.get("DownloadWillBegin", dict) if _globals.get("DownloadWillBegin") else dict, + ), + "download_end": EventConfig( + "download_end", + "browsingContext.downloadEnd", + _globals.get("DownloadEnd", dict) if _globals.get("DownloadEnd") else dict, + ), + "navigation_aborted": EventConfig( + "navigation_aborted", + "browsingContext.navigationAborted", + _globals.get("NavigationAborted", dict) if _globals.get("NavigationAborted") else dict, + ), + "navigation_committed": EventConfig( + "navigation_committed", + "browsingContext.navigationCommitted", + _globals.get("NavigationCommitted", dict) if _globals.get("NavigationCommitted") else dict, + ), + "navigation_failed": EventConfig( + "navigation_failed", + "browsingContext.navigationFailed", + _globals.get("NavigationFailed", dict) if _globals.get("NavigationFailed") else dict, + ), + "user_prompt_closed": EventConfig( + "user_prompt_closed", + "browsingContext.userPromptClosed", + _globals.get("UserPromptClosed", dict) if _globals.get("UserPromptClosed") else dict, + ), + "user_prompt_opened": EventConfig( + "user_prompt_opened", + "browsingContext.userPromptOpened", + _globals.get("UserPromptOpened", dict) if _globals.get("UserPromptOpened") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 8428c233682b8..d482fecc755cb 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -191,10 +191,15 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_forced_colors_mode_theme_override( + self, + theme: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setForcedColorsModeThemeOverride.""" if theme is None: - raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: {{snake_param!r}}") params = { "theme": theme, @@ -206,21 +211,15 @@ def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contex result = self._conn.execute(cmd) return result - def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setGeolocationOverride.""" - params = { - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setGeolocationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_locale_override( + self, + locale: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setLocaleOverride.""" if locale is None: - raise TypeError("set_locale_override() missing required argument: 'locale'") + raise TypeError("set_locale_override() missing required argument: {{snake_param!r}}") params = { "locale": locale, @@ -232,25 +231,15 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setNetworkConditions.""" - if network_conditions is None: - raise TypeError("set_network_conditions() missing required argument: 'network_conditions'") - - params = { - "networkConditions": network_conditions, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setNetworkConditions", params) - result = self._conn.execute(cmd) - return result - - def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_screen_settings_override( + self, + screen_area: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setScreenSettingsOverride.""" if screen_area is None: - raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") + raise TypeError("set_screen_settings_override() missing required argument: {{snake_param!r}}") params = { "screenArea": screen_area, @@ -262,40 +251,15 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScreenOrientationOverride.""" - if screen_orientation is None: - raise TypeError("set_screen_orientation_override() missing required argument: 'screen_orientation'") - - params = { - "screenOrientation": screen_orientation, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScreenOrientationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setUserAgentOverride.""" - if user_agent is None: - raise TypeError("set_user_agent_override() missing required argument: 'user_agent'") - - params = { - "userAgent": user_agent, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setUserAgentOverride", params) - result = self._conn.execute(cmd) - return result - - def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_viewport_meta_override( + self, + viewport_meta: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setViewportMetaOverride.""" if viewport_meta is None: - raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") + raise TypeError("set_viewport_meta_override() missing required argument: {{snake_param!r}}") params = { "viewportMeta": viewport_meta, @@ -307,25 +271,15 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScriptingEnabled.""" - if enabled is None: - raise TypeError("set_scripting_enabled() missing required argument: 'enabled'") - - params = { - "enabled": enabled, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScriptingEnabled", params) - result = self._conn.execute(cmd) - return result - - def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_scrollbar_type_override( + self, + scrollbar_type: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setScrollbarTypeOverride.""" if scrollbar_type is None: - raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") + raise TypeError("set_scrollbar_type_override() missing required argument: {{snake_param!r}}") params = { "scrollbarType": scrollbar_type, @@ -337,21 +291,6 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context result = self._conn.execute(cmd) return result - def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setTimezoneOverride.""" - if timezone is None: - raise TypeError("set_timezone_override() missing required argument: 'timezone'") - - params = { - "timezone": timezone, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setTimezoneOverride", params) - result = self._conn.execute(cmd) - return result - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 2a19d8072781a..0990dacc39363 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -371,9 +371,9 @@ def __init__(self, conn) -> None: def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" if context is None: - raise TypeError("perform_actions() missing required argument: 'context'") + raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") if actions is None: - raise TypeError("perform_actions() missing required argument: 'actions'") + raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -387,7 +387,7 @@ def perform_actions(self, context: Any | None = None, actions: List[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" if context is None: - raise TypeError("release_actions() missing required argument: 'context'") + raise TypeError("release_actions() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -400,11 +400,11 @@ def release_actions(self, context: Any | None = None): def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" if context is None: - raise TypeError("set_files() missing required argument: 'context'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") if element is None: - raise TypeError("set_files() missing required argument: 'element'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") if files is None: - raise TypeError("set_files() missing required argument: 'files'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -469,5 +469,9 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), + "file_dialog_opened": EventConfig( + "file_dialog_opened", + "input.fileDialogOpened", + _globals.get("FileDialogOpened", dict) if _globals.get("FileDialogOpened") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 1f16849b8e03d..07121242348ea 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -299,5 +299,9 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Log.EVENT_CONFIGS = { - "entry_added": (EventConfig("entry_added", "log.entryAdded", _globals.get("EntryAdded", dict)) if _globals.get("EntryAdded") else EventConfig("entry_added", "log.entryAdded", dict)), + "entry_added": EventConfig( + "entry_added", + "log.entryAdded", + _globals.get("EntryAdded", dict) if _globals.get("EntryAdded") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 1f6b0471f2414..d7baeb07040ce 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -563,12 +563,19 @@ def __init__(self, conn) -> None: self.intercepts = [] self._handler_intercepts: dict = {} - def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def add_data_collector( + self, + data_types: List[Any] | None = None, + max_encoded_data_size: Any | None = None, + collector_type: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute network.addDataCollector.""" if data_types is None: - raise TypeError("add_data_collector() missing required argument: 'data_types'") + raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") if max_encoded_data_size is None: - raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") + raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") params = { "dataTypes": data_types, @@ -582,10 +589,15 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da result = self._conn.execute(cmd) return result - def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): + def add_intercept( + self, + phases: List[Any] | None = None, + contexts: List[Any] | None = None, + url_patterns: List[Any] | None = None, + ): """Execute network.addIntercept.""" if phases is None: - raise TypeError("add_intercept() missing required argument: 'phases'") + raise TypeError("add_intercept() missing required argument: {{snake_param!r}}") params = { "phases": phases, @@ -597,10 +609,18 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): + def continue_request( + self, + request: Any | None = None, + body: Any | None = None, + cookies: List[Any] | None = None, + headers: List[Any] | None = None, + method: Any | None = None, + url: Any | None = None, + ): """Execute network.continueRequest.""" if request is None: - raise TypeError("continue_request() missing required argument: 'request'") + raise TypeError("continue_request() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -615,10 +635,18 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, result = self._conn.execute(cmd) return result - def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def continue_response( + self, + request: Any | None = None, + cookies: List[Any] | None = None, + credentials: Any | None = None, + headers: List[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.continueResponse.""" if request is None: - raise TypeError("continue_response() missing required argument: 'request'") + raise TypeError("continue_response() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -636,7 +664,7 @@ def continue_response(self, request: Any | None = None, cookies: List[Any] | Non def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" if request is None: - raise TypeError("continue_with_auth() missing required argument: 'request'") + raise TypeError("continue_with_auth() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -649,11 +677,11 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" if data_type is None: - raise TypeError("disown_data() missing required argument: 'data_type'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") if collector is None: - raise TypeError("disown_data() missing required argument: 'collector'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") if request is None: - raise TypeError("disown_data() missing required argument: 'request'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") params = { "dataType": data_type, @@ -668,7 +696,7 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" if request is None: - raise TypeError("fail_request() missing required argument: 'request'") + raise TypeError("fail_request() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -678,12 +706,18 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): + def get_data( + self, + data_type: Any | None = None, + collector: Any | None = None, + disown: bool | None = None, + request: Any | None = None, + ): """Execute network.getData.""" if data_type is None: - raise TypeError("get_data() missing required argument: 'data_type'") + raise TypeError("get_data() missing required argument: {{snake_param!r}}") if request is None: - raise TypeError("get_data() missing required argument: 'request'") + raise TypeError("get_data() missing required argument: {{snake_param!r}}") params = { "dataType": data_type, @@ -696,10 +730,18 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d result = self._conn.execute(cmd) return result - def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def provide_response( + self, + request: Any | None = None, + body: Any | None = None, + cookies: List[Any] | None = None, + headers: List[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.provideResponse.""" if request is None: - raise TypeError("provide_response() missing required argument: 'request'") + raise TypeError("provide_response() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -717,7 +759,7 @@ def provide_response(self, request: Any | None = None, body: Any | None = None, def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" if collector is None: - raise TypeError("remove_data_collector() missing required argument: 'collector'") + raise TypeError("remove_data_collector() missing required argument: {{snake_param!r}}") params = { "collector": collector, @@ -730,7 +772,7 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" if intercept is None: - raise TypeError("remove_intercept() missing required argument: 'intercept'") + raise TypeError("remove_intercept() missing required argument: {{snake_param!r}}") params = { "intercept": intercept, @@ -743,7 +785,7 @@ def remove_intercept(self, intercept: Any | None = None): def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" if cache_behavior is None: - raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") + raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") params = { "cacheBehavior": cache_behavior, @@ -754,10 +796,15 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A result = self._conn.execute(cmd) return result - def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_extra_headers( + self, + headers: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute network.setExtraHeaders.""" if headers is None: - raise TypeError("set_extra_headers() missing required argument: 'headers'") + raise TypeError("set_extra_headers() missing required argument: {{snake_param!r}}") params = { "headers": headers, @@ -772,9 +819,9 @@ def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" if method is None: - raise TypeError("before_request_sent() missing required argument: 'method'") + raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("before_request_sent() missing required argument: 'params'") + raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") params = { "initiator": initiator, @@ -789,11 +836,11 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" if error_text is None: - raise TypeError("fetch_error() missing required argument: 'error_text'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") if method is None: - raise TypeError("fetch_error() missing required argument: 'method'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("fetch_error() missing required argument: 'params'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") params = { "errorText": error_text, @@ -808,11 +855,11 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" if response is None: - raise TypeError("response_completed() missing required argument: 'response'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") if method is None: - raise TypeError("response_completed() missing required argument: 'method'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("response_completed() missing required argument: 'params'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") params = { "response": response, @@ -827,7 +874,7 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" if response is None: - raise TypeError("response_started() missing required argument: 'response'") + raise TypeError("response_started() missing required argument: {{snake_param!r}}") params = { "response": response, @@ -995,6 +1042,10 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), + "auth_required": EventConfig( + "auth_required", + "network.authRequired", + _globals.get("AuthRequired", dict) if _globals.get("AuthRequired") else dict, + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 0f59c400a38c2..8e832f4a9cae9 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -783,10 +783,17 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): + def add_preload_script( + self, + function_declaration: Any | None = None, + arguments: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + sandbox: Any | None = None, + ): """Execute script.addPreloadScript.""" if function_declaration is None: - raise TypeError("add_preload_script() missing required argument: 'function_declaration'") + raise TypeError("add_preload_script() missing required argument: {{snake_param!r}}") params = { "functionDeclaration": function_declaration, @@ -803,9 +810,9 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" if handles is None: - raise TypeError("disown() missing required argument: 'handles'") + raise TypeError("disown() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("disown() missing required argument: 'target'") + raise TypeError("disown() missing required argument: {{snake_param!r}}") params = { "handles": handles, @@ -816,14 +823,24 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): + def call_function( + self, + function_declaration: Any | None = None, + await_promise: bool | None = None, + target: Any | None = None, + arguments: List[Any] | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + this: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.callFunction.""" if function_declaration is None: - raise TypeError("call_function() missing required argument: 'function_declaration'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") if await_promise is None: - raise TypeError("call_function() missing required argument: 'await_promise'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("call_function() missing required argument: 'target'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") params = { "functionDeclaration": function_declaration, @@ -840,14 +857,22 @@ def call_function(self, function_declaration: Any | None = None, await_promise: result = self._conn.execute(cmd) return result - def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): + def evaluate( + self, + expression: Any | None = None, + target: Any | None = None, + await_promise: bool | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.evaluate.""" if expression is None: - raise TypeError("evaluate() missing required argument: 'expression'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("evaluate() missing required argument: 'target'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") if await_promise is None: - raise TypeError("evaluate() missing required argument: 'await_promise'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") params = { "expression": expression, @@ -876,7 +901,7 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" if script is None: - raise TypeError("remove_preload_script() missing required argument: 'script'") + raise TypeError("remove_preload_script() missing required argument: {{snake_param!r}}") params = { "script": script, @@ -889,11 +914,11 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" if channel is None: - raise TypeError("message() missing required argument: 'channel'") + raise TypeError("message() missing required argument: {{snake_param!r}}") if data is None: - raise TypeError("message() missing required argument: 'data'") + raise TypeError("message() missing required argument: {{snake_param!r}}") if source is None: - raise TypeError("message() missing required argument: 'source'") + raise TypeError("message() missing required argument: {{snake_param!r}}") params = { "channel": channel, @@ -1314,6 +1339,14 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), - "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), + "realm_created": EventConfig( + "realm_created", + "script.realmCreated", + _globals.get("RealmCreated", dict) if _globals.get("RealmCreated") else dict, + ), + "realm_destroyed": EventConfig( + "realm_destroyed", + "script.realmDestroyed", + _globals.get("RealmDestroyed", dict) if _globals.get("RealmDestroyed") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 374375a62f2ec..c7dd45ec824b8 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -195,7 +195,7 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" if capabilities is None: - raise TypeError("new() missing required argument: 'capabilities'") + raise TypeError("new() missing required argument: {{snake_param!r}}") params = { "capabilities": capabilities, @@ -214,10 +214,15 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def subscribe( + self, + events: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute session.subscribe.""" if events is None: - raise TypeError("subscribe() missing required argument: 'events'") + raise TypeError("subscribe() missing required argument: {{snake_param!r}}") params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 8742dc61ebccf..267569f782289 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -235,42 +235,6 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn - def get_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.getCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.getCookies", params) - result = self._conn.execute(cmd) - return result - - def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): - """Execute storage.setCookie.""" - if cookie is None: - raise TypeError("set_cookie() missing required argument: 'cookie'") - - params = { - "cookie": cookie, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.setCookie", params) - result = self._conn.execute(cmd) - return result - - def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.deleteCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.deleteCookies", params) - result = self._conn.execute(cmd) - return result - def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): From 924a4164af07ab8cbb248c142b148c08f05dcfd4 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 12:44:03 +0000 Subject: [PATCH 63/67] make sure not to generate F401 ruff errors --- py/generate_bidi.py | 55 +++++++++++-------- py/selenium/webdriver/common/bidi/browser.py | 7 +-- .../webdriver/common/bidi/browsing_context.py | 15 +++-- py/selenium/webdriver/common/bidi/common.py | 7 ++- .../webdriver/common/bidi/emulation.py | 29 +++++----- py/selenium/webdriver/common/bidi/input.py | 17 +++--- py/selenium/webdriver/common/bidi/log.py | 11 ++-- py/selenium/webdriver/common/bidi/network.py | 43 +++++++-------- .../webdriver/common/bidi/permissions.py | 10 ++-- py/selenium/webdriver/common/bidi/script.py | 27 ++++----- py/selenium/webdriver/common/bidi/session.py | 15 +++-- py/selenium/webdriver/common/bidi/storage.py | 9 ++- .../webdriver/common/bidi/webextension.py | 7 +-- 13 files changed, 126 insertions(+), 126 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index affd0a63a750c..8372d25743c08 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -42,7 +42,7 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import Any from .common import command_builder """ @@ -123,10 +123,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"List[{inner_type}]" + return f"list[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "Dict[str, Any]" + return "dict[str, Any]" # Default to Any for unknown types return "Any" @@ -171,7 +171,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_strs: # Check if full signature would exceed line length limit (120 chars) - single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" + single_line_signature = ( + f" def {method_name}(self, {', '.join(param_strs)}):" + ) if len(single_line_signature) > 120: # Format parameters on multiple lines body = f" def {method_name}(\n" @@ -197,7 +199,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_name in self.required_params: body += f" if {snake_param} is None:\n" msg = f"{method_snake}() missing required argument:" - body += f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + body += ( + f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + ) body += "\n" # Add validation if specified in enhancements (for additional business logic validation) @@ -585,18 +589,23 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: enhancements = enhancements or {} code = MODULE_HEADER.format(self.name) + # Collect needed imports to avoid duplicates + needs_dataclass = self.commands or self.types or self.events + needs_field = self.types + needs_threading = self.events + needs_callable = self.events + needs_session = self.events + # Add imports if needed - if self.types: - code += "from dataclasses import field\n" - if self.commands or self.types: - code += "from typing import Generator\n" + if needs_dataclass: code += "from dataclasses import dataclass\n" - - # Add imports for event handling if needed - if self.events: + if needs_field: + code += "from dataclasses import field\n" + if needs_threading: code += "import threading\n" + if needs_callable: code += "from collections.abc import Callable\n" - code += "from dataclasses import dataclass\n" + if needs_session: code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -680,7 +689,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate enums first (excluding those in exclude_types) exclude_types = set(enhancements.get("exclude_types", [])) - + # Also exclude any types that have extra_dataclasses overrides # Extract class names from extra_dataclasses strings for extra_cls in enhancements.get("extra_dataclasses", []): @@ -688,7 +697,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: match = re.search(r"class\s+(\w+)\s*:", extra_cls) if match: exclude_types.add(match.group(1)) - + for enum_def in self.enums: if enum_def.name in exclude_types: continue @@ -968,7 +977,7 @@ def clear_event_handlers(self) -> None: # Generate command methods exclude_methods = enhancements.get("exclude_methods", []) - + # Automatically exclude methods that are defined in extra_methods # to prevent generating duplicates if "extra_methods" in enhancements: @@ -977,7 +986,7 @@ def clear_event_handlers(self) -> None: match = re.search(r"def\s+(\w+)\s*\(", extra_method) if match: exclude_methods = list(exclude_methods) + [match.group(1)] - + if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1061,23 +1070,23 @@ def clear_event_handlers(self) -> None: # Try to get event class from globals, default to dict if not found getter = f'_globals.get("{event_def.name}", dict)' condition = f'_globals.get("{event_def.name}")' - event_class = f'{getter} if {condition} else dict' - + event_class = f"{getter} if {condition} else dict" + # Build the entry line and check if it exceeds 120 chars single_line = ( f' "{event_name}": ' f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' ) - + if len(single_line) > 120: # Break into multiple lines code += f' "{event_name}": EventConfig(\n' code += f' "{event_name}",\n' code += f' "{event_def.method}",\n' - code += f' {event_class},\n' - code += ' ),\n' + code += f" {event_class},\n" + code += " ),\n" else: - code += single_line + '\n' + code += single_line + "\n" # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 77ae8f0696281..a8fb60c98178d 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass def transform_download_params( diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 3f877b06b00ab..777005e0ce4e5 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class ReadinessState: """ReadinessState.""" @@ -376,10 +375,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d90d8c770263a..d7cb436a08471 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,13 @@ """Common utilities for BiDi command construction.""" -from typing import Any, Dict, Generator +from collections.abc import Generator +from typing import Any def command_builder( - method: str, params: Dict[str, Any] -) -> Generator[Dict[str, Any], Any, Any]: + method: str, params: dict[str, Any] +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index d482fecc755cb..0356372c48f03 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class ForcedColorsModeTheme: @@ -194,8 +193,8 @@ def __init__(self, conn) -> None: def set_forced_colors_mode_theme_override( self, theme: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setForcedColorsModeThemeOverride.""" if theme is None: @@ -214,8 +213,8 @@ def set_forced_colors_mode_theme_override( def set_locale_override( self, locale: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setLocaleOverride.""" if locale is None: @@ -234,8 +233,8 @@ def set_locale_override( def set_screen_settings_override( self, screen_area: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenSettingsOverride.""" if screen_area is None: @@ -254,8 +253,8 @@ def set_screen_settings_override( def set_viewport_meta_override( self, viewport_meta: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setViewportMetaOverride.""" if viewport_meta is None: @@ -274,8 +273,8 @@ def set_viewport_meta_override( def set_scrollbar_type_override( self, scrollbar_type: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScrollbarTypeOverride.""" if scrollbar_type is None: @@ -291,7 +290,7 @@ def set_scrollbar_type_override( result = self._conn.execute(cmd) return result - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 0990dacc39363..7e76cb831543f 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: input from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class PointerType: """PointerType.""" @@ -175,7 +174,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -368,7 +367,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): """Execute input.performActions.""" if context is None: raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") @@ -397,7 +396,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): """Execute input.setFiles.""" if context is None: raise TypeError("set_files() missing required argument: {{snake_param!r}}") diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 07121242348ea..fd712b7c9a8ab 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,14 +6,11 @@ # WebDriver BiDi module: log from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable from dataclasses import dataclass +from typing import Any + from selenium.webdriver.common.bidi.session import Session @@ -60,7 +57,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -85,7 +82,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index d7baeb07040ce..74951031c597f 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: network from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SameSite: """SameSite.""" @@ -565,11 +564,11 @@ def __init__(self, conn) -> None: def add_data_collector( self, - data_types: List[Any] | None = None, + data_types: list[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute network.addDataCollector.""" if data_types is None: @@ -591,9 +590,9 @@ def add_data_collector( def add_intercept( self, - phases: List[Any] | None = None, - contexts: List[Any] | None = None, - url_patterns: List[Any] | None = None, + phases: list[Any] | None = None, + contexts: list[Any] | None = None, + url_patterns: list[Any] | None = None, ): """Execute network.addIntercept.""" if phases is None: @@ -613,8 +612,8 @@ def continue_request( self, request: Any | None = None, body: Any | None = None, - cookies: List[Any] | None = None, - headers: List[Any] | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, method: Any | None = None, url: Any | None = None, ): @@ -638,9 +637,9 @@ def continue_request( def continue_response( self, request: Any | None = None, - cookies: List[Any] | None = None, + cookies: list[Any] | None = None, credentials: Any | None = None, - headers: List[Any] | None = None, + headers: list[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None, ): @@ -734,8 +733,8 @@ def provide_response( self, request: Any | None = None, body: Any | None = None, - cookies: List[Any] | None = None, - headers: List[Any] | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None, ): @@ -782,7 +781,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): """Execute network.setCacheBehavior.""" if cache_behavior is None: raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") @@ -798,9 +797,9 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A def set_extra_headers( self, - headers: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + headers: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute network.setExtraHeaders.""" if headers is None: diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index f00e765c62e3b..6dd138da17309 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Optional, Union +from typing import Any from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: Union[PermissionDescriptor, str], - state: Union[PermissionState, str], - origin: Optional[str] = None, - user_context: Optional[str] = None, + descriptor: PermissionDescriptor | str, + state: PermissionState | str, + origin: str | None = None, + user_context: str | None = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 8e832f4a9cae9..6c2e4298a2dce 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: script from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SpecialNumber: """SpecialNumber.""" @@ -786,9 +785,9 @@ def __init__(self, conn, driver=None) -> None: def add_preload_script( self, function_declaration: Any | None = None, - arguments: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + arguments: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, sandbox: Any | None = None, ): """Execute script.addPreloadScript.""" @@ -807,7 +806,7 @@ def add_preload_script( result = self._conn.execute(cmd) return result - def disown(self, handles: List[Any] | None = None, target: Any | None = None): + def disown(self, handles: list[Any] | None = None, target: Any | None = None): """Execute script.disown.""" if handles is None: raise TypeError("disown() missing required argument: {{snake_param!r}}") @@ -828,7 +827,7 @@ def call_function( function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, - arguments: List[Any] | None = None, + arguments: list[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, @@ -946,8 +945,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1188,8 +1188,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index c7dd45ec824b8..fcb42a4ad86fc 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class UserPromptHandlerType: @@ -216,9 +215,9 @@ def end(self): def subscribe( self, - events: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + events: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute session.subscribe.""" if events is None: @@ -234,7 +233,7 @@ def subscribe( result = self._conn.execute(cmd) return result - def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): + def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 267569f782289..089cee2c4fbdf 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -107,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index e007f8e4792a6..b1bc09452bc63 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass From 3578dd386ca4071b93ff4b7f7b342c1c07b57ad7 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 19:36:16 +0000 Subject: [PATCH 64/67] ruffs and mypy fixes --- py/generate_bidi.py | 51 +++++++++++++------ py/private/bidi_enhancements_manifest.py | 33 ++++++------ .../webdriver/common/bidi/browsing_context.py | 2 +- py/selenium/webdriver/common/bidi/common.py | 5 +- .../webdriver/common/bidi/emulation.py | 12 ++--- py/selenium/webdriver/common/bidi/input.py | 2 +- py/selenium/webdriver/common/bidi/log.py | 2 +- py/selenium/webdriver/common/bidi/network.py | 8 +-- py/selenium/webdriver/common/bidi/script.py | 2 +- py/selenium/webdriver/common/bidi/storage.py | 4 +- .../webdriver/common/bidi/webextension.py | 11 ++-- 11 files changed, 80 insertions(+), 52 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 8372d25743c08..ce29235456e48 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python3.10 """ Generate Python WebDriver BiDi command modules from CDDL specification. @@ -43,7 +43,6 @@ from __future__ import annotations from typing import Any -from .common import command_builder """ @@ -590,17 +589,17 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code = MODULE_HEADER.format(self.name) # Collect needed imports to avoid duplicates + needs_command_builder = bool(self.commands) needs_dataclass = self.commands or self.types or self.events - needs_field = self.types needs_threading = self.events needs_callable = self.events needs_session = self.events - # Add imports if needed + # Add imports (field import will be added conditionally after code generation) + if needs_command_builder: + code += "from .common import command_builder\n" if needs_dataclass: code += "from dataclasses import dataclass\n" - if needs_field: - code += "from dataclasses import field\n" if needs_threading: code += "import threading\n" if needs_callable: @@ -954,7 +953,7 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: code += ( - " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined ) if self.name == "script": @@ -1095,6 +1094,26 @@ def clear_event_handlers(self) -> None: code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' code += "}\n" + # Check if field() is actually used in the generated code + # If so, add the field import after the dataclass import + if "field(" in code: + # Find where to insert the field import + # It should go after "from dataclasses import dataclass" line + dataclass_import_pattern = r"from dataclasses import dataclass\n" + if re.search(dataclass_import_pattern, code): + code = re.sub( + dataclass_import_pattern, + "from dataclasses import dataclass\nfrom dataclasses import field\n", + code, + count=1 + ) + elif "from dataclasses import" not in code: + # If there's no dataclasses import yet, add field import after typing + code = code.replace( + "from typing import Any\n", + "from typing import Any\nfrom dataclasses import field\n" + ) + return code @@ -1634,12 +1653,14 @@ def generate_common_file(output_path: Path) -> None: "\n" '"""Common utilities for BiDi command construction."""\n' "\n" - "from typing import Any, Dict, Generator\n" + "from __future__ import annotations\n" + "\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" - " method: str, params: Dict[str, Any]\n" - ") -> Generator[Dict[str, Any], Any, Any]:\n" + " method: str, params: dict[str, Any]\n" + ") -> dict[str, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" @@ -1726,7 +1747,7 @@ def generate_permissions_file(output_path: Path) -> None: "from __future__ import annotations\n" "\n" "from enum import Enum\n" - "from typing import Any, Optional, Union\n" + "from typing import Any\n" "\n" "from .common import command_builder\n" "\n" @@ -1769,10 +1790,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" " def set_permission(\n" " self,\n" - " descriptor: Union[PermissionDescriptor, str],\n" - " state: Union[PermissionState, str],\n" - " origin: Optional[str] = None,\n" - " user_context: Optional[str] = None,\n" + " descriptor: PermissionDescriptor | str,\n" + " state: PermissionState | str,\n" + " origin: str | None = None,\n" + " user_context: str | None = None,\n" " ) -> None:\n" ' """Set a permission for a given origin.\n' "\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 40647157f8535..647dd7bcfd892 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -338,7 +338,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {} + params: dict[str, Any] = {} if coordinates is not None: if isinstance(coordinates, dict): coords_dict = coordinates @@ -390,7 +390,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"timezone": timezone} + params: dict[str, Any] = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -414,7 +414,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"enabled": enabled} + params: dict[str, Any] = {"enabled": enabled} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -437,7 +437,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"userAgent": user_agent} + params: dict[str, Any] = {"userAgent": user_agent} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -473,7 +473,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": "natural": natural.lower() if isinstance(natural, str) else natural, "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, } - params = {"screenOrientation": so_value} + params: dict[str, Any] = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -506,7 +506,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": nc_value = {"type": "offline"} if offline else None else: nc_value = network_conditions - params = {"networkConditions": nc_value} + params: dict[str, Any] = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -893,8 +893,8 @@ def from_json(self2, p): "network": { # Initialize intercepts tracking list and per-handler intercept map "extra_init_code": [ - "self.intercepts = []", - "self._handler_intercepts: dict = {}", + "self.intercepts: list[Any] = []", + "self._handler_intercepts: dict[str, Any] = {}", ], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ @@ -908,7 +908,7 @@ def from_json(self2, p): TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -1089,7 +1089,7 @@ def _auth_callback(params): TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -1122,7 +1122,7 @@ def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): - value = BytesValue(value_raw.get("type"), value_raw.get("value")) + value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) else: value = value_raw return cls( @@ -1379,6 +1379,7 @@ def to_bidi_dict(self) -> dict: elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} else: + assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) @@ -1395,12 +1396,14 @@ def to_bidi_dict(self) -> dict: ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): - extension = extension.get("extension") + extension_id: Any = extension.get("extension") + else: + extension_id = extension - if extension is None: + if extension_id is None: raise ValueError("extension parameter is required") - - params = {"extension": extension} + + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', ], diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 777005e0ce4e5..5b1a67ce93f11 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -598,7 +598,7 @@ def clear_event_handlers(self) -> None: class BrowsingContext: """WebDriver BiDi browsingContext module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d7cb436a08471..168f748d5501b 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,13 +17,14 @@ """Common utilities for BiDi command construction.""" -from collections.abc import Generator +from __future__ import annotations + from typing import Any def command_builder( method: str, params: dict[str, Any] -) -> Generator[dict[str, Any], Any, Any]: +) -> dict[str, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 0356372c48f03..3dcf8e58881e4 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -320,7 +320,7 @@ def set_geolocation_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {} + params: dict[str, Any] = {} if coordinates is not None: if isinstance(coordinates, dict): coords_dict = coordinates @@ -372,7 +372,7 @@ def set_timezone_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"timezone": timezone} + params: dict[str, Any] = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -396,7 +396,7 @@ def set_scripting_enabled( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"enabled": enabled} + params: dict[str, Any] = {"enabled": enabled} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -419,7 +419,7 @@ def set_user_agent_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"userAgent": user_agent} + params: dict[str, Any] = {"userAgent": user_agent} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -455,7 +455,7 @@ def set_screen_orientation_override( "natural": natural.lower() if isinstance(natural, str) else natural, "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, } - params = {"screenOrientation": so_value} + params: dict[str, Any] = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -488,7 +488,7 @@ def set_network_conditions( nc_value = {"type": "offline"} if offline else None else: nc_value = network_conditions - params = {"networkConditions": nc_value} + params: dict[str, Any] = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 7e76cb831543f..1d4730534f16d 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -362,7 +362,7 @@ def clear_event_handlers(self) -> None: class Input: """WebDriver BiDi input module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index fd712b7c9a8ab..488f0740a40b5 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -256,7 +256,7 @@ def clear_event_handlers(self) -> None: class Log: """WebDriver BiDi log module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 74951031c597f..30de3306ff001 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -368,7 +368,7 @@ class BytesValue: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -555,12 +555,12 @@ def clear_event_handlers(self) -> None: class Network: """WebDriver BiDi network module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - self.intercepts = [] - self._handler_intercepts: dict = {} + self.intercepts: list[Any] = [] + self._handler_intercepts: dict[str, Any] = {} def add_data_collector( self, diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 6c2e4298a2dce..221b5963e8ec1 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -776,7 +776,7 @@ def clear_event_handlers(self) -> None: class Script: """WebDriver BiDi script module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn, driver=None) -> None: self._conn = conn self._driver = driver diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 089cee2c4fbdf..a2606526f3856 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -76,7 +76,7 @@ class BytesValue: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -110,7 +110,7 @@ def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): - value = BytesValue(value_raw.get("type"), value_raw.get("value")) + value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) else: value = value_raw return cls( diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index b1bc09452bc63..70a21d7fd5e5e 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -100,6 +100,7 @@ def install( elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} else: + assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) @@ -116,11 +117,13 @@ def uninstall(self, extension: str | dict): ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): - extension = extension.get("extension") + extension_id: Any = extension.get("extension") + else: + extension_id = extension - if extension is None: + if extension_id is None: raise ValueError("extension parameter is required") - - params = {"extension": extension} + + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From ca22cfa080f2d5c68ce0d9c3165cae3501c3da48 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 19:43:50 +0000 Subject: [PATCH 65/67] fix linting --- py/generate_bidi.py | 12 +++++------- py/private/bidi_enhancements_manifest.py | 2 +- py/selenium/webdriver/common/bidi/common.py | 3 ++- py/selenium/webdriver/common/bidi/webextension.py | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index ce29235456e48..de41855954651 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -952,9 +952,7 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: - code += ( - " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined - ) + code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -1105,13 +1103,13 @@ def clear_event_handlers(self) -> None: dataclass_import_pattern, "from dataclasses import dataclass\nfrom dataclasses import field\n", code, - count=1 + count=1, ) elif "from dataclasses import" not in code: # If there's no dataclasses import yet, add field import after typing code = code.replace( "from typing import Any\n", - "from typing import Any\nfrom dataclasses import field\n" + "from typing import Any\nfrom dataclasses import field\n", ) return code @@ -1655,12 +1653,12 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" - "from typing import Any\n" + "from typing import Any, Generator\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> dict[str, Any]:\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 647dd7bcfd892..d9923531b0293 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1402,7 +1402,7 @@ def to_bidi_dict(self) -> dict: if extension_id is None: raise ValueError("extension parameter is required") - + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index 168f748d5501b..59e8afd93ab2e 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -19,12 +19,13 @@ from __future__ import annotations +from collections.abc import Generator from typing import Any def command_builder( method: str, params: dict[str, Any] -) -> dict[str, Any]: +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 70a21d7fd5e5e..b5881d01e0bea 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -123,7 +123,7 @@ def uninstall(self, extension: str | dict): if extension_id is None: raise ValueError("extension parameter is required") - + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From 5319a49a8a62379eaf6a335294483480d4f7f1f3 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Mar 2026 12:12:19 +0000 Subject: [PATCH 66/67] Fix auth tests --- py/generate_bidi.py | 12 +++++++----- py/private/bidi_enhancements_manifest.py | 20 +++++++++++++++++--- py/selenium/webdriver/common/bidi/network.py | 18 ++++++++++++++++-- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index de41855954651..ce29235456e48 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -952,7 +952,9 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: - code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined + code += ( + " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined + ) if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -1103,13 +1105,13 @@ def clear_event_handlers(self) -> None: dataclass_import_pattern, "from dataclasses import dataclass\nfrom dataclasses import field\n", code, - count=1, + count=1 ) elif "from dataclasses import" not in code: # If there's no dataclasses import yet, add field import after typing code = code.replace( "from typing import Any\n", - "from typing import Any\nfrom dataclasses import field\n", + "from typing import Any\nfrom dataclasses import field\n" ) return code @@ -1653,12 +1655,12 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" - "from typing import Any, Generator\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> Generator[dict[str, Any], Any, Any]:\n" + ") -> dict[str, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index d9923531b0293..d617f7468c034 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1033,6 +1033,10 @@ def _request_callback(params): """ from selenium.webdriver.common.bidi.common import command_builder as _cb + # Set up network intercept for authRequired phase + intercept_result = self._add_intercept(phases=["authRequired"]) + intercept_id = intercept_result.get("intercept") if intercept_result else None + def _auth_callback(params): raw = ( params @@ -1060,10 +1064,20 @@ def _auth_callback(params): ) ) - return self.add_event_handler("auth_required", _auth_callback)''', + callback_id = self.add_event_handler("auth_required", _auth_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id''', ''' def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID.""" - self.remove_event_handler("auth_required", callback_id)''', + """Remove an auth handler by callback ID and its associated network intercept. + + Args: + callback_id: The handler ID returned by add_auth_handler. + """ + self.remove_event_handler("auth_required", callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id)''', ], }, "storage": { diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 30de3306ff001..1dd2f5a476049 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -975,6 +975,10 @@ def add_auth_handler(self, username, password): """ from selenium.webdriver.common.bidi.common import command_builder as _cb + # Set up network intercept for authRequired phase + intercept_result = self._add_intercept(phases=["authRequired"]) + intercept_id = intercept_result.get("intercept") if intercept_result else None + def _auth_callback(params): raw = ( params @@ -1002,10 +1006,20 @@ def _auth_callback(params): ) ) - return self.add_event_handler("auth_required", _auth_callback) + callback_id = self.add_event_handler("auth_required", _auth_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID.""" + """Remove an auth handler by callback ID and its associated network intercept. + + Args: + callback_id: The handler ID returned by add_auth_handler. + """ self.remove_event_handler("auth_required", callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id) def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. From 90ca547f2566a2e984b43fea5568390a4a9a7aaa Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Mar 2026 12:30:15 +0000 Subject: [PATCH 67/67] sort spacing --- py/generate_bidi.py | 3 ++- py/private/bidi_enhancements_manifest.py | 10 ++++++++++ py/selenium/webdriver/common/bidi/browser.py | 4 ++-- .../webdriver/common/bidi/browsing_context.py | 13 ++++++------- py/selenium/webdriver/common/bidi/emulation.py | 4 ++-- py/selenium/webdriver/common/bidi/input.py | 11 +++++------ py/selenium/webdriver/common/bidi/log.py | 9 ++++----- py/selenium/webdriver/common/bidi/network.py | 9 ++++----- py/selenium/webdriver/common/bidi/script.py | 15 ++++++--------- py/selenium/webdriver/common/bidi/session.py | 4 ++-- py/selenium/webdriver/common/bidi/storage.py | 6 +++--- py/selenium/webdriver/common/bidi/webextension.py | 4 ++-- 12 files changed, 48 insertions(+), 44 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index ce29235456e48..32d19ec83cec9 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1655,12 +1655,13 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" + "from collections.abc import Generator\n" "from typing import Any\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> dict[str, Any]:\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index d617f7468c034..dcf464f425e9d 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -37,6 +37,7 @@ # ============================================================================ ENHANCEMENTS: dict[str, dict[str, Any]] = { + "browser": { # Dataclass custom methods "__dataclass_methods__": { @@ -170,6 +171,7 @@ return self._conn.execute(cmd)''', ], }, + "browsingContext": { # Method enhancements "create": { @@ -254,6 +256,7 @@ def from_json(cls, params: dict) -> "DownloadEndParams": ], # Download events are now in the CDDL spec, so no extra_events needed }, + "log": { # Make LogLevel an alias for Level so existing code using LogLevel works "aliases": {"LogLevel": "Level"}, @@ -317,6 +320,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": "entry_added": "Entry", }, }, + "emulation": { "extra_methods": [ ''' def set_geolocation_override( @@ -515,6 +519,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": return self._conn.execute(cmd)''', ], }, + "script": { "extra_methods": [ ''' def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: @@ -890,6 +895,7 @@ def from_json(self2, p): self._unsubscribe_log_entry(callback_id)''', ], }, + "network": { # Initialize intercepts tracking list and per-handler intercept map "extra_init_code": [ @@ -1080,6 +1086,7 @@ def _auth_callback(params): self._remove_intercept(intercept_id)''', ], }, + "storage": { # Exclude auto-generated dataclasses that need custom to_bidi_dict() # for JSON-over-WebSocket serialization, or custom constructors. @@ -1319,6 +1326,7 @@ def to_bidi_dict(self) -> dict: return result''', ], }, + "session": { # Override UserPromptHandler to add to_bidi_dict() for JSON serialization "exclude_types": ["UserPromptHandler"], @@ -1352,6 +1360,7 @@ def to_bidi_dict(self) -> dict: return result''', ], }, + "webExtension": { # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], @@ -1422,6 +1431,7 @@ def to_bidi_dict(self) -> dict: return self._conn.execute(cmd)''', ], }, + "input": { # FileDialogInfo needs from_json for event deserialization "exclude_types": ["FileDialogInfo", "PointerMoveAction", "PointerDownAction"], diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index a8fb60c98178d..a4ec770fbb135 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field def transform_download_params( diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5b1a67ce93f11..c5489ce865180 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class ReadinessState: """ReadinessState.""" @@ -375,10 +374,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: DownloadParams | None = None + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: + def from_json(cls, params: dict) -> "DownloadEndParams": """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 3dcf8e58881e4..03347a0a85c04 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field class ForcedColorsModeTheme: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 1d4730534f16d..44fd3c82c3407 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: input from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class PointerType: """PointerType.""" @@ -174,7 +173,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: + def from_json(cls, params: dict) -> "FileDialogInfo": """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 488f0740a40b5..3c6a95d74f6d1 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: log from __future__ import annotations +from typing import Any +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - from selenium.webdriver.common.bidi.session import Session @@ -57,7 +56,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -82,7 +81,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 1dd2f5a476049..6a0edf0b2b5e7 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: network from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 221b5963e8ec1..5a7d2792a1221 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: script from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SpecialNumber: """SpecialNumber.""" @@ -945,9 +944,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1188,9 +1186,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index fcb42a4ad86fc..177421eca5ee8 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field class UserPromptHandlerType: diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index a2606526f3856..fef35106c33b0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field @dataclass @@ -106,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index b5881d01e0bea..1c5b342c070d5 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field @dataclass